19#include "common/common.hpp"
20#include "group/group.hpp"
21#include "subgroup/subgroup.hpp"
72template <param_optimizer_tag tag_,
typename dict_t_>
76 template <
typename T,
typename U>
78 static constexpr bool value = []()
constexpr {
80 valid &= std::is_same<
typename T::template find_elem_t<
82 typename U::template find_elem_t<
84 valid &= T::template find_elem_v<tune_key::
85 memory_layout_a> == U::template find_elem_v<tune_key::memory_layout_a>;
86 valid &= T::template find_elem_v<tune_key::
87 memory_alignment_a> == U::template find_elem_v<tune_key::memory_alignment_a>;
88 valid &= std::is_same<
typename T::template find_elem_t<
90 typename U::template find_elem_t<
92 valid &= T::template find_elem_v<tune_key::
93 memory_layout_b> == U::template find_elem_v<tune_key::memory_layout_b>;
94 valid &= T::template find_elem_v<tune_key::
95 memory_alignment_b> == U::template find_elem_v<tune_key::memory_alignment_b>;
96 valid &= std::is_same<
typename T::template find_elem_t<
98 typename U::template find_elem_t<
100 valid &= T::template find_elem_v<tune_key::
101 memory_layout_c> == U::template find_elem_v<tune_key::memory_layout_c>;
102 valid &= T::template find_elem_v<tune_key::
103 memory_alignment_c> == U::template find_elem_v<tune_key::memory_alignment_c>;
104 valid &= T::template find_elem_v<tune_key::
105 gpu_arch> == U::template find_elem_v<tune_key::gpu_arch>;
120template <param_adaptor_tag tag_,
typename dict_t_>
123template <
typename dict_t_>
125 using dtype_acc =
typename dict_t_::template find_elem_t<
129 static constexpr uint32_t
wg_tile_n = wg_tile_shape::template dim<0>();
130 static constexpr uint32_t
wg_tile_m = wg_tile_shape::template dim<1>();
132 = dict_t_::template find_elem_v<tune_key::wg_tile_k>;
135 static constexpr uint32_t
sg_tile_n = sg_tile_shape::template dim<0>();
136 static constexpr uint32_t
sg_tile_m = sg_tile_shape::template dim<1>();
138 = dict_t_::template find_elem_v<tune_key::prefetch_distance>;
140 = dict_t_::template find_elem_v<tune_key::periodic_sync_interval>;
142 = dict_t_::template find_elem_v<tune_key::mma_engine>;
144 = dict_t_::template find_elem_v<tune_key::gpu_arch>;
Definition arch_config.hpp:24
param_adaptor_tag
Definition common.hpp:114
mma_engine
Definition common.hpp:225
tune_key
Definition common.hpp:27
gpu_arch
Definition common.hpp:73
tune_key_value
Definition common.hpp:58
@ pre_processing_mata_neg_filter
@ dispatch_policy_default
@ dispatch_policy_stream_k
@ param_optimizer_decision_tree
@ dispatch_policy_kslicing
param_optimizer_tag
Definition common.hpp:70
Workgroup level tile shape description.
Definition tile_shape.hpp:34
Definition common.hpp:124
typename dict_t_::template find_elem_t< tune_key::sg_tile_shape >::type sg_tile_shape
Definition common.hpp:134
static constexpr uint32_t periodic_sync_interval
Definition common.hpp:140
static constexpr uint32_t sg_tile_n
Definition common.hpp:135
typename dict_t_::template find_elem_t< tune_key::data_type_acc >::type dtype_acc
Definition common.hpp:126
static constexpr auto mma_engine_tag
Definition common.hpp:142
static constexpr auto gpu_arch_tag
Definition common.hpp:144
static constexpr uint32_t sg_tile_m
Definition common.hpp:136
static constexpr uint32_t wg_tile_n
Definition common.hpp:129
typename dict_t_::template find_elem_t< tune_key::wg_tile_shape >::type wg_tile_shape
Definition common.hpp:128
static constexpr uint32_t prefetch_distance
Definition common.hpp:138
static constexpr uint32_t wg_tile_m
Definition common.hpp:130
static constexpr uint32_t wg_tile_k
Definition common.hpp:132
Definition common.hpp:121
static constexpr bool value
Definition common.hpp:78