22namespace decision_tree_rule {
28 typename T::template find_elem_t<tune_key::data_type_acc>::type>>>;
31template <
typename dict_t_>
34 template <uint32_t wg_tile_shape_n_, uint32_t wg_tile_shape_m_,
35 uint32_t wg_tile_k_, uint32_t sg_tile_shape_n_,
36 uint32_t sg_tile_shape_m_>
53 if (z >= 0) {
return z; }
57 template <
typename T,
typename U>
60 sum +=
const_abs(T::wg_tile_shape_m - U::wg_tile_shape_m);
61 sum +=
const_abs(T::wg_tile_shape_n - U::wg_tile_shape_n);
63 sum +=
const_abs(T::sg_tile_shape_m - U::sg_tile_shape_m);
64 sum +=
const_abs(T::sg_tile_shape_n - U::sg_tile_shape_n);
76 template <
typename ref,
typename... elems>
82 template <
typename ref,
typename elem,
typename... elems>
85 static constexpr int cur_distance = distance_fcn<ref, elem>();
87 static constexpr bool use_next
88 = (
sizeof...(elems) > 0) && (cur_distance > nxt::distance);
92 = use_next ? nxt::distance : cur_distance;
95 template <
typename ref,
typename elem>
98 static constexpr int distance = distance_fcn<ref, elem>();
101 template <
typename T>
108 = wg_tile_shape::template dim<0>();
110 = wg_tile_shape::template dim<1>();
112 = T::template find_elem_v<tune_key::wg_tile_k>;
114 = sg_tile_shape::template dim<0>();
116 = sg_tile_shape::template dim<1>();
122 template <
typename T>
135 using type =
typename dict_t_::template update_dict_t<
136 typename update_config::to_dict>;
139template <
typename dict_t_>
142 template <uint32_t global_kslicing_ratio_,
143 uint32_t local_kslicing_ratio_, uint32_t wg_tile_shape_n_,
144 uint32_t wg_tile_shape_m_, uint32_t wg_tile_k_,
145 uint32_t sg_tile_shape_n_, uint32_t sg_tile_shape_m_>
148 = global_kslicing_ratio_;
150 = local_kslicing_ratio_;
173 template <
template <
typename>
typename G>
174 using apply =
typename G<this_t>::type;
177 template <
typename T>
180 = T::template find_elem_v<tune_key::global_kslicing_ratio>;
182 = T::template find_elem_v<tune_key::local_kslicing_ratio>;
188 = wg_tile_shape::template dim<0>();
190 = wg_tile_shape::template dim<1>();
192 = T::template find_elem_v<tune_key::wg_tile_k>;
194 = sg_tile_shape::template dim<0>();
196 = sg_tile_shape::template dim<1>();
203 template <
typename T>
209 template <
typename T>
212 = T::global_kslicing_ratio;
214 = T::local_kslicing_ratio;
250 template find_elem_t<local_kslicing_ratio>::type;
253 using type =
typename orig::template apply<local_kslicing_handler>;
258 using type =
typename std::conditional<
259 (dict_t_::template find_elem_v<tune_key::
261 typename dict_t_::template update_dict_t<
262 typename update_config::to_dict>,
267template <
typename dict_t_,
typename opt_dict_t_>
269 using type =
typename opt_dict_t_::template update_t<
271 typename dict_t_::template find_elem_t<
274 typename dict_t_::template find_elem_t<
277 typename dict_t_::template find_elem_t<
280 dict_t_::template find_elem_v<tune_key::memory_layout_a>>,
282 dict_t_::template find_elem_v<tune_key::memory_layout_b>>,
284 dict_t_::template find_elem_v<tune_key::memory_layout_c>>,
286 dict_t_::template find_elem_v<
289 dict_t_::template find_elem_v<
292 dict_t_::template find_elem_v<
295 dict_t_::template find_elem_v<tune_key::gpu_arch>>>;
301 using type =
typename dict_t_ ::template update_generator_t<
303 template update_generator_t<
305 template update_generator_t<
Definition arch_config.hpp:24
@ dispatch_policy_kslicing
param_optimizer_tag
Definition common.hpp:70
Definition decision_tree_policy.hpp:300
typename dict_t_ ::template update_generator_t< decision_tree_rule::data_type_handler >::template update_generator_t< decision_tree_rule::tile_shape_handler >::template update_generator_t< decision_tree_rule::kslicing_handler > type
Definition decision_tree_policy.hpp:306
fallback_optimizer< dict_t_, type > fallback_type
Definition decision_tree_policy.hpp:307
Definition decision_tree_policy.hpp:299
typename std::conditional< use_fallback, typename impl::fallback_type, impl >::type::type type
Definition decision_tree_policy.hpp:313
static constexpr bool use_fallback
Definition decision_tree_policy.hpp:310
Definition decision_tree_policy.hpp:25
typename T::template update_dict_t< dict_t< elem_t_t< tune_key::data_type_acc, typename T::template find_elem_t< tune_key::data_type_acc >::type > > > type
Definition decision_tree_policy.hpp:28
Definition decision_tree_policy.hpp:178
static constexpr uint32_t sg_tile_shape_n
Definition decision_tree_policy.hpp:194
static constexpr uint32_t local_kslicing_ratio
Definition decision_tree_policy.hpp:182
kslicing_config< global_kslicing_ratio, local_kslicing_ratio, wg_tile_shape_n, wg_tile_shape_m, wg_tile_k, sg_tile_shape_n, sg_tile_shape_m > type
Definition decision_tree_policy.hpp:200
static constexpr uint32_t wg_tile_shape_m
Definition decision_tree_policy.hpp:190
typename T::template find_elem_t< tune_key::sg_tile_shape >::type sg_tile_shape
Definition decision_tree_policy.hpp:186
static constexpr uint32_t sg_tile_shape_m
Definition decision_tree_policy.hpp:196
typename T::template find_elem_t< tune_key::wg_tile_shape >::type wg_tile_shape
Definition decision_tree_policy.hpp:184
static constexpr uint32_t wg_tile_shape_n
Definition decision_tree_policy.hpp:188
static constexpr uint32_t wg_tile_k
Definition decision_tree_policy.hpp:192
static constexpr uint32_t global_kslicing_ratio
Definition decision_tree_policy.hpp:180
Definition decision_tree_policy.hpp:146
static constexpr uint32_t sg_tile_shape_m
Definition decision_tree_policy.hpp:155
typename G< this_t >::type apply
Definition decision_tree_policy.hpp:174
static constexpr uint32_t wg_tile_k
Definition decision_tree_policy.hpp:153
static constexpr uint32_t local_kslicing_ratio
Definition decision_tree_policy.hpp:150
static constexpr uint32_t wg_tile_shape_m
Definition decision_tree_policy.hpp:152
static constexpr uint32_t sg_tile_shape_n
Definition decision_tree_policy.hpp:154
static constexpr uint32_t wg_tile_shape_n
Definition decision_tree_policy.hpp:151
static constexpr uint32_t global_kslicing_ratio
Definition decision_tree_policy.hpp:148
Definition decision_tree_policy.hpp:210
static constexpr uint32_t global_kslicing_ratio
Definition decision_tree_policy.hpp:212
static constexpr uint32_t sg_tile_shape_n
Definition decision_tree_policy.hpp:224
typename dict_t< elem_t_t< 1U, kslicing_config< global_kslicing_ratio, local_kslicing_ratio, wg_tile_shape_n, wg_tile_shape_m, wg_tile_k, sg_tile_shape_n, sg_tile_shape_m > >, elem_t_t< 2U, kslicing_config< global_kslicing_ratio, local_kslicing_ratio, 128, 64, 32, 32, 16 > >, elem_t_t< 4U, kslicing_config< global_kslicing_ratio, local_kslicing_ratio, 64, 64, 32, 32, 16 > >, elem_t_t< 8U, kslicing_config< global_kslicing_ratio, local_kslicing_ratio, 64, 32, 32, 32, 16 > >, elem_t_t< 16U, kslicing_config< global_kslicing_ratio, local_kslicing_ratio, 64, 16, 32, 32, 16 > > >::template find_elem_t< local_kslicing_ratio >::type type
Definition decision_tree_policy.hpp:250
static constexpr uint32_t sg_tile_shape_m
Definition decision_tree_policy.hpp:226
static constexpr uint32_t local_kslicing_ratio
Definition decision_tree_policy.hpp:214
static constexpr uint32_t wg_tile_shape_m
Definition decision_tree_policy.hpp:220
static constexpr uint32_t wg_tile_k
Definition decision_tree_policy.hpp:222
static constexpr uint32_t wg_tile_shape_n
Definition decision_tree_policy.hpp:217
Definition decision_tree_policy.hpp:206
typename orig::template apply< local_kslicing_handler > type
Definition decision_tree_policy.hpp:253
from_dict< dict_t_ > orig
Definition decision_tree_policy.hpp:207
Definition decision_tree_policy.hpp:141
typename from_dict_impl< T >::type from_dict
Definition decision_tree_policy.hpp:204
Definition decision_tree_policy.hpp:140
typename std::conditional<(dict_t_::template find_elem_v< tune_key::dispatch_policy >==tune_key_value::dispatch_policy_kslicing), typename dict_t_::template update_dict_t< typename update_config::to_dict >, dict_t_ >::type type
Definition decision_tree_policy.hpp:263
typename impl::update_config_impl::type update_config
Definition decision_tree_policy.hpp:257
elem type
Definition decision_tree_policy.hpp:97
typename std::conditional< use_next, typename nxt::type, cur_type >::type type
Definition decision_tree_policy.hpp:90
ref cur_type
Definition decision_tree_policy.hpp:84
Definition decision_tree_policy.hpp:77
static constexpr int distance
Definition decision_tree_policy.hpp:79
ref type
Definition decision_tree_policy.hpp:78
Definition decision_tree_policy.hpp:102
typename T::template find_elem_t< tune_key::wg_tile_shape >::type wg_tile_shape
Definition decision_tree_policy.hpp:104
static constexpr uint32_t wg_tile_shape_n
Definition decision_tree_policy.hpp:108
tile_shape_config< wg_tile_shape_n, wg_tile_shape_m, wg_tile_k, sg_tile_shape_n, sg_tile_shape_m > type
Definition decision_tree_policy.hpp:119
static constexpr uint32_t wg_tile_k
Definition decision_tree_policy.hpp:112
typename T::template find_elem_t< tune_key::sg_tile_shape >::type sg_tile_shape
Definition decision_tree_policy.hpp:106
static constexpr uint32_t wg_tile_shape_m
Definition decision_tree_policy.hpp:110
static constexpr uint32_t sg_tile_shape_n
Definition decision_tree_policy.hpp:114
static constexpr uint32_t sg_tile_shape_m
Definition decision_tree_policy.hpp:116
Definition decision_tree_policy.hpp:37
static constexpr uint32_t sg_tile_shape_m
Definition decision_tree_policy.hpp:42
static constexpr uint32_t wg_tile_shape_n
Definition decision_tree_policy.hpp:38
static constexpr uint32_t wg_tile_shape_m
Definition decision_tree_policy.hpp:39
static constexpr uint32_t sg_tile_shape_n
Definition decision_tree_policy.hpp:41
static constexpr uint32_t wg_tile_k
Definition decision_tree_policy.hpp:40
Definition decision_tree_policy.hpp:125
typename find_min_elem< orig, wg_256x256_k32_sg_32x64, wg_256x256_k32_sg_64x32, wg_128x512_k16_sg_32x64, wg_512x128_k16_sg_64x32, wg_32x256_k32_sg_16x16, wg_512x64_k32_sg_32x32, wg_64x64_k32_sg_16x8 >::type type
Definition decision_tree_policy.hpp:130
from_dict< dict_t_ > orig
Definition decision_tree_policy.hpp:126
Definition decision_tree_policy.hpp:33
static constexpr int distance_fcn()
Definition decision_tree_policy.hpp:58
tile_shape_config< 32, 256, 32, 16, 16 > wg_32x256_k32_sg_16x16
Definition decision_tree_policy.hpp:72
tile_shape_config< 128, 512, 16, 32, 64 > wg_128x512_k16_sg_32x64
Definition decision_tree_policy.hpp:70
tile_shape_config< 256, 256, 32, 64, 32 > wg_256x256_k32_sg_64x32
Definition decision_tree_policy.hpp:69
tile_shape_config< 512, 64, 32, 32, 32 > wg_512x64_k32_sg_32x32
Definition decision_tree_policy.hpp:73
tile_shape_config< 256, 256, 32, 32, 64 > wg_256x256_k32_sg_32x64
Definition decision_tree_policy.hpp:68
tile_shape_config< 512, 128, 16, 64, 32 > wg_512x128_k16_sg_64x32
Definition decision_tree_policy.hpp:71
static constexpr int const_abs(const int &z)
Definition decision_tree_policy.hpp:52
typename from_dict_impl< T >::type from_dict
Definition decision_tree_policy.hpp:123
Definition decision_tree_policy.hpp:32
typename impl::update_config_impl::type update_config
Definition decision_tree_policy.hpp:134
typename dict_t_::template update_dict_t< typename update_config::to_dict > type
Definition decision_tree_policy.hpp:136
Definition decision_tree_policy.hpp:268
typename opt_dict_t_::template update_t< elem_t_t< tune_key::data_type_a, typename dict_t_::template find_elem_t< tune_key::data_type_a >::type >, elem_t_t< tune_key::data_type_b, typename dict_t_::template find_elem_t< tune_key::data_type_b >::type >, elem_t_t< tune_key::data_type_c, typename dict_t_::template find_elem_t< tune_key::data_type_c >::type >, elem_v_t< tune_key::memory_layout_a, dict_t_::template find_elem_v< tune_key::memory_layout_a > >, elem_v_t< tune_key::memory_layout_b, dict_t_::template find_elem_v< tune_key::memory_layout_b > >, elem_v_t< tune_key::memory_layout_c, dict_t_::template find_elem_v< tune_key::memory_layout_c > >, elem_v_t< tune_key::memory_alignment_a, dict_t_::template find_elem_v< tune_key::memory_alignment_a > >, elem_v_t< tune_key::memory_alignment_b, dict_t_::template find_elem_v< tune_key::memory_alignment_b > >, elem_v_t< tune_key::memory_alignment_c, dict_t_::template find_elem_v< tune_key::memory_alignment_c > >, elem_v_t< tune_key::gpu_arch, dict_t_::template find_elem_v< tune_key::gpu_arch > > > type
Definition decision_tree_policy.hpp:295