31template <
typename dtype_in,
typename dtype_out,
typename dtype_acc>
56 typename dtype_out_,
typename dtype_acc_,
typename layer_norm_attr_>
65 static constexpr uint32_t wg_tile_m = layer_norm_attr_::wg_tile_m;
66 static constexpr uint32_t wg_tile_n = layer_norm_attr_::wg_tile_n;
67 static constexpr uint32_t sg_tile_m = layer_norm_attr_::sg_tile_m;
68 static constexpr uint32_t sg_tile_n = layer_norm_attr_::sg_tile_n;
69 static constexpr uint32_t wg_num_m = layer_norm_attr_::wg_num_m;
70 static constexpr uint32_t wg_num_n = layer_norm_attr_::wg_num_n;
71 static constexpr uint32_t chunk_size = layer_norm_attr_::chunk_size;
72 static constexpr uint32_t n_chunks = sg_tile_n / chunk_size;
83 [[maybe_unused]] uint32_t wg_idx, [[maybe_unused]] uint32_t wg_idy,
84 [[maybe_unused]] uint32_t sg_idx, [[maybe_unused]] uint32_t sg_idy,
85 [[maybe_unused]] uint32_t start_m) {}
109template <
typename dtype_in_,
typename dtype_out_,
typename dtype_acc_,
110 typename layer_norm_attr_>
112 dtype_out_, dtype_acc_, layer_norm_attr_,
gpu_arch::Xe> {
121 static constexpr uint32_t wg_tile_m = layer_norm_attr_::wg_tile_m;
122 static constexpr uint32_t wg_tile_n = layer_norm_attr_::wg_tile_n;
123 static constexpr uint32_t sg_tile_m = layer_norm_attr_::sg_tile_m;
124 static constexpr uint32_t sg_tile_n = layer_norm_attr_::sg_tile_n;
125 static constexpr uint32_t wg_num_m = layer_norm_attr_::wg_num_m;
126 static constexpr uint32_t wg_num_n = layer_norm_attr_::wg_num_n;
127 static constexpr uint32_t chunk_size = layer_norm_attr_::chunk_size;
128 static constexpr uint32_t n_chunks = sg_tile_n / chunk_size;
130 static constexpr uint32_t wg_size_x
131 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
132 static constexpr uint32_t wg_size_y
133 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
135 static_assert((sg_tile_n % (
sizeof(uint32_t) /
sizeof(
dtype_mask)) == 0),
136 "sg_tile_n need to be DW aligned");
181 [[maybe_unused]] uint32_t wg_idy, uint32_t sg_idx,
182 [[maybe_unused]] uint32_t sg_idy, uint32_t start_m) {
183 int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n;
190 bias_in_payload.init(args->
bias_ptr, matrix_n, 1, mat_ld, start_n, 0);
191 res_in_payload.init(args->
res_add_ptr, matrix_n, matrix_m, mat_ld,
194 matrix_m, mat_ld, start_n, start_m);
195 mask_in_payload.init(
196 args->
mask_ptr, matrix_n, matrix_m, mask_ld, start_n, start_m);
197 if constexpr (n_chunks == 1) {
209 if constexpr (n_chunks == 1) {
210 res_in_payload.update_tdesc(wg_num_m * wg_tile_m * mat_ld);
212 res_in_payload.update_tdesc(chunk_size);
214 if constexpr (n_chunks != 1) {
216 bias_in_payload.update_tdesc(chunk_size);
219 = xetla_cvt<dtype_acc, dtype_in>(bias_in.
reg);
222 = reduce_helper<reduce_op::sum, dtype_acc, chunk_size>(
224 if (dropout_prob != 0) {
228 if constexpr (n_chunks == 1) {
229 mask_in_payload.update_tdesc(wg_num_m * wg_tile_m * mask_ld);
231 mask_in_payload.update_tdesc(chunk_size);
233 output = drop_out<dtype_acc, chunk_size>(
234 output, mask_in.
reg, dropout_scale);
238 = xetla_cvt<dtype_acc, dtype_in>(res_in.
reg);
239 output = reduce_helper<reduce_op::sum, dtype_acc, chunk_size>(
241 bias_dropout_res_out.
reg = xetla_cvt<dtype_out, dtype_acc>(output);
242 subgroup::tile_store<cache_hint::uncached>(
243 bias_dropout_res_out, bias_dropout_res_out_payload);
244 if constexpr (n_chunks == 1) {
245 bias_dropout_res_out_payload.update_tdesc(
246 wg_num_m * wg_tile_m * mat_ld);
248 bias_dropout_res_out_payload.update_tdesc(chunk_size);
269template <
typename dtype_in_,
typename dtype_out_,
typename dtype_acc_,
270 typename layer_norm_attr_>
281 static constexpr uint32_t wg_tile_m = layer_norm_attr_::wg_tile_m;
282 static constexpr uint32_t wg_tile_n = layer_norm_attr_::wg_tile_n;
283 static constexpr uint32_t sg_tile_m = layer_norm_attr_::sg_tile_m;
284 static constexpr uint32_t sg_tile_n = layer_norm_attr_::sg_tile_n;
285 static constexpr uint32_t wg_num_m = layer_norm_attr_::wg_num_m;
286 static constexpr uint32_t wg_num_n = layer_norm_attr_::wg_num_n;
287 static constexpr uint32_t chunk_size = layer_norm_attr_::chunk_size;
288 static constexpr uint32_t n_chunks = sg_tile_n / chunk_size;
290 static constexpr uint32_t wg_size_x
291 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
292 static constexpr uint32_t wg_size_y
293 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
295 static_assert((sg_tile_n % (
sizeof(uint32_t) /
sizeof(
dtype_mask)) == 0),
296 "sg_tile_n need to be DW aligned");
320 [[maybe_unused]] uint32_t wg_idy, uint32_t sg_idx,
321 [[maybe_unused]] uint32_t sg_idy, uint32_t start_m) {
322 int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n;
327 mask_in_payload.init(
328 args->
mask_ptr, matrix_n, matrix_m, mask_ld, start_n, start_m);
349 if constexpr (n_chunks == 1) {
350 mask_in_payload.update_tdesc(wg_num_m * wg_tile_m * mask_ld);
352 mask_in_payload.update_tdesc(chunk_size);
355 = drop_out<dtype_acc, chunk_size>(
356 input, mask_in.
reg, dropout_scale);
367template <
typename dtype_in_,
typename dtype_out_,
typename dtype_acc_,
368 typename layer_norm_attr_>
370 dtype_in_, dtype_out_, dtype_acc_, layer_norm_attr_,
gpu_arch::Xe> {
379 static constexpr uint32_t wg_tile_m = layer_norm_attr_::wg_tile_m;
380 static constexpr uint32_t wg_tile_n = layer_norm_attr_::wg_tile_n;
381 static constexpr uint32_t sg_tile_m = layer_norm_attr_::sg_tile_m;
382 static constexpr uint32_t sg_tile_n = layer_norm_attr_::sg_tile_n;
383 static constexpr uint32_t wg_num_m = layer_norm_attr_::wg_num_m;
384 static constexpr uint32_t wg_num_n = layer_norm_attr_::wg_num_n;
385 static constexpr uint32_t chunk_size = layer_norm_attr_::chunk_size;
386 static constexpr uint32_t n_chunks = sg_tile_n / chunk_size;
388 static constexpr uint32_t wg_size_x
389 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
390 static constexpr uint32_t wg_size_y
391 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
437 uint32_t sg_idx, uint32_t sg_idy, uint32_t start_m) {
438 int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n;
447 uint32_t threshold = uint32_t(args->
dropout_prob *
float(4294967296));
449 bias_in_payload.init(args->
bias_ptr, matrix_n, 1, mat_ld, start_n, 0);
450 res_in_payload.init(args->
res_add_ptr, matrix_n, matrix_m, mat_ld,
453 matrix_m, mat_ld, start_n, start_m);
454 mask_out_payload.init(
455 args->
mask_ptr, matrix_n, matrix_m, mask_ld, start_n, start_m);
456 int linear_idx = (wg_idy * wg_size_y + sg_idy) * (wg_size_x * wg_num_n)
457 + wg_idx * wg_size_x + sg_idx;
458 dropout_fwd.
init(args->
rand_seed, linear_idx, rand_offset_ptr_v[0],
460 if constexpr (n_chunks == 1) {
472 if constexpr (n_chunks == 1) {
473 res_in_payload.update_tdesc(wg_num_m * wg_tile_m * mat_ld);
475 res_in_payload.update_tdesc(chunk_size);
477 if constexpr (n_chunks != 1) {
479 bias_in_payload.update_tdesc(chunk_size);
482 = xetla_cvt<dtype_acc, dtype_in>(bias_in.
reg);
485 = reduce_helper<reduce_op::sum, dtype_acc, chunk_size>(
487 if (dropout_prob != 0) {
489 output = dropout_fwd.template process<dtype_acc>(output);
491 subgroup::tile_store<cache_hint::uncached>(
492 mask_out, mask_out_payload);
493 if constexpr (n_chunks == 1) {
494 mask_out_payload.update_tdesc(wg_num_m * wg_tile_m * mask_ld);
496 mask_out_payload.update_tdesc(chunk_size);
501 = xetla_cvt<dtype_acc, dtype_in>(res_in.
reg);
502 output = reduce_helper<reduce_op::sum, dtype_acc, chunk_size>(
504 bias_dropout_res_out.
reg = xetla_cvt<dtype_out, dtype_acc>(output);
505 subgroup::tile_store<cache_hint::uncached>(
506 bias_dropout_res_out, bias_dropout_res_out_payload);
507 if constexpr (n_chunks == 1) {
508 bias_dropout_res_out_payload.update_tdesc(
509 wg_num_m * wg_tile_m * mat_ld);
511 bias_dropout_res_out_payload.update_tdesc(chunk_size);
532template <
typename dtype_in_,
typename dtype_out_,
typename dtype_acc_,
533 typename layer_norm_attr_>
535 dtype_out_, dtype_acc_, layer_norm_attr_,
gpu_arch::Xe> {
544 static constexpr uint32_t wg_tile_m = layer_norm_attr_::wg_tile_m;
545 static constexpr uint32_t wg_tile_n = layer_norm_attr_::wg_tile_n;
546 static constexpr uint32_t sg_tile_m = layer_norm_attr_::sg_tile_m;
547 static constexpr uint32_t sg_tile_n = layer_norm_attr_::sg_tile_n;
548 static constexpr uint32_t wg_num_m = layer_norm_attr_::wg_num_m;
549 static constexpr uint32_t wg_num_n = layer_norm_attr_::wg_num_n;
550 static constexpr uint32_t chunk_size = layer_norm_attr_::chunk_size;
551 static constexpr uint32_t n_chunks = sg_tile_n / chunk_size;
552 static_assert(sg_tile_n % chunk_size == 0,
553 "Current impl does not support tailing mechanism");
555 static constexpr uint32_t wg_size_x
556 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
557 static constexpr uint32_t wg_size_y
558 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
582 uint32_t sg_idx, uint32_t sg_idy, uint32_t start_m) {
583 int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n;
591 uint32_t threshold = uint32_t(args->
dropout_prob *
float(4294967296));
592 mask_out_payload.init(
593 args->
mask_ptr, matrix_n, matrix_m, mask_ld, start_n, start_m);
594 int linear_idx = (wg_idy * wg_size_y + sg_idy) * (wg_size_x * wg_num_n)
595 + wg_idx * wg_size_x + sg_idx;
596 dropout_fwd.
init(args->
rand_seed, linear_idx, rand_offset_ptr_v[0],
616 = dropout_fwd.template process<dtype_acc>(input);
618 subgroup::tile_store<cache_hint::uncached>(mask_out, mask_out_payload);
619 if constexpr (n_chunks == 1) {
620 mask_out_payload.update_tdesc(wg_num_m * wg_tile_m * mask_ld);
622 mask_out_payload.update_tdesc(chunk_size);
#define SW_BARRIER()
SW_BARRIER, insert software scheduling barrier, for better code control.
Definition common.hpp:227
#define __XETLA_API
Definition common.hpp:43
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__XETLA_API xetla_vector< Ty, N *NElts > xetla_load_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_mask< N > pred=1)
Stateless scattered load.
Definition memory.hpp:245
Definition limitation.hpp:607
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
ln_fwd_fused_kind
Definition layer_norm_fused_op_api.hpp:28
@ bias_rng_dropout_resAdd_ln
gpu_arch
Definition common.hpp:73
__XETLA_API void init(uint64_t seed, uint64_t subseq, uint64_t offset, uint32_t threshold_, float scale_)
Definition rand.hpp:123
__XETLA_API xetla_vector< dtype_mask, SZ > get_mask()
Definition rand.hpp:159
Definition layer_norm_fused_op_fwd_xe.hpp:32
uint8_t * mask_ptr
Definition layer_norm_fused_op_fwd_xe.hpp:36
dtype_in * res_add_ptr
Definition layer_norm_fused_op_fwd_xe.hpp:34
uint32_t mat_ld
Definition layer_norm_fused_op_fwd_xe.hpp:39
dtype_out * bias_dropout_res_ptr
Definition layer_norm_fused_op_fwd_xe.hpp:35
uint32_t matrix_m
Definition layer_norm_fused_op_fwd_xe.hpp:37
uint32_t matrix_n
Definition layer_norm_fused_op_fwd_xe.hpp:38
float dropout_scale
Definition layer_norm_fused_op_fwd_xe.hpp:45
uint64_t * rand_offset_ptr
Definition layer_norm_fused_op_fwd_xe.hpp:42
uint64_t rand_seed
Definition layer_norm_fused_op_fwd_xe.hpp:41
float dropout_prob
Definition layer_norm_fused_op_fwd_xe.hpp:43
dtype_in * bias_ptr
Definition layer_norm_fused_op_fwd_xe.hpp:33
uint32_t mask_ld
Definition layer_norm_fused_op_fwd_xe.hpp:40
__XETLA_API xetla_vector< dtype_acc, chunk_size > pre_op(xetla_vector< dtype_acc, chunk_size > input)
Definition layer_norm_fused_op_fwd_xe.hpp:91
dtype_in_ dtype_in
Definition layer_norm_fused_op_fwd_xe.hpp:61
dtype_out_ dtype_out
Definition layer_norm_fused_op_fwd_xe.hpp:62
__XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy, uint32_t start_m)
Definition layer_norm_fused_op_fwd_xe.hpp:82
__XETLA_API xetla_vector< dtype_acc, chunk_size > post_op(xetla_vector< dtype_acc, chunk_size > input)
Definition layer_norm_fused_op_fwd_xe.hpp:100
dtype_acc_ dtype_acc
Definition layer_norm_fused_op_fwd_xe.hpp:60
float dropout_prob
Definition layer_norm_fused_op_fwd_xe.hpp:170
mask_in_payload_t mask_in_payload
Definition layer_norm_fused_op_fwd_xe.hpp:164
uint32_t matrix_n
Definition layer_norm_fused_op_fwd_xe.hpp:167
res_in_t res_in
Definition layer_norm_fused_op_fwd_xe.hpp:161
bias_dropout_res_out_payload_t bias_dropout_res_out_payload
Definition layer_norm_fused_op_fwd_xe.hpp:160
__XETLA_API xetla_vector< dtype_acc, chunk_size > post_op(xetla_vector< dtype_acc, chunk_size > input)
Definition layer_norm_fused_op_fwd_xe.hpp:257
uint32_t mask_ld
Definition layer_norm_fused_op_fwd_xe.hpp:166
float dropout_scale
Definition layer_norm_fused_op_fwd_xe.hpp:169
mask_in_t mask_in
Definition layer_norm_fused_op_fwd_xe.hpp:163
res_in_payload_t res_in_payload
Definition layer_norm_fused_op_fwd_xe.hpp:162
uint32_t mat_ld
Definition layer_norm_fused_op_fwd_xe.hpp:165
dtype_acc_ dtype_acc
Definition layer_norm_fused_op_fwd_xe.hpp:115
bias_in_payload_t bias_in_payload
Definition layer_norm_fused_op_fwd_xe.hpp:158
uint32_t matrix_m
Definition layer_norm_fused_op_fwd_xe.hpp:168
uint8_t dtype_mask
Definition layer_norm_fused_op_fwd_xe.hpp:118
dtype_in_ dtype_in
Definition layer_norm_fused_op_fwd_xe.hpp:116
bias_dropout_res_out_t bias_dropout_res_out
Definition layer_norm_fused_op_fwd_xe.hpp:159
dtype_out_ dtype_out
Definition layer_norm_fused_op_fwd_xe.hpp:117
__XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy, uint32_t start_m)
Definition layer_norm_fused_op_fwd_xe.hpp:180
__XETLA_API xetla_vector< dtype_acc, chunk_size > pre_op(xetla_vector< dtype_acc, chunk_size > input)
Definition layer_norm_fused_op_fwd_xe.hpp:206
bias_in_t bias_in
Definition layer_norm_fused_op_fwd_xe.hpp:157
mask_out_payload_t mask_out_payload
Definition layer_norm_fused_op_fwd_xe.hpp:420
uint32_t matrix_m
Definition layer_norm_fused_op_fwd_xe.hpp:424
uint32_t mask_ld
Definition layer_norm_fused_op_fwd_xe.hpp:422
dtype_out_ dtype_out
Definition layer_norm_fused_op_fwd_xe.hpp:375
bias_dropout_res_out_t bias_dropout_res_out
Definition layer_norm_fused_op_fwd_xe.hpp:415
uint32_t mat_ld
Definition layer_norm_fused_op_fwd_xe.hpp:421
__XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy, uint32_t start_m)
Definition layer_norm_fused_op_fwd_xe.hpp:436
dropout_fwd_t< chunk_size > dropout_fwd
Definition layer_norm_fused_op_fwd_xe.hpp:426
__XETLA_API xetla_vector< dtype_acc, chunk_size > pre_op(xetla_vector< dtype_acc, chunk_size > input)
Definition layer_norm_fused_op_fwd_xe.hpp:469
float dropout_prob
Definition layer_norm_fused_op_fwd_xe.hpp:425
dtype_in_ dtype_in
Definition layer_norm_fused_op_fwd_xe.hpp:374
__XETLA_API xetla_vector< dtype_acc, chunk_size > post_op(xetla_vector< dtype_acc, chunk_size > input)
Definition layer_norm_fused_op_fwd_xe.hpp:520
uint32_t matrix_n
Definition layer_norm_fused_op_fwd_xe.hpp:423
mask_out_t mask_out
Definition layer_norm_fused_op_fwd_xe.hpp:419
uint8_t dtype_mask
Definition layer_norm_fused_op_fwd_xe.hpp:376
res_in_payload_t res_in_payload
Definition layer_norm_fused_op_fwd_xe.hpp:418
res_in_t res_in
Definition layer_norm_fused_op_fwd_xe.hpp:417
bias_in_t bias_in
Definition layer_norm_fused_op_fwd_xe.hpp:413
bias_in_payload_t bias_in_payload
Definition layer_norm_fused_op_fwd_xe.hpp:414
bias_dropout_res_out_payload_t bias_dropout_res_out_payload
Definition layer_norm_fused_op_fwd_xe.hpp:416
dtype_acc_ dtype_acc
Definition layer_norm_fused_op_fwd_xe.hpp:373
dtype_in_ dtype_in
Definition layer_norm_fused_op_fwd_xe.hpp:276
dtype_acc_ dtype_acc
Definition layer_norm_fused_op_fwd_xe.hpp:275
__XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy, uint32_t start_m)
Definition layer_norm_fused_op_fwd_xe.hpp:319
uint32_t matrix_n
Definition layer_norm_fused_op_fwd_xe.hpp:308
uint8_t dtype_mask
Definition layer_norm_fused_op_fwd_xe.hpp:278
float dropout_scale
Definition layer_norm_fused_op_fwd_xe.hpp:309
dtype_out_ dtype_out
Definition layer_norm_fused_op_fwd_xe.hpp:277
mask_in_t mask_in
Definition layer_norm_fused_op_fwd_xe.hpp:304
mask_in_payload_t mask_in_payload
Definition layer_norm_fused_op_fwd_xe.hpp:305
uint32_t matrix_m
Definition layer_norm_fused_op_fwd_xe.hpp:307
__XETLA_API xetla_vector< dtype_acc, chunk_size > pre_op(xetla_vector< dtype_acc, chunk_size > input)
Definition layer_norm_fused_op_fwd_xe.hpp:335
uint32_t mask_ld
Definition layer_norm_fused_op_fwd_xe.hpp:306
__XETLA_API xetla_vector< dtype_acc, chunk_size > post_op(xetla_vector< dtype_acc, chunk_size > input)
Definition layer_norm_fused_op_fwd_xe.hpp:344
dropout_fwd_t< chunk_size > dropout_fwd
Definition layer_norm_fused_op_fwd_xe.hpp:568
uint32_t matrix_n
Definition layer_norm_fused_op_fwd_xe.hpp:571
dtype_in_ dtype_in
Definition layer_norm_fused_op_fwd_xe.hpp:539
mask_out_t mask_out
Definition layer_norm_fused_op_fwd_xe.hpp:566
uint8_t dtype_mask
Definition layer_norm_fused_op_fwd_xe.hpp:541
__XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy, uint32_t start_m)
Definition layer_norm_fused_op_fwd_xe.hpp:581
dtype_out_ dtype_out
Definition layer_norm_fused_op_fwd_xe.hpp:540
uint32_t mask_ld
Definition layer_norm_fused_op_fwd_xe.hpp:569
dtype_acc_ dtype_acc
Definition layer_norm_fused_op_fwd_xe.hpp:538
uint32_t matrix_m
Definition layer_norm_fused_op_fwd_xe.hpp:570
__XETLA_API xetla_vector< dtype_acc, chunk_size > pre_op(xetla_vector< dtype_acc, chunk_size > input)
Definition layer_norm_fused_op_fwd_xe.hpp:604
__XETLA_API xetla_vector< dtype_acc, chunk_size > post_op(xetla_vector< dtype_acc, chunk_size > input)
Definition layer_norm_fused_op_fwd_xe.hpp:613
mask_out_payload_t mask_out_payload
Definition layer_norm_fused_op_fwd_xe.hpp:567
Definition layer_norm_fused_op_api.hpp:60
Definition memory_descriptor.hpp:139
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102