37template <
typename dtype_x_,
typename dtype_y_,
typename dtype_weight_,
38 typename dtype_acc_,
typename layer_norm_attr_,
39 typename ln_bwd_fused_op_>
41 layer_norm_attr_,
gpu_arch::
Xe, ln_bwd_fused_op_> {
49 static constexpr uint32_t wg_tile_m = layer_norm_attr::wg_tile_m;
50 static constexpr uint32_t wg_tile_n = layer_norm_attr::wg_tile_n;
51 static constexpr uint32_t sg_tile_m = layer_norm_attr::sg_tile_m;
52 static constexpr uint32_t sg_tile_n = layer_norm_attr::sg_tile_n;
53 static constexpr uint32_t wg_num_m = layer_norm_attr::wg_num_m;
54 static constexpr uint32_t wg_num_n = layer_norm_attr::wg_num_n;
56 static constexpr uint32_t wg_size_x
57 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
58 static constexpr uint32_t wg_size_y
59 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
61 static_assert((wg_size_x <= 32) && ((wg_size_x & (wg_size_x - 1)) == 0),
62 "Current only support wg_size_x <=32");
63 static constexpr uint32_t count_col_reduce
64 = (wg_size_x > 1) ? wg_size_y : 0;
65 static constexpr uint32_t count_row_reduce
66 = (wg_size_y > 1) ? wg_size_x : 0;
67 struct get_barrier_count {
68 static constexpr uint32_t count = count_col_reduce + count_row_reduce;
71 static constexpr uint32_t size_col_reduce = (wg_size_x > 1)
72 ? wg_size_x * wg_size_y * 4 *
sizeof(
dtype_acc)
75 static constexpr uint32_t size_row_reduce = (wg_size_y > 1)
76 ? wg_size_y * wg_size_x * sg_tile_n *
sizeof(
dtype_acc)
79 static constexpr uint32_t size = size_col_reduce + size_row_reduce;
92 subgroup::msg_type_v<ln_bwd_tile_desc_t, mem_space::global>,
97 subgroup::msg_type_v<ln_bwd_tile_desc_t, mem_space::global>,
102 subgroup::msg_type_v<ln_bwd_tile_desc_t, mem_space::global>,
141 template <
typename T, uint32_t SZ, uint32_t N,
reduce_op Op,
142 uint32_t wg_size_x, uint32_t wg_size_y,
144 struct ln_group_all_reduce_t {
158 inline ln_group_all_reduce_t(uint32_t sg_idx = 0, uint32_t sg_idy = 0,
159 uint32_t slm_base = 0, uint32_t nbarrier_base = 0) {
160 slm_base_0 = slm_base + sg_idy * wg_size_x * N *
sizeof(T);
161 slm_base_1 = slm_base_0 + wg_size_x * wg_size_y * N *
sizeof(T);
163 group_reduce.init(sg_idx, sg_idy + nbarrier_base, slm_base_0);
172 uint32_t slm_base = (itr_count & 1) ? slm_base_1 : slm_base_0;
173 group_reduce.set_slm_base(slm_base);
187 template <uint32_t
SIMD = 64 /
sizeof(dtype_acc)>
188 __XETLA_API static xetla_vector<dtype_acc, sg_tile_n> get_x_temp(
189 xetla_vector<dtype_x, sg_tile_n> x, dtype_acc rs, dtype_acc mu) {
190 xetla_vector<dtype_acc, sg_tile_n> x_temp;
191 xetla_vector<dtype_acc, sg_tile_n> x_acc
192 = xetla_cvt<dtype_acc, dtype_x>(x);
195 for (uint32_t i = 0; i < sg_tile_n /
SIMD; i++) {
196 x_temp.xetla_select<
SIMD, 1>(i *
SIMD)
197 = rs * (x_acc.xetla_select<
SIMD, 1>(i *
SIMD) - mu);
199 if constexpr ((sg_tile_n %
SIMD) != 0) {
200 constexpr uint32_t start = sg_tile_n /
SIMD *
SIMD;
201 constexpr uint32_t SIMD_tail = sg_tile_n %
SIMD;
202 x_temp.xetla_select<SIMD_tail, 1>(start)
203 = rs * (x_acc.xetla_select<SIMD_tail, 1>(start) - mu);
214 template <uint32_t
SIMD = 64 /
sizeof(dtype_acc)>
215 __XETLA_API static xetla_vector<dtype_acc, sg_tile_n> get_dy_temp(
216 xetla_vector<dtype_weight, sg_tile_n> gamma,
217 xetla_vector<dtype_acc, sg_tile_n> dy) {
218 xetla_vector<dtype_acc, sg_tile_n> dy_temp;
219 xetla_vector<dtype_acc, sg_tile_n> gamma_acc
220 = xetla_cvt<dtype_acc, dtype_weight>(gamma);
223 for (uint32_t i = 0; i < sg_tile_n /
SIMD; i++) {
224 dy_temp.xetla_select<
SIMD, 1>(i *
SIMD)
225 = gamma_acc.xetla_select<
SIMD, 1>(i *
SIMD)
226 * dy.xetla_select<
SIMD, 1>(i *
SIMD);
228 if constexpr ((sg_tile_n %
SIMD) != 0) {
229 constexpr uint32_t start = sg_tile_n /
SIMD *
SIMD;
230 constexpr uint32_t SIMD_tail = sg_tile_n %
SIMD;
231 dy_temp.xetla_select<SIMD_tail, 1>(start)
232 = gamma_acc.xetla_select<SIMD_tail, 1>(start)
233 * dy.xetla_select<SIMD_tail, 1>(start);
237 using wg_col_reduce_t = ln_group_all_reduce_t<dtype_acc, sg_tile_n, 2,
242 uint32_t slm_base = 0, uint32_t nbarrier_base = 0,
245 g.init(item.get_local_linear_id());
246 uint32_t sg_idx = g.get_id() % wg_size_x;
247 uint32_t sg_idy = g.get_id() / wg_size_x;
248 uint32_t wg_idx = item.get_group(2);
249 uint32_t wg_idy = item.get_group(1);
250 int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n;
251 int start_m = wg_idy * wg_tile_m + sg_idy * sg_tile_m;
262 x_in_payload.init(args->x_in_ptr, args->matrix_n, args->matrix_m,
263 args->mat_ld, start_n, start_m);
264 dy_in_payload.init(args->dy_in_ptr, args->matrix_n, args->matrix_m,
265 args->mat_ld, start_n, start_m);
266 gamma_in_payload.init(args->gamma_in_ptr, args->matrix_n, 1,
267 args->mat_ld, start_n, 0);
268 dx_out_payload.init(args->dx_out_ptr, args->matrix_n, args->matrix_m,
269 args->mat_ld, start_n, start_m);
270 fused_op.init(fused_op_args, wg_idx, wg_idy, sg_idx, sg_idy);
273 const dtype_acc wg_rn = 1.0f / wg_tile_n;
275 wg_col_reduce_t wg_col_reduce(sg_idx, sg_idy, slm_base, nbarrier_base);
280 for (uint32_t row = start_m; row < args->matrix_m;
281 row += wg_num_m * wg_tile_m) {
294 dy_in_payload.update_tdesc(wg_num_m * wg_tile_m * args->mat_ld);
295 x_in_payload.update_tdesc(wg_num_m * wg_tile_m * args->mat_ld);
298 = xetla_cvt<dtype_acc, dtype_y>(dy_in.
reg);
299 dy = fused_op.pre_op(dy);
301 = get_x_temp(x_in.
reg, rs, mu);
303 = get_dy_temp(gamma_in.
reg, dy);
304 dgamma += dy * x_temp;
307 auto buffer_2d = buffer.xetla_format<
dtype_acc, 2, sg_tile_n>();
308 buffer_2d.row(0) = dy_temp;
309 buffer_2d.row(1) = x_temp * dy_temp;
314 = rs * (dy_temp - (grad_1 * x_temp + grad_0));
315 dx = fused_op.post_op(dx);
316 dx_out.
reg = xetla_cvt<dtype_x, dtype_acc>(dx);
317 subgroup::tile_store<cache_hint::uncached>(dx_out, dx_out_payload);
318 dx_out_payload.update_tdesc(wg_num_m * wg_tile_m * args->mat_ld);
322 uint32_t slm_row_reduce_base = slm_base + size_col_reduce;
323 uint32_t nbarrier_row_reduce_base = nbarrier_base + count_col_reduce;
324 ln_group_row_reduce.init(
325 sg_idx, sg_idy, slm_row_reduce_base, nbarrier_row_reduce_base);
327 ln_group_row_reduce(args->dgamma_acc_ptr, args->matrix_n, wg_num_m,
328 args->matrix_n, start_n, wg_idy, dgamma);
329 ln_group_row_reduce(args->dbeta_acc_ptr, args->matrix_n, wg_num_m,
330 args->matrix_n, start_n, wg_idy, dbeta);
331 fused_op.final_op(ln_group_row_reduce);
#define __XETLA_API
Definition common.hpp:43
#define SIMD
Definition gemm_softmax.cpp:23
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__XETLA_API xetla_vector< Ty, N *NElts > xetla_load_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_mask< N > pred=1)
Stateless scattered load.
Definition memory.hpp:245
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:734
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
reduce_op
xetla reduce op
Definition common.hpp:217
gpu_arch
Definition common.hpp:73
This is the group reduction.
Definition reduction_api.hpp:36
This is the group row reduction(reduce_sum) + cooperative write out.
Definition reduction_api.hpp:39
dtype_acc_ dtype_acc
Definition layer_norm_bwd_xe.hpp:45
static __XETLA_API void call(sycl::nd_item< 3 > &item, arguments_t *args, uint32_t slm_base=0, uint32_t nbarrier_base=0, ln_fused_op_arguments_t *fused_op_args=nullptr)
Definition layer_norm_bwd_xe.hpp:241
dtype_y_ dtype_y
Definition layer_norm_bwd_xe.hpp:43
layer_norm_attr_ layer_norm_attr
Definition layer_norm_bwd_xe.hpp:46
dtype_weight_ dtype_weight
Definition layer_norm_bwd_xe.hpp:44
typename ln_bwd_fused_op::arguments_t ln_fused_op_arguments_t
Definition layer_norm_bwd_xe.hpp:48
work_group_t< wg_size_x *wg_size_y > work_group_t
Definition layer_norm_bwd_xe.hpp:60
ln_bwd_fused_op_ ln_bwd_fused_op
Definition layer_norm_bwd_xe.hpp:47
dtype_x_ dtype_x
Definition layer_norm_bwd_xe.hpp:42
dtype_acc * dgamma_acc_ptr
Definition layer_norm_bwd_xe.hpp:122
dtype_x * x_in_ptr
Definition layer_norm_bwd_xe.hpp:116
uint32_t matrix_m
Definition layer_norm_bwd_xe.hpp:125
dtype_weight * gamma_in_ptr
Definition layer_norm_bwd_xe.hpp:117
dtype_y * dy_in_ptr
Definition layer_norm_bwd_xe.hpp:115
uint32_t mat_ld
Definition layer_norm_bwd_xe.hpp:127
dtype_acc * rs_ptr
Definition layer_norm_bwd_xe.hpp:118
uint32_t matrix_n
Definition layer_norm_bwd_xe.hpp:126
dtype_acc * mu_ptr
Definition layer_norm_bwd_xe.hpp:119
dtype_acc * dbeta_acc_ptr
Definition layer_norm_bwd_xe.hpp:123
dtype_x * dx_out_ptr
Definition layer_norm_bwd_xe.hpp:121
Definition memory_descriptor.hpp:139
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102