19#include "common/common.hpp"
20#include "group/group.hpp"
21#include "subgroup/subgroup.hpp"
26#define rand_threshold_const 0x80000000
27#define SIGN_BIT_DW 0x80000000
28#define SIGN_BIT_W16 0x8000
29#define SIGN_BIT_B8 0x80
31template <
typename dtype_bin_,
typename dtype_bot_,
typename dtype_sfx_,
32 typename dtype_acc_,
int HWThreadNum,
bool Dopt_RandGenflag =
true,
33 uint16_t RandSIMD = 16,
int Max_SeqLen = 2048>
150 matAcc_128x128_t::tile_desc::tile_size_y,
151 matAcc_128x128_t::tile_desc::block_size_x,
152 matAcc_128x128_t::tile_desc::block_size_y,
156 matAcc_128x256_t::tile_desc::tile_size_y,
157 matAcc_128x256_t::tile_desc::block_size_x,
158 matAcc_128x256_t::tile_desc::block_size_y,
162 matAcc_64x384_t::tile_desc::tile_size_y,
163 matAcc_64x384_t::tile_desc::block_size_x,
164 matAcc_64x384_t::tile_desc::block_size_y,
168 matAcc_64x512_t::tile_desc::tile_size_y,
169 matAcc_64x512_t::tile_desc::block_size_x,
170 matAcc_64x512_t::tile_desc::block_size_y,
174 matAcc_32x1024_t::tile_desc::tile_size_y,
175 matAcc_32x1024_t::tile_desc::block_size_x,
176 matAcc_32x1024_t::tile_desc::block_size_y,
180 matAcc_16x2048_t::tile_desc::tile_size_y,
181 matAcc_16x2048_t::tile_desc::block_size_x,
182 matAcc_16x2048_t::tile_desc::block_size_y,
186 matAcc_128x64_t::tile_desc::tile_size_y,
187 matAcc_128x64_t::tile_desc::block_size_x,
188 matAcc_128x64_t::tile_desc::block_size_y,
202 : subgroup::msg_type_v<
209 : subgroup::msg_type_v<
230 : subgroup::msg_type_v<
237 : subgroup::msg_type_v<
266 subgroup::msg_type_v<mat_128x128_tile_desc_t, mem_space_c>,
271 subgroup::msg_type_v<mat_128x256_tile_desc_t, mem_space_c>,
276 subgroup::msg_type_v<mat_64x384_tile_desc_t, mem_space_c>,
281 subgroup::msg_type_v<mat_64x512_tile_desc_t, mem_space_c>,
286 subgroup::msg_type_v<mat_32x1024_tile_desc_t, mem_space_c>,
291 subgroup::msg_type_v<mat_16x2048_tile_desc_t, mem_space_c>,
296 subgroup::msg_type_v<mat_128x64_tile_desc_t, mem_space_c>,
323 int tru_seqlen_ex = 0;
324 int seqlen_entry = 0;
326 int groupid = item.get_group(0);
327 int hiddensize = 1024;
330 int wg_tile_QKT_k = hdsz;
332 int batchid = groupid / numhead;
333 int headid = groupid % numhead;
336 int tid_linear = item.get_local_linear_id();
337 g_thd32_tid.init(tid_linear);
339 uint32_t batch_offset =
sizeof(uint32_t) *
list_width * batchid;
341 = xetla_vector_gen<uint32_t, list_width>(0, 1);
342 list_offsets *=
sizeof(uint32_t);
343 list_offsets += batch_offset;
349 tru_seqlen = list_vec[0];
350 seqlen_entry = list_vec[1];
351 wg_tile_out_k = tru_seqlen;
352 tru_seqlen_ex = tru_seqlen;
354 tru_seqlen_ex = (((tru_seqlen + 1) >> 1) << 1);
356 tru_seqlen_ex = (((tru_seqlen + 3) >> 2) << 2);
361 if constexpr (Dopt_RandGenflag ==
true) {
362 uint64_t rand_seed = 67280421310721;
365 uint64_t rand_offset = list_vec.xetla_format<uint64_t>()[1];
366 if (list_vec[4] != 0) rand_threshold = list_vec[4];
367 if (rand_offset == 0) {
369 rand_offset = time_stamp.xetla_format<uint64_t>()[0];
371 Rand_Gen.
init(rand_seed, rand_subseq, rand_offset);
375 int all_vert_loop_num = 2;
376 int all_vert_stride = 128;
377 int all_vert128_shift = 0;
378 int block_16x16_num = 4;
382 if (tru_seqlen <= 128) {
385 all_vert_loop_num = 1;
387 }
else if (tru_seqlen <= 256) {
390 }
else if (tru_seqlen <= 384) {
392 all_vert_stride = 64;
393 all_vert128_shift = 1;
396 all_vert_loop_num = (tru_seqlen + all_vert_stride - 1) >> 6;
397 }
else if (tru_seqlen <= 512) {
399 all_vert_stride = 64;
400 all_vert128_shift = 1;
402 all_vert_loop_num = (tru_seqlen + all_vert_stride - 1) >> 6;
403 }
else if (tru_seqlen <= 1024) {
405 all_vert_stride = 32;
406 all_vert128_shift = 2;
408 all_vert_loop_num = (tru_seqlen + all_vert_stride - 1) >> 5;
409 }
else if (tru_seqlen <= 2048) {
411 all_vert_stride = 16;
412 all_vert128_shift = 3;
414 all_vert_loop_num = (tru_seqlen + all_vert_stride - 1) >> 4;
416 all_vert_loop_num = ((all_vert_loop_num + (1 << all_vert128_shift) - 1)
417 >> all_vert128_shift)
418 << all_vert128_shift;
419 int tid_x = tid_linear & ((1 << tid_x_shift) - 1);
420 int tid_y = tid_linear >> tid_x_shift;
430 int valid_block_16x16_x = (tid_x + 1) * 16 * block_16x16_num;
432 int bndy_block_num = 0;
433 if (valid_block_16x16_x <= tru_seqlen)
434 valid_block_16x16_x = block_16x16_num;
436 bndy_block_num = valid_block_16x16_x;
437 valid_block_16x16_x = (tru_seqlen + 15 + 16 * block_16x16_num
438 - valid_block_16x16_x)
440 bndy_block_num = bndy_block_num
441 + (valid_block_16x16_x - block_16x16_num) * 16
446 = xetla_vector_gen<uint32_t, 16>(0, 1);
447 int attn_mk_address_offset
448 = (batchid * Max_SeqLen) + (tid_x * 16 * block_16x16_num);
449 address_attn_mk *=
sizeof(uint32_t);
450 address_attn_mk += attn_mk_address_offset;
451 attn_mk_4x16.xetla_format<uint32_t>().xetla_select<16, 1>(0)
456 for (
int i = 1; i <= bndy_block_num; i++)
457 attn_mk_4x16[valid_block_16x16_x * 16 - i] = 1;
460 for (
int all_vert_loop = 0; all_vert_loop < all_vert_loop_num;
465 bool valid_compute =
true;
467 if (((all_vert_loop * all_vert_stride + tid_y * 16) >= tru_seqlen)
468 || ((tid_x * 16 * block_16x16_num) >= tru_seqlen))
469 valid_compute =
false;
473 switch (std_seqlen) {
478 uint32_t width_a = (headid + 1) * hdsz;
479 uint32_t height_a = tru_seqlen + seqlen_entry;
480 uint32_t pitch_a = hiddensize;
481 int start_x_a = headid * hdsz;
482 int start_y_a = all_vert_loop * all_vert_stride
485 gemm_arg_128x128.matA_base_desc.init({args->
matQ_ptr},
486 {width_a, height_a, pitch_a},
487 {start_x_a, start_y_a});
489 uint32_t width_b = (headid + 1) * hdsz;
490 uint32_t height_b = tru_seqlen + seqlen_entry;
491 uint32_t pitch_b = hiddensize;
492 int start_x_b = headid * hdsz;
493 int start_y_b = seqlen_entry;
496 gemm_arg_128x128.matB_base_desc.init({args->
matK_ptr},
497 {height_b, width_b, pitch_b},
498 {start_y_b, start_x_b});
500 gemm_arg_128x128.inner_loop_count
503 matAcc_128x128.init(0);
508 g_thd32_tid, matAcc_128x128, gemm_arg_128x128);
510 matElem_reg_4x16x16.xetla_format<
float>()
511 .xetla_select<16 * 32, 1>(0)
512 = matAcc_128x128.reg * args->
Pinv;
518 uint32_t width_a = (headid + 1) * hdsz;
519 uint32_t height_a = tru_seqlen + seqlen_entry;
520 uint32_t pitch_a = hiddensize;
521 int start_x_a = headid * hdsz;
522 int start_y_a = all_vert_loop * all_vert_stride
525 gemm_arg_128x256.matA_base_desc.init({args->
matQ_ptr},
526 {width_a, height_a, pitch_a},
527 {start_x_a, start_y_a});
529 uint32_t width_b = (headid + 1) * hdsz;
530 uint32_t height_b = tru_seqlen + seqlen_entry;
531 uint32_t pitch_b = hiddensize;
532 int start_x_b = headid * hdsz;
533 int start_y_b = seqlen_entry;
536 gemm_arg_128x256.matB_base_desc.init({args->
matK_ptr},
537 {height_b, width_b, pitch_b},
538 {start_y_b, start_x_b});
540 gemm_arg_128x256.inner_loop_count
543 matAcc_128x256.init(0);
548 g_thd32_tid, matAcc_128x256, gemm_arg_128x256);
550 matElem_reg_4x16x16.xetla_format<
float>()
551 .xetla_select<4 * 16 * 16, 1>(0)
552 = matAcc_128x256.reg * args->
Pinv;
559 uint32_t width_a = (headid + 1) * hdsz;
560 uint32_t height_a = tru_seqlen + seqlen_entry;
561 uint32_t pitch_a = hiddensize;
562 int start_x_a = headid * hdsz;
563 int start_y_a = all_vert_loop * all_vert_stride
566 gemm_arg_64x384.matA_base_desc.init({args->
matQ_ptr},
567 {width_a, height_a, pitch_a},
568 {start_x_a, start_y_a});
570 uint32_t width_b = (headid + 1) * hdsz;
571 uint32_t height_b = tru_seqlen + seqlen_entry;
572 uint32_t pitch_b = hiddensize;
573 int start_x_b = headid * hdsz;
574 int start_y_b = seqlen_entry;
577 gemm_arg_64x384.matB_base_desc.init({args->
matK_ptr},
578 {height_b, width_b, pitch_b},
579 {start_y_b, start_x_b});
581 gemm_arg_64x384.inner_loop_count
584 matAcc_64x384.init(0);
588 g_thd32_tid, matAcc_64x384, gemm_arg_64x384);
590 matElem_reg_4x16x16.xetla_format<
float>()
591 .xetla_select<3 * 16 * 16, 1>(0)
592 = matAcc_64x384.reg * args->
Pinv;
599 uint32_t width_a = (headid + 1) * hdsz;
600 uint32_t height_a = tru_seqlen + seqlen_entry;
601 uint32_t pitch_a = hiddensize;
602 int start_x_a = headid * hdsz;
603 int start_y_a = all_vert_loop * all_vert_stride
606 gemm_arg_64x512.matA_base_desc.init({args->
matQ_ptr},
607 {width_a, height_a, pitch_a},
608 {start_x_a, start_y_a});
610 uint32_t width_b = (headid + 1) * hdsz;
611 uint32_t height_b = tru_seqlen + seqlen_entry;
612 uint32_t pitch_b = hiddensize;
613 int start_x_b = headid * hdsz;
614 int start_y_b = seqlen_entry;
617 gemm_arg_64x512.matB_base_desc.init({args->
matK_ptr},
618 {height_b, width_b, pitch_b},
619 {start_y_b, start_x_b});
621 gemm_arg_64x512.inner_loop_count
624 matAcc_64x512.init(0);
628 g_thd32_tid, matAcc_64x512, gemm_arg_64x512);
630 matElem_reg_4x16x16.xetla_format<
float>()
631 .xetla_select<4 * 16 * 16, 1>(0)
632 = matAcc_64x512.reg * args->
Pinv;
639 uint32_t width_a = (headid + 1) * hdsz;
640 uint32_t height_a = tru_seqlen + seqlen_entry;
641 uint32_t pitch_a = hiddensize;
642 int start_x_a = headid * hdsz;
643 int start_y_a = all_vert_loop * all_vert_stride
646 gemm_arg_32x1024.matA_base_desc.init({args->
matQ_ptr},
647 {width_a, height_a, pitch_a},
648 {start_x_a, start_y_a});
650 uint32_t width_b = (headid + 1) * hdsz;
651 uint32_t height_b = tru_seqlen + seqlen_entry;
652 uint32_t pitch_b = hiddensize;
653 int start_x_b = headid * hdsz;
654 int start_y_b = seqlen_entry;
657 gemm_arg_32x1024.matB_base_desc.init({args->
matK_ptr},
658 {height_b, width_b, pitch_b},
659 {start_y_b, start_x_b});
661 gemm_arg_32x1024.inner_loop_count
664 matAcc_32x1024.init(0);
667 g_thd32_tid, matAcc_32x1024, gemm_arg_32x1024);
669 matElem_reg_4x16x16.xetla_format<
float>()
670 .xetla_select<4 * 16 * 16, 1>(0)
671 = matAcc_32x1024.reg * args->
Pinv;
677 uint32_t width_a = (headid + 1) * hdsz;
678 uint32_t height_a = tru_seqlen + seqlen_entry;
679 uint32_t pitch_a = hiddensize;
680 int start_x_a = headid * hdsz;
681 int start_y_a = all_vert_loop * all_vert_stride
684 gemm_arg_16x2048.matA_base_desc.init({args->
matQ_ptr},
685 {width_a, height_a, pitch_a},
686 {start_x_a, start_y_a});
688 uint32_t width_b = (headid + 1) * hdsz;
689 uint32_t height_b = tru_seqlen + seqlen_entry;
690 uint32_t pitch_b = hiddensize;
691 int start_x_b = headid * hdsz;
692 int start_y_b = seqlen_entry;
695 gemm_arg_16x2048.matB_base_desc.init({args->
matK_ptr},
696 {height_b, width_b, pitch_b},
697 {start_y_b, start_x_b});
699 gemm_arg_16x2048.inner_loop_count
702 matAcc_16x2048.init(0);
705 g_thd32_tid, matAcc_16x2048, gemm_arg_16x2048);
707 matElem_reg_4x16x16.xetla_format<
float>()
708 .xetla_select<4 * 16 * 16, 1>(0)
709 = matAcc_16x2048.reg * args->
Pinv;
718 = xetla_vector_gen<uint32_t, 16>(0, 1);
720 = (batchid * numhead + headid) * Max_SeqLen
721 + all_vert_stride * all_vert_loop + tid_y * 16;
722 address_fmax += address_offset;
723 address_fmax *=
sizeof(float);
733 for (
int i = 0; i < 16; i++) {
734 matElem_reg_4x16x16.xetla_select<16, 1>(16 * i)
736 attn_mk_4x16.xetla_select<16,
742 = matElem_reg_4x16x16
743 .xetla_select<16 * 16, 1>(0);
746 if (valid_block_16x16_x > 1) {
748 for (
int i = 0; i < 16; i++) {
750 .xetla_select<16, 1>(
751 16 * i + 16 * 16 * 1)
753 attn_mk_4x16.xetla_select<16,
757 matElem_reg_Max.merge(
759 .xetla_select<16 * 16, 1>(
762 .xetla_select<16 * 16, 1>(
766 if (valid_block_16x16_x > 2) {
768 for (
int i = 0; i < 16; i++) {
770 .xetla_select<16, 1>(
771 16 * i + 16 * 16 * 2)
773 attn_mk_4x16.xetla_select<
777 matElem_reg_Max.merge(
779 .xetla_select<16 * 16, 1>(
781 matElem_reg_4x16x16.xetla_select<
782 16 * 16, 1>(16 * 16 * 2)
784 if (valid_block_16x16_x > 3) {
786 for (
int i = 0; i < 16; i++) {
788 .xetla_select<16, 1>(
789 16 * i + 16 * 16 * 3)
791 attn_mk_4x16.xetla_select<
795 matElem_reg_Max.merge(
797 .xetla_select<16 * 16, 1>(
799 matElem_reg_4x16x16.xetla_select<
800 16 * 16, 1>(16 * 16 * 3)
806 matElem_reg_Max_8.xetla_format<float, 16, 8>()
807 .xetla_select<16, 1, 8, 1>(0, 0)
808 .merge(matElem_reg_Max
809 .xetla_format<float, 16, 16>()
810 .xetla_select<16, 1, 8, 1>(
813 .xetla_format<float, 16, 16>()
814 .xetla_select<16, 1, 8, 1>(
816 matElem_reg_Max.xetla_format<
float, 16,
818 .xetla_select<16, 1, 8,
825 matElem_reg_Max_4.xetla_format<float, 16, 4>()
826 .xetla_select<16, 1, 4, 1>(0, 0)
827 .merge(matElem_reg_Max_8
828 .xetla_format<float, 16, 8>()
829 .xetla_select<16, 1, 4, 1>(
832 .xetla_format<float, 16, 8>()
833 .xetla_select<16, 1, 4, 1>(
838 .xetla_select<16, 1, 4,
845 matElem_reg_Max_2.xetla_format<float, 16, 2>()
846 .xetla_select<16, 1, 2, 1>(0, 0)
847 .merge(matElem_reg_Max_4
848 .xetla_format<float, 16, 4>()
849 .xetla_select<16, 1, 2, 1>(
852 .xetla_format<float, 16, 4>()
853 .xetla_select<16, 1, 2, 1>(
858 .xetla_select<16, 1, 2,
865 matElem_reg_max_local.xetla_format<float, 16, 1>()
866 .xetla_select<16, 1, 1, 1>(0, 0)
867 .merge(matElem_reg_Max_2
868 .xetla_format<float, 16, 2>()
869 .xetla_select<16, 1, 1, 1>(
872 .xetla_format<float, 16, 2>()
873 .xetla_select<16, 1, 1, 1>(
878 .xetla_select<16, 1, 1,
889 (uint64_t)args->
Max_ptr, address_fmax,
890 matElem_reg_max_local.xetla_select<16, 1>(0),
895 if constexpr (Dopt_RandGenflag ==
true) {
898 for (
int i = 0; i < ((16 * 16) / (2 * 4 * RandSIMD));
900 rand_data = Rand_Gen.
rand();
901 rand_bit.xetla_select<4 * RandSIMD, 1>(
903 = rand_data > rand_threshold;
912 16>(args->
Max_ptr, address_fmax);
914 auto matElem_reg_max_use = matElem_reg_max_global;
917 for (
int i = 0; i < 16; i++) {
918 matElem_reg_4x16x16.xetla_select<16 * 16, 1>(0)
919 .xetla_select<16, 1>(i * 16)
920 = matElem_reg_4x16x16
921 .xetla_select<16 * 16, 1>(0)
922 .xetla_select<16, 1>(i * 16)
923 - matElem_reg_max_use[i];
925 matElem_reg_4x16x16.xetla_select<16 * 16, 1>(0)
926 .xetla_select<16, 1>(i * 16)
927 = xetla_exp<float, 16>(
929 .xetla_select<16 * 16, 1>(0)
930 .xetla_select<16, 1>(
934 if (valid_block_16x16_x > 1) {
936 for (
int i = 0; i < 16; i++) {
938 .xetla_select<16 * 16, 1>(16 * 16 * 1)
939 .xetla_select<16, 1>(i * 16)
940 = matElem_reg_4x16x16
941 .xetla_select<16 * 16, 1>(
943 .xetla_select<16, 1>(i * 16)
944 - matElem_reg_max_use[i];
947 .xetla_select<16 * 16, 1>(16 * 16 * 1)
948 .xetla_select<16, 1>(i * 16)
949 = xetla_exp<float, 16>(
953 .xetla_select<16, 1>(
957 if (valid_block_16x16_x > 2) {
959 for (
int i = 0; i < 16; i++) {
961 .xetla_select<16 * 16, 1>(
963 .xetla_select<16, 1>(i * 16)
964 = matElem_reg_4x16x16
965 .xetla_select<16 * 16, 1>(
967 .xetla_select<16, 1>(
969 - matElem_reg_max_use[i];
971 .xetla_select<16 * 16, 1>(
973 .xetla_select<16, 1>(i * 16)
974 = xetla_exp<float, 16>(
983 if (valid_block_16x16_x > 3) {
985 for (
int i = 0; i < 16; i++) {
987 .xetla_select<16 * 16, 1>(
989 .xetla_select<16, 1>(i * 16)
990 = matElem_reg_4x16x16
991 .xetla_select<16 * 16,
994 .xetla_select<16, 1>(
996 - matElem_reg_max_use[i];
998 .xetla_select<16 * 16, 1>(
1000 .xetla_select<16, 1>(i * 16)
1001 = xetla_exp<float, 16>(
1026 = matElem_reg_4x16x16.xetla_select<16 * 16, 1>(
1029 if (valid_block_16x16_x > 1) {
1031 += matElem_reg_4x16x16
1032 .xetla_select<16 * 16, 1>(
1034 if (valid_block_16x16_x > 2) {
1036 += matElem_reg_4x16x16
1037 .xetla_select<16 * 16, 1>(
1039 if (valid_block_16x16_x > 3)
1041 += matElem_reg_4x16x16.xetla_select<
1042 16 * 16, 1>(16 * 16 * 3);
1045 matElem_reg_Sum_8.xetla_format<float, 16, 8>()
1046 = matElem_reg_Sum.xetla_format<
float, 16, 16>()
1047 .xetla_select<16, 1, 8, 1>(0, 0)
1048 + matElem_reg_Sum.xetla_format<
float, 16, 16>()
1049 .xetla_select<16, 1, 8, 1>(0, 8);
1051 matElem_reg_Sum_4.xetla_format<float, 16, 4>()
1052 = matElem_reg_Sum_8.xetla_format<
float, 16, 8>()
1053 .xetla_select<16, 1, 4, 1>(0, 0)
1054 + matElem_reg_Sum_8.xetla_format<
float, 16, 8>()
1055 .xetla_select<16, 1, 4, 1>(0, 4);
1057 matElem_reg_Sum_2.xetla_format<float, 16, 2>()
1058 = matElem_reg_Sum_4.xetla_format<
float, 16, 4>()
1059 .xetla_select<16, 1, 2, 1>(0, 0)
1060 + matElem_reg_Sum_4.xetla_format<
float, 16, 4>()
1061 .xetla_select<16, 1, 2, 1>(0, 2);
1063 matElem_reg_Sum_1.xetla_format<float, 16, 1>()
1064 = matElem_reg_Sum_2.xetla_format<
float, 16, 2>()
1065 .xetla_select<16, 1, 1, 1>(0, 0)
1066 + matElem_reg_Sum_2.xetla_format<
float, 16, 2>()
1067 .xetla_select<16, 1, 1, 1>(0, 1);
1072 (uint64_t)args->
Sum_ptr, address_fmax,
1073 matElem_reg_Sum_1.xetla_select<16, 1>(0), pred);
1077 if constexpr (Dopt_RandGenflag ==
true) {
1080 for (
int i = ((16 * 16) / (2 * 4 * RandSIMD));
1081 i < ((16 * 16) / (4 * RandSIMD)); i++) {
1082 rand_data = Rand_Gen.
rand();
1083 rand_bit.xetla_select<4 * RandSIMD, 1>(
1085 = rand_data > rand_threshold;
1088 second_nbarr.
wait();
1094 16>(args->
Sum_ptr, address_fmax);
1097 = xetla_inv<float, 16>(matElem_reg_Sum_1);
1098 matElem_reg_Sum_1 *= args->
Scaling;
1101 for (
int i = 0; i < 16; i++) {
1102 matElem_reg_4x16x16.xetla_select<16 * 16, 1>(0)
1103 .xetla_select<16, 1>(i * 16)
1104 = matElem_reg_4x16x16
1105 .xetla_select<16 * 16, 1>(0)
1106 .xetla_select<16, 1>(i * 16)
1107 * matElem_reg_Sum_1[i];
1110 if (valid_block_16x16_x > 1) {
1112 for (
int i = 0; i < 16; i++) {
1114 .xetla_select<16 * 16, 1>(16 * 16 * 1)
1115 .xetla_select<16, 1>(i * 16)
1116 = matElem_reg_4x16x16
1117 .xetla_select<16 * 16, 1>(
1119 .xetla_select<16, 1>(i * 16)
1120 * matElem_reg_Sum_1[i];
1123 if constexpr (Dopt_RandGenflag ==
true) {
1127 i < ((16 * 16) / (4 * RandSIMD)); i++) {
1128 rand_data = Rand_Gen.
rand();
1129 rand_bit.xetla_select<4 * RandSIMD, 1>(
1130 (i * (4 * RandSIMD))
1132 = rand_data > rand_threshold;
1136 if (valid_block_16x16_x > 2) {
1138 for (
int i = 0; i < 16; i++) {
1140 .xetla_select<16 * 16, 1>(
1142 .xetla_select<16, 1>(i * 16)
1143 = matElem_reg_4x16x16
1144 .xetla_select<16 * 16, 1>(
1146 .xetla_select<16, 1>(
1148 * matElem_reg_Sum_1[i];
1151 if constexpr (Dopt_RandGenflag ==
true) {
1156 i < ((16 * 16) / (4 * RandSIMD));
1158 rand_data = Rand_Gen.
rand();
1159 rand_bit.xetla_select<4 * RandSIMD, 1>(
1160 (i * (4 * RandSIMD))
1162 = rand_data > rand_threshold;
1166 if (valid_block_16x16_x > 3) {
1168 for (
int i = 0; i < 16; i++) {
1170 .xetla_select<16 * 16, 1>(
1172 .xetla_select<16, 1>(i * 16)
1173 = matElem_reg_4x16x16
1174 .xetla_select<16 * 16,
1177 .xetla_select<16, 1>(
1179 * matElem_reg_Sum_1[i];
1182 if constexpr (Dopt_RandGenflag ==
true) {
1187 < ((16 * 16) / (4 * RandSIMD));
1189 rand_data = Rand_Gen.
rand();
1190 rand_bit.xetla_select<4 * RandSIMD,
1191 1>((i * (4 * RandSIMD))
1204 switch (std_seqlen) {
1213 =
max_seqlen * (batchid * numhead + headid + 1);
1215 int start_x_c = gemm_op_128x128_t::get_matC_offset_x(
1219 + all_vert_loop * all_vert_stride
1220 + gemm_op_128x128_t::get_matC_offset_y(
1222 matC_128x128_payload.init(args->
matQKT_ptr, width_c,
1223 height_c, pitch_c, start_x_c, start_y_c);
1225 if constexpr (Dopt_RandGenflag ==
false) {
1226 uint8_t *matMkdpot_byte_ptr
1228 matDpotMk_128x128_payload.init(matMkdpot_byte_ptr,
1229 width_c, height_c, pitch_c, start_x_c,
1232 matDpotMk_128x128_payload);
1236 = matElem_reg_4x16x16.xetla_format<
float>()
1237 .xetla_select<16 * 32, 1>(0);
1238 matC_128x128.
reg = xetla_cvt<dtype_sfx, float>(
1241 if constexpr (Dopt_RandGenflag ==
false) {
1242 rand_bit.xetla_select<16 * 16 * 2, 1>(0)
1243 = matDpotMk_128x128.
reg;
1249 rand_bit.xetla_select<16 * 16 * 2, 1>(0)
1251 matC_128x128.
reg.xetla_format<uint16_t>()
1257 rand_bit.xetla_select<16 * 16 * 2, 1>(0)
1259 matC_128x128.
reg.xetla_format<uint8_t>()
1264 matC_128x128, matC_128x128_payload);
1265 xetla_fence<memory_kind::untyped_global>();
1276 =
max_seqlen * (batchid * numhead + headid + 1);
1278 int start_x_c = gemm_op_128x256_t::get_matC_offset_x(
1282 + all_vert_loop * all_vert_stride
1283 + gemm_op_128x256_t::get_matC_offset_y(
1286 matC_128x256_payload.init(args->
matQKT_ptr, width_c,
1287 height_c, pitch_c, start_x_c, start_y_c);
1289 if constexpr (Dopt_RandGenflag ==
false) {
1290 uint8_t *matMkdpot_byte_ptr
1292 matDpotMk_128x256_payload.init(matMkdpot_byte_ptr,
1293 width_c, height_c, pitch_c, start_x_c,
1296 matDpotMk_128x256_payload);
1299 matC_128x256.
reg = xetla_cvt<dtype_sfx, float>(
1300 matElem_reg_4x16x16);
1302 if constexpr (Dopt_RandGenflag ==
false) {
1303 rand_bit = matDpotMk_128x256.
reg;
1309 rand_bit.xetla_select<16 * 16 * 4, 1>(0)
1311 matC_128x256.
reg.xetla_format<uint16_t>()
1317 rand_bit.xetla_select<16 * 16 * 4, 1>(0)
1319 matC_128x256.
reg.xetla_format<uint8_t>()
1324 matC_128x256, matC_128x256_payload);
1325 xetla_fence<memory_kind::untyped_global>();
1336 =
max_seqlen * (batchid * numhead + headid + 1);
1338 int start_x_c = gemm_op_64x384_t::get_matC_offset_x(
1342 + all_vert_loop * all_vert_stride
1343 + gemm_op_64x384_t::get_matC_offset_y(
1346 matC_64x384_payload.init(args->
matQKT_ptr, width_c,
1347 height_c, pitch_c, start_x_c, start_y_c);
1349 if constexpr (Dopt_RandGenflag ==
false) {
1350 uint8_t *matMkdpot_byte_ptr
1352 matDpotMk_64x384_payload.init(matMkdpot_byte_ptr,
1353 width_c, height_c, pitch_c, start_x_c,
1356 matDpotMk_64x384, matDpotMk_64x384_payload);
1360 = matElem_reg_4x16x16.xetla_format<
float>()
1361 .xetla_select<3 * 16 * 16, 1>(0);
1362 matC_64x384.
reg = xetla_cvt<dtype_sfx, float>(
1365 if constexpr (Dopt_RandGenflag ==
false) {
1366 rand_bit.xetla_select<16 * 16 * 3, 1>(0)
1367 = matDpotMk_64x384.
reg;
1373 rand_bit.xetla_select<16 * 16 * 3, 1>(0)
1375 matC_64x384.
reg.xetla_format<uint16_t>()
1381 rand_bit.xetla_select<16 * 16 * 3, 1>(0)
1383 matC_64x384.
reg.xetla_format<uint8_t>()
1388 xetla_fence<memory_kind::untyped_global>();
1398 =
max_seqlen * (batchid * numhead + headid + 1);
1400 int start_x_c = gemm_op_64x512_t::get_matC_offset_x(
1404 + all_vert_loop * all_vert_stride
1405 + gemm_op_64x512_t::get_matC_offset_y(
1407 matC_64x512_payload.init(args->
matQKT_ptr, width_c,
1408 height_c, pitch_c, start_x_c, start_y_c);
1410 if constexpr (Dopt_RandGenflag ==
false) {
1411 uint8_t *matMkdpot_byte_ptr
1413 matDpotMk_64x512_payload.init(matMkdpot_byte_ptr,
1414 width_c, height_c, pitch_c, start_x_c,
1417 matDpotMk_64x512, matDpotMk_64x512_payload);
1420 matC_64x512.
reg = xetla_cvt<dtype_sfx, float>(
1421 matElem_reg_4x16x16);
1423 if constexpr (Dopt_RandGenflag ==
false) {
1424 rand_bit = matDpotMk_64x512.
reg;
1430 rand_bit.xetla_select<16 * 16 * 4, 1>(0)
1432 matC_64x512.
reg.xetla_format<uint16_t>()
1438 rand_bit.xetla_select<16 * 16 * 4, 1>(0)
1440 matC_64x512.
reg.xetla_format<uint8_t>()
1445 xetla_fence<memory_kind::untyped_global>();
1455 =
max_seqlen * (batchid * numhead + headid + 1);
1457 int start_x_c = gemm_op_32x1024_t::get_matC_offset_x(
1461 + all_vert_loop * all_vert_stride
1462 + gemm_op_32x1024_t::get_matC_offset_y(
1465 matC_32x1024_payload.init(args->
matQKT_ptr, width_c,
1466 height_c, pitch_c, start_x_c, start_y_c);
1468 if constexpr (Dopt_RandGenflag ==
false) {
1469 uint8_t *matMkdpot_byte_ptr
1471 matDpotMk_32x1024_payload.init(matMkdpot_byte_ptr,
1472 width_c, height_c, pitch_c, start_x_c,
1475 matDpotMk_32x1024_payload);
1478 matC_32x1024.
reg = xetla_cvt<dtype_sfx, float>(
1479 matElem_reg_4x16x16);
1481 if constexpr (Dopt_RandGenflag ==
false) {
1482 rand_bit = matDpotMk_32x1024.
reg;
1488 rand_bit.xetla_select<16 * 16 * 4, 1>(0)
1490 matC_32x1024.
reg.xetla_format<uint16_t>()
1496 rand_bit.xetla_select<16 * 16 * 4, 1>(0)
1498 matC_32x1024.
reg.xetla_format<uint8_t>()
1503 matC_32x1024, matC_32x1024_payload);
1504 xetla_fence<memory_kind::untyped_global>();
1514 =
max_seqlen * (batchid * numhead + headid + 1);
1516 int start_x_c = gemm_op_16x2048_t::get_matC_offset_x(
1520 + all_vert_loop * all_vert_stride
1521 + gemm_op_16x2048_t::get_matC_offset_y(
1524 matC_16x2048_payload.init(args->
matQKT_ptr, width_c,
1525 height_c, pitch_c, start_x_c, start_y_c);
1527 if constexpr (Dopt_RandGenflag ==
false) {
1528 uint8_t *matMkdpot_byte_ptr
1530 matDpotMk_16x2048_payload.init(matMkdpot_byte_ptr,
1531 width_c, height_c, pitch_c, start_x_c,
1534 matDpotMk_16x2048_payload);
1537 matC_16x2048.
reg = xetla_cvt<dtype_sfx, float>(
1538 matElem_reg_4x16x16);
1540 if constexpr (Dopt_RandGenflag ==
false) {
1541 rand_bit = matDpotMk_16x2048.
reg;
1547 rand_bit.xetla_select<16 * 16 * 4, 1>(0)
1549 matC_16x2048.
reg.xetla_format<uint16_t>()
1555 rand_bit.xetla_select<16 * 16 * 4, 1>(0)
1557 matC_16x2048.
reg.xetla_format<uint8_t>()
1562 matC_16x2048, matC_16x2048_payload);
1563 xetla_fence<memory_kind::untyped_global>();
1572 second_nbarr.
wait();
1576 int all_vert128_loop = all_vert_loop >> all_vert128_shift;
1577 if (((((all_vert128_loop + 1) << all_vert128_shift) - 1)
1579 || (all_vert128_shift == 0)) {
1589 uint32_t width_a = tru_seqlen_ex;
1590 uint32_t height_a = (batchid * numhead + headid) *
max_seqlen
1594 int start_y_a = (batchid * numhead + headid) *
max_seqlen
1595 + all_vert128_loop * 128;
1597 gemm_arg_128x64.matA_base_desc.init({args->
matQKT_ptr},
1598 {width_a, height_a, pitch_a}, {start_x_a, start_y_a});
1600 uint32_t width_b = (headid + 1) * hdsz;
1601 uint32_t height_b = tru_seqlen + seqlen_entry;
1602 uint32_t pitch_b = hiddensize;
1603 int start_x_b = headid * hdsz;
1604 int start_y_b = seqlen_entry;
1606 gemm_arg_128x64.matB_base_desc.init({args->
matV_ptr},
1607 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
1609 gemm_arg_128x64.inner_loop_count
1612 matAcc_128x64.init(0);
1616 gemm_op_128x64(g_thd32_tid, matAcc_128x64, gemm_arg_128x64);
1618 int width_c = (headid + 1) * hdsz;
1619 int height_c = tru_seqlen + seqlen_entry;
1620 int pitch_c = hiddensize;
1621 int start_x_c = headid * hdsz
1622 + gemm_op_128x64_t::get_matC_offset_x(g_thd32_tid);
1623 int start_y_c = all_vert128_loop * 128 + seqlen_entry
1624 + gemm_op_128x64_t::get_matC_offset_y(g_thd32_tid);
1626 matC_128x64_payload.init(args->
matOut_ptr, width_c, height_c,
1627 pitch_c, start_x_c, start_y_c);
1628 subgroup::elemwise_cvt<matC_128x64_t, matAcc_128x64_t>(
1629 matC_128x64, matAcc_128x64);
1637template <
typename dtype_bwd_bin_,
typename dtype_bwd_bot_,
1638 typename dtype_bwd_sfx_,
typename dtype_bwd_acc_,
int HWThreadNum,
1639 bool Dopt_RandGenflag =
true,
bool Mkin_flag =
false,
1640 int Max_SeqLen = 512>
1780 typename gemm_op_128x64_trnp_a_t::arguments_t;
1782 typename gemm_op_256x64_trnp_a_t::arguments_t;
1784 typename gemm_op_128x64_trnp_af_t::arguments_t;
1786 typename gemm_op_256x64_trnp_af_t::arguments_t;
1803 matAcc_128x128_t::tile_desc::tile_size_y,
1804 matAcc_128x128_t::tile_desc::block_size_x,
1805 matAcc_128x128_t::tile_desc::block_size_y,
1809 matAcc_128x256_t::tile_desc::tile_size_y,
1810 matAcc_128x256_t::tile_desc::block_size_x,
1811 matAcc_128x256_t::tile_desc::block_size_y,
1815 matAcc_64x384_t::tile_desc::tile_size_y,
1816 matAcc_64x384_t::tile_desc::block_size_x,
1817 matAcc_64x384_t::tile_desc::block_size_y,
1821 matAcc_64x512_t::tile_desc::tile_size_y,
1822 matAcc_64x512_t::tile_desc::block_size_x,
1823 matAcc_64x512_t::tile_desc::block_size_y,
1827 matAcc_32x1024_t::tile_desc::tile_size_y,
1828 matAcc_32x1024_t::tile_desc::block_size_x,
1829 matAcc_32x1024_t::tile_desc::block_size_y,
1833 matAcc_16x2048_t::tile_desc::tile_size_y,
1834 matAcc_16x2048_t::tile_desc::block_size_x,
1835 matAcc_16x2048_t::tile_desc::block_size_y,
1868 : subgroup::msg_type_v<
1875 : subgroup::msg_type_v<
1897 matAcc_128x64_t::tile_desc::tile_size_y,
1898 matAcc_128x64_t::tile_desc::block_size_x,
1899 matAcc_128x64_t::tile_desc::block_size_y,
1902 matAcc_128x64_trnp_a_t::tile_desc::tile_size_x,
1903 matAcc_128x64_trnp_a_t::tile_desc::tile_size_y,
1904 matAcc_128x64_trnp_a_t::tile_desc::block_size_x,
1907 matAcc_256x64_trnp_a_t::tile_desc::tile_size_x,
1908 matAcc_256x64_trnp_a_t::tile_desc::tile_size_y,
1909 matAcc_256x64_trnp_a_t::tile_desc::block_size_x,
1912 matAcc_128x64_trnp_af_t::tile_desc::tile_size_x,
1913 matAcc_128x64_trnp_af_t::tile_desc::tile_size_y,
1914 matAcc_128x64_trnp_af_t::tile_desc::block_size_x,
1915 matAcc_128x64_trnp_af_t::tile_desc::block_size_y,
1918 matAcc_256x64_trnp_af_t::tile_desc::tile_size_x,
1919 matAcc_256x64_trnp_af_t::tile_desc::tile_size_y,
1920 matAcc_256x64_trnp_af_t::tile_desc::block_size_x,
1921 matAcc_256x64_trnp_af_t::tile_desc::block_size_y,
1937 : subgroup::msg_type_v<
1951 subgroup::msg_type_v<matC_256x64_trnp_a_tile_desc_t, mem_space_c>,
1984 subgroup::msg_type_v<matC_128x128_tile_desc_t, mem_space_c>,
1989 subgroup::msg_type_v<matC_128x256_tile_desc_t, mem_space_c>,
1994 subgroup::msg_type_v<matC_64x384_tile_desc_t, mem_space_c>,
1999 subgroup::msg_type_v<matC_64x512_tile_desc_t, mem_space_c>,
2004 subgroup::msg_type_v<matC_32x1024_tile_desc_t, mem_space_c>,
2009 subgroup::msg_type_v<matC_16x2048_tile_desc_t, mem_space_c>,
2018 subgroup::msg_type_v<matElem_tile_desc, mem_space::global>>;
2051 int tru_seqlen_ex = 0;
2052 int seqlen_entry = 0;
2053 int hiddensize = 1024;
2056 int max_seqlen = Max_SeqLen;
2057 int wg_tile_QKT_k = hdsz;
2060 int groupid = item.get_group(0);
2061 int batchid = groupid / numhead;
2062 int headid = groupid % numhead;
2065 int tid_linear = item.get_local_linear_id();
2066 g_thd32_tid.init(tid_linear);
2070 uint32_t batch_offset =
sizeof(uint32_t) *
list_width * batchid;
2072 = xetla_vector_gen<uint32_t, list_width>(0, 1);
2073 list_offsets *=
sizeof(uint32_t);
2074 list_offsets += batch_offset;
2079 tru_seqlen = list_vec[0];
2080 seqlen_entry = list_vec[1];
2081 wg_tile_out_k = tru_seqlen;
2082 tru_seqlen_ex = tru_seqlen;
2084 tru_seqlen_ex = ((tru_seqlen + 1) >> 1) << 1;
2086 tru_seqlen_ex = ((tru_seqlen + 3) >> 2) << 2;
2089 int all_vert_loop_num = 0;
2090 int transp128_loop_num = 0;
2091 int transp256_loop_num = 0;
2092 int offset_blk_128x128 = 0;
2093 int all_vert_stride = 0;
2094 int all_vert128_shift = 0;
2095 int block_16x16_num = 0;
2096 int tid_x_shift = 0;
2099 if (tru_seqlen <= 128) {
2101 all_vert_loop_num = 1;
2102 transp128_loop_num = 1;
2104 all_vert_loop_num = 1;
2105 all_vert_stride = 128;
2106 block_16x16_num = 2;
2107 }
else if (tru_seqlen <= 256) {
2109 transp256_loop_num = 1;
2110 all_vert_loop_num = 2;
2111 all_vert_stride = 128;
2112 all_vert128_shift = 0;
2113 block_16x16_num = 4;
2115 }
else if (tru_seqlen <= 384) {
2117 transp128_loop_num = 1;
2118 transp256_loop_num = 1;
2119 offset_blk_128x128 = 256;
2120 all_vert_stride = 64;
2121 all_vert128_shift = 1;
2122 block_16x16_num = 3;
2124 all_vert_loop_num = (tru_seqlen + all_vert_stride - 1) >> 6;
2125 }
else if (tru_seqlen <= 512) {
2127 transp256_loop_num = 2;
2128 all_vert_stride = 64;
2129 all_vert128_shift = 1;
2130 block_16x16_num = 4;
2132 all_vert_loop_num = (tru_seqlen + all_vert_stride - 1) >> 6;
2133 }
else if (tru_seqlen <= 1024) {
2135 transp256_loop_num = 4;
2136 all_vert_stride = 32;
2137 all_vert128_shift = 2;
2138 block_16x16_num = 4;
2140 all_vert_loop_num = (tru_seqlen + all_vert_stride - 1) >> 5;
2141 }
else if (tru_seqlen <= 2048) {
2143 transp256_loop_num = 8;
2144 all_vert_stride = 16;
2145 all_vert128_shift = 3;
2146 block_16x16_num = 4;
2148 all_vert_loop_num = (tru_seqlen + all_vert_stride - 1) >> 4;
2150 all_vert_loop_num = ((all_vert_loop_num + (1 << all_vert128_shift) - 1)
2151 >> all_vert128_shift)
2152 << all_vert128_shift;
2153 int tid_x = tid_linear & ((1 << tid_x_shift) - 1);
2154 int tid_y = tid_linear >> tid_x_shift;
2156 static_assert(
ThreadNum == 32,
"All Thread Sync");
2170 for (
int transp128_loop = 0; transp128_loop < transp128_loop_num;
2177 uint32_t width_a = tru_seqlen_ex;
2179 = (batchid * numhead + headid) * max_seqlen + tru_seqlen;
2180 uint32_t pitch_a = max_seqlen;
2181 int start_x_a = transp128_loop * 128 + offset_blk_128x128;
2182 int start_y_a = (batchid * numhead + headid) * max_seqlen;
2184 gemm_arg_128x64.matA_base_desc.init({args->
matW_ptr},
2185 {height_a, width_a, pitch_a}, {start_y_a, start_x_a});
2187 uint32_t width_b = (headid + 1) * hdsz;
2188 uint32_t height_b = tru_seqlen + seqlen_entry;
2189 uint32_t pitch_b = hiddensize;
2190 int start_x_b = headid * hdsz;
2191 int start_y_b = seqlen_entry;
2193 gemm_arg_128x64.matB_base_desc.init({args->
matdO_ptr},
2194 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
2196 gemm_arg_128x64.inner_loop_count
2199 matAcc_128x64.init(0);
2201 gemm_op_128x64_trnp_af(g_thd32_tid, matAcc_128x64, gemm_arg_128x64);
2203 int width_c = (headid + 1) * hdsz;
2204 int height_c = tru_seqlen + seqlen_entry;
2205 int pitch_c = hiddensize;
2206 int start_x_c = headid * hdsz
2207 + gemm_op_128x64_trnp_af_t::get_matC_offset_x(g_thd32_tid);
2208 int start_y_c = transp128_loop * 128 + seqlen_entry
2209 + offset_blk_128x128
2210 + gemm_op_128x64_trnp_af_t::get_matC_offset_y(g_thd32_tid);
2212 matC_128x64_payload.init(args->
matdV_ptr, width_c, height_c,
2213 pitch_c, start_x_c, start_y_c);
2223 for (
int transp256_loop = 0; transp256_loop < transp256_loop_num;
2230 uint32_t width_a = tru_seqlen_ex;
2232 = (batchid * numhead + headid) * max_seqlen + tru_seqlen;
2233 uint32_t pitch_a = max_seqlen;
2234 int start_x_a = transp256_loop * 256;
2235 int start_y_a = (batchid * numhead + headid) * max_seqlen;
2237 gemm_arg_256x64.matA_base_desc.init({args->
matW_ptr},
2238 {height_a, width_a, pitch_a}, {start_y_a, start_x_a});
2240 uint32_t width_b = (headid + 1) * hdsz;
2241 uint32_t height_b = tru_seqlen + seqlen_entry;
2242 uint32_t pitch_b = hiddensize;
2243 int start_x_b = headid * hdsz;
2244 int start_y_b = seqlen_entry;
2246 gemm_arg_256x64.matB_base_desc.init({args->
matdO_ptr},
2247 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
2249 gemm_arg_256x64.inner_loop_count
2252 matAcc_256x64.init(0);
2254 gemm_op_256x64_trnp_af(g_thd32_tid, matAcc_256x64, gemm_arg_256x64);
2256 int width_c = (headid + 1) * hdsz;
2257 int height_c = tru_seqlen + seqlen_entry;
2258 int pitch_c = hiddensize;
2259 int start_x_c = headid * hdsz
2260 + gemm_op_256x64_trnp_af_t::get_matC_offset_x(g_thd32_tid);
2261 int start_y_c = transp256_loop * 256 + seqlen_entry
2262 + gemm_op_256x64_trnp_af_t::get_matC_offset_y(g_thd32_tid);
2264 matC_256x64_payload.init(args->
matdV_ptr, width_c, height_c,
2265 pitch_c, start_x_c, start_y_c);
2275 int valid_block_16x16_x = (tid_x + 1) * 16 * block_16x16_num;
2277 int bndy_block_num = 0;
2278 if (valid_block_16x16_x <= tru_seqlen)
2279 valid_block_16x16_x = block_16x16_num;
2281 bndy_block_num = valid_block_16x16_x;
2282 valid_block_16x16_x = (tru_seqlen + 15 + 16 * block_16x16_num
2283 - valid_block_16x16_x)
2285 bndy_block_num = bndy_block_num
2286 + (valid_block_16x16_x - block_16x16_num) * 16
2291 for (
int all_vert_loop = 0; all_vert_loop < all_vert_loop_num;
2296 bool valid_compute =
true;
2298 int ld_st_width_c = max_seqlen;
2299 int ld_st_height_c = max_seqlen * (batchid * numhead + headid + 1);
2300 int ld_st_pitch_c = max_seqlen;
2301 int ld_st_start_x_c = 0;
2302 int ld_st_start_y_c = 0;
2304 if (((all_vert_loop * all_vert_stride + tid_y * 16) >= tru_seqlen)
2305 || ((tid_x * 16 * block_16x16_num) >= tru_seqlen))
2306 valid_compute =
false;
2308 if (valid_compute) {
2310 switch (std_seqlen) {
2318 ld_st_start_x_c = gemm_op_128x128_t::get_matC_offset_x(
2321 = (batchid * numhead + headid) * max_seqlen
2322 + all_vert_loop * all_vert_stride
2323 + gemm_op_128x128_t::get_matC_offset_y(
2326 matW_128x128_payload.init(args->
matW_ptr, ld_st_width_c,
2327 ld_st_height_c, ld_st_pitch_c, ld_st_start_x_c,
2331 uint32_t width_a = (headid + 1) * hdsz;
2332 uint32_t height_a = tru_seqlen + seqlen_entry;
2333 uint32_t pitch_a = hiddensize;
2334 int start_x_a = headid * hdsz;
2335 int start_y_a = all_vert_loop * all_vert_stride
2338 gemm_arg_128x128.matA_base_desc.init({args->
matdO_ptr},
2339 {width_a, height_a, pitch_a},
2340 {start_x_a, start_y_a});
2342 uint32_t width_b = (headid + 1) * hdsz;
2343 uint32_t height_b = tru_seqlen + seqlen_entry;
2344 uint32_t pitch_b = hiddensize;
2345 int start_x_b = headid * hdsz;
2346 int start_y_b = seqlen_entry;
2349 gemm_arg_128x128.matB_base_desc.init({args->
matV_ptr},
2350 {height_b, width_b, pitch_b},
2351 {start_y_b, start_x_b});
2353 gemm_arg_128x128.inner_loop_count
2356 matAcc_128x128.init(0);
2359 g_thd32_tid, matAcc_128x128, gemm_arg_128x128);
2361 matElem_reg_4x16x16.xetla_format<
float>()
2362 .xetla_select<16 * 32, 1>(0)
2363 = matAcc_128x128.reg;
2366 Sign_reg_4x16x16.xetla_select<16 * 16 * 2, 1>(0)
2368 matW_128x128.
reg.xetla_format<
2371 matW_128x128.
reg.xetla_format<uint8_t>() &= 0x7F;
2374 Sign_reg_4x16x16.xetla_select<16 * 16 * 2, 1>(0)
2376 matW_128x128.
reg.xetla_format<
2379 matW_128x128.
reg.xetla_format<uint16_t>() &= 0x7FFF;
2384 = xetla_cvt<float, dtype_sfx>(matW_128x128.
reg);
2385 matW_reg_4x16x16.xetla_select<16 * 16 * 2, 1>(0)
2396 ld_st_start_x_c = gemm_op_128x256_t::get_matC_offset_x(
2399 = (batchid * numhead + headid) * max_seqlen
2400 + all_vert_loop * all_vert_stride
2401 + gemm_op_128x256_t::get_matC_offset_y(
2404 matW_128x256_payload.init(args->
matW_ptr, ld_st_width_c,
2405 ld_st_height_c, ld_st_pitch_c, ld_st_start_x_c,
2409 uint32_t width_a = (headid + 1) * hdsz;
2410 uint32_t height_a = tru_seqlen + seqlen_entry;
2411 uint32_t pitch_a = hiddensize;
2412 int start_x_a = headid * hdsz;
2413 int start_y_a = all_vert_loop * all_vert_stride
2416 gemm_arg_128x256.matA_base_desc.init({args->
matdO_ptr},
2417 {width_a, height_a, pitch_a},
2418 {start_x_a, start_y_a});
2420 uint32_t width_b = (headid + 1) * hdsz;
2421 uint32_t height_b = tru_seqlen + seqlen_entry;
2422 uint32_t pitch_b = hiddensize;
2423 int start_x_b = headid * hdsz;
2424 int start_y_b = seqlen_entry;
2427 gemm_arg_128x256.matB_base_desc.init({args->
matV_ptr},
2428 {height_b, width_b, pitch_b},
2429 {start_y_b, start_x_b});
2431 gemm_arg_128x256.inner_loop_count
2434 matAcc_128x256.init(0);
2437 g_thd32_tid, matAcc_128x256, gemm_arg_128x256);
2439 matElem_reg_4x16x16.xetla_format<
float>()
2440 .xetla_select<16 * 16 * 4, 1>(0)
2441 = matAcc_128x256.reg;
2444 Sign_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2446 matW_128x256.
reg.xetla_format<
2449 matW_128x256.
reg.xetla_format<uint8_t>() &= 0x7F;
2452 Sign_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2454 matW_128x256.
reg.xetla_format<
2457 matW_128x256.
reg.xetla_format<uint16_t>() &= 0x7FFF;
2460 matW_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2461 = xetla_cvt<float, dtype_sfx>(matW_128x256.
reg);
2471 ld_st_start_x_c = gemm_op_64x384_t::get_matC_offset_x(
2474 = (batchid * numhead + headid) * max_seqlen
2475 + all_vert_loop * all_vert_stride
2476 + gemm_op_64x384_t::get_matC_offset_y(
2478 matW_64x384_payload.init(args->
matW_ptr, ld_st_width_c,
2479 ld_st_height_c, ld_st_pitch_c, ld_st_start_x_c,
2483 uint32_t width_a = (headid + 1) * hdsz;
2484 uint32_t height_a = tru_seqlen + seqlen_entry;
2485 uint32_t pitch_a = hiddensize;
2486 int start_x_a = headid * hdsz;
2487 int start_y_a = all_vert_loop * all_vert_stride
2490 gemm_arg_64x384.matA_base_desc.init({args->
matdO_ptr},
2491 {width_a, height_a, pitch_a},
2492 {start_x_a, start_y_a});
2494 uint32_t width_b = (headid + 1) * hdsz;
2495 uint32_t height_b = tru_seqlen + seqlen_entry;
2496 uint32_t pitch_b = hiddensize;
2497 int start_x_b = headid * hdsz;
2498 int start_y_b = seqlen_entry;
2501 gemm_arg_64x384.matB_base_desc.init({args->
matV_ptr},
2502 {height_b, width_b, pitch_b},
2503 {start_y_b, start_x_b});
2505 gemm_arg_64x384.inner_loop_count
2508 matAcc_64x384.init(0);
2511 g_thd32_tid, matAcc_64x384, gemm_arg_64x384);
2513 matElem_reg_4x16x16.xetla_format<
float>()
2514 .xetla_select<16 * 16 * 3, 1>(0)
2515 = matAcc_64x384.reg;
2518 Sign_reg_4x16x16.xetla_select<16 * 16 * 3, 1>(0)
2520 matW_64x384.
reg.xetla_format<
2523 matW_64x384.
reg.xetla_format<uint8_t>() &= 0x7F;
2526 Sign_reg_4x16x16.xetla_select<16 * 16 * 3, 1>(0)
2528 matW_64x384.
reg.xetla_format<
2531 matW_64x384.
reg.xetla_format<uint16_t>() &= 0x7FFF;
2536 = xetla_cvt<float, dtype_sfx>(matW_64x384.
reg);
2537 matW_reg_4x16x16.xetla_select<16 * 16 * 3, 1>(0)
2548 ld_st_start_x_c = gemm_op_64x512_t::get_matC_offset_x(
2551 = (batchid * numhead + headid) * max_seqlen
2552 + all_vert_loop * all_vert_stride
2553 + gemm_op_64x512_t::get_matC_offset_y(
2555 matW_64x512_payload.init(args->
matW_ptr, ld_st_width_c,
2556 ld_st_height_c, ld_st_pitch_c, ld_st_start_x_c,
2560 uint32_t width_a = (headid + 1) * hdsz;
2561 uint32_t height_a = tru_seqlen + seqlen_entry;
2562 uint32_t pitch_a = hiddensize;
2563 int start_x_a = headid * hdsz;
2564 int start_y_a = all_vert_loop * all_vert_stride
2567 gemm_arg_64x512.matA_base_desc.init({args->
matdO_ptr},
2568 {width_a, height_a, pitch_a},
2569 {start_x_a, start_y_a});
2571 uint32_t width_b = (headid + 1) * hdsz;
2572 uint32_t height_b = tru_seqlen + seqlen_entry;
2573 uint32_t pitch_b = hiddensize;
2574 int start_x_b = headid * hdsz;
2575 int start_y_b = seqlen_entry;
2578 gemm_arg_64x512.matB_base_desc.init({args->
matV_ptr},
2579 {height_b, width_b, pitch_b},
2580 {start_y_b, start_x_b});
2582 gemm_arg_64x512.inner_loop_count
2585 matAcc_64x512.init(0);
2588 g_thd32_tid, matAcc_64x512, gemm_arg_64x512);
2590 matElem_reg_4x16x16.xetla_format<
float>()
2591 .xetla_select<16 * 16 * 4, 1>(0)
2592 = matAcc_64x512.reg;
2595 Sign_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2597 matW_64x512.
reg.xetla_format<
2600 matW_64x512.
reg.xetla_format<uint8_t>() &= 0x7F;
2603 Sign_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2605 matW_64x512.
reg.xetla_format<
2608 matW_64x512.
reg.xetla_format<uint16_t>() &= 0x7FFF;
2611 matW_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2612 = xetla_cvt<float, dtype_sfx>(matW_64x512.
reg);
2622 ld_st_start_x_c = gemm_op_32x1024_t::get_matC_offset_x(
2625 = (batchid * numhead + headid) * max_seqlen
2626 + all_vert_loop * all_vert_stride
2627 + gemm_op_32x1024_t::get_matC_offset_y(
2629 matW_32x1024_payload.init(args->
matW_ptr, ld_st_width_c,
2630 ld_st_height_c, ld_st_pitch_c, ld_st_start_x_c,
2634 uint32_t width_a = (headid + 1) * hdsz;
2635 uint32_t height_a = tru_seqlen + seqlen_entry;
2636 uint32_t pitch_a = hiddensize;
2637 int start_x_a = headid * hdsz;
2638 int start_y_a = all_vert_loop * all_vert_stride
2641 gemm_arg_32x1024.matA_base_desc.init({args->
matdO_ptr},
2642 {width_a, height_a, pitch_a},
2643 {start_x_a, start_y_a});
2645 uint32_t width_b = (headid + 1) * hdsz;
2646 uint32_t height_b = tru_seqlen + seqlen_entry;
2647 uint32_t pitch_b = hiddensize;
2648 int start_x_b = headid * hdsz;
2649 int start_y_b = seqlen_entry;
2652 gemm_arg_32x1024.matB_base_desc.init({args->
matV_ptr},
2653 {height_b, width_b, pitch_b},
2654 {start_y_b, start_x_b});
2656 gemm_arg_32x1024.inner_loop_count
2659 matAcc_32x1024.init(0);
2662 g_thd32_tid, matAcc_32x1024, gemm_arg_32x1024);
2664 matElem_reg_4x16x16.xetla_format<
float>()
2665 .xetla_select<16 * 16 * 4, 1>(0)
2666 = matAcc_32x1024.reg;
2669 Sign_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2671 matW_32x1024.
reg.xetla_format<
2674 matW_32x1024.
reg.xetla_format<uint8_t>() &= 0x7F;
2677 Sign_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2679 matW_32x1024.
reg.xetla_format<
2682 matW_32x1024.
reg.xetla_format<uint16_t>() &= 0x7FFF;
2685 matW_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2686 = xetla_cvt<float, dtype_sfx>(matW_32x1024.
reg);
2696 ld_st_start_x_c = gemm_op_16x2048_t::get_matC_offset_x(
2699 = (batchid * numhead + headid) * max_seqlen
2700 + all_vert_loop * all_vert_stride
2701 + gemm_op_16x2048_t::get_matC_offset_y(
2703 matW_16x2048_payload.init(args->
matW_ptr, ld_st_width_c,
2704 ld_st_height_c, ld_st_pitch_c, ld_st_start_x_c,
2708 uint32_t width_a = (headid + 1) * hdsz;
2709 uint32_t height_a = tru_seqlen + seqlen_entry;
2710 uint32_t pitch_a = hiddensize;
2711 int start_x_a = headid * hdsz;
2712 int start_y_a = all_vert_loop * all_vert_stride
2715 gemm_arg_16x2048.matA_base_desc.init({args->
matdO_ptr},
2716 {width_a, height_a, pitch_a},
2717 {start_x_a, start_y_a});
2719 uint32_t width_b = (headid + 1) * hdsz;
2720 uint32_t height_b = tru_seqlen + seqlen_entry;
2721 uint32_t pitch_b = hiddensize;
2722 int start_x_b = headid * hdsz;
2723 int start_y_b = seqlen_entry;
2726 gemm_arg_16x2048.matB_base_desc.init({args->
matV_ptr},
2727 {height_b, width_b, pitch_b},
2728 {start_y_b, start_x_b});
2730 gemm_arg_16x2048.inner_loop_count
2733 matAcc_16x2048.init(0);
2736 g_thd32_tid, matAcc_16x2048, gemm_arg_16x2048);
2738 matElem_reg_4x16x16.xetla_format<
float>()
2739 .xetla_select<16 * 16 * 4, 1>(0)
2740 = matAcc_16x2048.reg;
2743 Sign_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2745 matW_16x2048.
reg.xetla_format<
2748 matW_16x2048.
reg.xetla_format<uint8_t>() &= 0x7F;
2751 Sign_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2753 matW_16x2048.
reg.xetla_format<
2756 matW_16x2048.
reg.xetla_format<uint16_t>() &= 0x7FFF;
2759 matW_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2760 = xetla_cvt<float, dtype_sfx>(matW_16x2048.
reg);
2773 matElem_reg_4x16x16.xetla_format<
float>()
2774 .xetla_select<16 * 16 * 2, 1>(0)
2775 *= matW_reg_4x16x16.xetla_select<16 * 16 * 2, 1>(0);
2777 matElem_reg_4x16x16.xetla_format<
float>()
2778 .xetla_select<16 * 16 * 2, 1>(0)
2780 Sign_reg_4x16x16.xetla_select<16 * 16 * 2,
2784 matElem_reg_Sum = matElem_reg_4x16x16.xetla_format<
float>()
2785 .xetla_select<16 * 16, 1>(0)
2786 + matElem_reg_4x16x16.xetla_format<
float>()
2787 .xetla_select<16 * 16, 1>(16 * 16);
2789 if (valid_block_16x16_x > 2) {
2791 matElem_reg_4x16x16.xetla_format<
float>()
2792 .xetla_select<16 * 16, 1>(16 * 16 * 2)
2793 *= matW_reg_4x16x16.xetla_select<16 * 16, 1>(
2796 matElem_reg_4x16x16.xetla_format<
float>()
2797 .xetla_select<16 * 16, 1>(16 * 16 * 2)
2799 Sign_reg_4x16x16.xetla_select<16 * 16,
2803 matElem_reg_Sum = matElem_reg_Sum
2804 + matElem_reg_4x16x16.xetla_format<
float>()
2805 .xetla_select<16 * 16, 1>(
2808 if (valid_block_16x16_x > 3) {
2809 matElem_reg_4x16x16.xetla_format<
float>()
2810 .xetla_select<16 * 16, 1>(16 * 16 * 3)
2812 .xetla_select<16 * 16, 1>(
2815 matElem_reg_4x16x16.xetla_format<
float>()
2816 .xetla_select<16 * 16, 1>(16 * 16 * 3)
2818 Sign_reg_4x16x16.xetla_select<
2819 16 * 16, 1>(16 * 16 * 3)
2822 matElem_reg_Sum = matElem_reg_Sum
2823 + matElem_reg_4x16x16.xetla_format<
float>()
2824 .xetla_select<16 * 16, 1>(
2829 matElem_reg_Sum_8.xetla_format<float, 16, 8>()
2830 = matElem_reg_Sum.xetla_format<
float, 16, 16>()
2831 .xetla_select<16, 1, 8, 1>(0, 0)
2832 + matElem_reg_Sum.xetla_format<
float, 16, 16>()
2833 .xetla_select<16, 1, 8, 1>(0, 8);
2835 matElem_reg_Sum_4.xetla_format<float, 16, 4>()
2836 = matElem_reg_Sum_8.xetla_format<
float, 16, 8>()
2837 .xetla_select<16, 1, 4, 1>(0, 0)
2838 + matElem_reg_Sum_8.xetla_format<
float, 16, 8>()
2839 .xetla_select<16, 1, 4, 1>(0, 4);
2841 matElem_reg_Sum_2.xetla_format<float, 16, 2>()
2842 = matElem_reg_Sum_4.xetla_format<
float, 16, 4>()
2843 .xetla_select<16, 1, 2, 1>(0, 0)
2844 + matElem_reg_Sum_4.xetla_format<
float, 16, 4>()
2845 .xetla_select<16, 1, 2, 1>(0, 2);
2847 matElem_reg_Sum_1.xetla_format<float, 16, 1>()
2848 = matElem_reg_Sum_2.xetla_format<
float, 16, 2>()
2849 .xetla_select<16, 1, 1, 1>(0, 0)
2850 + matElem_reg_Sum_2.xetla_format<
float, 16, 2>()
2851 .xetla_select<16, 1, 1, 1>(0, 1);
2854 = xetla_vector_gen<uint32_t, 16>(0, 1);
2856 = (batchid * numhead + headid) * Max_SeqLen
2857 + all_vert_stride * all_vert_loop + tid_y * 16;
2858 address_fsum += address_offset;
2859 address_fsum *=
sizeof(float);
2865 matElem_reg_Sum_1.xetla_select<16, 1>(0), pred);
2875 matElem_reg_Sum_1 *= args->
Scaling;
2878 for (
int i = 0; i < 16; i++) {
2879 matW_reg_4x16x16.xetla_select<16 * 16, 1>(16 * 16 * 0)
2880 .xetla_select<16, 1>(i * 16)
2882 .xetla_select<16 * 16, 1>(16 * 16 * 0)
2883 .xetla_select<16, 1>(i * 16)
2884 * matElem_reg_Sum_1[i];
2887 .xetla_select<16 * 16, 1>(16 * 16 * 0)
2888 .xetla_select<16, 1>(i * 16)
2889 = matElem_reg_4x16x16
2890 .xetla_select<16 * 16, 1>(16 * 16 * 0)
2891 .xetla_select<16, 1>(i * 16)
2893 .xetla_select<16 * 16, 1>(16 * 16 * 0)
2894 .xetla_select<16, 1>(i * 16);
2897 .xetla_select<16 * 16, 1>(16 * 16 * 0)
2898 .xetla_select<16, 1>(i * 16)
2902 if (valid_block_16x16_x > 1) {
2904 for (
int i = 0; i < 16; i++) {
2906 .xetla_select<16 * 16, 1>(16 * 16 * 1)
2907 .xetla_select<16, 1>(i * 16)
2909 .xetla_select<16 * 16, 1>(
2911 .xetla_select<16, 1>(i * 16)
2912 * matElem_reg_Sum_1[i];
2915 .xetla_select<16 * 16, 1>(16 * 16 * 1)
2916 .xetla_select<16, 1>(i * 16)
2917 = matElem_reg_4x16x16
2918 .xetla_select<16 * 16, 1>(
2920 .xetla_select<16, 1>(i * 16)
2922 .xetla_select<16 * 16, 1>(
2924 .xetla_select<16, 1>(i * 16);
2927 .xetla_select<16 * 16, 1>(16 * 16 * 1)
2928 .xetla_select<16, 1>(i * 16)
2932 if (valid_block_16x16_x > 2) {
2934 for (
int i = 0; i < 16; i++) {
2936 .xetla_select<16 * 16, 1>(16 * 16 * 2)
2937 .xetla_select<16, 1>(i * 16)
2939 .xetla_select<16 * 16, 1>(
2941 .xetla_select<16, 1>(i * 16)
2942 * matElem_reg_Sum_1[i];
2945 .xetla_select<16 * 16, 1>(16 * 16 * 2)
2946 .xetla_select<16, 1>(i * 16)
2947 = matElem_reg_4x16x16
2948 .xetla_select<16 * 16, 1>(
2950 .xetla_select<16, 1>(i * 16)
2952 .xetla_select<16 * 16, 1>(
2954 .xetla_select<16, 1>(i * 16);
2957 .xetla_select<16 * 16, 1>(16 * 16 * 2)
2958 .xetla_select<16, 1>(i * 16)
2962 if (valid_block_16x16_x > 3) {
2964 for (
int i = 0; i < 16; i++) {
2966 .xetla_select<16 * 16, 1>(
2968 .xetla_select<16, 1>(i * 16)
2970 .xetla_select<16 * 16, 1>(
2972 .xetla_select<16, 1>(
2974 * matElem_reg_Sum_1[i];
2977 .xetla_select<16 * 16, 1>(
2979 .xetla_select<16, 1>(i * 16)
2980 = matElem_reg_4x16x16
2981 .xetla_select<16 * 16, 1>(
2983 .xetla_select<16, 1>(
2986 .xetla_select<16 * 16, 1>(
2988 .xetla_select<16, 1>(
2992 .xetla_select<16 * 16, 1>(
2994 .xetla_select<16, 1>(i * 16)
3003 switch (std_seqlen) {
3008 matC_128x128_payload.init(args->
matdW_ptr,
3009 ld_st_width_c, ld_st_height_c, ld_st_pitch_c,
3010 ld_st_start_x_c, ld_st_start_y_c);
3013 = matElem_reg_4x16x16.xetla_format<
float>()
3014 .xetla_select<16 * 32, 1>(0);
3015 matC_128x128.
reg = xetla_cvt<dtype_sfx, float>(
3019 matC_128x128, matC_128x128_payload);
3020 xetla_fence<memory_kind::untyped_global>();
3027 matC_128x256_payload.init(args->
matdW_ptr,
3028 ld_st_width_c, ld_st_height_c, ld_st_pitch_c,
3029 ld_st_start_x_c, ld_st_start_y_c);
3031 matC_128x256.
reg = xetla_cvt<dtype_sfx, float>(
3032 matElem_reg_4x16x16);
3035 matC_128x256, matC_128x256_payload);
3036 xetla_fence<memory_kind::untyped_global>();
3043 matC_64x384_payload.init(args->
matdW_ptr, ld_st_width_c,
3044 ld_st_height_c, ld_st_pitch_c, ld_st_start_x_c,
3048 = matElem_reg_4x16x16.xetla_format<
float>()
3049 .xetla_select<16 * 16 * 3, 1>(0);
3050 matC_64x384.
reg = xetla_cvt<dtype_sfx, float>(
3054 xetla_fence<memory_kind::untyped_global>();
3061 matC_64x512_payload.init(args->
matdW_ptr, ld_st_width_c,
3062 ld_st_height_c, ld_st_pitch_c, ld_st_start_x_c,
3065 matC_64x512.
reg = xetla_cvt<dtype_sfx, float>(
3066 matElem_reg_4x16x16);
3069 xetla_fence<memory_kind::untyped_global>();
3076 matC_32x1024_payload.init(args->
matdW_ptr,
3077 ld_st_width_c, ld_st_height_c, ld_st_pitch_c,
3078 ld_st_start_x_c, ld_st_start_y_c);
3080 matC_32x1024.
reg = xetla_cvt<dtype_sfx, float>(
3081 matElem_reg_4x16x16);
3084 matC_32x1024, matC_32x1024_payload);
3085 xetla_fence<memory_kind::untyped_global>();
3092 matC_16x2048_payload.init(args->
matdW_ptr,
3093 ld_st_width_c, ld_st_height_c, ld_st_pitch_c,
3094 ld_st_start_x_c, ld_st_start_y_c);
3096 matC_16x2048.
reg = xetla_cvt<dtype_sfx, float>(
3097 matElem_reg_4x16x16);
3100 matC_16x2048, matC_16x2048_payload);
3101 xetla_fence<memory_kind::untyped_global>();
3112 second_nbarr.
wait();
3114 int all_vert128_loop = all_vert_loop >> all_vert128_shift;
3115 if (((((all_vert128_loop + 1) << all_vert128_shift) - 1)
3117 || (all_vert128_shift == 0)) {
3123 uint32_t width_a = tru_seqlen_ex;
3124 uint32_t height_a = (batchid * numhead + headid) * max_seqlen
3126 uint32_t pitch_a = max_seqlen;
3128 int start_y_a = (batchid * numhead + headid) * max_seqlen
3129 + all_vert128_loop * 128;
3131 gemm_arg_128x64.matA_base_desc.init({args->
matdW_ptr},
3132 {width_a, height_a, pitch_a}, {start_x_a, start_y_a});
3134 uint32_t width_b = (headid + 1) * hdsz;
3135 uint32_t height_b = tru_seqlen + seqlen_entry;
3136 uint32_t pitch_b = hiddensize;
3137 int start_x_b = headid * hdsz;
3138 int start_y_b = seqlen_entry;
3140 gemm_arg_128x64.matB_base_desc.init({args->
matK_ptr},
3141 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
3143 gemm_arg_128x64.inner_loop_count
3146 matAcc_128x64.init(0);
3148 gemm_op_128x64(g_thd32_tid, matAcc_128x64, gemm_arg_128x64);
3150 int ld_st_width_c = (headid + 1) * hdsz;
3151 int height_c = tru_seqlen + seqlen_entry;
3152 int pitch_c = hiddensize;
3153 int start_x_c = headid * hdsz
3154 + gemm_op_128x64_t::get_matC_offset_x(g_thd32_tid);
3155 int start_y_c = all_vert128_loop * 128 + seqlen_entry
3156 + gemm_op_128x64_t::get_matC_offset_y(g_thd32_tid);
3158 matC_128x64_payload.init(args->
matdQ_ptr, ld_st_width_c,
3159 height_c, pitch_c, start_x_c, start_y_c);
3160 subgroup::elemwise_cvt<matC_128x64_t, matAcc_128x64_t>(
3161 matC_128x64, matAcc_128x64);
3166 for (
int transp256_loop = 0; transp256_loop < transp256_loop_num;
3173 uint32_t width_a = tru_seqlen_ex;
3175 = (batchid * numhead + headid) * max_seqlen + tru_seqlen;
3176 uint32_t pitch_a = max_seqlen;
3177 int start_x_a = transp256_loop * 256;
3178 int start_y_a = (batchid * numhead + headid) * max_seqlen;
3180 gemm_arg_256x64.matA_base_desc.init({args->
matdW_ptr},
3181 {height_a, width_a, pitch_a}, {start_y_a, start_x_a});
3183 uint32_t width_b = (headid + 1) * hdsz;
3184 uint32_t height_b = tru_seqlen + seqlen_entry;
3185 uint32_t pitch_b = hiddensize;
3186 int start_x_b = headid * hdsz;
3187 int start_y_b = seqlen_entry;
3189 gemm_arg_256x64.matB_base_desc.init({args->
matQ_ptr},
3190 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
3192 gemm_arg_256x64.inner_loop_count
3195 matAcc_256x64.init(0);
3197 gemm_op_256x64_trnp_a(g_thd32_tid, matAcc_256x64, gemm_arg_256x64);
3199 int width_c = (headid + 1) * hdsz;
3200 int height_c = tru_seqlen + seqlen_entry;
3201 int pitch_c = hiddensize;
3202 int start_x_c = headid * hdsz
3203 + gemm_op_256x64_trnp_a_t::get_matC_offset_x(g_thd32_tid);
3204 int start_y_c = transp256_loop * 256 + seqlen_entry
3205 + gemm_op_256x64_trnp_a_t::get_matC_offset_y(g_thd32_tid);
3207 matC_256x64_payload.init(args->
matdK_ptr, width_c, height_c,
3208 pitch_c, start_x_c, start_y_c);
3217 for (
int transp128_loop = 0; transp128_loop < transp128_loop_num;
3224 uint32_t width_a = tru_seqlen_ex;
3226 = (batchid * numhead + headid) * max_seqlen + tru_seqlen;
3227 uint32_t pitch_a = max_seqlen;
3228 int start_x_a = transp128_loop * 128 + offset_blk_128x128;
3229 int start_y_a = (batchid * numhead + headid) * max_seqlen;
3231 gemm_arg_128x64.matA_base_desc.init({args->
matdW_ptr},
3232 {height_a, width_a, pitch_a}, {start_y_a, start_x_a});
3234 uint32_t width_b = (headid + 1) * hdsz;
3235 uint32_t height_b = tru_seqlen + seqlen_entry;
3236 uint32_t pitch_b = hiddensize;
3237 int start_x_b = headid * hdsz;
3238 int start_y_b = seqlen_entry;
3240 gemm_arg_128x64.matB_base_desc.init({args->
matQ_ptr},
3241 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
3243 gemm_arg_128x64.inner_loop_count
3246 matAcc_128x64.init(0);
3248 gemm_op_128x64_trnp_a(g_thd32_tid, matAcc_128x64, gemm_arg_128x64);
3250 int width_c = (headid + 1) * hdsz;
3251 int height_c = tru_seqlen + seqlen_entry;
3252 int pitch_c = hiddensize;
3253 int start_x_c = headid * hdsz
3254 + gemm_op_128x64_trnp_a_t::get_matC_offset_x(g_thd32_tid);
3255 int start_y_c = transp128_loop * 128 + seqlen_entry
3256 + offset_blk_128x128
3257 + gemm_op_128x64_trnp_a_t::get_matC_offset_y(g_thd32_tid);
3259 matC_128x64_payload.init(args->
matdK_ptr, width_c, height_c,
3260 pitch_c, start_x_c, start_y_c);
Gemm functor.
Definition api.hpp:52
#define __XETLA_API
Definition common.hpp:43
#define xetla_select
xetla select.
Definition base_ops.hpp:49
#define xetla_format
xetla format.
Definition base_ops.hpp:38
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
__XETLA_API xetla_vector< Ty, N *NElts > xetla_load_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_mask< N > pred=1)
Stateless scattered load.
Definition memory.hpp:245
__XETLA_API xetla_vector< uint32_t, 4 > get_time_stamp()
Returns time stamp.
Definition misc.hpp:57
__XETLA_API std::enable_if_t< arch_tag==gpu_arch::Xe, void > xetla_tatomic_store_global(uint64_t base_address, xetla_vector< Toffset, N > offset, xetla_vector< Ty, N > data, xetla_mask< N > pred=1)
Tensor atomic store API.
Definition raw_send_load_store.hpp:294
#define rand_threshold_const
Definition mha_attn_reg.hpp:26
#define list_width
Definition mha_attn_reg.hpp:25
#define SIGN_BIT_W16
Definition mha_attn_reg.hpp:28
#define SIGN_BIT_B8
Definition mha_attn_reg.hpp:29
Definition limitation.hpp:734
__XETLA_API std::enable_if_t<(T_src::register_layout !=reg_layout::linear) &&(T_dst::register_layout !=reg_layout::linear) &&is_same_layout< T_dst, T_src >::value &&(!is_floating_to_integer< T_dst, T_src >::value)> elemwise_cvt(T_dst &dst, T_src &src)
Is the element wise data conversion, the src and dst tile should have the same layout.
Definition op_function.hpp:40
__XETLA_API std::enable_if_t< detail::check_store_type< tile_t, payload_t >::is_global_2d_xe > tile_store(tile_t &tile, payload_t &payload)
Is the func storing data from register file to global memory.
Definition store_xe.hpp:91
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
mem_space
Definition common.hpp:77
@ fmax
Atomic store the float max of src1 and memory data and return the old value. see
@ fadd
Atomic float add of src1 from memory data and return the old value. see
gpu_arch
Definition common.hpp:73
mem_layout
Definition common.hpp:76
Compute attribute for gemm.
Definition common.hpp:32
Compute policy for xmx engine.
Definition compute_policy.hpp:35
Fine-tune knobs for gemm.
Definition common.hpp:43
Gemm default pre_processing functor.
Definition api.hpp:33
Gemm pre_processing functor with applying relu op to matA.
Definition api.hpp:39
Workgroup level tile shape description.
Definition tile_shape.hpp:34
Arguments for xetla_softmax_bwd_t::run.
Definition mha_attn_reg.hpp:2026
dtype_sfx * matW_ptr
Definition mha_attn_reg.hpp:2035
float * matSum_ptr
Definition mha_attn_reg.hpp:2040
dtype_bin * matV_ptr
Definition mha_attn_reg.hpp:2031
uint32_t * matMkdpot_ptr
Definition mha_attn_reg.hpp:2033
dtype_sfx * matdW_ptr
Definition mha_attn_reg.hpp:2036
float Pinv
Definition mha_attn_reg.hpp:2041
dtype_bin * matdO_ptr
Definition mha_attn_reg.hpp:2034
dtype_bin * matQ_ptr
Definition mha_attn_reg.hpp:2029
dtype_bot * matdK_ptr
Definition mha_attn_reg.hpp:2039
dtype_bin * matK_ptr
Definition mha_attn_reg.hpp:2030
uint32_t * matMkin_ptr
Definition mha_attn_reg.hpp:2032
uint32_t * mList_ptr
Definition mha_attn_reg.hpp:2028
dtype_bot * matdQ_ptr
Definition mha_attn_reg.hpp:2038
dtype_bot * matdV_ptr
Definition mha_attn_reg.hpp:2037
float Scaling
Definition mha_attn_reg.hpp:2042
Definition mha_attn_reg.hpp:1641
static constexpr uint32_t prefetch_distance
Definition mha_attn_reg.hpp:1669
dtype_bwd_acc_ dtype_acc
Definition mha_attn_reg.hpp:1645
subgroup::tile_desc_t< matAcc_32x1024_t::tile_desc::tile_size_x, matAcc_32x1024_t::tile_desc::tile_size_y, matAcc_32x1024_t::tile_desc::block_size_x, matAcc_32x1024_t::tile_desc::block_size_y, reg_layout::tiled > matC_32x1024_tile_desc_t
Definition mha_attn_reg.hpp:1830
group::compute_policy_default_xmx< group::compute_attr_t< dtype_sfx, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe > compute_policy_out_b_trnp_a
Definition mha_attn_reg.hpp:1707
static constexpr mem_space mem_space_a
Definition mha_attn_reg.hpp:1649
subgroup::tile_desc_t< matAcc_64x384_t::tile_desc::tile_size_x, matAcc_64x384_t::tile_desc::tile_size_y, matAcc_64x384_t::tile_desc::block_size_x, matAcc_64x384_t::tile_desc::block_size_y, reg_layout::tiled > matC_64x384_tile_desc_t
Definition mha_attn_reg.hpp:1818
typename gemm_op_128x64_trnp_a_t::matAcc_t matAcc_128x64_trnp_a_t
Definition mha_attn_reg.hpp:1796
group::tile_shape_t< 64, 128, 16, 16 > tile_attr_128x64
Definition mha_attn_reg.hpp:1683
typename gemm_op_64x512_t::matAcc_t matAcc_64x512_t
Definition mha_attn_reg.hpp:1791
typename gemm_op_16x2048_t::arguments_t gemm_arguments_16x2048
Definition mha_attn_reg.hpp:1776
static constexpr mem_layout gemm_mem_layout_QKT_b
Definition mha_attn_reg.hpp:1665
group::perf_tuning_knob_t< k_stride, prefetch_distance, periodic_sync_interval > bgm_perf_tuning_knob
Definition mha_attn_reg.hpp:1674
group::tile_shape_t< 256, 128, 64, 16 > tile_attr_128x256
Definition mha_attn_reg.hpp:1677
subgroup::tile_desc_t< matAcc_128x64_t::tile_desc::tile_size_x, matAcc_128x64_t::tile_desc::tile_size_y, matAcc_128x64_t::tile_desc::block_size_x, matAcc_128x64_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x64_tile_desc_t
Definition mha_attn_reg.hpp:1900
typename gemm_op_128x64_trnp_a_t::arguments_t gemm_arguments_128x64_trnp_a
Definition mha_attn_reg.hpp:1780
dtype_bwd_sfx_ dtype_sfx
Definition mha_attn_reg.hpp:1644
typename gemm_op_256x64_trnp_af_t::matAcc_t matAcc_256x64_trnp_af_t
Definition mha_attn_reg.hpp:1799
typename gemm_op_256x64_trnp_af_t::arguments_t gemm_arguments_256x64_trnp_af
Definition mha_attn_reg.hpp:1786
group::compute_policy_default_xmx< group::compute_attr_t< dtype_sfx, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe > compute_policy_out
Definition mha_attn_reg.hpp:1699
subgroup::tile_t< dtype_bot, matC_128x64_trnp_a_tile_desc_t > matC_128x64_trnp_a_t
Definition mha_attn_reg.hpp:1925
work_group_t< ThreadNum > work_group_t
Definition mha_attn_reg.hpp:1714
subgroup::tile_desc_t< matAcc_128x128_t::tile_desc::tile_size_x, matAcc_128x128_t::tile_desc::tile_size_y, matAcc_128x128_t::tile_desc::block_size_x, matAcc_128x128_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x128_tile_desc_t
Definition mha_attn_reg.hpp:1806
static constexpr mem_layout gemm_mem_layout_a
Definition mha_attn_reg.hpp:1661
dtype_bwd_bot_ dtype_bot
Definition mha_attn_reg.hpp:1643
group::tile_shape_t< 64, 256, 16, 32 > tile_attr_256x64
Definition mha_attn_reg.hpp:1682
typename gemm_op_256x64_trnp_a_t::arguments_t gemm_arguments_256x64_trnp_a
Definition mha_attn_reg.hpp:1782
typename gemm_op_128x64_trnp_af_t::matAcc_t matAcc_128x64_trnp_af_t
Definition mha_attn_reg.hpp:1798
group::tile_shape_t< 512, 64, 64, 16 > tile_attr_64x512
Definition mha_attn_reg.hpp:1679
static constexpr mem_space gemm_mem_space_trnp_a
Definition mha_attn_reg.hpp:1660
static constexpr uint32_t global_kslicing
Definition mha_attn_reg.hpp:1709
subgroup::tile_desc_t< matAcc_128x64_trnp_af_t::tile_desc::tile_size_x, matAcc_128x64_trnp_af_t::tile_desc::tile_size_y, matAcc_128x64_trnp_af_t::tile_desc::block_size_x, matAcc_128x64_trnp_af_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x64_trnp_af_tile_desc_t
Definition mha_attn_reg.hpp:1916
static constexpr mem_layout mem_layout_out_b
Definition mha_attn_reg.hpp:1656
typename gemm_op_128x256_t::matAcc_t matAcc_128x256_t
Definition mha_attn_reg.hpp:1789
subgroup::tile_desc_t< matAcc_128x64_trnp_a_t::tile_desc::tile_size_x, matAcc_128x64_trnp_a_t::tile_desc::tile_size_y, matAcc_128x64_trnp_a_t::tile_desc::block_size_x, matAcc_128x64_trnp_a_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x64_trnp_a_tile_desc_t
Definition mha_attn_reg.hpp:1905
static constexpr mem_layout mem_layout_trnp_a
Definition mha_attn_reg.hpp:1654
subgroup::tile_t< dtype_bot, matC_256x64_trnp_a_tile_desc_t > matC_256x64_trnp_a_t
Definition mha_attn_reg.hpp:1927
typename gemm_op_32x1024_t::matAcc_t matAcc_32x1024_t
Definition mha_attn_reg.hpp:1792
static constexpr mem_layout mem_layout_QKT_b
Definition mha_attn_reg.hpp:1655
typename gemm_op_128x128_t::arguments_t gemm_arguments_128x128
Definition mha_attn_reg.hpp:1771
mem_desc_t< dtype_bin, gemm_mem_layout_QKT_b, gemm_mem_space_b > mem_desc_b_QKT
Definition mha_attn_reg.hpp:1688
group::tile_shape_t< 1024, 32, 64, 16 > tile_attr_32x1024
Definition mha_attn_reg.hpp:1680
subgroup::tile_desc_t< matAcc_64x512_t::tile_desc::tile_size_x, matAcc_64x512_t::tile_desc::tile_size_y, matAcc_64x512_t::tile_desc::block_size_x, matAcc_64x512_t::tile_desc::block_size_y, reg_layout::tiled > matC_64x512_tile_desc_t
Definition mha_attn_reg.hpp:1824
subgroup::tile_t< dtype_bot, matC_256x64_trnp_af_tile_desc_t > matC_256x64_trnp_af_t
Definition mha_attn_reg.hpp:1931
static constexpr uint32_t k_stride
Definition mha_attn_reg.hpp:1672
typename gemm_op_64x384_t::matAcc_t matAcc_64x384_t
Definition mha_attn_reg.hpp:1790
static constexpr mem_layout gemm_mem_layout_out_b
Definition mha_attn_reg.hpp:1666
subgroup::tile_t< dtype_bot, matC_128x64_trnp_af_tile_desc_t > matC_128x64_trnp_af_t
Definition mha_attn_reg.hpp:1929
typename gemm_op_128x128_t::matAcc_t matAcc_128x128_t
Definition mha_attn_reg.hpp:1788
group::tile_shape_t< 2048, 16, 64, 16 > tile_attr_16x2048
Definition mha_attn_reg.hpp:1681
mem_desc_t< dtype_sfx, gemm_mem_layout_trnp_a, gemm_mem_space_trnp_a > mem_desc_a_out_b_trnp_a
Definition mha_attn_reg.hpp:1702
typename gemm_op_32x1024_t::arguments_t gemm_arguments_32x1024
Definition mha_attn_reg.hpp:1775
typename gemm_op_64x384_t::arguments_t gemm_arguments_64x384
Definition mha_attn_reg.hpp:1773
mem_desc_t< dtype_sfx, gemm_mem_layout_a, gemm_mem_space_a > mem_desc_a_out
Definition mha_attn_reg.hpp:1694
subgroup::tile_desc_t< matAcc_128x256_t::tile_desc::tile_size_x, matAcc_128x256_t::tile_desc::tile_size_y, matAcc_128x256_t::tile_desc::block_size_x, matAcc_128x256_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x256_tile_desc_t
Definition mha_attn_reg.hpp:1812
static constexpr mem_space mem_space_c
Definition mha_attn_reg.hpp:1651
typename gemm_op_128x64_t::matAcc_t matAcc_128x64_t
Definition mha_attn_reg.hpp:1795
group::compute_policy_default_xmx< group::compute_attr_t< dtype_bin, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe > compute_policy_QKT
Definition mha_attn_reg.hpp:1691
mem_desc_t< dtype_bin, gemm_mem_layout_a, gemm_mem_space_a > mem_desc_a_QKT
Definition mha_attn_reg.hpp:1686
typename gemm_op_128x64_t::arguments_t gemm_arguments_128x64
Definition mha_attn_reg.hpp:1778
group::tile_shape_t< 384, 64, 48, 16 > tile_attr_64x384
Definition mha_attn_reg.hpp:1678
typename gemm_op_64x512_t::arguments_t gemm_arguments_64x512
Definition mha_attn_reg.hpp:1774
static constexpr mem_space gemm_mem_space_b
Definition mha_attn_reg.hpp:1664
static constexpr uint32_t periodic_sync_interval
Definition mha_attn_reg.hpp:1668
static constexpr mem_layout gemm_mem_layout_trnp_a
Definition mha_attn_reg.hpp:1662
static constexpr mem_space mem_space_b
Definition mha_attn_reg.hpp:1650
typename gemm_op_256x64_trnp_a_t::matAcc_t matAcc_256x64_trnp_a_t
Definition mha_attn_reg.hpp:1797
static __XETLA_API void call(sycl::nd_item< 3 > &item, arguments_t *args)
Main execution function for fused mha softmax The basic process is GEMM -> Softmax -> GEMM.
Definition mha_attn_reg.hpp:2048
static constexpr mem_space gemm_mem_space_a
Definition mha_attn_reg.hpp:1659
subgroup::tile_desc_t< matAcc_256x64_trnp_af_t::tile_desc::tile_size_x, matAcc_256x64_trnp_af_t::tile_desc::tile_size_y, matAcc_256x64_trnp_af_t::tile_desc::block_size_x, matAcc_256x64_trnp_af_t::tile_desc::block_size_y, reg_layout::tiled > matC_256x64_trnp_af_tile_desc_t
Definition mha_attn_reg.hpp:1922
static constexpr int ThreadNum
Definition mha_attn_reg.hpp:1647
subgroup::tile_desc_t< matAcc_16x2048_t::tile_desc::tile_size_x, matAcc_16x2048_t::tile_desc::tile_size_y, matAcc_16x2048_t::tile_desc::block_size_x, matAcc_16x2048_t::tile_desc::block_size_y, reg_layout::tiled > matC_16x2048_tile_desc_t
Definition mha_attn_reg.hpp:1836
static constexpr uint16_t sfx_type_size
Definition mha_attn_reg.hpp:1710
mem_desc_t< dtype_bin, gemm_mem_layout_out_b, gemm_mem_space_b > mem_desc_b_out
Definition mha_attn_reg.hpp:1696
dtype_bwd_bin_ dtype_bin
Definition mha_attn_reg.hpp:1642
typename gemm_op_128x64_trnp_af_t::arguments_t gemm_arguments_128x64_trnp_af
Definition mha_attn_reg.hpp:1784
typename gemm_op_128x256_t::arguments_t gemm_arguments_128x256
Definition mha_attn_reg.hpp:1772
mem_desc_t< dtype_bin, gemm_mem_layout_out_b, gemm_mem_space_b > mem_desc_b_out_b_trnp_a
Definition mha_attn_reg.hpp:1704
static constexpr mem_layout mem_layout_c
Definition mha_attn_reg.hpp:1657
group::tile_shape_t< 128, 128, 32, 16 > tile_attr_128x128
Definition mha_attn_reg.hpp:1676
subgroup::tile_desc_t< matAcc_256x64_trnp_a_t::tile_desc::tile_size_x, matAcc_256x64_trnp_a_t::tile_desc::tile_size_y, matAcc_256x64_trnp_a_t::tile_desc::block_size_x, matAcc_256x64_trnp_a_t::tile_desc::block_size_y, reg_layout::tiled > matC_256x64_trnp_a_tile_desc_t
Definition mha_attn_reg.hpp:1910
typename gemm_op_16x2048_t::matAcc_t matAcc_16x2048_t
Definition mha_attn_reg.hpp:1793
static constexpr mem_layout mem_layout_a
Definition mha_attn_reg.hpp:1653
Arguments for xetla_softmax_fwd_t::run.
Definition mha_attn_reg.hpp:301
dtype_bot * matOut_ptr
Definition mha_attn_reg.hpp:310
uint32_t * matMkdpot_ptr
Definition mha_attn_reg.hpp:308
dtype_sfx * matQKT_ptr
Definition mha_attn_reg.hpp:309
float * Max_ptr
Definition mha_attn_reg.hpp:311
dtype_bin * matV_ptr
Definition mha_attn_reg.hpp:306
float Scaling
Definition mha_attn_reg.hpp:314
float * Sum_ptr
Definition mha_attn_reg.hpp:312
float Pinv
Definition mha_attn_reg.hpp:313
uint32_t * matMkin_ptr
Definition mha_attn_reg.hpp:307
uint32_t * mList_ptr
Definition mha_attn_reg.hpp:303
dtype_bin * matQ_ptr
Definition mha_attn_reg.hpp:304
dtype_bin * matK_ptr
Definition mha_attn_reg.hpp:305
Definition mha_attn_reg.hpp:34
subgroup::tile_desc_t< matAcc_32x1024_t::tile_desc::tile_size_x, matAcc_32x1024_t::tile_desc::tile_size_y, matAcc_32x1024_t::tile_desc::block_size_x, matAcc_32x1024_t::tile_desc::block_size_y, reg_layout::tiled > mat_32x1024_tile_desc_t
Definition mha_attn_reg.hpp:177
typename gemm_op_16x2048_t::matAcc_t matAcc_16x2048_t
Definition mha_attn_reg.hpp:145
group::compute_policy_default_xmx< group::compute_attr_t< dtype_sfx, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe > compute_policy_out
Definition mha_attn_reg.hpp:88
static constexpr mem_layout mem_layout_out_b
Definition mha_attn_reg.hpp:49
static constexpr mem_layout mem_layout_QKT_b
Definition mha_attn_reg.hpp:48
subgroup::tile_desc_t< matAcc_16x2048_t::tile_desc::tile_size_x, matAcc_16x2048_t::tile_desc::tile_size_y, matAcc_16x2048_t::tile_desc::block_size_x, matAcc_16x2048_t::tile_desc::block_size_y, reg_layout::tiled > mat_16x2048_tile_desc_t
Definition mha_attn_reg.hpp:183
static constexpr mem_layout mem_layout_c
Definition mha_attn_reg.hpp:50
group::tile_shape_t< 64, 128, 16, 16 > tile_attr_128x64
Definition mha_attn_reg.hpp:72
dtype_acc_ dtype_acc
Definition mha_attn_reg.hpp:38
static constexpr mem_space mem_space_a
Definition mha_attn_reg.hpp:42
static constexpr mem_layout gemm_mem_layout_QKT_b
Definition mha_attn_reg.hpp:56
group::tile_shape_t< 1024, 32, 64, 16 > tile_attr_32x1024
Definition mha_attn_reg.hpp:70
static constexpr uint32_t prefetch_distance
Definition mha_attn_reg.hpp:60
static constexpr uint32_t k_stride
Definition mha_attn_reg.hpp:62
typename gemm_op_64x384_t::matAcc_t matAcc_64x384_t
Definition mha_attn_reg.hpp:142
static constexpr uint16_t sfx_type_size
Definition mha_attn_reg.hpp:91
group::perf_tuning_knob_t< k_stride, prefetch_distance, periodic_sync_interval > bgm_perf_tuning_knob
Definition mha_attn_reg.hpp:64
static constexpr uint16_t Rand_SIMD
Definition mha_attn_reg.hpp:45
mem_desc_t< dtype_sfx, gemm_mem_layout_a, gemm_mem_space_a > mem_desc_a_out
Definition mha_attn_reg.hpp:83
dtype_bin_ dtype_bin
Definition mha_attn_reg.hpp:35
group::compute_policy_default_xmx< group::compute_attr_t< dtype_bin, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe > compute_policy_QKT
Definition mha_attn_reg.hpp:80
typename gemm_op_64x512_t::arguments_t gemm_arguments_64x512
Definition mha_attn_reg.hpp:135
static constexpr mem_layout gemm_mem_layout_a
Definition mha_attn_reg.hpp:53
subgroup::tile_desc_t< matAcc_128x128_t::tile_desc::tile_size_x, matAcc_128x128_t::tile_desc::tile_size_y, matAcc_128x128_t::tile_desc::block_size_x, matAcc_128x128_t::tile_desc::block_size_y, reg_layout::tiled > mat_128x128_tile_desc_t
Definition mha_attn_reg.hpp:153
mem_desc_t< dtype_bin, gemm_mem_layout_QKT_b, gemm_mem_space_b > mem_desc_b_QKT
Definition mha_attn_reg.hpp:77
group::tile_shape_t< 256, 128, 64, 16 > tile_attr_128x256
Definition mha_attn_reg.hpp:67
group::tile_shape_t< 384, 64, 48, 16 > tile_attr_64x384
Definition mha_attn_reg.hpp:68
typename gemm_op_128x256_t::matAcc_t matAcc_128x256_t
Definition mha_attn_reg.hpp:141
dtype_bot_ dtype_bot
Definition mha_attn_reg.hpp:36
group::tile_shape_t< 512, 64, 64, 16 > tile_attr_64x512
Definition mha_attn_reg.hpp:69
dtype_sfx_ dtype_sfx
Definition mha_attn_reg.hpp:37
static constexpr int ThreadNum
Definition mha_attn_reg.hpp:40
group::tile_shape_t< 2048, 16, 64, 16 > tile_attr_16x2048
Definition mha_attn_reg.hpp:71
subgroup::tile_desc_t< matAcc_128x256_t::tile_desc::tile_size_x, matAcc_128x256_t::tile_desc::tile_size_y, matAcc_128x256_t::tile_desc::block_size_x, matAcc_128x256_t::tile_desc::block_size_y, reg_layout::tiled > mat_128x256_tile_desc_t
Definition mha_attn_reg.hpp:159
subgroup::tile_desc_t< matAcc_64x384_t::tile_desc::tile_size_x, matAcc_64x384_t::tile_desc::tile_size_y, matAcc_64x384_t::tile_desc::block_size_x, matAcc_64x384_t::tile_desc::block_size_y, reg_layout::tiled > mat_64x384_tile_desc_t
Definition mha_attn_reg.hpp:165
subgroup::tile_desc_t< matAcc_128x64_t::tile_desc::tile_size_x, matAcc_128x64_t::tile_desc::tile_size_y, matAcc_128x64_t::tile_desc::block_size_x, matAcc_128x64_t::tile_desc::block_size_y, reg_layout::tiled > mat_128x64_tile_desc_t
Definition mha_attn_reg.hpp:189
group::tile_shape_t< 128, 128, 32, 16 > tile_attr_128x128
Definition mha_attn_reg.hpp:66
typename gemm_op_128x128_t::arguments_t gemm_arguments_128x128
Definition mha_attn_reg.hpp:132
typename gemm_op_128x64_t::matAcc_t matAcc_128x64_t
Definition mha_attn_reg.hpp:146
work_group_t< ThreadNum > work_group_t
Definition mha_attn_reg.hpp:95
typename gemm_op_64x384_t::arguments_t gemm_arguments_64x384
Definition mha_attn_reg.hpp:134
static constexpr mem_space mem_space_b
Definition mha_attn_reg.hpp:43
typename gemm_op_128x64_t::arguments_t gemm_arguments_128x64
Definition mha_attn_reg.hpp:138
mem_desc_t< dtype_bin, gemm_mem_layout_out_b, gemm_mem_space_b > mem_desc_b_out
Definition mha_attn_reg.hpp:85
static constexpr mem_space gemm_mem_space_b
Definition mha_attn_reg.hpp:55
typename gemm_op_128x128_t::matAcc_t matAcc_128x128_t
Definition mha_attn_reg.hpp:140
mem_desc_t< dtype_bin, gemm_mem_layout_a, gemm_mem_space_a > mem_desc_a_QKT
Definition mha_attn_reg.hpp:75
static constexpr mem_space mem_space_c
Definition mha_attn_reg.hpp:44
typename gemm_op_16x2048_t::arguments_t gemm_arguments_16x2048
Definition mha_attn_reg.hpp:137
static constexpr mem_layout gemm_mem_layout_out_b
Definition mha_attn_reg.hpp:57
typename gemm_op_64x512_t::matAcc_t matAcc_64x512_t
Definition mha_attn_reg.hpp:143
static constexpr uint32_t global_kslicing
Definition mha_attn_reg.hpp:90
static constexpr uint32_t periodic_sync_interval
Definition mha_attn_reg.hpp:59
typename gemm_op_32x1024_t::arguments_t gemm_arguments_32x1024
Definition mha_attn_reg.hpp:136
static constexpr mem_space gemm_mem_space_a
Definition mha_attn_reg.hpp:52
static __XETLA_API void call(sycl::nd_item< 3 > &item, arguments_t *args)
Main execution function for fused mha softmax The basic process is GEMM -> Softmax -> GEMM.
Definition mha_attn_reg.hpp:320
subgroup::tile_desc_t< matAcc_64x512_t::tile_desc::tile_size_x, matAcc_64x512_t::tile_desc::tile_size_y, matAcc_64x512_t::tile_desc::block_size_x, matAcc_64x512_t::tile_desc::block_size_y, reg_layout::tiled > mat_64x512_tile_desc_t
Definition mha_attn_reg.hpp:171
typename gemm_op_32x1024_t::matAcc_t matAcc_32x1024_t
Definition mha_attn_reg.hpp:144
static constexpr int max_seqlen
Definition mha_attn_reg.hpp:41
static constexpr mem_layout mem_layout_a
Definition mha_attn_reg.hpp:47
typename gemm_op_128x256_t::arguments_t gemm_arguments_128x256
Definition mha_attn_reg.hpp:133
Definition memory_descriptor.hpp:139
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102
xetla nbarrier definition API.
Definition raw_send_nbarrier.hpp:43
__XETLA_API void arrive()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:65
__XETLA_API void init_nbarrier(uint8_t nbarrier_id, nbarrier_role role=nbarrier_role::producer_consumer)
Definition raw_send_nbarrier.hpp:55
__XETLA_API void wait()
named barrier wait within subgroup.
Definition raw_send_nbarrier.hpp:76
__XETLA_API xetla_vector< uint32_t, 4 *SIMD > rand()
Definition rand.hpp:57
__XETLA_API void init(uint64_t seed, uint64_t subseq, uint64_t offset)
Definition rand.hpp:38