31template <
typename dtype_in,
typename dtype_out,
typename dtype_acc>
54 typename dtype_out_,
typename dtype_acc_,
typename layer_norm_attr_>
63 static constexpr uint32_t wg_tile_m = layer_norm_attr_::wg_tile_m;
64 static constexpr uint32_t wg_tile_n = layer_norm_attr_::wg_tile_n;
65 static constexpr uint32_t sg_tile_m = layer_norm_attr_::sg_tile_m;
66 static constexpr uint32_t sg_tile_n = layer_norm_attr_::sg_tile_n;
67 static constexpr uint32_t wg_num_m = layer_norm_attr_::wg_num_m;
68 static constexpr uint32_t wg_num_n = layer_norm_attr_::wg_num_n;
79 [[maybe_unused]] uint32_t wg_idx, [[maybe_unused]] uint32_t wg_idy,
80 [[maybe_unused]] uint32_t sg_idx,
81 [[maybe_unused]] uint32_t sg_idy) {}
106 template <
typename reduce_t>
116template <
typename dtype_in_,
typename dtype_out_,
typename dtype_acc_,
117 typename layer_norm_attr_>
119 dtype_out_, dtype_acc_, layer_norm_attr_,
gpu_arch::Xe> {
128 static constexpr uint32_t wg_tile_m = layer_norm_attr_::wg_tile_m;
129 static constexpr uint32_t wg_tile_n = layer_norm_attr_::wg_tile_n;
130 static constexpr uint32_t sg_tile_m = layer_norm_attr_::sg_tile_m;
131 static constexpr uint32_t sg_tile_n = layer_norm_attr_::sg_tile_n;
132 static constexpr uint32_t wg_num_m = layer_norm_attr_::wg_num_m;
133 static constexpr uint32_t wg_num_n = layer_norm_attr_::wg_num_n;
135 static constexpr uint32_t wg_size_x
136 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
137 static constexpr uint32_t wg_size_y
138 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
140 static_assert((sg_tile_n % (
sizeof(uint32_t) /
sizeof(
dtype_mask)) == 0),
141 "sg_tile_n need to be DW aligned");
176 uint32_t sg_idx, uint32_t sg_idy) {
177 int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n;
178 int start_m = wg_idy * wg_tile_m + sg_idy * sg_tile_m;
184 dx_resAdd_out_payload.init(args->
dx_resAdd_ptr, matrix_n, matrix_m,
185 mat_ld, start_n, start_m);
186 mask_in_payload.init(
187 args->
mask_ptr, matrix_n, matrix_m, mask_ld, start_n, start_m);
211 dx_resAdd_out.
reg = xetla_cvt<dtype_out, dtype_acc>(input);
212 subgroup::tile_store<cache_hint::uncached>(
213 dx_resAdd_out, dx_resAdd_out_payload);
214 dx_resAdd_out_payload.update_tdesc(wg_num_m * wg_tile_m * mat_ld);
215 if (dropout_prob != 0) {
218 mask_in_payload.update_tdesc(wg_num_m * wg_tile_m * mask_ld);
219 output = drop_out<dtype_acc, sg_tile_n>(
220 output, mask_in.
reg, dropout_scale_inv);
231 template <
typename reduce_t>
233 ln_group_row_reduce(dbias_acc_ptr, matrix_n, wg_num_m, matrix_n,
234 dbias_n, dbias_m, dbias);
244template <
typename dtype_in_,
typename dtype_out_,
typename dtype_acc_,
245 typename layer_norm_attr_>
247 dtype_out_, dtype_acc_, layer_norm_attr_,
gpu_arch::Xe> {
256 static constexpr uint32_t wg_tile_m = layer_norm_attr_::wg_tile_m;
257 static constexpr uint32_t wg_tile_n = layer_norm_attr_::wg_tile_n;
258 static constexpr uint32_t sg_tile_m = layer_norm_attr_::sg_tile_m;
259 static constexpr uint32_t sg_tile_n = layer_norm_attr_::sg_tile_n;
260 static constexpr uint32_t wg_num_m = layer_norm_attr_::wg_num_m;
261 static constexpr uint32_t wg_num_n = layer_norm_attr_::wg_num_n;
263 static constexpr uint32_t wg_size_x
264 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
265 static constexpr uint32_t wg_size_y
266 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
268 static_assert((sg_tile_n % (
sizeof(uint32_t) /
sizeof(
dtype_mask)) == 0),
269 "sg_tile_n need to be DW aligned");
302 uint32_t sg_idx, uint32_t sg_idy) {
303 int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n;
304 int start_m = wg_idy * wg_tile_m + sg_idy * sg_tile_m;
309 grad_in_payload.init(args->
gradAdd_ptr, matrix_n, matrix_m, mat_ld,
311 mask_in_payload.init(
312 args->
mask_ptr, matrix_n, matrix_m, mask_ld, start_n, start_m);
324 grad_in_payload.update_tdesc(wg_num_m * wg_tile_m * mat_ld);
326 = xetla_cvt<dtype_acc, dtype_in>(grad_in.
reg);
329 = reduce_helper<reduce_op::sum, dtype_acc, sg_tile_n>(
331 if (dropout_prob != 0) {
335 mask_in_payload.update_tdesc(wg_num_m * wg_tile_m * mask_ld);
336 output = drop_out<dtype_acc, sg_tile_n>(
337 output, mask_in.
reg, dropout_scale_inv);
356 template <
typename reduce_t>
366template <
typename dtype_in_,
typename dtype_out_,
typename dtype_acc_,
367 typename layer_norm_attr_>
378 static constexpr uint32_t wg_tile_m = layer_norm_attr_::wg_tile_m;
379 static constexpr uint32_t wg_tile_n = layer_norm_attr_::wg_tile_n;
380 static constexpr uint32_t sg_tile_m = layer_norm_attr_::sg_tile_m;
381 static constexpr uint32_t sg_tile_n = layer_norm_attr_::sg_tile_n;
382 static constexpr uint32_t wg_num_m = layer_norm_attr_::wg_num_m;
383 static constexpr uint32_t wg_num_n = layer_norm_attr_::wg_num_n;
385 static constexpr uint32_t wg_size_x
386 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
387 static constexpr uint32_t wg_size_y
388 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
412 uint32_t sg_idx, uint32_t sg_idy) {
413 int start_m = wg_idy * wg_tile_m + sg_idy * sg_tile_m;
414 int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n;
418 mask_in_payload.init(
419 args->
mask_ptr, matrix_n, matrix_m, mask_ld, start_n, start_m);
433 mask_in_payload.update_tdesc(wg_num_m * wg_tile_m * mask_ld);
434 output = drop_out<dtype_acc, sg_tile_n>(
435 input, mask_in.
reg, dropout_scale_inv);
453 template <
typename reduce_t>
#define SW_BARRIER()
SW_BARRIER, insert software scheduling barrier, for better code control.
Definition common.hpp:227
#define __XETLA_API
Definition common.hpp:43
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
Definition limitation.hpp:607
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
gpu_arch
Definition common.hpp:73
ln_bwd_fused_kind
Definition layer_norm_fused_op_api.hpp:40
Definition layer_norm_fused_op_bwd_xe.hpp:32
dtype_acc * dbias_acc_ptr
Definition layer_norm_fused_op_bwd_xe.hpp:33
dtype_out * dx_resAdd_ptr
Definition layer_norm_fused_op_bwd_xe.hpp:34
uint32_t matrix_m
Definition layer_norm_fused_op_bwd_xe.hpp:37
uint32_t mask_ld
Definition layer_norm_fused_op_bwd_xe.hpp:40
uint32_t matrix_n
Definition layer_norm_fused_op_bwd_xe.hpp:38
float dropout_scale_inv
Definition layer_norm_fused_op_bwd_xe.hpp:43
uint8_t * mask_ptr
Definition layer_norm_fused_op_bwd_xe.hpp:36
uint32_t mat_ld
Definition layer_norm_fused_op_bwd_xe.hpp:39
dtype_in * gradAdd_ptr
Definition layer_norm_fused_op_bwd_xe.hpp:35
float dropout_prob
Definition layer_norm_fused_op_bwd_xe.hpp:42
__XETLA_API xetla_vector< dtype_acc, sg_tile_n > pre_op(xetla_vector< dtype_acc, sg_tile_n > input)
Definition layer_norm_fused_op_bwd_xe.hpp:199
__XETLA_API xetla_vector< dtype_acc, sg_tile_n > post_op(xetla_vector< dtype_acc, sg_tile_n > input)
Definition layer_norm_fused_op_bwd_xe.hpp:208
uint8_t dtype_mask
Definition layer_norm_fused_op_bwd_xe.hpp:125
mask_in_t mask_in
Definition layer_norm_fused_op_bwd_xe.hpp:154
uint32_t mask_ld
Definition layer_norm_fused_op_bwd_xe.hpp:157
float dropout_scale_inv
Definition layer_norm_fused_op_bwd_xe.hpp:165
float dropout_prob
Definition layer_norm_fused_op_bwd_xe.hpp:164
__XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy)
Definition layer_norm_fused_op_bwd_xe.hpp:175
int32_t dbias_n
Definition layer_norm_fused_op_bwd_xe.hpp:160
uint32_t matrix_m
Definition layer_norm_fused_op_bwd_xe.hpp:159
dtype_acc * dbias_acc_ptr
Definition layer_norm_fused_op_bwd_xe.hpp:162
dtype_in_ dtype_in
Definition layer_norm_fused_op_bwd_xe.hpp:123
dx_resAdd_out_payload_t dx_resAdd_out_payload
Definition layer_norm_fused_op_bwd_xe.hpp:153
dx_resAdd_out_t dx_resAdd_out
Definition layer_norm_fused_op_bwd_xe.hpp:152
xetla_vector< dtype_acc, sg_tile_n > dbias
Definition layer_norm_fused_op_bwd_xe.hpp:163
uint32_t mat_ld
Definition layer_norm_fused_op_bwd_xe.hpp:156
dtype_acc_ dtype_acc
Definition layer_norm_fused_op_bwd_xe.hpp:122
__XETLA_API void final_op(reduce_t &ln_group_row_reduce)
Definition layer_norm_fused_op_bwd_xe.hpp:232
int32_t dbias_m
Definition layer_norm_fused_op_bwd_xe.hpp:161
uint32_t matrix_n
Definition layer_norm_fused_op_bwd_xe.hpp:158
mask_in_payload_t mask_in_payload
Definition layer_norm_fused_op_bwd_xe.hpp:155
dtype_out_ dtype_out
Definition layer_norm_fused_op_bwd_xe.hpp:124
dtype_out_ dtype_out
Definition layer_norm_fused_op_bwd_xe.hpp:374
mask_in_payload_t mask_in_payload
Definition layer_norm_fused_op_bwd_xe.hpp:397
dtype_in_ dtype_in
Definition layer_norm_fused_op_bwd_xe.hpp:373
mask_in_t mask_in
Definition layer_norm_fused_op_bwd_xe.hpp:396
uint8_t dtype_mask
Definition layer_norm_fused_op_bwd_xe.hpp:375
dtype_acc_ dtype_acc
Definition layer_norm_fused_op_bwd_xe.hpp:372
uint32_t matrix_m
Definition layer_norm_fused_op_bwd_xe.hpp:399
uint32_t matrix_n
Definition layer_norm_fused_op_bwd_xe.hpp:398
__XETLA_API xetla_vector< dtype_acc, sg_tile_n > pre_op(xetla_vector< dtype_acc, sg_tile_n > input)
Definition layer_norm_fused_op_bwd_xe.hpp:427
float dropout_scale_inv
Definition layer_norm_fused_op_bwd_xe.hpp:401
__XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy)
Definition layer_norm_fused_op_bwd_xe.hpp:411
__XETLA_API void final_op(reduce_t &ln_group_row_reduce)
Definition layer_norm_fused_op_bwd_xe.hpp:454
uint32_t mask_ld
Definition layer_norm_fused_op_bwd_xe.hpp:400
__XETLA_API xetla_vector< dtype_acc, sg_tile_n > post_op(xetla_vector< dtype_acc, sg_tile_n > input)
Definition layer_norm_fused_op_bwd_xe.hpp:443
uint32_t matrix_m
Definition layer_norm_fused_op_bwd_xe.hpp:289
uint32_t mask_ld
Definition layer_norm_fused_op_bwd_xe.hpp:287
float dropout_scale_inv
Definition layer_norm_fused_op_bwd_xe.hpp:291
uint8_t dtype_mask
Definition layer_norm_fused_op_bwd_xe.hpp:253
mask_in_payload_t mask_in_payload
Definition layer_norm_fused_op_bwd_xe.hpp:284
dtype_out_ dtype_out
Definition layer_norm_fused_op_bwd_xe.hpp:252
uint32_t matrix_n
Definition layer_norm_fused_op_bwd_xe.hpp:288
grad_in_t grad_in
Definition layer_norm_fused_op_bwd_xe.hpp:281
__XETLA_API xetla_vector< dtype_acc, sg_tile_n > pre_op(xetla_vector< dtype_acc, sg_tile_n > input)
Definition layer_norm_fused_op_bwd_xe.hpp:321
float dropout_prob
Definition layer_norm_fused_op_bwd_xe.hpp:290
__XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy)
Definition layer_norm_fused_op_bwd_xe.hpp:301
mask_in_t mask_in
Definition layer_norm_fused_op_bwd_xe.hpp:283
uint32_t mat_ld
Definition layer_norm_fused_op_bwd_xe.hpp:286
grad_in_payload_t grad_in_payload
Definition layer_norm_fused_op_bwd_xe.hpp:282
dtype_in_ dtype_in
Definition layer_norm_fused_op_bwd_xe.hpp:251
dtype_acc_ dtype_acc
Definition layer_norm_fused_op_bwd_xe.hpp:250
__XETLA_API xetla_vector< dtype_acc, sg_tile_n > post_op(xetla_vector< dtype_acc, sg_tile_n > input)
Definition layer_norm_fused_op_bwd_xe.hpp:346
__XETLA_API void final_op(reduce_t &ln_group_row_reduce)
Definition layer_norm_fused_op_bwd_xe.hpp:357
dtype_out_ dtype_out
Definition layer_norm_fused_op_bwd_xe.hpp:60
__XETLA_API xetla_vector< dtype_acc, sg_tile_n > post_op(xetla_vector< dtype_acc, sg_tile_n > input)
Definition layer_norm_fused_op_bwd_xe.hpp:96
dtype_acc_ dtype_acc
Definition layer_norm_fused_op_bwd_xe.hpp:58
__XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy)
Definition layer_norm_fused_op_bwd_xe.hpp:78
__XETLA_API xetla_vector< dtype_acc, sg_tile_n > pre_op(xetla_vector< dtype_acc, sg_tile_n > input)
Definition layer_norm_fused_op_bwd_xe.hpp:87
__XETLA_API void final_op(reduce_t &ln_group_row_reduce)
Definition layer_norm_fused_op_bwd_xe.hpp:107
dtype_in_ dtype_in
Definition layer_norm_fused_op_bwd_xe.hpp:59
Definition layer_norm_fused_op_api.hpp:73
Definition memory_descriptor.hpp:139
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102