XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
layer_norm_fused_op_bwd_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
23
24namespace gpu::xetla::group {
25
31template <typename dtype_in, typename dtype_out, typename dtype_acc>
33 dtype_acc *dbias_acc_ptr;
34 dtype_out *dx_resAdd_ptr;
35 dtype_in *gradAdd_ptr;
36 uint8_t *mask_ptr;
37 uint32_t matrix_m;
38 uint32_t matrix_n;
39 uint32_t mat_ld;
40 uint32_t mask_ld;
41 // dropout_scale_inv = (1-dropout_prob)
44};
45
53template <ln_bwd_fused_kind ln_fused_op_kind_, typename dtype_in_,
54 typename dtype_out_, typename dtype_acc_, typename layer_norm_attr_>
55struct ln_bwd_fused_op_t<ln_fused_op_kind_, dtype_in_, dtype_out_, dtype_acc_,
56 layer_norm_attr_, gpu_arch::Xe> {
57 static constexpr ln_bwd_fused_kind fused_op_kind = ln_fused_op_kind_;
58 using dtype_acc = dtype_acc_;
59 using dtype_in = dtype_in_;
60 using dtype_out = dtype_out_;
63 static constexpr uint32_t wg_tile_m = layer_norm_attr_::wg_tile_m;
64 static constexpr uint32_t wg_tile_n = layer_norm_attr_::wg_tile_n;
65 static constexpr uint32_t sg_tile_m = layer_norm_attr_::sg_tile_m;
66 static constexpr uint32_t sg_tile_n = layer_norm_attr_::sg_tile_n;
67 static constexpr uint32_t wg_num_m = layer_norm_attr_::wg_num_m;
68 static constexpr uint32_t wg_num_n = layer_norm_attr_::wg_num_n;
69
78 __XETLA_API void init([[maybe_unused]] arguments_t *args,
79 [[maybe_unused]] uint32_t wg_idx, [[maybe_unused]] uint32_t wg_idy,
80 [[maybe_unused]] uint32_t sg_idx,
81 [[maybe_unused]] uint32_t sg_idy) {}
82
89 return input;
90 }
91
98 return input;
99 }
100
106 template <typename reduce_t>
107 __XETLA_API void final_op([[maybe_unused]] reduce_t &ln_group_row_reduce) {}
108};
109
116template <typename dtype_in_, typename dtype_out_, typename dtype_acc_,
117 typename layer_norm_attr_>
119 dtype_out_, dtype_acc_, layer_norm_attr_, gpu_arch::Xe> {
120 static constexpr ln_bwd_fused_kind fused_op_kind
122 using dtype_acc = dtype_acc_;
123 using dtype_in = dtype_in_;
124 using dtype_out = dtype_out_;
125 using dtype_mask = uint8_t;
128 static constexpr uint32_t wg_tile_m = layer_norm_attr_::wg_tile_m;
129 static constexpr uint32_t wg_tile_n = layer_norm_attr_::wg_tile_n;
130 static constexpr uint32_t sg_tile_m = layer_norm_attr_::sg_tile_m;
131 static constexpr uint32_t sg_tile_n = layer_norm_attr_::sg_tile_n;
132 static constexpr uint32_t wg_num_m = layer_norm_attr_::wg_num_m;
133 static constexpr uint32_t wg_num_n = layer_norm_attr_::wg_num_n;
134
135 static constexpr uint32_t wg_size_x
136 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
137 static constexpr uint32_t wg_size_y
138 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
139
140 static_assert((sg_tile_n % (sizeof(uint32_t) / sizeof(dtype_mask)) == 0),
141 "sg_tile_n need to be DW aligned");
142 using ln_bwd_tile_desc_t = subgroup::tile_desc_t<sg_tile_n, 1, sg_tile_n, 1,
156 uint32_t mat_ld;
157 uint32_t mask_ld;
158 uint32_t matrix_n;
159 uint32_t matrix_m;
160 int32_t dbias_n;
161 int32_t dbias_m;
166
175 __XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy,
176 uint32_t sg_idx, uint32_t sg_idy) {
177 int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n;
178 int start_m = wg_idy * wg_tile_m + sg_idy * sg_tile_m;
179 dbias = 0;
180 mat_ld = args->mat_ld;
181 mask_ld = args->mask_ld;
182 matrix_n = args->matrix_n;
183 matrix_m = args->matrix_m;
184 dx_resAdd_out_payload.init(args->dx_resAdd_ptr, matrix_n, matrix_m,
185 mat_ld, start_n, start_m);
186 mask_in_payload.init(
187 args->mask_ptr, matrix_n, matrix_m, mask_ld, start_n, start_m);
188 dbias_acc_ptr = args->dbias_acc_ptr;
189 dropout_scale_inv = args->dropout_scale_inv;
190 dropout_prob = args->dropout_prob;
191 dbias_n = start_n;
192 dbias_m = wg_idy;
193 }
194
201 return input;
202 }
203
211 dx_resAdd_out.reg = xetla_cvt<dtype_out, dtype_acc>(input);
212 subgroup::tile_store<cache_hint::uncached>(
213 dx_resAdd_out, dx_resAdd_out_payload);
214 dx_resAdd_out_payload.update_tdesc(wg_num_m * wg_tile_m * mat_ld);
215 if (dropout_prob != 0) {
216 subgroup::tile_load(mask_in, mask_in_payload);
217 SW_BARRIER();
218 mask_in_payload.update_tdesc(wg_num_m * wg_tile_m * mask_ld);
219 output = drop_out<dtype_acc, sg_tile_n>(
220 output, mask_in.reg, dropout_scale_inv);
221 }
222 dbias += output;
223 return output;
224 }
225
231 template <typename reduce_t>
232 __XETLA_API void final_op(reduce_t &ln_group_row_reduce) {
233 ln_group_row_reduce(dbias_acc_ptr, matrix_n, wg_num_m, matrix_n,
234 dbias_n, dbias_m, dbias);
235 }
236};
237
244template <typename dtype_in_, typename dtype_out_, typename dtype_acc_,
245 typename layer_norm_attr_>
247 dtype_out_, dtype_acc_, layer_norm_attr_, gpu_arch::Xe> {
248 static constexpr ln_bwd_fused_kind fused_op_kind
250 using dtype_acc = dtype_acc_;
251 using dtype_in = dtype_in_;
252 using dtype_out = dtype_out_;
253 using dtype_mask = uint8_t;
256 static constexpr uint32_t wg_tile_m = layer_norm_attr_::wg_tile_m;
257 static constexpr uint32_t wg_tile_n = layer_norm_attr_::wg_tile_n;
258 static constexpr uint32_t sg_tile_m = layer_norm_attr_::sg_tile_m;
259 static constexpr uint32_t sg_tile_n = layer_norm_attr_::sg_tile_n;
260 static constexpr uint32_t wg_num_m = layer_norm_attr_::wg_num_m;
261 static constexpr uint32_t wg_num_n = layer_norm_attr_::wg_num_n;
262
263 static constexpr uint32_t wg_size_x
264 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
265 static constexpr uint32_t wg_size_y
266 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
267
268 static_assert((sg_tile_n % (sizeof(uint32_t) / sizeof(dtype_mask)) == 0),
269 "sg_tile_n need to be DW aligned");
270 using ln_bwd_tile_desc_t = subgroup::tile_desc_t<sg_tile_n, 1, sg_tile_n, 1,
280
285
286 uint32_t mat_ld;
287 uint32_t mask_ld;
288 uint32_t matrix_n;
289 uint32_t matrix_m;
292
301 __XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy,
302 uint32_t sg_idx, uint32_t sg_idy) {
303 int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n;
304 int start_m = wg_idy * wg_tile_m + sg_idy * sg_tile_m;
305 mat_ld = args->mat_ld;
306 mask_ld = args->mask_ld;
307 matrix_n = args->matrix_n;
308 matrix_m = args->matrix_m;
309 grad_in_payload.init(args->gradAdd_ptr, matrix_n, matrix_m, mat_ld,
310 start_n, start_m);
311 mask_in_payload.init(
312 args->mask_ptr, matrix_n, matrix_m, mask_ld, start_n, start_m);
313 dropout_scale_inv = args->dropout_scale_inv;
314 dropout_prob = args->dropout_prob;
315 }
316
323 subgroup::tile_load(grad_in, grad_in_payload);
324 grad_in_payload.update_tdesc(wg_num_m * wg_tile_m * mat_ld);
326 = xetla_cvt<dtype_acc, dtype_in>(grad_in.reg);
327 // grad_add
329 = reduce_helper<reduce_op::sum, dtype_acc, sg_tile_n>(
330 input, grad_input);
331 if (dropout_prob != 0) {
332 // dropout
333 subgroup::tile_load(mask_in, mask_in_payload);
334 SW_BARRIER();
335 mask_in_payload.update_tdesc(wg_num_m * wg_tile_m * mask_ld);
336 output = drop_out<dtype_acc, sg_tile_n>(
337 output, mask_in.reg, dropout_scale_inv);
338 }
339 return output;
340 }
341
348 return input;
349 }
350
356 template <typename reduce_t>
357 __XETLA_API void final_op([[maybe_unused]] reduce_t &ln_group_row_reduce) {}
358};
359
366template <typename dtype_in_, typename dtype_out_, typename dtype_acc_,
367 typename layer_norm_attr_>
368struct ln_bwd_fused_op_t<ln_bwd_fused_kind::ln_dropout, dtype_in_, dtype_out_,
369 dtype_acc_, layer_norm_attr_, gpu_arch::Xe> {
370 static constexpr ln_bwd_fused_kind fused_op_kind
372 using dtype_acc = dtype_acc_;
373 using dtype_in = dtype_in_;
374 using dtype_out = dtype_out_;
375 using dtype_mask = uint8_t;
378 static constexpr uint32_t wg_tile_m = layer_norm_attr_::wg_tile_m;
379 static constexpr uint32_t wg_tile_n = layer_norm_attr_::wg_tile_n;
380 static constexpr uint32_t sg_tile_m = layer_norm_attr_::sg_tile_m;
381 static constexpr uint32_t sg_tile_n = layer_norm_attr_::sg_tile_n;
382 static constexpr uint32_t wg_num_m = layer_norm_attr_::wg_num_m;
383 static constexpr uint32_t wg_num_n = layer_norm_attr_::wg_num_n;
384
385 static constexpr uint32_t wg_size_x
386 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
387 static constexpr uint32_t wg_size_y
388 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
389 using ln_bwd_tile_desc_t = subgroup::tile_desc_t<sg_tile_n, 1, sg_tile_n, 1,
395
398 uint32_t matrix_n;
399 uint32_t matrix_m;
400 uint32_t mask_ld;
402
411 __XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy,
412 uint32_t sg_idx, uint32_t sg_idy) {
413 int start_m = wg_idy * wg_tile_m + sg_idy * sg_tile_m;
414 int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n;
415 mask_ld = args->mask_ld;
416 matrix_m = args->matrix_m;
417 matrix_n = args->matrix_n;
418 mask_in_payload.init(
419 args->mask_ptr, matrix_n, matrix_m, mask_ld, start_n, start_m);
420 dropout_scale_inv = args->dropout_scale_inv;
421 }
422
430 // dropout
431 subgroup::tile_load(mask_in, mask_in_payload);
432 SW_BARRIER();
433 mask_in_payload.update_tdesc(wg_num_m * wg_tile_m * mask_ld);
434 output = drop_out<dtype_acc, sg_tile_n>(
435 input, mask_in.reg, dropout_scale_inv);
436 return output;
437 }
438
445 return input;
446 }
447
453 template <typename reduce_t>
454 __XETLA_API void final_op([[maybe_unused]] reduce_t &ln_group_row_reduce) {}
455};
456
457} // namespace gpu::xetla::group
#define SW_BARRIER()
SW_BARRIER, insert software scheduling barrier, for better code control.
Definition common.hpp:227
#define __XETLA_API
Definition common.hpp:43
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
Definition limitation.hpp:607
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
gpu_arch
Definition common.hpp:73
ln_bwd_fused_kind
Definition layer_norm_fused_op_api.hpp:40
Definition layer_norm_fused_op_bwd_xe.hpp:32
dtype_acc * dbias_acc_ptr
Definition layer_norm_fused_op_bwd_xe.hpp:33
dtype_out * dx_resAdd_ptr
Definition layer_norm_fused_op_bwd_xe.hpp:34
uint32_t matrix_m
Definition layer_norm_fused_op_bwd_xe.hpp:37
uint32_t mask_ld
Definition layer_norm_fused_op_bwd_xe.hpp:40
uint32_t matrix_n
Definition layer_norm_fused_op_bwd_xe.hpp:38
float dropout_scale_inv
Definition layer_norm_fused_op_bwd_xe.hpp:43
uint8_t * mask_ptr
Definition layer_norm_fused_op_bwd_xe.hpp:36
uint32_t mat_ld
Definition layer_norm_fused_op_bwd_xe.hpp:39
dtype_in * gradAdd_ptr
Definition layer_norm_fused_op_bwd_xe.hpp:35
float dropout_prob
Definition layer_norm_fused_op_bwd_xe.hpp:42
__XETLA_API xetla_vector< dtype_acc, sg_tile_n > pre_op(xetla_vector< dtype_acc, sg_tile_n > input)
Definition layer_norm_fused_op_bwd_xe.hpp:199
__XETLA_API xetla_vector< dtype_acc, sg_tile_n > post_op(xetla_vector< dtype_acc, sg_tile_n > input)
Definition layer_norm_fused_op_bwd_xe.hpp:208
__XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy)
Definition layer_norm_fused_op_bwd_xe.hpp:175
__XETLA_API xetla_vector< dtype_acc, sg_tile_n > pre_op(xetla_vector< dtype_acc, sg_tile_n > input)
Definition layer_norm_fused_op_bwd_xe.hpp:427
__XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy)
Definition layer_norm_fused_op_bwd_xe.hpp:411
__XETLA_API void final_op(reduce_t &ln_group_row_reduce)
Definition layer_norm_fused_op_bwd_xe.hpp:454
__XETLA_API xetla_vector< dtype_acc, sg_tile_n > post_op(xetla_vector< dtype_acc, sg_tile_n > input)
Definition layer_norm_fused_op_bwd_xe.hpp:443
__XETLA_API xetla_vector< dtype_acc, sg_tile_n > pre_op(xetla_vector< dtype_acc, sg_tile_n > input)
Definition layer_norm_fused_op_bwd_xe.hpp:321
__XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy)
Definition layer_norm_fused_op_bwd_xe.hpp:301
__XETLA_API xetla_vector< dtype_acc, sg_tile_n > post_op(xetla_vector< dtype_acc, sg_tile_n > input)
Definition layer_norm_fused_op_bwd_xe.hpp:346
__XETLA_API xetla_vector< dtype_acc, sg_tile_n > post_op(xetla_vector< dtype_acc, sg_tile_n > input)
Definition layer_norm_fused_op_bwd_xe.hpp:96
__XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy)
Definition layer_norm_fused_op_bwd_xe.hpp:78
__XETLA_API xetla_vector< dtype_acc, sg_tile_n > pre_op(xetla_vector< dtype_acc, sg_tile_n > input)
Definition layer_norm_fused_op_bwd_xe.hpp:87
__XETLA_API void final_op(reduce_t &ln_group_row_reduce)
Definition layer_norm_fused_op_bwd_xe.hpp:107
Definition layer_norm_fused_op_api.hpp:73
Definition memory_descriptor.hpp:139
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102