XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
stream_k_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "kernel/gemm/api.hpp"
23#include "kernel/gemm/common.hpp"
24#include "kernel/gemm/dispatch_policy.hpp"
25namespace gpu::xetla::kernel {
26
29
34template <typename gemm_t_, typename epilogue_t_>
36 epilogue_t_> {
37 using gemm_t = gemm_t_;
38 using epilogue_t = epilogue_t_;
39 using gemm_args_t = typename gemm_t::arguments_t;
40 using epilogue_args_t = typename epilogue_t::arguments_t;
42
43 //Scratchspace to accumulate partials
44 using mem_desc_d_t
46 //Workspace to sync across xecores
49
50 using tile_shape = typename gemm_t::tile_shape;
51 static constexpr uint32_t wg_tile_m = tile_shape::wg_tile_size_y;
52 static constexpr uint32_t wg_tile_n = tile_shape::wg_tile_size_x;
53 static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y;
54 static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x;
55 static constexpr uint32_t wg_size_y = tile_shape::wg_size_y;
56 static constexpr uint32_t wg_size_x = tile_shape::wg_size_x;
57 static constexpr uint32_t real_wg_tile_m = sg_tile_m * wg_size_y;
58 static constexpr uint32_t real_wg_tile_n = sg_tile_n * wg_size_x;
59 //tile_k used in GEMMs
60 static constexpr uint32_t k_stride = gemm_t::k_stride;
61
62 using work_group_t = typename gemm_t::work_group_t;
63
64 static constexpr gpu_arch arch_tag = gpu_arch::Xe;
65 static_assert(arch_tag == gemm_t::arch_tag, "arch_tag should be the same");
66
67 using mem_desc_a_t = typename gemm_t::mem_desc_a_t;
68 using mem_desc_b_t = typename gemm_t::mem_desc_b_t;
69 using mem_desc_c_t = typename epilogue_t::mem_desc_c_t;
70 using matA_base_t = typename mem_desc_a_t::base_t;
71 using matB_base_t = typename mem_desc_b_t::base_t;
72 using matC_base_t = typename mem_desc_c_t::base_t;
73 using matD_base_t = typename mem_desc_d_t::base_t;
74 using matatomic_sync_base_t = typename mem_desc_atomic_sync_t::base_t;
75 using dtype_a = typename mem_desc_a_t::dtype;
76 using dtype_b = typename mem_desc_b_t::dtype;
77 using dtype_c = typename mem_desc_c_t::dtype;
78 using matAcc_t = typename gemm_t::matAcc_t;
81
82public:
85 struct arguments_t {
87 uint32_t matrix_m;
89 uint32_t matrix_k;
91 uint32_t matrix_n;
93 uint32_t matA_ld;
95 uint32_t matB_ld;
97 uint32_t matC_ld;
99 uint32_t matD_ld;
103 matA_base_t matA_base;
105 matB_base_t matB_base;
107 matC_base_t matC_base;
109 matD_base_t matD_base;
111 matatomic_sync_base_t matatomic_sync_base;
115 epilogue_args_t epilogue_args;
116
118 inline arguments_t() = default;
119 // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor)
120 // Please check if you need to add self-define destructor
121 // ~arguments_t(){}
122
124 static constexpr bool host_callable = true;
125
141 inline arguments_t(uint32_t matrix_m_, uint32_t matrix_k_,
142 uint32_t matrix_n_, matA_base_t matA_base_, uint32_t matA_ld_,
143 matB_base_t matB_base_, uint32_t matB_ld_,
144 matC_base_t matC_base_, uint32_t matC_ld_,
145 matD_base_t matD_base_, uint32_t matD_ld_,
146 matatomic_sync_base_t matatomic_sync_base_,
147 uint32_t matatomic_sync_ld_, dispatch_stream_k &stream_k_args_,
148 epilogue_args_t epilogue_args_ = {})
149 : matrix_m(matrix_m_)
150 , matrix_k(matrix_k_)
151 , matrix_n(matrix_n_)
152 , matA_ld(matA_ld_)
153 , matB_ld(matB_ld_)
154 , matC_ld(matC_ld_)
155 , matD_ld(matD_ld_)
156 , matatomic_sync_ld(matatomic_sync_ld_)
157 , matA_base(matA_base_)
158 , matB_base(matB_base_)
159 , matC_base(matC_base_)
160 , matD_base(matD_base_)
161 , matatomic_sync_base(matatomic_sync_base_)
162 , stream_k_args(stream_k_args_)
163 , epilogue_args(epilogue_args_) {}
164 // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor)
165 // Please check if you need to add self-define destructor
166 // inline ~arguments_t(){}
167 inline arguments_t(const arguments_t &args)
168 : matrix_m(args.matrix_m)
169 , matrix_k(args.matrix_k)
170 , matrix_n(args.matrix_n)
171 , matA_ld(args.matA_ld)
172 , matB_ld(args.matB_ld)
173 , matC_ld(args.matC_ld)
174 , matD_ld(args.matD_ld)
175 , matatomic_sync_ld(args.matatomic_sync_ld)
176 , matA_base(args.matA_base)
177 , matB_base(args.matB_base)
178 , matC_base(args.matC_base)
179 , matD_base(args.matD_base)
180 , matatomic_sync_base(args.matatomic_sync_base)
181 , stream_k_args(args.stream_k_args)
182 , epilogue_args(args.epilogue_args) {}
183 inline arguments_t &operator=(const arguments_t &args) {
184 this->matrix_m = args.matrix_m;
185 this->matrix_k = args.matrix_k;
186 this->matrix_n = args.matrix_n;
187 this->matA_base = args.matA_base;
188 this->matA_ld = args.matA_ld;
189 this->matB_base = args.matB_base;
190 this->matB_ld = args.matB_ld;
191 this->matC_base = args.matC_base;
192 this->matC_ld = args.matC_ld;
193 this->matD_base = args.matD_base;
194 this->matD_ld = args.matD_ld;
195 this->matatomic_sync_base = args.matatomic_sync_base;
196 this->matatomic_sync_ld = args.matatomic_sync_ld;
197 this->stream_k_args = args.stream_k_args;
198 this->epilogue_args = args.epilogue_args;
199 return *this;
200 }
201 };
202
206 __XETLA_API static constexpr uint32_t get_barrier_count() {
207 constexpr uint32_t count = gemm_t::barrier_count
208 + epilogue_t::barrier_count
209 + epilogue_stream_k_t::barrier_count;
210 static_assert(
211 count <= 32, "The named_barrier count should be less than 32!");
212 return count;
213 }
214
218 __XETLA_API static constexpr uint32_t get_slm_size() {
219 constexpr uint32_t size = gemm_t::slm_size + epilogue_t::slm_size;
220 static_assert(size <= (128 * 1024),
221 "The local memory size should be less than 128KB!");
222 return size;
223 };
224
227 static cl::sycl::range<3> get_local_range() {
228 uint32_t local_range_m = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
229 uint32_t local_range_n = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
230 std::cout << "Local range: {" << 1 << ", " << local_range_m << ", "
231 << local_range_n << "} \n";
232 assert(local_range_m * local_range_n <= 32);
233 //Linearize for stream_k algorithm
234 return cl::sycl::range<3> {1, 1, local_range_m * local_range_n};
235 };
236
239 static cl::sycl::nd_range<3> get_nd_range(arguments_t &args) {
240 cl::sycl::range<3> local_range = get_local_range();
241 cl::sycl::range<3> group_range = args.stream_k_args.get_group_range();
242 return cl::sycl::nd_range<3> {group_range * local_range, local_range};
243 };
244
249 static size_t get_acc_buf_size(dispatch_stream_k &stream_k_args) {
250 return stream_k_args.matrix_m * stream_k_args.matrix_n;
251 };
252
255 static size_t get_cnt_buf_size(dispatch_stream_k &stream_k_args) {
256
257 //For atomic reduction each SK group needs a synchronization flag
258 uint32_t num_flags
259 = stream_k_args.sk_regions * stream_k_args.sk_groups_per_region;
260 const int barrier_size = sizeof(uint32_t);
261 uint32_t atomic_space_bytes
262 = cacheline_align_up(num_flags * barrier_size);
263 uint32_t atomic_space_elements = atomic_space_bytes / barrier_size;
264
265 return atomic_space_elements;
266 };
267
271 static bool can_implement(arguments_t &args) {
272 bool implementable = true;
273 if (gemm_t::msg_type_a != msg_type::unaligned_2d) {
274 if (gemm_t::msg_type_a == msg_type::block_2d) {
275 implementable
277 (uint64_t)(args.matA_base.base),
278 gemm_t::is_col_major_a ? args.matrix_m
279 : args.matrix_k,
280 gemm_t::is_col_major_a ? args.matrix_k
281 : args.matrix_m,
282 args.matA_ld);
283 } else {
284 implementable &= kernel::general_1d<arch_tag,
285 dtype_a>::check_alignment(args.matA_base.base,
286 args.matA_ld);
287 }
288 }
289 if (gemm_t::msg_type_b != msg_type::unaligned_2d) {
290 if (gemm_t::msg_type_b == msg_type::block_2d) {
291 implementable
293 (uint64_t)(args.matB_base.base),
294 gemm_t::is_col_major_b ? args.matrix_k
295 : args.matrix_n,
296 gemm_t::is_col_major_b ? args.matrix_n
297 : args.matrix_k,
298 args.matB_ld);
299 } else {
300 implementable &= kernel::general_1d<arch_tag,
301 dtype_b>::check_alignment(args.matB_base.base,
302 args.matB_ld);
303 }
304 }
305 if (epilogue_t::msg_type_c != msg_type::unaligned_2d) {
306 if (epilogue_t::msg_type_c == msg_type::block_2d) {
307 implementable
309 (uint64_t)(args.matC_base.base), args.matrix_n,
310 args.matrix_m, args.matC_ld);
311 } else {
312 implementable &= kernel::general_1d<arch_tag,
313 dtype_c>::check_alignment(args.matC_base.base,
314 args.matC_ld);
315 }
316 }
317
318 return implementable;
319 }
320
321protected:
323 struct TileWorkDesc {
324
328
329 //The first global-scoped MAC-iteration this group will perform for this tile
331
332 //The starting index in the k-domain for MAC-iterations this group will perform for this tile
334
335 //The end index in the k-domain for MAC-iterations this group will perform for this tile
336 int k_end;
337
338 //The number of remaining MAC-iterations this group will perform for this tile
340 };
341
342 __XETLA_API static void init_dp_tile_work(TileWorkDesc &tile_work,
343 int tile_idx, int matrix_k, int iters_per_tile) {
344
345 //The first global-scoped MAC-iteration this workgroup will perform for this tile
346 tile_work.iter_begin = tile_idx * iters_per_tile;
347
348 //The number of MAC-iterations this workgroup will perform
349 tile_work.k_iters_remaining = iters_per_tile;
350
351 //The starting index in the k-domain for MAC iterations this workgroup will perform for this tile
352 tile_work.k_begin = 0;
353
354 //The ending index (one-past) in the k-domain for MAC-iterations this workgroup will perform for this tile
355 tile_work.k_end = matrix_k;
356 }
357
358 __XETLA_API static void init_sk_tile_work(TileWorkDesc &tile_work,
359 int tile_idx, int group_iter_begin, int group_iter_end,
360 int matrix_k, int iters_per_tile,
361 [[maybe_unused]] uint32_t stride_k) {
362
363 //The first global-scoped MAC iteration for this tile
364 int tile_iter_begin = tile_idx * iters_per_tile;
365
366 //The first global-scoped MAC-iteration this workgroup will perform for this tile
367 tile_work.iter_begin = xetla_max(group_iter_begin, tile_iter_begin);
368
369 //The first tile-scoped MAC-iteration this workgroup will perform for this tile
370 int k_iter_begin = tile_work.iter_begin - tile_iter_begin;
371
372 //The last(one past) tile-scoped MAC-iteration this workgroup will perform for this tile
373 int k_iter_end = group_iter_end - tile_iter_begin;
374
375 //The number of MAC-iterations this workgroup will perform for this tile
376 tile_work.k_iters_remaining = k_iter_end - k_iter_begin;
377
378 //Starting index in the k-domain for MAC-iterations this workgroup will perform for this tile
379 tile_work.k_begin = k_iter_begin * k_stride;
380
381 //Ending index (one past) in the k-domain for MAC-iterations this workgroup will perform for this tile
382 tile_work.k_end = xetla_min(matrix_k, int(k_iter_end * k_stride));
383 }
384
385public:
392 __XETLA_API KERNEL_FUNC void operator()(sycl::nd_item<3> &item,
393 const arguments_t &args, uint32_t slm_base = 0,
394 uint32_t nbarrier_base = 0) {
395 const dispatch_stream_k &workgroup_mapping = args.stream_k_args;
396 int group_idx = item.get_group(2);
397
398 int tile_idx = 0;
399 int group_iter_begin = 0;
400 int group_iters_remaining = 0;
401
402 uint32_t gemm_slm_base = slm_base;
403 uint32_t gemm_nbarr_base = nbarrier_base;
404 uint32_t epilogue_slm_base = gemm_slm_base + gemm_t::slm_size;
405 uint32_t epilogue_nbarr_base = gemm_nbarr_base + gemm_t::barrier_count;
406
407 int iters_per_tile = workgroup_mapping.get_iters_per_tile();
408 int dp_start_group_idx
409 = workgroup_mapping.sk_waves * workgroup_mapping.avail_xecores;
410
411 bool dp_group = (group_idx >= dp_start_group_idx);
412
413 //setup for matatomic_sync / flag space
414 mem_desc_atomic_sync_t mem_desc_atomic_sync;
415 mem_desc_atomic_sync.init(args.matatomic_sync_base,
416 {args.matatomic_sync_ld, 1, args.matatomic_sync_ld}, {0, 0});
417
418 //Initialize tile-work descriptor
419 TileWorkDesc tile_work;
420
421 if (dp_group) {
422
423 int dp_group_idx = group_idx - dp_start_group_idx;
424 int first_dp_tile = workgroup_mapping.sk_tiles;
425
426 tile_idx = first_dp_tile + dp_group_idx;
427
428 group_iters_remaining = iters_per_tile;
429
430 init_dp_tile_work(
431 tile_work, tile_idx, args.matrix_k, group_iters_remaining);
432
433 } else {
434
435 //This is a SK group
436 int group_iter_end;
437 workgroup_mapping.get_iter_extents(
438 group_idx, group_iter_begin, group_iter_end);
439 group_iters_remaining = group_iter_end - group_iter_begin;
440
441 tile_idx = workgroup_mapping.get_sk_tile_idx(group_iter_end - 1);
442 init_sk_tile_work(tile_work, tile_idx, group_iter_begin,
443 group_iter_begin + group_iters_remaining, args.matrix_k,
444 iters_per_tile, k_stride);
445 }
446
447 //Tile offset in M and N
448 workgroup_mapping.get_tile_offsets(
449 tile_idx, tile_work.tile_offset_m, tile_work.tile_offset_n);
450
451 epilogue_stream_k_t epilogue_stream_k;
452
453 //StreamK processing loop body
454 while (true) {
456 {
457 // set up workgroup level coordinates and boundaries
458 int start_n = tile_work.tile_offset_n * wg_tile_n;
459 int start_m = tile_work.tile_offset_m * wg_tile_m;
460 int start_k = tile_work.k_begin;
461 uint32_t boundary_n = (start_n + wg_tile_n) > args.matrix_n
462 ? args.matrix_n
463 : (start_n + wg_tile_n);
464 uint32_t boundary_m = (start_m + wg_tile_m) > args.matrix_m
465 ? args.matrix_m
466 : (start_m + wg_tile_m);
467 uint32_t boundary_k = tile_work.k_end;
468
469 int first_group_idx = workgroup_mapping.get_first_group_idx(
470 tile_idx, group_idx);
471 bool tile_finished = (boundary_k == args.matrix_k);
472 bool tile_started = (tile_work.k_begin == 0);
473
474 // set up arguments
475 work_group_t g;
476 g.init(item.get_local_linear_id());
477 mem_desc_a_t mem_desc_a;
478 mem_desc_b_t mem_desc_b;
479 mem_desc_c_t mem_desc_c;
480 mem_desc_d_t mem_desc_d;
481
482 //setup for matA
483 if constexpr (mem_desc_a_t::is_local) {
484 mem_desc_a.init(args.matA_base,
485 {args.matrix_k, real_wg_tile_m, args.matrix_k},
486 {0, 0});
487 } else {
488 mem_desc_a.init(args.matA_base,
489 {boundary_k, boundary_m, args.matA_ld},
490 {start_k, start_m});
491 }
492
493 //setup for matB
494 if constexpr (mem_desc_b_t::is_local) {
495 mem_desc_b.init(args.matB_base,
496 {real_wg_tile_n, args.matrix_k, real_wg_tile_n},
497 {0, 0});
498 } else {
499 mem_desc_b.init(args.matB_base,
500 {boundary_n, boundary_k, args.matB_ld},
501 {start_n, start_k});
502 }
503 //setup for matC
504 if constexpr (mem_desc_c_t::is_local) {
505 mem_desc_c.init(args.matC_base,
506 {real_wg_tile_n, real_wg_tile_m, real_wg_tile_n},
507 {0, 0});
508 } else {
509 mem_desc_c.init(args.matC_base,
510 {boundary_n, boundary_m, args.matC_ld},
511 {start_n, start_m});
512 }
513
514 //setup for scratchspace matD
515 mem_desc_d.init(args.matD_base,
516 {boundary_n, boundary_m, args.matD_ld},
517 {start_n, start_m});
518
519 matAcc_t matAcc(0);
520
521 uint32_t inner_loop_count = tile_work.k_iters_remaining;
522
523 gemm_args_t gemm_args(mem_desc_a, mem_desc_b, inner_loop_count);
524 gemm_t gemm;
525
526 gemm(g, matAcc, gemm_args, gemm_slm_base, gemm_nbarr_base);
527
528 epilogue_stream_k(g, matAcc, mem_desc_c, mem_desc_d,
529 mem_desc_atomic_sync, group_idx, first_group_idx,
530 tile_finished, tile_started, args.epilogue_args,
531 epilogue_slm_base, epilogue_nbarr_base);
532 }
533
534 group_iters_remaining -= tile_work.k_iters_remaining;
535 if (group_iters_remaining == 0) { break; }
536
537 //Continue to next tile
538 if (dp_group) {
539 //DP groups consume their tiles at stride
540 tile_idx += workgroup_mapping.avail_xecores;
541 init_dp_tile_work(tile_work, tile_idx, args.matrix_k,
542 group_iters_remaining);
543 } else {
544 //SK groups consume their tiles in backwards order
545 tile_idx--;
546 init_sk_tile_work(tile_work, tile_idx, group_iter_begin,
547 group_iter_begin + group_iters_remaining, args.matrix_k,
548 iters_per_tile, k_stride);
549 }
550 //Tile offset in M and N
551 workgroup_mapping.get_tile_offsets(
552 tile_idx, tile_work.tile_offset_m, tile_work.tile_offset_n);
553 }
554 }
555};
556
558
559} // namespace gpu::xetla::kernel
Definition limitation.hpp:738
static cl::sycl::range< 3 > get_local_range()
Host helper function to get the expected local range under the current GEMM config.
Definition stream_k_xe.hpp:227
static __XETLA_API constexpr uint32_t get_barrier_count()
Gets named_barrier id consumption count.
Definition stream_k_xe.hpp:206
static __XETLA_API void init_dp_tile_work(TileWorkDesc &tile_work, int tile_idx, int matrix_k, int iters_per_tile)
Definition stream_k_xe.hpp:342
static __XETLA_API constexpr uint32_t get_slm_size()
Gets local memory size consumption.
Definition stream_k_xe.hpp:218
__XETLA_API KERNEL_FUNC void operator()(sycl::nd_item< 3 > &item, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Main execution function for stream_k GEMM.
Definition stream_k_xe.hpp:392
static size_t get_acc_buf_size(dispatch_stream_k &stream_k_args)
Host helper function to get the expected accumulation buffer size of the current STREAMK_GEMM_UNIVERS...
Definition stream_k_xe.hpp:249
static bool can_implement(arguments_t &args)
Check if the arguments can be implemented.
Definition stream_k_xe.hpp:271
static size_t get_cnt_buf_size(dispatch_stream_k &stream_k_args)
Host helper function to get the expected counter buffer size of the current STREAMK_GEMM_UNIVERSAL co...
Definition stream_k_xe.hpp:255
static cl::sycl::nd_range< 3 > get_nd_range(arguments_t &args)
Host helper function to get the expected nd_range under the current GEMM config.
Definition stream_k_xe.hpp:239
static __XETLA_API void init_sk_tile_work(TileWorkDesc &tile_work, int tile_idx, int group_iter_begin, int group_iter_end, int matrix_k, int iters_per_tile, uint32_t stride_k)
Definition stream_k_xe.hpp:358
GEMM_UNIVERSAL functor.
Definition api.hpp:39
Definition limitation.hpp:736
#define __XETLA_API
Definition common.hpp:43
__XETLA_API xetla_vector< T, SZ > xetla_max(xetla_vector< T, SZ > src0, xetla_vector< T, SZ > src1, Sat sat={})
Selects component-wise the maximum of the two vectors.
Definition math_general.hpp:97
__XETLA_API xetla_vector< T, SZ > xetla_min(xetla_vector< T, SZ > src0, xetla_vector< T, SZ > src1, Sat sat={})
Selects component-wise the minimum of the two vectors.
Definition math_general.hpp:166
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
C++ API.
Definition limitation.hpp:734
gpu_arch
Definition common.hpp:73
Is the epilogue functor specialized for stream_k.
Definition stream_k_op_xe.hpp:34
StreamK GEMM implementation.
Definition dispatch_policy.hpp:142
uint32_t matrix_n
Definition dispatch_policy.hpp:148
uint32_t matrix_m
Definition dispatch_policy.hpp:146
__XETLA_API KERNEL_FUNC int get_iters_per_tile() const
Kernel helper function to return number of K-iters per output tile.
Definition dispatch_policy.hpp:414
uint32_t sk_regions
Definition dispatch_policy.hpp:167
uint32_t sk_waves
Definition dispatch_policy.hpp:164
uint32_t sk_groups_per_region
Definition dispatch_policy.hpp:168
uint32_t avail_xecores
Number of xecores available for stream_k load balancing.
Definition dispatch_policy.hpp:158
int tile_offset_m
location of this tile in group-tile coordinates in output matrix
Definition stream_k_xe.hpp:326
uint32_t matA_ld
Is the leading dimension (pitch) size of the matrix A in memory.
Definition stream_k_xe.hpp:93
arguments_t(uint32_t matrix_m_, uint32_t matrix_k_, uint32_t matrix_n_, matA_base_t matA_base_, uint32_t matA_ld_, matB_base_t matB_base_, uint32_t matB_ld_, matC_base_t matC_base_, uint32_t matC_ld_, matD_base_t matD_base_, uint32_t matD_ld_, matatomic_sync_base_t matatomic_sync_base_, uint32_t matatomic_sync_ld_, dispatch_stream_k &stream_k_args_, epilogue_args_t epilogue_args_={})
Constructs arguments with initialization list.
Definition stream_k_xe.hpp:141
uint32_t matrix_m
Is the size of the m dimension of the matrix multiplication (m x k x n).
Definition stream_k_xe.hpp:87
uint32_t matC_ld
Is the leading dimension (pitch) size of the matrix C in memory.
Definition stream_k_xe.hpp:97
uint32_t matB_ld
Is the leading dimension (pitch) size of the matrix B in memory.
Definition stream_k_xe.hpp:95
uint32_t matD_ld
Is the leading dimension (pitch) size of the matrix D in memory.
Definition stream_k_xe.hpp:99
uint32_t matatomic_sync_ld
Is the leading dimension (pitch) size of the atomic_sync space in memory.
Definition stream_k_xe.hpp:101
uint32_t matrix_n
Is the size of the n dimension of the matrix multiplication (m x k x n).
Definition stream_k_xe.hpp:91
dispatch_stream_k stream_k_args
Is the workgroup split streamk arguments.
Definition stream_k_xe.hpp:113
matatomic_sync_base_t matatomic_sync_base
Is the base address of groupsync buf.
Definition stream_k_xe.hpp:111
uint32_t matrix_k
Is the size of the k dimension of the matrix multiplication (m x k x n).
Definition stream_k_xe.hpp:89
Definition memory_descriptor.hpp:139
__XETLA_API constexpr uint32_t cacheline_align_up(size_t size)
Definition misc.hpp:42