XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
compute_policy.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "group/gemm/common.hpp"
23
24namespace gpu::xetla::group {
25
28
33template <typename compute_attr_, typename perf_tuning_knob_,
34 gpu_arch arch_tag_>
36
38template <typename compute_attr_, typename perf_tuning_knob_>
39struct compute_policy_default_xmx<compute_attr_, perf_tuning_knob_,
40 gpu_arch::Xe> {
41 using compute_attr = compute_attr_;
42 using perf_tuning_knob = perf_tuning_knob_;
43 static constexpr int k_stride = perf_tuning_knob::k_stride;
44 static constexpr int stages = perf_tuning_knob::stages;
45 static constexpr int sync_freq = perf_tuning_knob::sync_freq;
46 static constexpr gpu_arch arch_tag = gpu_arch::Xe;
47 using dtype_mma_acc = typename compute_attr::dtype_acc;
48 using dtype_mma_a = typename compute_attr::dtype_a;
49 using dtype_mma_b = typename compute_attr::dtype_b;
50
51 static constexpr uint32_t block_bytes_x_a = 32;
52 static constexpr uint32_t block_size_x_a
53 = block_bytes_x_a / sizeof(dtype_mma_a);
54 static constexpr uint32_t block_size_y_a = 16;
55
56 static constexpr uint32_t block_size_x_b = 16;
57 static constexpr uint32_t block_bytes_y_b = 32;
58 static constexpr uint32_t block_size_y_b
59 = block_bytes_y_b / sizeof(dtype_mma_b);
60 static_assert(block_size_x_a == block_size_y_b,
61 "mat_a x need to match with mat_b y");
62};
63
68template <typename compute_attr_, typename perf_tuning_knob_,
69 gpu_arch arch_tag_ = gpu_arch::Xe>
71
73template <typename compute_attr_, typename perf_tuning_knob_>
74struct compute_policy_unaligned_xmx<compute_attr_, perf_tuning_knob_,
75 gpu_arch::Xe> {
76 using compute_attr = compute_attr_;
77 using perf_tuning_knob = perf_tuning_knob_;
78 static constexpr int k_stride = perf_tuning_knob::k_stride;
79 static constexpr int stages = perf_tuning_knob::stages;
80 static constexpr int sync_freq = perf_tuning_knob::sync_freq;
81 static constexpr gpu_arch arch_tag = gpu_arch::Xe;
82 using dtype_mma_acc = typename compute_attr::dtype_acc;
83 using dtype_mma_a = typename compute_attr::dtype_a;
84 using dtype_mma_b = typename compute_attr::dtype_b;
85
86 static constexpr uint32_t block_bytes_x_a = 32;
87 static constexpr uint32_t block_size_x_a
88 = block_bytes_x_a / sizeof(dtype_mma_a);
89 static constexpr uint32_t block_size_y_a = 16;
90
91 static constexpr uint32_t block_size_x_b = 16;
92 static constexpr uint32_t block_bytes_y_b = 32;
93 static constexpr uint32_t block_size_y_b
94 = block_bytes_y_b / sizeof(dtype_mma_b);
95 static_assert(block_size_x_a == block_size_y_b,
96 "mat_a x need to match with mat_b y");
97};
98
103template <typename compute_attr_, typename perf_tuning_knob_,
104 gpu_arch arch_tag_>
106
108template <typename compute_attr_, typename perf_tuning_knob_>
109struct compute_policy_default_fpu<compute_attr_, perf_tuning_knob_,
110 gpu_arch::Xe> {
111 using compute_attr = compute_attr_;
112 using perf_tuning_knob = perf_tuning_knob_;
113 static constexpr int k_stride = perf_tuning_knob::k_stride;
114 static constexpr int stages = perf_tuning_knob::stages;
115 static constexpr int sync_freq = perf_tuning_knob::sync_freq;
116 static constexpr gpu_arch arch_tag = gpu_arch::Xe;
117 using dtype_mma_acc = typename compute_attr::dtype_acc;
118 using dtype_mma_a = typename compute_attr::dtype_a;
119 using dtype_mma_b = typename compute_attr::dtype_b;
120
121 static constexpr uint32_t block_bytes_x_a = 32;
122 static constexpr uint32_t block_size_x_a
123 = block_bytes_x_a / sizeof(dtype_mma_a);
124 static constexpr uint32_t block_size_y_a = 16;
125 static constexpr uint32_t block_bytes_x_b = 64;
126 static constexpr uint32_t block_size_x_b
127 = block_bytes_x_b / sizeof(dtype_mma_b);
128 static constexpr uint32_t block_size_y_b = block_size_x_a;
129};
130
132
133} // namespace gpu::xetla::group
Definition limitation.hpp:607
gpu_arch
Definition common.hpp:73
typename compute_attr::dtype_acc dtype_mma_acc
Definition compute_policy.hpp:117
typename compute_attr::dtype_b dtype_mma_b
Definition compute_policy.hpp:119
typename compute_attr::dtype_a dtype_mma_a
Definition compute_policy.hpp:118
Compute policy for fpu engine.
Definition compute_policy.hpp:105
typename compute_attr::dtype_b dtype_mma_b
Definition compute_policy.hpp:49
typename compute_attr::dtype_acc dtype_mma_acc
Definition compute_policy.hpp:47
typename compute_attr::dtype_a dtype_mma_a
Definition compute_policy.hpp:48
Compute policy for xmx engine.
Definition compute_policy.hpp:35
typename compute_attr::dtype_a dtype_mma_a
Definition compute_policy.hpp:83
typename compute_attr::dtype_acc dtype_mma_acc
Definition compute_policy.hpp:82
typename compute_attr::dtype_b dtype_mma_b
Definition compute_policy.hpp:84
Compute policy for unaligned shape and xmx engine.
Definition compute_policy.hpp:70