XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
int4_dequantize_kslicing_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
24
25namespace gpu::xetla::kernel {
26
29
36template <int num_global_kslicing_, int num_local_kslicing_, typename gemm_t_,
37 typename epilogue_t_, typename group_swizzle_>
39 num_global_kslicing_, num_local_kslicing_>,
40 gemm_t_, epilogue_t_> {
41 using gemm_t = gemm_t_;
42 using epilogue_t = epilogue_t_;
43 using gemm_args_t = typename gemm_t::arguments_t;
44 using epilogue_args_t = typename epilogue_t::arguments_t;
45 using tile_shape = typename gemm_t::tile_shape;
46 using group_swizzle_t = group_swizzle_;
47 static constexpr uint32_t wg_tile_m = tile_shape::wg_tile_size_y;
48 static constexpr uint32_t wg_tile_n = tile_shape::wg_tile_size_x;
49 static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y;
50 static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x;
51 static constexpr uint32_t wg_size_y = tile_shape::wg_size_y;
52 static constexpr uint32_t wg_size_x = tile_shape::wg_size_x;
53 static constexpr uint32_t real_wg_tile_m = sg_tile_m * wg_size_y;
54 static constexpr uint32_t real_wg_tile_n = sg_tile_n * wg_size_x;
55
56 static constexpr uint32_t k_stride = gemm_t::k_stride;
57 static constexpr uint32_t dequant_s = gemm_t::dequant_s;
58 static constexpr uint32_t pack_ratio = gemm_t::pack_ratio;
59 using work_group_t = typename gemm_t::work_group_t;
60 static constexpr uint32_t work_group_size = work_group_t::size;
61
62 static constexpr gpu_arch arch_tag = gpu_arch::Xe;
63 static_assert(arch_tag == gemm_t::arch_tag, "arch_tag should be the same");
64 static_assert(
65 arch_tag == epilogue_t::arch_tag, "arch_tag should be the same");
66 static_assert(std::is_same<typename gemm_t::tile_shape,
67 typename epilogue_t::tile_shape>::value,
68 "tile_shape should be the same");
69
70 using mem_desc_a_t = typename gemm_t::mem_desc_a_t;
71 using mem_desc_b_t = typename gemm_t::mem_desc_b_t;
72 using mem_desc_scale_t = typename gemm_t::mem_desc_scale_t;
73 using mem_desc_zero_pt_t = typename gemm_t::mem_desc_zero_pt_t;
74 using mem_desc_c_t = typename epilogue_t::mem_desc_c_t;
75 using matA_base_t = typename mem_desc_a_t::base_t;
76 using matB_base_t = typename mem_desc_b_t::base_t;
77 using matC_base_t = typename mem_desc_c_t::base_t;
78 using scale_base_t = typename mem_desc_scale_t::base_t;
79 using zero_pt_base_t = typename mem_desc_zero_pt_t::base_t;
80
81 using dtype_a = typename mem_desc_a_t::dtype;
82 using dtype_b = typename mem_desc_b_t::dtype;
83 using dtype_c = typename mem_desc_c_t::dtype;
84 using dtype_scale = typename mem_desc_scale_t::dtype;
85 using dtype_zero_pt = typename mem_desc_zero_pt_t::dtype;
86 using matAcc_t = typename gemm_t::matAcc_t;
87 using dtype_acc = typename matAcc_t::dtype;
88 using mem_desc_acc_t
90 using mem_desc_cnt_t
92 using acc_base_t = typename mem_desc_acc_t::base_t;
93 using cnt_base_t = typename mem_desc_cnt_t::base_t;
94
95 static_assert(gemm_t::compute_policy::is_int4_matB_policy,
96 "should match with 4bit gemm impl");
97
98 static constexpr uint32_t num_global_kslicing = num_global_kslicing_;
99 static constexpr uint32_t num_local_kslicing = num_local_kslicing_;
100 static_assert((num_global_kslicing > 0) && (num_local_kslicing > 0),
101 "min slicing ratio is 1");
102
103 static_assert((num_local_kslicing & (num_local_kslicing - 1)) == 0,
104 "num_local_kslicing should be power of 2!");
105
107 matAcc_t, num_local_kslicing, gpu_arch::Xe>;
108 using mat_slice_t = typename kslicing_t::mat_slice_t;
109 static constexpr uint32_t ks_coop_num_x = kslicing_t::coop_num_x;
110 static constexpr uint32_t ks_coop_num_y = kslicing_t::coop_num_y;
111
112 static constexpr uint32_t gemm_nbarr_count = gemm_t::barrier_count;
113 static constexpr uint32_t gemm_slm_size = gemm_t::slm_size;
114
115 static constexpr uint32_t epilogue_nbarr_count = epilogue_t::barrier_count;
116 static constexpr uint32_t epilogue_slm_size = epilogue_t::slm_size;
117
118 static constexpr uint32_t kslicing_nbarr_count = kslicing_t::barrier_count;
119 static constexpr uint32_t kslicing_slm_size = kslicing_t::slm_size;
120
121 static constexpr uint32_t counter_size = 8;
122
123 using tile_shape_cnt = group::tile_shape_t<ks_coop_num_x * wg_size_x,
124 ks_coop_num_y * wg_size_y, ks_coop_num_x, ks_coop_num_y>;
125
128 num_global_kslicing, counter_size, gpu_arch::Xe>;
129
130public:
133 struct arguments_t {
135 uint32_t matrix_m;
137 uint32_t matrix_k;
139 uint32_t matrix_n;
141 uint32_t matA_ld;
143 uint32_t matB_ld;
145 uint32_t matC_ld;
147 matA_base_t matA_base;
149 matB_base_t matB_base;
151 matC_base_t matC_base;
153 acc_base_t acc_base;
155 cnt_base_t cnt_base;
157 epilogue_args_t epilogue_args;
158
159 scale_base_t scale_base;
160 zero_pt_base_t zero_pt_base;
161 uint32_t scale_ld;
162 uint32_t zero_pt_ld;
163
165 inline arguments_t() = default;
166
168 static constexpr bool host_callable = true;
169
170 // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor)
171 // Please check if you need to add self-define destructor
172 // ~arguments_t(){}
173
185 inline arguments_t(uint32_t matrix_m_, uint32_t matrix_k_,
186 uint32_t matrix_n_, matA_base_t matA_base_, uint32_t matA_ld_,
187 matB_base_t matB_base_, uint32_t matB_ld_,
188 matC_base_t matC_base_, uint32_t matC_ld_,
189 scale_base_t scale_base_, uint32_t scale_ld_,
190 zero_pt_base_t zero_pt_base_, uint32_t zero_pt_ld_,
191 acc_base_t acc_base_ = {}, cnt_base_t cnt_base_ = {},
192 epilogue_args_t epilogue_args_ = {})
193 : matrix_m(matrix_m_)
194 , matrix_k(matrix_k_)
195 , matrix_n(matrix_n_)
196 , matA_ld(matA_ld_)
197 , matB_ld(matB_ld_)
198 , matC_ld(matC_ld_)
199 , matA_base(matA_base_)
200 , matB_base(matB_base_)
201 , matC_base(matC_base_)
202 , acc_base(acc_base_)
203 , cnt_base(cnt_base_)
204 , epilogue_args(epilogue_args_)
205 , scale_base(scale_base_)
206 , zero_pt_base(zero_pt_base_)
207 , scale_ld(scale_ld_)
208 , zero_pt_ld(zero_pt_ld_) {}
209 inline arguments_t(const arguments_t &args)
210 : matrix_m(args.matrix_m)
211 , matrix_k(args.matrix_k)
212 , matrix_n(args.matrix_n)
213 , matA_ld(args.matA_ld)
214 , matB_ld(args.matB_ld)
215 , matC_ld(args.matC_ld)
216 , matA_base(args.matA_base)
217 , matB_base(args.matB_base)
218 , matC_base(args.matC_base)
219 , acc_base(args.acc_base)
220 , cnt_base(args.cnt_base)
221 , epilogue_args(args.epilogue_args)
222 , scale_base(args.scale_base)
223 , zero_pt_base(args.zero_pt_base)
224 , scale_ld(args.scale_ld)
225 , zero_pt_ld(args.zero_pt_ld) {}
226 // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor)
227 // Please check if you need to add self-define destructor
228 // inline ~arguments_t(){}
229 inline arguments_t &operator=(const arguments_t &args) {
230 this->matrix_m = args.matrix_m;
231 this->matrix_k = args.matrix_k;
232 this->matrix_n = args.matrix_n;
233 this->matA_base = args.matA_base;
234 this->matA_ld = args.matA_ld;
235 this->matB_base = args.matB_base;
236 this->matB_ld = args.matB_ld;
237 this->matC_base = args.matC_base;
238 this->matC_ld = args.matC_ld;
239 this->scale_base = args.scale_base;
240 this->scale_ld = args.scale_ld;
241 this->zero_pt_base = args.zero_pt_base;
242 this->zero_pt_ld = args.zero_pt_ld;
243 this->acc_base = args.acc_base;
244 this->cnt_base = args.cnt_base;
245 this->epilogue_args = args.epilogue_args;
246 return *this;
247 }
248 };
249
253 __XETLA_API static constexpr uint32_t get_barrier_count() {
254 constexpr uint32_t count = gemm_nbarr_count * num_local_kslicing
255 + kslicing_nbarr_count
256 + epilogue_nbarr_count * num_local_kslicing;
257 static_assert(
258 count <= 32, "The named_barrier count should be less than 32!");
259 return count;
260 }
261
265 __XETLA_API static constexpr uint32_t get_slm_size() {
266 constexpr uint32_t size = gemm_slm_size * num_local_kslicing
267 + kslicing_slm_size + epilogue_slm_size * num_local_kslicing;
268 static_assert(size <= (128 * 1024),
269 "The local memory size should be less than 128KB!");
270 return size;
271 }
272
275 static cl::sycl::range<3> get_local_range() {
276 uint32_t local_range_m = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
277 uint32_t local_range_n = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
278 std::cout << "Local range: {" << num_local_kslicing << ", "
279 << local_range_m << ", " << local_range_n << "} \n";
280 assert(local_range_m * local_range_n * num_local_kslicing <= 32);
281 return cl::sycl::range<3> {
282 num_local_kslicing, local_range_m, local_range_n};
283 };
284
289 static cl::sycl::range<3> get_group_range(
290 uint32_t matrix_m, uint32_t matrix_n) {
291 uint32_t group_range_m = (matrix_m + wg_tile_m - 1) / wg_tile_m;
292 uint32_t group_range_n = (matrix_n + wg_tile_n - 1) / wg_tile_n;
293 group_swizzle_t::update_group_range(group_range_m, group_range_n);
294 std::cout << "Group range: {" << num_global_kslicing << ", "
295 << group_range_m << ", " << group_range_n << "} \n";
296 return cl::sycl::range<3> {
297 num_global_kslicing, group_range_m, group_range_n};
298 };
299
303 static cl::sycl::nd_range<3> get_nd_range(arguments_t &args) {
304 cl::sycl::range<3> local_range = get_local_range();
305 cl::sycl::range<3> group_range
306 = get_group_range(args.matrix_m, args.matrix_n);
307 return cl::sycl::nd_range<3> {group_range * local_range, local_range};
308 };
309
314 static size_t get_acc_buf_size(uint32_t matrix_m, uint32_t matrix_n) {
315 return matrix_m * matrix_n;
316 };
317
322 static size_t get_cnt_buf_size(uint32_t matrix_m, uint32_t matrix_n) {
323 size_t group_range_m = (matrix_m + wg_tile_m - 1) / wg_tile_m;
324 size_t group_range_n = (matrix_n + wg_tile_n - 1) / wg_tile_n;
325 return group_range_m * group_range_n * wg_size_y * wg_size_x
326 * ks_coop_num_y * ks_coop_num_x * counter_size;
327 };
328
332 static bool can_implement(arguments_t &args) {
333 bool implementable = true;
334 if (gemm_t::msg_type_a != msg_type::unaligned_2d) {
335 if (gemm_t::msg_type_a == msg_type::block_2d) {
336 implementable &= kernel::block_2d<gpu_arch::Xe,
337 dtype_a>::check_tensor((uint64_t)(args.matA_base.base),
338 gemm_t::is_col_major_a ? args.matrix_m : args.matrix_k,
339 gemm_t::is_col_major_a ? args.matrix_k : args.matrix_m,
340 args.matA_ld);
341 } else {
342 implementable &= kernel::general_1d<gpu_arch::Xe,
343 dtype_a>::check_alignment(args.matA_base.base,
344 args.matA_ld);
345 }
346 }
347 if (gemm_t::msg_type_b != msg_type::unaligned_2d) {
348 if (gemm_t::msg_type_b == msg_type::block_2d) {
349 implementable &= kernel::block_2d<gpu_arch::Xe,
350 dtype_b>::check_tensor((uint64_t)(args.matB_base.base),
351 args.matB_ld / pack_ratio,
352 gemm_t::is_col_major_b ? args.matrix_n : args.matrix_k,
353 args.matB_ld / pack_ratio);
354 } else {
355 implementable &= kernel::general_1d<gpu_arch::Xe,
356 dtype_b>::check_alignment(args.matB_base.base,
357 args.matB_ld / pack_ratio);
358 }
359 }
360 if (epilogue_t::msg_type_c != msg_type::unaligned_2d) {
361 if (epilogue_t::msg_type_c == msg_type::block_2d) {
362 implementable &= kernel::block_2d<gpu_arch::Xe,
363 dtype_c>::check_tensor((uint64_t)(args.matC_base.base),
364 args.matrix_n, args.matrix_m, args.matC_ld);
365 } else {
366 implementable &= kernel::general_1d<gpu_arch::Xe,
367 dtype_c>::check_alignment(args.matC_base.base,
368 args.matC_ld);
369 }
370 }
371 // check for int4x2
372 implementable &= ((args.matB_ld % pack_ratio == 0)
373 && (args.zero_pt_ld % pack_ratio == 0)
374 && (args.matrix_n % pack_ratio == 0));
375
376 return implementable;
377 }
378
386 __XETLA_API KERNEL_FUNC void operator()(sycl::nd_item<3> &item,
387 const arguments_t &args, uint32_t slm_base = 0,
388 uint32_t nbarrier_base = 0) {
389 // set up workgroup level coordinates and boundaries
390 work_group_t g(item.get_local_linear_id() % work_group_size);
391 uint32_t wg_id = item.get_local_linear_id() / work_group_size;
392 group_swizzle_t group_swizzle;
393 int start_m = group_swizzle.template get_tile_idx<1>(item) * wg_tile_m;
394 int start_n = group_swizzle.template get_tile_idx<2>(item) * wg_tile_n;
395 int start_k = 0;
396 uint32_t wg_tile_k = args.matrix_k;
397 uint32_t boundary_n = (start_n + wg_tile_n) > args.matrix_n
398 ? args.matrix_n
399 : (start_n + wg_tile_n);
400 uint32_t boundary_m = (start_m + wg_tile_m) > args.matrix_m
401 ? args.matrix_m
402 : (start_m + wg_tile_m);
403 uint32_t boundary_k = wg_tile_k;
404 if constexpr (num_global_kslicing > 1) {
405 wg_tile_k = (wg_tile_k + num_global_kslicing - 1)
406 / num_global_kslicing;
407 start_k = start_k
408 + group_swizzle.template get_tile_idx<0>(item) * wg_tile_k;
409 boundary_k = (start_k + wg_tile_k) > boundary_k
410 ? boundary_k
411 : (start_k + wg_tile_k);
412 }
413 if constexpr (num_local_kslicing > 1) {
415 = (wg_tile_k + num_local_kslicing - 1) / num_local_kslicing;
416 start_k = start_k + wg_id * wg_tile_k;
417 boundary_k = (start_k + wg_tile_k) > boundary_k
418 ? boundary_k
419 : (start_k + wg_tile_k);
420 }
421
422 int start_x_scale = start_n;
423 int start_y_scale = start_k / dequant_s;
424
425 int start_x_zero_pt = start_n / pack_ratio;
426 int start_y_zero_pt = start_k / dequant_s;
427
428 // set up arguments
429 uint32_t gemm_slm_base = slm_base;
430 uint32_t gemm_nbarr_base = nbarrier_base;
431 if constexpr (num_local_kslicing > 1) {
432 gemm_slm_base = slm_base + wg_id * gemm_slm_size;
433 gemm_nbarr_base = nbarrier_base + wg_id * gemm_nbarr_count;
434 }
435 uint32_t kslicing_slm_base
436 = slm_base + num_local_kslicing * gemm_slm_size;
437 uint32_t kslicing_nbarr_base
438 = nbarrier_base + num_local_kslicing * gemm_nbarr_count;
439 uint32_t epilogue_slm_base = kslicing_slm_base + kslicing_slm_size;
440 uint32_t epilogue_nbarr_base
441 = kslicing_nbarr_base + kslicing_nbarr_count;
442
443 mem_desc_a_t mem_desc_a;
444 mem_desc_b_t mem_desc_b;
445 mem_desc_c_t mem_desc_c;
446 //setup for matA
447
448 mem_desc_a.init(args.matA_base, {boundary_k, boundary_m, args.matA_ld},
449 {start_k, start_m});
450 mem_desc_b.init(args.matB_base,
451 {boundary_n / pack_ratio, boundary_k,
452 args.matB_ld / pack_ratio},
453 {int(start_n / pack_ratio), start_k});
454
455 uint32_t scale_size_y = ((args.matrix_k + dequant_s - 1) / dequant_s);
456 mem_desc_scale_t mem_desc_scale(args.scale_base,
457 {args.matrix_n, scale_size_y, args.scale_ld},
458 {start_x_scale, start_y_scale});
459 mem_desc_zero_pt_t mem_desc_zero_pt(args.zero_pt_base,
460 {args.matrix_n / pack_ratio, scale_size_y,
461 args.zero_pt_ld / pack_ratio},
462 {start_x_zero_pt, start_y_zero_pt});
463
464 uint32_t inner_loop_count = (wg_tile_k + k_stride - 1) / k_stride;
465 gemm_args_t gemm_args(mem_desc_a, mem_desc_b, inner_loop_count,
466 mem_desc_scale, mem_desc_zero_pt);
467 matAcc_t matAcc;
468 matAcc.init(0);
469 gemm_t gemm;
470 gemm(g, matAcc, gemm_args, gemm_slm_base, gemm_nbarr_base);
471
472 kslicing_t kslicing(wg_id);
473 mat_slice_t mat_slice;
474 kslicing(g, mat_slice, matAcc, kslicing_slm_base, kslicing_nbarr_base);
475
476 if (kslicing.is_valid_post_process_wg()) {
477 //setup for matC
478 //set up cooperative offset for matC store
479 int32_t coop_offset_x
480 = kslicing.coop_id_x * mat_slice_t::tile_size_x;
481 int32_t coop_offset_y
482 = kslicing.coop_id_y * mat_slice_t::tile_size_y;
483 int32_t acc_start_x = start_n + coop_offset_x;
484 int32_t acc_start_y = start_m + coop_offset_y;
485 int32_t cnt_start_x = group_swizzle.template get_tile_idx<2>(item)
486 * tile_shape_cnt::wg_tile_size_x
487 + kslicing.coop_id_x;
488 int32_t cnt_start_y = group_swizzle.template get_tile_idx<1>(item)
489 * tile_shape_cnt::wg_tile_size_y
490 + kslicing.coop_id_y;
491 uint32_t group_range_x = item.get_group_range(2);
492 uint32_t group_range_y = item.get_group_range(1);
493 uint32_t cnt_size_x
494 = group_range_x * tile_shape_cnt::wg_tile_size_x;
495 uint32_t cnt_size_y
496 = group_range_y * tile_shape_cnt::wg_tile_size_y;
497 mem_desc_acc_t mem_desc_acc(args.acc_base,
498 {boundary_n, boundary_m, args.matrix_n},
499 {acc_start_x, acc_start_y});
500 mem_desc_cnt_t mem_desc_cnt(args.cnt_base,
501 {cnt_size_x, cnt_size_y, cnt_size_x},
502 {cnt_start_x, cnt_start_y});
503
504 global_group_reduce_t global_group_reduce;
505 global_group_reduce(g, mat_slice, mem_desc_acc, mem_desc_cnt);
506
507 if (global_group_reduce.is_last_group()) {
508 if constexpr (mem_desc_c_t::is_local) {
509 mem_desc_c.init(args.matC_base,
510 {real_wg_tile_n, real_wg_tile_m, real_wg_tile_n},
511 {coop_offset_x, coop_offset_y});
512 } else {
513 mem_desc_c.init(args.matC_base,
514 {boundary_n, boundary_m, args.matC_ld},
515 {start_n + coop_offset_x, start_m + coop_offset_y});
516 }
517 epilogue_t epilogue;
518 epilogue(g, mat_slice, mem_desc_c, args.epilogue_args,
519 epilogue_slm_base, epilogue_nbarr_base);
520 }
521 }
522 }
523};
524
526
527} // namespace gpu::xetla::kernel
Workgroups to do the cooperative reduction.
Definition cooperative_reduction.hpp:35
Cross group global reduction.
Definition global_reduction.hpp:40
Definition limitation.hpp:738
static cl::sycl::range< 3 > get_local_range()
Host helper function to get the expected local range under the current GEMM config.
Definition int4_dequantize_kslicing_xe.hpp:275
static cl::sycl::range< 3 > get_group_range(uint32_t matrix_m, uint32_t matrix_n)
Host helper function to get the expected group range under the current GEMM config.
Definition int4_dequantize_kslicing_xe.hpp:289
static __XETLA_API constexpr uint32_t get_slm_size()
Gets local memory size consumption.
Definition int4_dequantize_kslicing_xe.hpp:265
static __XETLA_API constexpr uint32_t get_barrier_count()
Gets named_barrier id consumption count.
Definition int4_dequantize_kslicing_xe.hpp:253
static cl::sycl::nd_range< 3 > get_nd_range(arguments_t &args)
Host helper function to get the expected nd_range under the current GEMM config.
Definition int4_dequantize_kslicing_xe.hpp:303
__XETLA_API KERNEL_FUNC void operator()(sycl::nd_item< 3 > &item, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Main execution function for GEMM.
Definition int4_dequantize_kslicing_xe.hpp:386
static size_t get_acc_buf_size(uint32_t matrix_m, uint32_t matrix_n)
Host helper function to get the expected accumulation buffer size of the current GEMM config.
Definition int4_dequantize_kslicing_xe.hpp:314
static size_t get_cnt_buf_size(uint32_t matrix_m, uint32_t matrix_n)
Host helper function to get the expected counter buffer size of the current GEMM config.
Definition int4_dequantize_kslicing_xe.hpp:322
static bool can_implement(arguments_t &args)
Check if the arguments can be implemented.
Definition int4_dequantize_kslicing_xe.hpp:332
GEMM_UNIVERSAL functor.
Definition api.hpp:39
Definition limitation.hpp:736
#define __XETLA_API
Definition common.hpp:43
C++ API.
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:734
gpu_arch
Definition common.hpp:73
Workgroup level tile shape description.
Definition tile_shape.hpp:34
4bit kslicing GEMM implementation.
Definition dispatch_policy.hpp:38
arguments_t(uint32_t matrix_m_, uint32_t matrix_k_, uint32_t matrix_n_, matA_base_t matA_base_, uint32_t matA_ld_, matB_base_t matB_base_, uint32_t matB_ld_, matC_base_t matC_base_, uint32_t matC_ld_, scale_base_t scale_base_, uint32_t scale_ld_, zero_pt_base_t zero_pt_base_, uint32_t zero_pt_ld_, acc_base_t acc_base_={}, cnt_base_t cnt_base_={}, epilogue_args_t epilogue_args_={})
Constructs arguments with initialization list.
Definition int4_dequantize_kslicing_xe.hpp:185
uint32_t matC_ld
Is the leading dimension (pitch) size of the matrix C in memory.
Definition int4_dequantize_kslicing_xe.hpp:145
uint32_t matrix_m
Is the size of the m dimension of the matrix multiplication (m x k x n).
Definition int4_dequantize_kslicing_xe.hpp:135
uint32_t matrix_n
Is the size of the n dimension of the matrix multiplication (m x k x n).
Definition int4_dequantize_kslicing_xe.hpp:139
uint32_t matA_ld
Is the leading dimension (pitch) size of the matrix A in memory.
Definition int4_dequantize_kslicing_xe.hpp:141
uint32_t matB_ld
Is the leading dimension (pitch) size of the matrix B in memory.
Definition int4_dequantize_kslicing_xe.hpp:143
uint32_t matrix_k
Is the size of the k dimension of the matrix multiplication (m x k x n).
Definition int4_dequantize_kslicing_xe.hpp:137
Definition memory_descriptor.hpp:139
static constexpr uint32_t size
Definition work_group.hpp:39