XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
decision_tree_policy.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (c) 2022-2023 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17#pragma once
18
20
21namespace gpu::xetla {
22namespace decision_tree_rule {
23
24template <typename T>
26 using type = typename T::template update_dict_t<dict_t<elem_t_t<
28 typename T::template find_elem_t<tune_key::data_type_acc>::type>>>;
29};
30
31template <typename dict_t_>
33 struct impl {
34 template <uint32_t wg_tile_shape_n_, uint32_t wg_tile_shape_m_,
35 uint32_t wg_tile_k_, uint32_t sg_tile_shape_n_,
36 uint32_t sg_tile_shape_m_>
38 static constexpr uint32_t wg_tile_shape_n = wg_tile_shape_n_;
39 static constexpr uint32_t wg_tile_shape_m = wg_tile_shape_m_;
40 static constexpr uint32_t wg_tile_k = wg_tile_k_;
41 static constexpr uint32_t sg_tile_shape_n = sg_tile_shape_n_;
42 static constexpr uint32_t sg_tile_shape_m = sg_tile_shape_m_;
43
44 using to_dict
50 };
51
52 static constexpr int const_abs(const int &z) {
53 if (z >= 0) { return z; }
54 return -z;
55 }
56
57 template <typename T, typename U>
58 static constexpr int distance_fcn() {
59 int sum = 0;
60 sum += const_abs(T::wg_tile_shape_m - U::wg_tile_shape_m);
61 sum += const_abs(T::wg_tile_shape_n - U::wg_tile_shape_n);
62 sum += const_abs(T::wg_tile_k - U::wg_tile_k);
63 sum += const_abs(T::sg_tile_shape_m - U::sg_tile_shape_m);
64 sum += const_abs(T::sg_tile_shape_n - U::sg_tile_shape_n);
65 return sum;
66 }
67
75
76 template <typename ref, typename... elems>
78 using type = ref;
79 static constexpr int distance = -1;
80 };
81
82 template <typename ref, typename elem, typename... elems>
83 struct find_min_elem<ref, elem, elems...> {
84 using cur_type = ref;
85 static constexpr int cur_distance = distance_fcn<ref, elem>();
86 using nxt = find_min_elem<ref, elems...>;
87 static constexpr bool use_next
88 = (sizeof...(elems) > 0) && (cur_distance > nxt::distance);
89 using type = typename std::conditional<use_next, typename nxt::type,
91 static constexpr int distance
92 = use_next ? nxt::distance : cur_distance;
93 };
94
95 template <typename ref, typename elem>
96 struct find_min_elem<ref, elem> {
97 using type = elem;
98 static constexpr int distance = distance_fcn<ref, elem>();
99 };
100
101 template <typename T>
103 using wg_tile_shape = typename T::template find_elem_t<
105 using sg_tile_shape = typename T::template find_elem_t<
107 static constexpr uint32_t wg_tile_shape_n
108 = wg_tile_shape::template dim<0>();
109 static constexpr uint32_t wg_tile_shape_m
110 = wg_tile_shape::template dim<1>();
111 static constexpr uint32_t wg_tile_k
112 = T::template find_elem_v<tune_key::wg_tile_k>;
113 static constexpr uint32_t sg_tile_shape_n
114 = sg_tile_shape::template dim<0>();
115 static constexpr uint32_t sg_tile_shape_m
116 = sg_tile_shape::template dim<1>();
117
120 };
121
122 template <typename T>
124
131 };
132 };
133
135 using type = typename dict_t_::template update_dict_t<
136 typename update_config::to_dict>;
137};
138
139template <typename dict_t_>
141 struct impl {
142 template <uint32_t global_kslicing_ratio_,
143 uint32_t local_kslicing_ratio_, uint32_t wg_tile_shape_n_,
144 uint32_t wg_tile_shape_m_, uint32_t wg_tile_k_,
145 uint32_t sg_tile_shape_n_, uint32_t sg_tile_shape_m_>
147 static constexpr uint32_t global_kslicing_ratio
148 = global_kslicing_ratio_;
149 static constexpr uint32_t local_kslicing_ratio
150 = local_kslicing_ratio_;
151 static constexpr uint32_t wg_tile_shape_n = wg_tile_shape_n_;
152 static constexpr uint32_t wg_tile_shape_m = wg_tile_shape_m_;
153 static constexpr uint32_t wg_tile_k = wg_tile_k_;
154 static constexpr uint32_t sg_tile_shape_n = sg_tile_shape_n_;
155 static constexpr uint32_t sg_tile_shape_m = sg_tile_shape_m_;
156
160
172
173 template <template <typename> typename G>
174 using apply = typename G<this_t>::type;
175 };
176
177 template <typename T>
179 static constexpr uint32_t global_kslicing_ratio
180 = T::template find_elem_v<tune_key::global_kslicing_ratio>;
181 static constexpr uint32_t local_kslicing_ratio
182 = T::template find_elem_v<tune_key::local_kslicing_ratio>;
183 using wg_tile_shape = typename T::template find_elem_t<
185 using sg_tile_shape = typename T::template find_elem_t<
187 static constexpr uint32_t wg_tile_shape_n
188 = wg_tile_shape::template dim<0>();
189 static constexpr uint32_t wg_tile_shape_m
190 = wg_tile_shape::template dim<1>();
191 static constexpr uint32_t wg_tile_k
192 = T::template find_elem_v<tune_key::wg_tile_k>;
193 static constexpr uint32_t sg_tile_shape_n
194 = sg_tile_shape::template dim<0>();
195 static constexpr uint32_t sg_tile_shape_m
196 = sg_tile_shape::template dim<1>();
197
201 };
202
203 template <typename T>
205
208
209 template <typename T>
211 static constexpr uint32_t global_kslicing_ratio
212 = T::global_kslicing_ratio;
213 static constexpr uint32_t local_kslicing_ratio
214 = T::local_kslicing_ratio;
215
216 static constexpr uint32_t wg_tile_shape_n
217 = (local_kslicing_ratio == 2) ? 128
218 : T::wg_tile_shape_n;
219 static constexpr uint32_t wg_tile_shape_m
220 = (local_kslicing_ratio == 2) ? 64 : T::wg_tile_shape_m;
221 static constexpr uint32_t wg_tile_k
222 = (local_kslicing_ratio == 2) ? 32 : T::wg_tile_k;
223 static constexpr uint32_t sg_tile_shape_n
224 = (local_kslicing_ratio == 2) ? 32 : T::sg_tile_shape_n;
225 static constexpr uint32_t sg_tile_shape_m
226 = (local_kslicing_ratio == 2) ? 16 : T::sg_tile_shape_m;
227
228 using type = typename dict_t<
229 elem_t_t<1U,
234 elem_t_t<2U,
236 local_kslicing_ratio, 128, 64, 32, 32,
237 16>>,
238 elem_t_t<4U,
240 local_kslicing_ratio, 64, 64, 32, 32,
241 16>>,
242 elem_t_t<8U,
244 local_kslicing_ratio, 64, 32, 32, 32,
245 16>>,
246 elem_t_t<16U,
248 local_kslicing_ratio, 64, 16, 32, 32,
249 16>>>::
250 template find_elem_t<local_kslicing_ratio>::type;
251 };
252
253 using type = typename orig::template apply<local_kslicing_handler>;
254 };
255 };
256
258 using type = typename std::conditional<
259 (dict_t_::template find_elem_v<tune_key::
260 dispatch_policy> == tune_key_value::dispatch_policy_kslicing),
261 typename dict_t_::template update_dict_t<
262 typename update_config::to_dict>,
263 dict_t_>::type;
264};
265} // namespace decision_tree_rule
266
267template <typename dict_t_, typename opt_dict_t_>
269 using type = typename opt_dict_t_::template update_t<
271 typename dict_t_::template find_elem_t<
274 typename dict_t_::template find_elem_t<
277 typename dict_t_::template find_elem_t<
280 dict_t_::template find_elem_v<tune_key::memory_layout_a>>,
282 dict_t_::template find_elem_v<tune_key::memory_layout_b>>,
284 dict_t_::template find_elem_v<tune_key::memory_layout_c>>,
286 dict_t_::template find_elem_v<
289 dict_t_::template find_elem_v<
292 dict_t_::template find_elem_v<
295 dict_t_::template find_elem_v<tune_key::gpu_arch>>>;
296};
297
298template <param_optimizer_tag tag_, typename dict_t_, typename... candidates_t>
300 struct impl {
301 using type = typename dict_t_ ::template update_generator_t<
303 template update_generator_t<
305 template update_generator_t<
308 };
309 static constexpr bool use_fallback
310 = !(param_optimizer_base::template validate_attribute<dict_t_,
311 typename impl::type>::value);
312 using type = typename std::conditional<use_fallback,
313 typename impl::fallback_type, impl>::type::type;
314};
315
316} // namespace gpu::xetla
Definition arch_config.hpp:24
param_optimizer_tag
Definition common.hpp:70
Definition decision_tree_policy.hpp:300
typename dict_t_ ::template update_generator_t< decision_tree_rule::data_type_handler >::template update_generator_t< decision_tree_rule::tile_shape_handler >::template update_generator_t< decision_tree_rule::kslicing_handler > type
Definition decision_tree_policy.hpp:306
fallback_optimizer< dict_t_, type > fallback_type
Definition decision_tree_policy.hpp:307
Definition decision_tree_policy.hpp:299
typename std::conditional< use_fallback, typename impl::fallback_type, impl >::type::type type
Definition decision_tree_policy.hpp:313
static constexpr bool use_fallback
Definition decision_tree_policy.hpp:310
Definition decision_tree_policy.hpp:25
typename T::template update_dict_t< dict_t< elem_t_t< tune_key::data_type_acc, typename T::template find_elem_t< tune_key::data_type_acc >::type > > > type
Definition decision_tree_policy.hpp:28
static constexpr uint32_t sg_tile_shape_n
Definition decision_tree_policy.hpp:194
static constexpr uint32_t local_kslicing_ratio
Definition decision_tree_policy.hpp:182
kslicing_config< global_kslicing_ratio, local_kslicing_ratio, wg_tile_shape_n, wg_tile_shape_m, wg_tile_k, sg_tile_shape_n, sg_tile_shape_m > type
Definition decision_tree_policy.hpp:200
static constexpr uint32_t wg_tile_shape_m
Definition decision_tree_policy.hpp:190
typename T::template find_elem_t< tune_key::sg_tile_shape >::type sg_tile_shape
Definition decision_tree_policy.hpp:186
static constexpr uint32_t sg_tile_shape_m
Definition decision_tree_policy.hpp:196
typename T::template find_elem_t< tune_key::wg_tile_shape >::type wg_tile_shape
Definition decision_tree_policy.hpp:184
static constexpr uint32_t wg_tile_shape_n
Definition decision_tree_policy.hpp:188
static constexpr uint32_t wg_tile_k
Definition decision_tree_policy.hpp:192
static constexpr uint32_t global_kslicing_ratio
Definition decision_tree_policy.hpp:180
static constexpr uint32_t sg_tile_shape_m
Definition decision_tree_policy.hpp:155
typename G< this_t >::type apply
Definition decision_tree_policy.hpp:174
static constexpr uint32_t wg_tile_k
Definition decision_tree_policy.hpp:153
static constexpr uint32_t local_kslicing_ratio
Definition decision_tree_policy.hpp:150
static constexpr uint32_t wg_tile_shape_m
Definition decision_tree_policy.hpp:152
static constexpr uint32_t sg_tile_shape_n
Definition decision_tree_policy.hpp:154
static constexpr uint32_t wg_tile_shape_n
Definition decision_tree_policy.hpp:151
static constexpr uint32_t global_kslicing_ratio
Definition decision_tree_policy.hpp:148
static constexpr uint32_t global_kslicing_ratio
Definition decision_tree_policy.hpp:212
static constexpr uint32_t sg_tile_shape_n
Definition decision_tree_policy.hpp:224
typename dict_t< elem_t_t< 1U, kslicing_config< global_kslicing_ratio, local_kslicing_ratio, wg_tile_shape_n, wg_tile_shape_m, wg_tile_k, sg_tile_shape_n, sg_tile_shape_m > >, elem_t_t< 2U, kslicing_config< global_kslicing_ratio, local_kslicing_ratio, 128, 64, 32, 32, 16 > >, elem_t_t< 4U, kslicing_config< global_kslicing_ratio, local_kslicing_ratio, 64, 64, 32, 32, 16 > >, elem_t_t< 8U, kslicing_config< global_kslicing_ratio, local_kslicing_ratio, 64, 32, 32, 32, 16 > >, elem_t_t< 16U, kslicing_config< global_kslicing_ratio, local_kslicing_ratio, 64, 16, 32, 32, 16 > > >::template find_elem_t< local_kslicing_ratio >::type type
Definition decision_tree_policy.hpp:250
static constexpr uint32_t sg_tile_shape_m
Definition decision_tree_policy.hpp:226
static constexpr uint32_t local_kslicing_ratio
Definition decision_tree_policy.hpp:214
static constexpr uint32_t wg_tile_shape_m
Definition decision_tree_policy.hpp:220
static constexpr uint32_t wg_tile_k
Definition decision_tree_policy.hpp:222
static constexpr uint32_t wg_tile_shape_n
Definition decision_tree_policy.hpp:217
typename orig::template apply< local_kslicing_handler > type
Definition decision_tree_policy.hpp:253
from_dict< dict_t_ > orig
Definition decision_tree_policy.hpp:207
Definition decision_tree_policy.hpp:141
typename from_dict_impl< T >::type from_dict
Definition decision_tree_policy.hpp:204
Definition decision_tree_policy.hpp:140
typename std::conditional<(dict_t_::template find_elem_v< tune_key::dispatch_policy >==tune_key_value::dispatch_policy_kslicing), typename dict_t_::template update_dict_t< typename update_config::to_dict >, dict_t_ >::type type
Definition decision_tree_policy.hpp:263
typename impl::update_config_impl::type update_config
Definition decision_tree_policy.hpp:257
typename std::conditional< use_next, typename nxt::type, cur_type >::type type
Definition decision_tree_policy.hpp:90
static constexpr int distance
Definition decision_tree_policy.hpp:79
typename T::template find_elem_t< tune_key::wg_tile_shape >::type wg_tile_shape
Definition decision_tree_policy.hpp:104
static constexpr uint32_t wg_tile_shape_n
Definition decision_tree_policy.hpp:108
tile_shape_config< wg_tile_shape_n, wg_tile_shape_m, wg_tile_k, sg_tile_shape_n, sg_tile_shape_m > type
Definition decision_tree_policy.hpp:119
static constexpr uint32_t wg_tile_k
Definition decision_tree_policy.hpp:112
typename T::template find_elem_t< tune_key::sg_tile_shape >::type sg_tile_shape
Definition decision_tree_policy.hpp:106
static constexpr uint32_t wg_tile_shape_m
Definition decision_tree_policy.hpp:110
static constexpr uint32_t sg_tile_shape_n
Definition decision_tree_policy.hpp:114
static constexpr uint32_t sg_tile_shape_m
Definition decision_tree_policy.hpp:116
static constexpr uint32_t sg_tile_shape_m
Definition decision_tree_policy.hpp:42
static constexpr uint32_t wg_tile_shape_n
Definition decision_tree_policy.hpp:38
static constexpr uint32_t wg_tile_shape_m
Definition decision_tree_policy.hpp:39
static constexpr uint32_t sg_tile_shape_n
Definition decision_tree_policy.hpp:41
static constexpr uint32_t wg_tile_k
Definition decision_tree_policy.hpp:40
typename find_min_elem< orig, wg_256x256_k32_sg_32x64, wg_256x256_k32_sg_64x32, wg_128x512_k16_sg_32x64, wg_512x128_k16_sg_64x32, wg_32x256_k32_sg_16x16, wg_512x64_k32_sg_32x32, wg_64x64_k32_sg_16x8 >::type type
Definition decision_tree_policy.hpp:130
from_dict< dict_t_ > orig
Definition decision_tree_policy.hpp:126
Definition decision_tree_policy.hpp:33
static constexpr int distance_fcn()
Definition decision_tree_policy.hpp:58
tile_shape_config< 32, 256, 32, 16, 16 > wg_32x256_k32_sg_16x16
Definition decision_tree_policy.hpp:72
tile_shape_config< 128, 512, 16, 32, 64 > wg_128x512_k16_sg_32x64
Definition decision_tree_policy.hpp:70
tile_shape_config< 256, 256, 32, 64, 32 > wg_256x256_k32_sg_64x32
Definition decision_tree_policy.hpp:69
tile_shape_config< 512, 64, 32, 32, 32 > wg_512x64_k32_sg_32x32
Definition decision_tree_policy.hpp:73
tile_shape_config< 256, 256, 32, 32, 64 > wg_256x256_k32_sg_32x64
Definition decision_tree_policy.hpp:68
tile_shape_config< 512, 128, 16, 64, 32 > wg_512x128_k16_sg_64x32
Definition decision_tree_policy.hpp:71
static constexpr int const_abs(const int &z)
Definition decision_tree_policy.hpp:52
typename from_dict_impl< T >::type from_dict
Definition decision_tree_policy.hpp:123
Definition decision_tree_policy.hpp:32
typename impl::update_config_impl::type update_config
Definition decision_tree_policy.hpp:134
typename dict_t_::template update_dict_t< typename update_config::to_dict > type
Definition decision_tree_policy.hpp:136
Definition dict.hpp:103
Definition dict.hpp:97
Definition dict.hpp:100
Definition decision_tree_policy.hpp:268
typename opt_dict_t_::template update_t< elem_t_t< tune_key::data_type_a, typename dict_t_::template find_elem_t< tune_key::data_type_a >::type >, elem_t_t< tune_key::data_type_b, typename dict_t_::template find_elem_t< tune_key::data_type_b >::type >, elem_t_t< tune_key::data_type_c, typename dict_t_::template find_elem_t< tune_key::data_type_c >::type >, elem_v_t< tune_key::memory_layout_a, dict_t_::template find_elem_v< tune_key::memory_layout_a > >, elem_v_t< tune_key::memory_layout_b, dict_t_::template find_elem_v< tune_key::memory_layout_b > >, elem_v_t< tune_key::memory_layout_c, dict_t_::template find_elem_v< tune_key::memory_layout_c > >, elem_v_t< tune_key::memory_alignment_a, dict_t_::template find_elem_v< tune_key::memory_alignment_a > >, elem_v_t< tune_key::memory_alignment_b, dict_t_::template find_elem_v< tune_key::memory_alignment_b > >, elem_v_t< tune_key::memory_alignment_c, dict_t_::template find_elem_v< tune_key::memory_alignment_c > >, elem_v_t< tune_key::gpu_arch, dict_t_::template find_elem_v< tune_key::gpu_arch > > > type
Definition decision_tree_policy.hpp:295
Definition common.hpp:75
Definition dict.hpp:59