XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
layer_norm_fwd_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
26
27namespace gpu::xetla::kernel {
28
38template <typename dtype_x_, typename dtype_y_, typename dtype_weight_,
39 typename dtype_acc_, typename layer_norm_attr_, bool store_for_bwd_,
40 typename ln_fwd_fused_op_>
41struct layer_norm_fwd_t<dtype_x_, dtype_y_, dtype_weight_, dtype_acc_,
42 layer_norm_attr_, store_for_bwd_, gpu_arch::Xe, ln_fwd_fused_op_> {
43 using dtype_x = dtype_x_;
44 using dtype_y = dtype_y_;
45 using dtype_weight = dtype_weight_;
46 using dtype_acc = dtype_acc_;
47 using layer_norm_attr = layer_norm_attr_;
48 using ln_fwd_fused_op = ln_fwd_fused_op_;
49 using ln_fused_op_arguments_t = typename ln_fwd_fused_op::arguments_t;
50 static constexpr bool store_for_bwd = store_for_bwd_;
51
52 static constexpr uint32_t wg_tile_m = layer_norm_attr::wg_tile_m;
53 static constexpr uint32_t wg_tile_n = layer_norm_attr::wg_tile_n;
54 static constexpr uint32_t sg_tile_m = layer_norm_attr::sg_tile_m;
55 static constexpr uint32_t sg_tile_n = layer_norm_attr::sg_tile_n;
56 static constexpr uint32_t wg_num_m = layer_norm_attr::wg_num_m;
57 static constexpr uint32_t wg_num_n = layer_norm_attr::wg_num_n;
58 static constexpr uint32_t chunk_size = layer_norm_attr::chunk_size;
59 static constexpr uint32_t n_chunks = sg_tile_n / chunk_size;
60 static_assert(sg_tile_n % chunk_size == 0,
61 "Current impl does not support tailing mechanism");
62
63 static constexpr uint32_t wg_size_x
64 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
65 static constexpr uint32_t wg_size_y
66 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
68 static_assert((wg_size_x <= 32) && ((wg_size_x & (wg_size_x - 1)) == 0),
69 "Current only support wg_size_x <=32");
70
73 struct get_barrier_count {
74 static constexpr uint32_t count = (wg_size_x > 1) ? wg_size_y : 0;
75 };
76
79 struct get_slm_size {
80 // 4 = (mu + m2) * double buffering
81 static constexpr uint32_t size = (wg_size_x > 1)
82 ? wg_size_x * wg_size_y * 4 * sizeof(dtype_acc)
83 : 0;
84 };
85
86 using ln_fwd_tile_desc_t = subgroup::tile_desc_t<chunk_size, 1, chunk_size,
92
96 subgroup::msg_type_v<ln_fwd_tile_desc_t, mem_space::global>,
101 subgroup::msg_type_v<ln_fwd_tile_desc_t, mem_space::global>,
106 subgroup::msg_type_v<ln_fwd_tile_desc_t, mem_space::global>,
111
114 struct arguments_t {
121 uint32_t matrix_m;
122 uint32_t matrix_n;
123 uint32_t mat_ld;
124 dtype_acc epsilon = 1e-5;
125 };
126
132 template <typename T, uint32_t SZ, uint32_t N>
133 struct parallel_mu_m2_t {
136 auto mu_vec_a = mu_vec.xetla_select<SZ / 2, 1>(0);
137 auto mu_vec_b = mu_vec.xetla_select<SZ / 2, 1>(SZ / 2);
138 auto m2_vec_a = m2_vec.xetla_select<SZ / 2, 1>(0);
139 auto m2_vec_b = m2_vec.xetla_select<SZ / 2, 1>(SZ / 2);
140 xetla_vector<T, SZ / 2> mu_vec_new = (mu_vec_a + mu_vec_b) / (T)2;
141 xetla_vector<T, SZ / 2> m2_vec_new = m2_vec_a + m2_vec_b
142 + (mu_vec_a - mu_vec_b) * (mu_vec_a - mu_vec_b) * (T)N
143 / (T)2;
144 return parallel_mu_m2_t<T, SZ / 2, N * 2>::call(
145 mu_vec_new, m2_vec_new);
146 }
147 };
148
153 template <typename T, uint32_t N>
154 struct parallel_mu_m2_t<T, 1, N> {
155
162 xetla_vector<T, 1> mu_vec, xetla_vector<T, 1> m2_vec) {
164 ret[0] = mu_vec[0];
165 ret[1] = m2_vec[0];
166 return ret;
167 }
168 };
169
178 __XETLA_API static void call(sycl::nd_item<3> &item, arguments_t *args,
179 uint32_t slm_base = 0, uint32_t nbarrier_base = 0,
180 ln_fused_op_arguments_t *fused_op_args = nullptr) {
181 work_group_t g;
182 g.init(item.get_local_linear_id());
183 int sg_idx = g.get_id() % wg_size_x;
184 int sg_idy = g.get_id() / wg_size_x;
185 int wg_idx = item.get_group(2);
186 int wg_idy = item.get_group(1);
187 int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n;
188 int start_m = wg_idy * wg_tile_m + sg_idy * sg_tile_m;
189
191 nbarrier.init_nbarrier(
192 sg_idy + nbarrier_base, nbarrier_role::producer_consumer);
193
194 x_in_t x_in;
195 x_in_payload_t x_in_payload;
196 gamma_in_t gamma_in;
197 gamma_in_payload_t gamma_in_payload;
198 beta_in_t beta_in;
199 beta_in_payload_t beta_in_payload;
200 y_out_t y_out;
201 y_out_payload_t y_out_payload;
202 ln_fwd_fused_op fused_op;
203 x_in_payload.init(args->x_in_ptr, args->matrix_n, args->matrix_m,
204 args->mat_ld, start_n, start_m);
205 // >>>>>>>>>> fused op fwd init
206
207 if constexpr (n_chunks == 1) {
208 fused_op.init(
209 fused_op_args, wg_idx, wg_idy, sg_idx, sg_idy, start_m);
210 gamma_in_payload.init(args->gamma_ptr, args->matrix_n, 1,
211 args->mat_ld, start_n, 0);
212 beta_in_payload.init(args->beta_ptr, args->matrix_n, 1,
213 args->mat_ld, start_n, 0);
214 subgroup::tile_load(gamma_in, gamma_in_payload);
215 subgroup::tile_load(beta_in, beta_in_payload);
216 }
217 y_out_payload.init(args->y_out_ptr, args->matrix_n, args->matrix_m,
218 args->mat_ld, start_n, start_m);
219 const dtype_acc sg_rn = 1.0f / sg_tile_n;
220 const dtype_acc wg_rn = 1.0f / wg_tile_n;
221 uint32_t slm_store_base_0 = sg_idx * 2 * sizeof(dtype_acc)
222 + sg_idy * wg_size_x * 2 * sizeof(dtype_acc) + slm_base;
223 uint32_t slm_load_base_0
224 = sg_idy * wg_size_x * 2 * sizeof(dtype_acc) + slm_base;
225 uint32_t slm_store_base_1 = slm_store_base_0
226 + wg_size_x * wg_size_y * 2 * sizeof(dtype_acc);
227 uint32_t slm_load_base_1 = slm_load_base_0
228 + wg_size_x * wg_size_y * 2 * sizeof(dtype_acc);
229 uint32_t itr_count = 0;
230
231 for (uint32_t row = start_m; row < args->matrix_m;
232 row += wg_num_m * wg_tile_m) {
233 if constexpr (n_chunks > 1) {
234 fused_op.init(
235 fused_op_args, wg_idx, wg_idy, sg_idx, sg_idy, row);
236 }
239 mu_m2[0] = 0;
240 mu_m2[1] = 0;
241 if constexpr (n_chunks > 1) {
242 x_in_payload.init(args->x_in_ptr, args->matrix_n,
243 args->matrix_m, args->mat_ld, start_n, row);
244 }
245#pragma unroll
246 for (uint32_t i = 0; i < n_chunks; i++) {
247 subgroup::tile_load(x_in, x_in_payload);
248 x_in_payload.update_tdesc(chunk_size);
249 input = xetla_cvt<dtype_acc, dtype_x>(x_in.reg);
250 // >>>>>>>>>> fused op pre-processing
251 input = fused_op.pre_op(input);
252 // >>>>>>>>>> first do sg_level reduction
253 mu_m2[0] += xetla_reduce<dtype_acc, dtype_acc, chunk_size,
254 reduce_op::sum>(input);
255 }
256 mu_m2[0] *= sg_rn;
257 if constexpr (n_chunks > 1) {
258 fused_op.init(
259 fused_op_args, wg_idx, wg_idy, sg_idx, sg_idy, row);
260 x_in_payload.init(args->x_in_ptr, args->matrix_n,
261 args->matrix_m, args->mat_ld, start_n, row);
262 }
263#pragma unroll
264 for (uint32_t i = 0; i < n_chunks; i++) {
265 if constexpr (n_chunks > 1) {
266 subgroup::tile_load(x_in, x_in_payload);
267 x_in_payload.update_tdesc(chunk_size);
268 input = xetla_cvt<dtype_acc, dtype_x>(x_in.reg);
269 // >>>>>>>>>> fused op pre-processing
270 input = fused_op.pre_op(input);
271 }
272
274 = input - dtype_acc(mu_m2[0]);
275 mu_m2[1] += xetla_reduce<dtype_acc, dtype_acc, chunk_size,
276 reduce_op::sum>(diff * diff);
277 }
278 // >>>>>>>>>> then do wg_level reduction
279 if constexpr (wg_size_x > 1) {
280 uint32_t slm_store_base = (itr_count & 1) == 0
281 ? slm_store_base_0
282 : slm_store_base_1;
283 xetla_store_local<dtype_acc, 2>(slm_store_base, mu_m2);
284 xetla_fence<memory_kind::shared_local>();
285 nbarrier.arrive();
286 uint32_t slm_load_base = (itr_count & 1) == 0 ? slm_load_base_0
287 : slm_load_base_1;
288 itr_count += 1;
289 nbarrier.wait();
290
292 = xetla_load_local<dtype_acc, wg_size_x * 2>(
293 slm_load_base);
295 = mu_m2_vec.xetla_select<wg_size_x, 2>(0);
297 = mu_m2_vec.xetla_select<wg_size_x, 2>(1);
298 mu_m2 = parallel_mu_m2_t<dtype_acc, wg_size_x, sg_tile_n>::call(
299 mu_vec, m2_vec);
300 }
301 dtype_acc mu = mu_m2[0];
302 dtype_acc m2 = mu_m2[1];
303 dtype_acc rs = xetla_rsqrt(m2 * wg_rn + args->epsilon);
304
305 if constexpr (store_for_bwd) {
306 if (sg_idx == 0) {
309 args->mu_ptr, row * sizeof(dtype_acc),
313 args->rs_ptr, row * sizeof(dtype_acc),
315 }
316 }
317 // to generate mixed instruction
318 if constexpr (chunk_size > 1) {
319 gamma_in_payload.init(args->gamma_ptr, args->matrix_n, 1,
320 args->mat_ld, start_n, 0);
321 beta_in_payload.init(args->beta_ptr, args->matrix_n, 1,
322 args->mat_ld, start_n, 0);
323 }
324
326
327 if constexpr (n_chunks > 1) {
328 fused_op.init(
329 fused_op_args, wg_idx, wg_idy, sg_idx, sg_idy, row);
330 x_in_payload.init(args->x_in_ptr, args->matrix_n,
331 args->matrix_m, args->mat_ld, start_n, row);
332 }
333#pragma unroll
334 for (uint32_t i = 0; i < n_chunks; i++) {
335 if constexpr (n_chunks > 1) {
336 subgroup::tile_load(gamma_in, gamma_in_payload);
337 gamma_in_payload.update_tdesc(chunk_size);
338
339 subgroup::tile_load(beta_in, beta_in_payload);
340 beta_in_payload.update_tdesc(chunk_size);
341
342 subgroup::tile_load(x_in, x_in_payload);
343 x_in_payload.update_tdesc(chunk_size);
344 input = xetla_cvt<dtype_acc, dtype_x>(x_in.reg);
345 // >>>>>>>>>> fused op pre-processing
346 input = fused_op.pre_op(input);
347 }
349 = xetla_cvt<dtype_acc, dtype_weight, chunk_size>(
350 beta_in.reg);
352 = xetla_cvt<dtype_acc, dtype_weight>(gamma_in.reg);
353
354 output = beta + (rs * (input - mu)) * gamma;
355 // >>>>>>>>>> fused op post-processing
356 output = fused_op.post_op(output);
357 y_out.reg = xetla_cvt<dtype_y, dtype_acc, chunk_size>(output);
359 cache_hint::write_back>(y_out, y_out_payload);
360 y_out_payload.update_tdesc(chunk_size);
361 }
362 x_in_payload.update_tdesc(
363 wg_num_m * wg_tile_m * args->mat_ld - sg_tile_n);
364 y_out_payload.update_tdesc(
365 wg_num_m * wg_tile_m * args->mat_ld - sg_tile_n);
366 }
367 }
368};
369
370} // namespace gpu::xetla::kernel
#define __XETLA_API
Definition common.hpp:43
C++ API.
C++ API.
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__XETLA_API T0 xetla_reduce(xetla_vector< T1, SZ > v)
Performs reduction over elements of the input vector.
Definition math_general.hpp:520
__XETLA_API xetla_vector< T, SZ > xetla_rsqrt(xetla_vector< T, SZ > src, Sat sat={})
Calculate the inversion of square root, i.e.
Definition math_general.hpp:375
__XETLA_API void xetla_store_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_vector< Ty, N *NElts > vals, xetla_mask< N > pred=1)
Stateless scattered store.
Definition memory.hpp:316
C++ API.
Definition limitation.hpp:734
__XETLA_API std::enable_if_t< detail::check_store_type< tile_t, payload_t >::is_global_2d_xe > tile_store(tile_t &tile, payload_t &payload)
Is the func storing data from register file to global memory.
Definition store_xe.hpp:91
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
gpu_arch
Definition common.hpp:73
static xetla_vector< T, 2 > call(xetla_vector< T, SZ > mu_vec, xetla_vector< T, SZ > m2_vec)
Definition layer_norm_fwd_xe.hpp:134
static xetla_vector< T, 2 > call(xetla_vector< T, 1 > mu_vec, xetla_vector< T, 1 > m2_vec)
Definition layer_norm_fwd_xe.hpp:161
static __XETLA_API void call(sycl::nd_item< 3 > &item, arguments_t *args, uint32_t slm_base=0, uint32_t nbarrier_base=0, ln_fused_op_arguments_t *fused_op_args=nullptr)
Definition layer_norm_fwd_xe.hpp:178
Definition memory_descriptor.hpp:139
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102
xetla nbarrier definition API.
Definition raw_send_nbarrier.hpp:43
__XETLA_API void arrive()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:65
__XETLA_API void init_nbarrier(uint8_t nbarrier_id, nbarrier_role role=nbarrier_role::producer_consumer)
Definition raw_send_nbarrier.hpp:55
__XETLA_API void wait()
named barrier wait within subgroup.
Definition raw_send_nbarrier.hpp:76