23#include "kernel/gemm/common.hpp"
24#include "kernel/gemm/dispatch_policy.hpp"
29template <
typename dtype_a,
mem_layout mem_layout_a, uint32_t alignment_a,
30 typename dtype_b,
mem_layout mem_layout_b, uint32_t alignment_b,
31 typename dtype_c,
mem_layout mem_layout_c, uint32_t alignment_c,
33 typename tune_option = dict_t<>>
36 typename param_optimizer<param_optimizer_tag::kernel,
37 typename default_param_t::template update_dict_t<
38 typename tune_option::template update_t<
39 elem_t_t<tune_key::data_type_a, dtype_a>,
40 elem_v_t<tune_key::memory_layout_a,
42 elem_v_t<tune_key::memory_alignment_a,
44 elem_t_t<tune_key::data_type_b, dtype_b>,
45 elem_v_t<tune_key::memory_layout_b,
47 elem_v_t<tune_key::memory_alignment_b,
49 elem_t_t<tune_key::data_type_c, dtype_c>,
50 elem_v_t<tune_key::memory_layout_c,
52 elem_v_t<tune_key::memory_alignment_c,
54 elem_t_t<tune_key::data_type_acc,
56 elem_v_t<tune_key::gpu_arch,
57 gpu_arch_tag>>>>::type> {};
59template <
typename dtype_a,
mem_layout mem_layout_a, uint32_t alignment_a,
60 typename dtype_b,
mem_layout mem_layout_b, uint32_t alignment_b,
61 typename dtype_c,
mem_layout mem_layout_c, uint32_t alignment_c,
66 mem_layout_b, alignment_b, dtype_c, mem_layout_c, alignment_c,
67 dtype_acc, gpu_arch_tag, tune_option>::type {};
70template <
typename dict_t_>
72 static constexpr bool use_rule
73 = (dict_t_::impl::template find_elem_index<tune_key::
74 param_optimizer_type> != dict_t_::impl::key_not_found)
75 && (dict_t_::template find_elem_v<
tune_key::
77 using type =
typename std::conditional<use_rule,
85template <
typename dict_t_>
88 using param =
typename dict_t_::template update_t<
103 static constexpr auto dispatch_policy_tag
104 = param::template find_elem_v<tune_key::dispatch_policy>;
105 static constexpr int num_global_splitk
106 = param::template find_elem_v<tune_key::global_kslicing_ratio>;
107 static constexpr int num_local_splitk
108 = param::template find_elem_v<tune_key::local_kslicing_ratio>;
114 num_global_splitk, num_local_splitk>>,
118 >::template find_elem_t<dispatch_policy_tag>::type;
124template <
typename dtype_a,
mem_layout mem_layout_a, uint32_t alignment_a,
126 uint32_t alignment_b,
mem_space mem_space_b,
typename dtype_acc,
131 typename param_optimizer<param_optimizer_tag::work_group,
132 typename default_param_t::template update_dict_t<
133 typename tune_option::template update_t<
134 elem_t_t<tune_key::data_type_a, dtype_a>,
135 elem_v_t<tune_key::memory_layout_a,
137 elem_v_t<tune_key::memory_alignment_a,
139 elem_v_t<tune_key::memory_space_a,
141 elem_t_t<tune_key::data_type_b, dtype_b>,
142 elem_v_t<tune_key::memory_layout_b,
144 elem_v_t<tune_key::memory_alignment_b,
146 elem_v_t<tune_key::memory_space_b,
148 elem_t_t<tune_key::data_type_acc,
150 elem_t_t<tune_key::wg_tile_shape,
152 elem_v_t<tune_key::wg_tile_k, wg_tile_k>,
153 elem_v_t<tune_key::gpu_arch,
154 gpu_arch_tag>>>>::type> {};
156template <
typename dtype_a,
mem_layout mem_layout_a, uint32_t alignment_a,
158 uint32_t alignment_b,
mem_space mem_space_b,
typename dtype_acc,
163 mem_space_a, dtype_b, mem_layout_b, alignment_b, mem_space_b,
164 dtype_acc, wg_shape, wg_tile_k, gpu_arch_tag, tune_option>::type {
167template <
typename dtype_c,
mem_layout mem_layout_c, uint32_t alignment_c,
172 typename param_optimizer<param_optimizer_tag::work_group,
173 typename default_param_t::template update_dict_t<
174 typename tune_option::template update_t<
175 elem_t_t<tune_key::data_type_c, dtype_c>,
176 elem_v_t<tune_key::memory_layout_c,
178 elem_v_t<tune_key::memory_alignment_c,
180 elem_v_t<tune_key::memory_space_c,
182 elem_t_t<tune_key::wg_tile_shape,
184 elem_v_t<tune_key::wg_tile_k, wg_tile_k>,
185 elem_v_t<tune_key::gpu_arch,
186 gpu_arch_tag>>>>::type> {};
188template <
typename dtype_c,
mem_layout mem_layout_c, uint32_t alignment_c,
193 mem_space_c, wg_shape, wg_tile_k, gpu_arch_tag,
194 tune_option>::type {};
197template <
typename dict_t_>
199 static constexpr bool use_rule
200 = (dict_t_::impl::template find_elem_index<tune_key::
201 param_optimizer_type> != dict_t_::impl::key_not_found)
202 && (dict_t_::template find_elem_v<
tune_key::
204 using type =
typename std::conditional<use_rule,
210template <
typename dict_t_>
217 typename param::template find_elem_t<tune_key::data_type_a>::type;
219 typename param::template find_elem_t<tune_key::data_type_b>::type;
220 static constexpr auto mem_layout_a
221 = param::template find_elem_v<tune_key::memory_layout_a>;
222 static constexpr auto mem_layout_b
223 = param::template find_elem_v<tune_key::memory_layout_b>;
224 static constexpr auto mem_space_a
225 = param::template find_elem_v<tune_key::memory_space_a>;
226 static constexpr auto mem_space_b
227 = param::template find_elem_v<tune_key::memory_space_b>;
228 static constexpr auto mem_alignment_a
229 = param::template find_elem_v<tune_key::memory_alignment_a>;
230 static constexpr auto mem_alignment_b
231 = param::template find_elem_v<tune_key::memory_alignment_b>;
237 base_t::prefetch_distance, base_t::periodic_sync_interval>;
242 typename std::conditional<
246 base_t::gpu_arch_tag>::value),
251 base_t::gpu_arch_tag>>::type>,
253 typename std::conditional<
257 base_t::gpu_arch_tag>::value),
261 template find_elem_t<base_t::mma_engine_tag>::type;
268 static constexpr auto pre_processing_tag
269 = param::template find_elem_v<tune_key::pre_processing>;
274 base_t::gpu_arch_tag>,
276 base_t::gpu_arch_tag>>::type;
284template <
typename dict_t_>
290 typename param::template find_elem_t<tune_key::data_type_c>::type;
291 static constexpr auto mem_layout_c
292 = param::template find_elem_v<tune_key::memory_layout_c>;
293 static constexpr auto mem_alignment_c
294 = param::template find_elem_v<tune_key::memory_alignment_c>;
295 static constexpr auto mem_space_c
296 = param::template find_elem_v<tune_key::memory_space_c>;
Definition selector_xe.hpp:30
Is the epilogue functor.
Definition api.hpp:35
Gemm functor.
Definition api.hpp:52
GEMM_UNIVERSAL functor.
Definition api.hpp:39
default_param_t::template update_t< elem_t_t< tune_key::data_type_acc, float >, elem_t_t< tune_key::wg_tile_shape, shape< 256, 256 > >, elem_v_t< tune_key::wg_tile_k, 32UL, uint32_t >, elem_t_t< tune_key::sg_tile_shape, shape< 64, 32 > >, elem_v_t< tune_key::prefetch_distance, 3UL, uint32_t >, elem_v_t< tune_key::periodic_sync_interval, 8UL, uint32_t >, elem_t_t< tune_key::epilogue_policy, group::epilogue_policy_default< gpu_arch::Xe > > > param_dict1_wg_t
Definition gemm_preset.hpp:115
default_param_t::template update_t< elem_v_t< tune_key::global_kslicing_ratio, 1UL, uint32_t >, elem_v_t< tune_key::local_kslicing_ratio, 2UL, uint32_t >, elem_t_t< tune_key::wg_tile_shape, shape< 128, 64 > >, elem_v_t< tune_key::wg_tile_k, 32UL, uint32_t >, elem_t_t< tune_key::sg_tile_shape, shape< 32, 16 > >, elem_v_t< tune_key::dispatch_policy, tune_key_value::dispatch_policy_kslicing > > param_kslicing_g1l2_t
Definition gemm_preset.hpp:102
default_param_t::template update_t< elem_v_t< tune_key::global_kslicing_ratio, 1UL, uint32_t >, elem_v_t< tune_key::local_kslicing_ratio, 1UL, uint32_t >, elem_t_t< tune_key::wg_tile_shape, shape< 256, 256 > >, elem_v_t< tune_key::wg_tile_k, 32UL, uint32_t >, elem_t_t< tune_key::sg_tile_shape, shape< 64, 32 > >, elem_v_t< tune_key::dispatch_policy, tune_key_value::dispatch_policy_kslicing > > param_kslicing_g1l1_t
Definition gemm_preset.hpp:84
default_param_t::template update_t< elem_v_t< tune_key::global_kslicing_ratio, 2UL, uint32_t >, elem_v_t< tune_key::local_kslicing_ratio, 1UL, uint32_t >, elem_t_t< tune_key::wg_tile_shape, shape< 256, 256 > >, elem_v_t< tune_key::wg_tile_k, 32UL, uint32_t >, elem_t_t< tune_key::sg_tile_shape, shape< 64, 32 > >, elem_v_t< tune_key::dispatch_policy, tune_key_value::dispatch_policy_kslicing > > param_kslicing_g2l1_t
Definition gemm_preset.hpp:93
Definition arch_config.hpp:24
param_adaptor_tag
Definition common.hpp:114
mem_space
Definition common.hpp:77
tune_key
Definition common.hpp:27
gpu_arch
Definition common.hpp:73
tune_key_value
Definition common.hpp:58
@ pre_processing_mata_neg_filter
@ dispatch_policy_default
@ dispatch_policy_stream_k
@ param_optimizer_decision_tree
@ dispatch_policy_kslicing
param_optimizer_tag
Definition common.hpp:70
mem_layout
Definition common.hpp:76
Definition decision_tree_policy.hpp:299
Definition dummy_policy.hpp:23
Compute attribute for gemm.
Definition common.hpp:32
Compute policy for fpu engine.
Definition compute_policy.hpp:105
Compute policy for xmx engine.
Definition compute_policy.hpp:35
Compute policy for unaligned shape and xmx engine.
Definition compute_policy.hpp:70
Definition default_gemm.hpp:186
Definition default_gemm.hpp:194
Definition default_gemm.hpp:154
Definition default_gemm.hpp:164
Fine-tune knobs for gemm.
Definition common.hpp:43
Gemm default pre_processing functor.
Definition api.hpp:33
Gemm pre_processing functor with applying relu op to matA.
Definition api.hpp:39
Workgroup level tile shape description.
Definition tile_shape.hpp:34
Definition default_gemm.hpp:57
Definition default_gemm.hpp:67
Default GEMM_UNIVERSAL implementation.
Definition dispatch_policy.hpp:116
Kslicing GEMM_UNIVERSAL implementation.
Definition dispatch_policy.hpp:129
StreamK GEMM implementation.
Definition dispatch_policy.hpp:142
Definition memory_descriptor.hpp:139
typename dict_t_::template update_t< elem_v_t< tune_key::memory_space_a, mem_space::global >, elem_v_t< tune_key::memory_space_b, mem_space::global >, elem_v_t< tune_key::memory_space_c, mem_space::global > > param
Definition default_gemm.hpp:91
typename dict_t< elem_t_t< tune_key_value::dispatch_policy_default, kernel::dispatch_policy_default< group_swizzle > >, elem_t_t< tune_key_value::dispatch_policy_kslicing, kernel::dispatch_policy_kslicing< group_swizzle, num_global_splitk, num_local_splitk > >, elem_t_t< tune_key_value::dispatch_policy_stream_k, kernel::dispatch_policy_stream_k< base_t::gpu_arch_tag > > >::template find_elem_t< dispatch_policy_tag >::type dispatch_policy
Definition default_gemm.hpp:118
typename param_adaptor< param_adaptor_tag::work_group_epilogue, param >::type epilogue_t
Definition default_gemm.hpp:98
typename param::template find_elem_t< tune_key::group_swizzle_policy >::type group_swizzle
Definition default_gemm.hpp:101
typename param_adaptor< param_adaptor_tag::work_group_gemm, param >::type gemm_t
Definition default_gemm.hpp:95
typename param::template find_elem_t< tune_key::epilogue_policy >::type epilogue_policy
Definition default_gemm.hpp:299
typename param::template find_elem_t< tune_key::data_type_c >::type dtype_c
Definition default_gemm.hpp:290
dict_t_ param
Definition default_gemm.hpp:286
typename std::conditional<(pre_processing_tag==tune_key_value::pre_processing_mata_neg_filter), group::pre_processing_matA_neg_filter_t< typename base_t::tile_shape, base_t::gpu_arch_tag >, group::pre_processing_default_t< typename base_t::tile_shape, base_t::gpu_arch_tag > >::type pre_processing
Definition default_gemm.hpp:276
typename param::template find_elem_t< tune_key::data_type_a >::type dtype_a
Definition default_gemm.hpp:217
dict_t_ param
Definition default_gemm.hpp:213
typename param::template find_elem_t< tune_key::data_type_b >::type dtype_b
Definition default_gemm.hpp:219
typename dict_t< elem_t_t< mma_engine::xmx, typename std::conditional<(group::detail::check_2d_block_pitch_alignment< dtype_a, dtype_b, mem_alignment_a, mem_alignment_b, base_t::gpu_arch_tag >::value), group::compute_policy_default_xmx< compute_attr, perf_tuning_knob, base_t::gpu_arch_tag >, group::compute_policy_unaligned_xmx< compute_attr, perf_tuning_knob, base_t::gpu_arch_tag > >::type >, elem_t_t< mma_engine::fpu, typename std::conditional<(group::detail::check_2d_block_pitch_alignment< dtype_a, dtype_b, mem_alignment_a, mem_alignment_b, base_t::gpu_arch_tag >::value), group::compute_policy_default_fpu< compute_attr, perf_tuning_knob, base_t::gpu_arch_tag >, void >::type > >::template find_elem_t< base_t::mma_engine_tag >::type compute_policy
Definition default_gemm.hpp:261
Definition common.hpp:124
typename dict_t_::template find_elem_t< tune_key::data_type_acc >::type dtype_acc
Definition common.hpp:126
Definition common.hpp:121
typename std::conditional< use_rule, decision_tree_optimizer< param_optimizer_tag::kernel, dict_t_ >, dummy_optimizer< param_optimizer_tag::kernel, dict_t_, kernel::param_kslicing_g1l1_t, kernel::param_kslicing_g2l1_t, kernel::param_kslicing_g1l2_t > >::type::type type
Definition default_gemm.hpp:82
typename std::conditional< use_rule, decision_tree_optimizer< param_optimizer_tag::work_group, dict_t_ >, dummy_optimizer< param_optimizer_tag::work_group, dict_t_, group::param_dict1_wg_t > >::type::type type
Definition default_gemm.hpp:207