XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
mha_attn_reg.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#pragma once
18
19#include "common/common.hpp"
20#include "group/group.hpp"
21#include "subgroup/subgroup.hpp"
22
23namespace gpu::xetla::kernel {
24
25#define list_width 16
26#define rand_threshold_const 0x80000000
27#define SIGN_BIT_DW 0x80000000
28#define SIGN_BIT_W16 0x8000
29#define SIGN_BIT_B8 0x80
30
31template <typename dtype_bin_, typename dtype_bot_, typename dtype_sfx_,
32 typename dtype_acc_, int HWThreadNum, bool Dopt_RandGenflag = true,
33 uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
35 using dtype_bin = dtype_bin_;
36 using dtype_bot = dtype_bot_;
37 using dtype_sfx = dtype_sfx_;
38 using dtype_acc = dtype_acc_;
39
40 static constexpr int ThreadNum = HWThreadNum;
41 static constexpr int max_seqlen = Max_SeqLen;
45 static constexpr uint16_t Rand_SIMD = RandSIMD;
46
51
54
58
59 static constexpr uint32_t periodic_sync_interval = 0;
60 static constexpr uint32_t prefetch_distance = 3;
61 static constexpr uint32_t k_stride
62 = 32 / sizeof(dtype_bin); //gemm_t::k_stride;
65
73
81
89
90 static constexpr uint32_t global_kslicing = 1;
91 static constexpr uint16_t sfx_type_size = sizeof(dtype_sfx);
92 static_assert((sfx_type_size == 1) || (sfx_type_size == 2)
93 || (sfx_type_size == 4));
94
96
112
131
132 using gemm_arguments_128x128 = typename gemm_op_128x128_t::arguments_t;
133 using gemm_arguments_128x256 = typename gemm_op_128x256_t::arguments_t;
134 using gemm_arguments_64x384 = typename gemm_op_64x384_t::arguments_t;
135 using gemm_arguments_64x512 = typename gemm_op_64x512_t::arguments_t;
136 using gemm_arguments_32x1024 = typename gemm_op_32x1024_t::arguments_t;
137 using gemm_arguments_16x2048 = typename gemm_op_16x2048_t::arguments_t;
138 using gemm_arguments_128x64 = typename gemm_op_128x64_t::arguments_t;
139
140 using matAcc_128x128_t = typename gemm_op_128x128_t::matAcc_t;
141 using matAcc_128x256_t = typename gemm_op_128x256_t::matAcc_t;
142 using matAcc_64x384_t = typename gemm_op_64x384_t::matAcc_t;
143 using matAcc_64x512_t = typename gemm_op_64x512_t::matAcc_t;
144 using matAcc_32x1024_t = typename gemm_op_32x1024_t::matAcc_t;
145 using matAcc_16x2048_t = typename gemm_op_16x2048_t::matAcc_t;
146 using matAcc_128x64_t = typename gemm_op_128x64_t::matAcc_t;
147
149 = subgroup::tile_desc_t<matAcc_128x128_t::tile_desc::tile_size_x,
150 matAcc_128x128_t::tile_desc::tile_size_y,
151 matAcc_128x128_t::tile_desc::block_size_x,
152 matAcc_128x128_t::tile_desc::block_size_y,
155 = subgroup::tile_desc_t<matAcc_128x256_t::tile_desc::tile_size_x,
156 matAcc_128x256_t::tile_desc::tile_size_y,
157 matAcc_128x256_t::tile_desc::block_size_x,
158 matAcc_128x256_t::tile_desc::block_size_y,
161 = subgroup::tile_desc_t<matAcc_64x384_t::tile_desc::tile_size_x,
162 matAcc_64x384_t::tile_desc::tile_size_y,
163 matAcc_64x384_t::tile_desc::block_size_x,
164 matAcc_64x384_t::tile_desc::block_size_y,
167 = subgroup::tile_desc_t<matAcc_64x512_t::tile_desc::tile_size_x,
168 matAcc_64x512_t::tile_desc::tile_size_y,
169 matAcc_64x512_t::tile_desc::block_size_x,
170 matAcc_64x512_t::tile_desc::block_size_y,
173 = subgroup::tile_desc_t<matAcc_32x1024_t::tile_desc::tile_size_x,
174 matAcc_32x1024_t::tile_desc::tile_size_y,
175 matAcc_32x1024_t::tile_desc::block_size_x,
176 matAcc_32x1024_t::tile_desc::block_size_y,
179 = subgroup::tile_desc_t<matAcc_16x2048_t::tile_desc::tile_size_x,
180 matAcc_16x2048_t::tile_desc::tile_size_y,
181 matAcc_16x2048_t::tile_desc::block_size_x,
182 matAcc_16x2048_t::tile_desc::block_size_y,
185 = subgroup::tile_desc_t<matAcc_128x64_t::tile_desc::tile_size_x,
186 matAcc_128x64_t::tile_desc::tile_size_y,
187 matAcc_128x64_t::tile_desc::block_size_x,
188 matAcc_128x64_t::tile_desc::block_size_y,
197
202 : subgroup::msg_type_v<
204 gpu_arch::Xe>;
209 : subgroup::msg_type_v<
211 gpu_arch::Xe>;
215 (global_kslicing > 1)
217 : subgroup::msg_type_v<mat_64x384_tile_desc_t, mem_space_c>,
218 gpu_arch::Xe>;
222 (global_kslicing > 1)
224 : subgroup::msg_type_v<mat_64x512_tile_desc_t, mem_space_c>,
225 gpu_arch::Xe>;
230 : subgroup::msg_type_v<
232 gpu_arch::Xe>;
237 : subgroup::msg_type_v<
239 gpu_arch::Xe>;
243 (global_kslicing > 1)
245 : subgroup::msg_type_v<mat_128x64_tile_desc_t, mem_space_c>,
246 gpu_arch::Xe>;
247
262
266 subgroup::msg_type_v<mat_128x128_tile_desc_t, mem_space_c>,
271 subgroup::msg_type_v<mat_128x256_tile_desc_t, mem_space_c>,
276 subgroup::msg_type_v<mat_64x384_tile_desc_t, mem_space_c>,
281 subgroup::msg_type_v<mat_64x512_tile_desc_t, mem_space_c>,
286 subgroup::msg_type_v<mat_32x1024_tile_desc_t, mem_space_c>,
291 subgroup::msg_type_v<mat_16x2048_tile_desc_t, mem_space_c>,
296 subgroup::msg_type_v<mat_128x64_tile_desc_t, mem_space_c>,
298
301 struct arguments_t {
302 // assume base address, surface width, height, pitch, start coordinate was set
303 uint32_t *mList_ptr;
307 uint32_t *matMkin_ptr;
308 uint32_t *matMkdpot_ptr;
311 float *Max_ptr;
312 float *Sum_ptr;
313 float Pinv;
314 float Scaling;
315 };
316
320 __XETLA_API static void call(sycl::nd_item<3> &item, arguments_t *args) {
321
322 int tru_seqlen = 0;
323 int tru_seqlen_ex = 0;
324 int seqlen_entry = 0;
325
326 int groupid = item.get_group(0);
327 int hiddensize = 1024;
328 int numhead = 16;
329 int hdsz = 64;
330 int wg_tile_QKT_k = hdsz; //args->matrix_k;
331 int wg_tile_out_k;
332 int batchid = groupid / numhead;
333 int headid = groupid % numhead;
334
335 work_group_t g_thd32_tid;
336 int tid_linear = item.get_local_linear_id();
337 g_thd32_tid.init(tid_linear);
338
339 uint32_t batch_offset = sizeof(uint32_t) * list_width * batchid;
341 = xetla_vector_gen<uint32_t, list_width>(0, 1);
342 list_offsets *= sizeof(uint32_t);
343 list_offsets += batch_offset;
344
348 list_width>(args->mList_ptr, list_offsets);
349 tru_seqlen = list_vec[0];
350 seqlen_entry = list_vec[1];
351 wg_tile_out_k = tru_seqlen;
352 tru_seqlen_ex = tru_seqlen; //DW align
353 if (sfx_type_size == 2)
354 tru_seqlen_ex = (((tru_seqlen + 1) >> 1) << 1);
355 else if (sfx_type_size == 1)
356 tru_seqlen_ex = (((tru_seqlen + 3) >> 2) << 2);
357 //float totalscaling = args->Pinv * args->Scaling;
358
360 uint32_t rand_threshold = rand_threshold_const;
361 if constexpr (Dopt_RandGenflag == true) {
362 uint64_t rand_seed = 67280421310721;
363 uint64_t rand_subseq
364 = (groupid * ThreadNum + tid_linear) * Rand_SIMD;
365 uint64_t rand_offset = list_vec.xetla_format<uint64_t>()[1];
366 if (list_vec[4] != 0) rand_threshold = list_vec[4];
367 if (rand_offset == 0) {
369 rand_offset = time_stamp.xetla_format<uint64_t>()[0];
370 }
371 Rand_Gen.init(rand_seed, rand_subseq, rand_offset);
372 }
373
374 //std_leqlen = 256
375 int all_vert_loop_num = 2;
376 int all_vert_stride = 128;
377 int all_vert128_shift = 0;
378 int block_16x16_num = 4;
379 int tid_x_shift = 0;
380
381 int std_seqlen;
382 if (tru_seqlen <= 128) {
383 std_seqlen = 128;
384 tid_x_shift = 2; // 16x32 128/32 = 4
385 all_vert_loop_num = 1;
386 block_16x16_num = 2;
387 } else if (tru_seqlen <= 256) {
388 std_seqlen = 256;
389 tid_x_shift = 2; // 16x64 256/64 = 4
390 } else if (tru_seqlen <= 384) {
391 std_seqlen = 384;
392 all_vert_stride = 64;
393 all_vert128_shift = 1;
394 block_16x16_num = 3;
395 tid_x_shift = 3; // 16x48 384/48 = 8
396 all_vert_loop_num = (tru_seqlen + all_vert_stride - 1) >> 6;
397 } else if (tru_seqlen <= 512) {
398 std_seqlen = 512;
399 all_vert_stride = 64;
400 all_vert128_shift = 1;
401 tid_x_shift = 3; // 16x64 512/64 = 8
402 all_vert_loop_num = (tru_seqlen + all_vert_stride - 1) >> 6;
403 } else if (tru_seqlen <= 1024) {
404 std_seqlen = 1024;
405 all_vert_stride = 32;
406 all_vert128_shift = 2;
407 tid_x_shift = 4; // 16x64 1024/64 = 16
408 all_vert_loop_num = (tru_seqlen + all_vert_stride - 1) >> 5;
409 } else if (tru_seqlen <= 2048) {
410 std_seqlen = 2048;
411 all_vert_stride = 16;
412 all_vert128_shift = 3;
413 tid_x_shift = 5; // 16x64 2048/64 = 32
414 all_vert_loop_num = (tru_seqlen + all_vert_stride - 1) >> 4;
415 }
416 all_vert_loop_num = ((all_vert_loop_num + (1 << all_vert128_shift) - 1)
417 >> all_vert128_shift)
418 << all_vert128_shift;
419 int tid_x = tid_linear & ((1 << tid_x_shift) - 1);
420 int tid_y = tid_linear >> tid_x_shift;
421
428
429 xetla_vector<int8_t, 4 * 16> attn_mk_4x16;
430 int valid_block_16x16_x = (tid_x + 1) * 16 * block_16x16_num;
431 {
432 int bndy_block_num = 0;
433 if (valid_block_16x16_x <= tru_seqlen)
434 valid_block_16x16_x = block_16x16_num;
435 else {
436 bndy_block_num = valid_block_16x16_x;
437 valid_block_16x16_x = (tru_seqlen + 15 + 16 * block_16x16_num
438 - valid_block_16x16_x)
439 >> 4;
440 bndy_block_num = bndy_block_num
441 + (valid_block_16x16_x - block_16x16_num) * 16
442 - tru_seqlen;
443 }
444
445 xetla_vector<uint32_t, 16> address_attn_mk
446 = xetla_vector_gen<uint32_t, 16>(0, 1);
447 int attn_mk_address_offset
448 = (batchid * Max_SeqLen) + (tid_x * 16 * block_16x16_num);
449 address_attn_mk *= sizeof(uint32_t);
450 address_attn_mk += attn_mk_address_offset;
451 attn_mk_4x16.xetla_format<uint32_t>().xetla_select<16, 1>(0)
454 16>(args->matMkin_ptr, address_attn_mk);
455
456 for (int i = 1; i <= bndy_block_num; i++)
457 attn_mk_4x16[valid_block_16x16_x * 16 - i] = 1;
458 }
459
460 for (int all_vert_loop = 0; all_vert_loop < all_vert_loop_num;
461 all_vert_loop++) {
462
463 xetla_vector<float, 4 * 16 * 16> matElem_reg_4x16x16;
465 bool valid_compute = true;
466
467 if (((all_vert_loop * all_vert_stride + tid_y * 16) >= tru_seqlen)
468 || ((tid_x * 16 * block_16x16_num) >= tru_seqlen))
469 valid_compute = false;
470
471 if (valid_compute) {
472
473 switch (std_seqlen) {
474 case 128: {
475 gemm_arguments_128x128 gemm_arg_128x128;
476 matAcc_128x128_t matAcc_128x128;
477
478 uint32_t width_a = (headid + 1) * hdsz;
479 uint32_t height_a = tru_seqlen + seqlen_entry;
480 uint32_t pitch_a = hiddensize;
481 int start_x_a = headid * hdsz;
482 int start_y_a = all_vert_loop * all_vert_stride
483 + seqlen_entry;
484
485 gemm_arg_128x128.matA_base_desc.init({args->matQ_ptr},
486 {width_a, height_a, pitch_a},
487 {start_x_a, start_y_a});
488
489 uint32_t width_b = (headid + 1) * hdsz;
490 uint32_t height_b = tru_seqlen + seqlen_entry;
491 uint32_t pitch_b = hiddensize;
492 int start_x_b = headid * hdsz;
493 int start_y_b = seqlen_entry;
494
495 //B transpose
496 gemm_arg_128x128.matB_base_desc.init({args->matK_ptr},
497 {height_b, width_b, pitch_b},
498 {start_y_b, start_x_b});
499
500 gemm_arg_128x128.inner_loop_count
501 = (wg_tile_QKT_k + k_stride - 1) / k_stride;
502
503 matAcc_128x128.init(0);
504
505 gemm_op_128x128_t gemm_op_128x128;
506
507 gemm_op_128x128(
508 g_thd32_tid, matAcc_128x128, gemm_arg_128x128);
509
510 matElem_reg_4x16x16.xetla_format<float>()
511 .xetla_select<16 * 32, 1>(0)
512 = matAcc_128x128.reg * args->Pinv;
513 } break;
514 case 256: {
515 gemm_arguments_128x256 gemm_arg_128x256;
516 matAcc_128x256_t matAcc_128x256;
517
518 uint32_t width_a = (headid + 1) * hdsz;
519 uint32_t height_a = tru_seqlen + seqlen_entry;
520 uint32_t pitch_a = hiddensize;
521 int start_x_a = headid * hdsz;
522 int start_y_a = all_vert_loop * all_vert_stride
523 + seqlen_entry;
524
525 gemm_arg_128x256.matA_base_desc.init({args->matQ_ptr},
526 {width_a, height_a, pitch_a},
527 {start_x_a, start_y_a});
528
529 uint32_t width_b = (headid + 1) * hdsz;
530 uint32_t height_b = tru_seqlen + seqlen_entry;
531 uint32_t pitch_b = hiddensize;
532 int start_x_b = headid * hdsz;
533 int start_y_b = seqlen_entry;
534
535 //B transpose
536 gemm_arg_128x256.matB_base_desc.init({args->matK_ptr},
537 {height_b, width_b, pitch_b},
538 {start_y_b, start_x_b});
539
540 gemm_arg_128x256.inner_loop_count
541 = (wg_tile_QKT_k + k_stride - 1) / k_stride;
542
543 matAcc_128x256.init(0);
544
545 gemm_op_128x256_t gemm_op_128x256;
546
547 gemm_op_128x256(
548 g_thd32_tid, matAcc_128x256, gemm_arg_128x256);
549
550 matElem_reg_4x16x16.xetla_format<float>()
551 .xetla_select<4 * 16 * 16, 1>(0)
552 = matAcc_128x256.reg * args->Pinv;
553
554 } break;
555 case 384: {
556 gemm_arguments_64x384 gemm_arg_64x384;
557 matAcc_64x384_t matAcc_64x384;
558
559 uint32_t width_a = (headid + 1) * hdsz;
560 uint32_t height_a = tru_seqlen + seqlen_entry;
561 uint32_t pitch_a = hiddensize;
562 int start_x_a = headid * hdsz;
563 int start_y_a = all_vert_loop * all_vert_stride
564 + seqlen_entry;
565
566 gemm_arg_64x384.matA_base_desc.init({args->matQ_ptr},
567 {width_a, height_a, pitch_a},
568 {start_x_a, start_y_a});
569
570 uint32_t width_b = (headid + 1) * hdsz;
571 uint32_t height_b = tru_seqlen + seqlen_entry;
572 uint32_t pitch_b = hiddensize;
573 int start_x_b = headid * hdsz;
574 int start_y_b = seqlen_entry;
575
576 //B transpose
577 gemm_arg_64x384.matB_base_desc.init({args->matK_ptr},
578 {height_b, width_b, pitch_b},
579 {start_y_b, start_x_b});
580
581 gemm_arg_64x384.inner_loop_count
582 = (wg_tile_QKT_k + k_stride - 1) / k_stride;
583
584 matAcc_64x384.init(0);
585
586 gemm_op_64x384_t gemm_op_64x384;
587 gemm_op_64x384(
588 g_thd32_tid, matAcc_64x384, gemm_arg_64x384);
589
590 matElem_reg_4x16x16.xetla_format<float>()
591 .xetla_select<3 * 16 * 16, 1>(0)
592 = matAcc_64x384.reg * args->Pinv;
593
594 } break;
595 case 512: {
596 gemm_arguments_64x512 gemm_arg_64x512;
597 matAcc_64x512_t matAcc_64x512;
598
599 uint32_t width_a = (headid + 1) * hdsz;
600 uint32_t height_a = tru_seqlen + seqlen_entry;
601 uint32_t pitch_a = hiddensize;
602 int start_x_a = headid * hdsz;
603 int start_y_a = all_vert_loop * all_vert_stride
604 + seqlen_entry;
605
606 gemm_arg_64x512.matA_base_desc.init({args->matQ_ptr},
607 {width_a, height_a, pitch_a},
608 {start_x_a, start_y_a});
609
610 uint32_t width_b = (headid + 1) * hdsz;
611 uint32_t height_b = tru_seqlen + seqlen_entry;
612 uint32_t pitch_b = hiddensize;
613 int start_x_b = headid * hdsz;
614 int start_y_b = seqlen_entry;
615
616 //B transpose
617 gemm_arg_64x512.matB_base_desc.init({args->matK_ptr},
618 {height_b, width_b, pitch_b},
619 {start_y_b, start_x_b});
620
621 gemm_arg_64x512.inner_loop_count
622 = (wg_tile_QKT_k + k_stride - 1) / k_stride;
623
624 matAcc_64x512.init(0);
625
626 gemm_op_64x512_t gemm_op_64x512;
627 gemm_op_64x512(
628 g_thd32_tid, matAcc_64x512, gemm_arg_64x512);
629
630 matElem_reg_4x16x16.xetla_format<float>()
631 .xetla_select<4 * 16 * 16, 1>(0)
632 = matAcc_64x512.reg * args->Pinv;
633
634 } break;
635 case 1024: {
636 gemm_arguments_32x1024 gemm_arg_32x1024;
637 matAcc_32x1024_t matAcc_32x1024;
638
639 uint32_t width_a = (headid + 1) * hdsz;
640 uint32_t height_a = tru_seqlen + seqlen_entry;
641 uint32_t pitch_a = hiddensize;
642 int start_x_a = headid * hdsz;
643 int start_y_a = all_vert_loop * all_vert_stride
644 + seqlen_entry;
645
646 gemm_arg_32x1024.matA_base_desc.init({args->matQ_ptr},
647 {width_a, height_a, pitch_a},
648 {start_x_a, start_y_a});
649
650 uint32_t width_b = (headid + 1) * hdsz;
651 uint32_t height_b = tru_seqlen + seqlen_entry;
652 uint32_t pitch_b = hiddensize;
653 int start_x_b = headid * hdsz;
654 int start_y_b = seqlen_entry;
655
656 //B transpose
657 gemm_arg_32x1024.matB_base_desc.init({args->matK_ptr},
658 {height_b, width_b, pitch_b},
659 {start_y_b, start_x_b});
660
661 gemm_arg_32x1024.inner_loop_count
662 = (wg_tile_QKT_k + k_stride - 1) / k_stride;
663
664 matAcc_32x1024.init(0);
665 gemm_op_32x1024_t gemm_op_32x1024;
666 gemm_op_32x1024(
667 g_thd32_tid, matAcc_32x1024, gemm_arg_32x1024);
668
669 matElem_reg_4x16x16.xetla_format<float>()
670 .xetla_select<4 * 16 * 16, 1>(0)
671 = matAcc_32x1024.reg * args->Pinv;
672 } break;
673 case 2048: {
674 gemm_arguments_16x2048 gemm_arg_16x2048;
675 matAcc_16x2048_t matAcc_16x2048;
676
677 uint32_t width_a = (headid + 1) * hdsz;
678 uint32_t height_a = tru_seqlen + seqlen_entry;
679 uint32_t pitch_a = hiddensize;
680 int start_x_a = headid * hdsz;
681 int start_y_a = all_vert_loop * all_vert_stride
682 + seqlen_entry;
683
684 gemm_arg_16x2048.matA_base_desc.init({args->matQ_ptr},
685 {width_a, height_a, pitch_a},
686 {start_x_a, start_y_a});
687
688 uint32_t width_b = (headid + 1) * hdsz;
689 uint32_t height_b = tru_seqlen + seqlen_entry;
690 uint32_t pitch_b = hiddensize;
691 int start_x_b = headid * hdsz;
692 int start_y_b = seqlen_entry;
693
694 //B transpose
695 gemm_arg_16x2048.matB_base_desc.init({args->matK_ptr},
696 {height_b, width_b, pitch_b},
697 {start_y_b, start_x_b});
698
699 gemm_arg_16x2048.inner_loop_count
700 = (wg_tile_QKT_k + k_stride - 1) / k_stride;
701
702 matAcc_16x2048.init(0);
703 gemm_op_16x2048_t gemm_op_16x2048;
704 gemm_op_16x2048(
705 g_thd32_tid, matAcc_16x2048, gemm_arg_16x2048);
706
707 matElem_reg_4x16x16.xetla_format<float>()
708 .xetla_select<4 * 16 * 16, 1>(0)
709 = matAcc_16x2048.reg * args->Pinv;
710 }
711 } //switch
712
713 { //softmax
714 xetla_vector<float, 16 * 1> matElem_reg_max_local;
715 xetla_vector<float, 16 * 1> matElem_reg_max_global;
716
717 xetla_vector<uint32_t, 16> address_fmax
718 = xetla_vector_gen<uint32_t, 16>(0, 1);
719 int address_offset
720 = (batchid * numhead + headid) * Max_SeqLen
721 + all_vert_stride * all_vert_loop + tid_y * 16;
722 address_fmax += address_offset;
723 address_fmax *= sizeof(float);
724
725 {
726 xetla_vector<float, 16 * 16> matElem_reg_Max;
727 xetla_vector<float, 16 * 8> matElem_reg_Max_8;
728 xetla_vector<float, 16 * 4> matElem_reg_Max_4;
729 xetla_vector<float, 16 * 2> matElem_reg_Max_2;
730
731 {
732#pragma unroll
733 for (int i = 0; i < 16; i++) {
734 matElem_reg_4x16x16.xetla_select<16, 1>(16 * i)
735 .merge(-1e32,
736 attn_mk_4x16.xetla_select<16,
737 1>(0)
738 > 0);
739 }
740
741 matElem_reg_Max
742 = matElem_reg_4x16x16
743 .xetla_select<16 * 16, 1>(0);
744 }
745
746 if (valid_block_16x16_x > 1) {
747#pragma unroll
748 for (int i = 0; i < 16; i++) {
749 matElem_reg_4x16x16
750 .xetla_select<16, 1>(
751 16 * i + 16 * 16 * 1)
752 .merge(-1e32,
753 attn_mk_4x16.xetla_select<16,
754 1>(16)
755 > 0);
756 }
757 matElem_reg_Max.merge(
758 matElem_reg_4x16x16
759 .xetla_select<16 * 16, 1>(
760 16 * 16 * 1),
761 matElem_reg_4x16x16
762 .xetla_select<16 * 16, 1>(
763 16 * 16 * 1)
764 > matElem_reg_Max);
765
766 if (valid_block_16x16_x > 2) {
767#pragma unroll
768 for (int i = 0; i < 16; i++) {
769 matElem_reg_4x16x16
770 .xetla_select<16, 1>(
771 16 * i + 16 * 16 * 2)
772 .merge(-1e32,
773 attn_mk_4x16.xetla_select<
774 16, 1>(16 * 2)
775 > 0);
776 }
777 matElem_reg_Max.merge(
778 matElem_reg_4x16x16
779 .xetla_select<16 * 16, 1>(
780 16 * 16 * 2),
781 matElem_reg_4x16x16.xetla_select<
782 16 * 16, 1>(16 * 16 * 2)
783 > matElem_reg_Max);
784 if (valid_block_16x16_x > 3) {
785#pragma unroll
786 for (int i = 0; i < 16; i++) {
787 matElem_reg_4x16x16
788 .xetla_select<16, 1>(
789 16 * i + 16 * 16 * 3)
790 .merge(-1e32,
791 attn_mk_4x16.xetla_select<
792 16, 1>(16 * 3)
793 > 0);
794 }
795 matElem_reg_Max.merge(
796 matElem_reg_4x16x16
797 .xetla_select<16 * 16, 1>(
798 16 * 16 * 3),
799 matElem_reg_4x16x16.xetla_select<
800 16 * 16, 1>(16 * 16 * 3)
801 > matElem_reg_Max);
802 }
803 }
804 }
805
806 matElem_reg_Max_8.xetla_format<float, 16, 8>()
807 .xetla_select<16, 1, 8, 1>(0, 0)
808 .merge(matElem_reg_Max
809 .xetla_format<float, 16, 16>()
810 .xetla_select<16, 1, 8, 1>(
811 0, 0),
812 matElem_reg_Max
813 .xetla_format<float, 16, 16>()
814 .xetla_select<16, 1, 8, 1>(
815 0, 8),
816 matElem_reg_Max.xetla_format<float, 16,
817 16>()
818 .xetla_select<16, 1, 8,
819 1>(0, 0)
820 > matElem_reg_Max
821 .xetla_format<float,
822 16, 16>()
823 .xetla_select<16, 1,
824 8, 1>(0, 8));
825 matElem_reg_Max_4.xetla_format<float, 16, 4>()
826 .xetla_select<16, 1, 4, 1>(0, 0)
827 .merge(matElem_reg_Max_8
828 .xetla_format<float, 16, 8>()
829 .xetla_select<16, 1, 4, 1>(
830 0, 0),
831 matElem_reg_Max_8
832 .xetla_format<float, 16, 8>()
833 .xetla_select<16, 1, 4, 1>(
834 0, 4),
835 matElem_reg_Max_8
836 .xetla_format<float, 16,
837 8>()
838 .xetla_select<16, 1, 4,
839 1>(0, 0)
840 > matElem_reg_Max_8
841 .xetla_format<float,
842 16, 8>()
843 .xetla_select<16, 1,
844 4, 1>(0, 4));
845 matElem_reg_Max_2.xetla_format<float, 16, 2>()
846 .xetla_select<16, 1, 2, 1>(0, 0)
847 .merge(matElem_reg_Max_4
848 .xetla_format<float, 16, 4>()
849 .xetla_select<16, 1, 2, 1>(
850 0, 0),
851 matElem_reg_Max_4
852 .xetla_format<float, 16, 4>()
853 .xetla_select<16, 1, 2, 1>(
854 0, 2),
855 matElem_reg_Max_4
856 .xetla_format<float, 16,
857 4>()
858 .xetla_select<16, 1, 2,
859 1>(0, 0)
860 > matElem_reg_Max_4
861 .xetla_format<float,
862 16, 4>()
863 .xetla_select<16, 1,
864 2, 1>(0, 2));
865 matElem_reg_max_local.xetla_format<float, 16, 1>()
866 .xetla_select<16, 1, 1, 1>(0, 0)
867 .merge(matElem_reg_Max_2
868 .xetla_format<float, 16, 2>()
869 .xetla_select<16, 1, 1, 1>(
870 0, 0),
871 matElem_reg_Max_2
872 .xetla_format<float, 16, 2>()
873 .xetla_select<16, 1, 1, 1>(
874 0, 1),
875 matElem_reg_Max_2
876 .xetla_format<float, 16,
877 2>()
878 .xetla_select<16, 1, 1,
879 1>(0, 0)
880 > matElem_reg_Max_2
881 .xetla_format<float,
882 16, 2>()
883 .xetla_select<16, 1,
884 1, 1>(0, 1));
885
886 xetla_mask<16> pred = 1;
889 (uint64_t)args->Max_ptr, address_fmax,
890 matElem_reg_max_local.xetla_select<16, 1>(0),
891 pred);
892 }
893
894 first_nbarr.arrive();
895 if constexpr (Dopt_RandGenflag == true) {
897#pragma unroll
898 for (int i = 0; i < ((16 * 16) / (2 * 4 * RandSIMD));
899 i++) {
900 rand_data = Rand_Gen.rand();
901 rand_bit.xetla_select<4 * RandSIMD, 1>(
902 i * (4 * RandSIMD))
903 = rand_data > rand_threshold;
904 }
905 }
906 first_nbarr.wait();
907
908 {
909 matElem_reg_max_global = xetla_load_global<float, 1,
912 16>(args->Max_ptr, address_fmax);
913
914 auto matElem_reg_max_use = matElem_reg_max_global;
915
916#pragma unroll
917 for (int i = 0; i < 16; i++) {
918 matElem_reg_4x16x16.xetla_select<16 * 16, 1>(0)
919 .xetla_select<16, 1>(i * 16)
920 = matElem_reg_4x16x16
921 .xetla_select<16 * 16, 1>(0)
922 .xetla_select<16, 1>(i * 16)
923 - matElem_reg_max_use[i];
924
925 matElem_reg_4x16x16.xetla_select<16 * 16, 1>(0)
926 .xetla_select<16, 1>(i * 16)
927 = xetla_exp<float, 16>(
928 matElem_reg_4x16x16
929 .xetla_select<16 * 16, 1>(0)
930 .xetla_select<16, 1>(
931 i * 16));
932 }
933
934 if (valid_block_16x16_x > 1) {
935#pragma unroll
936 for (int i = 0; i < 16; i++) {
937 matElem_reg_4x16x16
938 .xetla_select<16 * 16, 1>(16 * 16 * 1)
939 .xetla_select<16, 1>(i * 16)
940 = matElem_reg_4x16x16
941 .xetla_select<16 * 16, 1>(
942 16 * 16 * 1)
943 .xetla_select<16, 1>(i * 16)
944 - matElem_reg_max_use[i];
945
946 matElem_reg_4x16x16
947 .xetla_select<16 * 16, 1>(16 * 16 * 1)
948 .xetla_select<16, 1>(i * 16)
949 = xetla_exp<float, 16>(
950 matElem_reg_4x16x16
951 .xetla_select<16 * 16,
952 1>(16 * 16 * 1)
953 .xetla_select<16, 1>(
954 i * 16));
955 }
956
957 if (valid_block_16x16_x > 2) {
958#pragma unroll
959 for (int i = 0; i < 16; i++) {
960 matElem_reg_4x16x16
961 .xetla_select<16 * 16, 1>(
962 16 * 16 * 2)
963 .xetla_select<16, 1>(i * 16)
964 = matElem_reg_4x16x16
965 .xetla_select<16 * 16, 1>(
966 16 * 16 * 2)
967 .xetla_select<16, 1>(
968 i * 16)
969 - matElem_reg_max_use[i];
970 matElem_reg_4x16x16
971 .xetla_select<16 * 16, 1>(
972 16 * 16 * 2)
973 .xetla_select<16, 1>(i * 16)
974 = xetla_exp<float, 16>(
975 matElem_reg_4x16x16
977 16 * 16, 1>(
978 16 * 16 * 2)
979 .xetla_select<16,
980 1>(i * 16));
981 }
982
983 if (valid_block_16x16_x > 3) {
984#pragma unroll
985 for (int i = 0; i < 16; i++) {
986 matElem_reg_4x16x16
987 .xetla_select<16 * 16, 1>(
988 16 * 16 * 3)
989 .xetla_select<16, 1>(i * 16)
990 = matElem_reg_4x16x16
991 .xetla_select<16 * 16,
992 1>(
993 16 * 16 * 3)
994 .xetla_select<16, 1>(
995 i * 16)
996 - matElem_reg_max_use[i];
997 matElem_reg_4x16x16
998 .xetla_select<16 * 16, 1>(
999 16 * 16 * 3)
1000 .xetla_select<16, 1>(i * 16)
1001 = xetla_exp<float, 16>(
1002 matElem_reg_4x16x16
1003 .xetla_select<
1004 16 * 16,
1005 1>(16
1006 * 16
1007 * 3)
1008 .xetla_select<
1009 16, 1>(i
1010 * 16));
1011 }
1012 }
1013 }
1014 }
1015 }
1016
1017 xetla_vector<float, 16 * 1> matElem_reg_Sum_1;
1018
1019 {
1020 xetla_vector<float, 16 * 16> matElem_reg_Sum;
1021 xetla_vector<float, 16 * 8> matElem_reg_Sum_8;
1022 xetla_vector<float, 16 * 4> matElem_reg_Sum_4;
1023 xetla_vector<float, 16 * 2> matElem_reg_Sum_2;
1024
1025 matElem_reg_Sum
1026 = matElem_reg_4x16x16.xetla_select<16 * 16, 1>(
1027 0);
1028
1029 if (valid_block_16x16_x > 1) {
1030 matElem_reg_Sum
1031 += matElem_reg_4x16x16
1032 .xetla_select<16 * 16, 1>(
1033 16 * 16 * 1);
1034 if (valid_block_16x16_x > 2) {
1035 matElem_reg_Sum
1036 += matElem_reg_4x16x16
1037 .xetla_select<16 * 16, 1>(
1038 16 * 16 * 2);
1039 if (valid_block_16x16_x > 3)
1040 matElem_reg_Sum
1041 += matElem_reg_4x16x16.xetla_select<
1042 16 * 16, 1>(16 * 16 * 3);
1043 }
1044 }
1045 matElem_reg_Sum_8.xetla_format<float, 16, 8>()
1046 = matElem_reg_Sum.xetla_format<float, 16, 16>()
1047 .xetla_select<16, 1, 8, 1>(0, 0)
1048 + matElem_reg_Sum.xetla_format<float, 16, 16>()
1049 .xetla_select<16, 1, 8, 1>(0, 8);
1050
1051 matElem_reg_Sum_4.xetla_format<float, 16, 4>()
1052 = matElem_reg_Sum_8.xetla_format<float, 16, 8>()
1053 .xetla_select<16, 1, 4, 1>(0, 0)
1054 + matElem_reg_Sum_8.xetla_format<float, 16, 8>()
1055 .xetla_select<16, 1, 4, 1>(0, 4);
1056
1057 matElem_reg_Sum_2.xetla_format<float, 16, 2>()
1058 = matElem_reg_Sum_4.xetla_format<float, 16, 4>()
1059 .xetla_select<16, 1, 2, 1>(0, 0)
1060 + matElem_reg_Sum_4.xetla_format<float, 16, 4>()
1061 .xetla_select<16, 1, 2, 1>(0, 2);
1062
1063 matElem_reg_Sum_1.xetla_format<float, 16, 1>()
1064 = matElem_reg_Sum_2.xetla_format<float, 16, 2>()
1065 .xetla_select<16, 1, 1, 1>(0, 0)
1066 + matElem_reg_Sum_2.xetla_format<float, 16, 2>()
1067 .xetla_select<16, 1, 1, 1>(0, 1);
1068
1069 xetla_mask<16> pred = 1;
1072 (uint64_t)args->Sum_ptr, address_fmax,
1073 matElem_reg_Sum_1.xetla_select<16, 1>(0), pred);
1074 }
1075
1076 second_nbarr.arrive();
1077 if constexpr (Dopt_RandGenflag == true) {
1079#pragma unroll
1080 for (int i = ((16 * 16) / (2 * 4 * RandSIMD));
1081 i < ((16 * 16) / (4 * RandSIMD)); i++) {
1082 rand_data = Rand_Gen.rand();
1083 rand_bit.xetla_select<4 * RandSIMD, 1>(
1084 i * (4 * RandSIMD))
1085 = rand_data > rand_threshold;
1086 }
1087 }
1088 second_nbarr.wait();
1089
1090 {
1091 matElem_reg_Sum_1 = xetla_load_global<float, 1,
1094 16>(args->Sum_ptr, address_fmax);
1095
1096 matElem_reg_Sum_1
1097 = xetla_inv<float, 16>(matElem_reg_Sum_1);
1098 matElem_reg_Sum_1 *= args->Scaling;
1099
1100#pragma unroll
1101 for (int i = 0; i < 16; i++) {
1102 matElem_reg_4x16x16.xetla_select<16 * 16, 1>(0)
1103 .xetla_select<16, 1>(i * 16)
1104 = matElem_reg_4x16x16
1105 .xetla_select<16 * 16, 1>(0)
1106 .xetla_select<16, 1>(i * 16)
1107 * matElem_reg_Sum_1[i];
1108 }
1109
1110 if (valid_block_16x16_x > 1) {
1111#pragma unroll
1112 for (int i = 0; i < 16; i++) {
1113 matElem_reg_4x16x16
1114 .xetla_select<16 * 16, 1>(16 * 16 * 1)
1115 .xetla_select<16, 1>(i * 16)
1116 = matElem_reg_4x16x16
1117 .xetla_select<16 * 16, 1>(
1118 16 * 16 * 1)
1119 .xetla_select<16, 1>(i * 16)
1120 * matElem_reg_Sum_1[i];
1121 }
1122
1123 if constexpr (Dopt_RandGenflag == true) {
1125#pragma unroll
1126 for (int i = 0;
1127 i < ((16 * 16) / (4 * RandSIMD)); i++) {
1128 rand_data = Rand_Gen.rand();
1129 rand_bit.xetla_select<4 * RandSIMD, 1>(
1130 (i * (4 * RandSIMD))
1131 + (16 * 16 * 1))
1132 = rand_data > rand_threshold;
1133 }
1134 }
1135
1136 if (valid_block_16x16_x > 2) {
1137#pragma unroll
1138 for (int i = 0; i < 16; i++) {
1139 matElem_reg_4x16x16
1140 .xetla_select<16 * 16, 1>(
1141 16 * 16 * 2)
1142 .xetla_select<16, 1>(i * 16)
1143 = matElem_reg_4x16x16
1144 .xetla_select<16 * 16, 1>(
1145 16 * 16 * 2)
1146 .xetla_select<16, 1>(
1147 i * 16)
1148 * matElem_reg_Sum_1[i];
1149 }
1150
1151 if constexpr (Dopt_RandGenflag == true) {
1153 rand_data;
1154#pragma unroll
1155 for (int i = 0;
1156 i < ((16 * 16) / (4 * RandSIMD));
1157 i++) {
1158 rand_data = Rand_Gen.rand();
1159 rand_bit.xetla_select<4 * RandSIMD, 1>(
1160 (i * (4 * RandSIMD))
1161 + (16 * 16 * 2))
1162 = rand_data > rand_threshold;
1163 }
1164 }
1165
1166 if (valid_block_16x16_x > 3) {
1167#pragma unroll
1168 for (int i = 0; i < 16; i++) {
1169 matElem_reg_4x16x16
1170 .xetla_select<16 * 16, 1>(
1171 16 * 16 * 3)
1172 .xetla_select<16, 1>(i * 16)
1173 = matElem_reg_4x16x16
1174 .xetla_select<16 * 16,
1175 1>(
1176 16 * 16 * 3)
1177 .xetla_select<16, 1>(
1178 i * 16)
1179 * matElem_reg_Sum_1[i];
1180 }
1181
1182 if constexpr (Dopt_RandGenflag == true) {
1184 rand_data;
1185#pragma unroll
1186 for (int i = 0; i
1187 < ((16 * 16) / (4 * RandSIMD));
1188 i++) {
1189 rand_data = Rand_Gen.rand();
1190 rand_bit.xetla_select<4 * RandSIMD,
1191 1>((i * (4 * RandSIMD))
1192 + (16 * 16 * 3))
1193 = rand_data
1194 > rand_threshold;
1195 }
1196 }
1197 }
1198 }
1199 }
1200 }
1201 } //softmax
1202
1203 //store
1204 switch (std_seqlen) {
1205 case 128: {
1206 matC_128x128_t matC_128x128;
1207 matC_128x128_payload_t matC_128x128_payload;
1208 matDpotMk_128x128_t matDpotMk_128x128;
1209 matDpotMk_128x128_payload_t matDpotMk_128x128_payload;
1210
1211 int width_c = max_seqlen;
1212 int height_c
1213 = max_seqlen * (batchid * numhead + headid + 1);
1214 int pitch_c = max_seqlen;
1215 int start_x_c = gemm_op_128x128_t::get_matC_offset_x(
1216 g_thd32_tid);
1217 int start_y_c
1218 = (batchid * numhead + headid) * max_seqlen
1219 + all_vert_loop * all_vert_stride
1220 + gemm_op_128x128_t::get_matC_offset_y(
1221 g_thd32_tid);
1222 matC_128x128_payload.init(args->matQKT_ptr, width_c,
1223 height_c, pitch_c, start_x_c, start_y_c);
1224
1225 if constexpr (Dopt_RandGenflag == false) {
1226 uint8_t *matMkdpot_byte_ptr
1227 = (uint8_t *)(args->matMkdpot_ptr);
1228 matDpotMk_128x128_payload.init(matMkdpot_byte_ptr,
1229 width_c, height_c, pitch_c, start_x_c,
1230 start_y_c);
1231 subgroup::tile_load(matDpotMk_128x128,
1232 matDpotMk_128x128_payload);
1233 }
1234
1235 xetla_vector<float, 16 * 32> matElem_reg_store
1236 = matElem_reg_4x16x16.xetla_format<float>()
1237 .xetla_select<16 * 32, 1>(0);
1238 matC_128x128.reg = xetla_cvt<dtype_sfx, float>(
1239 matElem_reg_store);
1240
1241 if constexpr (Dopt_RandGenflag == false) {
1242 rand_bit.xetla_select<16 * 16 * 2, 1>(0)
1243 = matDpotMk_128x128.reg;
1244 }
1245
1246 if constexpr (sfx_type_size == 2) {
1248 drop_mk_w.xetla_merge(SIGN_BIT_W16,
1249 rand_bit.xetla_select<16 * 16 * 2, 1>(0)
1250 > 0);
1251 matC_128x128.reg.xetla_format<uint16_t>()
1252 |= drop_mk_w;
1253 }
1254 if constexpr (sfx_type_size == 1) {
1256 drop_mk_b.xetla_merge(SIGN_BIT_B8,
1257 rand_bit.xetla_select<16 * 16 * 2, 1>(0)
1258 > 0);
1259 matC_128x128.reg.xetla_format<uint8_t>()
1260 |= drop_mk_b;
1261 }
1262
1264 matC_128x128, matC_128x128_payload);
1265 xetla_fence<memory_kind::untyped_global>();
1266
1267 } break;
1268 case 256: {
1269 matC_128x256_t matC_128x256;
1270 matC_128x256_payload_t matC_128x256_payload;
1271 matDpotMk_128x256_t matDpotMk_128x256;
1272 matDpotMk_128x256_payload_t matDpotMk_128x256_payload;
1273
1274 int width_c = max_seqlen;
1275 int height_c
1276 = max_seqlen * (batchid * numhead + headid + 1);
1277 int pitch_c = max_seqlen;
1278 int start_x_c = gemm_op_128x256_t::get_matC_offset_x(
1279 g_thd32_tid);
1280 int start_y_c
1281 = (batchid * numhead + headid) * max_seqlen
1282 + all_vert_loop * all_vert_stride
1283 + gemm_op_128x256_t::get_matC_offset_y(
1284 g_thd32_tid);
1285
1286 matC_128x256_payload.init(args->matQKT_ptr, width_c,
1287 height_c, pitch_c, start_x_c, start_y_c);
1288
1289 if constexpr (Dopt_RandGenflag == false) {
1290 uint8_t *matMkdpot_byte_ptr
1291 = (uint8_t *)(args->matMkdpot_ptr);
1292 matDpotMk_128x256_payload.init(matMkdpot_byte_ptr,
1293 width_c, height_c, pitch_c, start_x_c,
1294 start_y_c);
1295 subgroup::tile_load(matDpotMk_128x256,
1296 matDpotMk_128x256_payload);
1297 }
1298
1299 matC_128x256.reg = xetla_cvt<dtype_sfx, float>(
1300 matElem_reg_4x16x16);
1301
1302 if constexpr (Dopt_RandGenflag == false) {
1303 rand_bit = matDpotMk_128x256.reg;
1304 }
1305
1306 if constexpr (sfx_type_size == 2) {
1308 drop_mk_w.xetla_merge(SIGN_BIT_W16,
1309 rand_bit.xetla_select<16 * 16 * 4, 1>(0)
1310 > 0);
1311 matC_128x256.reg.xetla_format<uint16_t>()
1312 |= drop_mk_w;
1313 }
1314 if constexpr (sfx_type_size == 1) {
1316 drop_mk_b.xetla_merge(SIGN_BIT_B8,
1317 rand_bit.xetla_select<16 * 16 * 4, 1>(0)
1318 > 0);
1319 matC_128x256.reg.xetla_format<uint8_t>()
1320 |= drop_mk_b;
1321 }
1322
1324 matC_128x256, matC_128x256_payload);
1325 xetla_fence<memory_kind::untyped_global>();
1326
1327 } break;
1328 case 384: {
1329 matC_64x384_t matC_64x384;
1330 matC_64x384_payload_t matC_64x384_payload;
1331 matDpotMk_64x384_t matDpotMk_64x384;
1332 matDpotMk_64x384_payload_t matDpotMk_64x384_payload;
1333
1334 int width_c = max_seqlen;
1335 int height_c
1336 = max_seqlen * (batchid * numhead + headid + 1);
1337 int pitch_c = max_seqlen;
1338 int start_x_c = gemm_op_64x384_t::get_matC_offset_x(
1339 g_thd32_tid);
1340 int start_y_c
1341 = (batchid * numhead + headid) * max_seqlen
1342 + all_vert_loop * all_vert_stride
1343 + gemm_op_64x384_t::get_matC_offset_y(
1344 g_thd32_tid);
1345
1346 matC_64x384_payload.init(args->matQKT_ptr, width_c,
1347 height_c, pitch_c, start_x_c, start_y_c);
1348
1349 if constexpr (Dopt_RandGenflag == false) {
1350 uint8_t *matMkdpot_byte_ptr
1351 = (uint8_t *)(args->matMkdpot_ptr);
1352 matDpotMk_64x384_payload.init(matMkdpot_byte_ptr,
1353 width_c, height_c, pitch_c, start_x_c,
1354 start_y_c);
1356 matDpotMk_64x384, matDpotMk_64x384_payload);
1357 }
1358
1359 xetla_vector<float, 3 * 16 * 16> matElem_reg_store
1360 = matElem_reg_4x16x16.xetla_format<float>()
1361 .xetla_select<3 * 16 * 16, 1>(0);
1362 matC_64x384.reg = xetla_cvt<dtype_sfx, float>(
1363 matElem_reg_store);
1364
1365 if constexpr (Dopt_RandGenflag == false) {
1366 rand_bit.xetla_select<16 * 16 * 3, 1>(0)
1367 = matDpotMk_64x384.reg;
1368 }
1369
1370 if constexpr (sfx_type_size == 2) {
1372 drop_mk_w.xetla_merge(SIGN_BIT_W16,
1373 rand_bit.xetla_select<16 * 16 * 3, 1>(0)
1374 > 0);
1375 matC_64x384.reg.xetla_format<uint16_t>()
1376 |= drop_mk_w;
1377 }
1378 if constexpr (sfx_type_size == 1) {
1380 drop_mk_b.xetla_merge(SIGN_BIT_B8,
1381 rand_bit.xetla_select<16 * 16 * 3, 1>(0)
1382 > 0);
1383 matC_64x384.reg.xetla_format<uint8_t>()
1384 |= drop_mk_b;
1385 }
1386
1387 subgroup::tile_store(matC_64x384, matC_64x384_payload);
1388 xetla_fence<memory_kind::untyped_global>();
1389 } break;
1390 case 512: {
1391 matC_64x512_t matC_64x512;
1392 matC_64x512_payload_t matC_64x512_payload;
1393 matDpotMk_64x512_t matDpotMk_64x512;
1394 matDpotMk_64x512_payload_t matDpotMk_64x512_payload;
1395
1396 int width_c = max_seqlen;
1397 int height_c
1398 = max_seqlen * (batchid * numhead + headid + 1);
1399 int pitch_c = max_seqlen;
1400 int start_x_c = gemm_op_64x512_t::get_matC_offset_x(
1401 g_thd32_tid);
1402 int start_y_c
1403 = (batchid * numhead + headid) * max_seqlen
1404 + all_vert_loop * all_vert_stride
1405 + gemm_op_64x512_t::get_matC_offset_y(
1406 g_thd32_tid);
1407 matC_64x512_payload.init(args->matQKT_ptr, width_c,
1408 height_c, pitch_c, start_x_c, start_y_c);
1409
1410 if constexpr (Dopt_RandGenflag == false) {
1411 uint8_t *matMkdpot_byte_ptr
1412 = (uint8_t *)(args->matMkdpot_ptr);
1413 matDpotMk_64x512_payload.init(matMkdpot_byte_ptr,
1414 width_c, height_c, pitch_c, start_x_c,
1415 start_y_c);
1417 matDpotMk_64x512, matDpotMk_64x512_payload);
1418 }
1419
1420 matC_64x512.reg = xetla_cvt<dtype_sfx, float>(
1421 matElem_reg_4x16x16);
1422
1423 if constexpr (Dopt_RandGenflag == false) {
1424 rand_bit = matDpotMk_64x512.reg;
1425 }
1426
1427 if constexpr (sfx_type_size == 2) {
1429 drop_mk_w.xetla_merge(SIGN_BIT_W16,
1430 rand_bit.xetla_select<16 * 16 * 4, 1>(0)
1431 > 0);
1432 matC_64x512.reg.xetla_format<uint16_t>()
1433 |= drop_mk_w;
1434 }
1435 if constexpr (sfx_type_size == 1) {
1437 drop_mk_b.xetla_merge(SIGN_BIT_B8,
1438 rand_bit.xetla_select<16 * 16 * 4, 1>(0)
1439 > 0);
1440 matC_64x512.reg.xetla_format<uint8_t>()
1441 |= drop_mk_b;
1442 }
1443
1444 subgroup::tile_store(matC_64x512, matC_64x512_payload);
1445 xetla_fence<memory_kind::untyped_global>();
1446 } break;
1447 case 1024: {
1448 matC_32x1024_t matC_32x1024;
1449 matC_32x1024_payload_t matC_32x1024_payload;
1450 matDpotMk_32x1024_t matDpotMk_32x1024;
1451 matDpotMk_32x1024_payload_t matDpotMk_32x1024_payload;
1452
1453 int width_c = max_seqlen;
1454 int height_c
1455 = max_seqlen * (batchid * numhead + headid + 1);
1456 int pitch_c = max_seqlen;
1457 int start_x_c = gemm_op_32x1024_t::get_matC_offset_x(
1458 g_thd32_tid);
1459 int start_y_c
1460 = (batchid * numhead + headid) * max_seqlen
1461 + all_vert_loop * all_vert_stride
1462 + gemm_op_32x1024_t::get_matC_offset_y(
1463 g_thd32_tid);
1464
1465 matC_32x1024_payload.init(args->matQKT_ptr, width_c,
1466 height_c, pitch_c, start_x_c, start_y_c);
1467
1468 if constexpr (Dopt_RandGenflag == false) {
1469 uint8_t *matMkdpot_byte_ptr
1470 = (uint8_t *)(args->matMkdpot_ptr);
1471 matDpotMk_32x1024_payload.init(matMkdpot_byte_ptr,
1472 width_c, height_c, pitch_c, start_x_c,
1473 start_y_c);
1474 subgroup::tile_load(matDpotMk_32x1024,
1475 matDpotMk_32x1024_payload);
1476 }
1477
1478 matC_32x1024.reg = xetla_cvt<dtype_sfx, float>(
1479 matElem_reg_4x16x16);
1480
1481 if constexpr (Dopt_RandGenflag == false) {
1482 rand_bit = matDpotMk_32x1024.reg;
1483 }
1484
1485 if constexpr (sfx_type_size == 2) {
1487 drop_mk_w.xetla_merge(SIGN_BIT_W16,
1488 rand_bit.xetla_select<16 * 16 * 4, 1>(0)
1489 > 0);
1490 matC_32x1024.reg.xetla_format<uint16_t>()
1491 |= drop_mk_w;
1492 }
1493 if constexpr (sfx_type_size == 1) {
1495 drop_mk_b.xetla_merge(SIGN_BIT_B8,
1496 rand_bit.xetla_select<16 * 16 * 4, 1>(0)
1497 > 0);
1498 matC_32x1024.reg.xetla_format<uint8_t>()
1499 |= drop_mk_b;
1500 }
1501
1503 matC_32x1024, matC_32x1024_payload);
1504 xetla_fence<memory_kind::untyped_global>();
1505 } break;
1506 case 2048: {
1507 matC_16x2048_t matC_16x2048;
1508 matC_16x2048_payload_t matC_16x2048_payload;
1509 matDpotMk_16x2048_t matDpotMk_16x2048;
1510 matDpotMk_16x2048_payload_t matDpotMk_16x2048_payload;
1511
1512 int width_c = max_seqlen;
1513 int height_c
1514 = max_seqlen * (batchid * numhead + headid + 1);
1515 int pitch_c = max_seqlen;
1516 int start_x_c = gemm_op_16x2048_t::get_matC_offset_x(
1517 g_thd32_tid);
1518 int start_y_c
1519 = (batchid * numhead + headid) * max_seqlen
1520 + all_vert_loop * all_vert_stride
1521 + gemm_op_16x2048_t::get_matC_offset_y(
1522 g_thd32_tid);
1523
1524 matC_16x2048_payload.init(args->matQKT_ptr, width_c,
1525 height_c, pitch_c, start_x_c, start_y_c);
1526
1527 if constexpr (Dopt_RandGenflag == false) {
1528 uint8_t *matMkdpot_byte_ptr
1529 = (uint8_t *)(args->matMkdpot_ptr);
1530 matDpotMk_16x2048_payload.init(matMkdpot_byte_ptr,
1531 width_c, height_c, pitch_c, start_x_c,
1532 start_y_c);
1533 subgroup::tile_load(matDpotMk_16x2048,
1534 matDpotMk_16x2048_payload);
1535 }
1536
1537 matC_16x2048.reg = xetla_cvt<dtype_sfx, float>(
1538 matElem_reg_4x16x16);
1539
1540 if constexpr (Dopt_RandGenflag == false) {
1541 rand_bit = matDpotMk_16x2048.reg;
1542 }
1543
1544 if constexpr (sfx_type_size == 2) {
1546 drop_mk_w.xetla_merge(SIGN_BIT_W16,
1547 rand_bit.xetla_select<16 * 16 * 4, 1>(0)
1548 > 0);
1549 matC_16x2048.reg.xetla_format<uint16_t>()
1550 |= drop_mk_w;
1551 }
1552 if constexpr (sfx_type_size == 1) {
1554 drop_mk_b.xetla_merge(SIGN_BIT_B8,
1555 rand_bit.xetla_select<16 * 16 * 4, 1>(0)
1556 > 0);
1557 matC_16x2048.reg.xetla_format<uint8_t>()
1558 |= drop_mk_b;
1559 }
1560
1562 matC_16x2048, matC_16x2048_payload);
1563 xetla_fence<memory_kind::untyped_global>();
1564 }
1565 } //store switch
1566
1567 } else { //valid_compute
1568 first_nbarr.arrive();
1569 first_nbarr.wait();
1570
1571 second_nbarr.arrive();
1572 second_nbarr.wait();
1573 }
1574
1575 //QKTV
1576 int all_vert128_loop = all_vert_loop >> all_vert128_shift;
1577 if (((((all_vert128_loop + 1) << all_vert128_shift) - 1)
1578 == all_vert_loop)
1579 || (all_vert128_shift == 0)) {
1580
1581 third_nbarr.arrive();
1582 third_nbarr.wait();
1583
1584 gemm_arguments_128x64 gemm_arg_128x64;
1585 matAcc_128x64_t matAcc_128x64;
1586 matC_128x64_t matC_128x64;
1587 matC_128x64_payload_t matC_128x64_payload;
1588
1589 uint32_t width_a = tru_seqlen_ex;
1590 uint32_t height_a = (batchid * numhead + headid) * max_seqlen
1591 + tru_seqlen;
1592 uint32_t pitch_a = max_seqlen;
1593 int start_x_a = 0;
1594 int start_y_a = (batchid * numhead + headid) * max_seqlen
1595 + all_vert128_loop * 128;
1596
1597 gemm_arg_128x64.matA_base_desc.init({args->matQKT_ptr},
1598 {width_a, height_a, pitch_a}, {start_x_a, start_y_a});
1599
1600 uint32_t width_b = (headid + 1) * hdsz;
1601 uint32_t height_b = tru_seqlen + seqlen_entry;
1602 uint32_t pitch_b = hiddensize;
1603 int start_x_b = headid * hdsz;
1604 int start_y_b = seqlen_entry;
1605
1606 gemm_arg_128x64.matB_base_desc.init({args->matV_ptr},
1607 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
1608
1609 gemm_arg_128x64.inner_loop_count
1610 = (wg_tile_out_k + k_stride - 1) / k_stride;
1611
1612 matAcc_128x64.init(0);
1613
1614 gemm_op_128x64_t gemm_op_128x64;
1615
1616 gemm_op_128x64(g_thd32_tid, matAcc_128x64, gemm_arg_128x64);
1617
1618 int width_c = (headid + 1) * hdsz;
1619 int height_c = tru_seqlen + seqlen_entry;
1620 int pitch_c = hiddensize;
1621 int start_x_c = headid * hdsz
1622 + gemm_op_128x64_t::get_matC_offset_x(g_thd32_tid);
1623 int start_y_c = all_vert128_loop * 128 + seqlen_entry
1624 + gemm_op_128x64_t::get_matC_offset_y(g_thd32_tid);
1625
1626 matC_128x64_payload.init(args->matOut_ptr, width_c, height_c,
1627 pitch_c, start_x_c, start_y_c);
1628 subgroup::elemwise_cvt<matC_128x64_t, matAcc_128x64_t>(
1629 matC_128x64, matAcc_128x64);
1630 subgroup::tile_store(matC_128x64, matC_128x64_payload);
1631 }
1632
1633 } //all_vert128_loop
1634 } //xetla_softmax_fwd_t::call()
1635}; //struct xetla_softmax_fwd_t
1636
1637template <typename dtype_bwd_bin_, typename dtype_bwd_bot_,
1638 typename dtype_bwd_sfx_, typename dtype_bwd_acc_, int HWThreadNum,
1639 bool Dopt_RandGenflag = true, bool Mkin_flag = false,
1640 int Max_SeqLen = 512>
1642 using dtype_bin = dtype_bwd_bin_;
1643 using dtype_bot = dtype_bwd_bot_;
1644 using dtype_sfx = dtype_bwd_sfx_;
1645 using dtype_acc = dtype_bwd_acc_;
1646
1647 static constexpr int ThreadNum = HWThreadNum;
1648 static_assert(ThreadNum == 32);
1652
1658
1663
1667
1668 static constexpr uint32_t periodic_sync_interval = 0;
1669 static constexpr uint32_t prefetch_distance = 3;
1670
1671 static constexpr uint32_t k_stride
1672 = 32 / sizeof(dtype_bin); //gemm_t::k_stride;
1675
1684
1692
1700
1708
1709 static constexpr uint32_t global_kslicing = 1;
1710 static constexpr uint16_t sfx_type_size = sizeof(dtype_sfx);
1711 static_assert((sfx_type_size == 1) || (sfx_type_size == 2)
1712 || (sfx_type_size == 4));
1713
1715
1734 gpu_arch::Xe>;
1737 gpu_arch::Xe>;
1738
1755
1770
1771 using gemm_arguments_128x128 = typename gemm_op_128x128_t::arguments_t;
1772 using gemm_arguments_128x256 = typename gemm_op_128x256_t::arguments_t;
1773 using gemm_arguments_64x384 = typename gemm_op_64x384_t::arguments_t;
1774 using gemm_arguments_64x512 = typename gemm_op_64x512_t::arguments_t;
1775 using gemm_arguments_32x1024 = typename gemm_op_32x1024_t::arguments_t;
1776 using gemm_arguments_16x2048 = typename gemm_op_16x2048_t::arguments_t;
1777
1778 using gemm_arguments_128x64 = typename gemm_op_128x64_t::arguments_t;
1780 typename gemm_op_128x64_trnp_a_t::arguments_t;
1782 typename gemm_op_256x64_trnp_a_t::arguments_t;
1784 typename gemm_op_128x64_trnp_af_t::arguments_t;
1786 typename gemm_op_256x64_trnp_af_t::arguments_t;
1787
1788 using matAcc_128x128_t = typename gemm_op_128x128_t::matAcc_t;
1789 using matAcc_128x256_t = typename gemm_op_128x256_t::matAcc_t;
1790 using matAcc_64x384_t = typename gemm_op_64x384_t::matAcc_t;
1791 using matAcc_64x512_t = typename gemm_op_64x512_t::matAcc_t;
1792 using matAcc_32x1024_t = typename gemm_op_32x1024_t::matAcc_t;
1793 using matAcc_16x2048_t = typename gemm_op_16x2048_t::matAcc_t;
1794
1795 using matAcc_128x64_t = typename gemm_op_128x64_t::matAcc_t;
1796 using matAcc_128x64_trnp_a_t = typename gemm_op_128x64_trnp_a_t::matAcc_t;
1797 using matAcc_256x64_trnp_a_t = typename gemm_op_256x64_trnp_a_t::matAcc_t;
1798 using matAcc_128x64_trnp_af_t = typename gemm_op_128x64_trnp_af_t::matAcc_t;
1799 using matAcc_256x64_trnp_af_t = typename gemm_op_256x64_trnp_af_t::matAcc_t;
1800
1802 = subgroup::tile_desc_t<matAcc_128x128_t::tile_desc::tile_size_x,
1803 matAcc_128x128_t::tile_desc::tile_size_y,
1804 matAcc_128x128_t::tile_desc::block_size_x,
1805 matAcc_128x128_t::tile_desc::block_size_y,
1808 = subgroup::tile_desc_t<matAcc_128x256_t::tile_desc::tile_size_x,
1809 matAcc_128x256_t::tile_desc::tile_size_y,
1810 matAcc_128x256_t::tile_desc::block_size_x,
1811 matAcc_128x256_t::tile_desc::block_size_y,
1814 = subgroup::tile_desc_t<matAcc_64x384_t::tile_desc::tile_size_x,
1815 matAcc_64x384_t::tile_desc::tile_size_y,
1816 matAcc_64x384_t::tile_desc::block_size_x,
1817 matAcc_64x384_t::tile_desc::block_size_y,
1820 = subgroup::tile_desc_t<matAcc_64x512_t::tile_desc::tile_size_x,
1821 matAcc_64x512_t::tile_desc::tile_size_y,
1822 matAcc_64x512_t::tile_desc::block_size_x,
1823 matAcc_64x512_t::tile_desc::block_size_y,
1826 = subgroup::tile_desc_t<matAcc_32x1024_t::tile_desc::tile_size_x,
1827 matAcc_32x1024_t::tile_desc::tile_size_y,
1828 matAcc_32x1024_t::tile_desc::block_size_x,
1829 matAcc_32x1024_t::tile_desc::block_size_y,
1832 = subgroup::tile_desc_t<matAcc_16x2048_t::tile_desc::tile_size_x,
1833 matAcc_16x2048_t::tile_desc::tile_size_y,
1834 matAcc_16x2048_t::tile_desc::block_size_x,
1835 matAcc_16x2048_t::tile_desc::block_size_y,
1847
1851 (global_kslicing > 1)
1853 : subgroup::msg_type_v<matC_128x128_tile_desc_t,
1854 mem_space_c>,
1855 gpu_arch::Xe>;
1859 (global_kslicing > 1)
1861 : subgroup::msg_type_v<matC_128x256_tile_desc_t,
1862 mem_space_c>,
1863 gpu_arch::Xe>;
1868 : subgroup::msg_type_v<
1870 gpu_arch::Xe>;
1875 : subgroup::msg_type_v<
1877 gpu_arch::Xe>;
1881 (global_kslicing > 1)
1883 : subgroup::msg_type_v<matC_32x1024_tile_desc_t,
1884 mem_space_c>,
1885 gpu_arch::Xe>;
1889 (global_kslicing > 1)
1891 : subgroup::msg_type_v<matC_16x2048_tile_desc_t,
1892 mem_space_c>,
1893 gpu_arch::Xe>;
1894
1896 = subgroup::tile_desc_t<matAcc_128x64_t::tile_desc::tile_size_x,
1897 matAcc_128x64_t::tile_desc::tile_size_y,
1898 matAcc_128x64_t::tile_desc::block_size_x,
1899 matAcc_128x64_t::tile_desc::block_size_y,
1902 matAcc_128x64_trnp_a_t::tile_desc::tile_size_x,
1903 matAcc_128x64_trnp_a_t::tile_desc::tile_size_y,
1904 matAcc_128x64_trnp_a_t::tile_desc::block_size_x,
1905 matAcc_128x64_trnp_a_t::tile_desc::block_size_y, reg_layout::tiled>;
1907 matAcc_256x64_trnp_a_t::tile_desc::tile_size_x,
1908 matAcc_256x64_trnp_a_t::tile_desc::tile_size_y,
1909 matAcc_256x64_trnp_a_t::tile_desc::block_size_x,
1910 matAcc_256x64_trnp_a_t::tile_desc::block_size_y, reg_layout::tiled>;
1912 matAcc_128x64_trnp_af_t::tile_desc::tile_size_x,
1913 matAcc_128x64_trnp_af_t::tile_desc::tile_size_y,
1914 matAcc_128x64_trnp_af_t::tile_desc::block_size_x,
1915 matAcc_128x64_trnp_af_t::tile_desc::block_size_y,
1918 matAcc_256x64_trnp_af_t::tile_desc::tile_size_x,
1919 matAcc_256x64_trnp_af_t::tile_desc::tile_size_y,
1920 matAcc_256x64_trnp_af_t::tile_desc::block_size_x,
1921 matAcc_256x64_trnp_af_t::tile_desc::block_size_y,
1932
1937 : subgroup::msg_type_v<
1939 gpu_arch::Xe>;
1943 (global_kslicing > 1)
1945 : subgroup::msg_type_v<matC_128x64_trnp_a_tile_desc_t,
1946 mem_space_c>,
1947 gpu_arch::Xe>;
1951 subgroup::msg_type_v<matC_256x64_trnp_a_tile_desc_t, mem_space_c>,
1952 gpu_arch::Xe>;
1956 (global_kslicing > 1)
1958 : subgroup::msg_type_v<matC_128x64_trnp_af_tile_desc_t,
1959 mem_space_c>,
1960 gpu_arch::Xe>;
1964 (global_kslicing > 1)
1966 : subgroup::msg_type_v<matC_256x64_trnp_af_tile_desc_t,
1967 mem_space_c>,
1968 gpu_arch::Xe>;
1969
1980
1984 subgroup::msg_type_v<matC_128x128_tile_desc_t, mem_space_c>,
1985 gpu_arch::Xe>;
1989 subgroup::msg_type_v<matC_128x256_tile_desc_t, mem_space_c>,
1990 gpu_arch::Xe>;
1994 subgroup::msg_type_v<matC_64x384_tile_desc_t, mem_space_c>,
1995 gpu_arch::Xe>;
1999 subgroup::msg_type_v<matC_64x512_tile_desc_t, mem_space_c>,
2000 gpu_arch::Xe>;
2004 subgroup::msg_type_v<matC_32x1024_tile_desc_t, mem_space_c>,
2005 gpu_arch::Xe>;
2009 subgroup::msg_type_v<matC_16x2048_tile_desc_t, mem_space_c>,
2010 gpu_arch::Xe>;
2011
2012#if 0
2013 //512 = 16x32 or 8x64
2017 matElem_tile_desc,
2018 subgroup::msg_type_v<matElem_tile_desc, mem_space::global>>;
2020 matElem_tile_desc,
2022#endif
2023
2027 // assume base address, surface width, height, pitch, start coordinate was set
2028 uint32_t *mList_ptr;
2032 uint32_t *matMkin_ptr;
2033 uint32_t *matMkdpot_ptr;
2041 float Pinv;
2042 float Scaling;
2043 };
2044
2048 __XETLA_API static void call(sycl::nd_item<3> &item, arguments_t *args) {
2049
2050 int tru_seqlen = 0;
2051 int tru_seqlen_ex = 0;
2052 int seqlen_entry = 0;
2053 int hiddensize = 1024;
2054 int numhead = 16;
2055 int hdsz = 64;
2056 int max_seqlen = Max_SeqLen;
2057 int wg_tile_QKT_k = hdsz; //args->matrix_k;
2058 int wg_tile_out_k;
2059
2060 int groupid = item.get_group(0);
2061 int batchid = groupid / numhead;
2062 int headid = groupid % numhead;
2063
2064 work_group_t g_thd32_tid;
2065 int tid_linear = item.get_local_linear_id();
2066 g_thd32_tid.init(tid_linear);
2067
2068 //float totalscaling = args->Pinv * args->Scaling;
2069
2070 uint32_t batch_offset = sizeof(uint32_t) * list_width * batchid;
2072 = xetla_vector_gen<uint32_t, list_width>(0, 1);
2073 list_offsets *= sizeof(uint32_t);
2074 list_offsets += batch_offset;
2078 list_width>(args->mList_ptr, list_offsets);
2079 tru_seqlen = list_vec[0];
2080 seqlen_entry = list_vec[1];
2081 wg_tile_out_k = tru_seqlen;
2082 tru_seqlen_ex = tru_seqlen; //4: dw aligned
2083 if (sfx_type_size == 2)
2084 tru_seqlen_ex = ((tru_seqlen + 1) >> 1) << 1;
2085 else if (sfx_type_size == 1)
2086 tru_seqlen_ex = ((tru_seqlen + 3) >> 2) << 2;
2087
2088 //reset for all std_seqlen
2089 int all_vert_loop_num = 0;
2090 int transp128_loop_num = 0;
2091 int transp256_loop_num = 0;
2092 int offset_blk_128x128 = 0;
2093 int all_vert_stride = 0;
2094 int all_vert128_shift = 0;
2095 int block_16x16_num = 0;
2096 int tid_x_shift = 0;
2097
2098 int std_seqlen;
2099 if (tru_seqlen <= 128) {
2100 std_seqlen = 128;
2101 all_vert_loop_num = 1;
2102 transp128_loop_num = 1;
2103 tid_x_shift = 2; // 16x32 128/32 = 4
2104 all_vert_loop_num = 1;
2105 all_vert_stride = 128;
2106 block_16x16_num = 2;
2107 } else if (tru_seqlen <= 256) {
2108 std_seqlen = 256;
2109 transp256_loop_num = 1;
2110 all_vert_loop_num = 2;
2111 all_vert_stride = 128;
2112 all_vert128_shift = 0;
2113 block_16x16_num = 4;
2114 tid_x_shift = 2; // 16x64 256/64 = 4
2115 } else if (tru_seqlen <= 384) {
2116 std_seqlen = 384;
2117 transp128_loop_num = 1;
2118 transp256_loop_num = 1;
2119 offset_blk_128x128 = 256;
2120 all_vert_stride = 64;
2121 all_vert128_shift = 1;
2122 block_16x16_num = 3;
2123 tid_x_shift = 3; // 16x48 384/48 = 8
2124 all_vert_loop_num = (tru_seqlen + all_vert_stride - 1) >> 6;
2125 } else if (tru_seqlen <= 512) {
2126 std_seqlen = 512;
2127 transp256_loop_num = 2;
2128 all_vert_stride = 64;
2129 all_vert128_shift = 1;
2130 block_16x16_num = 4;
2131 tid_x_shift = 3; // 16x64 512/64 = 8
2132 all_vert_loop_num = (tru_seqlen + all_vert_stride - 1) >> 6;
2133 } else if (tru_seqlen <= 1024) {
2134 std_seqlen = 1024;
2135 transp256_loop_num = 4;
2136 all_vert_stride = 32;
2137 all_vert128_shift = 2;
2138 block_16x16_num = 4;
2139 tid_x_shift = 4; // 16x64 1024/64 = 16
2140 all_vert_loop_num = (tru_seqlen + all_vert_stride - 1) >> 5;
2141 } else if (tru_seqlen <= 2048) {
2142 std_seqlen = 2048;
2143 transp256_loop_num = 8;
2144 all_vert_stride = 16;
2145 all_vert128_shift = 3;
2146 block_16x16_num = 4;
2147 tid_x_shift = 5; // 16x64 2048/64 = 32
2148 all_vert_loop_num = (tru_seqlen + all_vert_stride - 1) >> 4;
2149 }
2150 all_vert_loop_num = ((all_vert_loop_num + (1 << all_vert128_shift) - 1)
2151 >> all_vert128_shift)
2152 << all_vert128_shift;
2153 int tid_x = tid_linear & ((1 << tid_x_shift) - 1);
2154 int tid_y = tid_linear >> tid_x_shift;
2155
2156 static_assert(ThreadNum == 32, "All Thread Sync");
2159
2160 int max_2d_nbar_id = ThreadNum >> 1;
2161 first_nbarr.init_nbarrier(
2162 max_2d_nbar_id, nbarrier_role::producer_consumer);
2163 second_nbarr.init_nbarrier(
2164 max_2d_nbar_id + 1, nbarrier_role::producer_consumer);
2165
2167 all_nbarr.init_nbarrier(
2169
2170 for (int transp128_loop = 0; transp128_loop < transp128_loop_num;
2171 transp128_loop++) {
2172 gemm_arguments_128x64_trnp_af gemm_arg_128x64;
2173 matAcc_128x64_trnp_af_t matAcc_128x64;
2174 matC_128x64_trnp_af_t matC_128x64;
2175 matC_128x64_trnp_af_payload_t matC_128x64_payload;
2176
2177 uint32_t width_a = tru_seqlen_ex;
2178 uint32_t height_a
2179 = (batchid * numhead + headid) * max_seqlen + tru_seqlen;
2180 uint32_t pitch_a = max_seqlen;
2181 int start_x_a = transp128_loop * 128 + offset_blk_128x128;
2182 int start_y_a = (batchid * numhead + headid) * max_seqlen;
2183
2184 gemm_arg_128x64.matA_base_desc.init({args->matW_ptr},
2185 {height_a, width_a, pitch_a}, {start_y_a, start_x_a});
2186
2187 uint32_t width_b = (headid + 1) * hdsz;
2188 uint32_t height_b = tru_seqlen + seqlen_entry;
2189 uint32_t pitch_b = hiddensize;
2190 int start_x_b = headid * hdsz;
2191 int start_y_b = seqlen_entry;
2192
2193 gemm_arg_128x64.matB_base_desc.init({args->matdO_ptr},
2194 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
2195
2196 gemm_arg_128x64.inner_loop_count
2197 = (wg_tile_out_k + k_stride - 1) / k_stride;
2198
2199 matAcc_128x64.init(0);
2200 gemm_op_128x64_trnp_af_t gemm_op_128x64_trnp_af;
2201 gemm_op_128x64_trnp_af(g_thd32_tid, matAcc_128x64, gemm_arg_128x64);
2202
2203 int width_c = (headid + 1) * hdsz;
2204 int height_c = tru_seqlen + seqlen_entry;
2205 int pitch_c = hiddensize;
2206 int start_x_c = headid * hdsz
2207 + gemm_op_128x64_trnp_af_t::get_matC_offset_x(g_thd32_tid);
2208 int start_y_c = transp128_loop * 128 + seqlen_entry
2209 + offset_blk_128x128
2210 + gemm_op_128x64_trnp_af_t::get_matC_offset_y(g_thd32_tid);
2211
2212 matC_128x64_payload.init(args->matdV_ptr, width_c, height_c,
2213 pitch_c, start_x_c, start_y_c);
2215 matAcc_128x64_trnp_af_t>(matC_128x64, matAcc_128x64);
2216 subgroup::tile_store(matC_128x64, matC_128x64_payload);
2217
2218 //add global sync if nbarr used inside gemm
2219 all_nbarr.arrive();
2220 all_nbarr.wait();
2221 }
2222
2223 for (int transp256_loop = 0; transp256_loop < transp256_loop_num;
2224 transp256_loop++) {
2225 gemm_arguments_256x64_trnp_af gemm_arg_256x64;
2226 matAcc_256x64_trnp_af_t matAcc_256x64;
2227 matC_256x64_trnp_af_t matC_256x64;
2228 matC_256x64_trnp_af_payload_t matC_256x64_payload;
2229
2230 uint32_t width_a = tru_seqlen_ex;
2231 uint32_t height_a
2232 = (batchid * numhead + headid) * max_seqlen + tru_seqlen;
2233 uint32_t pitch_a = max_seqlen;
2234 int start_x_a = transp256_loop * 256;
2235 int start_y_a = (batchid * numhead + headid) * max_seqlen;
2236
2237 gemm_arg_256x64.matA_base_desc.init({args->matW_ptr},
2238 {height_a, width_a, pitch_a}, {start_y_a, start_x_a});
2239
2240 uint32_t width_b = (headid + 1) * hdsz;
2241 uint32_t height_b = tru_seqlen + seqlen_entry;
2242 uint32_t pitch_b = hiddensize;
2243 int start_x_b = headid * hdsz;
2244 int start_y_b = seqlen_entry;
2245
2246 gemm_arg_256x64.matB_base_desc.init({args->matdO_ptr},
2247 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
2248
2249 gemm_arg_256x64.inner_loop_count
2250 = (wg_tile_out_k + k_stride - 1) / k_stride;
2251
2252 matAcc_256x64.init(0);
2253 gemm_op_256x64_trnp_af_t gemm_op_256x64_trnp_af;
2254 gemm_op_256x64_trnp_af(g_thd32_tid, matAcc_256x64, gemm_arg_256x64);
2255
2256 int width_c = (headid + 1) * hdsz;
2257 int height_c = tru_seqlen + seqlen_entry;
2258 int pitch_c = hiddensize;
2259 int start_x_c = headid * hdsz
2260 + gemm_op_256x64_trnp_af_t::get_matC_offset_x(g_thd32_tid);
2261 int start_y_c = transp256_loop * 256 + seqlen_entry
2262 + gemm_op_256x64_trnp_af_t::get_matC_offset_y(g_thd32_tid);
2263
2264 matC_256x64_payload.init(args->matdV_ptr, width_c, height_c,
2265 pitch_c, start_x_c, start_y_c);
2267 matAcc_256x64_trnp_af_t>(matC_256x64, matAcc_256x64);
2268 subgroup::tile_store(matC_256x64, matC_256x64_payload);
2269
2270 //add global sync if nbarr used inside gemm
2271 all_nbarr.arrive();
2272 all_nbarr.wait();
2273 }
2274
2275 int valid_block_16x16_x = (tid_x + 1) * 16 * block_16x16_num;
2276 {
2277 int bndy_block_num = 0;
2278 if (valid_block_16x16_x <= tru_seqlen)
2279 valid_block_16x16_x = block_16x16_num;
2280 else {
2281 bndy_block_num = valid_block_16x16_x;
2282 valid_block_16x16_x = (tru_seqlen + 15 + 16 * block_16x16_num
2283 - valid_block_16x16_x)
2284 >> 4;
2285 bndy_block_num = bndy_block_num
2286 + (valid_block_16x16_x - block_16x16_num) * 16
2287 - tru_seqlen;
2288 }
2289 }
2290
2291 for (int all_vert_loop = 0; all_vert_loop < all_vert_loop_num;
2292 all_vert_loop++) {
2293 xetla_vector<float, 4 * 16 * 16> matElem_reg_4x16x16;
2294 xetla_vector<float, 4 * 16 * 16> matW_reg_4x16x16;
2295 xetla_vector<uint8_t, 4 * 16 * 16> Sign_reg_4x16x16 = 0;
2296 bool valid_compute = true;
2297
2298 int ld_st_width_c = max_seqlen;
2299 int ld_st_height_c = max_seqlen * (batchid * numhead + headid + 1);
2300 int ld_st_pitch_c = max_seqlen;
2301 int ld_st_start_x_c = 0;
2302 int ld_st_start_y_c = 0;
2303
2304 if (((all_vert_loop * all_vert_stride + tid_y * 16) >= tru_seqlen)
2305 || ((tid_x * 16 * block_16x16_num) >= tru_seqlen))
2306 valid_compute = false;
2307
2308 if (valid_compute) {
2309
2310 switch (std_seqlen) {
2311 case 128: {
2312 gemm_arguments_128x128 gemm_arg_128x128;
2313 matAcc_128x128_t matAcc_128x128;
2314
2315 matW_128x128_t matW_128x128;
2316 matW_128x128_payload_t matW_128x128_payload;
2317
2318 ld_st_start_x_c = gemm_op_128x128_t::get_matC_offset_x(
2319 g_thd32_tid);
2320 ld_st_start_y_c
2321 = (batchid * numhead + headid) * max_seqlen
2322 + all_vert_loop * all_vert_stride
2323 + gemm_op_128x128_t::get_matC_offset_y(
2324 g_thd32_tid);
2325
2326 matW_128x128_payload.init(args->matW_ptr, ld_st_width_c,
2327 ld_st_height_c, ld_st_pitch_c, ld_st_start_x_c,
2328 ld_st_start_y_c);
2329 subgroup::tile_load(matW_128x128, matW_128x128_payload);
2330
2331 uint32_t width_a = (headid + 1) * hdsz;
2332 uint32_t height_a = tru_seqlen + seqlen_entry;
2333 uint32_t pitch_a = hiddensize;
2334 int start_x_a = headid * hdsz;
2335 int start_y_a = all_vert_loop * all_vert_stride
2336 + seqlen_entry;
2337
2338 gemm_arg_128x128.matA_base_desc.init({args->matdO_ptr},
2339 {width_a, height_a, pitch_a},
2340 {start_x_a, start_y_a});
2341
2342 uint32_t width_b = (headid + 1) * hdsz;
2343 uint32_t height_b = tru_seqlen + seqlen_entry;
2344 uint32_t pitch_b = hiddensize;
2345 int start_x_b = headid * hdsz;
2346 int start_y_b = seqlen_entry;
2347
2348 //B transpose, be swapped in init
2349 gemm_arg_128x128.matB_base_desc.init({args->matV_ptr},
2350 {height_b, width_b, pitch_b},
2351 {start_y_b, start_x_b});
2352
2353 gemm_arg_128x128.inner_loop_count
2354 = (wg_tile_QKT_k + k_stride - 1) / k_stride;
2355
2356 matAcc_128x128.init(0);
2357 gemm_op_128x128_t gemm_op_128x128;
2358 gemm_op_128x128(
2359 g_thd32_tid, matAcc_128x128, gemm_arg_128x128);
2360
2361 matElem_reg_4x16x16.xetla_format<float>()
2362 .xetla_select<16 * 32, 1>(0)
2363 = matAcc_128x128.reg;
2364
2365 if constexpr (sfx_type_size == 1) {
2366 Sign_reg_4x16x16.xetla_select<16 * 16 * 2, 1>(0)
2367 .merge(1,
2368 matW_128x128.reg.xetla_format<
2369 int8_t>()
2370 < 0);
2371 matW_128x128.reg.xetla_format<uint8_t>() &= 0x7F;
2372 }
2373 if constexpr (sfx_type_size == 2) {
2374 Sign_reg_4x16x16.xetla_select<16 * 16 * 2, 1>(0)
2375 .merge(1,
2376 matW_128x128.reg.xetla_format<
2377 int16_t>()
2378 < 0);
2379 matW_128x128.reg.xetla_format<uint16_t>() &= 0x7FFF;
2380 }
2381
2382 xetla_vector<float, 2 * 16 * 16> matElem_reg_conv;
2383 matElem_reg_conv
2384 = xetla_cvt<float, dtype_sfx>(matW_128x128.reg);
2385 matW_reg_4x16x16.xetla_select<16 * 16 * 2, 1>(0)
2386 = matElem_reg_conv;
2387 } break;
2388
2389 case 256: {
2390 gemm_arguments_128x256 gemm_arg_128x256;
2391 matAcc_128x256_t matAcc_128x256;
2392
2393 matW_128x256_t matW_128x256;
2394 matW_128x256_payload_t matW_128x256_payload;
2395
2396 ld_st_start_x_c = gemm_op_128x256_t::get_matC_offset_x(
2397 g_thd32_tid);
2398 ld_st_start_y_c
2399 = (batchid * numhead + headid) * max_seqlen
2400 + all_vert_loop * all_vert_stride
2401 + gemm_op_128x256_t::get_matC_offset_y(
2402 g_thd32_tid);
2403
2404 matW_128x256_payload.init(args->matW_ptr, ld_st_width_c,
2405 ld_st_height_c, ld_st_pitch_c, ld_st_start_x_c,
2406 ld_st_start_y_c);
2407 subgroup::tile_load(matW_128x256, matW_128x256_payload);
2408
2409 uint32_t width_a = (headid + 1) * hdsz;
2410 uint32_t height_a = tru_seqlen + seqlen_entry;
2411 uint32_t pitch_a = hiddensize;
2412 int start_x_a = headid * hdsz;
2413 int start_y_a = all_vert_loop * all_vert_stride
2414 + seqlen_entry;
2415
2416 gemm_arg_128x256.matA_base_desc.init({args->matdO_ptr},
2417 {width_a, height_a, pitch_a},
2418 {start_x_a, start_y_a});
2419
2420 uint32_t width_b = (headid + 1) * hdsz;
2421 uint32_t height_b = tru_seqlen + seqlen_entry;
2422 uint32_t pitch_b = hiddensize;
2423 int start_x_b = headid * hdsz;
2424 int start_y_b = seqlen_entry;
2425
2426 //B transpose, be swapped in init
2427 gemm_arg_128x256.matB_base_desc.init({args->matV_ptr},
2428 {height_b, width_b, pitch_b},
2429 {start_y_b, start_x_b});
2430
2431 gemm_arg_128x256.inner_loop_count
2432 = (wg_tile_QKT_k + k_stride - 1) / k_stride;
2433
2434 matAcc_128x256.init(0);
2435 gemm_op_128x256_t gemm_op_128x256;
2436 gemm_op_128x256(
2437 g_thd32_tid, matAcc_128x256, gemm_arg_128x256);
2438
2439 matElem_reg_4x16x16.xetla_format<float>()
2440 .xetla_select<16 * 16 * 4, 1>(0)
2441 = matAcc_128x256.reg;
2442
2443 if constexpr (sfx_type_size == 1) {
2444 Sign_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2445 .merge(1,
2446 matW_128x256.reg.xetla_format<
2447 int8_t>()
2448 < 0);
2449 matW_128x256.reg.xetla_format<uint8_t>() &= 0x7F;
2450 }
2451 if constexpr (sfx_type_size == 2) {
2452 Sign_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2453 .merge(1,
2454 matW_128x256.reg.xetla_format<
2455 int16_t>()
2456 < 0);
2457 matW_128x256.reg.xetla_format<uint16_t>() &= 0x7FFF;
2458 }
2459
2460 matW_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2461 = xetla_cvt<float, dtype_sfx>(matW_128x256.reg);
2462 } break;
2463
2464 case 384: {
2465 gemm_arguments_64x384 gemm_arg_64x384;
2466 matAcc_64x384_t matAcc_64x384;
2467
2468 matW_64x384_t matW_64x384;
2469 matW_64x384_payload_t matW_64x384_payload;
2470
2471 ld_st_start_x_c = gemm_op_64x384_t::get_matC_offset_x(
2472 g_thd32_tid);
2473 ld_st_start_y_c
2474 = (batchid * numhead + headid) * max_seqlen
2475 + all_vert_loop * all_vert_stride
2476 + gemm_op_64x384_t::get_matC_offset_y(
2477 g_thd32_tid);
2478 matW_64x384_payload.init(args->matW_ptr, ld_st_width_c,
2479 ld_st_height_c, ld_st_pitch_c, ld_st_start_x_c,
2480 ld_st_start_y_c);
2481 subgroup::tile_load(matW_64x384, matW_64x384_payload);
2482
2483 uint32_t width_a = (headid + 1) * hdsz;
2484 uint32_t height_a = tru_seqlen + seqlen_entry;
2485 uint32_t pitch_a = hiddensize;
2486 int start_x_a = headid * hdsz;
2487 int start_y_a = all_vert_loop * all_vert_stride
2488 + seqlen_entry;
2489
2490 gemm_arg_64x384.matA_base_desc.init({args->matdO_ptr},
2491 {width_a, height_a, pitch_a},
2492 {start_x_a, start_y_a});
2493
2494 uint32_t width_b = (headid + 1) * hdsz;
2495 uint32_t height_b = tru_seqlen + seqlen_entry;
2496 uint32_t pitch_b = hiddensize;
2497 int start_x_b = headid * hdsz;
2498 int start_y_b = seqlen_entry;
2499
2500 //B transpose, be swapped in init
2501 gemm_arg_64x384.matB_base_desc.init({args->matV_ptr},
2502 {height_b, width_b, pitch_b},
2503 {start_y_b, start_x_b});
2504
2505 gemm_arg_64x384.inner_loop_count
2506 = (wg_tile_QKT_k + k_stride - 1) / k_stride;
2507
2508 matAcc_64x384.init(0);
2509 gemm_op_64x384_t gemm_op_64x384;
2510 gemm_op_64x384(
2511 g_thd32_tid, matAcc_64x384, gemm_arg_64x384);
2512
2513 matElem_reg_4x16x16.xetla_format<float>()
2514 .xetla_select<16 * 16 * 3, 1>(0)
2515 = matAcc_64x384.reg;
2516
2517 if constexpr (sfx_type_size == 1) {
2518 Sign_reg_4x16x16.xetla_select<16 * 16 * 3, 1>(0)
2519 .merge(1,
2520 matW_64x384.reg.xetla_format<
2521 int8_t>()
2522 < 0);
2523 matW_64x384.reg.xetla_format<uint8_t>() &= 0x7F;
2524 }
2525 if constexpr (sfx_type_size == 2) {
2526 Sign_reg_4x16x16.xetla_select<16 * 16 * 3, 1>(0)
2527 .merge(1,
2528 matW_64x384.reg.xetla_format<
2529 int16_t>()
2530 < 0);
2531 matW_64x384.reg.xetla_format<uint16_t>() &= 0x7FFF;
2532 }
2533
2534 xetla_vector<float, 3 * 16 * 16> matElem_reg_conv;
2535 matElem_reg_conv
2536 = xetla_cvt<float, dtype_sfx>(matW_64x384.reg);
2537 matW_reg_4x16x16.xetla_select<16 * 16 * 3, 1>(0)
2538 = matElem_reg_conv;
2539 } break;
2540
2541 case 512: {
2542 gemm_arguments_64x512 gemm_arg_64x512;
2543 matAcc_64x512_t matAcc_64x512;
2544
2545 matW_64x512_t matW_64x512;
2546 matW_64x512_payload_t matW_64x512_payload;
2547
2548 ld_st_start_x_c = gemm_op_64x512_t::get_matC_offset_x(
2549 g_thd32_tid);
2550 ld_st_start_y_c
2551 = (batchid * numhead + headid) * max_seqlen
2552 + all_vert_loop * all_vert_stride
2553 + gemm_op_64x512_t::get_matC_offset_y(
2554 g_thd32_tid);
2555 matW_64x512_payload.init(args->matW_ptr, ld_st_width_c,
2556 ld_st_height_c, ld_st_pitch_c, ld_st_start_x_c,
2557 ld_st_start_y_c);
2558 subgroup::tile_load(matW_64x512, matW_64x512_payload);
2559
2560 uint32_t width_a = (headid + 1) * hdsz;
2561 uint32_t height_a = tru_seqlen + seqlen_entry;
2562 uint32_t pitch_a = hiddensize;
2563 int start_x_a = headid * hdsz;
2564 int start_y_a = all_vert_loop * all_vert_stride
2565 + seqlen_entry;
2566
2567 gemm_arg_64x512.matA_base_desc.init({args->matdO_ptr},
2568 {width_a, height_a, pitch_a},
2569 {start_x_a, start_y_a});
2570
2571 uint32_t width_b = (headid + 1) * hdsz;
2572 uint32_t height_b = tru_seqlen + seqlen_entry;
2573 uint32_t pitch_b = hiddensize;
2574 int start_x_b = headid * hdsz;
2575 int start_y_b = seqlen_entry;
2576
2577 //B transpose, be swapped in init
2578 gemm_arg_64x512.matB_base_desc.init({args->matV_ptr},
2579 {height_b, width_b, pitch_b},
2580 {start_y_b, start_x_b});
2581
2582 gemm_arg_64x512.inner_loop_count
2583 = (wg_tile_QKT_k + k_stride - 1) / k_stride;
2584
2585 matAcc_64x512.init(0);
2586 gemm_op_64x512_t gemm_op_64x512;
2587 gemm_op_64x512(
2588 g_thd32_tid, matAcc_64x512, gemm_arg_64x512);
2589
2590 matElem_reg_4x16x16.xetla_format<float>()
2591 .xetla_select<16 * 16 * 4, 1>(0)
2592 = matAcc_64x512.reg;
2593
2594 if constexpr (sfx_type_size == 1) {
2595 Sign_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2596 .merge(1,
2597 matW_64x512.reg.xetla_format<
2598 int8_t>()
2599 < 0);
2600 matW_64x512.reg.xetla_format<uint8_t>() &= 0x7F;
2601 }
2602 if constexpr (sfx_type_size == 2) {
2603 Sign_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2604 .merge(1,
2605 matW_64x512.reg.xetla_format<
2606 int16_t>()
2607 < 0);
2608 matW_64x512.reg.xetla_format<uint16_t>() &= 0x7FFF;
2609 }
2610
2611 matW_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2612 = xetla_cvt<float, dtype_sfx>(matW_64x512.reg);
2613 } break;
2614
2615 case 1024: {
2616 gemm_arguments_32x1024 gemm_arg_32x1024;
2617 matAcc_32x1024_t matAcc_32x1024;
2618
2619 matW_32x1024_t matW_32x1024;
2620 matW_32x1024_payload_t matW_32x1024_payload;
2621
2622 ld_st_start_x_c = gemm_op_32x1024_t::get_matC_offset_x(
2623 g_thd32_tid);
2624 ld_st_start_y_c
2625 = (batchid * numhead + headid) * max_seqlen
2626 + all_vert_loop * all_vert_stride
2627 + gemm_op_32x1024_t::get_matC_offset_y(
2628 g_thd32_tid);
2629 matW_32x1024_payload.init(args->matW_ptr, ld_st_width_c,
2630 ld_st_height_c, ld_st_pitch_c, ld_st_start_x_c,
2631 ld_st_start_y_c);
2632 subgroup::tile_load(matW_32x1024, matW_32x1024_payload);
2633
2634 uint32_t width_a = (headid + 1) * hdsz;
2635 uint32_t height_a = tru_seqlen + seqlen_entry;
2636 uint32_t pitch_a = hiddensize;
2637 int start_x_a = headid * hdsz;
2638 int start_y_a = all_vert_loop * all_vert_stride
2639 + seqlen_entry;
2640
2641 gemm_arg_32x1024.matA_base_desc.init({args->matdO_ptr},
2642 {width_a, height_a, pitch_a},
2643 {start_x_a, start_y_a});
2644
2645 uint32_t width_b = (headid + 1) * hdsz;
2646 uint32_t height_b = tru_seqlen + seqlen_entry;
2647 uint32_t pitch_b = hiddensize;
2648 int start_x_b = headid * hdsz;
2649 int start_y_b = seqlen_entry;
2650
2651 //B transpose, be swapped in init
2652 gemm_arg_32x1024.matB_base_desc.init({args->matV_ptr},
2653 {height_b, width_b, pitch_b},
2654 {start_y_b, start_x_b});
2655
2656 gemm_arg_32x1024.inner_loop_count
2657 = (wg_tile_QKT_k + k_stride - 1) / k_stride;
2658
2659 matAcc_32x1024.init(0);
2660 gemm_op_32x1024_t gemm_op_32x1024;
2661 gemm_op_32x1024(
2662 g_thd32_tid, matAcc_32x1024, gemm_arg_32x1024);
2663
2664 matElem_reg_4x16x16.xetla_format<float>()
2665 .xetla_select<16 * 16 * 4, 1>(0)
2666 = matAcc_32x1024.reg;
2667
2668 if constexpr (sfx_type_size == 1) {
2669 Sign_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2670 .merge(1,
2671 matW_32x1024.reg.xetla_format<
2672 int8_t>()
2673 < 0);
2674 matW_32x1024.reg.xetla_format<uint8_t>() &= 0x7F;
2675 }
2676 if constexpr (sfx_type_size == 2) {
2677 Sign_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2678 .merge(1,
2679 matW_32x1024.reg.xetla_format<
2680 int16_t>()
2681 < 0);
2682 matW_32x1024.reg.xetla_format<uint16_t>() &= 0x7FFF;
2683 }
2684
2685 matW_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2686 = xetla_cvt<float, dtype_sfx>(matW_32x1024.reg);
2687 } break;
2688
2689 case 2048: {
2690 gemm_arguments_16x2048 gemm_arg_16x2048;
2691 matAcc_16x2048_t matAcc_16x2048;
2692
2693 matW_16x2048_t matW_16x2048;
2694 matW_16x2048_payload_t matW_16x2048_payload;
2695
2696 ld_st_start_x_c = gemm_op_16x2048_t::get_matC_offset_x(
2697 g_thd32_tid);
2698 ld_st_start_y_c
2699 = (batchid * numhead + headid) * max_seqlen
2700 + all_vert_loop * all_vert_stride
2701 + gemm_op_16x2048_t::get_matC_offset_y(
2702 g_thd32_tid);
2703 matW_16x2048_payload.init(args->matW_ptr, ld_st_width_c,
2704 ld_st_height_c, ld_st_pitch_c, ld_st_start_x_c,
2705 ld_st_start_y_c);
2706 subgroup::tile_load(matW_16x2048, matW_16x2048_payload);
2707
2708 uint32_t width_a = (headid + 1) * hdsz;
2709 uint32_t height_a = tru_seqlen + seqlen_entry;
2710 uint32_t pitch_a = hiddensize;
2711 int start_x_a = headid * hdsz;
2712 int start_y_a = all_vert_loop * all_vert_stride
2713 + seqlen_entry;
2714
2715 gemm_arg_16x2048.matA_base_desc.init({args->matdO_ptr},
2716 {width_a, height_a, pitch_a},
2717 {start_x_a, start_y_a});
2718
2719 uint32_t width_b = (headid + 1) * hdsz;
2720 uint32_t height_b = tru_seqlen + seqlen_entry;
2721 uint32_t pitch_b = hiddensize;
2722 int start_x_b = headid * hdsz;
2723 int start_y_b = seqlen_entry;
2724
2725 //B transpose, be swapped in init
2726 gemm_arg_16x2048.matB_base_desc.init({args->matV_ptr},
2727 {height_b, width_b, pitch_b},
2728 {start_y_b, start_x_b});
2729
2730 gemm_arg_16x2048.inner_loop_count
2731 = (wg_tile_QKT_k + k_stride - 1) / k_stride;
2732
2733 matAcc_16x2048.init(0);
2734 gemm_op_16x2048_t gemm_op_16x2048;
2735 gemm_op_16x2048(
2736 g_thd32_tid, matAcc_16x2048, gemm_arg_16x2048);
2737
2738 matElem_reg_4x16x16.xetla_format<float>()
2739 .xetla_select<16 * 16 * 4, 1>(0)
2740 = matAcc_16x2048.reg;
2741
2742 if constexpr (sfx_type_size == 1) {
2743 Sign_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2744 .merge(1,
2745 matW_16x2048.reg.xetla_format<
2746 int8_t>()
2747 < 0);
2748 matW_16x2048.reg.xetla_format<uint8_t>() &= 0x7F;
2749 }
2750 if constexpr (sfx_type_size == 2) {
2751 Sign_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2752 .merge(1,
2753 matW_16x2048.reg.xetla_format<
2754 int16_t>()
2755 < 0);
2756 matW_16x2048.reg.xetla_format<uint16_t>() &= 0x7FFF;
2757 }
2758
2759 matW_reg_4x16x16.xetla_select<16 * 16 * 4, 1>(0)
2760 = xetla_cvt<float, dtype_sfx>(matW_16x2048.reg);
2761 } break;
2762 } //switch
2763
2764 //softmax
2765 {
2766
2767 xetla_vector<float, 16 * 1> matElem_reg_Sum_1;
2768 xetla_vector<float, 16 * 16> matElem_reg_Sum;
2769 xetla_vector<float, 16 * 8> matElem_reg_Sum_8;
2770 xetla_vector<float, 16 * 4> matElem_reg_Sum_4;
2771 xetla_vector<float, 16 * 2> matElem_reg_Sum_2;
2772
2773 matElem_reg_4x16x16.xetla_format<float>()
2774 .xetla_select<16 * 16 * 2, 1>(0)
2775 *= matW_reg_4x16x16.xetla_select<16 * 16 * 2, 1>(0);
2776
2777 matElem_reg_4x16x16.xetla_format<float>()
2778 .xetla_select<16 * 16 * 2, 1>(0)
2779 .merge(0.0,
2780 Sign_reg_4x16x16.xetla_select<16 * 16 * 2,
2781 1>(0)
2782 > 0);
2783
2784 matElem_reg_Sum = matElem_reg_4x16x16.xetla_format<float>()
2785 .xetla_select<16 * 16, 1>(0)
2786 + matElem_reg_4x16x16.xetla_format<float>()
2787 .xetla_select<16 * 16, 1>(16 * 16);
2788
2789 if (valid_block_16x16_x > 2) {
2790
2791 matElem_reg_4x16x16.xetla_format<float>()
2792 .xetla_select<16 * 16, 1>(16 * 16 * 2)
2793 *= matW_reg_4x16x16.xetla_select<16 * 16, 1>(
2794 16 * 16 * 2);
2795
2796 matElem_reg_4x16x16.xetla_format<float>()
2797 .xetla_select<16 * 16, 1>(16 * 16 * 2)
2798 .merge(0.0,
2799 Sign_reg_4x16x16.xetla_select<16 * 16,
2800 1>(16 * 16 * 2)
2801 > 0);
2802
2803 matElem_reg_Sum = matElem_reg_Sum
2804 + matElem_reg_4x16x16.xetla_format<float>()
2805 .xetla_select<16 * 16, 1>(
2806 16 * 16 * 2);
2807
2808 if (valid_block_16x16_x > 3) {
2809 matElem_reg_4x16x16.xetla_format<float>()
2810 .xetla_select<16 * 16, 1>(16 * 16 * 3)
2811 *= matW_reg_4x16x16
2812 .xetla_select<16 * 16, 1>(
2813 16 * 16 * 3);
2814
2815 matElem_reg_4x16x16.xetla_format<float>()
2816 .xetla_select<16 * 16, 1>(16 * 16 * 3)
2817 .merge(0.0,
2818 Sign_reg_4x16x16.xetla_select<
2819 16 * 16, 1>(16 * 16 * 3)
2820 > 0);
2821
2822 matElem_reg_Sum = matElem_reg_Sum
2823 + matElem_reg_4x16x16.xetla_format<float>()
2824 .xetla_select<16 * 16, 1>(
2825 16 * 16 * 3);
2826 }
2827 }
2828
2829 matElem_reg_Sum_8.xetla_format<float, 16, 8>()
2830 = matElem_reg_Sum.xetla_format<float, 16, 16>()
2831 .xetla_select<16, 1, 8, 1>(0, 0)
2832 + matElem_reg_Sum.xetla_format<float, 16, 16>()
2833 .xetla_select<16, 1, 8, 1>(0, 8);
2834
2835 matElem_reg_Sum_4.xetla_format<float, 16, 4>()
2836 = matElem_reg_Sum_8.xetla_format<float, 16, 8>()
2837 .xetla_select<16, 1, 4, 1>(0, 0)
2838 + matElem_reg_Sum_8.xetla_format<float, 16, 8>()
2839 .xetla_select<16, 1, 4, 1>(0, 4);
2840
2841 matElem_reg_Sum_2.xetla_format<float, 16, 2>()
2842 = matElem_reg_Sum_4.xetla_format<float, 16, 4>()
2843 .xetla_select<16, 1, 2, 1>(0, 0)
2844 + matElem_reg_Sum_4.xetla_format<float, 16, 4>()
2845 .xetla_select<16, 1, 2, 1>(0, 2);
2846
2847 matElem_reg_Sum_1.xetla_format<float, 16, 1>()
2848 = matElem_reg_Sum_2.xetla_format<float, 16, 2>()
2849 .xetla_select<16, 1, 1, 1>(0, 0)
2850 + matElem_reg_Sum_2.xetla_format<float, 16, 2>()
2851 .xetla_select<16, 1, 1, 1>(0, 1);
2852
2853 xetla_vector<uint32_t, 16> address_fsum
2854 = xetla_vector_gen<uint32_t, 16>(0, 1);
2855 int address_offset
2856 = (batchid * numhead + headid) * Max_SeqLen
2857 + all_vert_stride * all_vert_loop + tid_y * 16;
2858 address_fsum += address_offset;
2859 address_fsum *= sizeof(float);
2860
2861 xetla_mask<16> pred = 1;
2864 (uint64_t)args->matSum_ptr, address_fsum,
2865 matElem_reg_Sum_1.xetla_select<16, 1>(0), pred);
2866
2867 first_nbarr.arrive();
2868 first_nbarr.wait();
2869
2870 matElem_reg_Sum_1 = xetla_load_global<float, 1,
2873 16>(args->matSum_ptr, address_fsum);
2874
2875 matElem_reg_Sum_1 *= args->Scaling;
2876
2877#pragma unroll
2878 for (int i = 0; i < 16; i++) {
2879 matW_reg_4x16x16.xetla_select<16 * 16, 1>(16 * 16 * 0)
2880 .xetla_select<16, 1>(i * 16)
2881 = matW_reg_4x16x16
2882 .xetla_select<16 * 16, 1>(16 * 16 * 0)
2883 .xetla_select<16, 1>(i * 16)
2884 * matElem_reg_Sum_1[i];
2885
2886 matElem_reg_4x16x16
2887 .xetla_select<16 * 16, 1>(16 * 16 * 0)
2888 .xetla_select<16, 1>(i * 16)
2889 = matElem_reg_4x16x16
2890 .xetla_select<16 * 16, 1>(16 * 16 * 0)
2891 .xetla_select<16, 1>(i * 16)
2892 - matW_reg_4x16x16
2893 .xetla_select<16 * 16, 1>(16 * 16 * 0)
2894 .xetla_select<16, 1>(i * 16);
2895
2896 matElem_reg_4x16x16
2897 .xetla_select<16 * 16, 1>(16 * 16 * 0)
2898 .xetla_select<16, 1>(i * 16)
2899 *= args->Pinv;
2900 }
2901
2902 if (valid_block_16x16_x > 1) {
2903#pragma unroll
2904 for (int i = 0; i < 16; i++) {
2905 matW_reg_4x16x16
2906 .xetla_select<16 * 16, 1>(16 * 16 * 1)
2907 .xetla_select<16, 1>(i * 16)
2908 = matW_reg_4x16x16
2909 .xetla_select<16 * 16, 1>(
2910 16 * 16 * 1)
2911 .xetla_select<16, 1>(i * 16)
2912 * matElem_reg_Sum_1[i];
2913
2914 matElem_reg_4x16x16
2915 .xetla_select<16 * 16, 1>(16 * 16 * 1)
2916 .xetla_select<16, 1>(i * 16)
2917 = matElem_reg_4x16x16
2918 .xetla_select<16 * 16, 1>(
2919 16 * 16 * 1)
2920 .xetla_select<16, 1>(i * 16)
2921 - matW_reg_4x16x16
2922 .xetla_select<16 * 16, 1>(
2923 16 * 16 * 1)
2924 .xetla_select<16, 1>(i * 16);
2925
2926 matElem_reg_4x16x16
2927 .xetla_select<16 * 16, 1>(16 * 16 * 1)
2928 .xetla_select<16, 1>(i * 16)
2929 *= args->Pinv;
2930 }
2931
2932 if (valid_block_16x16_x > 2) {
2933#pragma unroll
2934 for (int i = 0; i < 16; i++) {
2935 matW_reg_4x16x16
2936 .xetla_select<16 * 16, 1>(16 * 16 * 2)
2937 .xetla_select<16, 1>(i * 16)
2938 = matW_reg_4x16x16
2939 .xetla_select<16 * 16, 1>(
2940 16 * 16 * 2)
2941 .xetla_select<16, 1>(i * 16)
2942 * matElem_reg_Sum_1[i];
2943
2944 matElem_reg_4x16x16
2945 .xetla_select<16 * 16, 1>(16 * 16 * 2)
2946 .xetla_select<16, 1>(i * 16)
2947 = matElem_reg_4x16x16
2948 .xetla_select<16 * 16, 1>(
2949 16 * 16 * 2)
2950 .xetla_select<16, 1>(i * 16)
2951 - matW_reg_4x16x16
2952 .xetla_select<16 * 16, 1>(
2953 16 * 16 * 2)
2954 .xetla_select<16, 1>(i * 16);
2955
2956 matElem_reg_4x16x16
2957 .xetla_select<16 * 16, 1>(16 * 16 * 2)
2958 .xetla_select<16, 1>(i * 16)
2959 *= args->Pinv;
2960 }
2961
2962 if (valid_block_16x16_x > 3) {
2963#pragma unroll
2964 for (int i = 0; i < 16; i++) {
2965 matW_reg_4x16x16
2966 .xetla_select<16 * 16, 1>(
2967 16 * 16 * 3)
2968 .xetla_select<16, 1>(i * 16)
2969 = matW_reg_4x16x16
2970 .xetla_select<16 * 16, 1>(
2971 16 * 16 * 3)
2972 .xetla_select<16, 1>(
2973 i * 16)
2974 * matElem_reg_Sum_1[i];
2975
2976 matElem_reg_4x16x16
2977 .xetla_select<16 * 16, 1>(
2978 16 * 16 * 3)
2979 .xetla_select<16, 1>(i * 16)
2980 = matElem_reg_4x16x16
2981 .xetla_select<16 * 16, 1>(
2982 16 * 16 * 3)
2983 .xetla_select<16, 1>(
2984 i * 16)
2985 - matW_reg_4x16x16
2986 .xetla_select<16 * 16, 1>(
2987 16 * 16 * 3)
2988 .xetla_select<16, 1>(
2989 i * 16);
2990
2991 matElem_reg_4x16x16
2992 .xetla_select<16 * 16, 1>(
2993 16 * 16 * 3)
2994 .xetla_select<16, 1>(i * 16)
2995 *= args->Pinv;
2996 }
2997 }
2998 }
2999 }
3000 }
3001
3002 //store
3003 switch (std_seqlen) {
3004 case 128: {
3005 matC_128x128_t matC_128x128;
3006 matC_128x128_payload_t matC_128x128_payload;
3007
3008 matC_128x128_payload.init(args->matdW_ptr,
3009 ld_st_width_c, ld_st_height_c, ld_st_pitch_c,
3010 ld_st_start_x_c, ld_st_start_y_c);
3011
3012 xetla_vector<float, 16 * 32> matElem_reg_store
3013 = matElem_reg_4x16x16.xetla_format<float>()
3014 .xetla_select<16 * 32, 1>(0);
3015 matC_128x128.reg = xetla_cvt<dtype_sfx, float>(
3016 matElem_reg_store);
3017
3019 matC_128x128, matC_128x128_payload);
3020 xetla_fence<memory_kind::untyped_global>();
3021 } break;
3022
3023 case 256: {
3024 matC_128x256_t matC_128x256;
3025 matC_128x256_payload_t matC_128x256_payload;
3026
3027 matC_128x256_payload.init(args->matdW_ptr,
3028 ld_st_width_c, ld_st_height_c, ld_st_pitch_c,
3029 ld_st_start_x_c, ld_st_start_y_c);
3030
3031 matC_128x256.reg = xetla_cvt<dtype_sfx, float>(
3032 matElem_reg_4x16x16);
3033
3035 matC_128x256, matC_128x256_payload);
3036 xetla_fence<memory_kind::untyped_global>();
3037 } break;
3038
3039 case 384: {
3040 matC_64x384_t matC_64x384;
3041 matC_64x384_payload_t matC_64x384_payload;
3042
3043 matC_64x384_payload.init(args->matdW_ptr, ld_st_width_c,
3044 ld_st_height_c, ld_st_pitch_c, ld_st_start_x_c,
3045 ld_st_start_y_c);
3046
3047 xetla_vector<float, 16 * 16 * 3> matElem_reg_store
3048 = matElem_reg_4x16x16.xetla_format<float>()
3049 .xetla_select<16 * 16 * 3, 1>(0);
3050 matC_64x384.reg = xetla_cvt<dtype_sfx, float>(
3051 matElem_reg_store);
3052
3053 subgroup::tile_store(matC_64x384, matC_64x384_payload);
3054 xetla_fence<memory_kind::untyped_global>();
3055 } break;
3056
3057 case 512: {
3058 matC_64x512_t matC_64x512;
3059 matC_64x512_payload_t matC_64x512_payload;
3060
3061 matC_64x512_payload.init(args->matdW_ptr, ld_st_width_c,
3062 ld_st_height_c, ld_st_pitch_c, ld_st_start_x_c,
3063 ld_st_start_y_c);
3064
3065 matC_64x512.reg = xetla_cvt<dtype_sfx, float>(
3066 matElem_reg_4x16x16);
3067
3068 subgroup::tile_store(matC_64x512, matC_64x512_payload);
3069 xetla_fence<memory_kind::untyped_global>();
3070 } break;
3071
3072 case 1024: {
3073 matC_32x1024_t matC_32x1024;
3074 matC_32x1024_payload_t matC_32x1024_payload;
3075
3076 matC_32x1024_payload.init(args->matdW_ptr,
3077 ld_st_width_c, ld_st_height_c, ld_st_pitch_c,
3078 ld_st_start_x_c, ld_st_start_y_c);
3079
3080 matC_32x1024.reg = xetla_cvt<dtype_sfx, float>(
3081 matElem_reg_4x16x16);
3082
3084 matC_32x1024, matC_32x1024_payload);
3085 xetla_fence<memory_kind::untyped_global>();
3086 } break;
3087
3088 case 2048: {
3089 matC_16x2048_t matC_16x2048;
3090 matC_16x2048_payload_t matC_16x2048_payload;
3091
3092 matC_16x2048_payload.init(args->matdW_ptr,
3093 ld_st_width_c, ld_st_height_c, ld_st_pitch_c,
3094 ld_st_start_x_c, ld_st_start_y_c);
3095
3096 matC_16x2048.reg = xetla_cvt<dtype_sfx, float>(
3097 matElem_reg_4x16x16);
3098
3100 matC_16x2048, matC_16x2048_payload);
3101 xetla_fence<memory_kind::untyped_global>();
3102 } break;
3103 } //switch
3104
3105 } //valid coputing
3106 else {
3107 first_nbarr.arrive();
3108 first_nbarr.wait();
3109 }
3110
3111 second_nbarr.arrive();
3112 second_nbarr.wait();
3113
3114 int all_vert128_loop = all_vert_loop >> all_vert128_shift;
3115 if (((((all_vert128_loop + 1) << all_vert128_shift) - 1)
3116 == all_vert_loop)
3117 || (all_vert128_shift == 0)) { //dQ
3118 gemm_arguments_128x64 gemm_arg_128x64;
3119 matAcc_128x64_t matAcc_128x64;
3120 matC_128x64_t matC_128x64;
3121 matC_128x64_payload_t matC_128x64_payload;
3122
3123 uint32_t width_a = tru_seqlen_ex;
3124 uint32_t height_a = (batchid * numhead + headid) * max_seqlen
3125 + tru_seqlen;
3126 uint32_t pitch_a = max_seqlen;
3127 int start_x_a = 0;
3128 int start_y_a = (batchid * numhead + headid) * max_seqlen
3129 + all_vert128_loop * 128;
3130
3131 gemm_arg_128x64.matA_base_desc.init({args->matdW_ptr},
3132 {width_a, height_a, pitch_a}, {start_x_a, start_y_a});
3133
3134 uint32_t width_b = (headid + 1) * hdsz;
3135 uint32_t height_b = tru_seqlen + seqlen_entry;
3136 uint32_t pitch_b = hiddensize;
3137 int start_x_b = headid * hdsz;
3138 int start_y_b = seqlen_entry;
3139
3140 gemm_arg_128x64.matB_base_desc.init({args->matK_ptr},
3141 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
3142
3143 gemm_arg_128x64.inner_loop_count
3144 = (wg_tile_out_k + k_stride - 1) / k_stride;
3145
3146 matAcc_128x64.init(0);
3147 gemm_op_128x64_t gemm_op_128x64;
3148 gemm_op_128x64(g_thd32_tid, matAcc_128x64, gemm_arg_128x64);
3149
3150 int ld_st_width_c = (headid + 1) * hdsz;
3151 int height_c = tru_seqlen + seqlen_entry;
3152 int pitch_c = hiddensize;
3153 int start_x_c = headid * hdsz
3154 + gemm_op_128x64_t::get_matC_offset_x(g_thd32_tid);
3155 int start_y_c = all_vert128_loop * 128 + seqlen_entry
3156 + gemm_op_128x64_t::get_matC_offset_y(g_thd32_tid);
3157
3158 matC_128x64_payload.init(args->matdQ_ptr, ld_st_width_c,
3159 height_c, pitch_c, start_x_c, start_y_c);
3160 subgroup::elemwise_cvt<matC_128x64_t, matAcc_128x64_t>(
3161 matC_128x64, matAcc_128x64);
3162 subgroup::tile_store(matC_128x64, matC_128x64_payload);
3163 }
3164 } //all_vert128_loop
3165
3166 for (int transp256_loop = 0; transp256_loop < transp256_loop_num;
3167 transp256_loop++) {
3168 gemm_arguments_256x64_trnp_a gemm_arg_256x64;
3169 matAcc_256x64_trnp_a_t matAcc_256x64;
3170 matC_256x64_trnp_a_t matC_256x64;
3171 matC_256x64_trnp_a_payload_t matC_256x64_payload;
3172
3173 uint32_t width_a = tru_seqlen_ex;
3174 uint32_t height_a
3175 = (batchid * numhead + headid) * max_seqlen + tru_seqlen;
3176 uint32_t pitch_a = max_seqlen;
3177 int start_x_a = transp256_loop * 256;
3178 int start_y_a = (batchid * numhead + headid) * max_seqlen;
3179
3180 gemm_arg_256x64.matA_base_desc.init({args->matdW_ptr},
3181 {height_a, width_a, pitch_a}, {start_y_a, start_x_a});
3182
3183 uint32_t width_b = (headid + 1) * hdsz;
3184 uint32_t height_b = tru_seqlen + seqlen_entry;
3185 uint32_t pitch_b = hiddensize;
3186 int start_x_b = headid * hdsz;
3187 int start_y_b = seqlen_entry;
3188
3189 gemm_arg_256x64.matB_base_desc.init({args->matQ_ptr},
3190 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
3191
3192 gemm_arg_256x64.inner_loop_count
3193 = (wg_tile_out_k + k_stride - 1) / k_stride;
3194
3195 matAcc_256x64.init(0);
3196 gemm_op_256x64_trnp_a_t gemm_op_256x64_trnp_a;
3197 gemm_op_256x64_trnp_a(g_thd32_tid, matAcc_256x64, gemm_arg_256x64);
3198
3199 int width_c = (headid + 1) * hdsz;
3200 int height_c = tru_seqlen + seqlen_entry;
3201 int pitch_c = hiddensize;
3202 int start_x_c = headid * hdsz
3203 + gemm_op_256x64_trnp_a_t::get_matC_offset_x(g_thd32_tid);
3204 int start_y_c = transp256_loop * 256 + seqlen_entry
3205 + gemm_op_256x64_trnp_a_t::get_matC_offset_y(g_thd32_tid);
3206
3207 matC_256x64_payload.init(args->matdK_ptr, width_c, height_c,
3208 pitch_c, start_x_c, start_y_c);
3210 matAcc_256x64_trnp_a_t>(matC_256x64, matAcc_256x64);
3211 subgroup::tile_store(matC_256x64, matC_256x64_payload);
3212
3213 all_nbarr.arrive();
3214 all_nbarr.wait();
3215 }
3216
3217 for (int transp128_loop = 0; transp128_loop < transp128_loop_num;
3218 transp128_loop++) {
3219 gemm_arguments_128x64_trnp_a gemm_arg_128x64;
3220 matAcc_128x64_trnp_a_t matAcc_128x64;
3221 matC_128x64_trnp_a_t matC_128x64;
3222 matC_128x64_trnp_a_payload_t matC_128x64_payload;
3223
3224 uint32_t width_a = tru_seqlen_ex;
3225 uint32_t height_a
3226 = (batchid * numhead + headid) * max_seqlen + tru_seqlen;
3227 uint32_t pitch_a = max_seqlen;
3228 int start_x_a = transp128_loop * 128 + offset_blk_128x128;
3229 int start_y_a = (batchid * numhead + headid) * max_seqlen;
3230
3231 gemm_arg_128x64.matA_base_desc.init({args->matdW_ptr},
3232 {height_a, width_a, pitch_a}, {start_y_a, start_x_a});
3233
3234 uint32_t width_b = (headid + 1) * hdsz;
3235 uint32_t height_b = tru_seqlen + seqlen_entry;
3236 uint32_t pitch_b = hiddensize;
3237 int start_x_b = headid * hdsz;
3238 int start_y_b = seqlen_entry;
3239
3240 gemm_arg_128x64.matB_base_desc.init({args->matQ_ptr},
3241 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
3242
3243 gemm_arg_128x64.inner_loop_count
3244 = (wg_tile_out_k + k_stride - 1) / k_stride;
3245
3246 matAcc_128x64.init(0);
3247 gemm_op_128x64_trnp_a_t gemm_op_128x64_trnp_a;
3248 gemm_op_128x64_trnp_a(g_thd32_tid, matAcc_128x64, gemm_arg_128x64);
3249
3250 int width_c = (headid + 1) * hdsz;
3251 int height_c = tru_seqlen + seqlen_entry;
3252 int pitch_c = hiddensize;
3253 int start_x_c = headid * hdsz
3254 + gemm_op_128x64_trnp_a_t::get_matC_offset_x(g_thd32_tid);
3255 int start_y_c = transp128_loop * 128 + seqlen_entry
3256 + offset_blk_128x128
3257 + gemm_op_128x64_trnp_a_t::get_matC_offset_y(g_thd32_tid);
3258
3259 matC_128x64_payload.init(args->matdK_ptr, width_c, height_c,
3260 pitch_c, start_x_c, start_y_c);
3262 matAcc_128x64_trnp_a_t>(matC_128x64, matAcc_128x64);
3263 subgroup::tile_store(matC_128x64, matC_128x64_payload);
3264
3265 all_nbarr.arrive();
3266 all_nbarr.wait();
3267 } //transp128_loop
3268
3269 } //xetla_softmax_bwd_t::call
3270}; //struct xetla_softmax_bwd_t
3271
3272} // namespace gpu::xetla::kernel
Gemm functor.
Definition api.hpp:52
#define __XETLA_API
Definition common.hpp:43
#define xetla_select
xetla select.
Definition base_ops.hpp:49
#define xetla_format
xetla format.
Definition base_ops.hpp:38
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
__XETLA_API xetla_vector< Ty, N *NElts > xetla_load_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_mask< N > pred=1)
Stateless scattered load.
Definition memory.hpp:245
__XETLA_API xetla_vector< uint32_t, 4 > get_time_stamp()
Returns time stamp.
Definition misc.hpp:57
__XETLA_API std::enable_if_t< arch_tag==gpu_arch::Xe, void > xetla_tatomic_store_global(uint64_t base_address, xetla_vector< Toffset, N > offset, xetla_vector< Ty, N > data, xetla_mask< N > pred=1)
Tensor atomic store API.
Definition raw_send_load_store.hpp:294
#define rand_threshold_const
Definition mha_attn_reg.hpp:26
#define list_width
Definition mha_attn_reg.hpp:25
#define SIGN_BIT_W16
Definition mha_attn_reg.hpp:28
#define SIGN_BIT_B8
Definition mha_attn_reg.hpp:29
Definition limitation.hpp:734
__XETLA_API std::enable_if_t<(T_src::register_layout !=reg_layout::linear) &&(T_dst::register_layout !=reg_layout::linear) &&is_same_layout< T_dst, T_src >::value &&(!is_floating_to_integer< T_dst, T_src >::value)> elemwise_cvt(T_dst &dst, T_src &src)
Is the element wise data conversion, the src and dst tile should have the same layout.
Definition op_function.hpp:40
__XETLA_API std::enable_if_t< detail::check_store_type< tile_t, payload_t >::is_global_2d_xe > tile_store(tile_t &tile, payload_t &payload)
Is the func storing data from register file to global memory.
Definition store_xe.hpp:91
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
mem_space
Definition common.hpp:77
@ fmax
Atomic store the float max of src1 and memory data and return the old value. see
@ fadd
Atomic float add of src1 from memory data and return the old value. see
gpu_arch
Definition common.hpp:73
mem_layout
Definition common.hpp:76
Compute attribute for gemm.
Definition common.hpp:32
Compute policy for xmx engine.
Definition compute_policy.hpp:35
Fine-tune knobs for gemm.
Definition common.hpp:43
Gemm default pre_processing functor.
Definition api.hpp:33
Gemm pre_processing functor with applying relu op to matA.
Definition api.hpp:39
Workgroup level tile shape description.
Definition tile_shape.hpp:34
Arguments for xetla_softmax_bwd_t::run.
Definition mha_attn_reg.hpp:2026
dtype_sfx * matW_ptr
Definition mha_attn_reg.hpp:2035
float * matSum_ptr
Definition mha_attn_reg.hpp:2040
dtype_bin * matV_ptr
Definition mha_attn_reg.hpp:2031
uint32_t * matMkdpot_ptr
Definition mha_attn_reg.hpp:2033
dtype_sfx * matdW_ptr
Definition mha_attn_reg.hpp:2036
float Pinv
Definition mha_attn_reg.hpp:2041
dtype_bin * matdO_ptr
Definition mha_attn_reg.hpp:2034
dtype_bin * matQ_ptr
Definition mha_attn_reg.hpp:2029
dtype_bot * matdK_ptr
Definition mha_attn_reg.hpp:2039
dtype_bin * matK_ptr
Definition mha_attn_reg.hpp:2030
uint32_t * matMkin_ptr
Definition mha_attn_reg.hpp:2032
uint32_t * mList_ptr
Definition mha_attn_reg.hpp:2028
dtype_bot * matdQ_ptr
Definition mha_attn_reg.hpp:2038
dtype_bot * matdV_ptr
Definition mha_attn_reg.hpp:2037
float Scaling
Definition mha_attn_reg.hpp:2042
Definition mha_attn_reg.hpp:1641
static constexpr uint32_t prefetch_distance
Definition mha_attn_reg.hpp:1669
dtype_bwd_acc_ dtype_acc
Definition mha_attn_reg.hpp:1645
subgroup::tile_desc_t< matAcc_32x1024_t::tile_desc::tile_size_x, matAcc_32x1024_t::tile_desc::tile_size_y, matAcc_32x1024_t::tile_desc::block_size_x, matAcc_32x1024_t::tile_desc::block_size_y, reg_layout::tiled > matC_32x1024_tile_desc_t
Definition mha_attn_reg.hpp:1830
group::compute_policy_default_xmx< group::compute_attr_t< dtype_sfx, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe > compute_policy_out_b_trnp_a
Definition mha_attn_reg.hpp:1707
static constexpr mem_space mem_space_a
Definition mha_attn_reg.hpp:1649
subgroup::tile_desc_t< matAcc_64x384_t::tile_desc::tile_size_x, matAcc_64x384_t::tile_desc::tile_size_y, matAcc_64x384_t::tile_desc::block_size_x, matAcc_64x384_t::tile_desc::block_size_y, reg_layout::tiled > matC_64x384_tile_desc_t
Definition mha_attn_reg.hpp:1818
typename gemm_op_128x64_trnp_a_t::matAcc_t matAcc_128x64_trnp_a_t
Definition mha_attn_reg.hpp:1796
group::tile_shape_t< 64, 128, 16, 16 > tile_attr_128x64
Definition mha_attn_reg.hpp:1683
typename gemm_op_64x512_t::matAcc_t matAcc_64x512_t
Definition mha_attn_reg.hpp:1791
typename gemm_op_16x2048_t::arguments_t gemm_arguments_16x2048
Definition mha_attn_reg.hpp:1776
static constexpr mem_layout gemm_mem_layout_QKT_b
Definition mha_attn_reg.hpp:1665
group::perf_tuning_knob_t< k_stride, prefetch_distance, periodic_sync_interval > bgm_perf_tuning_knob
Definition mha_attn_reg.hpp:1674
group::tile_shape_t< 256, 128, 64, 16 > tile_attr_128x256
Definition mha_attn_reg.hpp:1677
subgroup::tile_desc_t< matAcc_128x64_t::tile_desc::tile_size_x, matAcc_128x64_t::tile_desc::tile_size_y, matAcc_128x64_t::tile_desc::block_size_x, matAcc_128x64_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x64_tile_desc_t
Definition mha_attn_reg.hpp:1900
typename gemm_op_128x64_trnp_a_t::arguments_t gemm_arguments_128x64_trnp_a
Definition mha_attn_reg.hpp:1780
dtype_bwd_sfx_ dtype_sfx
Definition mha_attn_reg.hpp:1644
typename gemm_op_256x64_trnp_af_t::matAcc_t matAcc_256x64_trnp_af_t
Definition mha_attn_reg.hpp:1799
typename gemm_op_256x64_trnp_af_t::arguments_t gemm_arguments_256x64_trnp_af
Definition mha_attn_reg.hpp:1786
group::compute_policy_default_xmx< group::compute_attr_t< dtype_sfx, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe > compute_policy_out
Definition mha_attn_reg.hpp:1699
subgroup::tile_t< dtype_bot, matC_128x64_trnp_a_tile_desc_t > matC_128x64_trnp_a_t
Definition mha_attn_reg.hpp:1925
work_group_t< ThreadNum > work_group_t
Definition mha_attn_reg.hpp:1714
subgroup::tile_desc_t< matAcc_128x128_t::tile_desc::tile_size_x, matAcc_128x128_t::tile_desc::tile_size_y, matAcc_128x128_t::tile_desc::block_size_x, matAcc_128x128_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x128_tile_desc_t
Definition mha_attn_reg.hpp:1806
static constexpr mem_layout gemm_mem_layout_a
Definition mha_attn_reg.hpp:1661
dtype_bwd_bot_ dtype_bot
Definition mha_attn_reg.hpp:1643
group::tile_shape_t< 64, 256, 16, 32 > tile_attr_256x64
Definition mha_attn_reg.hpp:1682
typename gemm_op_256x64_trnp_a_t::arguments_t gemm_arguments_256x64_trnp_a
Definition mha_attn_reg.hpp:1782
typename gemm_op_128x64_trnp_af_t::matAcc_t matAcc_128x64_trnp_af_t
Definition mha_attn_reg.hpp:1798
group::tile_shape_t< 512, 64, 64, 16 > tile_attr_64x512
Definition mha_attn_reg.hpp:1679
static constexpr mem_space gemm_mem_space_trnp_a
Definition mha_attn_reg.hpp:1660
static constexpr uint32_t global_kslicing
Definition mha_attn_reg.hpp:1709
subgroup::tile_desc_t< matAcc_128x64_trnp_af_t::tile_desc::tile_size_x, matAcc_128x64_trnp_af_t::tile_desc::tile_size_y, matAcc_128x64_trnp_af_t::tile_desc::block_size_x, matAcc_128x64_trnp_af_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x64_trnp_af_tile_desc_t
Definition mha_attn_reg.hpp:1916
static constexpr mem_layout mem_layout_out_b
Definition mha_attn_reg.hpp:1656
typename gemm_op_128x256_t::matAcc_t matAcc_128x256_t
Definition mha_attn_reg.hpp:1789
subgroup::tile_desc_t< matAcc_128x64_trnp_a_t::tile_desc::tile_size_x, matAcc_128x64_trnp_a_t::tile_desc::tile_size_y, matAcc_128x64_trnp_a_t::tile_desc::block_size_x, matAcc_128x64_trnp_a_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x64_trnp_a_tile_desc_t
Definition mha_attn_reg.hpp:1905
static constexpr mem_layout mem_layout_trnp_a
Definition mha_attn_reg.hpp:1654
subgroup::tile_t< dtype_bot, matC_256x64_trnp_a_tile_desc_t > matC_256x64_trnp_a_t
Definition mha_attn_reg.hpp:1927
typename gemm_op_32x1024_t::matAcc_t matAcc_32x1024_t
Definition mha_attn_reg.hpp:1792
static constexpr mem_layout mem_layout_QKT_b
Definition mha_attn_reg.hpp:1655
typename gemm_op_128x128_t::arguments_t gemm_arguments_128x128
Definition mha_attn_reg.hpp:1771
mem_desc_t< dtype_bin, gemm_mem_layout_QKT_b, gemm_mem_space_b > mem_desc_b_QKT
Definition mha_attn_reg.hpp:1688
group::tile_shape_t< 1024, 32, 64, 16 > tile_attr_32x1024
Definition mha_attn_reg.hpp:1680
subgroup::tile_desc_t< matAcc_64x512_t::tile_desc::tile_size_x, matAcc_64x512_t::tile_desc::tile_size_y, matAcc_64x512_t::tile_desc::block_size_x, matAcc_64x512_t::tile_desc::block_size_y, reg_layout::tiled > matC_64x512_tile_desc_t
Definition mha_attn_reg.hpp:1824
subgroup::tile_t< dtype_bot, matC_256x64_trnp_af_tile_desc_t > matC_256x64_trnp_af_t
Definition mha_attn_reg.hpp:1931
static constexpr uint32_t k_stride
Definition mha_attn_reg.hpp:1672
typename gemm_op_64x384_t::matAcc_t matAcc_64x384_t
Definition mha_attn_reg.hpp:1790
static constexpr mem_layout gemm_mem_layout_out_b
Definition mha_attn_reg.hpp:1666
subgroup::tile_t< dtype_bot, matC_128x64_trnp_af_tile_desc_t > matC_128x64_trnp_af_t
Definition mha_attn_reg.hpp:1929
typename gemm_op_128x128_t::matAcc_t matAcc_128x128_t
Definition mha_attn_reg.hpp:1788
group::tile_shape_t< 2048, 16, 64, 16 > tile_attr_16x2048
Definition mha_attn_reg.hpp:1681
mem_desc_t< dtype_sfx, gemm_mem_layout_trnp_a, gemm_mem_space_trnp_a > mem_desc_a_out_b_trnp_a
Definition mha_attn_reg.hpp:1702
typename gemm_op_32x1024_t::arguments_t gemm_arguments_32x1024
Definition mha_attn_reg.hpp:1775
typename gemm_op_64x384_t::arguments_t gemm_arguments_64x384
Definition mha_attn_reg.hpp:1773
mem_desc_t< dtype_sfx, gemm_mem_layout_a, gemm_mem_space_a > mem_desc_a_out
Definition mha_attn_reg.hpp:1694
subgroup::tile_desc_t< matAcc_128x256_t::tile_desc::tile_size_x, matAcc_128x256_t::tile_desc::tile_size_y, matAcc_128x256_t::tile_desc::block_size_x, matAcc_128x256_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x256_tile_desc_t
Definition mha_attn_reg.hpp:1812
static constexpr mem_space mem_space_c
Definition mha_attn_reg.hpp:1651
typename gemm_op_128x64_t::matAcc_t matAcc_128x64_t
Definition mha_attn_reg.hpp:1795
group::compute_policy_default_xmx< group::compute_attr_t< dtype_bin, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe > compute_policy_QKT
Definition mha_attn_reg.hpp:1691
mem_desc_t< dtype_bin, gemm_mem_layout_a, gemm_mem_space_a > mem_desc_a_QKT
Definition mha_attn_reg.hpp:1686
typename gemm_op_128x64_t::arguments_t gemm_arguments_128x64
Definition mha_attn_reg.hpp:1778
group::tile_shape_t< 384, 64, 48, 16 > tile_attr_64x384
Definition mha_attn_reg.hpp:1678
typename gemm_op_64x512_t::arguments_t gemm_arguments_64x512
Definition mha_attn_reg.hpp:1774
static constexpr mem_space gemm_mem_space_b
Definition mha_attn_reg.hpp:1664
static constexpr uint32_t periodic_sync_interval
Definition mha_attn_reg.hpp:1668
static constexpr mem_layout gemm_mem_layout_trnp_a
Definition mha_attn_reg.hpp:1662
static constexpr mem_space mem_space_b
Definition mha_attn_reg.hpp:1650
typename gemm_op_256x64_trnp_a_t::matAcc_t matAcc_256x64_trnp_a_t
Definition mha_attn_reg.hpp:1797
static __XETLA_API void call(sycl::nd_item< 3 > &item, arguments_t *args)
Main execution function for fused mha softmax The basic process is GEMM -> Softmax -> GEMM.
Definition mha_attn_reg.hpp:2048
static constexpr mem_space gemm_mem_space_a
Definition mha_attn_reg.hpp:1659
subgroup::tile_desc_t< matAcc_256x64_trnp_af_t::tile_desc::tile_size_x, matAcc_256x64_trnp_af_t::tile_desc::tile_size_y, matAcc_256x64_trnp_af_t::tile_desc::block_size_x, matAcc_256x64_trnp_af_t::tile_desc::block_size_y, reg_layout::tiled > matC_256x64_trnp_af_tile_desc_t
Definition mha_attn_reg.hpp:1922
static constexpr int ThreadNum
Definition mha_attn_reg.hpp:1647
subgroup::tile_desc_t< matAcc_16x2048_t::tile_desc::tile_size_x, matAcc_16x2048_t::tile_desc::tile_size_y, matAcc_16x2048_t::tile_desc::block_size_x, matAcc_16x2048_t::tile_desc::block_size_y, reg_layout::tiled > matC_16x2048_tile_desc_t
Definition mha_attn_reg.hpp:1836
static constexpr uint16_t sfx_type_size
Definition mha_attn_reg.hpp:1710
mem_desc_t< dtype_bin, gemm_mem_layout_out_b, gemm_mem_space_b > mem_desc_b_out
Definition mha_attn_reg.hpp:1696
dtype_bwd_bin_ dtype_bin
Definition mha_attn_reg.hpp:1642
typename gemm_op_128x64_trnp_af_t::arguments_t gemm_arguments_128x64_trnp_af
Definition mha_attn_reg.hpp:1784
typename gemm_op_128x256_t::arguments_t gemm_arguments_128x256
Definition mha_attn_reg.hpp:1772
mem_desc_t< dtype_bin, gemm_mem_layout_out_b, gemm_mem_space_b > mem_desc_b_out_b_trnp_a
Definition mha_attn_reg.hpp:1704
static constexpr mem_layout mem_layout_c
Definition mha_attn_reg.hpp:1657
group::tile_shape_t< 128, 128, 32, 16 > tile_attr_128x128
Definition mha_attn_reg.hpp:1676
subgroup::tile_desc_t< matAcc_256x64_trnp_a_t::tile_desc::tile_size_x, matAcc_256x64_trnp_a_t::tile_desc::tile_size_y, matAcc_256x64_trnp_a_t::tile_desc::block_size_x, matAcc_256x64_trnp_a_t::tile_desc::block_size_y, reg_layout::tiled > matC_256x64_trnp_a_tile_desc_t
Definition mha_attn_reg.hpp:1910
typename gemm_op_16x2048_t::matAcc_t matAcc_16x2048_t
Definition mha_attn_reg.hpp:1793
static constexpr mem_layout mem_layout_a
Definition mha_attn_reg.hpp:1653
Arguments for xetla_softmax_fwd_t::run.
Definition mha_attn_reg.hpp:301
dtype_bot * matOut_ptr
Definition mha_attn_reg.hpp:310
uint32_t * matMkdpot_ptr
Definition mha_attn_reg.hpp:308
dtype_sfx * matQKT_ptr
Definition mha_attn_reg.hpp:309
float * Max_ptr
Definition mha_attn_reg.hpp:311
dtype_bin * matV_ptr
Definition mha_attn_reg.hpp:306
float Scaling
Definition mha_attn_reg.hpp:314
float * Sum_ptr
Definition mha_attn_reg.hpp:312
uint32_t * matMkin_ptr
Definition mha_attn_reg.hpp:307
uint32_t * mList_ptr
Definition mha_attn_reg.hpp:303
dtype_bin * matQ_ptr
Definition mha_attn_reg.hpp:304
dtype_bin * matK_ptr
Definition mha_attn_reg.hpp:305
Definition mha_attn_reg.hpp:34
subgroup::tile_desc_t< matAcc_32x1024_t::tile_desc::tile_size_x, matAcc_32x1024_t::tile_desc::tile_size_y, matAcc_32x1024_t::tile_desc::block_size_x, matAcc_32x1024_t::tile_desc::block_size_y, reg_layout::tiled > mat_32x1024_tile_desc_t
Definition mha_attn_reg.hpp:177
typename gemm_op_16x2048_t::matAcc_t matAcc_16x2048_t
Definition mha_attn_reg.hpp:145
group::compute_policy_default_xmx< group::compute_attr_t< dtype_sfx, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe > compute_policy_out
Definition mha_attn_reg.hpp:88
static constexpr mem_layout mem_layout_out_b
Definition mha_attn_reg.hpp:49
static constexpr mem_layout mem_layout_QKT_b
Definition mha_attn_reg.hpp:48
subgroup::tile_desc_t< matAcc_16x2048_t::tile_desc::tile_size_x, matAcc_16x2048_t::tile_desc::tile_size_y, matAcc_16x2048_t::tile_desc::block_size_x, matAcc_16x2048_t::tile_desc::block_size_y, reg_layout::tiled > mat_16x2048_tile_desc_t
Definition mha_attn_reg.hpp:183
static constexpr mem_layout mem_layout_c
Definition mha_attn_reg.hpp:50
group::tile_shape_t< 64, 128, 16, 16 > tile_attr_128x64
Definition mha_attn_reg.hpp:72
dtype_acc_ dtype_acc
Definition mha_attn_reg.hpp:38
static constexpr mem_space mem_space_a
Definition mha_attn_reg.hpp:42
static constexpr mem_layout gemm_mem_layout_QKT_b
Definition mha_attn_reg.hpp:56
group::tile_shape_t< 1024, 32, 64, 16 > tile_attr_32x1024
Definition mha_attn_reg.hpp:70
static constexpr uint32_t prefetch_distance
Definition mha_attn_reg.hpp:60
static constexpr uint32_t k_stride
Definition mha_attn_reg.hpp:62
typename gemm_op_64x384_t::matAcc_t matAcc_64x384_t
Definition mha_attn_reg.hpp:142
static constexpr uint16_t sfx_type_size
Definition mha_attn_reg.hpp:91
group::perf_tuning_knob_t< k_stride, prefetch_distance, periodic_sync_interval > bgm_perf_tuning_knob
Definition mha_attn_reg.hpp:64
static constexpr uint16_t Rand_SIMD
Definition mha_attn_reg.hpp:45
mem_desc_t< dtype_sfx, gemm_mem_layout_a, gemm_mem_space_a > mem_desc_a_out
Definition mha_attn_reg.hpp:83
dtype_bin_ dtype_bin
Definition mha_attn_reg.hpp:35
group::compute_policy_default_xmx< group::compute_attr_t< dtype_bin, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe > compute_policy_QKT
Definition mha_attn_reg.hpp:80
typename gemm_op_64x512_t::arguments_t gemm_arguments_64x512
Definition mha_attn_reg.hpp:135
static constexpr mem_layout gemm_mem_layout_a
Definition mha_attn_reg.hpp:53
subgroup::tile_desc_t< matAcc_128x128_t::tile_desc::tile_size_x, matAcc_128x128_t::tile_desc::tile_size_y, matAcc_128x128_t::tile_desc::block_size_x, matAcc_128x128_t::tile_desc::block_size_y, reg_layout::tiled > mat_128x128_tile_desc_t
Definition mha_attn_reg.hpp:153
mem_desc_t< dtype_bin, gemm_mem_layout_QKT_b, gemm_mem_space_b > mem_desc_b_QKT
Definition mha_attn_reg.hpp:77
group::tile_shape_t< 256, 128, 64, 16 > tile_attr_128x256
Definition mha_attn_reg.hpp:67
group::tile_shape_t< 384, 64, 48, 16 > tile_attr_64x384
Definition mha_attn_reg.hpp:68
typename gemm_op_128x256_t::matAcc_t matAcc_128x256_t
Definition mha_attn_reg.hpp:141
dtype_bot_ dtype_bot
Definition mha_attn_reg.hpp:36
group::tile_shape_t< 512, 64, 64, 16 > tile_attr_64x512
Definition mha_attn_reg.hpp:69
dtype_sfx_ dtype_sfx
Definition mha_attn_reg.hpp:37
static constexpr int ThreadNum
Definition mha_attn_reg.hpp:40
group::tile_shape_t< 2048, 16, 64, 16 > tile_attr_16x2048
Definition mha_attn_reg.hpp:71
subgroup::tile_desc_t< matAcc_128x256_t::tile_desc::tile_size_x, matAcc_128x256_t::tile_desc::tile_size_y, matAcc_128x256_t::tile_desc::block_size_x, matAcc_128x256_t::tile_desc::block_size_y, reg_layout::tiled > mat_128x256_tile_desc_t
Definition mha_attn_reg.hpp:159
subgroup::tile_desc_t< matAcc_64x384_t::tile_desc::tile_size_x, matAcc_64x384_t::tile_desc::tile_size_y, matAcc_64x384_t::tile_desc::block_size_x, matAcc_64x384_t::tile_desc::block_size_y, reg_layout::tiled > mat_64x384_tile_desc_t
Definition mha_attn_reg.hpp:165
subgroup::tile_desc_t< matAcc_128x64_t::tile_desc::tile_size_x, matAcc_128x64_t::tile_desc::tile_size_y, matAcc_128x64_t::tile_desc::block_size_x, matAcc_128x64_t::tile_desc::block_size_y, reg_layout::tiled > mat_128x64_tile_desc_t
Definition mha_attn_reg.hpp:189
group::tile_shape_t< 128, 128, 32, 16 > tile_attr_128x128
Definition mha_attn_reg.hpp:66
typename gemm_op_128x128_t::arguments_t gemm_arguments_128x128
Definition mha_attn_reg.hpp:132
typename gemm_op_128x64_t::matAcc_t matAcc_128x64_t
Definition mha_attn_reg.hpp:146
work_group_t< ThreadNum > work_group_t
Definition mha_attn_reg.hpp:95
typename gemm_op_64x384_t::arguments_t gemm_arguments_64x384
Definition mha_attn_reg.hpp:134
static constexpr mem_space mem_space_b
Definition mha_attn_reg.hpp:43
typename gemm_op_128x64_t::arguments_t gemm_arguments_128x64
Definition mha_attn_reg.hpp:138
mem_desc_t< dtype_bin, gemm_mem_layout_out_b, gemm_mem_space_b > mem_desc_b_out
Definition mha_attn_reg.hpp:85
static constexpr mem_space gemm_mem_space_b
Definition mha_attn_reg.hpp:55
typename gemm_op_128x128_t::matAcc_t matAcc_128x128_t
Definition mha_attn_reg.hpp:140
mem_desc_t< dtype_bin, gemm_mem_layout_a, gemm_mem_space_a > mem_desc_a_QKT
Definition mha_attn_reg.hpp:75
static constexpr mem_space mem_space_c
Definition mha_attn_reg.hpp:44
typename gemm_op_16x2048_t::arguments_t gemm_arguments_16x2048
Definition mha_attn_reg.hpp:137
static constexpr mem_layout gemm_mem_layout_out_b
Definition mha_attn_reg.hpp:57
typename gemm_op_64x512_t::matAcc_t matAcc_64x512_t
Definition mha_attn_reg.hpp:143
static constexpr uint32_t global_kslicing
Definition mha_attn_reg.hpp:90
static constexpr uint32_t periodic_sync_interval
Definition mha_attn_reg.hpp:59
typename gemm_op_32x1024_t::arguments_t gemm_arguments_32x1024
Definition mha_attn_reg.hpp:136
static constexpr mem_space gemm_mem_space_a
Definition mha_attn_reg.hpp:52
static __XETLA_API void call(sycl::nd_item< 3 > &item, arguments_t *args)
Main execution function for fused mha softmax The basic process is GEMM -> Softmax -> GEMM.
Definition mha_attn_reg.hpp:320
subgroup::tile_desc_t< matAcc_64x512_t::tile_desc::tile_size_x, matAcc_64x512_t::tile_desc::tile_size_y, matAcc_64x512_t::tile_desc::block_size_x, matAcc_64x512_t::tile_desc::block_size_y, reg_layout::tiled > mat_64x512_tile_desc_t
Definition mha_attn_reg.hpp:171
typename gemm_op_32x1024_t::matAcc_t matAcc_32x1024_t
Definition mha_attn_reg.hpp:144
static constexpr int max_seqlen
Definition mha_attn_reg.hpp:41
static constexpr mem_layout mem_layout_a
Definition mha_attn_reg.hpp:47
typename gemm_op_128x256_t::arguments_t gemm_arguments_128x256
Definition mha_attn_reg.hpp:133
Definition memory_descriptor.hpp:139
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102
xetla nbarrier definition API.
Definition raw_send_nbarrier.hpp:43
__XETLA_API void arrive()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:65
__XETLA_API void init_nbarrier(uint8_t nbarrier_id, nbarrier_role role=nbarrier_role::producer_consumer)
Definition raw_send_nbarrier.hpp:55
__XETLA_API void wait()
named barrier wait within subgroup.
Definition raw_send_nbarrier.hpp:76
Definition rand.hpp:30
__XETLA_API xetla_vector< uint32_t, 4 *SIMD > rand()
Definition rand.hpp:57
__XETLA_API void init(uint64_t seed, uint64_t subseq, uint64_t offset)
Definition rand.hpp:38