XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
dummy_policy.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (c) 2022-2023 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17#pragma once
18
20
21namespace gpu::xetla {
22template <param_optimizer_tag tag_, typename dict_t_, typename... candidates_t>
24 struct impl {
25 enum class eval_tag : uint8_t {
26 et_type,
30 };
31
32 template <auto key_, typename T, typename U, eval_tag eval_tag_,
33 typename = void>
35
36 template <auto key_, typename T, typename U>
38 static constexpr int value = []() constexpr {
39 return (std::is_same<T, U>::value) ? 0 : 1;
40 }
41 ();
42 };
43
44 template <auto key_, typename T, typename U>
46 static constexpr int value = []() constexpr {
47 auto l = T::template find_elem_v<key_>;
48 auto r = U::template find_elem_v<key_>;
49 return (l == r) ? 0 : 1;
50 }
51 ();
52 };
53
54 template <typename T>
55 static constexpr T const_max(const T &l, const T &r) {
56 return (l > r) ? l : r;
57 }
58
59 template <typename T>
60 static constexpr T const_min(const T &l, const T &r) {
61 return (l > r) ? r : l;
62 }
63
64 template <auto key_, typename T, typename U>
66 static constexpr int value = []() constexpr {
67 auto l = T::template find_elem_v<key_>;
68 auto r = U::template find_elem_v<key_>;
69 return (const_max(l, r) - const_min(l, r));
70 }
71 ();
72 };
73
74 template <auto key_, typename T, typename U>
76 static constexpr int value = []() constexpr {
77 auto l = T::template find_elem_v<key_>;
78 auto r = U::template find_elem_v<key_>;
79 auto ret = (l - r);
80 return ret * ret;
81 }
82 ();
83 };
84
85 template <auto key_, typename T, typename U>
87 static constexpr int value = []() constexpr {
88 using eval_fcn = typename std::conditional<
89 ((key_ == tune_key::data_type_a)
90 || (key_ == tune_key::data_type_b)
91 || (key_ == tune_key::data_type_c)
92 || (key_ == tune_key::data_type_acc)
93 || (key_ == tune_key::epilogue_policy)),
95 typename std::conditional<
97 || (key_
98 == tune_key::
100 || (key_ == tune_key::wg_tile_k)
101 || (key_ == tune_key::prefetch_distance)
102 || (key_
103 == tune_key::
105 param_distance_eval_fcn<key_, T, U,
107 param_distance_eval_fcn<key_, T, U,
109 switch (key_) {
110 case tune_key::wg_tile_k: return 10 * eval_fcn::value;
113 return 1000 * eval_fcn::value;
114 case tune_key::data_type_acc: return 10 * eval_fcn::value;
115 default: return 10000000 * eval_fcn::value;
116 }
117 return 0;
118 }
119 ();
120 };
121
122 template <typename T, typename U>
124 static constexpr int value = []() constexpr {
125 using T_L = typename T::template find_elem_t<
127 using T_R = typename T::template find_elem_t<
129
130 int l_x = T_L::template dim<0>();
131 int l_y = T_L::template dim<1>();
132
133 int r_x = T_R::template dim<0>();
134 int r_y = T_R::template dim<1>();
135
136 return (const_max(l_x, r_x) - const_min(l_x, r_x))
137 + (const_max(l_y, r_y) - const_min(l_y, r_y));
138 }
139 ();
140 };
141
142 template <typename T, typename U>
144 static constexpr int value = []() constexpr {
145 using T_L = typename T::template find_elem_t<
147 using T_R = typename T::template find_elem_t<
149
150 int l_x = T_L::template dim<0>();
151 int l_y = T_L::template dim<1>();
152
153 int r_x = T_R::template dim<0>();
154 int r_y = T_R::template dim<1>();
155
156 return (const_max(l_x, r_x) - const_min(l_x, r_x))
157 + (const_max(l_y, r_y) - const_min(l_y, r_y));
158 }
159 ();
160 };
161
162 template <typename T, typename U>
164 static constexpr int value = []() constexpr {
165 int sum = 0;
168 U>::value;
170 U>::value;
173 U>::value;
175 U>::value;
178 U>::value;
180 U>::value;
182 U>::value;
183 if constexpr (tag_ == param_optimizer_tag::work_group) {
185 U>::value;
187 U>::value;
189 U>::value;
190 }
191 if constexpr (tag_ == param_optimizer_tag::kernel) {
193 T, U>::value;
195 T, U>::value;
196 }
198 U>::value;
200 if constexpr (tag_ == param_optimizer_tag::work_group) {
202 U>::value;
204 U>::value;
205 }
207 U>::value;
209 U>::value;
213 U>::value;
214 if constexpr (tag_ == param_optimizer_tag::kernel) {
216 U>::value;
218 T, U>::value;
219 }
220 return sum;
221 }
222 ();
223 };
224
225 template <int opt_val_, typename opt_t_, typename... elems>
227
228 template <int opt_val_, typename opt_t_>
229 struct finder_impl<opt_val_, opt_t_> {
230 using type = opt_t_;
231 static constexpr int value = opt_val_;
232 };
233
234 template <int opt_val_, typename opt_t_, typename elem_,
235 typename... elems>
236 struct finder_impl<opt_val_, opt_t_, elem_, elems...> {
237 static constexpr int can_val
239 using cur_opt_t = typename std::conditional<(can_val < opt_val_),
240 elem_, opt_t_>::type;
241 static constexpr int cur_opt_val = const_min(opt_val_, can_val);
242
243 using nxt_result = finder_impl<cur_opt_val, cur_opt_t, elems...>;
244
245 using type = typename nxt_result::type;
246 static constexpr int value = nxt_result::value;
247 };
248
249 template <typename opt_t_, typename... elems>
251 : finder_impl<param_distance<dict_t_, opt_t_>::value, opt_t_,
252 elems...> {};
253
254 using type = typename finder_impl_helper<candidates_t...>::type;
256 };
257 static constexpr bool use_fallback
258 = !(param_optimizer_base::template validate_attribute<dict_t_,
259 typename impl::type>::value);
260 using type = typename std::conditional<use_fallback,
261 typename impl::fallback_type, impl>::type::type;
262};
263
264} // namespace gpu::xetla
Definition arch_config.hpp:24
tune_key
Definition common.hpp:27
param_optimizer_tag
Definition common.hpp:70
typename std::conditional<(can_val< opt_val_), elem_, opt_t_ >::type cur_opt_t
Definition dummy_policy.hpp:240
typename nxt_result::type type
Definition dummy_policy.hpp:245
Definition dummy_policy.hpp:226
static constexpr int value
Definition dummy_policy.hpp:87
static constexpr int value
Definition dummy_policy.hpp:164
Definition dummy_policy.hpp:24
static constexpr T const_max(const T &l, const T &r)
Definition dummy_policy.hpp:55
eval_tag
Definition dummy_policy.hpp:25
typename finder_impl_helper< candidates_t... >::type type
Definition dummy_policy.hpp:254
static constexpr T const_min(const T &l, const T &r)
Definition dummy_policy.hpp:60
fallback_optimizer< dict_t_, type > fallback_type
Definition dummy_policy.hpp:255
Definition dummy_policy.hpp:23
typename std::conditional< use_fallback, typename impl::fallback_type, impl >::type::type type
Definition dummy_policy.hpp:261
static constexpr bool use_fallback
Definition dummy_policy.hpp:258
Definition decision_tree_policy.hpp:268
Definition common.hpp:75