XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
dispatch_policy.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "kernel/gemm/common.hpp"
23
24namespace gpu::xetla::kernel {
27
31template <gpu_arch arch_tag_>
33public:
34 static constexpr gpu_arch arch_tag = arch_tag_;
35
36 inline group_swizzle_default() = default;
37
38 template <int idx>
39 static __XETLA_API int get_tile_idx(sycl::nd_item<3> &item) {
40 return item.get_group(idx);
41 }
42 // correct group range, nothing will be done under this swizzle policy
44 [[maybe_unused]] uint32_t &group_range_m,
45 [[maybe_unused]] uint32_t &group_range_n) {}
46};
47
52template <int wg_num_n_, gpu_arch arch_tag_>
54public:
55 static constexpr gpu_arch arch_tag = arch_tag_;
56 inline group_swizzle_snake() = default;
57
58 // get dim0 group id
59 template <int idx>
60 static __XETLA_API typename std::enable_if_t<idx == 0, int> get_tile_idx(
61 sycl::nd_item<3> &item) {
62 return item.get_group(idx);
63 }
64 // get transformed dim1 group id
65 template <int idx>
66 static __XETLA_API typename std::enable_if_t<idx == 1, int> get_tile_idx(
67 sycl::nd_item<3> &item) {
68 uint32_t group_range_n = item.get_group_range(2);
69 uint32_t wg_repeat_n = group_range_n / wg_num_n;
70 uint32_t repeat_id = get_2d_group_linear_id(item) / max_wg_num;
71 uint32_t repeat_id_m = repeat_id / wg_repeat_n;
72 uint32_t repeat_start_m = repeat_id_m * wg_num_m;
73 uint32_t wg_inner_id = get_2d_group_linear_id(item) % max_wg_num;
74 uint32_t wg_coord_m = wg_inner_id / wg_num_n;
75 int start_m_id = repeat_start_m + wg_coord_m;
76 return start_m_id;
77 }
78 // get transformed dim2 group id
79 template <int idx>
80 static __XETLA_API typename std::enable_if_t<idx == 2, int> get_tile_idx(
81 sycl::nd_item<3> &item) {
82 uint32_t group_range_n = item.get_group_range(2);
83 uint32_t wg_repeat_n = group_range_n / wg_num_n;
84 uint32_t repeat_id = get_2d_group_linear_id(item) / max_wg_num;
85 uint32_t repeat_id_n = repeat_id % wg_repeat_n;
86 uint32_t repeat_id_m = repeat_id / wg_repeat_n;
87 uint32_t repeat_start_n_0 = repeat_id_n * wg_num_n;
88 uint32_t repeat_start_n_1 = (wg_repeat_n - repeat_id_n - 1) * wg_num_n;
89 uint32_t repeat_start_n
90 = (repeat_id_m & 1) == 0 ? repeat_start_n_0 : repeat_start_n_1;
91 uint32_t wg_inner_id = get_2d_group_linear_id(item) % max_wg_num;
92 uint32_t wg_coord_n = wg_inner_id % wg_num_n;
93 int start_n_id = repeat_start_n + wg_coord_n;
94 return start_n_id;
95 }
96 // correct group range, workgroup will be padded to fit the given wg_num_n
97 // under this swizzle policy
99 uint32_t &group_range_m, uint32_t &group_range_n) {
100 group_range_m = (group_range_m + wg_num_m - 1) / wg_num_m * wg_num_m;
101 group_range_n = (group_range_n + wg_num_n - 1) / wg_num_n * wg_num_n;
102 }
103
104private:
105 static constexpr uint32_t max_wg_num = arch_attr_t<arch_tag>::max_wg_num;
106 static constexpr uint32_t wg_num_n = wg_num_n_;
107 static_assert(!(max_wg_num % wg_num_n),
108 "max_wg_num cannot be divisible by given wg_num_n!");
109 static constexpr uint32_t wg_num_m = max_wg_num / wg_num_n;
110};
111
115template <typename group_swizzle_policy_>
117 using group_swizzle_policy = group_swizzle_policy_;
118 static constexpr gpu_arch arch_tag = group_swizzle_policy::arch_tag;
119};
120
127template <typename group_swizzle_policy_, int global_ratio_ = 1,
128 int local_ratio_ = 1>
130 using group_swizzle_policy = group_swizzle_policy_;
131 static constexpr int global_ratio = global_ratio_;
132 static constexpr int local_ratio = local_ratio_;
133 static constexpr gpu_arch arch_tag = group_swizzle_policy::arch_tag;
134};
135
141template <gpu_arch arch_tag_ = gpu_arch::Xe>
143
144 static constexpr gpu_arch arch_tag = arch_tag_;
145
146 uint32_t matrix_m;
147 uint32_t matrix_k;
148 uint32_t matrix_n;
149
150 uint32_t wg_tile_m;
151 uint32_t wg_tile_k;
152 uint32_t wg_tile_n;
153
154 uint32_t sg_tile_m;
155 uint32_t sg_tile_n;
156
159
161 uint32_t dp_groups;
162
163 uint32_t sk_tiles;
164 uint32_t sk_waves;
167 uint32_t sk_regions;
169
170 //FastDivMod counters initialized in host to use multiply and shift operations in kernel code for modulus and division
179
181 static int const kMinItersPerSkGroup = 2;
182
183 //Host+Device interface functions
184
187 cl::sycl::range<3> get_group_range() const {
188 cl::sycl::range<3> group_range
189 = cl::sycl::range<3> {1, 1, num_workgroups};
190 return group_range;
191 };
192
194 void get_sk_workgroups(int &sk_groups,
195 int &savings_iters,
196 int sk_tiles, int iters_per_tile, int avail_xecores,
197 bool allow_partial_wave) const {
198
199 savings_iters = INT_MIN;
200 sk_groups = 0;
201
202 if (sk_tiles == 0) { return; }
203
204 int sk_iters = sk_tiles * iters_per_tile;
205
206 int dp_equiv_waves = (sk_tiles + avail_xecores - 1) / avail_xecores;
207 int dp_equiv_iters = iters_per_tile * dp_equiv_waves;
208
209 int min_sk_groups = (allow_partial_wave)
210 ? std::min(avail_xecores, sk_tiles + 1)
212 int max_sk_groups
213 = std::min(avail_xecores, sk_iters / kMinItersPerSkGroup);
214
215 for (int trial_sk_groups = min_sk_groups;
216 trial_sk_groups <= max_sk_groups; trial_sk_groups++) {
217
218 int sk_waves
219 = (trial_sk_groups + avail_xecores - 1) / avail_xecores;
220 int max_sk_iters_per_group
221 = (sk_iters + trial_sk_groups - 1) / trial_sk_groups;
222 int sk_iter_equiv = max_sk_iters_per_group * sk_waves;
223
224 int num_peers = ((trial_sk_groups + sk_tiles - 1) / sk_tiles) + 1;
225 float iter_cost = 0.02f * float(num_peers) * float(sk_iter_equiv);
226
227 if (trial_sk_groups % sk_tiles == 0) {
228
229 //aligned
230 num_peers = (trial_sk_groups / sk_tiles);
231 iter_cost = 0.0f;
232 }
233
234 float peer_cost = 2.0f * float(num_peers);
235 float base_cost = 2.0f * float(sk_waves);
236
237 int fixup_iter_equiv = int(base_cost + iter_cost + peer_cost);
238
239 int trial_savings_iter
240 = dp_equiv_iters - sk_iter_equiv - fixup_iter_equiv;
241
242 if (trial_savings_iter >= savings_iters) {
243
244 savings_iters = trial_savings_iter;
245 sk_groups = trial_sk_groups;
246 }
247 }
248 }
249
251 void get_groups(int &dp_tiles, int &sk_groups, int output_tiles,
252 int iters_per_tile, int avail_xecores) {
253
254 int full_waves = output_tiles / avail_xecores;
255 int full_wave_tiles = full_waves * avail_xecores;
256 int partial_wave_tiles = output_tiles - full_wave_tiles;
257
258 if (partial_wave_tiles == 0) {
259 //No tails
260 return;
261 }
262 int score = -1;
263 dp_tiles = output_tiles;
264 sk_groups = 0;
265
266 if (full_waves < 1) {
267
268 dp_tiles = full_wave_tiles;
269
270 get_sk_workgroups(sk_groups, score, partial_wave_tiles,
271 iters_per_tile, avail_xecores, true);
272
273 if (score < 0) {
274 //Not profitable
275 dp_tiles = output_tiles;
276 sk_groups = 0;
277 }
278
279 return;
280 }
281
282 //Form the SK wave by combining the last full wave and the partial wave
283 dp_tiles = full_wave_tiles - avail_xecores;
284
285 get_sk_workgroups(sk_groups, score, partial_wave_tiles + avail_xecores,
286 iters_per_tile, avail_xecores,
287 false); // cannot run with less than a full wave of SK-groups
288
289 std::cout << "SK Score: " << score << "\n\n";
290
291 if (score < 0) { //Not profitable for stream_k split
292
293 sk_groups = 0;
294 dp_tiles = output_tiles;
295 }
296 }
297
299 inline dispatch_policy_stream_k() = default;
300
302 //static constexpr bool host_callable = true;
303
304 inline dispatch_policy_stream_k(uint32_t matrix_m_, uint32_t matrix_k_,
305 uint32_t matrix_n_, uint32_t wg_tile_m_, uint32_t wg_tile_k_,
306 uint32_t wg_tile_n_, uint32_t sg_tile_m_, uint32_t sg_tile_n_,
307 uint32_t avail_xecores_ = arch_attr_t<arch_tag>::max_wg_num)
308 : matrix_m(matrix_m_)
309 , matrix_k(matrix_k_)
310 , matrix_n(matrix_n_)
311 , wg_tile_m(wg_tile_m_)
312 , wg_tile_k(wg_tile_k_)
313 , wg_tile_n(wg_tile_n_)
314 , sg_tile_m(sg_tile_m_)
315 , sg_tile_n(sg_tile_n_)
316 , avail_xecores(avail_xecores_) {
317
318 int iters_per_tile = (matrix_k + wg_tile_k - 1) / wg_tile_k;
319
320 //Default values for sk parameters
321 int sk_iters_per_normal_group = 0;
322 int sk_iters_per_big_group = 0;
323
324 // Default : a single region of iteration space across all SK tiles
325 sk_regions = 1;
326
328 sk_waves = 0;
331
332 int num_tiles_m = (matrix_m + wg_tile_m - 1) / wg_tile_m;
333 int num_tiles_n = (matrix_n + wg_tile_n - 1) / wg_tile_n;
334
335 int output_tiles = num_tiles_m * num_tiles_n;
336
337 int dp_tiles = output_tiles;
338 int sk_groups = 0;
339
340 //Use heuristics to get stream_k split
341 get_groups(dp_tiles, sk_groups, output_tiles, iters_per_tile,
343
344 sk_tiles = output_tiles - dp_tiles;
345
346 // Compute SK group iteration details
347 if (sk_groups > 0) {
348
349 sk_waves = (sk_groups + avail_xecores - 1) / avail_xecores;
350 //Compute global iteration space - tiles_m*tiles_n*k_iters
351 int sk_iters = sk_tiles * iters_per_tile;
352 sk_groups = std::min(sk_groups, sk_iters);
353
354 //sk_iters may not divide sk_groups evenly; some groups perform one additional iteration
355 sk_iters_per_normal_group = sk_iters / sk_groups;
356 int extra_sk_iters
357 = sk_iters - (sk_iters_per_normal_group * sk_groups);
358 int sk_big_groups = extra_sk_iters;
359 sk_iters_per_big_group = sk_iters_per_normal_group + 1;
360
361 //KSlicing to fill up multiple regions within groups
362 uint32_t current_sk_gruops = sk_groups;
363 if ((current_sk_gruops > sk_tiles) && (sk_groups % sk_tiles == 0)) {
364
366 }
367
368 sk_groups_per_region = sk_groups / sk_regions;
369 sk_big_groups_per_region = sk_big_groups / sk_regions;
370 sk_iters_per_region = sk_iters / sk_regions;
371
372 //Initialize fast divmod counters related to SK
376 = FastDivMod(sk_iters_per_normal_group);
377 div_mod_sk_iters_per_big_group = FastDivMod(sk_iters_per_big_group);
378 }
379
380 div_mod_tiles_m = FastDivMod(num_tiles_m);
381 div_mod_tiles_n = FastDivMod(num_tiles_n);
382 div_mod_iters_per_tile = FastDivMod(iters_per_tile);
383
384 dp_groups = dp_tiles;
386
387 //Print the stats
388 uint32_t total_tiles = num_tiles_m * num_tiles_n;
389 std::cout << " problem size: (" << matrix_m << "," << matrix_n << ")"
390 << ", tiled_shape: (" << num_tiles_m << "," << num_tiles_n
391 << ")"
392 << ", tiles: " << total_tiles
393 << ", dp_tiles: " << total_tiles - sk_tiles
394 << ", sk_tiles: " << sk_tiles
395 << ", iters_per_tile: " << iters_per_tile
396 << ", num_workgroups: " << num_workgroups
397 << ", dp_workgroups: " << dp_groups
398 << ", dp_waves: " << dp_groups / avail_xecores
399 << ", sk_groups_per_region: " << sk_groups_per_region
400 << ", sk_regions: " << sk_regions
401 << ", sk_waves: " << sk_waves
402 << ", sk_iters_per_normal_group: "
403 << sk_iters_per_normal_group
404 << ", sk_big_groups_per_region: " << sk_big_groups_per_region
405 << ", avail_xecores: " << avail_xecores << "\n\n";
406 }
407
410 return (sk_waves * avail_xecores) + dp_groups;
411 }
412
415
416 return static_cast<int>(div_mod_iters_per_tile);
417 }
418
421
422 return static_cast<int>(div_mod_sk_iters_per_normal_group);
423 }
424
427
428 return static_cast<int>(div_mod_sk_regions);
429 }
430
433
434 return static_cast<int>(div_mod_sk_groups_per_region);
435 }
436
439 int tile_idx, int &tile_offset_m, int &tile_offset_n) const {
440
441 int tiles_m = static_cast<int>(div_mod_tiles_m);
442 int tiles_n = static_cast<int>(div_mod_tiles_n);
443 if (tiles_m > tiles_n) {
444 div_mod_tiles_n.fast_divmod(tile_offset_m, tile_offset_n, tile_idx);
445 } else {
446 div_mod_tiles_m.fast_divmod(tile_offset_n, tile_offset_m, tile_idx);
447 }
448 }
449
452
453 int tile_idx = div_mod_iters_per_tile.div(iter);
454 return tile_idx;
455 }
456
459 int &group_iter_begin, int &group_iter_end) const {
460 int region_idx;
461 int group_idx_in_region;
463 region_idx, group_idx_in_region, sk_group_idx);
464
465 group_iter_begin = (region_idx * sk_iters_per_region)
466 + (group_idx_in_region * get_sk_iters_per_normal_group());
467
468 //Adjust extents for the first num_big_group groups that get one extra iteration
469 int group_iters = get_sk_iters_per_normal_group();
470 uint32_t current_group_idx_in_region = group_idx_in_region;
471 if (current_group_idx_in_region < sk_big_groups_per_region) {
472
473 group_iter_begin += group_idx_in_region;
474 group_iters += 1;
475 } else {
476
477 //This is a regular group
478 group_iter_begin += sk_big_groups_per_region;
479 }
480
481 group_iter_end = group_iter_begin + group_iters;
482 }
483
486 int tile_idx, int group_idx) const {
487 uint32_t current_tile_idx = tile_idx;
488 if (current_tile_idx >= sk_tiles) {
489 //DP group
490 return group_idx;
491 }
492
493 int iter = tile_idx * get_iters_per_tile();
494
495 int region_idx, iter_in_region;
496
498 region_idx, iter_in_region, iter);
499
500 //Number of iterations in the big group region
501 int big_group_iters
504
505 //Number of iterations in the normal group region
506 int normal_group_iters = iter_in_region - big_group_iters;
507
508 uint32_t big_group_idx_in_region
509 = div_mod_sk_iters_per_big_group.div(iter_in_region);
510
511 int normal_group_idx_in_region = sk_big_groups_per_region
512 + div_mod_sk_iters_per_normal_group.div(normal_group_iters);
513
514 int group_idx_in_region
515 = (big_group_idx_in_region < sk_big_groups_per_region)
516 ? big_group_idx_in_region
517 : normal_group_idx_in_region;
518
519 int owning_group_idx = (get_sk_groups_per_region() * region_idx)
520 + group_idx_in_region;
521
522 return owning_group_idx;
523 }
524};
525
527
528} // namespace gpu::xetla::kernel
#define __XETLA_API
Definition common.hpp:43
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
__XETLA_API uint32_t get_2d_group_linear_id(sycl::nd_item< 3 > &item)
get linear group id of the last two dimensions.
Definition misc.hpp:180
Definition limitation.hpp:734
gpu_arch
Definition common.hpp:73
Fast division + modulus operation Host code pre-computes values to avoid expensive operations in kern...
Definition fastmath.hpp:46
__XETLA_API KERNEL_FUNC void fast_divmod(int &quotient, int &remainder, int dividend) const
Kernel side function to find quotient and remainder.
Definition fastmath.hpp:74
__XETLA_API KERNEL_FUNC int div(int dividend) const
kernel side utility functions for query of quotient
Definition fastmath.hpp:87
Definition arch_config.hpp:72
Default GEMM_UNIVERSAL implementation.
Definition dispatch_policy.hpp:116
group_swizzle_policy_ group_swizzle_policy
Definition dispatch_policy.hpp:117
static constexpr gpu_arch arch_tag
Definition dispatch_policy.hpp:118
Kslicing GEMM_UNIVERSAL implementation.
Definition dispatch_policy.hpp:129
static constexpr int local_ratio
Definition dispatch_policy.hpp:132
static constexpr int global_ratio
Definition dispatch_policy.hpp:131
group_swizzle_policy_ group_swizzle_policy
Definition dispatch_policy.hpp:130
static constexpr gpu_arch arch_tag
Definition dispatch_policy.hpp:133
StreamK GEMM implementation.
Definition dispatch_policy.hpp:142
FastDivMod div_mod_sk_iters_per_region
Definition dispatch_policy.hpp:177
FastDivMod div_mod_tiles_m
Definition dispatch_policy.hpp:171
void get_sk_workgroups(int &sk_groups, int &savings_iters, int sk_tiles, int iters_per_tile, int avail_xecores, bool allow_partial_wave) const
Host helper function to compute sk_groups to dispatch for a given number of sk_tiles.
Definition dispatch_policy.hpp:194
__XETLA_API KERNEL_FUNC int get_sk_tile_idx(int iter) const
Kernel function to return tile idx for current sk iteration.
Definition dispatch_policy.hpp:451
__XETLA_API KERNEL_FUNC void get_iter_extents(int sk_group_idx, int &group_iter_begin, int &group_iter_end) const
Kernel function to get iteration extends for stream_k split.
Definition dispatch_policy.hpp:458
FastDivMod div_mod_iters_per_tile
Definition dispatch_policy.hpp:173
uint32_t num_workgroups
Definition dispatch_policy.hpp:160
uint32_t sk_big_groups_per_region
Definition dispatch_policy.hpp:165
__XETLA_API KERNEL_FUNC void get_tile_offsets(int tile_idx, int &tile_offset_m, int &tile_offset_n) const
Kernel function to get tile offset for m and n.
Definition dispatch_policy.hpp:438
uint32_t matrix_n
Definition dispatch_policy.hpp:148
__XETLA_API KERNEL_FUNC int get_sk_iters_per_normal_group() const
Kernel helper function to return number of K-iters for normal sk groups.
Definition dispatch_policy.hpp:420
uint32_t sg_tile_n
Definition dispatch_policy.hpp:155
FastDivMod div_mod_sk_iters_per_big_group
Definition dispatch_policy.hpp:178
__XETLA_API KERNEL_FUNC int get_first_group_idx(int tile_idx, int group_idx) const
kernel function to get the first sk group index writing the sliced output tile;
Definition dispatch_policy.hpp:485
FastDivMod div_mod_tiles_n
Definition dispatch_policy.hpp:172
FastDivMod div_mod_sk_groups_per_region
Definition dispatch_policy.hpp:175
uint32_t wg_tile_n
Definition dispatch_policy.hpp:152
FastDivMod div_mod_sk_regions
Definition dispatch_policy.hpp:174
uint32_t sk_iters_per_region
Definition dispatch_policy.hpp:166
uint32_t wg_tile_k
Definition dispatch_policy.hpp:151
uint32_t matrix_m
Definition dispatch_policy.hpp:146
__XETLA_API KERNEL_FUNC int get_iters_per_tile() const
Kernel helper function to return number of K-iters per output tile.
Definition dispatch_policy.hpp:414
dispatch_policy_stream_k(uint32_t matrix_m_, uint32_t matrix_k_, uint32_t matrix_n_, uint32_t wg_tile_m_, uint32_t wg_tile_k_, uint32_t wg_tile_n_, uint32_t sg_tile_m_, uint32_t sg_tile_n_, uint32_t avail_xecores_=arch_attr_t< arch_tag >::max_wg_num)
Set for device copyable.
Definition dispatch_policy.hpp:304
uint32_t dp_groups
Definition dispatch_policy.hpp:161
__XETLA_API KERNEL_FUNC int get_sk_regions() const
Kernel helper function to return number of SK regions.
Definition dispatch_policy.hpp:426
cl::sycl::range< 3 > get_group_range() const
Host helper function to get the expected nd_range under the current GEMM config.
Definition dispatch_policy.hpp:187
FastDivMod div_mod_sk_iters_per_normal_group
Definition dispatch_policy.hpp:176
uint32_t sk_regions
Definition dispatch_policy.hpp:167
uint32_t sk_waves
Definition dispatch_policy.hpp:164
uint32_t sk_tiles
Number of data-parallel workgroups.
Definition dispatch_policy.hpp:163
int get_num_active_groups() const
Host helper function to return number of groups after stream_k split.
Definition dispatch_policy.hpp:409
uint32_t sg_tile_m
Definition dispatch_policy.hpp:154
uint32_t matrix_k
Definition dispatch_policy.hpp:147
static int const kMinItersPerSkGroup
Minimum number of MAC-iterations per streamk group.
Definition dispatch_policy.hpp:181
static constexpr gpu_arch arch_tag
Definition dispatch_policy.hpp:144
void get_groups(int &dp_tiles, int &sk_groups, int output_tiles, int iters_per_tile, int avail_xecores)
Determine the populations of DP and SK groups to invoke for the given number of output tiles.
Definition dispatch_policy.hpp:251
uint32_t wg_tile_m
Definition dispatch_policy.hpp:150
__XETLA_API KERNEL_FUNC int get_sk_groups_per_region() const
Kernel helper function to return number of SK groups per region.
Definition dispatch_policy.hpp:432
uint32_t sk_groups_per_region
Definition dispatch_policy.hpp:168
uint32_t avail_xecores
Number of xecores available for stream_k load balancing.
Definition dispatch_policy.hpp:158
Default GROUP_SWIZZLE implementation.
Definition dispatch_policy.hpp:32
static constexpr gpu_arch arch_tag
Definition dispatch_policy.hpp:34
static __XETLA_API void update_group_range(uint32_t &group_range_m, uint32_t &group_range_n)
Definition dispatch_policy.hpp:43
static __XETLA_API int get_tile_idx(sycl::nd_item< 3 > &item)
Definition dispatch_policy.hpp:39
GROUP_SWIZZLE implementation of snake curve.
Definition dispatch_policy.hpp:53
static __XETLA_API void update_group_range(uint32_t &group_range_m, uint32_t &group_range_n)
Definition dispatch_policy.hpp:98
static constexpr gpu_arch arch_tag
Definition dispatch_policy.hpp:55
static __XETLA_API std::enable_if_t< idx==0, int > get_tile_idx(sycl::nd_item< 3 > &item)
Definition dispatch_policy.hpp:60
static __XETLA_API std::enable_if_t< idx==2, int > get_tile_idx(sycl::nd_item< 3 > &item)
Definition dispatch_policy.hpp:80
static __XETLA_API std::enable_if_t< idx==1, int > get_tile_idx(sycl::nd_item< 3 > &item)
Definition dispatch_policy.hpp:66