19#include "common/common.hpp"
20#include "group/group.hpp"
21#include "subgroup/subgroup.hpp"
26#define rand_threshold_const 0x80000000
27#define SIGN_BIT_DW 0x80000000
28#define SIGN_BIT_W16 0x8000
29#define SIGN_BIT_B8 0x80
41template <
typename dtype_bin_,
typename dtype_bot_,
typename dtype_sfx_,
42 typename dtype_acc_,
int HWThreadNum,
bool Dopt_RandGenflag =
true,
43 uint16_t RandSIMD = 16,
int Max_SeqLen = 512>
130 matAcc_128x128_t::tile_desc::tile_size_y,
131 matAcc_128x128_t::tile_desc::block_size_x,
132 matAcc_128x128_t::tile_desc::block_size_y,
136 matAcc_128x256_t::tile_desc::tile_size_y,
137 matAcc_128x256_t::tile_desc::block_size_x,
138 matAcc_128x256_t::tile_desc::block_size_y,
142 matAcc_128x64_t::tile_desc::tile_size_y,
143 matAcc_128x64_t::tile_desc::block_size_x,
144 matAcc_128x64_t::tile_desc::block_size_y,
177 subgroup::msg_type_v<matElem_tile_desc_t, mem_space::global>,
184 subgroup::msg_type_v<matElem_tile_desc_t, mem_space::global>,
216 int tru_seqlen_ex = 0;
217 int seqlen_entry = 0;
219 int groupid = item.get_group(0);
220 int hiddensize = 1024;
223 int wg_tile_QKT_k = hdsz;
225 int batchid = groupid / numhead;
226 int headid = groupid % numhead;
229 int tid_linear = item.get_local_linear_id();
230 g_thd32_tid.init(tid_linear);
232 uint32_t batch_offset =
sizeof(uint32_t) *
list_width * batchid;
234 = xetla_vector_gen<uint32_t, list_width>(0, 1);
235 list_offsets *=
sizeof(uint32_t);
236 list_offsets += batch_offset;
242 tru_seqlen = list_vec[0];
243 seqlen_entry = list_vec[1];
244 wg_tile_out_k = tru_seqlen;
245 tru_seqlen_ex = tru_seqlen;
247 tru_seqlen_ex = (((tru_seqlen + 1) >> 1) << 1);
249 tru_seqlen_ex = (((tru_seqlen + 3) >> 2) << 2);
254 if constexpr (Dopt_RandGenflag ==
true) {
255 uint64_t rand_seed = 67280421310721;
258 uint64_t rand_offset = list_vec.xetla_format<uint64_t>()[1];
259 if (list_vec[4] != 0) rand_threshold = list_vec[4];
260 if (rand_offset == 0) {
262 rand_offset = time_stamp.xetla_format<uint64_t>()[0];
264 Rand_Gen.
init(rand_seed, rand_subseq, rand_offset);
268 int all_vert_loop_num = 2;
269 int blk_128x128_one = 0;
270 int blk_128x256_loop_num = 1;
271 int offset_blk_128x128 = 0;
274 if (tru_seqlen <= 128) {
276 all_vert_loop_num = 1;
278 blk_128x256_loop_num = 0;
279 }
else if (tru_seqlen <= 256)
281 else if (tru_seqlen <= 384) {
283 all_vert_loop_num = 3;
285 blk_128x256_loop_num = 1;
286 offset_blk_128x128 = 256;
289 all_vert_loop_num = 4;
291 blk_128x256_loop_num = 2;
299 for (
int all_vert128_loop = 0; all_vert128_loop < all_vert_loop_num;
300 all_vert128_loop++) {
301 for (
int hor_256_loop = 0; hor_256_loop < blk_128x256_loop_num;
308 uint32_t width_a = (headid + 1) * hdsz;
309 uint32_t height_a = tru_seqlen + seqlen_entry;
310 uint32_t pitch_a = hiddensize;
311 int start_x_a = headid * hdsz;
312 int start_y_a = all_vert128_loop * 128 + seqlen_entry;
314 gemm_arg_128x256.matA_base_desc.init({args->
matQ_ptr},
315 {width_a, height_a, pitch_a}, {start_x_a, start_y_a});
317 uint32_t width_b = (headid + 1) * hdsz;
318 uint32_t height_b = tru_seqlen + seqlen_entry;
319 uint32_t pitch_b = hiddensize;
320 int start_x_b = headid * hdsz;
321 int start_y_b = hor_256_loop * 256 + seqlen_entry;
324 gemm_arg_128x256.matB_base_desc.init({args->
matK_ptr},
325 {height_b, width_b, pitch_b}, {start_y_b, start_x_b});
327 gemm_arg_128x256.inner_loop_count
330 matAcc_128x256.init(0);
333 gemm_op_128x256(g_thd32_tid, matAcc_128x256, gemm_arg_128x256);
337 =
max_seqlen * (batchid * numhead + headid + 1);
340 = gemm_op_128x256_t::get_matC_offset_x(g_thd32_tid)
341 + hor_256_loop * 256;
342 int start_y_c = (batchid * numhead + headid) *
max_seqlen
343 + all_vert128_loop * 128
344 + gemm_op_128x256_t::get_matC_offset_y(g_thd32_tid);
346 matC_128x256_payload.init(args->
matQKT_ptr, width_c, height_c,
347 pitch_c, start_x_c, start_y_c);
348 subgroup::elemwise_cvt<matC_128x256_t, matAcc_128x256_t>(
349 matC_128x256, matAcc_128x256);
351 xetla_fence<memory_kind::untyped_global>();
354 for (
int blk_128x128_loop = 0; blk_128x128_loop < blk_128x128_one;
355 blk_128x128_loop++) {
361 uint32_t width_a = (headid + 1) * hdsz;
362 uint32_t height_a = tru_seqlen + seqlen_entry;
363 uint32_t pitch_a = hiddensize;
364 int start_x_a = headid * hdsz;
365 int start_y_a = all_vert128_loop * 128 + seqlen_entry;
367 gemm_arg_128x128.matA_base_desc.init({args->
matQ_ptr},
368 {width_a, height_a, pitch_a}, {start_x_a, start_y_a});
370 uint32_t width_b = (headid + 1) * hdsz;
371 uint32_t height_b = tru_seqlen + seqlen_entry;
372 uint32_t pitch_b = hiddensize;
373 int start_x_b = headid * hdsz;
374 int start_y_b = offset_blk_128x128 + seqlen_entry;
377 gemm_arg_128x128.matB_base_desc.init({args->
matK_ptr},
378 {height_b, width_b, pitch_b}, {start_y_b, start_x_b});
380 gemm_arg_128x128.inner_loop_count
383 matAcc_128x128.init(0);
386 gemm_op_128x128(g_thd32_tid, matAcc_128x128, gemm_arg_128x128);
390 =
max_seqlen * (batchid * numhead + headid + 1);
392 int start_x_c = offset_blk_128x128
393 + gemm_op_128x128_t::get_matC_offset_x(g_thd32_tid);
394 int start_y_c = (batchid * numhead + headid) *
max_seqlen
395 + all_vert128_loop * 128
396 + gemm_op_128x128_t::get_matC_offset_y(g_thd32_tid);
398 matC_128x128_payload.init(args->
matQKT_ptr, width_c, height_c,
399 pitch_c, start_x_c, start_y_c);
400 subgroup::elemwise_cvt<matC_128x128_t, matAcc_128x128_t>(
401 matC_128x128, matAcc_128x128);
403 xetla_fence<memory_kind::untyped_global>();
408 int elem_Ln512_loop_num = 4;
411 int height_elem_offset
413 + (all_vert128_loop * 128) + (tid_linear * 4))
415 int width_elem = width_8x16_512;
417 int pitch_elem = width_elem;
418 int start_x_elem = 0;
420 int bndy_mk_lp_start = (tru_seqlen + 31) >> 5;
422 = 32 - (bndy_mk_lp_start << 5) + tru_seqlen;
426 =
sizeof(uint32_t) * (
max_seqlen / 32) * (batchid);
428 = xetla_vector_gen<uint32_t, 16>(0, 1);
429 mk_attn_offsets *=
sizeof(uint32_t);
430 mk_attn_offsets += mk_attn_all;
436 uint32_t mk_offset_all =
sizeof(uint32_t) * (
max_seqlen / 32)
438 + (all_vert128_loop * 128) + tid_linear * 4);
440 = xetla_vector_gen<uint32_t, 16>(0, 1);
441 mk_offsets *=
sizeof(uint32_t);
442 mk_offsets += mk_offset_all;
447 for (
int elem_Ln512_loop = 0;
448 elem_Ln512_loop < elem_Ln512_loop_num;
458 start_y_elem = height_elem_offset
459 + elem_Ln512_loop * height_8x64_512;
460 height_elem = start_y_elem
463 matQKT_rd_payload.init(args->
matQKT_ptr, width_elem,
464 height_elem, pitch_elem, start_x_elem,
466 matQKT_st_payload.init(args->
matQKT_ptr, width_elem,
467 height_elem, pitch_elem, start_x_elem,
470 subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
471 matQKT_rd, matQKT_rd_payload);
473 if constexpr (Dopt_RandGenflag ==
false) {
480 mk_offsets +=
sizeof(uint32_t) * (
max_seqlen / 32);
482 for (
int j = bndy_mk_lp_start; j < 16; j++)
483 mkin_vec16[j] = 0xFFFFFFFF;
484 if (bndy_mk_lp_shift < 32) {
485 uint32_t tmp = 0xFFFFFFFF;
486 tmp >>= bndy_mk_lp_shift;
487 tmp <<= bndy_mk_lp_shift;
488 mkin_vec16[bndy_mk_lp_start - 1] |= tmp;
494 = xetla_cvt<float, dtype_sfx>(matQKT_rd.
reg);
495 matQKT_reg16x32.
reg = matQKT_reg16x32.
reg * args->
Pinv;
498 for (
int j = 0; j < 16; j++) {
499 uint32_t mkdata_i = mkin_vec16[j];
501 = xetla_mask_int_gen<32>(mkdata_i);
502 matQKT_reg16x32.
reg.xetla_format<
float>()
503 .xetla_select<32, 1>(j * 32)
506 .xetla_format<
float>()
507 .xetla_select<32, 1>(j * 32),
514 for (
int j = 0; j < 32; j++) {
516 > matQKT_reg16x32.
reg.xetla_format<
float>()
517 .xetla_select<16, 1>(j * 16));
518 QKT_reg16_f.xetla_merge(QKT_reg16_f,
519 matQKT_reg16x32.
reg.xetla_format<
float>()
520 .xetla_select<16, 1>(j * 16),
525 = (QKT_reg16_f.xetla_select<8, 1>(0)
526 > QKT_reg16_f.xetla_select<8, 1>(8));
528 QKT_reg16_f.select<8, 1>(0),
529 QKT_reg16_f.select<8, 1>(8), filter_max8);
531 = (QKT_reg16_f.xetla_select<4, 1>(0)
532 > QKT_reg16_f.xetla_select<4, 1>(4));
534 QKT_reg16_f.select<4, 1>(0),
535 QKT_reg16_f.select<4, 1>(4), filter_max4);
537 = (QKT_reg16_f.xetla_select<2, 1>(0)
538 > QKT_reg16_f.xetla_select<2, 1>(2));
540 QKT_reg16_f.select<2, 1>(0),
541 QKT_reg16_f.select<2, 1>(2), filter_max2);
543 = (QKT_reg16_f.xetla_select<1, 1>(0)
544 > QKT_reg16_f.xetla_select<1, 1>(1));
546 QKT_reg16_f.xetla_select<1, 1>(0),
547 QKT_reg16_f.xetla_select<1, 1>(1), filter_max1);
550 float tmp_max = QKT_reg16_f[0];
551 matQKT_reg16x32.
reg = matQKT_reg16x32.
reg - tmp_max;
555 for (
int j = 0; j < 16; j++)
556 matQKT_reg16x32.
reg.xetla_format<
float>()
557 .xetla_select<32, 1>(j * 32)
558 = xetla_exp<float, 32>(
560 .xetla_format<
float>()
561 .xetla_select<32, 1>(j * 32));
563 QKT_reg16_f = matQKT_reg16x32.
reg.xetla_format<
float>()
564 .xetla_select<16, 1>(0)
565 + matQKT_reg16x32.
reg.xetla_format<
float>()
566 .xetla_select<16, 1>(16);
568 for (
int j = 2; j < 32; j++)
569 QKT_reg16_f = QKT_reg16_f
570 + matQKT_reg16x32.
reg.xetla_format<
float>()
571 .xetla_select<16, 1>(j * 16);
573 QKT_reg16_f.xetla_select<8, 1>(0)
574 += QKT_reg16_f.xetla_select<8, 1>(8);
575 QKT_reg16_f.xetla_select<4, 1>(0)
576 += QKT_reg16_f.xetla_select<4, 1>(4);
577 QKT_reg16_f.xetla_select<2, 1>(0)
578 += QKT_reg16_f.xetla_select<2, 1>(2);
579 QKT_reg16_f.xetla_select<1, 1>(0)
580 += QKT_reg16_f.xetla_select<1, 1>(1);
582 QKT_reg16_f.xetla_select<1, 1>(0) = xetla_inv<float, 1>(
583 QKT_reg16_f.xetla_select<1, 1>(0));
585 float tmp = QKT_reg16_f[0];
590 for (
int j = 0; j < 32; j++)
591 matQKT_reg16x32.
reg.xetla_format<
float>()
592 .xetla_select<16, 1>(j * 16)
606 matElem_reg_w_t drop_mk_w;
607 matElem_reg_b_t drop_mk_b;
609 if constexpr (Dopt_RandGenflag ==
true) {
610 matQKT_st.
reg = xetla_cvt<dtype_sfx, float>(
611 matQKT_reg16x32.
reg);
613 using matElem_reg_w_t
617 using matElem_reg_b_t
621 matElem_reg_w_t drop_mk_w;
622 matElem_reg_b_t drop_mk_b;
625 for (
int i = 0; i < (Max_SeqLen / (4 * 4 * RandSIMD));
627 rand_data = Rand_Gen.
rand();
628 rand_bit.xetla_select<4 * RandSIMD, 1>(
630 = rand_data > rand_threshold;
633 for (
int j = 0; j < 4; j++) {
636 drop_mk_w.reg.xetla_select<32, 1>(0)
638 rand_bit.xetla_select<32, 1>(
640 matQKT_st.
reg.xetla_format<uint16_t>()
641 .xetla_select<32, 1>(j * 32)
642 |= drop_mk_w.
reg.xetla_select<32, 1>(0);
645 drop_mk_b.reg.xetla_select<32, 1>(0)
647 rand_bit.xetla_select<32, 1>(
649 matQKT_st.
reg.xetla_format<uint8_t>()
650 .xetla_select<32, 1>(j * 32)
651 |= drop_mk_b.
reg.xetla_select<32, 1>(0);
655 if (std_seqlen > 128) {
658 i < (Max_SeqLen / (4 * 4 * RandSIMD));
660 rand_data = Rand_Gen.
rand();
661 rand_bit.xetla_select<4 * RandSIMD, 1>(
663 = rand_data > rand_threshold;
666 for (
int j = 4; j < 8; j++) {
668 drop_mk_w.reg.xetla_select<32, 1>(0)
670 rand_bit.xetla_select<32,
672 matQKT_st.
reg.xetla_format<uint16_t>()
673 .xetla_select<32, 1>(j * 32)
675 .xetla_select<32, 1>(0);
678 drop_mk_b.reg.xetla_select<32, 1>(0)
680 rand_bit.xetla_select<32,
682 matQKT_st.
reg.xetla_format<uint8_t>()
683 .xetla_select<32, 1>(j * 32)
685 .xetla_select<32, 1>(0);
689 if (std_seqlen > 256) {
692 i < (Max_SeqLen / (4 * 4 * RandSIMD));
694 rand_data = Rand_Gen.
rand();
695 rand_bit.xetla_select<4 * RandSIMD, 1>(
697 = rand_data > rand_threshold;
700 for (
int j = 8; j < 12; j++) {
702 drop_mk_w.reg.xetla_select<32, 1>(0)
704 rand_bit.xetla_select<
707 matQKT_st.
reg.xetla_format<uint16_t>()
708 .xetla_select<32, 1>(j * 32)
710 .xetla_select<32, 1>(
714 drop_mk_b.reg.xetla_select<32, 1>(0)
716 rand_bit.xetla_select<
719 matQKT_st.
reg.xetla_format<uint8_t>()
720 .xetla_select<32, 1>(j * 32)
722 .xetla_select<32, 1>(
726 if (std_seqlen > 384) {
729 < (Max_SeqLen / (4 * 4 * RandSIMD));
731 rand_data = Rand_Gen.
rand();
732 rand_bit.xetla_select<4 * RandSIMD, 1>(
734 = rand_data > rand_threshold;
737 for (
int j = 12; j < 16; j++) {
739 drop_mk_w.reg.xetla_select<32, 1>(0)
742 rand_bit.xetla_select<
747 .xetla_format<uint16_t>()
748 .xetla_select<32, 1>(j * 32)
754 drop_mk_b.reg.xetla_select<32, 1>(0)
756 rand_bit.xetla_select<
761 .xetla_format<uint8_t>()
762 .xetla_select<32, 1>(j * 32)
772 matQKT_st.
reg = xetla_cvt<dtype_sfx, float>(
773 matQKT_reg16x32.
reg);
775 for (
int j = 0; j < 16; j++) {
776 uint32_t mkdata_i = mkdpot_vec16[j];
778 = xetla_mask_int_gen<32>(mkdata_i);
780 drop_mk_w.reg.xetla_select<32, 1>(0)
782 matQKT_st.
reg.xetla_format<uint16_t>()
783 .xetla_select<32, 1>(j * 32)
784 |= drop_mk_w.
reg.xetla_select<32, 1>(0);
787 drop_mk_b.reg.xetla_select<32, 1>(0)
789 matQKT_st.
reg.xetla_format<uint8_t>()
790 .xetla_select<32, 1>(j * 32)
791 |= drop_mk_b.
reg.xetla_select<32, 1>(0);
795 matQKT_st.
reg = xetla_cvt<dtype_sfx, float>(
796 matQKT_reg16x32.
reg);
800 xetla_fence<memory_kind::untyped_global>();
814 uint32_t width_a = tru_seqlen_ex;
815 uint32_t height_a = (batchid * numhead + headid) *
max_seqlen
819 int start_y_a = (batchid * numhead + headid) *
max_seqlen
820 + all_vert128_loop * 128;
822 gemm_arg_128x64.matA_base_desc.init({args->
matQKT_ptr},
823 {width_a, height_a, pitch_a}, {start_x_a, start_y_a});
825 uint32_t width_b = (headid + 1) * hdsz;
826 uint32_t height_b = tru_seqlen + seqlen_entry;
827 uint32_t pitch_b = hiddensize;
828 int start_x_b = headid * hdsz;
829 int start_y_b = seqlen_entry;
831 gemm_arg_128x64.matB_base_desc.init({args->
matV_ptr},
832 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
834 gemm_arg_128x64.inner_loop_count
837 matAcc_128x64.init(0);
839 gemm_op_128x64(g_thd32_tid, matAcc_128x64, gemm_arg_128x64);
841 uint32_t width_c = (headid + 1) * hdsz;
842 uint32_t height_c = tru_seqlen + seqlen_entry;
843 uint32_t pitch_c = hiddensize;
844 int start_x_c = headid * hdsz
845 + gemm_op_128x64_t::get_matC_offset_x(g_thd32_tid);
846 int start_y_c = all_vert128_loop * 128 + seqlen_entry
847 + gemm_op_128x64_t::get_matC_offset_y(g_thd32_tid);
849 matC_128x64_payload.init(args->
matOut_ptr, width_c, height_c,
850 pitch_c, start_x_c, start_y_c);
851 subgroup::elemwise_cvt<matC_128x64_t, matAcc_128x64_t>(
852 matC_128x64, matAcc_128x64);
869template <
typename dtype_bwd_bin_,
typename dtype_bwd_bot_,
870 typename dtype_bwd_sfx_,
typename dtype_bwd_acc_,
int HWThreadNum,
871 bool Dopt_RandGenflag =
true,
bool Mkin_flag =
false>
982 typename gemm_op_128x64_trnp_a_t::arguments_t;
984 typename gemm_op_256x64_trnp_a_t::arguments_t;
986 typename gemm_op_128x64_trnp_af_t::arguments_t;
988 typename gemm_op_256x64_trnp_af_t::arguments_t;
1000 matAcc_128x128_t::tile_desc::tile_size_y,
1001 matAcc_128x128_t::tile_desc::block_size_x,
1002 matAcc_128x128_t::tile_desc::block_size_y,
1006 matAcc_128x256_t::tile_desc::tile_size_y,
1007 matAcc_128x256_t::tile_desc::block_size_x,
1008 matAcc_128x256_t::tile_desc::block_size_y,
1012 matAcc_128x64_t::tile_desc::tile_size_y,
1013 matAcc_128x64_t::tile_desc::block_size_x,
1014 matAcc_128x64_t::tile_desc::block_size_y,
1017 matAcc_128x64_trnp_a_t::tile_desc::tile_size_x,
1018 matAcc_128x64_trnp_a_t::tile_desc::tile_size_y,
1019 matAcc_128x64_trnp_a_t::tile_desc::block_size_x,
1022 matAcc_256x64_trnp_a_t::tile_desc::tile_size_x,
1023 matAcc_256x64_trnp_a_t::tile_desc::tile_size_y,
1024 matAcc_256x64_trnp_a_t::tile_desc::block_size_x,
1027 matAcc_128x64_trnp_af_t::tile_desc::tile_size_x,
1028 matAcc_128x64_trnp_af_t::tile_desc::tile_size_y,
1029 matAcc_128x64_trnp_af_t::tile_desc::block_size_x,
1030 matAcc_128x64_trnp_af_t::tile_desc::block_size_y,
1033 matAcc_256x64_trnp_af_t::tile_desc::tile_size_x,
1034 matAcc_256x64_trnp_af_t::tile_desc::tile_size_y,
1035 matAcc_256x64_trnp_af_t::tile_desc::block_size_x,
1036 matAcc_256x64_trnp_af_t::tile_desc::block_size_y,
1072 : subgroup::msg_type_v<
1120 subgroup::msg_type_v<matElem_tile_desc_t, mem_space::global>,
1157 int tru_seqlen_ex = 0;
1158 int seqlen_entry = 0;
1159 int hiddensize = 1024;
1162 int max_seqlen = 512;
1163 int wg_tile_QKT_k = hdsz;
1166 int groupid = item.get_group(0);
1167 int batchid = groupid / numhead;
1168 int headid = groupid % numhead;
1172 uint32_t batch_offset =
sizeof(uint32_t) *
list_width * batchid;
1174 = xetla_vector_gen<uint32_t, list_width>(0, 1);
1175 list_offsets *=
sizeof(uint32_t);
1176 list_offsets += batch_offset;
1181 tru_seqlen = list_vec[0];
1182 seqlen_entry = list_vec[1];
1183 wg_tile_out_k = tru_seqlen;
1184 tru_seqlen_ex = tru_seqlen;
1186 tru_seqlen_ex = ((tru_seqlen + 1) >> 1) << 1;
1188 tru_seqlen_ex = ((tru_seqlen + 3) >> 2) << 2;
1191 int all_vert_loop_num = 0;
1192 int transp128_loop_num = 0;
1193 int transp256_loop_num = 0;
1194 int blk_128x128_one = 0;
1195 int blk_128x256_loop_num = 0;
1196 int offset_blk_128x128 = 0;
1198 if (tru_seqlen <= 128) {
1200 all_vert_loop_num = 1;
1201 transp128_loop_num = 1;
1202 blk_128x128_one = 1;
1203 }
else if (tru_seqlen <= 256) {
1205 all_vert_loop_num = 2;
1206 transp256_loop_num = 1;
1207 blk_128x256_loop_num = 1;
1208 }
else if (tru_seqlen <= 384) {
1210 all_vert_loop_num = 3;
1211 transp128_loop_num = 1;
1212 transp256_loop_num = 1;
1213 blk_128x128_one = 1;
1214 blk_128x256_loop_num = 1;
1215 offset_blk_128x128 = 256;
1218 all_vert_loop_num = 4;
1219 transp256_loop_num = 2;
1220 blk_128x256_loop_num = 2;
1224 int tid_linear = item.get_local_linear_id();
1225 g_thd32_tid.init(tid_linear);
1227 static_assert(
ThreadNum == 32,
"All Thread Sync");
1241 for (
int transp128_loop = 0; transp128_loop < transp128_loop_num;
1248 uint32_t width_a = tru_seqlen_ex;
1250 = (batchid * numhead + headid) * max_seqlen + tru_seqlen;
1251 uint32_t pitch_a = max_seqlen;
1252 int start_x_a = transp128_loop * 128 + offset_blk_128x128;
1253 int start_y_a = (batchid * numhead + headid) * max_seqlen;
1255 gemm_arg_128x64.matA_base_desc.init({args->
matW_ptr},
1256 {height_a, width_a, pitch_a}, {start_y_a, start_x_a});
1258 uint32_t width_b = (headid + 1) * hdsz;
1259 uint32_t height_b = tru_seqlen + seqlen_entry;
1260 uint32_t pitch_b = hiddensize;
1261 int start_x_b = headid * hdsz;
1262 int start_y_b = seqlen_entry;
1264 gemm_arg_128x64.matB_base_desc.init({args->
matdO_ptr},
1265 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
1266 gemm_arg_128x64.inner_loop_count
1268 matAcc_128x64.init(0);
1271 gemm_op_128x64_trnp_af(g_thd32_tid, matAcc_128x64, gemm_arg_128x64);
1273 int width_c = (headid + 1) * hdsz;
1274 int height_c = tru_seqlen + seqlen_entry;
1275 int pitch_c = hiddensize;
1276 int start_x_c = headid * hdsz
1277 + gemm_op_128x64_trnp_af_t::get_matC_offset_x(g_thd32_tid);
1278 int start_y_c = transp128_loop * 128 + seqlen_entry
1279 + offset_blk_128x128
1280 + gemm_op_128x64_trnp_af_t::get_matC_offset_y(g_thd32_tid);
1282 matC_128x64_payload.init(args->
matdV_ptr, width_c, height_c,
1283 pitch_c, start_x_c, start_y_c);
1293 for (
int transp256_loop = 0; transp256_loop < transp256_loop_num;
1300 uint32_t width_a = tru_seqlen_ex;
1302 = (batchid * numhead + headid) * max_seqlen + tru_seqlen;
1303 uint32_t pitch_a = max_seqlen;
1304 int start_x_a = transp256_loop * 256;
1305 int start_y_a = (batchid * numhead + headid) * max_seqlen;
1306 gemm_arg_256x64.matA_base_desc.init({args->
matW_ptr},
1307 {height_a, width_a, pitch_a}, {start_y_a, start_x_a});
1309 uint32_t width_b = (headid + 1) * hdsz;
1310 uint32_t height_b = tru_seqlen + seqlen_entry;
1311 uint32_t pitch_b = hiddensize;
1312 int start_x_b = headid * hdsz;
1313 int start_y_b = seqlen_entry;
1314 gemm_arg_256x64.matB_base_desc.init({args->
matdO_ptr},
1315 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
1317 gemm_arg_256x64.inner_loop_count
1320 matAcc_256x64.init(0);
1323 gemm_op_256x64_trnp_af(g_thd32_tid, matAcc_256x64, gemm_arg_256x64);
1325 int width_c = (headid + 1) * hdsz;
1326 int height_c = tru_seqlen + seqlen_entry;
1327 int pitch_c = hiddensize;
1328 int start_x_c = headid * hdsz
1329 + gemm_op_256x64_trnp_af_t::get_matC_offset_x(g_thd32_tid);
1330 int start_y_c = transp256_loop * 256 + seqlen_entry
1331 + gemm_op_256x64_trnp_af_t::get_matC_offset_y(g_thd32_tid);
1333 matC_256x64_payload.init(args->
matdV_ptr, width_c, height_c,
1334 pitch_c, start_x_c, start_y_c);
1344 for (
int all_vert128_loop = 0; all_vert128_loop < all_vert_loop_num;
1345 all_vert128_loop++) {
1347 for (
int hor_256_loop = 0; hor_256_loop < blk_128x256_loop_num;
1354 uint32_t width_a = (headid + 1) * hdsz;
1355 uint32_t height_a = tru_seqlen + seqlen_entry;
1356 uint32_t pitch_a = hiddensize;
1357 int start_x_a = headid * hdsz;
1358 int start_y_a = all_vert128_loop * 128 + seqlen_entry;
1360 gemm_arg_128x256.matA_base_desc.init({args->
matdO_ptr},
1361 {width_a, height_a, pitch_a}, {start_x_a, start_y_a});
1363 uint32_t width_b = (headid + 1) * hdsz;
1364 uint32_t height_b = tru_seqlen + seqlen_entry;
1365 uint32_t pitch_b = hiddensize;
1366 int start_x_b = headid * hdsz;
1367 int start_y_b = hor_256_loop * 256 + seqlen_entry;
1370 gemm_arg_128x256.matB_base_desc.init({args->
matV_ptr},
1371 {height_b, width_b, pitch_b}, {start_y_b, start_x_b});
1373 gemm_arg_128x256.inner_loop_count
1376 matAcc_128x256.init(0);
1379 gemm_op_128x256(g_thd32_tid, matAcc_128x256, gemm_arg_128x256);
1381 int width_c = max_seqlen;
1382 int height_c = max_seqlen * (batchid * numhead + headid + 1);
1383 int pitch_c = max_seqlen;
1385 = gemm_op_128x256_t::get_matC_offset_x(g_thd32_tid)
1386 + hor_256_loop * 256;
1387 int start_y_c = (batchid * numhead + headid) * max_seqlen
1388 + all_vert128_loop * 128
1389 + gemm_op_128x256_t::get_matC_offset_y(g_thd32_tid);
1391 matC_128x256_payload.init(args->
matdW_ptr, width_c, height_c,
1392 pitch_c, start_x_c, start_y_c);
1393 subgroup::elemwise_cvt<matC_128x256_t, matAcc_128x256_t>(
1394 matC_128x256, matAcc_128x256);
1396 xetla_fence<memory_kind::untyped_global>();
1399 for (
int blk_128x128_loop = 0; blk_128x128_loop < blk_128x128_one;
1400 blk_128x128_loop++) {
1406 uint32_t width_a = (headid + 1) * hdsz;
1407 uint32_t height_a = tru_seqlen + seqlen_entry;
1408 uint32_t pitch_a = hiddensize;
1409 int start_x_a = headid * hdsz;
1410 int start_y_a = all_vert128_loop * 128 + seqlen_entry;
1412 gemm_arg_128x128.matA_base_desc.init({args->
matdO_ptr},
1413 {width_a, height_a, pitch_a}, {start_x_a, start_y_a});
1415 uint32_t width_b = (headid + 1) * hdsz;
1416 uint32_t height_b = tru_seqlen + seqlen_entry;
1417 uint32_t pitch_b = hiddensize;
1418 int start_x_b = headid * hdsz;
1419 int start_y_b = offset_blk_128x128 + seqlen_entry;
1422 gemm_arg_128x128.matB_base_desc.init({args->
matV_ptr},
1423 {height_b, width_b, pitch_b}, {start_y_b, start_x_b});
1425 gemm_arg_128x128.inner_loop_count
1428 matAcc_128x128.init(0);
1431 gemm_op_128x128(g_thd32_tid, matAcc_128x128, gemm_arg_128x128);
1433 int width_c = max_seqlen;
1434 int height_c = max_seqlen * (batchid * numhead + headid + 1);
1435 int pitch_c = max_seqlen;
1436 int start_x_c = offset_blk_128x128
1437 + gemm_op_128x128_t::get_matC_offset_x(g_thd32_tid);
1438 int start_y_c = (batchid * numhead + headid) * max_seqlen
1439 + all_vert128_loop * 128
1440 + gemm_op_128x128_t::get_matC_offset_y(g_thd32_tid);
1442 matC_128x128_payload.init(args->
matdW_ptr, width_c, height_c,
1443 pitch_c, start_x_c, start_y_c);
1444 subgroup::elemwise_cvt<matC_128x128_t, matAcc_128x128_t>(
1445 matC_128x128, matAcc_128x128);
1447 xetla_fence<memory_kind::untyped_global>();
1450 int elem_Ln512_loop_num = 4;
1453 int height_elem_offset
1454 = (max_seqlen * (batchid * numhead + headid)
1455 + (all_vert128_loop * 128) + (tid_linear * 4))
1457 int width_elem = width_8x16_512;
1459 int pitch_elem = width_elem;
1460 int start_x_elem = 0;
1464 if constexpr (Mkin_flag ==
true) {
1465 uint32_t mk_attn_all
1466 =
sizeof(uint32_t) * (max_seqlen / 32) * (batchid);
1468 = xetla_vector_gen<uint32_t, 16>(0, 1);
1469 mk_attn_offsets *=
sizeof(uint32_t);
1470 mk_attn_offsets += mk_attn_all;
1477 uint32_t mk_offset_all;
1479 = xetla_vector_gen<uint32_t, 16>(0, 1);
1480 if constexpr (Dopt_RandGenflag ==
false) {
1481 mk_offset_all =
sizeof(uint32_t) * (max_seqlen / 32)
1482 * ((batchid * numhead + headid) * max_seqlen
1483 + (all_vert128_loop * 128) + tid_linear * 4);
1484 mk_offsets *=
sizeof(uint32_t);
1485 mk_offsets += mk_offset_all;
1491 for (
int elem_Ln512_loop = 0; elem_Ln512_loop < elem_Ln512_loop_num;
1492 elem_Ln512_loop++) {
1503 start_y_elem = height_elem_offset
1504 + elem_Ln512_loop * height_8x64_512;
1508 matdW_rd_payload.init(args->
matdW_ptr, width_elem, height_elem,
1509 pitch_elem, start_x_elem, start_y_elem);
1510 matW_rd_payload.init(args->
matW_ptr, width_elem, height_elem,
1511 pitch_elem, start_x_elem, start_y_elem);
1512 matdW_st_payload.init(args->
matdW_ptr, width_elem, height_elem,
1513 pitch_elem, start_x_elem, start_y_elem);
1515 subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
1516 matdW_rd, matdW_rd_payload);
1517 subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
1518 matW_rd, matW_rd_payload);
1520 if constexpr (Dopt_RandGenflag ==
false) {
1525 mk_offsets +=
sizeof(uint32_t) * (max_seqlen / 32);
1528 matdW_reg16x32.
reg = xetla_cvt<float, dtype_sfx>(matdW_rd.
reg);
1529 matW_reg16x32.
reg = xetla_cvt<float, dtype_sfx>(matW_rd.
reg);
1531 if constexpr (Dopt_RandGenflag ==
false) {
1534 for (
int j = 0; j < 16; j++) {
1535 uint32_t mkdata_i = mkdpot_vec16[j];
1537 = xetla_mask_int_gen<32>(mkdata_i);
1538 matdW_reg16x32.
reg.xetla_format<
float>()
1539 .xetla_select<32, 1>(j * 32)
1541 matdW_reg16x32.
reg.xetla_format<
float>()
1542 .xetla_select<32, 1>(j * 32),
1545 matdW_reg16x32.
reg = matW_reg16x32.
reg * matdW_reg16x32.
reg;
1548 for (
int j = 0; j < 16; j++) {
1551 mask = matW_rd.
reg.xetla_format<int16_t>()
1552 .xetla_select<32, 1>(j * 32)
1554 matW_rd.
reg.xetla_format<uint16_t>()
1555 .xetla_select<32, 1>(j * 32)
1559 mask = matW_rd.
reg.xetla_format<int8_t>()
1560 .xetla_select<32, 1>(j * 32)
1562 matW_rd.
reg.xetla_format<uint8_t>()
1563 .xetla_select<32, 1>(j * 32)
1566 matW_reg16x32.
reg.xetla_format<
float>()
1567 .xetla_select<32, 1>(j * 32)
1568 .xetla_merge(0.0, mask);
1571 matdW_reg16x32.
reg = matW_reg16x32.
reg * matdW_reg16x32.
reg;
1574 = xetla_cvt<float, dtype_sfx>(matW_rd.
reg);
1579 = matdW_reg16x32.
reg.xetla_select<16, 1>(0);
1581 for (
int j = 1; j < 32; j++)
1583 + matdW_reg16x32.
reg.xetla_select<16, 1>(j * 16);
1585 mdw_sum.xetla_select<8, 1>(0) = mdw_sum.xetla_select<8, 1>(0)
1586 + mdw_sum.xetla_select<8, 1>(8);
1587 mdw_sum.xetla_select<4, 1>(0) = mdw_sum.xetla_select<4, 1>(0)
1588 + mdw_sum.xetla_select<4, 1>(4);
1589 mdw_sum.xetla_select<2, 1>(0) = mdw_sum.xetla_select<2, 1>(0)
1590 + mdw_sum.xetla_select<2, 1>(2);
1591 mdw_sum.xetla_select<1, 1>(0) = mdw_sum.xetla_select<1, 1>(0)
1592 + mdw_sum.xetla_select<1, 1>(1);
1594 float sumtmp = mdw_sum[0];
1595 matW_reg16x32.
reg = matW_reg16x32.
reg * sumtmp;
1598 matdW_reg16x32.
reg -= matW_reg16x32.
reg;
1600 matdW_reg16x32.
reg = matdW_reg16x32.
reg * args->
Pinv;
1602 if constexpr (Mkin_flag ==
true) {
1604 for (
int j = 0; j < 16; j++) {
1605 uint32_t mkdata_i = mkin_vec16[j];
1607 = xetla_mask_int_gen<32>(mkdata_i);
1608 matdW_reg16x32.
reg.xetla_format<
float>()
1609 .xetla_select<32, 1>(j * 32)
1611 matdW_reg16x32.
reg.xetla_format<
float>()
1612 .xetla_select<32, 1>(j * 32),
1617 matdW_st.
reg = xetla_cvt<dtype_sfx, float>(matdW_reg16x32.
reg);
1620 xetla_fence<memory_kind::untyped_global>();
1624 second_nbarr.
wait();
1632 uint32_t width_a = tru_seqlen_ex;
1633 uint32_t height_a = (batchid * numhead + headid) * max_seqlen
1635 uint32_t pitch_a = max_seqlen;
1637 int start_y_a = (batchid * numhead + headid) * max_seqlen
1638 + all_vert128_loop * 128;
1640 gemm_arg_128x64.matA_base_desc.init({args->
matdW_ptr},
1641 {width_a, height_a, pitch_a}, {start_x_a, start_y_a});
1643 uint32_t width_b = (headid + 1) * hdsz;
1644 uint32_t height_b = tru_seqlen + seqlen_entry;
1645 uint32_t pitch_b = hiddensize;
1646 int start_x_b = headid * hdsz;
1647 int start_y_b = seqlen_entry;
1649 gemm_arg_128x64.matB_base_desc.init({args->
matK_ptr},
1650 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
1652 gemm_arg_128x64.inner_loop_count
1655 matAcc_128x64.init(0);
1658 gemm_op_128x64(g_thd32_tid, matAcc_128x64, gemm_arg_128x64);
1660 int width_c = (headid + 1) * hdsz;
1661 int height_c = tru_seqlen + seqlen_entry;
1662 int pitch_c = hiddensize;
1663 int start_x_c = headid * hdsz
1664 + gemm_op_128x64_t::get_matC_offset_x(g_thd32_tid);
1665 int start_y_c = all_vert128_loop * 128 + seqlen_entry
1666 + gemm_op_128x64_t::get_matC_offset_y(g_thd32_tid);
1668 matC_128x64_payload.init(args->
matdQ_ptr, width_c, height_c,
1669 pitch_c, start_x_c, start_y_c);
1670 subgroup::elemwise_cvt<matC_128x64_t, matAcc_128x64_t>(
1671 matC_128x64, matAcc_128x64);
1676 for (
int transp256_loop = 0; transp256_loop < transp256_loop_num;
1683 uint32_t width_a = tru_seqlen_ex;
1685 = (batchid * numhead + headid) * max_seqlen + tru_seqlen;
1686 uint32_t pitch_a = max_seqlen;
1687 int start_x_a = transp256_loop * 256;
1688 int start_y_a = (batchid * numhead + headid) * max_seqlen;
1690 gemm_arg_256x64.matA_base_desc.init({args->
matdW_ptr},
1691 {height_a, width_a, pitch_a}, {start_y_a, start_x_a});
1693 uint32_t width_b = (headid + 1) * hdsz;
1694 uint32_t height_b = tru_seqlen + seqlen_entry;
1695 uint32_t pitch_b = hiddensize;
1696 int start_x_b = headid * hdsz;
1697 int start_y_b = seqlen_entry;
1699 gemm_arg_256x64.matB_base_desc.init({args->
matQ_ptr},
1700 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
1702 gemm_arg_256x64.inner_loop_count
1705 matAcc_256x64.init(0);
1707 gemm_op_256x64_trnp_a(g_thd32_tid, matAcc_256x64, gemm_arg_256x64);
1709 int width_c = (headid + 1) * hdsz;
1710 int height_c = tru_seqlen + seqlen_entry;
1711 int pitch_c = hiddensize;
1712 int start_x_c = headid * hdsz
1713 + gemm_op_256x64_trnp_a_t::get_matC_offset_x(g_thd32_tid);
1714 int start_y_c = transp256_loop * 256 + seqlen_entry
1715 + gemm_op_256x64_trnp_a_t::get_matC_offset_y(g_thd32_tid);
1717 matC_256x64_payload.init(args->
matdK_ptr, width_c, height_c,
1718 pitch_c, start_x_c, start_y_c);
1727 for (
int transp128_loop = 0; transp128_loop < transp128_loop_num;
1734 uint32_t width_a = tru_seqlen_ex;
1736 = (batchid * numhead + headid) * max_seqlen + tru_seqlen;
1737 uint32_t pitch_a = max_seqlen;
1738 int start_x_a = transp128_loop * 128 + offset_blk_128x128;
1739 int start_y_a = (batchid * numhead + headid) * max_seqlen;
1741 gemm_arg_128x64.matA_base_desc.init({args->
matdW_ptr},
1742 {height_a, width_a, pitch_a}, {start_y_a, start_x_a});
1744 uint32_t width_b = (headid + 1) * hdsz;
1745 uint32_t height_b = tru_seqlen + seqlen_entry;
1746 uint32_t pitch_b = hiddensize;
1747 int start_x_b = headid * hdsz;
1748 int start_y_b = seqlen_entry;
1750 gemm_arg_128x64.matB_base_desc.init({args->
matQ_ptr},
1751 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
1753 gemm_arg_128x64.inner_loop_count
1756 matAcc_128x64.init(0);
1759 gemm_op_128x64_trnp_a(g_thd32_tid, matAcc_128x64, gemm_arg_128x64);
1761 int width_c = (headid + 1) * hdsz;
1762 int height_c = tru_seqlen + seqlen_entry;
1763 int pitch_c = hiddensize;
1764 int start_x_c = headid * hdsz
1765 + gemm_op_128x64_trnp_a_t::get_matC_offset_x(g_thd32_tid);
1766 int start_y_c = transp128_loop * 128 + seqlen_entry
1767 + offset_blk_128x128
1768 + gemm_op_128x64_trnp_a_t::get_matC_offset_y(g_thd32_tid);
1770 matC_128x64_payload.init(args->
matdK_ptr, width_c, height_c,
1771 pitch_c, start_x_c, start_y_c);
Gemm functor.
Definition api.hpp:52
Definition limitation.hpp:738
#define __XETLA_API
Definition common.hpp:43
#define xetla_merge
xetla merge.
Definition base_ops.hpp:60
__ESIMD_NS::simd_mask< N > xetla_mask_int
wrapper for xetla_mask_int.
Definition base_types.hpp:172
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
__XETLA_API xetla_vector< Ty, N *NElts > xetla_load_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_mask< N > pred=1)
Stateless scattered load.
Definition memory.hpp:245
__XETLA_API xetla_vector< uint32_t, 4 > get_time_stamp()
Returns time stamp.
Definition misc.hpp:57
#define rand_threshold_const
Definition mha_attn_reg.hpp:26
#define list_width
Definition mha_attn_reg.hpp:25
#define SIGN_BIT_W16
Definition mha_attn_reg.hpp:28
#define SIGN_BIT_B8
Definition mha_attn_reg.hpp:29
Definition limitation.hpp:734
__XETLA_API std::enable_if_t<(T_src::register_layout !=reg_layout::linear) &&(T_dst::register_layout !=reg_layout::linear) &&is_same_layout< T_dst, T_src >::value &&(!is_floating_to_integer< T_dst, T_src >::value)> elemwise_cvt(T_dst &dst, T_src &src)
Is the element wise data conversion, the src and dst tile should have the same layout.
Definition op_function.hpp:40
__XETLA_API std::enable_if_t< detail::check_store_type< tile_t, payload_t >::is_global_2d_xe > tile_store(tile_t &tile, payload_t &payload)
Is the func storing data from register file to global memory.
Definition store_xe.hpp:91
mem_space
Definition common.hpp:77
gpu_arch
Definition common.hpp:73
msg_type
Definition common.hpp:78
mem_layout
Definition common.hpp:76
Compute attribute for gemm.
Definition common.hpp:32
Compute policy for xmx engine.
Definition compute_policy.hpp:35
Fine-tune knobs for gemm.
Definition common.hpp:43
Gemm default pre_processing functor.
Definition api.hpp:33
Gemm pre_processing functor with applying relu op to matA.
Definition api.hpp:39
Workgroup level tile shape description.
Definition tile_shape.hpp:34
Arguments for xetla_softmax_bwd_t::run.
Definition mha_core_attn.hpp:1130
uint32_t * matMkin_ptr
Definition mha_core_attn.hpp:1136
float Scaling
Definition mha_core_attn.hpp:1145
dtype_bin * matdO_ptr
Definition mha_core_attn.hpp:1138
dtype_sfx * matdW_ptr
Definition mha_core_attn.hpp:1140
dtype_bot * matdQ_ptr
Definition mha_core_attn.hpp:1142
dtype_bin * matK_ptr
Definition mha_core_attn.hpp:1134
dtype_bin * matV_ptr
Definition mha_core_attn.hpp:1135
dtype_sfx * matW_ptr
Definition mha_core_attn.hpp:1139
float Pinv
Definition mha_core_attn.hpp:1144
dtype_bin * matQ_ptr
Definition mha_core_attn.hpp:1133
uint32_t * mList_ptr
Definition mha_core_attn.hpp:1132
dtype_bot * matdK_ptr
Definition mha_core_attn.hpp:1143
dtype_bot * matdV_ptr
Definition mha_core_attn.hpp:1141
uint32_t * matMkdpot_ptr
Definition mha_core_attn.hpp:1137
Definition mha_core_attn.hpp:872
mem_desc_t< dtype_sfx, gemm_mem_layout_a, gemm_mem_space_a > mem_desc_a_out
Definition mha_core_attn.hpp:920
subgroup::tile_desc_t< matAcc_128x128_t::tile_desc::tile_size_x, matAcc_128x128_t::tile_desc::tile_size_y, matAcc_128x128_t::tile_desc::block_size_x, matAcc_128x128_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x128_tile_desc_t
Definition mha_core_attn.hpp:1003
static constexpr int ThreadNum
Definition mha_core_attn.hpp:878
typename gemm_op_128x128_t::arguments_t gemm_arguments_128x128
Definition mha_core_attn.hpp:978
static constexpr mem_layout gemm_mem_layout_out_b
Definition mha_core_attn.hpp:897
typename gemm_op_128x64_t::arguments_t gemm_arguments_128x64
Definition mha_core_attn.hpp:980
typename gemm_op_128x64_trnp_a_t::matAcc_t matAcc_128x64_trnp_a_t
Definition mha_core_attn.hpp:993
subgroup::tile_t< dtype_bot, matC_128x64_trnp_af_tile_desc_t > matC_128x64_trnp_af_t
Definition mha_core_attn.hpp:1048
group::compute_policy_default_xmx< group::compute_attr_t< dtype_bin, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe > compute_policy_QKT
Definition mha_core_attn.hpp:917
static constexpr mem_layout mem_layout_QKT_b
Definition mha_core_attn.hpp:886
static constexpr mem_layout mem_layout_a
Definition mha_core_attn.hpp:884
static constexpr mem_layout mem_layout_trnp_a
Definition mha_core_attn.hpp:885
mem_desc_t< dtype_bin, gemm_mem_layout_a, gemm_mem_space_a > mem_desc_a_QKT
Definition mha_core_attn.hpp:912
typename gemm_op_128x64_t::matAcc_t matAcc_128x64_t
Definition mha_core_attn.hpp:992
subgroup::tile_desc_t< matAcc_256x64_trnp_a_t::tile_desc::tile_size_x, matAcc_256x64_trnp_a_t::tile_desc::tile_size_y, matAcc_256x64_trnp_a_t::tile_desc::block_size_x, matAcc_256x64_trnp_a_t::tile_desc::block_size_y, reg_layout::tiled > matC_256x64_trnp_a_tile_desc_t
Definition mha_core_attn.hpp:1025
group::perf_tuning_knob_t< k_stride, prefetch_distance, periodic_sync_interval > bgm_perf_tuning_knob
Definition mha_core_attn.hpp:905
typename gemm_op_128x128_t::matAcc_t matAcc_128x128_t
Definition mha_core_attn.hpp:990
group::compute_policy_default_xmx< group::compute_attr_t< dtype_sfx, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe > compute_policy_out
Definition mha_core_attn.hpp:925
static constexpr uint32_t periodic_sync_interval
Definition mha_core_attn.hpp:899
gpu::xetla::subgroup::tile_desc_t< 64/sfx_type_size, 8 *sfx_type_size, 64/sfx_type_size, 8 *sfx_type_size, reg_layout::tiled > matElem_tile_desc_t
Definition mha_core_attn.hpp:1112
static constexpr mem_layout gemm_mem_layout_QKT_b
Definition mha_core_attn.hpp:896
typename gemm_op_128x64_trnp_af_t::matAcc_t matAcc_128x64_trnp_af_t
Definition mha_core_attn.hpp:995
dtype_bwd_acc_ dtype_acc
Definition mha_core_attn.hpp:876
subgroup::tile_t< dtype_bot, matC_256x64_trnp_af_tile_desc_t > matC_256x64_trnp_af_t
Definition mha_core_attn.hpp:1050
static constexpr mem_space gemm_mem_space_b
Definition mha_core_attn.hpp:895
typename gemm_op_128x256_t::matAcc_t matAcc_128x256_t
Definition mha_core_attn.hpp:991
mem_desc_t< dtype_bin, gemm_mem_layout_out_b, gemm_mem_space_b > mem_desc_b_out_b_trnp_a
Definition mha_core_attn.hpp:930
static __XETLA_API void call(sycl::nd_item< 3 > &item, arguments_t *args)
Main execution function for fused mha softmax.
Definition mha_core_attn.hpp:1154
static constexpr uint16_t sfx_type_size
Definition mha_core_attn.hpp:936
mem_desc_t< dtype_bin, gemm_mem_layout_out_b, gemm_mem_space_b > mem_desc_b_out
Definition mha_core_attn.hpp:922
static constexpr mem_space gemm_mem_space_a
Definition mha_core_attn.hpp:890
static constexpr uint32_t k_stride
Definition mha_core_attn.hpp:903
group::tile_shape_t< 256, 128, 32, 32 > tile_attr_128x256
Definition mha_core_attn.hpp:907
typename gemm_op_128x64_trnp_af_t::arguments_t gemm_arguments_128x64_trnp_af
Definition mha_core_attn.hpp:986
static constexpr mem_layout gemm_mem_layout_trnp_a
Definition mha_core_attn.hpp:893
group::tile_shape_t< 64, 256, 16, 32 > tile_attr_256x64
Definition mha_core_attn.hpp:908
static constexpr mem_space mem_space_c
Definition mha_core_attn.hpp:882
typename gemm_op_256x64_trnp_a_t::arguments_t gemm_arguments_256x64_trnp_a
Definition mha_core_attn.hpp:984
typename gemm_op_256x64_trnp_af_t::matAcc_t matAcc_256x64_trnp_af_t
Definition mha_core_attn.hpp:996
subgroup::tile_t< dtype_bot, matC_128x64_trnp_a_tile_desc_t > matC_128x64_trnp_a_t
Definition mha_core_attn.hpp:1044
static constexpr mem_space gemm_mem_space_trnp_a
Definition mha_core_attn.hpp:891
static constexpr mem_layout gemm_mem_layout_a
Definition mha_core_attn.hpp:892
subgroup::tile_desc_t< matAcc_256x64_trnp_af_t::tile_desc::tile_size_x, matAcc_256x64_trnp_af_t::tile_desc::tile_size_y, matAcc_256x64_trnp_af_t::tile_desc::block_size_x, matAcc_256x64_trnp_af_t::tile_desc::block_size_y, reg_layout::tiled > matC_256x64_trnp_af_tile_desc_t
Definition mha_core_attn.hpp:1037
typename gemm_op_128x256_t::arguments_t gemm_arguments_128x256
Definition mha_core_attn.hpp:979
dtype_bwd_sfx_ dtype_sfx
Definition mha_core_attn.hpp:875
subgroup::tile_desc_t< matAcc_128x64_trnp_a_t::tile_desc::tile_size_x, matAcc_128x64_trnp_a_t::tile_desc::tile_size_y, matAcc_128x64_trnp_a_t::tile_desc::block_size_x, matAcc_128x64_trnp_a_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x64_trnp_a_tile_desc_t
Definition mha_core_attn.hpp:1020
typename gemm_op_128x64_trnp_a_t::arguments_t gemm_arguments_128x64_trnp_a
Definition mha_core_attn.hpp:982
group::compute_policy_default_xmx< group::compute_attr_t< dtype_sfx, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe > compute_policy_out_b_trnp_a
Definition mha_core_attn.hpp:933
static constexpr mem_space mem_space_a
Definition mha_core_attn.hpp:880
static constexpr uint32_t prefetch_distance
Definition mha_core_attn.hpp:900
dtype_bwd_bin_ dtype_bin
Definition mha_core_attn.hpp:873
static constexpr mem_space mem_space_b
Definition mha_core_attn.hpp:881
work_group_t< ThreadNum > work_group_t
Definition mha_core_attn.hpp:940
subgroup::tile_desc_t< matAcc_128x64_t::tile_desc::tile_size_x, matAcc_128x64_t::tile_desc::tile_size_y, matAcc_128x64_t::tile_desc::block_size_x, matAcc_128x64_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x64_tile_desc_t
Definition mha_core_attn.hpp:1015
mem_desc_t< dtype_bin, gemm_mem_layout_QKT_b, gemm_mem_space_b > mem_desc_b_QKT
Definition mha_core_attn.hpp:914
static constexpr mem_layout mem_layout_c
Definition mha_core_attn.hpp:888
subgroup::tile_desc_t< matAcc_128x256_t::tile_desc::tile_size_x, matAcc_128x256_t::tile_desc::tile_size_y, matAcc_128x256_t::tile_desc::block_size_x, matAcc_128x256_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x256_tile_desc_t
Definition mha_core_attn.hpp:1009
mem_desc_t< dtype_sfx, gemm_mem_layout_trnp_a, gemm_mem_space_trnp_a > mem_desc_a_out_b_trnp_a
Definition mha_core_attn.hpp:928
dtype_bwd_bot_ dtype_bot
Definition mha_core_attn.hpp:874
subgroup::tile_t< dtype_bot, matC_256x64_trnp_a_tile_desc_t > matC_256x64_trnp_a_t
Definition mha_core_attn.hpp:1046
typename gemm_op_256x64_trnp_af_t::arguments_t gemm_arguments_256x64_trnp_af
Definition mha_core_attn.hpp:988
subgroup::tile_desc_t< matAcc_128x64_trnp_af_t::tile_desc::tile_size_x, matAcc_128x64_trnp_af_t::tile_desc::tile_size_y, matAcc_128x64_trnp_af_t::tile_desc::block_size_x, matAcc_128x64_trnp_af_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x64_trnp_af_tile_desc_t
Definition mha_core_attn.hpp:1031
group::tile_shape_t< 64, 128, 16, 16 > tile_attr_128x64
Definition mha_core_attn.hpp:909
static constexpr uint32_t global_kslicing
Definition mha_core_attn.hpp:935
typename gemm_op_256x64_trnp_a_t::matAcc_t matAcc_256x64_trnp_a_t
Definition mha_core_attn.hpp:994
static constexpr mem_layout mem_layout_out_b
Definition mha_core_attn.hpp:887
group::tile_shape_t< 128, 128, 32, 16 > tile_attr_128x128
Definition mha_core_attn.hpp:906
Arguments for xetla_softmax_fwd_t::run.
Definition mha_core_attn.hpp:193
dtype_bin * matK_ptr
Definition mha_core_attn.hpp:197
uint32_t * matMkin_ptr
Definition mha_core_attn.hpp:199
uint32_t * mList_ptr
Definition mha_core_attn.hpp:195
dtype_sfx * matQKT_ptr
Definition mha_core_attn.hpp:201
dtype_bin * matV_ptr
Definition mha_core_attn.hpp:198
dtype_bot * matOut_ptr
Definition mha_core_attn.hpp:202
dtype_bin * matQ_ptr
Definition mha_core_attn.hpp:196
uint32_t * matMkdpot_ptr
Definition mha_core_attn.hpp:200
float Pinv
Definition mha_core_attn.hpp:203
float Scaling
Definition mha_core_attn.hpp:204
Definition mha_core_attn.hpp:44
subgroup::tile_desc_t< matAcc_128x64_t::tile_desc::tile_size_x, matAcc_128x64_t::tile_desc::tile_size_y, matAcc_128x64_t::tile_desc::block_size_x, matAcc_128x64_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x64_tile_desc_t
Definition mha_core_attn.hpp:145
mem_desc_t< dtype_bin, gemm_mem_layout_a, gemm_mem_space_a > mem_desc_a_QKT
Definition mha_core_attn.hpp:81
typename gemm_op_128x128_t::matAcc_t matAcc_128x128_t
Definition mha_core_attn.hpp:124
static constexpr int max_seqlen
Definition mha_core_attn.hpp:51
typename gemm_op_128x256_t::matAcc_t matAcc_128x256_t
Definition mha_core_attn.hpp:125
static constexpr uint16_t Rand_SIMD
Definition mha_core_attn.hpp:55
static constexpr mem_layout mem_layout_out_b
Definition mha_core_attn.hpp:59
static constexpr uint32_t global_kslicing
Definition mha_core_attn.hpp:96
static constexpr mem_layout mem_layout_QKT_b
Definition mha_core_attn.hpp:58
group::perf_tuning_knob_t< k_stride, prefetch_distance, periodic_sync_interval > bgm_perf_tuning_knob
Definition mha_core_attn.hpp:74
typename gemm_op_128x128_t::arguments_t gemm_arguments_128x128
Definition mha_core_attn.hpp:120
subgroup::tile_desc_t< matAcc_128x128_t::tile_desc::tile_size_x, matAcc_128x128_t::tile_desc::tile_size_y, matAcc_128x128_t::tile_desc::block_size_x, matAcc_128x128_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x128_tile_desc_t
Definition mha_core_attn.hpp:133
group::compute_policy_default_xmx< group::compute_attr_t< dtype_bin, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe > compute_policy_QKT
Definition mha_core_attn.hpp:86
mem_desc_t< dtype_bin, gemm_mem_layout_QKT_b, gemm_mem_space_b > mem_desc_b_QKT
Definition mha_core_attn.hpp:83
static constexpr uint32_t periodic_sync_interval
Definition mha_core_attn.hpp:69
typename gemm_op_128x256_t::arguments_t gemm_arguments_128x256
Definition mha_core_attn.hpp:121
static constexpr mem_space mem_space_b
Definition mha_core_attn.hpp:53
static constexpr mem_layout gemm_mem_layout_QKT_b
Definition mha_core_attn.hpp:66
work_group_t< ThreadNum > work_group_t
Definition mha_core_attn.hpp:101
dtype_bot_ dtype_bot
Definition mha_core_attn.hpp:46
dtype_acc_ dtype_acc
Definition mha_core_attn.hpp:48
static constexpr uint32_t k_stride
Definition mha_core_attn.hpp:72
static constexpr mem_space gemm_mem_space_b
Definition mha_core_attn.hpp:65
group::compute_policy_default_xmx< group::compute_attr_t< dtype_sfx, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe > compute_policy_out
Definition mha_core_attn.hpp:94
dtype_sfx_ dtype_sfx
Definition mha_core_attn.hpp:47
static constexpr int ThreadNum
Definition mha_core_attn.hpp:50
static constexpr mem_layout mem_layout_a
Definition mha_core_attn.hpp:57
mem_desc_t< dtype_bin, gemm_mem_layout_out_b, gemm_mem_space_b > mem_desc_b_out
Definition mha_core_attn.hpp:91
gpu::xetla::subgroup::tile_desc_t< 64/sfx_type_size, 8 *sfx_type_size, 64/sfx_type_size, 8 *sfx_type_size, reg_layout::tiled > matElem_tile_desc_t
Definition mha_core_attn.hpp:171
group::tile_shape_t< 128, 128, 32, 16 > tile_attr_128x128
Definition mha_core_attn.hpp:76
static constexpr mem_layout gemm_mem_layout_out_b
Definition mha_core_attn.hpp:67
static constexpr mem_space gemm_mem_space_a
Definition mha_core_attn.hpp:62
static constexpr uint16_t sfx_type_size
Definition mha_core_attn.hpp:97
group::tile_shape_t< 256, 128, 32, 32 > tile_attr_128x256
Definition mha_core_attn.hpp:77
typename gemm_op_128x64_t::arguments_t gemm_arguments_128x64
Definition mha_core_attn.hpp:122
static constexpr uint32_t prefetch_distance
Definition mha_core_attn.hpp:70
typename gemm_op_128x64_t::matAcc_t matAcc_128x64_t
Definition mha_core_attn.hpp:126
static constexpr mem_space mem_space_c
Definition mha_core_attn.hpp:54
static constexpr mem_layout mem_layout_c
Definition mha_core_attn.hpp:60
subgroup::tile_desc_t< matAcc_128x256_t::tile_desc::tile_size_x, matAcc_128x256_t::tile_desc::tile_size_y, matAcc_128x256_t::tile_desc::block_size_x, matAcc_128x256_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x256_tile_desc_t
Definition mha_core_attn.hpp:139
dtype_bin_ dtype_bin
Definition mha_core_attn.hpp:45
static constexpr mem_layout gemm_mem_layout_a
Definition mha_core_attn.hpp:63
static __XETLA_API void call(sycl::nd_item< 3 > &item, arguments_t *args)
Main execution function for fused mha softmax.
Definition mha_core_attn.hpp:213
mem_desc_t< dtype_sfx, gemm_mem_layout_a, gemm_mem_space_a > mem_desc_a_out
Definition mha_core_attn.hpp:89
group::tile_shape_t< 64, 128, 16, 16 > tile_attr_128x64
Definition mha_core_attn.hpp:78
static constexpr mem_space mem_space_a
Definition mha_core_attn.hpp:52
Definition memory_descriptor.hpp:139
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102
xetla nbarrier definition API.
Definition raw_send_nbarrier.hpp:43
__XETLA_API void arrive()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:65
__XETLA_API void init_nbarrier(uint8_t nbarrier_id, nbarrier_role role=nbarrier_role::producer_consumer)
Definition raw_send_nbarrier.hpp:55
__XETLA_API void wait()
named barrier wait within subgroup.
Definition raw_send_nbarrier.hpp:76
__XETLA_API xetla_vector< uint32_t, 4 *SIMD > rand()
Definition rand.hpp:57
__XETLA_API void init(uint64_t seed, uint64_t subseq, uint64_t offset)
Definition rand.hpp:38