XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
layer_norm_bwd_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
26
27namespace gpu::xetla::kernel {
28
37template <typename dtype_x_, typename dtype_y_, typename dtype_weight_,
38 typename dtype_acc_, typename layer_norm_attr_,
39 typename ln_bwd_fused_op_>
40struct layer_norm_bwd_t<dtype_x_, dtype_y_, dtype_weight_, dtype_acc_,
41 layer_norm_attr_, gpu_arch::Xe, ln_bwd_fused_op_> {
42 using dtype_x = dtype_x_;
43 using dtype_y = dtype_y_;
44 using dtype_weight = dtype_weight_;
45 using dtype_acc = dtype_acc_;
46 using layer_norm_attr = layer_norm_attr_;
47 using ln_bwd_fused_op = ln_bwd_fused_op_;
48 using ln_fused_op_arguments_t = typename ln_bwd_fused_op::arguments_t;
49 static constexpr uint32_t wg_tile_m = layer_norm_attr::wg_tile_m;
50 static constexpr uint32_t wg_tile_n = layer_norm_attr::wg_tile_n;
51 static constexpr uint32_t sg_tile_m = layer_norm_attr::sg_tile_m;
52 static constexpr uint32_t sg_tile_n = layer_norm_attr::sg_tile_n;
53 static constexpr uint32_t wg_num_m = layer_norm_attr::wg_num_m;
54 static constexpr uint32_t wg_num_n = layer_norm_attr::wg_num_n;
55
56 static constexpr uint32_t wg_size_x
57 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
58 static constexpr uint32_t wg_size_y
59 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
61 static_assert((wg_size_x <= 32) && ((wg_size_x & (wg_size_x - 1)) == 0),
62 "Current only support wg_size_x <=32");
63 static constexpr uint32_t count_col_reduce
64 = (wg_size_x > 1) ? wg_size_y : 0;
65 static constexpr uint32_t count_row_reduce
66 = (wg_size_y > 1) ? wg_size_x : 0;
67 struct get_barrier_count {
68 static constexpr uint32_t count = count_col_reduce + count_row_reduce;
69 };
70 // 4 = (grad0 + grad1) * double buffering
71 static constexpr uint32_t size_col_reduce = (wg_size_x > 1)
72 ? wg_size_x * wg_size_y * 4 * sizeof(dtype_acc)
73 : 0;
74 // wg_size_y * rows
75 static constexpr uint32_t size_row_reduce = (wg_size_y > 1)
76 ? wg_size_y * wg_size_x * sg_tile_n * sizeof(dtype_acc)
77 : 0;
78 struct get_slm_size {
79 static constexpr uint32_t size = size_col_reduce + size_row_reduce;
80 };
81
82 using ln_bwd_tile_desc_t = subgroup::tile_desc_t<sg_tile_n, 1, sg_tile_n, 1,
88
92 subgroup::msg_type_v<ln_bwd_tile_desc_t, mem_space::global>,
97 subgroup::msg_type_v<ln_bwd_tile_desc_t, mem_space::global>,
102 subgroup::msg_type_v<ln_bwd_tile_desc_t, mem_space::global>,
107
110 wg_size_x, wg_size_y, 32, gpu_arch::Xe>;
111
114 struct arguments_t {
120
124
125 uint32_t matrix_m;
126 uint32_t matrix_n;
127 uint32_t mat_ld;
128 dtype_acc epsilon = 1e-5;
129 };
130
131private:
141 template <typename T, uint32_t SZ, uint32_t N, reduce_op Op,
142 uint32_t wg_size_x, uint32_t wg_size_y,
143 gpu_arch arch_ = gpu_arch::Xe>
144 struct ln_group_all_reduce_t {
145 uint32_t itr_count;
146 uint32_t slm_base_0;
147 uint32_t slm_base_1;
148
150 group_reduce;
151
158 inline ln_group_all_reduce_t(uint32_t sg_idx = 0, uint32_t sg_idy = 0,
159 uint32_t slm_base = 0, uint32_t nbarrier_base = 0) {
160 slm_base_0 = slm_base + sg_idy * wg_size_x * N * sizeof(T);
161 slm_base_1 = slm_base_0 + wg_size_x * wg_size_y * N * sizeof(T);
162 itr_count = 0;
163 group_reduce.init(sg_idx, sg_idy + nbarrier_base, slm_base_0);
164 }
165
170 inline KERNEL_FUNC xetla_vector<T, N> operator()(
172 uint32_t slm_base = (itr_count & 1) ? slm_base_1 : slm_base_0;
173 group_reduce.set_slm_base(slm_base);
174 xetla_vector<T, N> ret = group_reduce(buffer);
175 itr_count += 1;
176 return ret;
177 }
178 };
179
187 template <uint32_t SIMD = 64 / sizeof(dtype_acc)>
188 __XETLA_API static xetla_vector<dtype_acc, sg_tile_n> get_x_temp(
189 xetla_vector<dtype_x, sg_tile_n> x, dtype_acc rs, dtype_acc mu) {
190 xetla_vector<dtype_acc, sg_tile_n> x_temp;
191 xetla_vector<dtype_acc, sg_tile_n> x_acc
192 = xetla_cvt<dtype_acc, dtype_x>(x);
194#pragma unroll
195 for (uint32_t i = 0; i < sg_tile_n / SIMD; i++) {
196 x_temp.xetla_select<SIMD, 1>(i * SIMD)
197 = rs * (x_acc.xetla_select<SIMD, 1>(i * SIMD) - mu);
198 }
199 if constexpr ((sg_tile_n % SIMD) != 0) {
200 constexpr uint32_t start = sg_tile_n / SIMD * SIMD;
201 constexpr uint32_t SIMD_tail = sg_tile_n % SIMD;
202 x_temp.xetla_select<SIMD_tail, 1>(start)
203 = rs * (x_acc.xetla_select<SIMD_tail, 1>(start) - mu);
204 }
205 return x_temp;
206 }
207
214 template <uint32_t SIMD = 64 / sizeof(dtype_acc)>
215 __XETLA_API static xetla_vector<dtype_acc, sg_tile_n> get_dy_temp(
216 xetla_vector<dtype_weight, sg_tile_n> gamma,
217 xetla_vector<dtype_acc, sg_tile_n> dy) {
218 xetla_vector<dtype_acc, sg_tile_n> dy_temp;
219 xetla_vector<dtype_acc, sg_tile_n> gamma_acc
220 = xetla_cvt<dtype_acc, dtype_weight>(gamma);
222#pragma unroll
223 for (uint32_t i = 0; i < sg_tile_n / SIMD; i++) {
224 dy_temp.xetla_select<SIMD, 1>(i * SIMD)
225 = gamma_acc.xetla_select<SIMD, 1>(i * SIMD)
226 * dy.xetla_select<SIMD, 1>(i * SIMD);
227 }
228 if constexpr ((sg_tile_n % SIMD) != 0) {
229 constexpr uint32_t start = sg_tile_n / SIMD * SIMD;
230 constexpr uint32_t SIMD_tail = sg_tile_n % SIMD;
231 dy_temp.xetla_select<SIMD_tail, 1>(start)
232 = gamma_acc.xetla_select<SIMD_tail, 1>(start)
233 * dy.xetla_select<SIMD_tail, 1>(start);
234 }
235 return dy_temp;
236 }
237 using wg_col_reduce_t = ln_group_all_reduce_t<dtype_acc, sg_tile_n, 2,
238 reduce_op::sum, wg_size_x, wg_size_y, gpu_arch::Xe>;
239
240public:
241 __XETLA_API static void call(sycl::nd_item<3> &item, arguments_t *args,
242 uint32_t slm_base = 0, uint32_t nbarrier_base = 0,
243 ln_fused_op_arguments_t *fused_op_args = nullptr) {
244 work_group_t g;
245 g.init(item.get_local_linear_id());
246 uint32_t sg_idx = g.get_id() % wg_size_x;
247 uint32_t sg_idy = g.get_id() / wg_size_x;
248 uint32_t wg_idx = item.get_group(2);
249 uint32_t wg_idy = item.get_group(1);
250 int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n;
251 int start_m = wg_idy * wg_tile_m + sg_idy * sg_tile_m;
252
253 x_in_t x_in;
254 x_in_payload_t x_in_payload;
255 dy_in_t dy_in;
256 dy_in_payload_t dy_in_payload;
257 gamma_in_t gamma_in;
258 gamma_in_payload_t gamma_in_payload;
259 dx_out_t dx_out;
260 dx_out_payload_t dx_out_payload;
261 ln_bwd_fused_op fused_op;
262 x_in_payload.init(args->x_in_ptr, args->matrix_n, args->matrix_m,
263 args->mat_ld, start_n, start_m);
264 dy_in_payload.init(args->dy_in_ptr, args->matrix_n, args->matrix_m,
265 args->mat_ld, start_n, start_m);
266 gamma_in_payload.init(args->gamma_in_ptr, args->matrix_n, 1,
267 args->mat_ld, start_n, 0);
268 dx_out_payload.init(args->dx_out_ptr, args->matrix_n, args->matrix_m,
269 args->mat_ld, start_n, start_m);
270 fused_op.init(fused_op_args, wg_idx, wg_idy, sg_idx, sg_idy);
271 subgroup::tile_load(gamma_in, gamma_in_payload);
272
273 const dtype_acc wg_rn = 1.0f / wg_tile_n;
274
275 wg_col_reduce_t wg_col_reduce(sg_idx, sg_idy, slm_base, nbarrier_base);
276
279
280 for (uint32_t row = start_m; row < args->matrix_m;
281 row += wg_num_m * wg_tile_m) {
282 subgroup::tile_load(dy_in, dy_in_payload);
283 subgroup::tile_load(x_in, x_in_payload);
287 args->mu_ptr, row * sizeof(dtype_acc));
291 args->rs_ptr, row * sizeof(dtype_acc));
292 dtype_acc mu = mu_v[0];
293 dtype_acc rs = rs_v[0];
294 dy_in_payload.update_tdesc(wg_num_m * wg_tile_m * args->mat_ld);
295 x_in_payload.update_tdesc(wg_num_m * wg_tile_m * args->mat_ld);
296
298 = xetla_cvt<dtype_acc, dtype_y>(dy_in.reg);
299 dy = fused_op.pre_op(dy);
301 = get_x_temp(x_in.reg, rs, mu);
303 = get_dy_temp(gamma_in.reg, dy);
304 dgamma += dy * x_temp;
305 dbeta += dy;
307 auto buffer_2d = buffer.xetla_format<dtype_acc, 2, sg_tile_n>();
308 buffer_2d.row(0) = dy_temp;
309 buffer_2d.row(1) = x_temp * dy_temp;
310 xetla_vector<dtype_acc, 2> grad_0_1 = wg_col_reduce(buffer);
311 dtype_acc grad_0 = grad_0_1[0] * wg_rn;
312 dtype_acc grad_1 = grad_0_1[1] * wg_rn;
314 = rs * (dy_temp - (grad_1 * x_temp + grad_0));
315 dx = fused_op.post_op(dx);
316 dx_out.reg = xetla_cvt<dtype_x, dtype_acc>(dx);
317 subgroup::tile_store<cache_hint::uncached>(dx_out, dx_out_payload);
318 dx_out_payload.update_tdesc(wg_num_m * wg_tile_m * args->mat_ld);
319 }
320
321 ln_group_row_reduce_store_t ln_group_row_reduce;
322 uint32_t slm_row_reduce_base = slm_base + size_col_reduce;
323 uint32_t nbarrier_row_reduce_base = nbarrier_base + count_col_reduce;
324 ln_group_row_reduce.init(
325 sg_idx, sg_idy, slm_row_reduce_base, nbarrier_row_reduce_base);
327 ln_group_row_reduce(args->dgamma_acc_ptr, args->matrix_n, wg_num_m,
328 args->matrix_n, start_n, wg_idy, dgamma);
329 ln_group_row_reduce(args->dbeta_acc_ptr, args->matrix_n, wg_num_m,
330 args->matrix_n, start_n, wg_idy, dbeta);
331 fused_op.final_op(ln_group_row_reduce);
332 }
333};
334
335} // namespace gpu::xetla::kernel
#define __XETLA_API
Definition common.hpp:43
C++ API.
C++ API.
#define SIMD
Definition gemm_softmax.cpp:23
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__XETLA_API xetla_vector< Ty, N *NElts > xetla_load_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_mask< N > pred=1)
Stateless scattered load.
Definition memory.hpp:245
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
C++ API.
Definition limitation.hpp:734
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
reduce_op
xetla reduce op
Definition common.hpp:217
gpu_arch
Definition common.hpp:73
This is the group reduction.
Definition reduction_api.hpp:36
This is the group row reduction(reduce_sum) + cooperative write out.
Definition reduction_api.hpp:39
static __XETLA_API void call(sycl::nd_item< 3 > &item, arguments_t *args, uint32_t slm_base=0, uint32_t nbarrier_base=0, ln_fused_op_arguments_t *fused_op_args=nullptr)
Definition layer_norm_bwd_xe.hpp:241
Definition memory_descriptor.hpp:139
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102