XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
default_gemm.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
23#include "kernel/gemm/common.hpp"
24#include "kernel/gemm/dispatch_policy.hpp"
26
27namespace gpu::xetla {
28namespace kernel {
29template <typename dtype_a, mem_layout mem_layout_a, uint32_t alignment_a,
30 typename dtype_b, mem_layout mem_layout_b, uint32_t alignment_b,
31 typename dtype_c, mem_layout mem_layout_c, uint32_t alignment_c,
32 typename dtype_acc, gpu_arch gpu_arch_tag = gpu_arch::Xe,
33 typename tune_option = dict_t<>>
35 : param_adaptor<param_adaptor_tag::kernel,
36 typename param_optimizer<param_optimizer_tag::kernel,
37 typename default_param_t::template update_dict_t<
38 typename tune_option::template update_t<
39 elem_t_t<tune_key::data_type_a, dtype_a>,
40 elem_v_t<tune_key::memory_layout_a,
41 mem_layout_a>,
42 elem_v_t<tune_key::memory_alignment_a,
43 alignment_a>,
44 elem_t_t<tune_key::data_type_b, dtype_b>,
45 elem_v_t<tune_key::memory_layout_b,
46 mem_layout_b>,
47 elem_v_t<tune_key::memory_alignment_b,
48 alignment_b>,
49 elem_t_t<tune_key::data_type_c, dtype_c>,
50 elem_v_t<tune_key::memory_layout_c,
51 mem_layout_c>,
52 elem_v_t<tune_key::memory_alignment_c,
53 alignment_c>,
54 elem_t_t<tune_key::data_type_acc,
55 dtype_acc>,
56 elem_v_t<tune_key::gpu_arch,
57 gpu_arch_tag>>>>::type> {};
58
59template <typename dtype_a, mem_layout mem_layout_a, uint32_t alignment_a,
60 typename dtype_b, mem_layout mem_layout_b, uint32_t alignment_b,
61 typename dtype_c, mem_layout mem_layout_c, uint32_t alignment_c,
62 typename dtype_acc, gpu_arch gpu_arch_tag = gpu_arch::Xe,
63 typename tune_option = dict_t<>>
65 : default_gemm_config_t<dtype_a, mem_layout_a, alignment_a, dtype_b,
66 mem_layout_b, alignment_b, dtype_c, mem_layout_c, alignment_c,
67 dtype_acc, gpu_arch_tag, tune_option>::type {};
68} // namespace kernel
69
70template <typename dict_t_>
72 static constexpr bool use_rule
73 = (dict_t_::impl::template find_elem_index<tune_key::
74 param_optimizer_type> != dict_t_::impl::key_not_found)
75 && (dict_t_::template find_elem_v<tune_key::
77 using type = typename std::conditional<use_rule,
83};
84
85template <typename dict_t_>
87 : param_adaptor_base<dict_t_> {
88 using param = typename dict_t_::template update_t<
93
95 param>::type;
96 using epilogue_t =
98 param>::type;
99
100 using group_swizzle = typename param::template find_elem_t<
102
103 static constexpr auto dispatch_policy_tag
104 = param::template find_elem_v<tune_key::dispatch_policy>;
105 static constexpr int num_global_splitk
106 = param::template find_elem_v<tune_key::global_kslicing_ratio>;
107 static constexpr int num_local_splitk
108 = param::template find_elem_v<tune_key::local_kslicing_ratio>;
109 using dispatch_policy = typename dict_t<
114 num_global_splitk, num_local_splitk>>,
117
118 >::template find_elem_t<dispatch_policy_tag>::type;
119
121};
122
123namespace group {
124template <typename dtype_a, mem_layout mem_layout_a, uint32_t alignment_a,
125 mem_space mem_space_a, typename dtype_b, mem_layout mem_layout_b,
126 uint32_t alignment_b, mem_space mem_space_b, typename dtype_acc,
127 typename wg_shape, uint32_t wg_tile_k,
128 gpu_arch gpu_arch_tag = gpu_arch::Xe, typename tune_option = dict_t<>>
130 : param_adaptor<param_adaptor_tag::work_group_gemm,
131 typename param_optimizer<param_optimizer_tag::work_group,
132 typename default_param_t::template update_dict_t<
133 typename tune_option::template update_t<
134 elem_t_t<tune_key::data_type_a, dtype_a>,
135 elem_v_t<tune_key::memory_layout_a,
136 mem_layout_a>,
137 elem_v_t<tune_key::memory_alignment_a,
138 alignment_a>,
139 elem_v_t<tune_key::memory_space_a,
140 mem_space_a>,
141 elem_t_t<tune_key::data_type_b, dtype_b>,
142 elem_v_t<tune_key::memory_layout_b,
143 mem_layout_b>,
144 elem_v_t<tune_key::memory_alignment_b,
145 alignment_b>,
146 elem_v_t<tune_key::memory_space_b,
147 mem_space_b>,
148 elem_t_t<tune_key::data_type_acc,
149 dtype_acc>,
150 elem_t_t<tune_key::wg_tile_shape,
151 wg_shape>,
152 elem_v_t<tune_key::wg_tile_k, wg_tile_k>,
153 elem_v_t<tune_key::gpu_arch,
154 gpu_arch_tag>>>>::type> {};
155
156template <typename dtype_a, mem_layout mem_layout_a, uint32_t alignment_a,
157 mem_space mem_space_a, typename dtype_b, mem_layout mem_layout_b,
158 uint32_t alignment_b, mem_space mem_space_b, typename dtype_acc,
159 typename wg_shape, uint32_t wg_tile_k,
160 gpu_arch gpu_arch_tag = gpu_arch::Xe, typename tune_option = dict_t<>>
162 : default_gemm_selector_config_t<dtype_a, mem_layout_a, alignment_a,
163 mem_space_a, dtype_b, mem_layout_b, alignment_b, mem_space_b,
164 dtype_acc, wg_shape, wg_tile_k, gpu_arch_tag, tune_option>::type {
165};
166
167template <typename dtype_c, mem_layout mem_layout_c, uint32_t alignment_c,
168 mem_space mem_space_c, typename wg_shape, uint32_t wg_tile_k,
169 gpu_arch gpu_arch_tag = gpu_arch::Xe, typename tune_option = dict_t<>>
171 : param_adaptor<param_adaptor_tag::work_group_epilogue,
172 typename param_optimizer<param_optimizer_tag::work_group,
173 typename default_param_t::template update_dict_t<
174 typename tune_option::template update_t<
175 elem_t_t<tune_key::data_type_c, dtype_c>,
176 elem_v_t<tune_key::memory_layout_c,
177 mem_layout_c>,
178 elem_v_t<tune_key::memory_alignment_c,
179 alignment_c>,
180 elem_v_t<tune_key::memory_space_c,
181 mem_space_c>,
182 elem_t_t<tune_key::wg_tile_shape,
183 wg_shape>,
184 elem_v_t<tune_key::wg_tile_k, wg_tile_k>,
185 elem_v_t<tune_key::gpu_arch,
186 gpu_arch_tag>>>>::type> {};
187
188template <typename dtype_c, mem_layout mem_layout_c, uint32_t alignment_c,
189 mem_space mem_space_c, typename wg_shape, uint32_t wg_tile_k,
190 gpu_arch gpu_arch_tag = gpu_arch::Xe, typename tune_option = dict_t<>>
192 : default_epilogue_selector_config_t<dtype_c, mem_layout_c, alignment_c,
193 mem_space_c, wg_shape, wg_tile_k, gpu_arch_tag,
194 tune_option>::type {};
195} // namespace group
196
197template <typename dict_t_>
199 static constexpr bool use_rule
200 = (dict_t_::impl::template find_elem_index<tune_key::
201 param_optimizer_type> != dict_t_::impl::key_not_found)
202 && (dict_t_::template find_elem_v<tune_key::
204 using type = typename std::conditional<use_rule,
207 group::param_dict1_wg_t>>::type::type;
208};
209
210template <typename dict_t_>
212 : param_adaptor_base<dict_t_> {
213 using param = dict_t_;
215
216 using dtype_a =
217 typename param::template find_elem_t<tune_key::data_type_a>::type;
218 using dtype_b =
219 typename param::template find_elem_t<tune_key::data_type_b>::type;
220 static constexpr auto mem_layout_a
221 = param::template find_elem_v<tune_key::memory_layout_a>;
222 static constexpr auto mem_layout_b
223 = param::template find_elem_v<tune_key::memory_layout_b>;
224 static constexpr auto mem_space_a
225 = param::template find_elem_v<tune_key::memory_space_a>;
226 static constexpr auto mem_space_b
227 = param::template find_elem_v<tune_key::memory_space_b>;
228 static constexpr auto mem_alignment_a
229 = param::template find_elem_v<tune_key::memory_alignment_a>;
230 static constexpr auto mem_alignment_b
231 = param::template find_elem_v<tune_key::memory_alignment_b>;
232
234 typename base_t::dtype_acc>;
235
237 base_t::prefetch_distance, base_t::periodic_sync_interval>;
238
239 // specific the computation, performance tuning and computation core
240 using compute_policy = typename dict_t<
242 typename std::conditional<
244 dtype_a, dtype_b, mem_alignment_a,
245 mem_alignment_b,
246 base_t::gpu_arch_tag>::value),
248 perf_tuning_knob, base_t::gpu_arch_tag>,
251 base_t::gpu_arch_tag>>::type>,
253 typename std::conditional<
255 dtype_a, dtype_b, mem_alignment_a,
256 mem_alignment_b,
257 base_t::gpu_arch_tag>::value),
259 perf_tuning_knob, base_t::gpu_arch_tag>,
260 void>::type>>::
261 template find_elem_t<base_t::mma_engine_tag>::type;
262
267
268 static constexpr auto pre_processing_tag
269 = param::template find_elem_v<tune_key::pre_processing>;
270 using pre_processing = typename std::conditional<
271 (pre_processing_tag
274 base_t::gpu_arch_tag>,
276 base_t::gpu_arch_tag>>::type;
277
280
281 using type = gemm_t;
282};
283
284template <typename dict_t_>
286 using param = dict_t_;
288
289 using dtype_c =
290 typename param::template find_elem_t<tune_key::data_type_c>::type;
291 static constexpr auto mem_layout_c
292 = param::template find_elem_v<tune_key::memory_layout_c>;
293 static constexpr auto mem_alignment_c
294 = param::template find_elem_v<tune_key::memory_alignment_c>;
295 static constexpr auto mem_space_c
296 = param::template find_elem_v<tune_key::memory_space_c>;
297
298 using epilogue_policy = typename param::template find_elem_t<
300
302 typename base_t::tile_shape,
304
306};
307} // namespace gpu::xetla
Is the epilogue functor.
Definition api.hpp:35
Gemm functor.
Definition api.hpp:52
GEMM_UNIVERSAL functor.
Definition api.hpp:39
default_param_t::template update_t< elem_t_t< tune_key::data_type_acc, float >, elem_t_t< tune_key::wg_tile_shape, shape< 256, 256 > >, elem_v_t< tune_key::wg_tile_k, 32UL, uint32_t >, elem_t_t< tune_key::sg_tile_shape, shape< 64, 32 > >, elem_v_t< tune_key::prefetch_distance, 3UL, uint32_t >, elem_v_t< tune_key::periodic_sync_interval, 8UL, uint32_t >, elem_t_t< tune_key::epilogue_policy, group::epilogue_policy_default< gpu_arch::Xe > > > param_dict1_wg_t
Definition gemm_preset.hpp:115
default_param_t::template update_t< elem_v_t< tune_key::global_kslicing_ratio, 1UL, uint32_t >, elem_v_t< tune_key::local_kslicing_ratio, 2UL, uint32_t >, elem_t_t< tune_key::wg_tile_shape, shape< 128, 64 > >, elem_v_t< tune_key::wg_tile_k, 32UL, uint32_t >, elem_t_t< tune_key::sg_tile_shape, shape< 32, 16 > >, elem_v_t< tune_key::dispatch_policy, tune_key_value::dispatch_policy_kslicing > > param_kslicing_g1l2_t
Definition gemm_preset.hpp:102
default_param_t::template update_t< elem_v_t< tune_key::global_kslicing_ratio, 1UL, uint32_t >, elem_v_t< tune_key::local_kslicing_ratio, 1UL, uint32_t >, elem_t_t< tune_key::wg_tile_shape, shape< 256, 256 > >, elem_v_t< tune_key::wg_tile_k, 32UL, uint32_t >, elem_t_t< tune_key::sg_tile_shape, shape< 64, 32 > >, elem_v_t< tune_key::dispatch_policy, tune_key_value::dispatch_policy_kslicing > > param_kslicing_g1l1_t
Definition gemm_preset.hpp:84
default_param_t::template update_t< elem_v_t< tune_key::global_kslicing_ratio, 2UL, uint32_t >, elem_v_t< tune_key::local_kslicing_ratio, 1UL, uint32_t >, elem_t_t< tune_key::wg_tile_shape, shape< 256, 256 > >, elem_v_t< tune_key::wg_tile_k, 32UL, uint32_t >, elem_t_t< tune_key::sg_tile_shape, shape< 64, 32 > >, elem_v_t< tune_key::dispatch_policy, tune_key_value::dispatch_policy_kslicing > > param_kslicing_g2l1_t
Definition gemm_preset.hpp:93
Definition arch_config.hpp:24
param_adaptor_tag
Definition common.hpp:114
mem_space
Definition common.hpp:77
tune_key
Definition common.hpp:27
gpu_arch
Definition common.hpp:73
tune_key_value
Definition common.hpp:58
param_optimizer_tag
Definition common.hpp:70
mem_layout
Definition common.hpp:76
Definition decision_tree_policy.hpp:299
Definition dict.hpp:103
Definition dummy_policy.hpp:23
Definition dict.hpp:97
Definition dict.hpp:100
Compute attribute for gemm.
Definition common.hpp:32
Compute policy for fpu engine.
Definition compute_policy.hpp:105
Compute policy for xmx engine.
Definition compute_policy.hpp:35
Compute policy for unaligned shape and xmx engine.
Definition compute_policy.hpp:70
Definition default_gemm.hpp:194
Definition default_gemm.hpp:164
Fine-tune knobs for gemm.
Definition common.hpp:43
Gemm default pre_processing functor.
Definition api.hpp:33
Gemm pre_processing functor with applying relu op to matA.
Definition api.hpp:39
Workgroup level tile shape description.
Definition tile_shape.hpp:34
Definition default_gemm.hpp:57
Definition default_gemm.hpp:67
Default GEMM_UNIVERSAL implementation.
Definition dispatch_policy.hpp:116
Kslicing GEMM_UNIVERSAL implementation.
Definition dispatch_policy.hpp:129
StreamK GEMM implementation.
Definition dispatch_policy.hpp:142
Definition memory_descriptor.hpp:139
typename dict_t_::template update_t< elem_v_t< tune_key::memory_space_a, mem_space::global >, elem_v_t< tune_key::memory_space_b, mem_space::global >, elem_v_t< tune_key::memory_space_c, mem_space::global > > param
Definition default_gemm.hpp:91
typename dict_t< elem_t_t< tune_key_value::dispatch_policy_default, kernel::dispatch_policy_default< group_swizzle > >, elem_t_t< tune_key_value::dispatch_policy_kslicing, kernel::dispatch_policy_kslicing< group_swizzle, num_global_splitk, num_local_splitk > >, elem_t_t< tune_key_value::dispatch_policy_stream_k, kernel::dispatch_policy_stream_k< base_t::gpu_arch_tag > > >::template find_elem_t< dispatch_policy_tag >::type dispatch_policy
Definition default_gemm.hpp:118
typename param_adaptor< param_adaptor_tag::work_group_epilogue, param >::type epilogue_t
Definition default_gemm.hpp:98
typename param::template find_elem_t< tune_key::group_swizzle_policy >::type group_swizzle
Definition default_gemm.hpp:101
typename param_adaptor< param_adaptor_tag::work_group_gemm, param >::type gemm_t
Definition default_gemm.hpp:95
typename param::template find_elem_t< tune_key::epilogue_policy >::type epilogue_policy
Definition default_gemm.hpp:299
typename param::template find_elem_t< tune_key::data_type_c >::type dtype_c
Definition default_gemm.hpp:290
typename std::conditional<(pre_processing_tag==tune_key_value::pre_processing_mata_neg_filter), group::pre_processing_matA_neg_filter_t< typename base_t::tile_shape, base_t::gpu_arch_tag >, group::pre_processing_default_t< typename base_t::tile_shape, base_t::gpu_arch_tag > >::type pre_processing
Definition default_gemm.hpp:276
typename param::template find_elem_t< tune_key::data_type_a >::type dtype_a
Definition default_gemm.hpp:217
typename param::template find_elem_t< tune_key::data_type_b >::type dtype_b
Definition default_gemm.hpp:219
typename dict_t< elem_t_t< mma_engine::xmx, typename std::conditional<(group::detail::check_2d_block_pitch_alignment< dtype_a, dtype_b, mem_alignment_a, mem_alignment_b, base_t::gpu_arch_tag >::value), group::compute_policy_default_xmx< compute_attr, perf_tuning_knob, base_t::gpu_arch_tag >, group::compute_policy_unaligned_xmx< compute_attr, perf_tuning_knob, base_t::gpu_arch_tag > >::type >, elem_t_t< mma_engine::fpu, typename std::conditional<(group::detail::check_2d_block_pitch_alignment< dtype_a, dtype_b, mem_alignment_a, mem_alignment_b, base_t::gpu_arch_tag >::value), group::compute_policy_default_fpu< compute_attr, perf_tuning_knob, base_t::gpu_arch_tag >, void >::type > >::template find_elem_t< base_t::mma_engine_tag >::type compute_policy
Definition default_gemm.hpp:261
Definition common.hpp:124
typename dict_t_::template find_elem_t< tune_key::data_type_acc >::type dtype_acc
Definition common.hpp:126
Definition common.hpp:121
typename std::conditional< use_rule, decision_tree_optimizer< param_optimizer_tag::kernel, dict_t_ >, dummy_optimizer< param_optimizer_tag::kernel, dict_t_, kernel::param_kslicing_g1l1_t, kernel::param_kslicing_g2l1_t, kernel::param_kslicing_g1l2_t > >::type::type type
Definition default_gemm.hpp:82
typename std::conditional< use_rule, decision_tree_optimizer< param_optimizer_tag::work_group, dict_t_ >, dummy_optimizer< param_optimizer_tag::work_group, dict_t_, group::param_dict1_wg_t > >::type::type type
Definition default_gemm.hpp:207
Definition common.hpp:73