XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
kslicing_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "kernel/gemm/api.hpp"
23#include "kernel/gemm/common.hpp"
24#include "kernel/gemm/dispatch_policy.hpp"
25
26namespace gpu::xetla::kernel {
27
30
37template <int num_global_kslicing_, int num_local_kslicing_, typename gemm_t_,
38 typename epilogue_t_, typename group_swizzle_>
40 num_global_kslicing_, num_local_kslicing_>,
41 gemm_t_, epilogue_t_,
42 std::enable_if_t<(group_swizzle_::arch_tag == gpu_arch::Xe)>> {
43 using gemm_t = gemm_t_;
44 using epilogue_t = epilogue_t_;
45 using gemm_args_t = typename gemm_t::arguments_t;
46 using epilogue_args_t = typename epilogue_t::arguments_t;
47 using tile_shape = typename gemm_t::tile_shape;
48 using group_swizzle_t = group_swizzle_;
49
50 static constexpr uint32_t wg_tile_m = tile_shape::wg_tile_size_y;
51 static constexpr uint32_t wg_tile_n = tile_shape::wg_tile_size_x;
52 static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y;
53 static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x;
54 static constexpr uint32_t wg_size_y = tile_shape::wg_size_y;
55 static constexpr uint32_t wg_size_x = tile_shape::wg_size_x;
56 static constexpr uint32_t real_wg_tile_m = sg_tile_m * wg_size_y;
57 static constexpr uint32_t real_wg_tile_n = sg_tile_n * wg_size_x;
58
59 static constexpr uint32_t k_stride = gemm_t::k_stride;
60 using work_group_t = typename gemm_t::work_group_t;
61 static constexpr uint32_t work_group_size = work_group_t::size;
62
63 static constexpr gpu_arch arch_tag = group_swizzle_t::arch_tag;
64 static_assert(arch_tag == gemm_t::arch_tag, "arch_tag should be the same");
65 static_assert(
66 arch_tag == epilogue_t::arch_tag, "arch_tag should be the same");
67 static_assert(std::is_same<typename gemm_t::tile_shape,
68 typename epilogue_t::tile_shape>::value,
69 "tile_shape should be the same");
70
71 using mem_desc_a_t = typename gemm_t::mem_desc_a_t;
72 using mem_desc_b_t = typename gemm_t::mem_desc_b_t;
73 using mem_desc_c_t = typename epilogue_t::mem_desc_c_t;
74 using matA_base_t = typename mem_desc_a_t::base_t;
75 using matB_base_t = typename mem_desc_b_t::base_t;
76 using matC_base_t = typename mem_desc_c_t::base_t;
77 using dtype_a = typename mem_desc_a_t::dtype;
78 using dtype_b = typename mem_desc_b_t::dtype;
79 using dtype_c = typename mem_desc_c_t::dtype;
80 using matAcc_t = typename gemm_t::matAcc_t;
81 using dtype_acc = typename matAcc_t::dtype;
82 using mem_desc_acc_t
84 using mem_desc_cnt_t
86 using acc_base_t = typename mem_desc_acc_t::base_t;
87 using cnt_base_t = typename mem_desc_cnt_t::base_t;
88
89 static constexpr uint32_t num_global_kslicing = num_global_kslicing_;
90 static constexpr uint32_t num_local_kslicing = num_local_kslicing_;
91 static_assert((num_global_kslicing > 0) && (num_local_kslicing > 0),
92 "min slicing ratio is 1");
93
94 static_assert((num_local_kslicing & (num_local_kslicing - 1)) == 0,
95 "num_local_kslicing should be power of 2!");
96
98 matAcc_t, num_local_kslicing, arch_tag>;
99 using mat_slice_t = typename kslicing_t::mat_slice_t;
100 static constexpr uint32_t ks_coop_num_x = kslicing_t::coop_num_x;
101 static constexpr uint32_t ks_coop_num_y = kslicing_t::coop_num_y;
102
103 static constexpr uint32_t gemm_nbarr_count = gemm_t::barrier_count;
104 static constexpr uint32_t gemm_slm_size = gemm_t::slm_size;
105
106 static constexpr uint32_t epilogue_nbarr_count = epilogue_t::barrier_count;
107 static constexpr uint32_t epilogue_slm_size = epilogue_t::slm_size;
108
109 static constexpr uint32_t kslicing_nbarr_count = kslicing_t::barrier_count;
110 static constexpr uint32_t kslicing_slm_size = kslicing_t::slm_size;
111
112 static constexpr uint32_t counter_size = 8;
113
114 static constexpr uint32_t alignment = 8 / sizeof(dtype_acc);
115
116 using tile_shape_cnt = group::tile_shape_t<ks_coop_num_x * wg_size_x,
117 ks_coop_num_y * wg_size_y, ks_coop_num_x, ks_coop_num_y>;
118
121 num_global_kslicing, counter_size, arch_tag>;
122
123public:
126 struct arguments_t {
128 uint32_t matrix_m;
130 uint32_t matrix_k;
132 uint32_t matrix_n;
134 matA_base_t matA_base;
136 uint32_t matA_ld;
138 matB_base_t matB_base;
140 uint32_t matB_ld;
142 matC_base_t matC_base;
144 uint32_t matC_ld;
146 acc_base_t acc_base;
148 cnt_base_t cnt_base;
150 epilogue_args_t epilogue_args;
151
153 inline arguments_t() = default;
154
156 static constexpr bool host_callable = true;
157
158 // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor)
159 // Please check if you need to add self-define destructor
160 // ~arguments_t(){}
161
173 inline arguments_t(uint32_t matrix_m_, uint32_t matrix_k_,
174 uint32_t matrix_n_, matA_base_t matA_base_, uint32_t matA_ld_,
175 matB_base_t matB_base_, uint32_t matB_ld_,
176 matC_base_t matC_base_, uint32_t matC_ld_,
177 acc_base_t acc_base_ = {}, cnt_base_t cnt_base_ = {},
178 epilogue_args_t epilogue_args_ = {})
179 : matrix_m(matrix_m_)
180 , matrix_k(matrix_k_)
181 , matrix_n(matrix_n_)
182 , matA_base(matA_base_)
183 , matA_ld(matA_ld_)
184 , matB_base(matB_base_)
185 , matB_ld(matB_ld_)
186 , matC_base(matC_base_)
187 , matC_ld(matC_ld_)
188 , acc_base(acc_base_)
189 , cnt_base(cnt_base_)
190 , epilogue_args(epilogue_args_) {}
191 inline arguments_t(const arguments_t &args)
192 : matrix_m(args.matrix_m)
193 , matrix_k(args.matrix_k)
194 , matrix_n(args.matrix_n)
195 , matA_base(args.matA_base)
196 , matA_ld(args.matA_ld)
197 , matB_base(args.matB_base)
198 , matB_ld(args.matB_ld)
199 , matC_base(args.matC_base)
200 , matC_ld(args.matC_ld)
201 , acc_base(args.acc_base)
202 , cnt_base(args.cnt_base)
203 , epilogue_args(args.epilogue_args) {}
204 // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor)
205 // Please check if you need to add self-define destructor
206 // inline ~arguments_t(){}
207 inline arguments_t &operator=(const arguments_t &args) {
208 this->matrix_m = args.matrix_m;
209 this->matrix_k = args.matrix_k;
210 this->matrix_n = args.matrix_n;
211 this->matA_base = args.matA_base;
212 this->matA_ld = args.matA_ld;
213 this->matB_base = args.matB_base;
214 this->matB_ld = args.matB_ld;
215 this->matC_base = args.matC_base;
216 this->matC_ld = args.matC_ld;
217 this->acc_base = args.acc_base;
218 this->cnt_base = args.cnt_base;
219 this->epilogue_args = args.epilogue_args;
220 return *this;
221 }
222 };
223
227 __XETLA_API static constexpr uint32_t get_barrier_count() {
228 constexpr uint32_t count = gemm_nbarr_count * num_local_kslicing
229 + kslicing_nbarr_count
230 + epilogue_nbarr_count * num_local_kslicing;
231 static_assert(
232 count <= 32, "The named_barrier count should be less than 32!");
233 return count;
234 }
235
239 __XETLA_API static constexpr uint32_t get_slm_size() {
240 constexpr uint32_t size = gemm_slm_size * num_local_kslicing
241 + kslicing_slm_size + epilogue_slm_size * num_local_kslicing;
242 static_assert(size <= (128 * 1024),
243 "The local memory size should be less than 128KB!");
244 return size;
245 }
246
249 static cl::sycl::range<3> get_local_range() {
250 uint32_t local_range_m = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
251 uint32_t local_range_n = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
252 std::cout << "Local range: {" << num_local_kslicing << ", "
253 << local_range_m << ", " << local_range_n << "} \n";
254 assert(local_range_m * local_range_n * num_local_kslicing <= 32);
255 return cl::sycl::range<3> {
256 num_local_kslicing, local_range_m, local_range_n};
257 };
258
263 static cl::sycl::range<3> get_group_range(
264 uint32_t matrix_m, uint32_t matrix_n) {
265 uint32_t group_range_m = (matrix_m + wg_tile_m - 1) / wg_tile_m;
266 uint32_t group_range_n = (matrix_n + wg_tile_n - 1) / wg_tile_n;
267 group_swizzle_t::update_group_range(group_range_m, group_range_n);
268 std::cout << "Group range: {" << num_global_kslicing << ", "
269 << group_range_m << ", " << group_range_n << "} \n";
270 return cl::sycl::range<3> {
271 num_global_kslicing, group_range_m, group_range_n};
272 };
273
277 static cl::sycl::nd_range<3> get_nd_range(arguments_t &args) {
278 cl::sycl::range<3> local_range = get_local_range();
279 cl::sycl::range<3> group_range
280 = get_group_range(args.matrix_m, args.matrix_n);
281 return cl::sycl::nd_range<3> {group_range * local_range, local_range};
282 };
283
288 static size_t get_acc_buf_size(uint32_t matrix_m, uint32_t matrix_n) {
289 size_t aligned_n = (matrix_n + alignment - 1) / alignment * alignment;
290 return matrix_m * aligned_n;
291 };
292
297 static size_t get_cnt_buf_size(uint32_t matrix_m, uint32_t matrix_n) {
298 size_t group_range_m = (matrix_m + wg_tile_m - 1) / wg_tile_m;
299 size_t group_range_n = (matrix_n + wg_tile_n - 1) / wg_tile_n;
300 return group_range_m * group_range_n * wg_size_y * wg_size_x
301 * ks_coop_num_y * ks_coop_num_x * counter_size;
302 };
303
307 static bool can_implement(arguments_t &args) {
308 bool implementable = true;
309 if (gemm_t::msg_type_a != msg_type::unaligned_2d) {
310 if (gemm_t::msg_type_a == msg_type::block_2d) {
311 implementable &= kernel::block_2d<gpu_arch::Xe,
312 dtype_a>::check_tensor((uint64_t)(args.matA_base.base),
313 gemm_t::is_col_major_a ? args.matrix_m : args.matrix_k,
314 gemm_t::is_col_major_a ? args.matrix_k : args.matrix_m,
315 args.matA_ld);
316 } else {
317 implementable &= kernel::general_1d<gpu_arch::Xe,
318 dtype_a>::check_alignment(args.matA_base.base,
319 args.matA_ld);
320 }
321 }
322 if (gemm_t::msg_type_b != msg_type::unaligned_2d) {
323 if (gemm_t::msg_type_b == msg_type::block_2d) {
324 implementable &= kernel::block_2d<gpu_arch::Xe,
325 dtype_b>::check_tensor((uint64_t)(args.matB_base.base),
326 gemm_t::is_col_major_b ? args.matrix_k : args.matrix_n,
327 gemm_t::is_col_major_b ? args.matrix_n : args.matrix_k,
328 args.matB_ld);
329 } else {
330 implementable &= kernel::general_1d<gpu_arch::Xe,
331 dtype_b>::check_alignment(args.matB_base.base,
332 args.matB_ld);
333 }
334 }
335 if (epilogue_t::msg_type_c != msg_type::unaligned_2d) {
336 if (epilogue_t::msg_type_c == msg_type::block_2d) {
337 implementable &= kernel::block_2d<gpu_arch::Xe,
338 dtype_c>::check_tensor((uint64_t)(args.matC_base.base),
339 args.matrix_n, args.matrix_m, args.matC_ld);
340 } else {
341 implementable &= kernel::general_1d<gpu_arch::Xe,
342 dtype_c>::check_alignment(args.matC_base.base,
343 args.matC_ld);
344 }
345 }
346
347 return implementable;
348 }
349
357 __XETLA_API KERNEL_FUNC void operator()(sycl::nd_item<3> &item,
358 const arguments_t &args, uint32_t slm_base = 0,
359 uint32_t nbarrier_base = 0) {
360 // set up workgroup level coordinates and boundaries
361 work_group_t g(item.get_local_linear_id() % work_group_size);
362 uint32_t wg_id = item.get_local_linear_id() / work_group_size;
363 group_swizzle_t group_swizzle;
364 int start_m = group_swizzle.template get_tile_idx<1>(item) * wg_tile_m;
365 int start_n = group_swizzle.template get_tile_idx<2>(item) * wg_tile_n;
366 int start_k = 0;
367 uint32_t wg_tile_k = args.matrix_k;
368 uint32_t boundary_n = (start_n + wg_tile_n) > args.matrix_n
369 ? args.matrix_n
370 : (start_n + wg_tile_n);
371 uint32_t boundary_m = (start_m + wg_tile_m) > args.matrix_m
372 ? args.matrix_m
373 : (start_m + wg_tile_m);
374 uint32_t boundary_k = wg_tile_k;
375 if constexpr (num_global_kslicing > 1) {
376 wg_tile_k = (wg_tile_k + num_global_kslicing - 1)
377 / num_global_kslicing;
378 start_k = start_k
379 + group_swizzle.template get_tile_idx<0>(item) * wg_tile_k;
380 boundary_k = (start_k + wg_tile_k) > boundary_k
381 ? boundary_k
382 : (start_k + wg_tile_k);
383 }
384 if constexpr (num_local_kslicing > 1) {
386 = (wg_tile_k + num_local_kslicing - 1) / num_local_kslicing;
387 start_k = start_k + wg_id * wg_tile_k;
388 boundary_k = (start_k + wg_tile_k) > boundary_k
389 ? boundary_k
390 : (start_k + wg_tile_k);
391 }
392
393 // set up arguments
394 uint32_t gemm_slm_base = slm_base;
395 uint32_t gemm_nbarr_base = nbarrier_base;
396 if constexpr (num_local_kslicing > 1) {
397 gemm_slm_base = slm_base + wg_id * gemm_slm_size;
398 gemm_nbarr_base = nbarrier_base + wg_id * gemm_nbarr_count;
399 }
400 uint32_t kslicing_slm_base
401 = slm_base + num_local_kslicing * gemm_slm_size;
402 uint32_t kslicing_nbarr_base
403 = nbarrier_base + num_local_kslicing * gemm_nbarr_count;
404 uint32_t epilogue_slm_base = kslicing_slm_base + kslicing_slm_size;
405 uint32_t epilogue_nbarr_base
406 = kslicing_nbarr_base + kslicing_nbarr_count;
407
408 mem_desc_a_t mem_desc_a;
409 mem_desc_b_t mem_desc_b;
410 mem_desc_c_t mem_desc_c;
411 //setup for matA
412 if constexpr (mem_desc_a_t::is_local) {
413 mem_desc_a.init(args.matA_base,
414 {wg_tile_k, real_wg_tile_m, wg_tile_k}, {0, 0});
415 } else {
416 mem_desc_a.init(args.matA_base,
417 {boundary_k, boundary_m, args.matA_ld}, {start_k, start_m});
418 }
419 //setup for matB
420 if constexpr (mem_desc_b_t::is_local) {
421 mem_desc_b.init(args.matB_base,
422 {real_wg_tile_n, wg_tile_k, real_wg_tile_n}, {0, 0});
423 } else {
424 mem_desc_b.init(args.matB_base,
425 {boundary_n, boundary_k, args.matB_ld}, {start_n, start_k});
426 }
427
428 uint32_t inner_loop_count = (wg_tile_k + k_stride - 1) / k_stride;
429 gemm_args_t gemm_args(mem_desc_a, mem_desc_b, inner_loop_count);
430 matAcc_t matAcc;
431 matAcc.init(0);
432 gemm_t gemm;
433 gemm(g, matAcc, gemm_args, gemm_slm_base, gemm_nbarr_base);
434
435 kslicing_t kslicing(wg_id);
436 mat_slice_t mat_slice;
437 kslicing(g, mat_slice, matAcc, kslicing_slm_base, kslicing_nbarr_base);
438 if (kslicing.is_valid_post_process_wg()) {
439 //setup for matC
440 //set up cooperative offset for matC store
441 int32_t coop_offset_x
442 = kslicing.coop_id_x * mat_slice_t::tile_size_x;
443 int32_t coop_offset_y
444 = kslicing.coop_id_y * mat_slice_t::tile_size_y;
445 int32_t acc_start_x = start_n + coop_offset_x;
446 int32_t acc_start_y = start_m + coop_offset_y;
447 int32_t cnt_start_x = group_swizzle.template get_tile_idx<2>(item)
448 * tile_shape_cnt::wg_tile_size_x
449 + kslicing.coop_id_x;
450 int32_t cnt_start_y = group_swizzle.template get_tile_idx<1>(item)
451 * tile_shape_cnt::wg_tile_size_y
452 + kslicing.coop_id_y;
453 uint32_t group_range_x = item.get_group_range(2);
454 uint32_t group_range_y = item.get_group_range(1);
455 uint32_t cnt_size_x
456 = group_range_x * tile_shape_cnt::wg_tile_size_x;
457 uint32_t cnt_size_y
458 = group_range_y * tile_shape_cnt::wg_tile_size_y;
459
460 uint32_t acc_aligned_n
461 = (args.matrix_n + alignment - 1) / alignment * alignment;
462
463 uint32_t acc_boundary_n = (start_n + wg_tile_n) > acc_aligned_n
464 ? acc_aligned_n
465 : start_n + wg_tile_n;
466
467 mem_desc_acc_t mem_desc_acc(args.acc_base,
468 {acc_boundary_n, boundary_m, acc_aligned_n},
469 {acc_start_x, acc_start_y});
470 mem_desc_cnt_t mem_desc_cnt(args.cnt_base,
471 {cnt_size_x, cnt_size_y, cnt_size_x},
472 {cnt_start_x, cnt_start_y});
473
474 global_group_reduce_t global_group_reduce;
475 global_group_reduce(g, mat_slice, mem_desc_acc, mem_desc_cnt);
476
477 if (global_group_reduce.is_last_group()) {
478 if constexpr (mem_desc_c_t::is_local) {
479 mem_desc_c.init(args.matC_base,
480 {real_wg_tile_n, real_wg_tile_m, real_wg_tile_n},
481 {coop_offset_x, coop_offset_y});
482 } else {
483 mem_desc_c.init(args.matC_base,
484 {boundary_n, boundary_m, args.matC_ld},
485 {start_n + coop_offset_x, start_m + coop_offset_y});
486 }
487 epilogue_t epilogue;
488 epilogue(g, mat_slice, mem_desc_c, args.epilogue_args,
489 epilogue_slm_base, epilogue_nbarr_base);
490 }
491 }
492 }
493};
494
496
497} // namespace gpu::xetla::kernel
Workgroups to do the cooperative reduction.
Definition cooperative_reduction.hpp:35
Cross group global reduction.
Definition global_reduction.hpp:40
Definition limitation.hpp:738
__XETLA_API KERNEL_FUNC void operator()(sycl::nd_item< 3 > &item, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Main execution function for GEMM_UNIVERSAL.
Definition kslicing_xe.hpp:357
static size_t get_cnt_buf_size(uint32_t matrix_m, uint32_t matrix_n)
Host helper function to get the expected counter buffer size of the current GEMM_UNIVERSAL config.
Definition kslicing_xe.hpp:297
static cl::sycl::range< 3 > get_local_range()
Host helper function to get the expected local range under the current GEMM_UNIVERSAL config.
Definition kslicing_xe.hpp:249
static cl::sycl::range< 3 > get_group_range(uint32_t matrix_m, uint32_t matrix_n)
Host helper function to get the expected group range under the current GEMM_UNIVERSAL config.
Definition kslicing_xe.hpp:263
static cl::sycl::nd_range< 3 > get_nd_range(arguments_t &args)
Host helper function to get the expected nd_range of the current GEMM_UNIVERSAL config.
Definition kslicing_xe.hpp:277
static size_t get_acc_buf_size(uint32_t matrix_m, uint32_t matrix_n)
Host helper function to get the expected accumulation buffer size of the current GEMM_UNIVERSAL confi...
Definition kslicing_xe.hpp:288
GEMM_UNIVERSAL functor.
Definition api.hpp:39
Definition limitation.hpp:736
#define __XETLA_API
Definition common.hpp:43
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
C++ API.
Definition limitation.hpp:734
gpu_arch
Definition common.hpp:73
Workgroup level tile shape description.
Definition tile_shape.hpp:34
Kslicing GEMM_UNIVERSAL implementation.
Definition dispatch_policy.hpp:129
arguments_t(uint32_t matrix_m_, uint32_t matrix_k_, uint32_t matrix_n_, matA_base_t matA_base_, uint32_t matA_ld_, matB_base_t matB_base_, uint32_t matB_ld_, matC_base_t matC_base_, uint32_t matC_ld_, acc_base_t acc_base_={}, cnt_base_t cnt_base_={}, epilogue_args_t epilogue_args_={})
Constructs arguments with initialization list.
Definition kslicing_xe.hpp:173
Definition memory_descriptor.hpp:139
static constexpr uint32_t size
Definition work_group.hpp:39