38template <
typename dtype_x_,
typename dtype_y_,
typename dtype_weight_,
39 typename dtype_acc_,
typename layer_norm_attr_,
bool store_for_bwd_,
40 typename ln_fwd_fused_op_>
42 layer_norm_attr_, store_for_bwd_,
gpu_arch::
Xe, ln_fwd_fused_op_> {
50 static constexpr bool store_for_bwd = store_for_bwd_;
52 static constexpr uint32_t wg_tile_m = layer_norm_attr::wg_tile_m;
53 static constexpr uint32_t wg_tile_n = layer_norm_attr::wg_tile_n;
54 static constexpr uint32_t sg_tile_m = layer_norm_attr::sg_tile_m;
55 static constexpr uint32_t sg_tile_n = layer_norm_attr::sg_tile_n;
56 static constexpr uint32_t wg_num_m = layer_norm_attr::wg_num_m;
57 static constexpr uint32_t wg_num_n = layer_norm_attr::wg_num_n;
58 static constexpr uint32_t chunk_size = layer_norm_attr::chunk_size;
59 static constexpr uint32_t n_chunks = sg_tile_n / chunk_size;
60 static_assert(sg_tile_n % chunk_size == 0,
61 "Current impl does not support tailing mechanism");
63 static constexpr uint32_t wg_size_x
64 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
65 static constexpr uint32_t wg_size_y
66 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
68 static_assert((wg_size_x <= 32) && ((wg_size_x & (wg_size_x - 1)) == 0),
69 "Current only support wg_size_x <=32");
73 struct get_barrier_count {
74 static constexpr uint32_t count = (wg_size_x > 1) ? wg_size_y : 0;
81 static constexpr uint32_t size = (wg_size_x > 1)
82 ? wg_size_x * wg_size_y * 4 *
sizeof(
dtype_acc)
96 subgroup::msg_type_v<ln_fwd_tile_desc_t, mem_space::global>,
101 subgroup::msg_type_v<ln_fwd_tile_desc_t, mem_space::global>,
106 subgroup::msg_type_v<ln_fwd_tile_desc_t, mem_space::global>,
132 template <
typename T, u
int32_t SZ, u
int32_t N>
133 struct parallel_mu_m2_t {
136 auto mu_vec_a = mu_vec.xetla_select<SZ / 2, 1>(0);
137 auto mu_vec_b = mu_vec.xetla_select<SZ / 2, 1>(SZ / 2);
138 auto m2_vec_a = m2_vec.xetla_select<SZ / 2, 1>(0);
139 auto m2_vec_b = m2_vec.xetla_select<SZ / 2, 1>(SZ / 2);
140 xetla_vector<T, SZ / 2> mu_vec_new = (mu_vec_a + mu_vec_b) / (T)2;
141 xetla_vector<T, SZ / 2> m2_vec_new = m2_vec_a + m2_vec_b
142 + (mu_vec_a - mu_vec_b) * (mu_vec_a - mu_vec_b) * (T)N
144 return parallel_mu_m2_t<T, SZ / 2, N * 2>::call(
145 mu_vec_new, m2_vec_new);
153 template <
typename T, u
int32_t N>
154 struct parallel_mu_m2_t<T, 1, N> {
179 uint32_t slm_base = 0, uint32_t nbarrier_base = 0,
182 g.init(item.get_local_linear_id());
183 int sg_idx = g.get_id() % wg_size_x;
184 int sg_idy = g.get_id() / wg_size_x;
185 int wg_idx = item.get_group(2);
186 int wg_idy = item.get_group(1);
187 int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n;
188 int start_m = wg_idy * wg_tile_m + sg_idy * sg_tile_m;
203 x_in_payload.init(args->x_in_ptr, args->matrix_n, args->matrix_m,
204 args->mat_ld, start_n, start_m);
207 if constexpr (n_chunks == 1) {
209 fused_op_args, wg_idx, wg_idy, sg_idx, sg_idy, start_m);
210 gamma_in_payload.init(args->gamma_ptr, args->matrix_n, 1,
211 args->mat_ld, start_n, 0);
212 beta_in_payload.init(args->beta_ptr, args->matrix_n, 1,
213 args->mat_ld, start_n, 0);
217 y_out_payload.init(args->y_out_ptr, args->matrix_n, args->matrix_m,
218 args->mat_ld, start_n, start_m);
219 const dtype_acc sg_rn = 1.0f / sg_tile_n;
220 const dtype_acc wg_rn = 1.0f / wg_tile_n;
221 uint32_t slm_store_base_0 = sg_idx * 2 *
sizeof(
dtype_acc)
222 + sg_idy * wg_size_x * 2 *
sizeof(
dtype_acc) + slm_base;
223 uint32_t slm_load_base_0
224 = sg_idy * wg_size_x * 2 *
sizeof(
dtype_acc) + slm_base;
225 uint32_t slm_store_base_1 = slm_store_base_0
226 + wg_size_x * wg_size_y * 2 *
sizeof(
dtype_acc);
227 uint32_t slm_load_base_1 = slm_load_base_0
228 + wg_size_x * wg_size_y * 2 *
sizeof(
dtype_acc);
229 uint32_t itr_count = 0;
231 for (uint32_t row = start_m; row < args->matrix_m;
232 row += wg_num_m * wg_tile_m) {
233 if constexpr (n_chunks > 1) {
235 fused_op_args, wg_idx, wg_idy, sg_idx, sg_idy, row);
241 if constexpr (n_chunks > 1) {
242 x_in_payload.init(args->x_in_ptr, args->matrix_n,
243 args->matrix_m, args->mat_ld, start_n, row);
246 for (uint32_t i = 0; i < n_chunks; i++) {
248 x_in_payload.update_tdesc(chunk_size);
249 input = xetla_cvt<dtype_acc, dtype_x>(x_in.
reg);
251 input = fused_op.pre_op(input);
257 if constexpr (n_chunks > 1) {
259 fused_op_args, wg_idx, wg_idy, sg_idx, sg_idy, row);
260 x_in_payload.init(args->x_in_ptr, args->matrix_n,
261 args->matrix_m, args->mat_ld, start_n, row);
264 for (uint32_t i = 0; i < n_chunks; i++) {
265 if constexpr (n_chunks > 1) {
267 x_in_payload.update_tdesc(chunk_size);
268 input = xetla_cvt<dtype_acc, dtype_x>(x_in.
reg);
270 input = fused_op.pre_op(input);
279 if constexpr (wg_size_x > 1) {
280 uint32_t slm_store_base = (itr_count & 1) == 0
283 xetla_store_local<dtype_acc, 2>(slm_store_base, mu_m2);
284 xetla_fence<memory_kind::shared_local>();
286 uint32_t slm_load_base = (itr_count & 1) == 0 ? slm_load_base_0
292 = xetla_load_local<dtype_acc, wg_size_x * 2>(
295 = mu_m2_vec.xetla_select<wg_size_x, 2>(0);
297 = mu_m2_vec.xetla_select<wg_size_x, 2>(1);
298 mu_m2 = parallel_mu_m2_t<dtype_acc, wg_size_x, sg_tile_n>::call(
305 if constexpr (store_for_bwd) {
318 if constexpr (chunk_size > 1) {
319 gamma_in_payload.init(args->gamma_ptr, args->matrix_n, 1,
320 args->mat_ld, start_n, 0);
321 beta_in_payload.init(args->beta_ptr, args->matrix_n, 1,
322 args->mat_ld, start_n, 0);
327 if constexpr (n_chunks > 1) {
329 fused_op_args, wg_idx, wg_idy, sg_idx, sg_idy, row);
330 x_in_payload.init(args->x_in_ptr, args->matrix_n,
331 args->matrix_m, args->mat_ld, start_n, row);
334 for (uint32_t i = 0; i < n_chunks; i++) {
335 if constexpr (n_chunks > 1) {
337 gamma_in_payload.update_tdesc(chunk_size);
340 beta_in_payload.update_tdesc(chunk_size);
343 x_in_payload.update_tdesc(chunk_size);
344 input = xetla_cvt<dtype_acc, dtype_x>(x_in.
reg);
346 input = fused_op.pre_op(input);
349 = xetla_cvt<dtype_acc, dtype_weight, chunk_size>(
352 = xetla_cvt<dtype_acc, dtype_weight>(gamma_in.
reg);
354 output = beta + (rs * (input - mu)) * gamma;
356 output = fused_op.post_op(output);
357 y_out.
reg = xetla_cvt<dtype_y, dtype_acc, chunk_size>(output);
360 y_out_payload.update_tdesc(chunk_size);
362 x_in_payload.update_tdesc(
363 wg_num_m * wg_tile_m * args->mat_ld - sg_tile_n);
364 y_out_payload.update_tdesc(
365 wg_num_m * wg_tile_m * args->mat_ld - sg_tile_n);
#define __XETLA_API
Definition common.hpp:43
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__XETLA_API T0 xetla_reduce(xetla_vector< T1, SZ > v)
Performs reduction over elements of the input vector.
Definition math_general.hpp:520
__XETLA_API xetla_vector< T, SZ > xetla_rsqrt(xetla_vector< T, SZ > src, Sat sat={})
Calculate the inversion of square root, i.e.
Definition math_general.hpp:375
__XETLA_API void xetla_store_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_vector< Ty, N *NElts > vals, xetla_mask< N > pred=1)
Stateless scattered store.
Definition memory.hpp:316
Definition limitation.hpp:734
__XETLA_API std::enable_if_t< detail::check_store_type< tile_t, payload_t >::is_global_2d_xe > tile_store(tile_t &tile, payload_t &payload)
Is the func storing data from register file to global memory.
Definition store_xe.hpp:91
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
gpu_arch
Definition common.hpp:73
static xetla_vector< T, 2 > call(xetla_vector< T, SZ > mu_vec, xetla_vector< T, SZ > m2_vec)
Definition layer_norm_fwd_xe.hpp:134
static xetla_vector< T, 2 > call(xetla_vector< T, 1 > mu_vec, xetla_vector< T, 1 > m2_vec)
Definition layer_norm_fwd_xe.hpp:161
uint32_t matrix_n
Definition layer_norm_fwd_xe.hpp:122
uint32_t matrix_m
Definition layer_norm_fwd_xe.hpp:121
uint32_t mat_ld
Definition layer_norm_fwd_xe.hpp:123
dtype_y * y_out_ptr
Definition layer_norm_fwd_xe.hpp:118
dtype_acc * rs_ptr
Definition layer_norm_fwd_xe.hpp:119
dtype_x * x_in_ptr
Definition layer_norm_fwd_xe.hpp:115
dtype_weight * beta_ptr
Definition layer_norm_fwd_xe.hpp:117
dtype_weight * gamma_ptr
Definition layer_norm_fwd_xe.hpp:116
dtype_acc * mu_ptr
Definition layer_norm_fwd_xe.hpp:120
layer_norm_attr_ layer_norm_attr
Definition layer_norm_fwd_xe.hpp:47
dtype_acc_ dtype_acc
Definition layer_norm_fwd_xe.hpp:46
dtype_x_ dtype_x
Definition layer_norm_fwd_xe.hpp:43
work_group_t< wg_size_x *wg_size_y > work_group_t
Definition layer_norm_fwd_xe.hpp:67
dtype_y_ dtype_y
Definition layer_norm_fwd_xe.hpp:44
ln_fwd_fused_op_ ln_fwd_fused_op
Definition layer_norm_fwd_xe.hpp:48
static __XETLA_API void call(sycl::nd_item< 3 > &item, arguments_t *args, uint32_t slm_base=0, uint32_t nbarrier_base=0, ln_fused_op_arguments_t *fused_op_args=nullptr)
Definition layer_norm_fwd_xe.hpp:178
typename ln_fwd_fused_op::arguments_t ln_fused_op_arguments_t
Definition layer_norm_fwd_xe.hpp:49
dtype_weight_ dtype_weight
Definition layer_norm_fwd_xe.hpp:45
Definition memory_descriptor.hpp:139
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102
xetla nbarrier definition API.
Definition raw_send_nbarrier.hpp:43
__XETLA_API void arrive()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:65
__XETLA_API void init_nbarrier(uint8_t nbarrier_id, nbarrier_role role=nbarrier_role::producer_consumer)
Definition raw_send_nbarrier.hpp:55
__XETLA_API void wait()
named barrier wait within subgroup.
Definition raw_send_nbarrier.hpp:76