XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
int4_dequantize_xmx_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
24
25namespace gpu::xetla::group {
26
29
31template <typename compute_attr_, typename perf_tuning_knob_,
32 typename tile_shape_, typename mem_desc_a_t_, typename mem_desc_b_t_,
33 typename dtype_scale_, typename dtype_zero_pt_, int dequant_s_,
34 typename pre_processing_t_>
35class gemm_t<
36 compute_policy_int4_dequantize_xmx<compute_attr_, perf_tuning_knob_,
37 dtype_scale_, dtype_zero_pt_, dequant_s_, gpu_arch::Xe>,
38 tile_shape_, // tile shape of workgroup-level gemm
39 mem_desc_a_t_, // memory attribute of matA
40 mem_desc_b_t_, // memory attribute of matB
41 pre_processing_t_ // pre_processing functor
42 > {
43public:
44 using mem_desc_a_t = mem_desc_a_t_;
45 using mem_desc_b_t = mem_desc_b_t_;
46 using tile_shape = tile_shape_;
47 using pre_processing_t = pre_processing_t_;
49 perf_tuning_knob_, dtype_scale_, dtype_zero_pt_, dequant_s_,
51 static constexpr uint32_t k_stride = compute_policy::k_stride;
52
53 static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y;
54 static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x;
55 static constexpr uint32_t wg_size_x = tile_shape::wg_size_x;
56 static constexpr uint32_t wg_size_y = tile_shape::wg_size_y;
57 using work_group_t = typename tile_shape::work_group_t;
58
59 constexpr static gpu_arch arch_tag = compute_policy::arch_tag;
60 static constexpr uint32_t dequant_s = compute_policy::dequant_s;
61 using dtype_b = typename mem_desc_b_t::dtype;
63 static constexpr uint32_t pack_ratio = sizeof(dtype_b) * 2;
64
65 static constexpr mem_layout mem_layout_a = mem_desc_a_t::layout;
66 static constexpr mem_layout mem_layout_b = mem_desc_b_t::layout;
67 static constexpr bool is_col_major_a
68 = mem_layout_a == mem_layout::col_major;
69 static constexpr bool is_col_major_b
70 = mem_layout_b == mem_layout::col_major;
71
72private:
73 /******** set data type **********/
74 using dtype_a = typename mem_desc_a_t::dtype;
75 using dtype_mma_acc = typename compute_policy::dtype_mma_acc;
76 using dtype_mma_a = typename compute_policy::dtype_mma_a;
77 using dtype_mma_b = typename compute_policy::dtype_mma_b;
78 using dtype_scale = typename compute_policy::dtype_scale;
79
80 static_assert(std::is_same<remove_const_t<dtype_b>,
82 "this is for 4bit matB ");
83 static_assert(std::is_same<remove_const_t<dtype_zero_pt>,
85 "this is for 4bit zero_pt ");
86
87 /******** set memory attribute **********/
88 static constexpr mem_space mem_space_a = mem_desc_a_t::space;
89 static constexpr mem_space mem_space_b = mem_desc_b_t::space;
90
91 static constexpr bool is_local_a = mem_space_a == mem_space::local;
92 static constexpr bool is_local_b = mem_space_b == mem_space::local;
93 static constexpr tdesc_update_dir update_dir_a = is_col_major_a
96 static constexpr tdesc_update_dir update_dir_b = is_col_major_b
99 static_assert(!is_col_major_b, "only support MatB row-major for now");
100 static_assert((!is_local_a) && (!is_local_b),
101 "only support from global memory for now");
102
103 static constexpr uint32_t stages = compute_policy::stages;
104 static constexpr uint32_t sync_freq = compute_policy::sync_freq;
105
106 /******** set tile layout && worker scope **********/
107 static constexpr uint32_t tile_size_x_a = k_stride;
108 static constexpr uint32_t tile_size_y_a = sg_tile_m;
109 static constexpr uint32_t tile_size_x_b = sg_tile_n;
110 static constexpr uint32_t tile_size_y_b = k_stride;
111 static constexpr uint32_t tile_size_x_c = sg_tile_n;
112 static constexpr uint32_t tile_size_y_c = sg_tile_m;
113 static constexpr uint32_t block_size_x_a
114 = compute_policy::block_bytes_x_a / sizeof(dtype_mma_a);
115 static constexpr uint32_t block_size_y_a
116 = (compute_policy::block_size_y_a > tile_size_y_a)
117 ? tile_size_y_a
118 : compute_policy::block_size_y_a;
119
120 static constexpr uint32_t block_size_x_b = compute_policy::block_size_x_b;
121 static constexpr uint32_t block_size_y_b
122 = compute_policy::block_bytes_y_b / sizeof(dtype_mma_b);
123
124 /******** set tile **********/
125 static constexpr bool is_vnni_tiled_a
126 = (sizeof(dtype_a) < sizeof(uint32_t)) && is_col_major_a;
127 static constexpr reg_layout reg_layout_a
128 = is_vnni_tiled_a ? reg_layout::vnni_tiled : reg_layout::tiled;
129 using matA_tile_desc_t = subgroup::tile_desc_t<tile_size_x_a, tile_size_y_a,
130 block_size_x_a, block_size_y_a, reg_layout_a>;
134 subgroup::msg_type_v<matA_tile_desc_t, mem_space_a>, arch_tag>;
138 wg_size_x, arch_tag>;
139
140 //note: plane format, row-major
141 //note: 4bit x 2, row-major
142 using matB_tile_desc_t = subgroup::tile_desc_t<tile_size_x_b / pack_ratio,
143 tile_size_y_b, block_size_x_b / pack_ratio, block_size_y_b,
148 subgroup::msg_type_v<matB_tile_desc_t, mem_space_b>, arch_tag>;
150 matB_tile_desc_t, wg_size_y, arch_tag>;
151
153 = subgroup::tile_desc_t<tile_size_x_b, tile_size_y_b,
154 block_size_x_b, block_size_y_b, reg_layout::vnni_tiled>;
156
157public:
158 static_assert((k_stride % (block_size_y_b) == 0),
159 "k_stride%(block_size_y_b) == 0");
160 static_assert((dequant_s % (block_size_y_b) == 0),
161 "dequant_s%(block_size_y_b) == 0");
162 static_assert(
163 (k_stride % (dequant_s) == 0) || (dequant_s % (k_stride) == 0),
164 "k_stride should match with dequant_s");
165
166 //num_block_y set to 1
167 static constexpr uint32_t block_size_y_scale
168 = (k_stride + dequant_s - 1) / dequant_s;
169 static constexpr uint32_t tile_size_y_scale = block_size_y_scale;
170 static constexpr uint32_t block_size_y_zero_pt
171 = (k_stride + dequant_s - 1) / dequant_s;
172 static constexpr uint32_t tile_size_y_zero_pt = block_size_y_zero_pt;
173
174 static constexpr uint32_t scale_addr_update_freq
175 = (k_stride < dequant_s) ? dequant_s / k_stride : 1;
176
181
183 tile_size_y_c, block_size_x_b, block_size_y_a, reg_layout::tiled>;
185
186private:
188 = subgroup::tile_desc_t<tile_size_x_b, tile_size_y_scale,
189 block_size_x_b, block_size_y_scale, reg_layout::tiled>;
191 using scale_payload_t
193 subgroup::msg_type_v<scale_tile_desc_t, mem_space::global>,
194 arch_tag>;
196 = subgroup::tile_desc_t<tile_size_x_b / pack_ratio,
197 tile_size_y_zero_pt, block_size_x_b / pack_ratio,
198 block_size_y_zero_pt, reg_layout::tiled>;
202 subgroup::msg_type_v<zero_pt_tile_desc_t, mem_space::global>,
203 arch_tag>;
206 1, arch_tag>;
209 zero_pt_tile_desc_t, 1, arch_tag>;
210
212 matA_acc_t, mma_engine::xmx, arch_tag>;
213 static constexpr bool enable_periodic_sync = (sync_freq != 0);
214 static constexpr uint32_t barrier_count_x = wg_size_y > 1 ? wg_size_x : 0;
215 static constexpr uint32_t barrier_count_y = wg_size_x > 1 ? wg_size_y : 0;
216
217public:
218 static constexpr uint32_t barrier_count
219 = enable_periodic_sync ? barrier_count_x + barrier_count_y : 0;
220 // current only support matA from slm
221 static constexpr uint32_t slm_size = is_local_a
222 ? sg_tile_m * wg_size_y * k_stride * sizeof(dtype_a)
223 : 0;
224
225 static constexpr msg_type msg_type_a = matA_payload_t::message_type;
226 static constexpr msg_type msg_type_b = matB_payload_t::message_type;
227
230 struct arguments_t {
241
243 inline arguments_t() = default;
244
245 inline arguments_t(mem_desc_a_t matA_desc, mem_desc_b_t matB_desc,
246 uint32_t loop_count, mem_desc_scale_t scale_desc,
247 mem_desc_zero_pt_t zero_pt_desc)
248 : matA_base_desc(matA_desc)
249 , matB_base_desc(matB_desc)
250 , inner_loop_count(loop_count)
251 , scale_base_desc(scale_desc)
252 , zero_pt_base_desc(zero_pt_desc) {}
253 // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor)
254 // Please check if you need to add self-define destructor
255 // inline ~arguments_t(){}
256 inline arguments_t(const arguments_t &args)
257 : matA_base_desc(args.matA_base_desc)
258 , matB_base_desc(args.matB_base_desc)
259 , inner_loop_count(args.inner_loop_count)
260 , scale_base_desc(args.scale_base_desc)
261 , zero_pt_base_desc(args.zero_pt_base_desc) {}
262 inline arguments_t &operator=(const arguments_t &args) {
263 this->matA_base_desc = args.matA_base_desc;
264 this->matB_base_desc = args.matB_base_desc;
265 this->inner_loop_count = args.inner_loop_count;
266 this->scale_base_desc = args.scale_base_desc;
267 this->zero_pt_base_desc = args.zero_pt_base_desc;
268 return *this;
269 }
270 inline void init(mem_desc_a_t matA_desc, mem_desc_b_t matB_desc,
271 uint32_t loop_count, mem_desc_scale_t scale_desc,
272 mem_desc_zero_pt_t zero_pt_desc) {
273 matA_base_desc = matA_desc;
274 matB_base_desc = matB_desc;
275 inner_loop_count = loop_count;
276 scale_base_desc = scale_desc;
277 zero_pt_base_desc = zero_pt_desc;
278 }
279 };
280
285 int32_t sg_idx = g.get_id() % wg_size_x;
286 return sg_idx * sg_tile_n;
287 }
288
293 int32_t sg_idy = g.get_id() / wg_size_x;
294 return sg_idy * sg_tile_m;
295 }
296
298 "This release function will wait until all the r/w and nbarrier "
299 "id used in this gemm have been committed. By default, it will "
300 "use barrier_id 0 to do the entire workgroup sync if wg_size > 1. "
301 "If you call this function, please set a free barrier id or make "
302 "sure barrier_id 0 is not being occupied and you need to allocate "
303 "one more barrier count in addition to the gemm barrier counts.")
304 __XETLA_API static void release(uint8_t nbarrier_id = 0) {
305 static constexpr bool need_local_fence
306 = (mem_space_a == mem_space::local)
307 || (mem_space_b == mem_space::local);
308 if constexpr (need_local_fence) {
309 xetla_fence<memory_kind::shared_local>();
310 }
311 xetla_fence<memory_kind::untyped_global>();
312 static constexpr uint32_t wg_size = wg_size_x * wg_size_y;
313 if constexpr (wg_size > 1) {
315 nbarrier.init_nbarrier(
317 nbarrier.arrive_wait();
318 }
319 }
320
329 arguments_t args, [[maybe_unused]] uint32_t slm_base = 0,
330 uint32_t nbarrier_base = 0) {
331 int32_t sg_idx = g.get_id() % wg_size_x;
332 int32_t sg_idy = g.get_id() / wg_size_x;
333 update_sg_tile_tdesc(args, sg_idx, sg_idy);
334
335 matA_t matA;
336 matB_t matB;
337 scale_t scale;
338 zero_pt_t zero_pt;
339
340 matA_payload_t matA_payload(args.matA_base_desc);
341 matB_payload_t matB_payload(args.matB_base_desc);
342 scale_payload_t scale_payload(args.scale_base_desc);
343 zero_pt_payload_t zero_pt_payload(args.zero_pt_base_desc);
344 matA_prefetch_payload_t matA_prefetch_payload(
345 args.matA_base_desc, sg_idx);
346 matB_prefetch_payload_t matB_prefetch_payload(
347 args.matB_base_desc, sg_idy);
348 scale_prefetch_payload_t scale_prefetch_payload(
349 args.scale_base_desc, 0);
350 zero_pt_prefetch_payload_t zero_pt_prefetch_payload(
351 args.zero_pt_base_desc, 0);
352
354 nbarrier_a.init_nbarrier(
355 sg_idy + nbarrier_base, nbarrier_role::producer_consumer);
357 nbarrier_b.init_nbarrier(sg_idx + barrier_count_y + nbarrier_base,
359
360 int scale_prefetch_addr_i = 0;
361 int scale_load_addr_i = 0;
362 SW_BARRIER();
363#pragma unroll
364 for (uint32_t i = 0; i < stages; i++) {
365 subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
366 matA_prefetch_payload);
367 subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
368 matB_prefetch_payload);
369 subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
370 scale_prefetch_payload);
371 subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
372 zero_pt_prefetch_payload);
373 scale_prefetch_addr_i++;
374 matA_prefetch_payload.template update_tdesc<update_dir_a>(
375 matA_t::tile_size_x);
376 matB_prefetch_payload.template update_tdesc<update_dir_b>(
377 matB_t::tile_size_y);
378 if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) {
379 scale_prefetch_payload
380 .template update_tdesc<tdesc_update_dir::y_dir>(
381 scale_t::tile_size_y);
382 zero_pt_prefetch_payload
383 .template update_tdesc<tdesc_update_dir::y_dir>(
384 zero_pt_t::tile_size_y);
385 }
386 }
387
388 for (uint32_t i = 0; i < args.inner_loop_count; i++) {
389 if constexpr (enable_periodic_sync) {
390 if ((i % sync_freq) == 0) {
391 if constexpr (wg_size_x > 1) { nbarrier_a.arrive(); }
392 if constexpr (wg_size_y > 1) { nbarrier_b.arrive(); }
393 }
394 }
395 subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
396 matA, matA_payload);
397 subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
398 matB, matB_payload);
399 subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
400 scale, scale_payload);
401 subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
402 zero_pt, zero_pt_payload);
403 scale_load_addr_i++;
404 SW_BARRIER();
405 if constexpr (stages != 0) {
406 subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
407 matA_prefetch_payload);
408 subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
409 matB_prefetch_payload);
410 subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
411 scale_prefetch_payload);
412 subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
413 zero_pt_prefetch_payload);
414 scale_prefetch_addr_i++;
415 }
416 SW_BARRIER();
417 matA_payload.template update_tdesc<update_dir_a>(
418 matA_t::tile_size_x);
419 matB_payload.template update_tdesc<update_dir_b>(
420 matB_t::tile_size_y);
421 if ((scale_load_addr_i % scale_addr_update_freq) == 0) {
422 scale_payload.template update_tdesc<tdesc_update_dir::y_dir>(
423 scale_t::tile_size_y);
424 zero_pt_payload.template update_tdesc<tdesc_update_dir::y_dir>(
425 zero_pt_t::tile_size_y);
426 }
427 if constexpr (stages != 0) {
428 matA_prefetch_payload.template update_tdesc<update_dir_a>(
429 matA_t::tile_size_x);
430 matB_prefetch_payload.template update_tdesc<update_dir_b>(
431 matB_t::tile_size_y);
432 if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) {
433 scale_prefetch_payload
434 .template update_tdesc<tdesc_update_dir::y_dir>(
435 scale_t::tile_size_y);
436 zero_pt_prefetch_payload
437 .template update_tdesc<tdesc_update_dir::y_dir>(
438 zero_pt_t::tile_size_y);
439 }
440 }
441 SW_BARRIER();
442 matA_acc_t matA_acc;
443 matB_acc_t matB_acc;
444 if constexpr (is_vnni_tiled_a) { subgroup::vnni_reverse(matA); }
445 subgroup::elemwise_cvt(matA_acc, matA);
446 dequantize(matB_acc, matB, scale, zero_pt);
447 SW_BARRIER();
448 tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc);
449 SW_BARRIER();
450 if constexpr (enable_periodic_sync) {
451 if ((i % sync_freq) == 0) {
452 if constexpr (wg_size_x > 1) { nbarrier_a.wait(); }
453 if constexpr (wg_size_y > 1) { nbarrier_b.wait(); }
454 }
455 }
456 }
457 SW_BARRIER();
458 }
459
460private:
461 inline void dequantize(matB_acc_t &matB_acc, matB_t &matB, scale_t &scale,
462 zero_pt_t &zero_pt) {
463 //no tail, because this is matB
464 constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b;
465 constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b;
466 constexpr uint32_t vnni_rows = sizeof(uint32_t) / sizeof(dtype_mma_b);
467 constexpr uint32_t block_b_y_per_scale = dequant_s / block_size_y_b;
468#pragma unroll
469 for (uint32_t i = 0; i < num_block_y; ++i) {
470#pragma unroll
471 for (uint32_t j = 0; j < num_block_x; ++j) {
472 int block_id = (i * num_block_x + j);
473 auto matB_blk = matB.reg.xetla_select<matB_t::block_elems, 1>(
474 block_id * matB_t::block_elems)
475 .xetla_format<uint8_t>();
476 int scale_block_id
477 = (i / block_b_y_per_scale * num_block_x + j);
478 auto scale_vec
479 = scale.reg.xetla_select<scale_t::block_size_x, 1>(
480 scale_block_id * scale_t::block_size_x);
481 auto zero_pt_vec
482 = zero_pt.reg
483 .xetla_select<zero_pt_t::block_size_x, 1>(
484 scale_block_id
485 * zero_pt_t::block_size_x)
486 .xetla_format<uint8_t>();
487
488 auto dst_blk
489 = matB_acc.reg.xetla_select<matB_acc_t::block_elems, 1>(
490 block_id * matB_acc_t::block_elems);
491
493 //2: int8 includes 2 4bits data.
494 zero_pt_sub.xetla_select<block_size_x_b / 2, 2>(0)
495 = zero_pt_vec & 0x0f;
496 zero_pt_sub.xetla_select<block_size_x_b / 2, 2>(1)
497 = zero_pt_vec >> 4;
498
500 zero_pt_blk;
501#pragma unroll
502 for (uint32_t row = 0; row < block_size_y_b; row++) {
503 zero_pt_blk
504 .xetla_select<block_size_x_b, 1>(
505 row * block_size_x_b)
506 .xetla_format<int8_t>()
507 = zero_pt_sub.xetla_format<int8_t>() + int8_t(1);
508 }
509
511 cvt_blk.xetla_select<matB_t::block_elems, 2>(0)
512 = matB_blk & 0x0f;
513 cvt_blk.xetla_select<matB_t::block_elems, 2>(1) = matB_blk >> 4;
514
516 cvt_blk_i32;
517
518 cvt_blk_i32 = (cvt_blk.xetla_format<int8_t>()
519 - zero_pt_blk.xetla_format<int8_t>());
520
522 temp_blk;
523 temp_blk.xetla_select<matB_acc_t::block_elems, vnni_rows>(0)
524 = cvt_blk_i32;
525
526#pragma unroll
527 for (uint32_t k = 0; k < block_size_y_b; k += vnni_rows) {
528#pragma unroll
529 for (uint32_t row = 0; row < vnni_rows; row++) {
530 temp_blk.xetla_select<block_size_x_b, vnni_rows>(
531 row + block_size_x_b * k * vnni_rows)
532 = temp_blk.xetla_select<block_size_x_b,
533 vnni_rows>(
534 (k + row) * block_size_x_b * vnni_rows);
535 }
536 }
537
538 xetla_vector<dtype_scale, block_size_x_b * vnni_rows> scale_blk;
539#pragma unroll
540 for (uint32_t row = 0; row < vnni_rows; row++) {
541 scale_blk.xetla_select<block_size_x_b, vnni_rows>(row)
542 = scale_vec;
543 }
544
545#pragma unroll
546 for (uint32_t k = 0; k < block_size_y_b; k += vnni_rows) {
547 dst_blk.xetla_select<block_size_x_b * vnni_rows, 1>(
548 k * block_size_x_b)
549 = temp_blk.xetla_select<block_size_x_b * vnni_rows,
550 1>(k * block_size_x_b * vnni_rows)
551 * scale_blk;
552 }
553 }
554 }
555 }
557 __XETLA_API static void update_sg_tile_tdesc(
558 arguments_t &args, int32_t sg_idx, int32_t sg_idy) {
559 int32_t tile_offset_n = sg_idx * sg_tile_n;
560 int32_t tile_offset_m = sg_idy * sg_tile_m;
561
562 args.matA_base_desc.update_coord_y(tile_offset_m);
563 args.matB_base_desc.update_coord_x(tile_offset_n / pack_ratio);
564 args.scale_base_desc.update_coord_x(tile_offset_n);
565 args.zero_pt_base_desc.update_coord_x(tile_offset_n / pack_ratio);
566 }
567};
568
570
571} // namespace gpu::xetla::group
XETLA_MARKER("This release function will wait until all the r/w and nbarrier " "id used in this gemm have been committed. By default, it will " "use barrier_id 0 to do the entire workgroup sync if wg_size > 1. " "If you call this function, please set a free barrier id or make " "sure barrier_id 0 is not being occupied and you need to allocate " "one more barrier count in addition to the gemm barrier counts.") __XETLA_API static void release(uint8_t nbarrier_id=0)
Definition int4_dequantize_xmx_xe.hpp:297
__XETLA_API KERNEL_FUNC void operator()(work_group_t &g, matAcc_t &matAcc, arguments_t args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Main execution function for gemm.
Definition int4_dequantize_xmx_xe.hpp:328
Gemm functor.
Definition api.hpp:52
typename std::remove_const< T >::type remove_const_t
Definition common.hpp:26
#define SW_BARRIER()
SW_BARRIER, insert software scheduling barrier, for better code control.
Definition common.hpp:227
#define __XETLA_API
Definition common.hpp:43
C++ API.
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
__XETLA_API std::enable_if_t<(T_src::register_layout !=reg_layout::linear) &&(T_dst::register_layout !=reg_layout::linear) &&is_same_layout< T_dst, T_src >::value &&(!is_floating_to_integer< T_dst, T_src >::value)> elemwise_cvt(T_dst &dst, T_src &src)
Is the element wise data conversion, the src and dst tile should have the same layout.
Definition op_function.hpp:40
__XETLA_API std::enable_if_t< T::register_layout==reg_layout::tiled > vnni_reverse(T &mat_Acc)
Converts vnni_tiled layout format to tiled layout.
Definition op_function.hpp:196
reg_layout
tile layout in register linear: linear layout with one tile tiled: 2d block stacked in raster order v...
Definition common.hpp:209
mem_space
Definition common.hpp:77
gpu_arch
Definition common.hpp:73
msg_type
Definition common.hpp:78
tdesc_update_dir
Definition common.hpp:228
mem_layout
Definition common.hpp:76
arguments_t(mem_desc_a_t matA_desc, mem_desc_b_t matB_desc, uint32_t loop_count, mem_desc_scale_t scale_desc, mem_desc_zero_pt_t zero_pt_desc)
Definition int4_dequantize_xmx_xe.hpp:245
void init(mem_desc_a_t matA_desc, mem_desc_b_t matB_desc, uint32_t loop_count, mem_desc_scale_t scale_desc, mem_desc_zero_pt_t zero_pt_desc)
Definition int4_dequantize_xmx_xe.hpp:270
mem_desc_zero_pt_t zero_pt_base_desc
Is the memory description of zero_pt buffer. Zero_pt size: (matrix_k/dequant_s)x(matrix_n/pack_ratio)
Definition int4_dequantize_xmx_xe.hpp:240
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the memory information to prefetch data to cache.
Definition api.hpp:53
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is the xetla tile mma operation definition API.
Definition api.hpp:36
Is a struct contains some register file.
Definition api.hpp:99
xetla nbarrier definition API.
Definition raw_send_nbarrier.hpp:43
__XETLA_API void arrive()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:65
__XETLA_API void arrive_wait()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:80
__XETLA_API void init_nbarrier(uint8_t nbarrier_id, nbarrier_role role=nbarrier_role::producer_consumer)
Definition raw_send_nbarrier.hpp:55
__XETLA_API void wait()
named barrier wait within subgroup.
Definition raw_send_nbarrier.hpp:76