35 template <
typename matAcc_t,
typename coord_t>
37 [[maybe_unused]]
const coord_t &coord,
39 [[maybe_unused]] uint32_t slm_base = 0,
40 [[maybe_unused]] uint32_t nbarrier_base = 0) {}
48 template <
typename matAcc_t,
typename coord_t>
50 [[maybe_unused]]
const coord_t &coord,
52 [[maybe_unused]] uint32_t slm_base = 0,
53 [[maybe_unused]] uint32_t nbarrier_base = 0) {
55 matAcc.reg.xetla_merge(0, mask);
64 template <
typename matAcc_t,
typename coord_t>
66 [[maybe_unused]]
const coord_t &coord,
67 [[maybe_unused]] uint32_t slm_base = 0,
68 [[maybe_unused]] uint32_t nbarrier_base = 0) {
69 constexpr int elems = matAcc_t::tile_desc::block_elems;
70 constexpr int rounds = matAcc_t::tile_desc::tile_elems / elems;
71 using dtype =
typename matAcc_t::dtype;
73 for (uint32_t i = 0; i < rounds; ++i) {
74 auto sub_vec = matAcc.reg.xetla_select<elems, 1>(elems * i);
75 sub_vec = xetla_tanh<dtype, elems>(sub_vec);
77 constexpr int remained_elems = matAcc_t::tile_desc::tile_elems % elems;
78 if constexpr (remained_elems != 0) {
79 auto sub_vec = matAcc.reg.xetla_select<remained_elems, 1>(
80 elems * (matAcc_t::tile_elems / elems));
81 sub_vec = xetla_tanh<dtype, remained_elems>(sub_vec);
91 template <
typename matAcc_t,
typename coord_t>
93 [[maybe_unused]]
const coord_t &coord,
94 [[maybe_unused]] uint32_t slm_base = 0,
95 [[maybe_unused]] uint32_t nbarrier_base = 0) {
96 constexpr int elems = matAcc_t::tile_desc::block_elems;
97 constexpr int rounds = matAcc_t::tile_desc::tile_elems / elems;
98 constexpr float one = 1.0f;
100 for (uint32_t i = 0; i < rounds; ++i) {
101 auto sub_vec = matAcc.reg.xetla_select<elems, 1>(elems * i);
104 = xetla_exp<typename matAcc_t::dtype, elems>(sub_vec);
106 = temp_vec / (temp_vec + one);
107 sigmoid_value.xetla_merge(1, mask);
108 sub_vec = sigmoid_value;
110 constexpr int remained_elems = matAcc_t::tile_desc::tile_elems % elems;
111 if constexpr (remained_elems != 0) {
112 auto sub_vec = matAcc.reg.xetla_select<remained_elems, 1>(
113 elems * (matAcc_t::tile_elems / elems));
116 = xetla_exp<typename matAcc_t::dtype, remained_elems>(
119 = temp_vec / (temp_vec + one);
120 sigmoid_value.xetla_merge(1, mask);
121 sub_vec = sigmoid_value;
131 template <
typename matAcc_t,
typename coord_t>
133 [[maybe_unused]]
const coord_t &coord,
135 [[maybe_unused]] uint32_t slm_base = 0,
136 [[maybe_unused]] uint32_t nbarrier_base = 0) {
137 using dtype =
typename matAcc_t::dtype;
138 constexpr dtype C0 = 0.044715f;
139 constexpr dtype sqrt_two_over_pi = 0.79788458347320556640625f;
141 constexpr int elems = 8 * 16;
142 constexpr int rounds = matAcc_t::tile_elems / elems;
144 for (uint32_t i = 0; i < rounds; ++i) {
145 auto sub_vec = matAcc.reg.xetla_select<elems, 1>(elems * i);
147 * (1.f + C0 * sub_vec * sub_vec));
149 = xetla_tanh<dtype, elems>(sub_vec_x);
150 sub_vec = 0.5f * sub_vec * (1.f + tanh_value);
152 constexpr int remained_elems = matAcc_t::tile_elems % elems;
153 if constexpr (remained_elems != 0) {
154 auto sub_vec = matAcc.reg.xetla_select<remained_elems, 1>(
155 elems * (matAcc_t::tile_elems / elems));
157 * sub_vec * (1.f + C0 * sub_vec * sub_vec));
159 = xetla_tanh<dtype, remained_elems>(sub_vec_x);
160 sub_vec = 0.5f * sub_vec * (1.f + tanh_value);
171template <
typename dtype_out, gpu_arch arch_tag,
class enable =
void>
174template <
typename dtype_out_, gpu_arch arch_tag>
176 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
180 using shape_t =
typename mem_desc_w_t::shape_t;
181 using coord_t =
typename mem_desc_w_t::coord_t;
182 using base_t =
typename mem_desc_w_t::base_t;
189 :
shape(shape_), base(base_) {}
191 template <
typename matAcc_t>
193 const coord_t &coord,
const arguments_t &args,
194 [[maybe_unused]] uint32_t slm_base = 0,
195 [[maybe_unused]] uint32_t nbarrier_base = 0) {
196 using dtype_acc =
typename matAcc_t::dtype;
197 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
198 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
199 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
200 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
203 using bwd_w_tile_desc_t =
tile_desc_t<block_size_x, block_size_y,
209 bwd_w_payload_t bwd_w_payload(mem_desc_w);
211 constexpr dtype_acc c0 = 0.044715f;
212 constexpr dtype_acc d0 = 0.134145f;
213 constexpr dtype_acc sqrt_two_over_pi = 0.79788458347320556640625f;
214 constexpr uint32_t block_elems = matAcc_t::block_elems;
215 constexpr uint32_t num_block_x = matAcc_t::num_block_x;
217 for (uint32_t i = 0; i < tile_size_y / block_size_y; ++i) {
219 for (uint32_t j = 0; j < num_block_x; ++j) {
220 auto x = matAcc.reg.xetla_select<block_elems, 1>(
221 block_elems * (i * num_block_x + j));
223 = xetla_tanh<dtype_acc, block_elems>(
224 sqrt_two_over_pi * (x + c0 * x * x * x));
226 + 0.5f * x * (1.f - z * z)
227 * (sqrt_two_over_pi * (1.f + d0 * x * x)));
228 x = 0.5f * x * (1.f + z);
229 bwd_w.reg = xetla_cvt<dtype_out, dtype_acc, block_elems>(w);
230 tile_store<cache_hint::uncached>(bwd_w, bwd_w_payload);
231 bwd_w_payload.template update_tdesc<tdesc_update_dir::x_dir>(
234 bwd_w_payload.template update_tdesc<tdesc_update_dir::x_dir>(
236 bwd_w_payload.template update_tdesc<tdesc_update_dir::y_dir>(
239 if constexpr (tile_size_y % block_size_y != 0) {
240 constexpr uint32_t remain_size_y = tile_size_y % block_size_y;
241 constexpr uint32_t remain_y_start
242 = tile_size_y / block_size_y * block_size_y;
243 constexpr uint32_t remain_elems_start
244 = remain_y_start * tile_size_x;
245 constexpr uint32_t remain_block_elems
246 = remain_size_y * block_size_x;
248 using remain_bwd_w_tile_desc_t
249 =
tile_desc_t<block_size_x, remain_size_y, block_size_x,
251 using remain_bwd_w_tile_t
256 mem_desc_w.update_coord_y(remain_y_start);
257 remain_bwd_w_payload_t remain_bwd_w_payload(mem_desc_w);
258 remain_bwd_w_tile_t remain_bwd_w;
260 for (uint32_t j = 0; j < num_block_x; ++j) {
261 auto x = matAcc.reg.xetla_select<remain_block_elems, 1>(
262 remain_elems_start + remain_block_elems * j);
264 = xetla_tanh<dtype_acc, remain_block_elems>(
265 sqrt_two_over_pi * (x + c0 * x * x * x));
268 + 0.5f * x * (1.f - z * z)
269 * (sqrt_two_over_pi * (1.f + d0 * x * x)));
270 x = 0.5f * x * (1.f + z);
272 = xetla_cvt<dtype_out, dtype_acc, remain_block_elems>(
274 tile_store<cache_hint::uncached>(
275 remain_bwd_w, remain_bwd_w_payload);
277 .template update_tdesc<tdesc_update_dir::x_dir>(
289template <
typename dtype_in, gpu_arch arch_tag,
class enable =
void>
292template <
typename dtype_in_, gpu_arch arch_tag>
294 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
298 using shape_t =
typename mem_desc_x_t::shape_t;
299 using coord_t =
typename mem_desc_x_t::coord_t;
300 using base_t =
typename mem_desc_x_t::base_t;
306 :
shape(shape_), base(base_) {}
308 template <
typename matAcc_t>
310 const coord_t &coord,
const arguments_t &args,
311 [[maybe_unused]] uint32_t slm_base = 0,
312 [[maybe_unused]] uint32_t nbarrier_base = 0) {
313 using dtype_acc =
typename matAcc_t::dtype;
314 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
315 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
316 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
317 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
319 using bwd_x_tile_desc_t =
tile_desc_t<tile_size_x, tile_size_y,
327 bwd_x_payload_t bwd_x_payload(mem_desc_x);
328 tile_load<cache_hint::cached, cache_hint::cached>(bwd_x, bwd_x_payload);
330 constexpr dtype_acc c0 = 0.044715f;
331 constexpr dtype_acc d0 = 0.134145f;
332 constexpr dtype_acc sqrt_two_over_pi = 0.79788458347320556640625f;
333 constexpr uint32_t block_elems = matAcc_t::block_elems;
334 constexpr uint32_t num_block_x = matAcc_t::num_block_x;
336 for (uint32_t i = 0; i < tile_size_y / block_size_y; ++i) {
338 for (uint32_t j = 0; j < num_block_x; ++j) {
339 auto x_in = bwd_x.reg.xetla_select<block_elems, 1>(
340 block_elems * (i * num_block_x + j));
341 auto x = xetla_cvt<dtype_acc, dtype_in, block_elems>(x_in);
342 auto dy = matAcc.reg.xetla_select<block_elems, 1>(
343 block_elems * (i * num_block_x + j));
345 = xetla_tanh<dtype_acc, block_elems>(
346 sqrt_two_over_pi * (x + c0 * x * x * x));
348 + 0.5f * x * (1.f - z * z)
349 * (sqrt_two_over_pi * (1.f + d0 * x * x)));
353 if constexpr (tile_size_y % block_size_y != 0) {
354 constexpr uint32_t remain_size_y = tile_size_y % block_size_y;
355 constexpr uint32_t remain_y_start
356 = tile_size_y / block_size_y * block_size_y;
357 constexpr uint32_t remain_elems_start
358 = remain_y_start * tile_size_x;
359 constexpr uint32_t remain_block_elems
360 = remain_size_y * block_size_x;
362 for (uint32_t j = 0; j < num_block_x; ++j) {
363 auto x_in = bwd_x.reg.xetla_select<remain_block_elems, 1>(
364 remain_elems_start + remain_block_elems * j);
365 auto x = xetla_cvt<dtype_acc, dtype_in, remain_block_elems>(
367 auto dy = matAcc.reg.xetla_select<remain_block_elems, 1>(
368 remain_elems_start + remain_block_elems * j);
370 = xetla_tanh<dtype_acc, remain_block_elems>(
371 sqrt_two_over_pi * (x + c0 * x * x * x));
374 + 0.5f * x * (1.f - z * z)
375 * (sqrt_two_over_pi * (1.f + d0 * x * x)));
387template <
typename dtype_bias, gpu_arch arch_tag,
class enable =
void>
390template <
typename dtype_bias_, gpu_arch arch_tag>
392 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
396 using shape_t =
typename mem_desc_bias_t::shape_t;
397 using coord_t =
typename mem_desc_bias_t::coord_t;
398 using base_t =
typename mem_desc_bias_t::base_t;
405 :
shape(shape_), base(base_) {}
407 template <
typename matAcc_t>
409 const coord_t &coord,
const arguments_t &args,
410 [[maybe_unused]] uint32_t slm_base = 0,
411 [[maybe_unused]] uint32_t nbarrier_base = 0) {
412 using dtype_acc =
typename matAcc_t::dtype;
413 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
414 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
415 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
416 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
417 static constexpr int32_t num_block_x = matAcc_t::num_block_x;
418 static constexpr uint32_t block_elems = matAcc_t::block_elems;
420 using bias_tile_desc_t =
tile_desc_t<tile_size_x, 1, block_size_x, 1,
424 msg_type_v<bias_tile_desc_t, mem_desc_bias_t::space>, arch_tag>;
425 coord_t bias_coord(coord.x, 0);
428 bias_payload_t bias_payload(mem_desc_bias);
429 tile_load<cache_hint::cached, cache_hint::cached>(bias, bias_payload);
432 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
434 for (uint32_t j = 0; j < num_block_x; j++) {
437 .xetla_select<block_elems, 1>(
438 (i * num_block_x + j) * block_elems)
439 .xetla_format<dtype_acc, block_size_y,
442 for (uint32_t row_i = 0; row_i < block_size_y; row_i++) {
443 auto src_reg = bias.reg.xetla_select<block_size_x, 1>(
446 = xetla_cvt<dtype_acc, dtype_bias, block_size_x>(
448 + dst_reg.row(row_i);
453 if constexpr ((tile_size_y % block_size_y) != 0) {
454 constexpr uint32_t tail_start_y
455 = tile_size_y / block_size_y * block_size_y;
456 constexpr int32_t tail_size_y = tile_size_y % block_size_y;
457 constexpr int32_t tail_block_elems = tail_size_y * block_size_x;
459 for (uint32_t j = 0; j < num_block_x; j++) {
460 auto dst_reg = matAcc.reg
461 .xetla_select<tail_block_elems, 1>(
462 tail_start_y * tile_size_x
463 + j * tail_block_elems)
467 for (uint32_t row_i = 0; row_i < tail_size_y; row_i++) {
468 auto src_reg = bias.reg.xetla_select<block_size_x, 1>(
471 = xetla_cvt<dtype_acc, dtype_bias, block_size_x>(
473 + dst_reg.row(row_i);
484template <
typename scale_dtype,
typename offset_dtype,
gpu_arch arch_tag,
488template <
typename scale_dtype_,
typename offset_dtype_, gpu_arch arch_tag>
490 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
505 using coord_t =
typename scale_mem_desc_t::coord_t;
515 : scale_base(scale_base_)
516 , scale_shape(scale_shape_)
517 , offset_base(offset_base_)
518 , offset_shape(offset_shape_) {}
520 template <
typename matAcc_t>
522 const coord_t &coord,
const arguments_t &args,
523 [[maybe_unused]] uint32_t slm_base = 0,
524 [[maybe_unused]] uint32_t nbarrier_base = 0) {
526 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
527 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
528 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
529 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
530 static constexpr int32_t num_block_x = matAcc_t::num_block_x;
531 static constexpr uint32_t block_elems = matAcc_t::block_elems;
533 using scale_tile_desc_t =
tile_desc_t<tile_size_x, 1, block_size_x, 1,
536 using scale_payload_t
538 msg_type_v<scale_tile_desc_t, scale_mem_desc_t::space>,
540 coord_t scale_coord(coord.x, 0);
542 args.scale_base, args.scale_shape, scale_coord);
543 scale_tile_t scale_tile;
544 scale_payload_t scale_payload(scale_mem_desc);
545 tile_load<cache_hint::cached, cache_hint::cached>(
546 scale_tile, scale_payload);
548 using offset_tile_desc_t =
tile_desc_t<tile_size_x, 1, block_size_x, 1,
553 msg_type_v<offset_tile_desc_t, offset_mem_desc_t::space>,
555 coord_t offset_coord(coord.x, 0);
557 args.offset_base, args.offset_shape, offset_coord);
558 offset_tile_t offset_tile;
559 offset_payload_t offset_payload(offset_mem_desc);
560 tile_load<cache_hint::cached, cache_hint::cached>(
561 offset_tile, offset_payload);
564 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
566 for (uint32_t j = 0; j < num_block_x; j++) {
567 auto acc_reg = matAcc.reg.xetla_select<block_elems, 1>(
568 (i * num_block_x + j) * block_elems);
569 auto offset_reg = offset_tile.reg.xetla_select<block_size_x, 1>(
571 auto scale_reg = scale_tile.reg.xetla_select<block_size_x, 1>(
574 for (uint32_t row_i = 0; row_i < block_size_y; row_i++) {
575 acc_reg.xetla_select<block_size_x, 1>(row_i * block_size_x)
577 * acc_reg.xetla_select<block_size_x, 1>(
578 row_i * block_size_x)
585 if constexpr ((tile_size_y % block_size_y) != 0) {
586 constexpr uint32_t tail_start_y
587 = tile_size_y / block_size_y * block_size_y;
588 constexpr int32_t tail_size_y = tile_size_y % block_size_y;
589 constexpr int32_t tail_block_elems = tail_size_y * block_size_x;
591 for (uint32_t j = 0; j < num_block_x; j++) {
592 auto acc_reg = matAcc.reg.xetla_select<tail_block_elems, 1>(
593 tail_start_y * tile_size_x + j * tail_block_elems);
594 auto offset_reg = offset_tile.reg.xetla_select<block_size_x, 1>(
596 auto scale_reg = scale_tile.reg.xetla_select<block_size_x, 1>(
599 for (uint32_t row_i = 0; row_i < tail_size_y; row_i++) {
600 acc_reg.xetla_select<block_size_x, 1>(row_i * block_size_x)
602 * acc_reg.xetla_select<block_size_x, 1>(
603 row_i * block_size_x)
614template <
typename scale_dtype, gpu_arch arch_tag,
class enable =
void>
617template <
typename scale_dtype_, gpu_arch arch_tag>
619 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
627 using coord_t =
typename scale_mem_desc_t::coord_t;
635 : scale_base(scale_base_), scale_shape(scale_shape_) {}
637 template <
typename matAcc_t>
639 const coord_t &coord,
const arguments_t &args,
640 [[maybe_unused]] uint32_t slm_base = 0,
641 [[maybe_unused]] uint32_t nbarrier_base = 0) {
643 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
644 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
645 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
646 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
647 static constexpr int32_t num_block_x = matAcc_t::num_block_x;
648 static constexpr uint32_t block_elems = matAcc_t::block_elems;
650 using scale_tile_desc_t =
tile_desc_t<tile_size_x, 1, block_size_x, 1,
653 using scale_payload_t
655 msg_type_v<scale_tile_desc_t, scale_mem_desc_t::space>,
657 coord_t scale_coord(coord.x, 0);
659 args.scale_base, args.scale_shape, scale_coord);
660 scale_tile_t scale_tile;
661 scale_payload_t scale_payload(scale_mem_desc);
662 tile_load<cache_hint::cached, cache_hint::cached>(
663 scale_tile, scale_payload);
666 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
668 for (uint32_t j = 0; j < num_block_x; j++) {
669 auto acc_reg = matAcc.reg.xetla_select<block_elems, 1>(
670 (i * num_block_x + j) * block_elems);
671 auto scale_reg = scale_tile.reg.xetla_select<block_size_x, 1>(
674 for (uint32_t row_i = 0; row_i < block_size_y; row_i++) {
675 acc_reg.xetla_select<block_size_x, 1>(row_i * block_size_x)
677 * acc_reg.xetla_select<block_size_x, 1>(
678 row_i * block_size_x);
683 if constexpr ((tile_size_y % block_size_y) != 0) {
684 constexpr uint32_t tail_start_y
685 = tile_size_y / block_size_y * block_size_y;
686 constexpr int32_t tail_size_y = tile_size_y % block_size_y;
687 constexpr int32_t tail_block_elems = tail_size_y * block_size_x;
689 for (uint32_t j = 0; j < num_block_x; j++) {
690 auto acc_reg = matAcc.reg.xetla_select<tail_block_elems, 1>(
691 tail_start_y * tile_size_x + j * tail_block_elems);
692 auto scale_reg = scale_tile.reg.xetla_select<block_size_x, 1>(
695 for (uint32_t row_i = 0; row_i < tail_size_y; row_i++) {
696 acc_reg.xetla_select<block_size_x, 1>(row_i * block_size_x)
698 * acc_reg.xetla_select<block_size_x, 1>(
699 row_i * block_size_x);
717template <reduce_op reduce_kind_,
typename dtype_in_, gpu_arch arch_tag>
719 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
723 using shape_t =
typename mem_desc_in_t::shape_t;
724 using coord_t =
typename mem_desc_in_t::coord_t;
725 using base_t =
typename mem_desc_in_t::base_t;
733 :
shape(shape_), base(base_) {}
735 template <
typename matAcc_t>
737 const coord_t &coord,
const arguments_t &args,
738 [[maybe_unused]] uint32_t slm_base = 0,
739 [[maybe_unused]] uint32_t nbarrier_base = 0) {
740 using dtype_acc =
typename matAcc_t::dtype;
741 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
742 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
743 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
744 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
745 static constexpr int32_t num_block_x = matAcc_t::num_block_x;
746 static constexpr uint32_t block_elems = matAcc_t::block_elems;
748 using mat_in_tile_desc_t =
tile_desc_t<tile_size_x, block_size_y,
753 msg_type_v<mat_in_tile_desc_t, mem_desc_in_t::space>, arch_tag>;
756 mat_in_tile_t mat_in;
757 mat_in_payload_t mat_in_payload(mem_desc_in);
758 mat_in_tile_acc_t mat_in_acc;
761 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
762 tile_load<cache_hint::cached, cache_hint::cached>(
763 mat_in, mat_in_payload);
766 for (uint32_t j = 0; j < num_block_x; j++) {
767 auto dst_reg = matAcc.reg.xetla_select<block_elems, 1>(
768 (i * num_block_x + j) * block_elems);
769 auto src_reg = mat_in_acc.reg.xetla_select<block_elems, 1>(
771 dst_reg = reduce_helper<reduce_kind, dtype_acc, block_elems>(
774 mat_in_payload.template update_tdesc<tdesc_update_dir::y_dir>(
778 if constexpr ((tile_size_y % block_size_y) != 0) {
779 constexpr uint32_t tail_start_y
780 = tile_size_y / block_size_y * block_size_y;
781 constexpr int32_t tail_size_y = tile_size_y % block_size_y;
782 constexpr int32_t tail_block_elems = tail_size_y * block_size_x;
784 using mat_tail_in_tile_desc_t =
tile_desc_t<tile_size_x,
786 using mat_tail_in_tile_t
789 mat_tail_in_tile_desc_t,
790 msg_type_v<mat_tail_in_tile_desc_t, mem_desc_in_t::space>,
792 using mat_tail_in_tile_acc_t
794 mat_tail_in_tile_t mat_tail_in;
795 mat_tail_in_payload_t mat_tail_in_payload(mem_desc_in);
796 mat_tail_in_tile_acc_t mat_tail_in_acc;
797 mat_tail_in_payload.template update_tdesc<tdesc_update_dir::y_dir>(
799 tile_load<cache_hint::cached, cache_hint::cached>(
800 mat_tail_in, mat_tail_in_payload);
803 for (uint32_t j = 0; j < num_block_x; j++) {
804 auto dst_reg = matAcc.reg.xetla_select<tail_block_elems, 1>(
805 tail_start_y * tile_size_x + j * tail_block_elems);
807 = mat_tail_in_acc.reg.xetla_select<tail_block_elems, 1>(
808 j * tail_block_elems);
810 tail_block_elems>(src_reg, dst_reg);
824template <
reduce_op reduce_kind,
typename dtype_in,
828template <reduce_op reduce_kind_,
typename dtype_in_>
833 using shape_t =
typename mem_desc_in_t::shape_t;
834 using coord_t =
typename mem_desc_in_t::coord_t;
835 using base_t =
typename mem_desc_in_t::base_t;
843 :
shape(shape_), base(base_) {}
845 template <
typename matAcc_t>
847 const coord_t &coord,
const arguments_t &args,
848 [[maybe_unused]] uint32_t slm_base = 0,
849 [[maybe_unused]] uint32_t nbarrier_base = 0) {
850 using dtype_acc =
typename matAcc_t::dtype;
851 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
852 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
853 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
854 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
855 static constexpr int32_t num_block_x = matAcc_t::num_block_x;
856 static constexpr uint32_t block_elems = matAcc_t::block_elems;
858 using mat_in_tile_desc_t =
tile_desc_t<block_size_x, block_size_y,
861 using mat_in_payload_t
863 msg_type_v<mat_in_tile_desc_t, mem_desc_in_t::space>,
866 mat_in_tile_t mat_in;
867 mat_in_tile_t mat_zero(0);
868 mat_in_payload_t mat_in_payload(mem_desc_in);
871 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
873 for (uint32_t j = 0; j < num_block_x; j++) {
875 tile_load<cache_hint::cached, cache_hint::cached>(
876 mat_in, mat_in_payload);
877 auto dst_reg = matAcc.reg.xetla_select<block_elems, 1>(
878 (i * num_block_x + j) * block_elems);
880 auto src_reg = mat_in.reg;
881 dst_reg = reduce_helper<reduce_kind, dtype_acc, block_elems>(
886 mat_in_payload.template update_tdesc<tdesc_update_dir::x_dir>(
889 mat_in_payload.template update_tdesc<tdesc_update_dir::x_dir>(
890 -num_block_x * block_size_x);
891 mat_in_payload.template update_tdesc<tdesc_update_dir::y_dir>(
895 if constexpr ((tile_size_y % block_size_y) != 0) {
896 constexpr uint32_t tail_start_y
897 = tile_size_y / block_size_y * block_size_y;
898 constexpr int32_t tail_size_y = tile_size_y % block_size_y;
899 constexpr int32_t tail_block_elems = tail_size_y * block_size_x;
901 for (uint32_t j = 0; j < num_block_x; j++) {
903 tile_load<cache_hint::cached, cache_hint::cached>(
904 mat_in, mat_in_payload);
905 auto dst_reg = matAcc.reg.xetla_select<tail_block_elems, 1>(
906 tail_start_y * tile_size_x + j * tail_block_elems);
907 auto src_reg = mat_in.reg.xetla_select<tail_block_elems, 1>(
908 tail_start_y * tile_size_x + j * tail_block_elems);
910 tail_block_elems>(src_reg, dst_reg);
915 mat_in_payload.template update_tdesc<tdesc_update_dir::x_dir>(
928template <
typename dtype_mask, gpu_arch arch_tag,
class enable =
void>
931template <
typename dtype_mask_, gpu_arch arch_tag>
933 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
937 using shape_t =
typename mem_desc_mask_t::shape_t;
938 using coord_t =
typename mem_desc_mask_t::coord_t;
939 using base_t =
typename mem_desc_mask_t::base_t;
940 static constexpr uint32_t num_flag = 4;
941 static constexpr uint32_t unroll_size = num_flag * 16;
950 :
shape(shape_), base(base_), prob(prob_), scale(scale_) {}
953 template <
typename matAcc_t>
955 const coord_t &coord,
const arguments_t &args,
956 [[maybe_unused]] uint32_t slm_base = 0,
957 [[maybe_unused]] uint32_t nbarrier_base = 0) {
958 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
959 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
960 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
961 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
962 static constexpr uint32_t tile_elems = matAcc_t::tile_elems;
963 if (args.prob == 0) {
return; }
964 using mask_in_tile_desc_t =
tile_desc_t<tile_size_x, tile_size_y,
967 using mask_in_payload_t
969 msg_type_v<mask_in_tile_desc_t, mem_desc_mask_t::space>,
972 mask_in_tile_t mask_in;
973 mask_in_payload_t mask_in_payload(mem_desc_mask);
974 tile_load<cache_hint::cached, cache_hint::cached>(
975 mask_in, mask_in_payload);
977 for (uint32_t i = 0; i < tile_elems / unroll_size; i++) {
979 = mask_in.reg.xetla_select<unroll_size, 1>(i * unroll_size)
982 = matAcc.reg.xetla_select<unroll_size, 1>(i * unroll_size);
983 dst_reg *= args.scale;
984 dst_reg.xetla_merge(0, mask_flag);
986 if constexpr (tile_elems % unroll_size != 0) {
987 constexpr uint32_t remain_len = tile_elems % unroll_size;
988 constexpr uint32_t remain_start
989 = tile_elems / unroll_size * unroll_size;
991 = mask_in.reg.xetla_select<remain_len, 1>(remain_start) > 0;
992 auto dst_reg = matAcc.reg.xetla_select<remain_len, 1>(remain_start);
993 dst_reg *= args.scale;
994 dst_reg.xetla_merge(0, mask_flag);
1005template <
typename dtype_mask, gpu_arch arch_tag,
class enable =
void>
1008template <
typename dtype_mask_, gpu_arch arch_tag>
1010 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
1014 using shape_t =
typename mem_desc_mask_t::shape_t;
1015 using coord_t =
typename mem_desc_mask_t::coord_t;
1016 using base_t =
typename mem_desc_mask_t::base_t;
1017 static constexpr uint32_t random_simd = 16;
1018 static constexpr uint32_t random_len = 4 * random_simd;
1021 struct arguments_t {
1030 uint64_t *rand_offset_ptr_,
1031 uint64_t rand_seed_ = 67280421310721)
1032 : mask_shape(mask_shape_)
1033 , mask_base(mask_base_)
1034 , rand_offset_ptr(rand_offset_ptr_)
1036 , rand_seed(rand_seed_) {}
1039 template <
typename matAcc_t>
1041 const coord_t &coord,
const arguments_t &args,
1042 [[maybe_unused]] uint32_t slm_base = 0,
1043 [[maybe_unused]] uint32_t nbarrier_base = 0) {
1044 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
1045 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
1046 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
1047 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
1048 static constexpr uint32_t tile_elems = matAcc_t::tile_elems;
1050 using mask_out_tile_desc_t =
tile_desc_t<tile_size_x, tile_size_y,
1054 mask_out_tile_desc_t,
1055 msg_type_v<mask_out_tile_desc_t, mem_desc_mask_t::space>,
1057 if (args.prob == 0) {
return; }
1059 float scale = 1.f / (1.f - args.prob);
1060 uint32_t threshold = uint32_t(args.prob *
float(4294967296));
1064 args.rand_offset_ptr, 0);
1065 uint64_t rand_offset = rand_offset_v[0];
1066 uint64_t rand_subseq = uint64_t(coord.y) << 32 | uint64_t(coord.x);
1067 rand_gen.
init(args.rand_seed, rand_subseq, rand_offset);
1069 mem_desc_mask_t mem_desc_mask(args.mask_base, args.mask_shape, coord);
1070 mask_out_tile_t mask_out;
1071 mask_out_payload_t mask_out_payload(mem_desc_mask);
1074 for (uint32_t i = 0; i < tile_elems / random_len; i++) {
1076 = matAcc.reg.xetla_select<random_len, 1>(i * random_len);
1078 = mask_out.reg.xetla_select<random_len, 1>(i * random_len);
1082 out_sub.xetla_merge(0, mask_flag);
1083 mask_sub.xetla_merge(1, 0, mask_flag);
1085 if constexpr (tile_elems % random_len != 0) {
1086 constexpr uint32_t remain_len = tile_elems % random_len;
1087 constexpr uint32_t remain_start
1088 = tile_elems / random_len * random_len;
1089 auto out_sub = matAcc.reg.xetla_select<remain_len, 1>(remain_start);
1091 = mask_out.reg.xetla_select<remain_len, 1>(remain_start);
1096 out_sub.xetla_merge(0, mask_flag.xetla_select<remain_len, 1>(0));
1097 mask_sub.xetla_merge(
1098 1, 0, mask_flag.xetla_select<remain_len, 1>(0));
1100 tile_store<cache_hint::streaming>(mask_out, mask_out_payload);
1109template <
typename dtype_in, gpu_arch arch_tag,
class enable =
void>
1112template <
typename dtype_in_, gpu_arch arch_tag>
1114 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
1120 struct arguments_t {
1126 template <
typename matAcc_t>
1128 [[maybe_unused]]
coord_t coord, [[maybe_unused]] arguments_t args,
1129 [[maybe_unused]] uint32_t slm_base = 0,
1130 [[maybe_unused]] uint32_t nbarrier_base = 0) {
1131 using dtype_acc =
typename matAcc_t::dtype;
1132 static_assert(std::is_same<dtype_in, dtype_acc>::value,
1133 "Given multiplier must have same type as matAcc!");
1134 matAcc.reg *= args.multiplier;
1143template <
typename dtype_in, gpu_arch arch_tag,
class enable =
void>
1146template <
typename dtype_in_, gpu_arch arch_tag>
1148 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
1154 using base_t =
typename mem_desc_in_t::base_t;
1156 struct arguments_t {
1164 :
shape(shape_), base(base_), alpha(alpha_), beta(beta_) {}
1167 template <
typename matAcc_t>
1169 const coord_t &coord,
const arguments_t &args,
1170 [[maybe_unused]] uint32_t slm_base = 0,
1171 [[maybe_unused]] uint32_t nbarrier_base = 0) {
1172 using dtype_acc =
typename matAcc_t::dtype;
1173 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
1174 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
1175 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
1176 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
1177 static constexpr uint32_t num_block_x = matAcc_t::num_block_x;
1178 static constexpr uint32_t num_block_y = matAcc_t::num_block_y;
1179 static constexpr uint32_t block_elems = matAcc_t::block_elems;
1180 static constexpr uint32_t remained_size_y = tile_size_y % block_size_y;
1182 using mat_in_tile_desc_t =
tile_desc_t<tile_size_x, block_size_y,
1187 msg_type_v<mat_in_tile_desc_t, mem_desc_in_t::space>, arch_tag>;
1190 mat_in_tile_t mat_in;
1191 mat_in_payload_t mat_in_payload(mem_desc_in);
1192 mat_in_tile_acc_t mat_in_acc;
1194 dtype_acc alpha = dtype_acc(args.alpha);
1195 dtype_acc beta = dtype_acc(args.beta);
1196 matAcc.reg *= alpha;
1199 for (uint32_t i = 0; i < num_block_y; ++i) {
1200 tile_load<cache_hint::cached, cache_hint::cached>(
1201 mat_in, mat_in_payload);
1203 mat_in_acc.reg *= beta;
1205 for (uint32_t j = 0; j < num_block_x; ++j) {
1206 auto dst_reg = matAcc.reg.xetla_select<block_elems, 1>(
1207 (i * num_block_x + j) * block_elems);
1208 auto src_reg = mat_in_acc.reg.xetla_select<block_elems, 1>(
1210 dst_reg = reduce_helper<reduce_op::sum, dtype_acc, block_elems>(
1213 mat_in_payload.template update_tdesc<tdesc_update_dir::y_dir>(
1217 if constexpr (remained_size_y > 0) {
1218 constexpr uint32_t tail_start_y = num_block_y * block_size_y;
1219 constexpr uint32_t tail_block_elems
1220 = remained_size_y * block_size_x;
1222 using mat_tail_in_tile_desc_t
1223 =
tile_desc_t<tile_size_x, remained_size_y, block_size_x,
1225 using mat_tail_in_tile_t
1228 mat_tail_in_tile_desc_t,
1229 msg_type_v<mat_tail_in_tile_desc_t, mem_desc_in_t::space>,
1231 using mat_tail_in_tile_acc_t
1234 mat_tail_in_tile_t mat_tail_in;
1235 mat_tail_in_payload_t mat_tail_in_payload(mem_desc_in);
1236 mat_tail_in_tile_acc_t mat_tail_in_acc;
1237 mat_tail_in_payload.template update_tdesc<tdesc_update_dir::y_dir>(
1239 tile_load<cache_hint::cached, cache_hint::cached>(
1240 mat_tail_in, mat_tail_in_payload);
1242 mat_tail_in_acc.reg *= beta;
1244 for (uint32_t j = 0; j < num_block_x; ++j) {
1245 auto dst_reg = matAcc.reg.xetla_select<tail_block_elems, 1>(
1246 tail_start_y * tile_size_x + j * tail_block_elems);
1248 = mat_tail_in_acc.reg.xetla_select<tail_block_elems, 1>(
1249 j * tail_block_elems);
1251 tail_block_elems>(src_reg, dst_reg);
#define __XETLA_API
Definition common.hpp:43
#define xetla_format
xetla format.
Definition base_ops.hpp:38
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
__XETLA_API xetla_vector< Ty, N *NElts > xetla_load_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_mask< N > pred=1)
Stateless scattered load.
Definition memory.hpp:245
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
__XETLA_API std::enable_if_t< reduce_kind==reduce_op::sum, xetla_vector< dtype, size > > reduce_helper(xetla_vector< dtype, size > a, xetla_vector< dtype, size > b)
Definition misc.hpp:112
Definition limitation.hpp:457
__XETLA_API std::enable_if_t<(T_src::register_layout !=reg_layout::linear) &&(T_dst::register_layout !=reg_layout::linear) &&is_same_layout< T_dst, T_src >::value &&(!is_floating_to_integer< T_dst, T_src >::value)> elemwise_cvt(T_dst &dst, T_src &src)
Is the element wise data conversion, the src and dst tile should have the same layout.
Definition op_function.hpp:40
__XETLA_API std::enable_if_t< detail::check_store_type< tile_t, payload_t >::is_global_2d_xe > tile_store(tile_t &tile, payload_t &payload)
Is the func storing data from register file to global memory.
Definition store_xe.hpp:91
reduce_op
xetla reduce op
Definition common.hpp:217
gpu_arch
Definition common.hpp:73
shape_t shape
Definition tile_op_functor.hpp:401
arguments_t(base_t base_, shape_t shape_)
Definition tile_op_functor.hpp:404
base_t base
Definition tile_op_functor.hpp:402
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:408
typename mem_desc_bias_t::coord_t coord_t
Definition tile_op_functor.hpp:397
dtype_bias_ dtype_bias
Definition tile_op_functor.hpp:393
typename mem_desc_bias_t::base_t base_t
Definition tile_op_functor.hpp:398
typename mem_desc_bias_t::shape_t shape_t
Definition tile_op_functor.hpp:396
Is the bias_add op functor.
Definition tile_op_functor.hpp:388
dtype_mask_ dtype_mask
Definition tile_op_functor.hpp:934
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:954
typename mem_desc_mask_t::coord_t coord_t
Definition tile_op_functor.hpp:938
typename mem_desc_mask_t::shape_t shape_t
Definition tile_op_functor.hpp:937
typename mem_desc_mask_t::base_t base_t
Definition tile_op_functor.hpp:939
float scale
Definition tile_op_functor.hpp:946
arguments_t(base_t base_, shape_t shape_, float prob_, float scale_)
Definition tile_op_functor.hpp:948
shape_t shape
Definition tile_op_functor.hpp:943
base_t base
Definition tile_op_functor.hpp:944
float prob
Definition tile_op_functor.hpp:945
Is the dropout op functor.
Definition tile_op_functor.hpp:929
typename mem_desc_in_t::base_t base_t
Definition tile_op_functor.hpp:835
typename mem_desc_in_t::shape_t shape_t
Definition tile_op_functor.hpp:833
dtype_in_ dtype_in
Definition tile_op_functor.hpp:830
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:846
typename mem_desc_in_t::coord_t coord_t
Definition tile_op_functor.hpp:834
base_t base
Definition tile_op_functor.hpp:840
shape_t shape
Definition tile_op_functor.hpp:839
arguments_t(base_t base_, shape_t shape_)
Definition tile_op_functor.hpp:842
Is the element-wise reduce op functor, specialized for stream_k dispatch Load partial sum from scratc...
Definition tile_op_functor.hpp:826
typename mem_desc_in_t::base_t base_t
Definition tile_op_functor.hpp:725
typename mem_desc_in_t::shape_t shape_t
Definition tile_op_functor.hpp:723
typename mem_desc_in_t::coord_t coord_t
Definition tile_op_functor.hpp:724
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:736
dtype_in_ dtype_in
Definition tile_op_functor.hpp:720
base_t base
Definition tile_op_functor.hpp:730
arguments_t(base_t base_, shape_t shape_)
Definition tile_op_functor.hpp:732
shape_t shape
Definition tile_op_functor.hpp:729
Is the element-wise reduce op functor.
Definition tile_op_functor.hpp:715
arguments_t(base_t base_, shape_t shape_)
Definition tile_op_functor.hpp:305
base_t base
Definition tile_op_functor.hpp:303
shape_t shape
Definition tile_op_functor.hpp:302
typename mem_desc_x_t::coord_t coord_t
Definition tile_op_functor.hpp:299
typename mem_desc_x_t::shape_t shape_t
Definition tile_op_functor.hpp:298
dtype_in_ dtype_in
Definition tile_op_functor.hpp:295
typename mem_desc_x_t::base_t base_t
Definition tile_op_functor.hpp:300
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:309
Is the element-wise gelu backward op functor.
Definition tile_op_functor.hpp:290
Definition tile_op_functor.hpp:130
Is the element-wise gelu inference forward op functor.
Definition tile_op_functor.hpp:129
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:132
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:192
typename mem_desc_w_t::coord_t coord_t
Definition tile_op_functor.hpp:181
dtype_out_ dtype_out
Definition tile_op_functor.hpp:177
typename mem_desc_w_t::shape_t shape_t
Definition tile_op_functor.hpp:180
typename mem_desc_w_t::base_t base_t
Definition tile_op_functor.hpp:182
base_t base
Definition tile_op_functor.hpp:186
arguments_t(base_t base_, shape_t shape_)
Definition tile_op_functor.hpp:188
shape_t shape
Definition tile_op_functor.hpp:185
Is the element-wise gelu training forward op functor.
Definition tile_op_functor.hpp:172
arguments_t(base_t base_, shape_t shape_, dtype_in alpha_, dtype_in beta_)
Definition tile_op_functor.hpp:1162
shape_t shape
Definition tile_op_functor.hpp:1157
dtype_in alpha
Definition tile_op_functor.hpp:1159
base_t base
Definition tile_op_functor.hpp:1158
dtype_in beta
Definition tile_op_functor.hpp:1160
dtype_in_ dtype_in
Definition tile_op_functor.hpp:1149
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:1168
typename mem_desc_in_t::coord_t coord_t
Definition tile_op_functor.hpp:1153
typename mem_desc_in_t::shape_t shape_t
Definition tile_op_functor.hpp:1152
typename mem_desc_in_t::base_t base_t
Definition tile_op_functor.hpp:1154
Is the linear_op functor.
Definition tile_op_functor.hpp:1144
Is to illustrate the memory information.
Definition api.hpp:44
Definition tile_op_functor.hpp:34
Is none op functor, for placeholder purpose.
Definition tile_op_functor.hpp:33
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:36
Definition tile_op_functor.hpp:47
Is the element-wise relu op functor.
Definition tile_op_functor.hpp:46
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:49
float prob
Definition tile_op_functor.hpp:1025
base_t mask_base
Definition tile_op_functor.hpp:1023
arguments_t(base_t mask_base_, shape_t mask_shape_, float prob_, uint64_t *rand_offset_ptr_, uint64_t rand_seed_=67280421310721)
Definition tile_op_functor.hpp:1029
shape_t mask_shape
Definition tile_op_functor.hpp:1022
uint64_t * rand_offset_ptr
Definition tile_op_functor.hpp:1024
uint64_t rand_seed
Definition tile_op_functor.hpp:1026
typename mem_desc_mask_t::base_t base_t
Definition tile_op_functor.hpp:1016
typename mem_desc_mask_t::shape_t shape_t
Definition tile_op_functor.hpp:1014
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:1040
xetla_rand_t< random_simd > rand_gen
Definition tile_op_functor.hpp:1019
typename mem_desc_mask_t::coord_t coord_t
Definition tile_op_functor.hpp:1015
dtype_mask_ dtype_mask
Definition tile_op_functor.hpp:1011
Is the random number generator and dropout op functor.
Definition tile_op_functor.hpp:1006
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, coord_t coord, arguments_t args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:1127
dtype_in_ dtype_in
Definition tile_op_functor.hpp:1115
typename mem_desc_in_t::coord_t coord_t
Definition tile_op_functor.hpp:1118
dtype_in multiplier
Definition tile_op_functor.hpp:1121
arguments_t(dtype_in multiplier_)
Definition tile_op_functor.hpp:1123
Is the scalar_multiply op functor.
Definition tile_op_functor.hpp:1110
offset_dtype_ offset_dtype
Definition tile_op_functor.hpp:492
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:521
typename scale_mem_desc_t::shape_t scale_shape_t
Definition tile_op_functor.hpp:499
typename scale_mem_desc_t::coord_t coord_t
Definition tile_op_functor.hpp:505
scale_dtype_ scale_dtype
Definition tile_op_functor.hpp:491
typename offset_mem_desc_t::shape_t offset_shape_t
Definition tile_op_functor.hpp:502
typename offset_mem_desc_t::base_t offset_base_t
Definition tile_op_functor.hpp:503
typename scale_mem_desc_t::base_t scale_base_t
Definition tile_op_functor.hpp:500
scale_base_t scale_base
Definition tile_op_functor.hpp:508
scale_shape_t scale_shape
Definition tile_op_functor.hpp:509
offset_shape_t offset_shape
Definition tile_op_functor.hpp:511
offset_base_t offset_base
Definition tile_op_functor.hpp:510
arguments_t(scale_base_t scale_base_, scale_shape_t scale_shape_, offset_base_t offset_base_, offset_shape_t offset_shape_)
Definition tile_op_functor.hpp:513
Is MatAcc * vector scale + vector offset.
Definition tile_op_functor.hpp:486
scale_shape_t scale_shape
Definition tile_op_functor.hpp:631
scale_base_t scale_base
Definition tile_op_functor.hpp:630
arguments_t(scale_base_t scale_base_, scale_shape_t scale_shape_)
Definition tile_op_functor.hpp:634
typename scale_mem_desc_t::base_t scale_base_t
Definition tile_op_functor.hpp:626
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:638
typename scale_mem_desc_t::shape_t scale_shape_t
Definition tile_op_functor.hpp:625
typename scale_mem_desc_t::coord_t coord_t
Definition tile_op_functor.hpp:627
scale_dtype_ scale_dtype
Definition tile_op_functor.hpp:620
Is MatAcc * vector scale.
Definition tile_op_functor.hpp:615
Definition tile_op_functor.hpp:90
Is the element-wise sigmoid op functor.
Definition tile_op_functor.hpp:89
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:92
Definition tile_op_functor.hpp:63
Is the element-wise tanh op functor.
Definition tile_op_functor.hpp:62
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:65
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
__XETLA_API xetla_vector< uint32_t, 4 *SIMD > rand()
Definition rand.hpp:57
__XETLA_API void init(uint64_t seed, uint64_t subseq, uint64_t offset)
Definition rand.hpp:38