XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
mha_core_attn.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#pragma once
18
19#include "common/common.hpp"
20#include "group/group.hpp"
21#include "subgroup/subgroup.hpp"
22
23namespace gpu::xetla::kernel {
24
25#define list_width 16
26#define rand_threshold_const 0x80000000
27#define SIGN_BIT_DW 0x80000000
28#define SIGN_BIT_W16 0x8000
29#define SIGN_BIT_B8 0x80
30
41template <typename dtype_bin_, typename dtype_bot_, typename dtype_sfx_,
42 typename dtype_acc_, int HWThreadNum, bool Dopt_RandGenflag = true,
43 uint16_t RandSIMD = 16, int Max_SeqLen = 512>
45 using dtype_bin = dtype_bin_;
46 using dtype_bot = dtype_bot_;
47 using dtype_sfx = dtype_sfx_;
48 using dtype_acc = dtype_acc_;
49
50 static constexpr int ThreadNum = HWThreadNum;
51 static constexpr int max_seqlen = Max_SeqLen;
55 static constexpr uint16_t Rand_SIMD = RandSIMD;
56
61
64
68
69 static constexpr uint32_t periodic_sync_interval = 0;
70 static constexpr uint32_t prefetch_distance = 3;
71 static constexpr uint32_t k_stride
72 = 32 / sizeof(dtype_bin); //gemm_t::k_stride;
75
79
87
95
96 static constexpr uint32_t global_kslicing = 1;
97 static constexpr uint16_t sfx_type_size = sizeof(dtype_sfx);
98 static_assert((sfx_type_size == 1) || (sfx_type_size == 2)
99 || (sfx_type_size == 4));
100
102
110
119
120 using gemm_arguments_128x128 = typename gemm_op_128x128_t::arguments_t;
121 using gemm_arguments_128x256 = typename gemm_op_128x256_t::arguments_t;
122 using gemm_arguments_128x64 = typename gemm_op_128x64_t::arguments_t;
123
124 using matAcc_128x128_t = typename gemm_op_128x128_t::matAcc_t;
125 using matAcc_128x256_t = typename gemm_op_128x256_t::matAcc_t;
126 using matAcc_128x64_t = typename gemm_op_128x64_t::matAcc_t;
127
129 = subgroup::tile_desc_t<matAcc_128x128_t::tile_desc::tile_size_x,
130 matAcc_128x128_t::tile_desc::tile_size_y,
131 matAcc_128x128_t::tile_desc::block_size_x,
132 matAcc_128x128_t::tile_desc::block_size_y,
135 = subgroup::tile_desc_t<matAcc_128x256_t::tile_desc::tile_size_x,
136 matAcc_128x256_t::tile_desc::tile_size_y,
137 matAcc_128x256_t::tile_desc::block_size_x,
138 matAcc_128x256_t::tile_desc::block_size_y,
141 = subgroup::tile_desc_t<matAcc_128x64_t::tile_desc::tile_size_x,
142 matAcc_128x64_t::tile_desc::tile_size_y,
143 matAcc_128x64_t::tile_desc::block_size_x,
144 matAcc_128x64_t::tile_desc::block_size_y,
155 gpu_arch::Xe>;
160 gpu_arch::Xe>;
165 gpu_arch::Xe>;
166
167 //512 = 16x32 or 8x64
177 subgroup::msg_type_v<matElem_tile_desc_t, mem_space::global>,
184 subgroup::msg_type_v<matElem_tile_desc_t, mem_space::global>,
189
193 struct arguments_t {
194 // assume base address, surface width, height, pitch, start coordinate was set
195 uint32_t *mList_ptr;
199 uint32_t *matMkin_ptr;
200 uint32_t *matMkdpot_ptr;
203 float Pinv;
204 float Scaling;
205 };
206
213 __XETLA_API static void call(sycl::nd_item<3> &item, arguments_t *args) {
214
215 int tru_seqlen = 0;
216 int tru_seqlen_ex = 0;
217 int seqlen_entry = 0;
218
219 int groupid = item.get_group(0);
220 int hiddensize = 1024;
221 int numhead = 16;
222 int hdsz = 64;
223 int wg_tile_QKT_k = hdsz; //args->matrix_k;
224 int wg_tile_out_k;
225 int batchid = groupid / numhead;
226 int headid = groupid % numhead;
227
228 work_group_t g_thd32_tid;
229 int tid_linear = item.get_local_linear_id();
230 g_thd32_tid.init(tid_linear);
231
232 uint32_t batch_offset = sizeof(uint32_t) * list_width * batchid;
234 = xetla_vector_gen<uint32_t, list_width>(0, 1);
235 list_offsets *= sizeof(uint32_t);
236 list_offsets += batch_offset;
237
241 list_width>(args->mList_ptr, list_offsets);
242 tru_seqlen = list_vec[0];
243 seqlen_entry = list_vec[1];
244 wg_tile_out_k = tru_seqlen;
245 tru_seqlen_ex = tru_seqlen; //DW align
246 if (sfx_type_size == 2)
247 tru_seqlen_ex = (((tru_seqlen + 1) >> 1) << 1);
248 else if (sfx_type_size == 1)
249 tru_seqlen_ex = (((tru_seqlen + 3) >> 2) << 2);
250 //float totalscaling = args->Pinv * args->Scaling;
251
253 uint32_t rand_threshold = rand_threshold_const;
254 if constexpr (Dopt_RandGenflag == true) {
255 uint64_t rand_seed = 67280421310721;
256 uint64_t rand_subseq
257 = (groupid * ThreadNum + tid_linear) * Rand_SIMD;
258 uint64_t rand_offset = list_vec.xetla_format<uint64_t>()[1];
259 if (list_vec[4] != 0) rand_threshold = list_vec[4];
260 if (rand_offset == 0) {
262 rand_offset = time_stamp.xetla_format<uint64_t>()[0];
263 }
264 Rand_Gen.init(rand_seed, rand_subseq, rand_offset);
265 }
266
267 //std_leqlen = 256
268 int all_vert_loop_num = 2;
269 int blk_128x128_one = 0;
270 int blk_128x256_loop_num = 1;
271 int offset_blk_128x128 = 0;
272
273 int std_seqlen;
274 if (tru_seqlen <= 128) {
275 std_seqlen = 128;
276 all_vert_loop_num = 1;
277 blk_128x128_one = 1;
278 blk_128x256_loop_num = 0;
279 } else if (tru_seqlen <= 256)
280 std_seqlen = 256;
281 else if (tru_seqlen <= 384) {
282 std_seqlen = 384;
283 all_vert_loop_num = 3;
284 blk_128x128_one = 1;
285 blk_128x256_loop_num = 1;
286 offset_blk_128x128 = 256;
287 } else {
288 std_seqlen = 512;
289 all_vert_loop_num = 4;
290 blk_128x128_one = 0;
291 blk_128x256_loop_num = 2;
292 }
293
298
299 for (int all_vert128_loop = 0; all_vert128_loop < all_vert_loop_num;
300 all_vert128_loop++) {
301 for (int hor_256_loop = 0; hor_256_loop < blk_128x256_loop_num;
302 hor_256_loop++) {
303 gemm_arguments_128x256 gemm_arg_128x256;
304 matAcc_128x256_t matAcc_128x256;
305 matC_128x256_t matC_128x256;
306 matC_128x256_payload_t matC_128x256_payload;
307
308 uint32_t width_a = (headid + 1) * hdsz;
309 uint32_t height_a = tru_seqlen + seqlen_entry;
310 uint32_t pitch_a = hiddensize;
311 int start_x_a = headid * hdsz;
312 int start_y_a = all_vert128_loop * 128 + seqlen_entry;
313
314 gemm_arg_128x256.matA_base_desc.init({args->matQ_ptr},
315 {width_a, height_a, pitch_a}, {start_x_a, start_y_a});
316
317 uint32_t width_b = (headid + 1) * hdsz;
318 uint32_t height_b = tru_seqlen + seqlen_entry;
319 uint32_t pitch_b = hiddensize;
320 int start_x_b = headid * hdsz;
321 int start_y_b = hor_256_loop * 256 + seqlen_entry;
322
323 //B transpose
324 gemm_arg_128x256.matB_base_desc.init({args->matK_ptr},
325 {height_b, width_b, pitch_b}, {start_y_b, start_x_b});
326
327 gemm_arg_128x256.inner_loop_count
328 = (wg_tile_QKT_k + k_stride - 1) / k_stride;
329
330 matAcc_128x256.init(0);
331 gemm_op_128x256_t gemm_op_128x256;
332
333 gemm_op_128x256(g_thd32_tid, matAcc_128x256, gemm_arg_128x256);
334
335 uint32_t width_c = max_seqlen;
336 uint32_t height_c
337 = max_seqlen * (batchid * numhead + headid + 1);
338 uint32_t pitch_c = max_seqlen;
339 int start_x_c
340 = gemm_op_128x256_t::get_matC_offset_x(g_thd32_tid)
341 + hor_256_loop * 256;
342 int start_y_c = (batchid * numhead + headid) * max_seqlen
343 + all_vert128_loop * 128
344 + gemm_op_128x256_t::get_matC_offset_y(g_thd32_tid);
345
346 matC_128x256_payload.init(args->matQKT_ptr, width_c, height_c,
347 pitch_c, start_x_c, start_y_c);
348 subgroup::elemwise_cvt<matC_128x256_t, matAcc_128x256_t>(
349 matC_128x256, matAcc_128x256);
350 subgroup::tile_store(matC_128x256, matC_128x256_payload);
351 xetla_fence<memory_kind::untyped_global>();
352 }
353
354 for (int blk_128x128_loop = 0; blk_128x128_loop < blk_128x128_one;
355 blk_128x128_loop++) {
356 gemm_arguments_128x128 gemm_arg_128x128;
357 matAcc_128x128_t matAcc_128x128;
358 matC_128x128_t matC_128x128;
359 matC_128x128_payload_t matC_128x128_payload;
360
361 uint32_t width_a = (headid + 1) * hdsz;
362 uint32_t height_a = tru_seqlen + seqlen_entry;
363 uint32_t pitch_a = hiddensize;
364 int start_x_a = headid * hdsz;
365 int start_y_a = all_vert128_loop * 128 + seqlen_entry;
366
367 gemm_arg_128x128.matA_base_desc.init({args->matQ_ptr},
368 {width_a, height_a, pitch_a}, {start_x_a, start_y_a});
369
370 uint32_t width_b = (headid + 1) * hdsz;
371 uint32_t height_b = tru_seqlen + seqlen_entry;
372 uint32_t pitch_b = hiddensize;
373 int start_x_b = headid * hdsz;
374 int start_y_b = offset_blk_128x128 + seqlen_entry;
375
376 //B transpose
377 gemm_arg_128x128.matB_base_desc.init({args->matK_ptr},
378 {height_b, width_b, pitch_b}, {start_y_b, start_x_b});
379
380 gemm_arg_128x128.inner_loop_count
381 = (wg_tile_QKT_k + k_stride - 1) / k_stride;
382
383 matAcc_128x128.init(0);
384 gemm_op_128x128_t gemm_op_128x128;
385
386 gemm_op_128x128(g_thd32_tid, matAcc_128x128, gemm_arg_128x128);
387
388 uint32_t width_c = max_seqlen;
389 uint32_t height_c
390 = max_seqlen * (batchid * numhead + headid + 1);
391 uint32_t pitch_c = max_seqlen;
392 int start_x_c = offset_blk_128x128
393 + gemm_op_128x128_t::get_matC_offset_x(g_thd32_tid);
394 int start_y_c = (batchid * numhead + headid) * max_seqlen
395 + all_vert128_loop * 128
396 + gemm_op_128x128_t::get_matC_offset_y(g_thd32_tid);
397
398 matC_128x128_payload.init(args->matQKT_ptr, width_c, height_c,
399 pitch_c, start_x_c, start_y_c);
400 subgroup::elemwise_cvt<matC_128x128_t, matAcc_128x128_t>(
401 matC_128x128, matAcc_128x128);
402 subgroup::tile_store(matC_128x128, matC_128x128_payload);
403 xetla_fence<memory_kind::untyped_global>();
404 }
405
406 //fwd softmax
407 {
408 int elem_Ln512_loop_num = 4;
409 int height_8x64_512 = 8 * sfx_type_size;
410 int width_8x16_512 = 64 / sfx_type_size;
411 int height_elem_offset
412 = (max_seqlen * (batchid * numhead + headid)
413 + (all_vert128_loop * 128) + (tid_linear * 4))
414 * height_8x64_512;
415 int width_elem = width_8x16_512;
416 int height_elem;
417 int pitch_elem = width_elem;
418 int start_x_elem = 0;
419 int start_y_elem;
420 int bndy_mk_lp_start = (tru_seqlen + 31) >> 5; //32
421 int bndy_mk_lp_shift
422 = 32 - (bndy_mk_lp_start << 5) + tru_seqlen;
423
425 uint32_t mk_attn_all
426 = sizeof(uint32_t) * (max_seqlen / 32) * (batchid);
427 xetla_vector<uint32_t, 16> mk_attn_offsets
428 = xetla_vector_gen<uint32_t, 16>(0, 1);
429 mk_attn_offsets *= sizeof(uint32_t);
430 mk_attn_offsets += mk_attn_all;
431 mkin_vec16 = xetla_load_global<uint32_t, 1,
434 args->matMkin_ptr, mk_attn_offsets);
435
436 uint32_t mk_offset_all = sizeof(uint32_t) * (max_seqlen / 32)
437 * ((batchid * numhead + headid) * max_seqlen
438 + (all_vert128_loop * 128) + tid_linear * 4);
440 = xetla_vector_gen<uint32_t, 16>(0, 1);
441 mk_offsets *= sizeof(uint32_t);
442 mk_offsets += mk_offset_all;
443
444 first_nbarr.arrive();
445 first_nbarr.wait();
446
447 for (int elem_Ln512_loop = 0;
448 elem_Ln512_loop < elem_Ln512_loop_num;
449 elem_Ln512_loop++) {
450 matElem_ld_t matQKT_rd;
451 matElem_ld_payload_t matQKT_rd_payload;
452 matElem_st_t matQKT_st;
453 matElem_st_payload_t matQKT_st_payload;
454 matElem_reg_t matQKT_reg16x32;
455
456 xetla_vector<uint32_t, 16> mkdpot_vec16;
457
458 start_y_elem = height_elem_offset
459 + elem_Ln512_loop * height_8x64_512;
460 height_elem = start_y_elem
461 + ((std_seqlen * sfx_type_size) / 64);
462
463 matQKT_rd_payload.init(args->matQKT_ptr, width_elem,
464 height_elem, pitch_elem, start_x_elem,
465 start_y_elem);
466 matQKT_st_payload.init(args->matQKT_ptr, width_elem,
467 height_elem, pitch_elem, start_x_elem,
468 start_y_elem);
469
470 subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
471 matQKT_rd, matQKT_rd_payload);
472
473 if constexpr (Dopt_RandGenflag == false) {
474 mkdpot_vec16 = xetla_load_global<uint32_t, 1,
477 16>(args->matMkdpot_ptr, mk_offsets);
478 }
479
480 mk_offsets += sizeof(uint32_t) * (max_seqlen / 32);
481
482 for (int j = bndy_mk_lp_start; j < 16; j++)
483 mkin_vec16[j] = 0xFFFFFFFF;
484 if (bndy_mk_lp_shift < 32) {
485 uint32_t tmp = 0xFFFFFFFF;
486 tmp >>= bndy_mk_lp_shift;
487 tmp <<= bndy_mk_lp_shift;
488 mkin_vec16[bndy_mk_lp_start - 1] |= tmp;
489 //mkin_vec16[bndy_mk_lp_start - 1] <<= bndy_mk_lp_shift;
490 //mkin_vec16[bndy_mk_lp_start - 1] >>= bndy_mk_lp_shift;
491 }
492
493 matQKT_reg16x32.reg
494 = xetla_cvt<float, dtype_sfx>(matQKT_rd.reg);
495 matQKT_reg16x32.reg = matQKT_reg16x32.reg * args->Pinv;
496
497#pragma unroll
498 for (int j = 0; j < 16; j++) {
499 uint32_t mkdata_i = mkin_vec16[j];
500 xetla_mask_int<32> mkdata
501 = xetla_mask_int_gen<32>(mkdata_i);
502 matQKT_reg16x32.reg.xetla_format<float>()
503 .xetla_select<32, 1>(j * 32)
504 .xetla_merge(-1e32,
505 matQKT_reg16x32.reg
506 .xetla_format<float>()
507 .xetla_select<32, 1>(j * 32),
508 mkdata);
509 }
510
511 xetla_vector<float, 16> QKT_reg16_f;
512 QKT_reg16_f = -1e32;
513#pragma unroll
514 for (int j = 0; j < 32; j++) {
515 xetla_mask<16> filter_max = (QKT_reg16_f
516 > matQKT_reg16x32.reg.xetla_format<float>()
517 .xetla_select<16, 1>(j * 16));
518 QKT_reg16_f.xetla_merge(QKT_reg16_f,
519 matQKT_reg16x32.reg.xetla_format<float>()
520 .xetla_select<16, 1>(j * 16),
521 filter_max);
522 }
523
524 xetla_mask<8> filter_max8
525 = (QKT_reg16_f.xetla_select<8, 1>(0)
526 > QKT_reg16_f.xetla_select<8, 1>(8));
527 QKT_reg16_f.xetla_select<8, 1>(0).xetla_merge(
528 QKT_reg16_f.select<8, 1>(0),
529 QKT_reg16_f.select<8, 1>(8), filter_max8);
530 xetla_mask<4> filter_max4
531 = (QKT_reg16_f.xetla_select<4, 1>(0)
532 > QKT_reg16_f.xetla_select<4, 1>(4));
533 QKT_reg16_f.xetla_select<4, 1>(0).xetla_merge(
534 QKT_reg16_f.select<4, 1>(0),
535 QKT_reg16_f.select<4, 1>(4), filter_max4);
536 xetla_mask<2> filter_max2
537 = (QKT_reg16_f.xetla_select<2, 1>(0)
538 > QKT_reg16_f.xetla_select<2, 1>(2));
539 QKT_reg16_f.xetla_select<2, 1>(0).xetla_merge(
540 QKT_reg16_f.select<2, 1>(0),
541 QKT_reg16_f.select<2, 1>(2), filter_max2);
542 xetla_mask<1> filter_max1
543 = (QKT_reg16_f.xetla_select<1, 1>(0)
544 > QKT_reg16_f.xetla_select<1, 1>(1));
545 QKT_reg16_f.xetla_select<1, 1>(0).xetla_merge(
546 QKT_reg16_f.xetla_select<1, 1>(0),
547 QKT_reg16_f.xetla_select<1, 1>(1), filter_max1);
548
549 {
550 float tmp_max = QKT_reg16_f[0];
551 matQKT_reg16x32.reg = matQKT_reg16x32.reg - tmp_max;
552 }
553
554#pragma unroll
555 for (int j = 0; j < 16; j++)
556 matQKT_reg16x32.reg.xetla_format<float>()
557 .xetla_select<32, 1>(j * 32)
558 = xetla_exp<float, 32>(
559 matQKT_reg16x32.reg
560 .xetla_format<float>()
561 .xetla_select<32, 1>(j * 32));
562
563 QKT_reg16_f = matQKT_reg16x32.reg.xetla_format<float>()
564 .xetla_select<16, 1>(0)
565 + matQKT_reg16x32.reg.xetla_format<float>()
566 .xetla_select<16, 1>(16);
567#pragma unroll
568 for (int j = 2; j < 32; j++)
569 QKT_reg16_f = QKT_reg16_f
570 + matQKT_reg16x32.reg.xetla_format<float>()
571 .xetla_select<16, 1>(j * 16);
572
573 QKT_reg16_f.xetla_select<8, 1>(0)
574 += QKT_reg16_f.xetla_select<8, 1>(8);
575 QKT_reg16_f.xetla_select<4, 1>(0)
576 += QKT_reg16_f.xetla_select<4, 1>(4);
577 QKT_reg16_f.xetla_select<2, 1>(0)
578 += QKT_reg16_f.xetla_select<2, 1>(2);
579 QKT_reg16_f.xetla_select<1, 1>(0)
580 += QKT_reg16_f.xetla_select<1, 1>(1);
581
582 QKT_reg16_f.xetla_select<1, 1>(0) = xetla_inv<float, 1>(
583 QKT_reg16_f.xetla_select<1, 1>(0));
584 {
585 float tmp = QKT_reg16_f[0];
586 QKT_reg16_f = tmp;
587 }
588
589#pragma unroll
590 for (int j = 0; j < 32; j++)
591 matQKT_reg16x32.reg.xetla_format<float>()
592 .xetla_select<16, 1>(j * 16)
593 *= QKT_reg16_f;
594
595 xetla_mask<(Max_SeqLen >> 2)> rand_bit;
597
598 matQKT_reg16x32.reg = matQKT_reg16x32.reg * args->Scaling;
599
600 using matElem_reg_w_t = subgroup::tile_t<uint16_t,
601 subgroup::tile_desc_t<32, 1, 32, 1,
603 using matElem_reg_b_t = subgroup::tile_t<uint8_t,
604 subgroup::tile_desc_t<32, 1, 32, 1,
606 matElem_reg_w_t drop_mk_w;
607 matElem_reg_b_t drop_mk_b;
608
609 if constexpr (Dopt_RandGenflag == true) {
610 matQKT_st.reg = xetla_cvt<dtype_sfx, float>(
611 matQKT_reg16x32.reg);
612
613 using matElem_reg_w_t
616 32, 1, reg_layout::tiled>>;
617 using matElem_reg_b_t
620 32, 1, reg_layout::tiled>>;
621 matElem_reg_w_t drop_mk_w;
622 matElem_reg_b_t drop_mk_b;
623
624#pragma unroll
625 for (int i = 0; i < (Max_SeqLen / (4 * 4 * RandSIMD));
626 i++) {
627 rand_data = Rand_Gen.rand();
628 rand_bit.xetla_select<4 * RandSIMD, 1>(
629 i * (4 * RandSIMD))
630 = rand_data > rand_threshold;
631 }
632#pragma unroll
633 for (int j = 0; j < 4; j++) {
634
635 if constexpr (sfx_type_size == 2) {
636 drop_mk_w.reg.xetla_select<32, 1>(0)
638 rand_bit.xetla_select<32, 1>(
639 j * 32));
640 matQKT_st.reg.xetla_format<uint16_t>()
641 .xetla_select<32, 1>(j * 32)
642 |= drop_mk_w.reg.xetla_select<32, 1>(0);
643 }
644 if constexpr (sfx_type_size == 1) {
645 drop_mk_b.reg.xetla_select<32, 1>(0)
647 rand_bit.xetla_select<32, 1>(
648 j * 32));
649 matQKT_st.reg.xetla_format<uint8_t>()
650 .xetla_select<32, 1>(j * 32)
651 |= drop_mk_b.reg.xetla_select<32, 1>(0);
652 }
653 }
654
655 if (std_seqlen > 128) {
656#pragma unroll
657 for (int i = 0;
658 i < (Max_SeqLen / (4 * 4 * RandSIMD));
659 i++) {
660 rand_data = Rand_Gen.rand();
661 rand_bit.xetla_select<4 * RandSIMD, 1>(
662 i * (4 * RandSIMD))
663 = rand_data > rand_threshold;
664 }
665#pragma unroll
666 for (int j = 4; j < 8; j++) {
667 if constexpr (sfx_type_size == 2) {
668 drop_mk_w.reg.xetla_select<32, 1>(0)
670 rand_bit.xetla_select<32,
671 1>((j - 4) * 32));
672 matQKT_st.reg.xetla_format<uint16_t>()
673 .xetla_select<32, 1>(j * 32)
674 |= drop_mk_w.reg
675 .xetla_select<32, 1>(0);
676 }
677 if constexpr (sfx_type_size == 1) {
678 drop_mk_b.reg.xetla_select<32, 1>(0)
680 rand_bit.xetla_select<32,
681 1>((j - 4) * 32));
682 matQKT_st.reg.xetla_format<uint8_t>()
683 .xetla_select<32, 1>(j * 32)
684 |= drop_mk_b.reg
685 .xetla_select<32, 1>(0);
686 }
687 }
688
689 if (std_seqlen > 256) {
690#pragma unroll
691 for (int i = 0;
692 i < (Max_SeqLen / (4 * 4 * RandSIMD));
693 i++) {
694 rand_data = Rand_Gen.rand();
695 rand_bit.xetla_select<4 * RandSIMD, 1>(
696 i * (4 * RandSIMD))
697 = rand_data > rand_threshold;
698 }
699#pragma unroll
700 for (int j = 8; j < 12; j++) {
701 if constexpr (sfx_type_size == 2) {
702 drop_mk_w.reg.xetla_select<32, 1>(0)
704 rand_bit.xetla_select<
705 32, 1>(
706 (j - 8) * 32));
707 matQKT_st.reg.xetla_format<uint16_t>()
708 .xetla_select<32, 1>(j * 32)
709 |= drop_mk_w.reg
710 .xetla_select<32, 1>(
711 0);
712 }
713 if constexpr (sfx_type_size == 1) {
714 drop_mk_b.reg.xetla_select<32, 1>(0)
716 rand_bit.xetla_select<
717 32, 1>(
718 (j - 8) * 32));
719 matQKT_st.reg.xetla_format<uint8_t>()
720 .xetla_select<32, 1>(j * 32)
721 |= drop_mk_b.reg
722 .xetla_select<32, 1>(
723 0);
724 }
725 }
726 if (std_seqlen > 384) {
727#pragma unroll
728 for (int i = 0; i
729 < (Max_SeqLen / (4 * 4 * RandSIMD));
730 i++) {
731 rand_data = Rand_Gen.rand();
732 rand_bit.xetla_select<4 * RandSIMD, 1>(
733 i * (4 * RandSIMD))
734 = rand_data > rand_threshold;
735 }
736#pragma unroll
737 for (int j = 12; j < 16; j++) {
738 if constexpr (sfx_type_size == 2) {
739 drop_mk_w.reg.xetla_select<32, 1>(0)
741 0,
742 rand_bit.xetla_select<
743 32, 1>(
744 (j - 12)
745 * 32));
746 matQKT_st.reg
747 .xetla_format<uint16_t>()
748 .xetla_select<32, 1>(j * 32)
749 |= drop_mk_w.reg
750 .xetla_select<32,
751 1>(0);
752 }
753 if constexpr (sfx_type_size == 1) {
754 drop_mk_b.reg.xetla_select<32, 1>(0)
756 rand_bit.xetla_select<
757 32, 1>(
758 (j - 12)
759 * 32));
760 matQKT_st.reg
761 .xetla_format<uint8_t>()
762 .xetla_select<32, 1>(j * 32)
763 |= drop_mk_b.reg
764 .xetla_select<32,
765 1>(0);
766 }
767 }
768 }
769 }
770 }
771 } else {
772 matQKT_st.reg = xetla_cvt<dtype_sfx, float>(
773 matQKT_reg16x32.reg);
774#pragma unroll
775 for (int j = 0; j < 16; j++) {
776 uint32_t mkdata_i = mkdpot_vec16[j];
777 xetla_mask_int<32> mkdata
778 = xetla_mask_int_gen<32>(mkdata_i);
779 if constexpr (sfx_type_size == 2) {
780 drop_mk_w.reg.xetla_select<32, 1>(0)
781 .xetla_merge(SIGN_BIT_W16, 0, mkdata);
782 matQKT_st.reg.xetla_format<uint16_t>()
783 .xetla_select<32, 1>(j * 32)
784 |= drop_mk_w.reg.xetla_select<32, 1>(0);
785 }
786 if constexpr (sfx_type_size == 1) {
787 drop_mk_b.reg.xetla_select<32, 1>(0)
788 .xetla_merge(SIGN_BIT_B8, 0, mkdata);
789 matQKT_st.reg.xetla_format<uint8_t>()
790 .xetla_select<32, 1>(j * 32)
791 |= drop_mk_b.reg.xetla_select<32, 1>(0);
792 }
793 }
794
795 matQKT_st.reg = xetla_cvt<dtype_sfx, float>(
796 matQKT_reg16x32.reg);
797 }
798
799 subgroup::tile_store(matQKT_st, matQKT_st_payload);
800 xetla_fence<memory_kind::untyped_global>();
801 }
802
803 second_nbarr.arrive();
804 second_nbarr.wait();
805 }
806
807 //QKtV
808 {
809 gemm_arguments_128x64 gemm_arg_128x64;
810 matAcc_128x64_t matAcc_128x64;
811 matC_128x64_t matC_128x64;
812 matC_128x64_payload_t matC_128x64_payload;
813
814 uint32_t width_a = tru_seqlen_ex;
815 uint32_t height_a = (batchid * numhead + headid) * max_seqlen
816 + tru_seqlen;
817 uint32_t pitch_a = max_seqlen;
818 int start_x_a = 0;
819 int start_y_a = (batchid * numhead + headid) * max_seqlen
820 + all_vert128_loop * 128;
821
822 gemm_arg_128x64.matA_base_desc.init({args->matQKT_ptr},
823 {width_a, height_a, pitch_a}, {start_x_a, start_y_a});
824
825 uint32_t width_b = (headid + 1) * hdsz;
826 uint32_t height_b = tru_seqlen + seqlen_entry;
827 uint32_t pitch_b = hiddensize;
828 int start_x_b = headid * hdsz;
829 int start_y_b = seqlen_entry;
830
831 gemm_arg_128x64.matB_base_desc.init({args->matV_ptr},
832 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
833
834 gemm_arg_128x64.inner_loop_count
835 = (wg_tile_out_k + k_stride - 1) / k_stride;
836
837 matAcc_128x64.init(0);
838 gemm_op_128x64_t gemm_op_128x64;
839 gemm_op_128x64(g_thd32_tid, matAcc_128x64, gemm_arg_128x64);
840
841 uint32_t width_c = (headid + 1) * hdsz;
842 uint32_t height_c = tru_seqlen + seqlen_entry;
843 uint32_t pitch_c = hiddensize;
844 int start_x_c = headid * hdsz
845 + gemm_op_128x64_t::get_matC_offset_x(g_thd32_tid);
846 int start_y_c = all_vert128_loop * 128 + seqlen_entry
847 + gemm_op_128x64_t::get_matC_offset_y(g_thd32_tid);
848
849 matC_128x64_payload.init(args->matOut_ptr, width_c, height_c,
850 pitch_c, start_x_c, start_y_c);
851 subgroup::elemwise_cvt<matC_128x64_t, matAcc_128x64_t>(
852 matC_128x64, matAcc_128x64);
853 subgroup::tile_store(matC_128x64, matC_128x64_payload);
854 }
855
856 } //all_vert128_loop
857 } //xetla_softmax_fwd_t::call()
858}; //struct xetla_softmax_fwd_t
859
869template <typename dtype_bwd_bin_, typename dtype_bwd_bot_,
870 typename dtype_bwd_sfx_, typename dtype_bwd_acc_, int HWThreadNum,
871 bool Dopt_RandGenflag = true, bool Mkin_flag = false>
873 using dtype_bin = dtype_bwd_bin_;
874 using dtype_bot = dtype_bwd_bot_;
875 using dtype_sfx = dtype_bwd_sfx_;
876 using dtype_acc = dtype_bwd_acc_;
877
878 static constexpr int ThreadNum = HWThreadNum;
879 static_assert(ThreadNum == 32);
883
889
894
898
899 static constexpr uint32_t periodic_sync_interval = 0;
900 static constexpr uint32_t prefetch_distance = 3;
901
902 static constexpr uint32_t k_stride
903 = 32 / sizeof(dtype_bin); //gemm_t::k_stride;
910
918
926
934
935 static constexpr uint32_t global_kslicing = 1;
936 static constexpr uint16_t sfx_type_size = sizeof(dtype_sfx);
937 static_assert((sfx_type_size == 1) || (sfx_type_size == 2)
938 || (sfx_type_size == 4));
939
941
956
977
978 using gemm_arguments_128x128 = typename gemm_op_128x128_t::arguments_t;
979 using gemm_arguments_128x256 = typename gemm_op_128x256_t::arguments_t;
980 using gemm_arguments_128x64 = typename gemm_op_128x64_t::arguments_t;
982 typename gemm_op_128x64_trnp_a_t::arguments_t;
984 typename gemm_op_256x64_trnp_a_t::arguments_t;
986 typename gemm_op_128x64_trnp_af_t::arguments_t;
988 typename gemm_op_256x64_trnp_af_t::arguments_t;
989
990 using matAcc_128x128_t = typename gemm_op_128x128_t::matAcc_t;
991 using matAcc_128x256_t = typename gemm_op_128x256_t::matAcc_t;
992 using matAcc_128x64_t = typename gemm_op_128x64_t::matAcc_t;
993 using matAcc_128x64_trnp_a_t = typename gemm_op_128x64_trnp_a_t::matAcc_t;
994 using matAcc_256x64_trnp_a_t = typename gemm_op_256x64_trnp_a_t::matAcc_t;
995 using matAcc_128x64_trnp_af_t = typename gemm_op_128x64_trnp_af_t::matAcc_t;
996 using matAcc_256x64_trnp_af_t = typename gemm_op_256x64_trnp_af_t::matAcc_t;
997
999 = subgroup::tile_desc_t<matAcc_128x128_t::tile_desc::tile_size_x,
1000 matAcc_128x128_t::tile_desc::tile_size_y,
1001 matAcc_128x128_t::tile_desc::block_size_x,
1002 matAcc_128x128_t::tile_desc::block_size_y,
1005 = subgroup::tile_desc_t<matAcc_128x256_t::tile_desc::tile_size_x,
1006 matAcc_128x256_t::tile_desc::tile_size_y,
1007 matAcc_128x256_t::tile_desc::block_size_x,
1008 matAcc_128x256_t::tile_desc::block_size_y,
1011 = subgroup::tile_desc_t<matAcc_128x64_t::tile_desc::tile_size_x,
1012 matAcc_128x64_t::tile_desc::tile_size_y,
1013 matAcc_128x64_t::tile_desc::block_size_x,
1014 matAcc_128x64_t::tile_desc::block_size_y,
1017 matAcc_128x64_trnp_a_t::tile_desc::tile_size_x,
1018 matAcc_128x64_trnp_a_t::tile_desc::tile_size_y,
1019 matAcc_128x64_trnp_a_t::tile_desc::block_size_x,
1020 matAcc_128x64_trnp_a_t::tile_desc::block_size_y, reg_layout::tiled>;
1022 matAcc_256x64_trnp_a_t::tile_desc::tile_size_x,
1023 matAcc_256x64_trnp_a_t::tile_desc::tile_size_y,
1024 matAcc_256x64_trnp_a_t::tile_desc::block_size_x,
1025 matAcc_256x64_trnp_a_t::tile_desc::block_size_y, reg_layout::tiled>;
1027 matAcc_128x64_trnp_af_t::tile_desc::tile_size_x,
1028 matAcc_128x64_trnp_af_t::tile_desc::tile_size_y,
1029 matAcc_128x64_trnp_af_t::tile_desc::block_size_x,
1030 matAcc_128x64_trnp_af_t::tile_desc::block_size_y,
1033 matAcc_256x64_trnp_af_t::tile_desc::tile_size_x,
1034 matAcc_256x64_trnp_af_t::tile_desc::tile_size_y,
1035 matAcc_256x64_trnp_af_t::tile_desc::block_size_x,
1036 matAcc_256x64_trnp_af_t::tile_desc::block_size_y,
1051
1055 (global_kslicing > 1)
1057 : subgroup::msg_type_v<matC_128x128_tile_desc_t,
1058 mem_space_c>,
1059 gpu_arch::Xe>;
1063 (global_kslicing > 1)
1065 : subgroup::msg_type_v<matC_128x256_tile_desc_t,
1066 mem_space_c>,
1067 gpu_arch::Xe>;
1072 : subgroup::msg_type_v<
1074 gpu_arch::Xe>;
1078 (global_kslicing > 1)
1080 : subgroup::msg_type_v<matC_128x64_trnp_a_tile_desc_t,
1081 mem_space_c>,
1082 gpu_arch::Xe>;
1086 (global_kslicing > 1)
1088 : subgroup::msg_type_v<matC_256x64_trnp_a_tile_desc_t,
1089 mem_space_c>,
1090 gpu_arch::Xe>;
1094 (global_kslicing > 1)
1096 : subgroup::msg_type_v<matC_128x64_trnp_af_tile_desc_t,
1097 mem_space_c>,
1098 gpu_arch::Xe>;
1102 (global_kslicing > 1)
1104 : subgroup::msg_type_v<matC_256x64_trnp_af_tile_desc_t,
1105 mem_space_c>,
1106 gpu_arch::Xe>;
1107
1108 //512 = 16x32 or 8x64
1120 subgroup::msg_type_v<matElem_tile_desc_t, mem_space::global>,
1121 gpu_arch::Xe>;
1126 gpu::xetla::subgroup::tile_desc_t<32, 16, 32, 16,
1131 // assume base address, surface width, height, pitch, start coordinate was set
1132 uint32_t *mList_ptr;
1136 uint32_t *matMkin_ptr;
1137 uint32_t *matMkdpot_ptr;
1144 float Pinv;
1145 float Scaling;
1146 };
1147
1154 __XETLA_API static void call(sycl::nd_item<3> &item, arguments_t *args) {
1155
1156 int tru_seqlen = 0;
1157 int tru_seqlen_ex = 0;
1158 int seqlen_entry = 0;
1159 int hiddensize = 1024;
1160 int numhead = 16;
1161 int hdsz = 64;
1162 int max_seqlen = 512;
1163 int wg_tile_QKT_k = hdsz; //args->matrix_k;
1164 int wg_tile_out_k;
1165
1166 int groupid = item.get_group(0);
1167 int batchid = groupid / numhead;
1168 int headid = groupid % numhead;
1169
1170 //float totalscaling = args->Pinv * args->Scaling;
1171
1172 uint32_t batch_offset = sizeof(uint32_t) * list_width * batchid;
1174 = xetla_vector_gen<uint32_t, list_width>(0, 1);
1175 list_offsets *= sizeof(uint32_t);
1176 list_offsets += batch_offset;
1180 list_width>(args->mList_ptr, list_offsets);
1181 tru_seqlen = list_vec[0];
1182 seqlen_entry = list_vec[1];
1183 wg_tile_out_k = tru_seqlen;
1184 tru_seqlen_ex = tru_seqlen; //4: dw aligned
1185 if (sfx_type_size == 2)
1186 tru_seqlen_ex = ((tru_seqlen + 1) >> 1) << 1;
1187 else if (sfx_type_size == 1)
1188 tru_seqlen_ex = ((tru_seqlen + 3) >> 2) << 2;
1189
1190 //reset for all std_seqlen
1191 int all_vert_loop_num = 0;
1192 int transp128_loop_num = 0;
1193 int transp256_loop_num = 0;
1194 int blk_128x128_one = 0;
1195 int blk_128x256_loop_num = 0;
1196 int offset_blk_128x128 = 0;
1197 int std_seqlen;
1198 if (tru_seqlen <= 128) {
1199 std_seqlen = 128;
1200 all_vert_loop_num = 1;
1201 transp128_loop_num = 1;
1202 blk_128x128_one = 1;
1203 } else if (tru_seqlen <= 256) {
1204 std_seqlen = 256;
1205 all_vert_loop_num = 2;
1206 transp256_loop_num = 1;
1207 blk_128x256_loop_num = 1;
1208 } else if (tru_seqlen <= 384) {
1209 std_seqlen = 384;
1210 all_vert_loop_num = 3;
1211 transp128_loop_num = 1;
1212 transp256_loop_num = 1;
1213 blk_128x128_one = 1;
1214 blk_128x256_loop_num = 1;
1215 offset_blk_128x128 = 256;
1216 } else {
1217 std_seqlen = 512;
1218 all_vert_loop_num = 4;
1219 transp256_loop_num = 2;
1220 blk_128x256_loop_num = 2;
1221 }
1222
1223 work_group_t g_thd32_tid;
1224 int tid_linear = item.get_local_linear_id();
1225 g_thd32_tid.init(tid_linear);
1226
1227 static_assert(ThreadNum == 32, "All Thread Sync");
1230
1231 int max_2d_nbar_id = ThreadNum >> 1;
1232 first_nbarr.init_nbarrier(
1233 max_2d_nbar_id, nbarrier_role::producer_consumer);
1234 second_nbarr.init_nbarrier(
1235 max_2d_nbar_id + 1, nbarrier_role::producer_consumer);
1236
1238 all_nbarr.init_nbarrier(
1240
1241 for (int transp128_loop = 0; transp128_loop < transp128_loop_num;
1242 transp128_loop++) {
1243 gemm_arguments_128x64_trnp_af gemm_arg_128x64;
1244 matAcc_128x64_trnp_af_t matAcc_128x64;
1245 matC_128x64_trnp_af_t matC_128x64;
1246 matC_128x64_trnp_af_payload_t matC_128x64_payload;
1247
1248 uint32_t width_a = tru_seqlen_ex;
1249 uint32_t height_a
1250 = (batchid * numhead + headid) * max_seqlen + tru_seqlen;
1251 uint32_t pitch_a = max_seqlen;
1252 int start_x_a = transp128_loop * 128 + offset_blk_128x128;
1253 int start_y_a = (batchid * numhead + headid) * max_seqlen;
1254
1255 gemm_arg_128x64.matA_base_desc.init({args->matW_ptr},
1256 {height_a, width_a, pitch_a}, {start_y_a, start_x_a});
1257
1258 uint32_t width_b = (headid + 1) * hdsz;
1259 uint32_t height_b = tru_seqlen + seqlen_entry;
1260 uint32_t pitch_b = hiddensize;
1261 int start_x_b = headid * hdsz;
1262 int start_y_b = seqlen_entry;
1263
1264 gemm_arg_128x64.matB_base_desc.init({args->matdO_ptr},
1265 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
1266 gemm_arg_128x64.inner_loop_count
1267 = (wg_tile_out_k + k_stride - 1) / k_stride;
1268 matAcc_128x64.init(0);
1269
1270 gemm_op_128x64_trnp_af_t gemm_op_128x64_trnp_af;
1271 gemm_op_128x64_trnp_af(g_thd32_tid, matAcc_128x64, gemm_arg_128x64);
1272
1273 int width_c = (headid + 1) * hdsz;
1274 int height_c = tru_seqlen + seqlen_entry;
1275 int pitch_c = hiddensize;
1276 int start_x_c = headid * hdsz
1277 + gemm_op_128x64_trnp_af_t::get_matC_offset_x(g_thd32_tid);
1278 int start_y_c = transp128_loop * 128 + seqlen_entry
1279 + offset_blk_128x128
1280 + gemm_op_128x64_trnp_af_t::get_matC_offset_y(g_thd32_tid);
1281
1282 matC_128x64_payload.init(args->matdV_ptr, width_c, height_c,
1283 pitch_c, start_x_c, start_y_c);
1285 matAcc_128x64_trnp_af_t>(matC_128x64, matAcc_128x64);
1286 subgroup::tile_store(matC_128x64, matC_128x64_payload);
1287
1288 //add global sync if nbarr used inside gemm
1289 all_nbarr.arrive();
1290 all_nbarr.wait();
1291 }
1292
1293 for (int transp256_loop = 0; transp256_loop < transp256_loop_num;
1294 transp256_loop++) {
1295 gemm_arguments_256x64_trnp_af gemm_arg_256x64;
1296 matAcc_256x64_trnp_af_t matAcc_256x64;
1297 matC_256x64_trnp_af_t matC_256x64;
1298 matC_256x64_trnp_af_payload_t matC_256x64_payload;
1299
1300 uint32_t width_a = tru_seqlen_ex;
1301 uint32_t height_a
1302 = (batchid * numhead + headid) * max_seqlen + tru_seqlen;
1303 uint32_t pitch_a = max_seqlen;
1304 int start_x_a = transp256_loop * 256;
1305 int start_y_a = (batchid * numhead + headid) * max_seqlen;
1306 gemm_arg_256x64.matA_base_desc.init({args->matW_ptr},
1307 {height_a, width_a, pitch_a}, {start_y_a, start_x_a});
1308
1309 uint32_t width_b = (headid + 1) * hdsz;
1310 uint32_t height_b = tru_seqlen + seqlen_entry;
1311 uint32_t pitch_b = hiddensize;
1312 int start_x_b = headid * hdsz;
1313 int start_y_b = seqlen_entry;
1314 gemm_arg_256x64.matB_base_desc.init({args->matdO_ptr},
1315 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
1316
1317 gemm_arg_256x64.inner_loop_count
1318 = (wg_tile_out_k + k_stride - 1) / k_stride;
1319
1320 matAcc_256x64.init(0);
1321
1322 gemm_op_256x64_trnp_af_t gemm_op_256x64_trnp_af;
1323 gemm_op_256x64_trnp_af(g_thd32_tid, matAcc_256x64, gemm_arg_256x64);
1324
1325 int width_c = (headid + 1) * hdsz;
1326 int height_c = tru_seqlen + seqlen_entry;
1327 int pitch_c = hiddensize;
1328 int start_x_c = headid * hdsz
1329 + gemm_op_256x64_trnp_af_t::get_matC_offset_x(g_thd32_tid);
1330 int start_y_c = transp256_loop * 256 + seqlen_entry
1331 + gemm_op_256x64_trnp_af_t::get_matC_offset_y(g_thd32_tid);
1332
1333 matC_256x64_payload.init(args->matdV_ptr, width_c, height_c,
1334 pitch_c, start_x_c, start_y_c);
1336 matAcc_256x64_trnp_af_t>(matC_256x64, matAcc_256x64);
1337 subgroup::tile_store(matC_256x64, matC_256x64_payload);
1338
1339 //add global sync if nbarr used inside gemm
1340 all_nbarr.arrive();
1341 all_nbarr.wait();
1342 }
1343
1344 for (int all_vert128_loop = 0; all_vert128_loop < all_vert_loop_num;
1345 all_vert128_loop++) {
1346 //dW
1347 for (int hor_256_loop = 0; hor_256_loop < blk_128x256_loop_num;
1348 hor_256_loop++) {
1349 gemm_arguments_128x256 gemm_arg_128x256;
1350 matAcc_128x256_t matAcc_128x256;
1351 matC_128x256_t matC_128x256;
1352 matC_128x256_payload_t matC_128x256_payload;
1353
1354 uint32_t width_a = (headid + 1) * hdsz;
1355 uint32_t height_a = tru_seqlen + seqlen_entry;
1356 uint32_t pitch_a = hiddensize;
1357 int start_x_a = headid * hdsz;
1358 int start_y_a = all_vert128_loop * 128 + seqlen_entry;
1359
1360 gemm_arg_128x256.matA_base_desc.init({args->matdO_ptr},
1361 {width_a, height_a, pitch_a}, {start_x_a, start_y_a});
1362
1363 uint32_t width_b = (headid + 1) * hdsz;
1364 uint32_t height_b = tru_seqlen + seqlen_entry;
1365 uint32_t pitch_b = hiddensize;
1366 int start_x_b = headid * hdsz;
1367 int start_y_b = hor_256_loop * 256 + seqlen_entry;
1368
1369 //B transpose, be swapped in init
1370 gemm_arg_128x256.matB_base_desc.init({args->matV_ptr},
1371 {height_b, width_b, pitch_b}, {start_y_b, start_x_b});
1372
1373 gemm_arg_128x256.inner_loop_count
1374 = (wg_tile_QKT_k + k_stride - 1) / k_stride;
1375
1376 matAcc_128x256.init(0);
1377
1378 gemm_op_128x256_t gemm_op_128x256;
1379 gemm_op_128x256(g_thd32_tid, matAcc_128x256, gemm_arg_128x256);
1380
1381 int width_c = max_seqlen;
1382 int height_c = max_seqlen * (batchid * numhead + headid + 1);
1383 int pitch_c = max_seqlen;
1384 int start_x_c
1385 = gemm_op_128x256_t::get_matC_offset_x(g_thd32_tid)
1386 + hor_256_loop * 256;
1387 int start_y_c = (batchid * numhead + headid) * max_seqlen
1388 + all_vert128_loop * 128
1389 + gemm_op_128x256_t::get_matC_offset_y(g_thd32_tid);
1390
1391 matC_128x256_payload.init(args->matdW_ptr, width_c, height_c,
1392 pitch_c, start_x_c, start_y_c);
1393 subgroup::elemwise_cvt<matC_128x256_t, matAcc_128x256_t>(
1394 matC_128x256, matAcc_128x256);
1395 subgroup::tile_store(matC_128x256, matC_128x256_payload);
1396 xetla_fence<memory_kind::untyped_global>();
1397 }
1398
1399 for (int blk_128x128_loop = 0; blk_128x128_loop < blk_128x128_one;
1400 blk_128x128_loop++) {
1401 gemm_arguments_128x128 gemm_arg_128x128;
1402 matAcc_128x128_t matAcc_128x128;
1403 matC_128x128_t matC_128x128;
1404 matC_128x128_payload_t matC_128x128_payload;
1405
1406 uint32_t width_a = (headid + 1) * hdsz;
1407 uint32_t height_a = tru_seqlen + seqlen_entry;
1408 uint32_t pitch_a = hiddensize;
1409 int start_x_a = headid * hdsz;
1410 int start_y_a = all_vert128_loop * 128 + seqlen_entry;
1411
1412 gemm_arg_128x128.matA_base_desc.init({args->matdO_ptr},
1413 {width_a, height_a, pitch_a}, {start_x_a, start_y_a});
1414
1415 uint32_t width_b = (headid + 1) * hdsz;
1416 uint32_t height_b = tru_seqlen + seqlen_entry;
1417 uint32_t pitch_b = hiddensize;
1418 int start_x_b = headid * hdsz;
1419 int start_y_b = offset_blk_128x128 + seqlen_entry;
1420
1421 //B transpose, be swapped in init
1422 gemm_arg_128x128.matB_base_desc.init({args->matV_ptr},
1423 {height_b, width_b, pitch_b}, {start_y_b, start_x_b});
1424
1425 gemm_arg_128x128.inner_loop_count
1426 = (wg_tile_QKT_k + k_stride - 1) / k_stride;
1427
1428 matAcc_128x128.init(0);
1429
1430 gemm_op_128x128_t gemm_op_128x128;
1431 gemm_op_128x128(g_thd32_tid, matAcc_128x128, gemm_arg_128x128);
1432
1433 int width_c = max_seqlen;
1434 int height_c = max_seqlen * (batchid * numhead + headid + 1);
1435 int pitch_c = max_seqlen;
1436 int start_x_c = offset_blk_128x128
1437 + gemm_op_128x128_t::get_matC_offset_x(g_thd32_tid);
1438 int start_y_c = (batchid * numhead + headid) * max_seqlen
1439 + all_vert128_loop * 128
1440 + gemm_op_128x128_t::get_matC_offset_y(g_thd32_tid);
1441
1442 matC_128x128_payload.init(args->matdW_ptr, width_c, height_c,
1443 pitch_c, start_x_c, start_y_c);
1444 subgroup::elemwise_cvt<matC_128x128_t, matAcc_128x128_t>(
1445 matC_128x128, matAcc_128x128);
1446 subgroup::tile_store(matC_128x128, matC_128x128_payload);
1447 xetla_fence<memory_kind::untyped_global>();
1448 }
1449
1450 int elem_Ln512_loop_num = 4;
1451 int height_8x64_512 = 8 * sfx_type_size;
1452 int width_8x16_512 = 64 / sfx_type_size;
1453 int height_elem_offset
1454 = (max_seqlen * (batchid * numhead + headid)
1455 + (all_vert128_loop * 128) + (tid_linear * 4))
1456 * height_8x64_512;
1457 int width_elem = width_8x16_512;
1458 int height_elem;
1459 int pitch_elem = width_elem;
1460 int start_x_elem = 0;
1461 int start_y_elem;
1462
1463 xetla_vector<uint32_t, 16> mkin_vec16;
1464 if constexpr (Mkin_flag == true) {
1465 uint32_t mk_attn_all
1466 = sizeof(uint32_t) * (max_seqlen / 32) * (batchid);
1467 xetla_vector<uint32_t, 16> mk_attn_offsets
1468 = xetla_vector_gen<uint32_t, 16>(0, 1);
1469 mk_attn_offsets *= sizeof(uint32_t);
1470 mk_attn_offsets += mk_attn_all;
1471 mkin_vec16 = xetla_load_global<uint32_t, 1,
1473 cache_hint::cached, 16>(
1474 args->matMkin_ptr, mk_attn_offsets);
1475 }
1476
1477 uint32_t mk_offset_all;
1479 = xetla_vector_gen<uint32_t, 16>(0, 1);
1480 if constexpr (Dopt_RandGenflag == false) {
1481 mk_offset_all = sizeof(uint32_t) * (max_seqlen / 32)
1482 * ((batchid * numhead + headid) * max_seqlen
1483 + (all_vert128_loop * 128) + tid_linear * 4);
1484 mk_offsets *= sizeof(uint32_t);
1485 mk_offsets += mk_offset_all;
1486 }
1487
1488 first_nbarr.arrive();
1489 first_nbarr.wait();
1490
1491 for (int elem_Ln512_loop = 0; elem_Ln512_loop < elem_Ln512_loop_num;
1492 elem_Ln512_loop++) {
1493 matElem_ld_t matdW_rd;
1494 matElem_ld_payload_t matdW_rd_payload;
1495 matElem_ld_t matW_rd;
1496 matElem_ld_payload_t matW_rd_payload;
1497 matElem_st_t matdW_st;
1498 matElem_st_payload_t matdW_st_payload;
1499 matElem_reg_t matdW_reg16x32;
1500 matElem_reg_t matW_reg16x32;
1501 xetla_vector<uint32_t, 16> mkdpot_vec16;
1502
1503 start_y_elem = height_elem_offset
1504 + elem_Ln512_loop * height_8x64_512;
1505 height_elem
1506 = start_y_elem + ((std_seqlen * sfx_type_size) / 64);
1507
1508 matdW_rd_payload.init(args->matdW_ptr, width_elem, height_elem,
1509 pitch_elem, start_x_elem, start_y_elem);
1510 matW_rd_payload.init(args->matW_ptr, width_elem, height_elem,
1511 pitch_elem, start_x_elem, start_y_elem);
1512 matdW_st_payload.init(args->matdW_ptr, width_elem, height_elem,
1513 pitch_elem, start_x_elem, start_y_elem);
1514
1515 subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
1516 matdW_rd, matdW_rd_payload);
1517 subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
1518 matW_rd, matW_rd_payload);
1519
1520 if constexpr (Dopt_RandGenflag == false) {
1521 mkdpot_vec16 = xetla_load_global<uint32_t, 1,
1524 16>(args->matMkdpot_ptr, mk_offsets);
1525 mk_offsets += sizeof(uint32_t) * (max_seqlen / 32);
1526 }
1527
1528 matdW_reg16x32.reg = xetla_cvt<float, dtype_sfx>(matdW_rd.reg);
1529 matW_reg16x32.reg = xetla_cvt<float, dtype_sfx>(matW_rd.reg);
1530
1531 if constexpr (Dopt_RandGenflag == false) {
1532
1533#pragma unroll
1534 for (int j = 0; j < 16; j++) {
1535 uint32_t mkdata_i = mkdpot_vec16[j];
1536 xetla_mask_int<32> mkdata
1537 = xetla_mask_int_gen<32>(mkdata_i);
1538 matdW_reg16x32.reg.xetla_format<float>()
1539 .xetla_select<32, 1>(j * 32)
1540 .xetla_merge(0.0,
1541 matdW_reg16x32.reg.xetla_format<float>()
1542 .xetla_select<32, 1>(j * 32),
1543 mkdata);
1544 }
1545 matdW_reg16x32.reg = matW_reg16x32.reg * matdW_reg16x32.reg;
1546 } else {
1547#pragma unroll
1548 for (int j = 0; j < 16; j++) {
1549 xetla_mask<32> mask;
1550 if constexpr (sfx_type_size == 2) {
1551 mask = matW_rd.reg.xetla_format<int16_t>()
1552 .xetla_select<32, 1>(j * 32)
1553 < 0;
1554 matW_rd.reg.xetla_format<uint16_t>()
1555 .xetla_select<32, 1>(j * 32)
1556 &= 0x7FFF;
1557 }
1558 if constexpr (sfx_type_size == 1) {
1559 mask = matW_rd.reg.xetla_format<int8_t>()
1560 .xetla_select<32, 1>(j * 32)
1561 < 0;
1562 matW_rd.reg.xetla_format<uint8_t>()
1563 .xetla_select<32, 1>(j * 32)
1564 &= 0x7F;
1565 }
1566 matW_reg16x32.reg.xetla_format<float>()
1567 .xetla_select<32, 1>(j * 32)
1568 .xetla_merge(0.0, mask);
1569 }
1570
1571 matdW_reg16x32.reg = matW_reg16x32.reg * matdW_reg16x32.reg;
1572
1573 matW_reg16x32.reg
1574 = xetla_cvt<float, dtype_sfx>(matW_rd.reg);
1575 matW_reg16x32.reg *= args->Scaling;
1576 }
1577
1579 = matdW_reg16x32.reg.xetla_select<16, 1>(0);
1580#pragma unroll
1581 for (int j = 1; j < 32; j++)
1582 mdw_sum = mdw_sum
1583 + matdW_reg16x32.reg.xetla_select<16, 1>(j * 16);
1584
1585 mdw_sum.xetla_select<8, 1>(0) = mdw_sum.xetla_select<8, 1>(0)
1586 + mdw_sum.xetla_select<8, 1>(8);
1587 mdw_sum.xetla_select<4, 1>(0) = mdw_sum.xetla_select<4, 1>(0)
1588 + mdw_sum.xetla_select<4, 1>(4);
1589 mdw_sum.xetla_select<2, 1>(0) = mdw_sum.xetla_select<2, 1>(0)
1590 + mdw_sum.xetla_select<2, 1>(2);
1591 mdw_sum.xetla_select<1, 1>(0) = mdw_sum.xetla_select<1, 1>(0)
1592 + mdw_sum.xetla_select<1, 1>(1);
1593 {
1594 float sumtmp = mdw_sum[0];
1595 matW_reg16x32.reg = matW_reg16x32.reg * sumtmp;
1596 }
1597
1598 matdW_reg16x32.reg -= matW_reg16x32.reg;
1599
1600 matdW_reg16x32.reg = matdW_reg16x32.reg * args->Pinv;
1601
1602 if constexpr (Mkin_flag == true) {
1603#pragma unroll
1604 for (int j = 0; j < 16; j++) {
1605 uint32_t mkdata_i = mkin_vec16[j];
1606 xetla_mask_int<32> mkdata
1607 = xetla_mask_int_gen<32>(mkdata_i);
1608 matdW_reg16x32.reg.xetla_format<float>()
1609 .xetla_select<32, 1>(j * 32)
1610 .xetla_merge(0.0,
1611 matdW_reg16x32.reg.xetla_format<float>()
1612 .xetla_select<32, 1>(j * 32),
1613 mkdata);
1614 }
1615 }
1616
1617 matdW_st.reg = xetla_cvt<dtype_sfx, float>(matdW_reg16x32.reg);
1618
1619 subgroup::tile_store(matdW_st, matdW_st_payload);
1620 xetla_fence<memory_kind::untyped_global>();
1621 }
1622
1623 second_nbarr.arrive();
1624 second_nbarr.wait();
1625
1626 { //dQ
1627 gemm_arguments_128x64 gemm_arg_128x64;
1628 matAcc_128x64_t matAcc_128x64;
1629 matC_128x64_t matC_128x64;
1630 matC_128x64_payload_t matC_128x64_payload;
1631
1632 uint32_t width_a = tru_seqlen_ex;
1633 uint32_t height_a = (batchid * numhead + headid) * max_seqlen
1634 + tru_seqlen;
1635 uint32_t pitch_a = max_seqlen;
1636 int start_x_a = 0;
1637 int start_y_a = (batchid * numhead + headid) * max_seqlen
1638 + all_vert128_loop * 128;
1639
1640 gemm_arg_128x64.matA_base_desc.init({args->matdW_ptr},
1641 {width_a, height_a, pitch_a}, {start_x_a, start_y_a});
1642
1643 uint32_t width_b = (headid + 1) * hdsz;
1644 uint32_t height_b = tru_seqlen + seqlen_entry;
1645 uint32_t pitch_b = hiddensize;
1646 int start_x_b = headid * hdsz;
1647 int start_y_b = seqlen_entry;
1648
1649 gemm_arg_128x64.matB_base_desc.init({args->matK_ptr},
1650 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
1651
1652 gemm_arg_128x64.inner_loop_count
1653 = (wg_tile_out_k + k_stride - 1) / k_stride;
1654
1655 matAcc_128x64.init(0);
1656
1657 gemm_op_128x64_t gemm_op_128x64;
1658 gemm_op_128x64(g_thd32_tid, matAcc_128x64, gemm_arg_128x64);
1659
1660 int width_c = (headid + 1) * hdsz;
1661 int height_c = tru_seqlen + seqlen_entry;
1662 int pitch_c = hiddensize;
1663 int start_x_c = headid * hdsz
1664 + gemm_op_128x64_t::get_matC_offset_x(g_thd32_tid);
1665 int start_y_c = all_vert128_loop * 128 + seqlen_entry
1666 + gemm_op_128x64_t::get_matC_offset_y(g_thd32_tid);
1667
1668 matC_128x64_payload.init(args->matdQ_ptr, width_c, height_c,
1669 pitch_c, start_x_c, start_y_c);
1670 subgroup::elemwise_cvt<matC_128x64_t, matAcc_128x64_t>(
1671 matC_128x64, matAcc_128x64);
1672 subgroup::tile_store(matC_128x64, matC_128x64_payload);
1673 }
1674 } //all_vert128_loop
1675
1676 for (int transp256_loop = 0; transp256_loop < transp256_loop_num;
1677 transp256_loop++) {
1678 gemm_arguments_256x64_trnp_a gemm_arg_256x64;
1679 matAcc_256x64_trnp_a_t matAcc_256x64;
1680 matC_256x64_trnp_a_t matC_256x64;
1681 matC_256x64_trnp_a_payload_t matC_256x64_payload;
1682
1683 uint32_t width_a = tru_seqlen_ex;
1684 uint32_t height_a
1685 = (batchid * numhead + headid) * max_seqlen + tru_seqlen;
1686 uint32_t pitch_a = max_seqlen;
1687 int start_x_a = transp256_loop * 256;
1688 int start_y_a = (batchid * numhead + headid) * max_seqlen;
1689
1690 gemm_arg_256x64.matA_base_desc.init({args->matdW_ptr},
1691 {height_a, width_a, pitch_a}, {start_y_a, start_x_a});
1692
1693 uint32_t width_b = (headid + 1) * hdsz;
1694 uint32_t height_b = tru_seqlen + seqlen_entry;
1695 uint32_t pitch_b = hiddensize;
1696 int start_x_b = headid * hdsz;
1697 int start_y_b = seqlen_entry;
1698
1699 gemm_arg_256x64.matB_base_desc.init({args->matQ_ptr},
1700 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
1701
1702 gemm_arg_256x64.inner_loop_count
1703 = (wg_tile_out_k + k_stride - 1) / k_stride;
1704
1705 matAcc_256x64.init(0);
1706 gemm_op_256x64_trnp_a_t gemm_op_256x64_trnp_a;
1707 gemm_op_256x64_trnp_a(g_thd32_tid, matAcc_256x64, gemm_arg_256x64);
1708
1709 int width_c = (headid + 1) * hdsz;
1710 int height_c = tru_seqlen + seqlen_entry;
1711 int pitch_c = hiddensize;
1712 int start_x_c = headid * hdsz
1713 + gemm_op_256x64_trnp_a_t::get_matC_offset_x(g_thd32_tid);
1714 int start_y_c = transp256_loop * 256 + seqlen_entry
1715 + gemm_op_256x64_trnp_a_t::get_matC_offset_y(g_thd32_tid);
1716
1717 matC_256x64_payload.init(args->matdK_ptr, width_c, height_c,
1718 pitch_c, start_x_c, start_y_c);
1720 matAcc_256x64_trnp_a_t>(matC_256x64, matAcc_256x64);
1721 subgroup::tile_store(matC_256x64, matC_256x64_payload);
1722
1723 all_nbarr.arrive();
1724 all_nbarr.wait();
1725 }
1726
1727 for (int transp128_loop = 0; transp128_loop < transp128_loop_num;
1728 transp128_loop++) {
1729 gemm_arguments_128x64_trnp_a gemm_arg_128x64;
1730 matAcc_128x64_trnp_a_t matAcc_128x64;
1731 matC_128x64_trnp_a_t matC_128x64;
1732 matC_128x64_trnp_a_payload_t matC_128x64_payload;
1733
1734 uint32_t width_a = tru_seqlen_ex;
1735 uint32_t height_a
1736 = (batchid * numhead + headid) * max_seqlen + tru_seqlen;
1737 uint32_t pitch_a = max_seqlen;
1738 int start_x_a = transp128_loop * 128 + offset_blk_128x128;
1739 int start_y_a = (batchid * numhead + headid) * max_seqlen;
1740
1741 gemm_arg_128x64.matA_base_desc.init({args->matdW_ptr},
1742 {height_a, width_a, pitch_a}, {start_y_a, start_x_a});
1743
1744 uint32_t width_b = (headid + 1) * hdsz;
1745 uint32_t height_b = tru_seqlen + seqlen_entry;
1746 uint32_t pitch_b = hiddensize;
1747 int start_x_b = headid * hdsz;
1748 int start_y_b = seqlen_entry;
1749
1750 gemm_arg_128x64.matB_base_desc.init({args->matQ_ptr},
1751 {width_b, height_b, pitch_b}, {start_x_b, start_y_b});
1752
1753 gemm_arg_128x64.inner_loop_count
1754 = (wg_tile_out_k + k_stride - 1) / k_stride;
1755
1756 matAcc_128x64.init(0);
1757
1758 gemm_op_128x64_trnp_a_t gemm_op_128x64_trnp_a;
1759 gemm_op_128x64_trnp_a(g_thd32_tid, matAcc_128x64, gemm_arg_128x64);
1760
1761 int width_c = (headid + 1) * hdsz;
1762 int height_c = tru_seqlen + seqlen_entry;
1763 int pitch_c = hiddensize;
1764 int start_x_c = headid * hdsz
1765 + gemm_op_128x64_trnp_a_t::get_matC_offset_x(g_thd32_tid);
1766 int start_y_c = transp128_loop * 128 + seqlen_entry
1767 + offset_blk_128x128
1768 + gemm_op_128x64_trnp_a_t::get_matC_offset_y(g_thd32_tid);
1769
1770 matC_128x64_payload.init(args->matdK_ptr, width_c, height_c,
1771 pitch_c, start_x_c, start_y_c);
1773 matAcc_128x64_trnp_a_t>(matC_128x64, matAcc_128x64);
1774 subgroup::tile_store(matC_128x64, matC_128x64_payload);
1775
1776 all_nbarr.arrive();
1777 all_nbarr.wait();
1778 } //transp128_loop
1779
1780 } //xetla_softmax_bwd_t::call
1781}; //struct xetla_softmax_bwd_t
1782
1783} // namespace gpu::xetla::kernel
Gemm functor.
Definition api.hpp:52
Definition limitation.hpp:738
#define __XETLA_API
Definition common.hpp:43
#define xetla_merge
xetla merge.
Definition base_ops.hpp:60
__ESIMD_NS::simd_mask< N > xetla_mask_int
wrapper for xetla_mask_int.
Definition base_types.hpp:172
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
__XETLA_API xetla_vector< Ty, N *NElts > xetla_load_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_mask< N > pred=1)
Stateless scattered load.
Definition memory.hpp:245
__XETLA_API xetla_vector< uint32_t, 4 > get_time_stamp()
Returns time stamp.
Definition misc.hpp:57
#define rand_threshold_const
Definition mha_attn_reg.hpp:26
#define list_width
Definition mha_attn_reg.hpp:25
#define SIGN_BIT_W16
Definition mha_attn_reg.hpp:28
#define SIGN_BIT_B8
Definition mha_attn_reg.hpp:29
Definition limitation.hpp:734
__XETLA_API std::enable_if_t<(T_src::register_layout !=reg_layout::linear) &&(T_dst::register_layout !=reg_layout::linear) &&is_same_layout< T_dst, T_src >::value &&(!is_floating_to_integer< T_dst, T_src >::value)> elemwise_cvt(T_dst &dst, T_src &src)
Is the element wise data conversion, the src and dst tile should have the same layout.
Definition op_function.hpp:40
__XETLA_API std::enable_if_t< detail::check_store_type< tile_t, payload_t >::is_global_2d_xe > tile_store(tile_t &tile, payload_t &payload)
Is the func storing data from register file to global memory.
Definition store_xe.hpp:91
mem_space
Definition common.hpp:77
gpu_arch
Definition common.hpp:73
msg_type
Definition common.hpp:78
mem_layout
Definition common.hpp:76
Compute attribute for gemm.
Definition common.hpp:32
Compute policy for xmx engine.
Definition compute_policy.hpp:35
Fine-tune knobs for gemm.
Definition common.hpp:43
Gemm default pre_processing functor.
Definition api.hpp:33
Gemm pre_processing functor with applying relu op to matA.
Definition api.hpp:39
Workgroup level tile shape description.
Definition tile_shape.hpp:34
Arguments for xetla_softmax_bwd_t::run.
Definition mha_core_attn.hpp:1130
uint32_t * matMkin_ptr
Definition mha_core_attn.hpp:1136
float Scaling
Definition mha_core_attn.hpp:1145
dtype_bin * matdO_ptr
Definition mha_core_attn.hpp:1138
dtype_sfx * matdW_ptr
Definition mha_core_attn.hpp:1140
dtype_bot * matdQ_ptr
Definition mha_core_attn.hpp:1142
dtype_bin * matK_ptr
Definition mha_core_attn.hpp:1134
dtype_bin * matV_ptr
Definition mha_core_attn.hpp:1135
dtype_sfx * matW_ptr
Definition mha_core_attn.hpp:1139
float Pinv
Definition mha_core_attn.hpp:1144
dtype_bin * matQ_ptr
Definition mha_core_attn.hpp:1133
uint32_t * mList_ptr
Definition mha_core_attn.hpp:1132
dtype_bot * matdK_ptr
Definition mha_core_attn.hpp:1143
dtype_bot * matdV_ptr
Definition mha_core_attn.hpp:1141
uint32_t * matMkdpot_ptr
Definition mha_core_attn.hpp:1137
Definition mha_core_attn.hpp:872
mem_desc_t< dtype_sfx, gemm_mem_layout_a, gemm_mem_space_a > mem_desc_a_out
Definition mha_core_attn.hpp:920
subgroup::tile_desc_t< matAcc_128x128_t::tile_desc::tile_size_x, matAcc_128x128_t::tile_desc::tile_size_y, matAcc_128x128_t::tile_desc::block_size_x, matAcc_128x128_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x128_tile_desc_t
Definition mha_core_attn.hpp:1003
static constexpr int ThreadNum
Definition mha_core_attn.hpp:878
typename gemm_op_128x128_t::arguments_t gemm_arguments_128x128
Definition mha_core_attn.hpp:978
static constexpr mem_layout gemm_mem_layout_out_b
Definition mha_core_attn.hpp:897
typename gemm_op_128x64_t::arguments_t gemm_arguments_128x64
Definition mha_core_attn.hpp:980
typename gemm_op_128x64_trnp_a_t::matAcc_t matAcc_128x64_trnp_a_t
Definition mha_core_attn.hpp:993
subgroup::tile_t< dtype_bot, matC_128x64_trnp_af_tile_desc_t > matC_128x64_trnp_af_t
Definition mha_core_attn.hpp:1048
group::compute_policy_default_xmx< group::compute_attr_t< dtype_bin, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe > compute_policy_QKT
Definition mha_core_attn.hpp:917
static constexpr mem_layout mem_layout_QKT_b
Definition mha_core_attn.hpp:886
static constexpr mem_layout mem_layout_a
Definition mha_core_attn.hpp:884
static constexpr mem_layout mem_layout_trnp_a
Definition mha_core_attn.hpp:885
mem_desc_t< dtype_bin, gemm_mem_layout_a, gemm_mem_space_a > mem_desc_a_QKT
Definition mha_core_attn.hpp:912
typename gemm_op_128x64_t::matAcc_t matAcc_128x64_t
Definition mha_core_attn.hpp:992
subgroup::tile_desc_t< matAcc_256x64_trnp_a_t::tile_desc::tile_size_x, matAcc_256x64_trnp_a_t::tile_desc::tile_size_y, matAcc_256x64_trnp_a_t::tile_desc::block_size_x, matAcc_256x64_trnp_a_t::tile_desc::block_size_y, reg_layout::tiled > matC_256x64_trnp_a_tile_desc_t
Definition mha_core_attn.hpp:1025
group::perf_tuning_knob_t< k_stride, prefetch_distance, periodic_sync_interval > bgm_perf_tuning_knob
Definition mha_core_attn.hpp:905
typename gemm_op_128x128_t::matAcc_t matAcc_128x128_t
Definition mha_core_attn.hpp:990
group::compute_policy_default_xmx< group::compute_attr_t< dtype_sfx, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe > compute_policy_out
Definition mha_core_attn.hpp:925
static constexpr uint32_t periodic_sync_interval
Definition mha_core_attn.hpp:899
gpu::xetla::subgroup::tile_desc_t< 64/sfx_type_size, 8 *sfx_type_size, 64/sfx_type_size, 8 *sfx_type_size, reg_layout::tiled > matElem_tile_desc_t
Definition mha_core_attn.hpp:1112
static constexpr mem_layout gemm_mem_layout_QKT_b
Definition mha_core_attn.hpp:896
typename gemm_op_128x64_trnp_af_t::matAcc_t matAcc_128x64_trnp_af_t
Definition mha_core_attn.hpp:995
dtype_bwd_acc_ dtype_acc
Definition mha_core_attn.hpp:876
subgroup::tile_t< dtype_bot, matC_256x64_trnp_af_tile_desc_t > matC_256x64_trnp_af_t
Definition mha_core_attn.hpp:1050
static constexpr mem_space gemm_mem_space_b
Definition mha_core_attn.hpp:895
typename gemm_op_128x256_t::matAcc_t matAcc_128x256_t
Definition mha_core_attn.hpp:991
mem_desc_t< dtype_bin, gemm_mem_layout_out_b, gemm_mem_space_b > mem_desc_b_out_b_trnp_a
Definition mha_core_attn.hpp:930
static __XETLA_API void call(sycl::nd_item< 3 > &item, arguments_t *args)
Main execution function for fused mha softmax.
Definition mha_core_attn.hpp:1154
static constexpr uint16_t sfx_type_size
Definition mha_core_attn.hpp:936
mem_desc_t< dtype_bin, gemm_mem_layout_out_b, gemm_mem_space_b > mem_desc_b_out
Definition mha_core_attn.hpp:922
static constexpr mem_space gemm_mem_space_a
Definition mha_core_attn.hpp:890
static constexpr uint32_t k_stride
Definition mha_core_attn.hpp:903
group::tile_shape_t< 256, 128, 32, 32 > tile_attr_128x256
Definition mha_core_attn.hpp:907
typename gemm_op_128x64_trnp_af_t::arguments_t gemm_arguments_128x64_trnp_af
Definition mha_core_attn.hpp:986
static constexpr mem_layout gemm_mem_layout_trnp_a
Definition mha_core_attn.hpp:893
group::tile_shape_t< 64, 256, 16, 32 > tile_attr_256x64
Definition mha_core_attn.hpp:908
static constexpr mem_space mem_space_c
Definition mha_core_attn.hpp:882
typename gemm_op_256x64_trnp_a_t::arguments_t gemm_arguments_256x64_trnp_a
Definition mha_core_attn.hpp:984
typename gemm_op_256x64_trnp_af_t::matAcc_t matAcc_256x64_trnp_af_t
Definition mha_core_attn.hpp:996
subgroup::tile_t< dtype_bot, matC_128x64_trnp_a_tile_desc_t > matC_128x64_trnp_a_t
Definition mha_core_attn.hpp:1044
static constexpr mem_space gemm_mem_space_trnp_a
Definition mha_core_attn.hpp:891
static constexpr mem_layout gemm_mem_layout_a
Definition mha_core_attn.hpp:892
subgroup::tile_desc_t< matAcc_256x64_trnp_af_t::tile_desc::tile_size_x, matAcc_256x64_trnp_af_t::tile_desc::tile_size_y, matAcc_256x64_trnp_af_t::tile_desc::block_size_x, matAcc_256x64_trnp_af_t::tile_desc::block_size_y, reg_layout::tiled > matC_256x64_trnp_af_tile_desc_t
Definition mha_core_attn.hpp:1037
typename gemm_op_128x256_t::arguments_t gemm_arguments_128x256
Definition mha_core_attn.hpp:979
dtype_bwd_sfx_ dtype_sfx
Definition mha_core_attn.hpp:875
subgroup::tile_desc_t< matAcc_128x64_trnp_a_t::tile_desc::tile_size_x, matAcc_128x64_trnp_a_t::tile_desc::tile_size_y, matAcc_128x64_trnp_a_t::tile_desc::block_size_x, matAcc_128x64_trnp_a_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x64_trnp_a_tile_desc_t
Definition mha_core_attn.hpp:1020
typename gemm_op_128x64_trnp_a_t::arguments_t gemm_arguments_128x64_trnp_a
Definition mha_core_attn.hpp:982
group::compute_policy_default_xmx< group::compute_attr_t< dtype_sfx, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe > compute_policy_out_b_trnp_a
Definition mha_core_attn.hpp:933
static constexpr mem_space mem_space_a
Definition mha_core_attn.hpp:880
static constexpr uint32_t prefetch_distance
Definition mha_core_attn.hpp:900
dtype_bwd_bin_ dtype_bin
Definition mha_core_attn.hpp:873
static constexpr mem_space mem_space_b
Definition mha_core_attn.hpp:881
work_group_t< ThreadNum > work_group_t
Definition mha_core_attn.hpp:940
subgroup::tile_desc_t< matAcc_128x64_t::tile_desc::tile_size_x, matAcc_128x64_t::tile_desc::tile_size_y, matAcc_128x64_t::tile_desc::block_size_x, matAcc_128x64_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x64_tile_desc_t
Definition mha_core_attn.hpp:1015
mem_desc_t< dtype_bin, gemm_mem_layout_QKT_b, gemm_mem_space_b > mem_desc_b_QKT
Definition mha_core_attn.hpp:914
static constexpr mem_layout mem_layout_c
Definition mha_core_attn.hpp:888
subgroup::tile_desc_t< matAcc_128x256_t::tile_desc::tile_size_x, matAcc_128x256_t::tile_desc::tile_size_y, matAcc_128x256_t::tile_desc::block_size_x, matAcc_128x256_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x256_tile_desc_t
Definition mha_core_attn.hpp:1009
mem_desc_t< dtype_sfx, gemm_mem_layout_trnp_a, gemm_mem_space_trnp_a > mem_desc_a_out_b_trnp_a
Definition mha_core_attn.hpp:928
dtype_bwd_bot_ dtype_bot
Definition mha_core_attn.hpp:874
subgroup::tile_t< dtype_bot, matC_256x64_trnp_a_tile_desc_t > matC_256x64_trnp_a_t
Definition mha_core_attn.hpp:1046
typename gemm_op_256x64_trnp_af_t::arguments_t gemm_arguments_256x64_trnp_af
Definition mha_core_attn.hpp:988
subgroup::tile_desc_t< matAcc_128x64_trnp_af_t::tile_desc::tile_size_x, matAcc_128x64_trnp_af_t::tile_desc::tile_size_y, matAcc_128x64_trnp_af_t::tile_desc::block_size_x, matAcc_128x64_trnp_af_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x64_trnp_af_tile_desc_t
Definition mha_core_attn.hpp:1031
group::tile_shape_t< 64, 128, 16, 16 > tile_attr_128x64
Definition mha_core_attn.hpp:909
static constexpr uint32_t global_kslicing
Definition mha_core_attn.hpp:935
typename gemm_op_256x64_trnp_a_t::matAcc_t matAcc_256x64_trnp_a_t
Definition mha_core_attn.hpp:994
static constexpr mem_layout mem_layout_out_b
Definition mha_core_attn.hpp:887
group::tile_shape_t< 128, 128, 32, 16 > tile_attr_128x128
Definition mha_core_attn.hpp:906
Arguments for xetla_softmax_fwd_t::run.
Definition mha_core_attn.hpp:193
dtype_bin * matK_ptr
Definition mha_core_attn.hpp:197
uint32_t * matMkin_ptr
Definition mha_core_attn.hpp:199
uint32_t * mList_ptr
Definition mha_core_attn.hpp:195
dtype_sfx * matQKT_ptr
Definition mha_core_attn.hpp:201
dtype_bin * matV_ptr
Definition mha_core_attn.hpp:198
dtype_bot * matOut_ptr
Definition mha_core_attn.hpp:202
dtype_bin * matQ_ptr
Definition mha_core_attn.hpp:196
uint32_t * matMkdpot_ptr
Definition mha_core_attn.hpp:200
float Scaling
Definition mha_core_attn.hpp:204
Definition mha_core_attn.hpp:44
subgroup::tile_desc_t< matAcc_128x64_t::tile_desc::tile_size_x, matAcc_128x64_t::tile_desc::tile_size_y, matAcc_128x64_t::tile_desc::block_size_x, matAcc_128x64_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x64_tile_desc_t
Definition mha_core_attn.hpp:145
mem_desc_t< dtype_bin, gemm_mem_layout_a, gemm_mem_space_a > mem_desc_a_QKT
Definition mha_core_attn.hpp:81
typename gemm_op_128x128_t::matAcc_t matAcc_128x128_t
Definition mha_core_attn.hpp:124
static constexpr int max_seqlen
Definition mha_core_attn.hpp:51
typename gemm_op_128x256_t::matAcc_t matAcc_128x256_t
Definition mha_core_attn.hpp:125
static constexpr uint16_t Rand_SIMD
Definition mha_core_attn.hpp:55
static constexpr mem_layout mem_layout_out_b
Definition mha_core_attn.hpp:59
static constexpr uint32_t global_kslicing
Definition mha_core_attn.hpp:96
static constexpr mem_layout mem_layout_QKT_b
Definition mha_core_attn.hpp:58
group::perf_tuning_knob_t< k_stride, prefetch_distance, periodic_sync_interval > bgm_perf_tuning_knob
Definition mha_core_attn.hpp:74
typename gemm_op_128x128_t::arguments_t gemm_arguments_128x128
Definition mha_core_attn.hpp:120
subgroup::tile_desc_t< matAcc_128x128_t::tile_desc::tile_size_x, matAcc_128x128_t::tile_desc::tile_size_y, matAcc_128x128_t::tile_desc::block_size_x, matAcc_128x128_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x128_tile_desc_t
Definition mha_core_attn.hpp:133
group::compute_policy_default_xmx< group::compute_attr_t< dtype_bin, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe > compute_policy_QKT
Definition mha_core_attn.hpp:86
mem_desc_t< dtype_bin, gemm_mem_layout_QKT_b, gemm_mem_space_b > mem_desc_b_QKT
Definition mha_core_attn.hpp:83
static constexpr uint32_t periodic_sync_interval
Definition mha_core_attn.hpp:69
typename gemm_op_128x256_t::arguments_t gemm_arguments_128x256
Definition mha_core_attn.hpp:121
static constexpr mem_space mem_space_b
Definition mha_core_attn.hpp:53
static constexpr mem_layout gemm_mem_layout_QKT_b
Definition mha_core_attn.hpp:66
work_group_t< ThreadNum > work_group_t
Definition mha_core_attn.hpp:101
dtype_bot_ dtype_bot
Definition mha_core_attn.hpp:46
dtype_acc_ dtype_acc
Definition mha_core_attn.hpp:48
static constexpr uint32_t k_stride
Definition mha_core_attn.hpp:72
static constexpr mem_space gemm_mem_space_b
Definition mha_core_attn.hpp:65
group::compute_policy_default_xmx< group::compute_attr_t< dtype_sfx, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe > compute_policy_out
Definition mha_core_attn.hpp:94
dtype_sfx_ dtype_sfx
Definition mha_core_attn.hpp:47
static constexpr int ThreadNum
Definition mha_core_attn.hpp:50
static constexpr mem_layout mem_layout_a
Definition mha_core_attn.hpp:57
mem_desc_t< dtype_bin, gemm_mem_layout_out_b, gemm_mem_space_b > mem_desc_b_out
Definition mha_core_attn.hpp:91
gpu::xetla::subgroup::tile_desc_t< 64/sfx_type_size, 8 *sfx_type_size, 64/sfx_type_size, 8 *sfx_type_size, reg_layout::tiled > matElem_tile_desc_t
Definition mha_core_attn.hpp:171
group::tile_shape_t< 128, 128, 32, 16 > tile_attr_128x128
Definition mha_core_attn.hpp:76
static constexpr mem_layout gemm_mem_layout_out_b
Definition mha_core_attn.hpp:67
static constexpr mem_space gemm_mem_space_a
Definition mha_core_attn.hpp:62
static constexpr uint16_t sfx_type_size
Definition mha_core_attn.hpp:97
group::tile_shape_t< 256, 128, 32, 32 > tile_attr_128x256
Definition mha_core_attn.hpp:77
typename gemm_op_128x64_t::arguments_t gemm_arguments_128x64
Definition mha_core_attn.hpp:122
static constexpr uint32_t prefetch_distance
Definition mha_core_attn.hpp:70
typename gemm_op_128x64_t::matAcc_t matAcc_128x64_t
Definition mha_core_attn.hpp:126
static constexpr mem_space mem_space_c
Definition mha_core_attn.hpp:54
static constexpr mem_layout mem_layout_c
Definition mha_core_attn.hpp:60
subgroup::tile_desc_t< matAcc_128x256_t::tile_desc::tile_size_x, matAcc_128x256_t::tile_desc::tile_size_y, matAcc_128x256_t::tile_desc::block_size_x, matAcc_128x256_t::tile_desc::block_size_y, reg_layout::tiled > matC_128x256_tile_desc_t
Definition mha_core_attn.hpp:139
dtype_bin_ dtype_bin
Definition mha_core_attn.hpp:45
static constexpr mem_layout gemm_mem_layout_a
Definition mha_core_attn.hpp:63
static __XETLA_API void call(sycl::nd_item< 3 > &item, arguments_t *args)
Main execution function for fused mha softmax.
Definition mha_core_attn.hpp:213
mem_desc_t< dtype_sfx, gemm_mem_layout_a, gemm_mem_space_a > mem_desc_a_out
Definition mha_core_attn.hpp:89
group::tile_shape_t< 64, 128, 16, 16 > tile_attr_128x64
Definition mha_core_attn.hpp:78
static constexpr mem_space mem_space_a
Definition mha_core_attn.hpp:52
Definition memory_descriptor.hpp:139
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102
xetla nbarrier definition API.
Definition raw_send_nbarrier.hpp:43
__XETLA_API void arrive()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:65
__XETLA_API void init_nbarrier(uint8_t nbarrier_id, nbarrier_role role=nbarrier_role::producer_consumer)
Definition raw_send_nbarrier.hpp:55
__XETLA_API void wait()
named barrier wait within subgroup.
Definition raw_send_nbarrier.hpp:76
Definition rand.hpp:30
__XETLA_API xetla_vector< uint32_t, 4 *SIMD > rand()
Definition rand.hpp:57
__XETLA_API void init(uint64_t seed, uint64_t subseq, uint64_t offset)
Definition rand.hpp:38