XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
layer_norm_fused_op_fwd_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
23
24namespace gpu::xetla::group {
25
31template <typename dtype_in, typename dtype_out, typename dtype_acc>
33 dtype_in *bias_ptr;
34 dtype_in *res_add_ptr;
36 uint8_t *mask_ptr;
37 uint32_t matrix_m;
38 uint32_t matrix_n;
39 uint32_t mat_ld;
40 uint32_t mask_ld;
41 uint64_t rand_seed = 67280421310721;
42 uint64_t *rand_offset_ptr;
44 // dropout_scale = 1 / (1-dropout_prob)
46};
47
55template <ln_fwd_fused_kind ln_fused_op_kind_, typename dtype_in_,
56 typename dtype_out_, typename dtype_acc_, typename layer_norm_attr_>
57struct ln_fwd_fused_op_t<ln_fused_op_kind_, dtype_in_, dtype_out_, dtype_acc_,
58 layer_norm_attr_, gpu_arch::Xe> {
59 static constexpr ln_fwd_fused_kind fused_op_kind = ln_fused_op_kind_;
60 using dtype_acc = dtype_acc_;
61 using dtype_in = dtype_in_;
62 using dtype_out = dtype_out_;
65 static constexpr uint32_t wg_tile_m = layer_norm_attr_::wg_tile_m;
66 static constexpr uint32_t wg_tile_n = layer_norm_attr_::wg_tile_n;
67 static constexpr uint32_t sg_tile_m = layer_norm_attr_::sg_tile_m;
68 static constexpr uint32_t sg_tile_n = layer_norm_attr_::sg_tile_n;
69 static constexpr uint32_t wg_num_m = layer_norm_attr_::wg_num_m;
70 static constexpr uint32_t wg_num_n = layer_norm_attr_::wg_num_n;
71 static constexpr uint32_t chunk_size = layer_norm_attr_::chunk_size;
72 static constexpr uint32_t n_chunks = sg_tile_n / chunk_size;
73
82 __XETLA_API void init([[maybe_unused]] arguments_t *args,
83 [[maybe_unused]] uint32_t wg_idx, [[maybe_unused]] uint32_t wg_idy,
84 [[maybe_unused]] uint32_t sg_idx, [[maybe_unused]] uint32_t sg_idy,
85 [[maybe_unused]] uint32_t start_m) {}
86
93 return input;
94 }
95
102 return input;
103 }
104};
109template <typename dtype_in_, typename dtype_out_, typename dtype_acc_,
110 typename layer_norm_attr_>
112 dtype_out_, dtype_acc_, layer_norm_attr_, gpu_arch::Xe> {
113 static constexpr ln_fwd_fused_kind fused_op_kind
115 using dtype_acc = dtype_acc_;
116 using dtype_in = dtype_in_;
117 using dtype_out = dtype_out_;
118 using dtype_mask = uint8_t;
121 static constexpr uint32_t wg_tile_m = layer_norm_attr_::wg_tile_m;
122 static constexpr uint32_t wg_tile_n = layer_norm_attr_::wg_tile_n;
123 static constexpr uint32_t sg_tile_m = layer_norm_attr_::sg_tile_m;
124 static constexpr uint32_t sg_tile_n = layer_norm_attr_::sg_tile_n;
125 static constexpr uint32_t wg_num_m = layer_norm_attr_::wg_num_m;
126 static constexpr uint32_t wg_num_n = layer_norm_attr_::wg_num_n;
127 static constexpr uint32_t chunk_size = layer_norm_attr_::chunk_size;
128 static constexpr uint32_t n_chunks = sg_tile_n / chunk_size;
129
130 static constexpr uint32_t wg_size_x
131 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
132 static constexpr uint32_t wg_size_y
133 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
134
135 static_assert((sg_tile_n % (sizeof(uint32_t) / sizeof(dtype_mask)) == 0),
136 "sg_tile_n need to be DW aligned");
137
138 using ln_fwd_tile_desc_t = subgroup::tile_desc_t<chunk_size, 1, chunk_size,
165 uint32_t mat_ld;
166 uint32_t mask_ld;
167 uint32_t matrix_n;
168 uint32_t matrix_m;
171
180 __XETLA_API void init(arguments_t *args, uint32_t wg_idx,
181 [[maybe_unused]] uint32_t wg_idy, uint32_t sg_idx,
182 [[maybe_unused]] uint32_t sg_idy, uint32_t start_m) {
183 int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n;
184 mat_ld = args->mat_ld;
185 mask_ld = args->mask_ld;
186 matrix_n = args->matrix_n;
187 matrix_m = args->matrix_m;
188 dropout_scale = args->dropout_scale;
189 dropout_prob = args->dropout_prob;
190 bias_in_payload.init(args->bias_ptr, matrix_n, 1, mat_ld, start_n, 0);
191 res_in_payload.init(args->res_add_ptr, matrix_n, matrix_m, mat_ld,
192 start_n, start_m);
193 bias_dropout_res_out_payload.init(args->bias_dropout_res_ptr, matrix_n,
194 matrix_m, mat_ld, start_n, start_m);
195 mask_in_payload.init(
196 args->mask_ptr, matrix_n, matrix_m, mask_ld, start_n, start_m);
197 if constexpr (n_chunks == 1) {
198 subgroup::tile_load(bias_in, bias_in_payload);
199 }
200 }
201
208 subgroup::tile_load(res_in, res_in_payload);
209 if constexpr (n_chunks == 1) {
210 res_in_payload.update_tdesc(wg_num_m * wg_tile_m * mat_ld);
211 } else {
212 res_in_payload.update_tdesc(chunk_size);
213 }
214 if constexpr (n_chunks != 1) {
215 subgroup::tile_load(bias_in, bias_in_payload);
216 bias_in_payload.update_tdesc(chunk_size);
217 }
219 = xetla_cvt<dtype_acc, dtype_in>(bias_in.reg);
220 // bias_add
222 = reduce_helper<reduce_op::sum, dtype_acc, chunk_size>(
223 input, bias_input);
224 if (dropout_prob != 0) {
225 // dropout
226 subgroup::tile_load(mask_in, mask_in_payload);
227 SW_BARRIER();
228 if constexpr (n_chunks == 1) {
229 mask_in_payload.update_tdesc(wg_num_m * wg_tile_m * mask_ld);
230 } else {
231 mask_in_payload.update_tdesc(chunk_size);
232 }
233 output = drop_out<dtype_acc, chunk_size>(
234 output, mask_in.reg, dropout_scale);
235 }
236 // res_add, generate mixed mode
238 = xetla_cvt<dtype_acc, dtype_in>(res_in.reg);
239 output = reduce_helper<reduce_op::sum, dtype_acc, chunk_size>(
240 output, res_input);
241 bias_dropout_res_out.reg = xetla_cvt<dtype_out, dtype_acc>(output);
242 subgroup::tile_store<cache_hint::uncached>(
243 bias_dropout_res_out, bias_dropout_res_out_payload);
244 if constexpr (n_chunks == 1) {
245 bias_dropout_res_out_payload.update_tdesc(
246 wg_num_m * wg_tile_m * mat_ld);
247 } else {
248 bias_dropout_res_out_payload.update_tdesc(chunk_size);
249 }
250 return output;
251 }
252
259 return input;
260 }
261};
262
269template <typename dtype_in_, typename dtype_out_, typename dtype_acc_,
270 typename layer_norm_attr_>
271struct ln_fwd_fused_op_t<ln_fwd_fused_kind::ln_dropout, dtype_in_, dtype_out_,
272 dtype_acc_, layer_norm_attr_, gpu_arch::Xe> {
273 static constexpr ln_fwd_fused_kind fused_op_kind
275 using dtype_acc = dtype_acc_;
276 using dtype_in = dtype_in_;
277 using dtype_out = dtype_out_;
278 using dtype_mask = uint8_t;
281 static constexpr uint32_t wg_tile_m = layer_norm_attr_::wg_tile_m;
282 static constexpr uint32_t wg_tile_n = layer_norm_attr_::wg_tile_n;
283 static constexpr uint32_t sg_tile_m = layer_norm_attr_::sg_tile_m;
284 static constexpr uint32_t sg_tile_n = layer_norm_attr_::sg_tile_n;
285 static constexpr uint32_t wg_num_m = layer_norm_attr_::wg_num_m;
286 static constexpr uint32_t wg_num_n = layer_norm_attr_::wg_num_n;
287 static constexpr uint32_t chunk_size = layer_norm_attr_::chunk_size;
288 static constexpr uint32_t n_chunks = sg_tile_n / chunk_size;
289
290 static constexpr uint32_t wg_size_x
291 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
292 static constexpr uint32_t wg_size_y
293 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
294
295 static_assert((sg_tile_n % (sizeof(uint32_t) / sizeof(dtype_mask)) == 0),
296 "sg_tile_n need to be DW aligned");
297
298 using ln_fwd_tile_desc_t = subgroup::tile_desc_t<chunk_size, 1, chunk_size,
306 uint32_t mask_ld;
307 uint32_t matrix_m;
308 uint32_t matrix_n;
310
319 __XETLA_API void init(arguments_t *args, uint32_t wg_idx,
320 [[maybe_unused]] uint32_t wg_idy, uint32_t sg_idx,
321 [[maybe_unused]] uint32_t sg_idy, uint32_t start_m) {
322 int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n;
323 dropout_scale = args->dropout_scale;
324 mask_ld = args->mask_ld;
325 matrix_m = args->matrix_m;
326 matrix_n = args->matrix_n;
327 mask_in_payload.init(
328 args->mask_ptr, matrix_n, matrix_m, mask_ld, start_n, start_m);
329 }
330
337 return input;
338 }
339
346 // dropout
347 subgroup::tile_load(mask_in, mask_in_payload);
348 SW_BARRIER();
349 if constexpr (n_chunks == 1) {
350 mask_in_payload.update_tdesc(wg_num_m * wg_tile_m * mask_ld);
351 } else {
352 mask_in_payload.update_tdesc(chunk_size);
353 }
355 = drop_out<dtype_acc, chunk_size>(
356 input, mask_in.reg, dropout_scale);
357 return output;
358 }
359};
360
367template <typename dtype_in_, typename dtype_out_, typename dtype_acc_,
368 typename layer_norm_attr_>
370 dtype_in_, dtype_out_, dtype_acc_, layer_norm_attr_, gpu_arch::Xe> {
371 static constexpr ln_fwd_fused_kind fused_op_kind
373 using dtype_acc = dtype_acc_;
374 using dtype_in = dtype_in_;
375 using dtype_out = dtype_out_;
376 using dtype_mask = uint8_t;
379 static constexpr uint32_t wg_tile_m = layer_norm_attr_::wg_tile_m;
380 static constexpr uint32_t wg_tile_n = layer_norm_attr_::wg_tile_n;
381 static constexpr uint32_t sg_tile_m = layer_norm_attr_::sg_tile_m;
382 static constexpr uint32_t sg_tile_n = layer_norm_attr_::sg_tile_n;
383 static constexpr uint32_t wg_num_m = layer_norm_attr_::wg_num_m;
384 static constexpr uint32_t wg_num_n = layer_norm_attr_::wg_num_n;
385 static constexpr uint32_t chunk_size = layer_norm_attr_::chunk_size;
386 static constexpr uint32_t n_chunks = sg_tile_n / chunk_size;
387
388 static constexpr uint32_t wg_size_x
389 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
390 static constexpr uint32_t wg_size_y
391 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
392
393 using ln_fwd_tile_desc_t = subgroup::tile_desc_t<chunk_size, 1, chunk_size,
412
421 uint32_t mat_ld;
422 uint32_t mask_ld;
423 uint32_t matrix_n;
424 uint32_t matrix_m;
427
436 __XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy,
437 uint32_t sg_idx, uint32_t sg_idy, uint32_t start_m) {
438 int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n;
439 xetla_vector<uint64_t, 1> rand_offset_ptr_v
442 args->rand_offset_ptr, 0);
443 mat_ld = args->mat_ld;
444 mask_ld = args->mask_ld;
445 matrix_n = args->matrix_n;
446 matrix_m = args->matrix_m;
447 uint32_t threshold = uint32_t(args->dropout_prob * float(4294967296));
448 dropout_prob = args->dropout_prob;
449 bias_in_payload.init(args->bias_ptr, matrix_n, 1, mat_ld, start_n, 0);
450 res_in_payload.init(args->res_add_ptr, matrix_n, matrix_m, mat_ld,
451 start_n, start_m);
452 bias_dropout_res_out_payload.init(args->bias_dropout_res_ptr, matrix_n,
453 matrix_m, mat_ld, start_n, start_m);
454 mask_out_payload.init(
455 args->mask_ptr, matrix_n, matrix_m, mask_ld, start_n, start_m);
456 int linear_idx = (wg_idy * wg_size_y + sg_idy) * (wg_size_x * wg_num_n)
457 + wg_idx * wg_size_x + sg_idx;
458 dropout_fwd.init(args->rand_seed, linear_idx, rand_offset_ptr_v[0],
459 threshold, args->dropout_scale);
460 if constexpr (n_chunks == 1) {
461 subgroup::tile_load(bias_in, bias_in_payload);
462 }
463 }
464
471 subgroup::tile_load(res_in, res_in_payload);
472 if constexpr (n_chunks == 1) {
473 res_in_payload.update_tdesc(wg_num_m * wg_tile_m * mat_ld);
474 } else {
475 res_in_payload.update_tdesc(chunk_size);
476 }
477 if constexpr (n_chunks != 1) {
478 subgroup::tile_load(bias_in, bias_in_payload);
479 bias_in_payload.update_tdesc(chunk_size);
480 }
482 = xetla_cvt<dtype_acc, dtype_in>(bias_in.reg);
483 // bias_add
485 = reduce_helper<reduce_op::sum, dtype_acc, chunk_size>(
486 input, bias_input);
487 if (dropout_prob != 0) {
488 // dropout
489 output = dropout_fwd.template process<dtype_acc>(output);
490 mask_out.reg = dropout_fwd.get_mask();
491 subgroup::tile_store<cache_hint::uncached>(
492 mask_out, mask_out_payload);
493 if constexpr (n_chunks == 1) {
494 mask_out_payload.update_tdesc(wg_num_m * wg_tile_m * mask_ld);
495 } else {
496 mask_out_payload.update_tdesc(chunk_size);
497 }
498 }
499 // res_add, generate mixed mode
501 = xetla_cvt<dtype_acc, dtype_in>(res_in.reg);
502 output = reduce_helper<reduce_op::sum, dtype_acc, chunk_size>(
503 output, res_input);
504 bias_dropout_res_out.reg = xetla_cvt<dtype_out, dtype_acc>(output);
505 subgroup::tile_store<cache_hint::uncached>(
506 bias_dropout_res_out, bias_dropout_res_out_payload);
507 if constexpr (n_chunks == 1) {
508 bias_dropout_res_out_payload.update_tdesc(
509 wg_num_m * wg_tile_m * mat_ld);
510 } else {
511 bias_dropout_res_out_payload.update_tdesc(chunk_size);
512 }
513 return output;
514 }
515
522 return input;
523 }
524};
525
532template <typename dtype_in_, typename dtype_out_, typename dtype_acc_,
533 typename layer_norm_attr_>
535 dtype_out_, dtype_acc_, layer_norm_attr_, gpu_arch::Xe> {
536 static constexpr ln_fwd_fused_kind fused_op_kind
538 using dtype_acc = dtype_acc_;
539 using dtype_in = dtype_in_;
540 using dtype_out = dtype_out_;
541 using dtype_mask = uint8_t;
544 static constexpr uint32_t wg_tile_m = layer_norm_attr_::wg_tile_m;
545 static constexpr uint32_t wg_tile_n = layer_norm_attr_::wg_tile_n;
546 static constexpr uint32_t sg_tile_m = layer_norm_attr_::sg_tile_m;
547 static constexpr uint32_t sg_tile_n = layer_norm_attr_::sg_tile_n;
548 static constexpr uint32_t wg_num_m = layer_norm_attr_::wg_num_m;
549 static constexpr uint32_t wg_num_n = layer_norm_attr_::wg_num_n;
550 static constexpr uint32_t chunk_size = layer_norm_attr_::chunk_size;
551 static constexpr uint32_t n_chunks = sg_tile_n / chunk_size;
552 static_assert(sg_tile_n % chunk_size == 0,
553 "Current impl does not support tailing mechanism");
554
555 static constexpr uint32_t wg_size_x
556 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
557 static constexpr uint32_t wg_size_y
558 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
559
561 chunk_size, 1, reg_layout::tiled>;
569 uint32_t mask_ld;
570 uint32_t matrix_m;
571 uint32_t matrix_n;
572
581 __XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy,
582 uint32_t sg_idx, uint32_t sg_idy, uint32_t start_m) {
583 int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n;
584 xetla_vector<uint64_t, 1> rand_offset_ptr_v
587 args->rand_offset_ptr, 0);
588 mask_ld = args->mask_ld;
589 matrix_m = args->matrix_m;
590 matrix_n = args->matrix_n;
591 uint32_t threshold = uint32_t(args->dropout_prob * float(4294967296));
592 mask_out_payload.init(
593 args->mask_ptr, matrix_n, matrix_m, mask_ld, start_n, start_m);
594 int linear_idx = (wg_idy * wg_size_y + sg_idy) * (wg_size_x * wg_num_n)
595 + wg_idx * wg_size_x + sg_idx;
596 dropout_fwd.init(args->rand_seed, linear_idx, rand_offset_ptr_v[0],
597 threshold, args->dropout_scale);
598 }
599
606 return input;
607 }
608
616 = dropout_fwd.template process<dtype_acc>(input);
617 mask_out.reg = dropout_fwd.get_mask();
618 subgroup::tile_store<cache_hint::uncached>(mask_out, mask_out_payload);
619 if constexpr (n_chunks == 1) {
620 mask_out_payload.update_tdesc(wg_num_m * wg_tile_m * mask_ld);
621 } else {
622 mask_out_payload.update_tdesc(chunk_size);
623 }
624 return output;
625 }
626};
627
628} // namespace gpu::xetla::group
#define SW_BARRIER()
SW_BARRIER, insert software scheduling barrier, for better code control.
Definition common.hpp:227
#define __XETLA_API
Definition common.hpp:43
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__XETLA_API xetla_vector< Ty, N *NElts > xetla_load_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_mask< N > pred=1)
Stateless scattered load.
Definition memory.hpp:245
Definition limitation.hpp:607
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
ln_fwd_fused_kind
Definition layer_norm_fused_op_api.hpp:28
gpu_arch
Definition common.hpp:73
Definition rand.hpp:116
__XETLA_API void init(uint64_t seed, uint64_t subseq, uint64_t offset, uint32_t threshold_, float scale_)
Definition rand.hpp:123
__XETLA_API xetla_vector< dtype_mask, SZ > get_mask()
Definition rand.hpp:159
Definition layer_norm_fused_op_fwd_xe.hpp:32
uint8_t * mask_ptr
Definition layer_norm_fused_op_fwd_xe.hpp:36
dtype_in * res_add_ptr
Definition layer_norm_fused_op_fwd_xe.hpp:34
uint32_t mat_ld
Definition layer_norm_fused_op_fwd_xe.hpp:39
dtype_out * bias_dropout_res_ptr
Definition layer_norm_fused_op_fwd_xe.hpp:35
uint32_t matrix_m
Definition layer_norm_fused_op_fwd_xe.hpp:37
uint32_t matrix_n
Definition layer_norm_fused_op_fwd_xe.hpp:38
float dropout_scale
Definition layer_norm_fused_op_fwd_xe.hpp:45
uint64_t * rand_offset_ptr
Definition layer_norm_fused_op_fwd_xe.hpp:42
uint64_t rand_seed
Definition layer_norm_fused_op_fwd_xe.hpp:41
float dropout_prob
Definition layer_norm_fused_op_fwd_xe.hpp:43
dtype_in * bias_ptr
Definition layer_norm_fused_op_fwd_xe.hpp:33
uint32_t mask_ld
Definition layer_norm_fused_op_fwd_xe.hpp:40
__XETLA_API xetla_vector< dtype_acc, chunk_size > pre_op(xetla_vector< dtype_acc, chunk_size > input)
Definition layer_norm_fused_op_fwd_xe.hpp:91
__XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy, uint32_t start_m)
Definition layer_norm_fused_op_fwd_xe.hpp:82
__XETLA_API xetla_vector< dtype_acc, chunk_size > post_op(xetla_vector< dtype_acc, chunk_size > input)
Definition layer_norm_fused_op_fwd_xe.hpp:100
__XETLA_API xetla_vector< dtype_acc, chunk_size > post_op(xetla_vector< dtype_acc, chunk_size > input)
Definition layer_norm_fused_op_fwd_xe.hpp:257
__XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy, uint32_t start_m)
Definition layer_norm_fused_op_fwd_xe.hpp:180
__XETLA_API xetla_vector< dtype_acc, chunk_size > pre_op(xetla_vector< dtype_acc, chunk_size > input)
Definition layer_norm_fused_op_fwd_xe.hpp:206
__XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy, uint32_t start_m)
Definition layer_norm_fused_op_fwd_xe.hpp:436
__XETLA_API xetla_vector< dtype_acc, chunk_size > pre_op(xetla_vector< dtype_acc, chunk_size > input)
Definition layer_norm_fused_op_fwd_xe.hpp:469
__XETLA_API xetla_vector< dtype_acc, chunk_size > post_op(xetla_vector< dtype_acc, chunk_size > input)
Definition layer_norm_fused_op_fwd_xe.hpp:520
__XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy, uint32_t start_m)
Definition layer_norm_fused_op_fwd_xe.hpp:319
__XETLA_API xetla_vector< dtype_acc, chunk_size > pre_op(xetla_vector< dtype_acc, chunk_size > input)
Definition layer_norm_fused_op_fwd_xe.hpp:335
__XETLA_API xetla_vector< dtype_acc, chunk_size > post_op(xetla_vector< dtype_acc, chunk_size > input)
Definition layer_norm_fused_op_fwd_xe.hpp:344
__XETLA_API void init(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy, uint32_t start_m)
Definition layer_norm_fused_op_fwd_xe.hpp:581
__XETLA_API xetla_vector< dtype_acc, chunk_size > pre_op(xetla_vector< dtype_acc, chunk_size > input)
Definition layer_norm_fused_op_fwd_xe.hpp:604
__XETLA_API xetla_vector< dtype_acc, chunk_size > post_op(xetla_vector< dtype_acc, chunk_size > input)
Definition layer_norm_fused_op_fwd_xe.hpp:613
Definition layer_norm_fused_op_api.hpp:60
Definition memory_descriptor.hpp:139
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102