XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
tile_op_functor.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "subgroup/tile/api.hpp"
28
29namespace gpu::xetla::subgroup {
30
33struct none_op_t {
34 struct arguments_t {};
35 template <typename matAcc_t, typename coord_t>
36 __XETLA_API KERNEL_FUNC void operator()([[maybe_unused]] matAcc_t &matAcc,
37 [[maybe_unused]] const coord_t &coord,
38 [[maybe_unused]] const arguments_t &args,
39 [[maybe_unused]] uint32_t slm_base = 0,
40 [[maybe_unused]] uint32_t nbarrier_base = 0) {}
41};
42
46struct relu_op_t {
47 struct arguments_t {};
48 template <typename matAcc_t, typename coord_t>
49 __XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc,
50 [[maybe_unused]] const coord_t &coord,
51 [[maybe_unused]] const arguments_t &args,
52 [[maybe_unused]] uint32_t slm_base = 0,
53 [[maybe_unused]] uint32_t nbarrier_base = 0) {
54 xetla_mask<matAcc_t::tile_elems> mask = matAcc.reg <= 0;
55 matAcc.reg.xetla_merge(0, mask);
56 }
57};
58
62struct tanh_op_t {
63 struct arguments_t {};
64 template <typename matAcc_t, typename coord_t>
65 __XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc,
66 [[maybe_unused]] const coord_t &coord,
67 [[maybe_unused]] uint32_t slm_base = 0,
68 [[maybe_unused]] uint32_t nbarrier_base = 0) {
69 constexpr int elems = matAcc_t::tile_desc::block_elems;
70 constexpr int rounds = matAcc_t::tile_desc::tile_elems / elems;
71 using dtype = typename matAcc_t::dtype;
72#pragma unroll
73 for (uint32_t i = 0; i < rounds; ++i) {
74 auto sub_vec = matAcc.reg.xetla_select<elems, 1>(elems * i);
75 sub_vec = xetla_tanh<dtype, elems>(sub_vec);
76 }
77 constexpr int remained_elems = matAcc_t::tile_desc::tile_elems % elems;
78 if constexpr (remained_elems != 0) {
79 auto sub_vec = matAcc.reg.xetla_select<remained_elems, 1>(
80 elems * (matAcc_t::tile_elems / elems));
81 sub_vec = xetla_tanh<dtype, remained_elems>(sub_vec);
82 }
83 }
84};
85
90 struct arguments_t {};
91 template <typename matAcc_t, typename coord_t>
92 __XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc,
93 [[maybe_unused]] const coord_t &coord,
94 [[maybe_unused]] uint32_t slm_base = 0,
95 [[maybe_unused]] uint32_t nbarrier_base = 0) {
96 constexpr int elems = matAcc_t::tile_desc::block_elems;
97 constexpr int rounds = matAcc_t::tile_desc::tile_elems / elems;
98 constexpr float one = 1.0f;
99#pragma unroll
100 for (uint32_t i = 0; i < rounds; ++i) {
101 auto sub_vec = matAcc.reg.xetla_select<elems, 1>(elems * i);
102 xetla_mask<elems> mask = sub_vec >= 10;
104 = xetla_exp<typename matAcc_t::dtype, elems>(sub_vec);
106 = temp_vec / (temp_vec + one);
107 sigmoid_value.xetla_merge(1, mask);
108 sub_vec = sigmoid_value;
109 }
110 constexpr int remained_elems = matAcc_t::tile_desc::tile_elems % elems;
111 if constexpr (remained_elems != 0) {
112 auto sub_vec = matAcc.reg.xetla_select<remained_elems, 1>(
113 elems * (matAcc_t::tile_elems / elems));
114 xetla_mask<remained_elems> mask = sub_vec >= 250;
116 = xetla_exp<typename matAcc_t::dtype, remained_elems>(
117 sub_vec);
119 = temp_vec / (temp_vec + one);
120 sigmoid_value.xetla_merge(1, mask);
121 sub_vec = sigmoid_value;
122 }
123 }
124};
125
130 struct arguments_t {};
131 template <typename matAcc_t, typename coord_t>
132 __XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc,
133 [[maybe_unused]] const coord_t &coord,
134 [[maybe_unused]] const arguments_t &args,
135 [[maybe_unused]] uint32_t slm_base = 0,
136 [[maybe_unused]] uint32_t nbarrier_base = 0) {
137 using dtype = typename matAcc_t::dtype;
138 constexpr dtype C0 = 0.044715f;
139 constexpr dtype sqrt_two_over_pi = 0.79788458347320556640625f;
140 // total flag register
141 constexpr int elems = 8 * 16;
142 constexpr int rounds = matAcc_t::tile_elems / elems;
143#pragma unroll
144 for (uint32_t i = 0; i < rounds; ++i) {
145 auto sub_vec = matAcc.reg.xetla_select<elems, 1>(elems * i);
146 xetla_vector<dtype, elems> sub_vec_x = (sqrt_two_over_pi * sub_vec
147 * (1.f + C0 * sub_vec * sub_vec));
149 = xetla_tanh<dtype, elems>(sub_vec_x);
150 sub_vec = 0.5f * sub_vec * (1.f + tanh_value);
151 }
152 constexpr int remained_elems = matAcc_t::tile_elems % elems;
153 if constexpr (remained_elems != 0) {
154 auto sub_vec = matAcc.reg.xetla_select<remained_elems, 1>(
155 elems * (matAcc_t::tile_elems / elems));
156 xetla_vector<dtype, remained_elems> sub_vec_x = (sqrt_two_over_pi
157 * sub_vec * (1.f + C0 * sub_vec * sub_vec));
159 = xetla_tanh<dtype, remained_elems>(sub_vec_x);
160 sub_vec = 0.5f * sub_vec * (1.f + tanh_value);
161 }
162 }
163};
164
171template <typename dtype_out, gpu_arch arch_tag, class enable = void>
174template <typename dtype_out_, gpu_arch arch_tag>
175struct gelu_fwd_w_op_t<dtype_out_, arch_tag,
176 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
177 using dtype_out = dtype_out_;
180 using shape_t = typename mem_desc_w_t::shape_t;
181 using coord_t = typename mem_desc_w_t::coord_t;
182 using base_t = typename mem_desc_w_t::base_t;
183
184 struct arguments_t {
187 inline arguments_t() = default;
188 inline arguments_t(base_t base_, shape_t shape_)
189 : shape(shape_), base(base_) {}
190 };
191 template <typename matAcc_t>
192 __XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc,
193 const coord_t &coord, const arguments_t &args,
194 [[maybe_unused]] uint32_t slm_base = 0,
195 [[maybe_unused]] uint32_t nbarrier_base = 0) {
196 using dtype_acc = typename matAcc_t::dtype;
197 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
198 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
199 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
200 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
201
202 mem_desc_w_t mem_desc_w(args.base, args.shape, coord);
203 using bwd_w_tile_desc_t = tile_desc_t<block_size_x, block_size_y,
204 block_size_x, block_size_y, reg_layout::tiled>;
205 using bwd_w_tile_t = tile_t<dtype_out, bwd_w_tile_desc_t>;
206 using bwd_w_payload_t = mem_payload_t<mem_desc_w_t, bwd_w_tile_desc_t,
207 msg_type::block_2d, arch_tag>;
208 bwd_w_tile_t bwd_w;
209 bwd_w_payload_t bwd_w_payload(mem_desc_w);
210 // start compute
211 constexpr dtype_acc c0 = 0.044715f;
212 constexpr dtype_acc d0 = 0.134145f;
213 constexpr dtype_acc sqrt_two_over_pi = 0.79788458347320556640625f;
214 constexpr uint32_t block_elems = matAcc_t::block_elems;
215 constexpr uint32_t num_block_x = matAcc_t::num_block_x;
216#pragma unroll
217 for (uint32_t i = 0; i < tile_size_y / block_size_y; ++i) {
218#pragma unroll
219 for (uint32_t j = 0; j < num_block_x; ++j) {
220 auto x = matAcc.reg.xetla_select<block_elems, 1>(
221 block_elems * (i * num_block_x + j));
223 = xetla_tanh<dtype_acc, block_elems>(
224 sqrt_two_over_pi * (x + c0 * x * x * x));
225 xetla_vector<dtype_acc, block_elems> w = (0.5f * (1.f + z)
226 + 0.5f * x * (1.f - z * z)
227 * (sqrt_two_over_pi * (1.f + d0 * x * x)));
228 x = 0.5f * x * (1.f + z);
229 bwd_w.reg = xetla_cvt<dtype_out, dtype_acc, block_elems>(w);
230 tile_store<cache_hint::uncached>(bwd_w, bwd_w_payload);
231 bwd_w_payload.template update_tdesc<tdesc_update_dir::x_dir>(
232 block_size_x);
233 }
234 bwd_w_payload.template update_tdesc<tdesc_update_dir::x_dir>(
235 -1 * tile_size_x);
236 bwd_w_payload.template update_tdesc<tdesc_update_dir::y_dir>(
237 block_size_y);
238 }
239 if constexpr (tile_size_y % block_size_y != 0) {
240 constexpr uint32_t remain_size_y = tile_size_y % block_size_y;
241 constexpr uint32_t remain_y_start
242 = tile_size_y / block_size_y * block_size_y;
243 constexpr uint32_t remain_elems_start
244 = remain_y_start * tile_size_x;
245 constexpr uint32_t remain_block_elems
246 = remain_size_y * block_size_x;
247
248 using remain_bwd_w_tile_desc_t
249 = tile_desc_t<block_size_x, remain_size_y, block_size_x,
250 remain_size_y, reg_layout::tiled>;
251 using remain_bwd_w_tile_t
253 using remain_bwd_w_payload_t = mem_payload_t<mem_desc_w_t,
254 remain_bwd_w_tile_desc_t, msg_type::block_2d, arch_tag>;
255
256 mem_desc_w.update_coord_y(remain_y_start);
257 remain_bwd_w_payload_t remain_bwd_w_payload(mem_desc_w);
258 remain_bwd_w_tile_t remain_bwd_w;
259#pragma unroll
260 for (uint32_t j = 0; j < num_block_x; ++j) {
261 auto x = matAcc.reg.xetla_select<remain_block_elems, 1>(
262 remain_elems_start + remain_block_elems * j);
264 = xetla_tanh<dtype_acc, remain_block_elems>(
265 sqrt_two_over_pi * (x + c0 * x * x * x));
267 * (1.f + z)
268 + 0.5f * x * (1.f - z * z)
269 * (sqrt_two_over_pi * (1.f + d0 * x * x)));
270 x = 0.5f * x * (1.f + z);
271 remain_bwd_w.reg
272 = xetla_cvt<dtype_out, dtype_acc, remain_block_elems>(
273 w);
274 tile_store<cache_hint::uncached>(
275 remain_bwd_w, remain_bwd_w_payload);
276 remain_bwd_w_payload
277 .template update_tdesc<tdesc_update_dir::x_dir>(
278 block_size_x);
279 }
280 }
281 }
282};
283
289template <typename dtype_in, gpu_arch arch_tag, class enable = void>
292template <typename dtype_in_, gpu_arch arch_tag>
293struct gelu_bwd_op_t<dtype_in_, arch_tag,
294 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
295 using dtype_in = dtype_in_;
298 using shape_t = typename mem_desc_x_t::shape_t;
299 using coord_t = typename mem_desc_x_t::coord_t;
300 using base_t = typename mem_desc_x_t::base_t;
301 struct arguments_t {
304 inline arguments_t() = default;
305 inline arguments_t(base_t base_, shape_t shape_)
306 : shape(shape_), base(base_) {}
307 };
308 template <typename matAcc_t>
309 __XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc,
310 const coord_t &coord, const arguments_t &args,
311 [[maybe_unused]] uint32_t slm_base = 0,
312 [[maybe_unused]] uint32_t nbarrier_base = 0) {
313 using dtype_acc = typename matAcc_t::dtype;
314 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
315 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
316 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
317 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
318
319 using bwd_x_tile_desc_t = tile_desc_t<tile_size_x, tile_size_y,
320 block_size_x, block_size_y, reg_layout::tiled>;
321 using bwd_x_tile_t = tile_t<dtype_in, bwd_x_tile_desc_t>;
322 using bwd_x_payload_t = mem_payload_t<mem_desc_x_t, bwd_x_tile_desc_t,
323 msg_type::block_2d, arch_tag>;
324 bwd_x_tile_t bwd_x;
325 // init tdesc
326 mem_desc_x_t mem_desc_x(args.base, args.shape, coord);
327 bwd_x_payload_t bwd_x_payload(mem_desc_x);
328 tile_load<cache_hint::cached, cache_hint::cached>(bwd_x, bwd_x_payload);
329 // start compute
330 constexpr dtype_acc c0 = 0.044715f;
331 constexpr dtype_acc d0 = 0.134145f;
332 constexpr dtype_acc sqrt_two_over_pi = 0.79788458347320556640625f;
333 constexpr uint32_t block_elems = matAcc_t::block_elems;
334 constexpr uint32_t num_block_x = matAcc_t::num_block_x;
335#pragma unroll
336 for (uint32_t i = 0; i < tile_size_y / block_size_y; ++i) {
337#pragma unroll
338 for (uint32_t j = 0; j < num_block_x; ++j) {
339 auto x_in = bwd_x.reg.xetla_select<block_elems, 1>(
340 block_elems * (i * num_block_x + j));
341 auto x = xetla_cvt<dtype_acc, dtype_in, block_elems>(x_in);
342 auto dy = matAcc.reg.xetla_select<block_elems, 1>(
343 block_elems * (i * num_block_x + j));
345 = xetla_tanh<dtype_acc, block_elems>(
346 sqrt_two_over_pi * (x + c0 * x * x * x));
347 xetla_vector<dtype_acc, block_elems> w = (0.5f * (1.f + z)
348 + 0.5f * x * (1.f - z * z)
349 * (sqrt_two_over_pi * (1.f + d0 * x * x)));
350 dy = w * dy;
351 }
352 }
353 if constexpr (tile_size_y % block_size_y != 0) {
354 constexpr uint32_t remain_size_y = tile_size_y % block_size_y;
355 constexpr uint32_t remain_y_start
356 = tile_size_y / block_size_y * block_size_y;
357 constexpr uint32_t remain_elems_start
358 = remain_y_start * tile_size_x;
359 constexpr uint32_t remain_block_elems
360 = remain_size_y * block_size_x;
361#pragma unroll
362 for (uint32_t j = 0; j < num_block_x; ++j) {
363 auto x_in = bwd_x.reg.xetla_select<remain_block_elems, 1>(
364 remain_elems_start + remain_block_elems * j);
365 auto x = xetla_cvt<dtype_acc, dtype_in, remain_block_elems>(
366 x_in);
367 auto dy = matAcc.reg.xetla_select<remain_block_elems, 1>(
368 remain_elems_start + remain_block_elems * j);
370 = xetla_tanh<dtype_acc, remain_block_elems>(
371 sqrt_two_over_pi * (x + c0 * x * x * x));
373 * (1.f + z)
374 + 0.5f * x * (1.f - z * z)
375 * (sqrt_two_over_pi * (1.f + d0 * x * x)));
376 dy = w * dy;
377 }
378 }
379 }
380};
381
387template <typename dtype_bias, gpu_arch arch_tag, class enable = void>
390template <typename dtype_bias_, gpu_arch arch_tag>
391struct bias_add_op_t<dtype_bias_, arch_tag,
392 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
393 using dtype_bias = dtype_bias_;
396 using shape_t = typename mem_desc_bias_t::shape_t;
397 using coord_t = typename mem_desc_bias_t::coord_t;
398 using base_t = typename mem_desc_bias_t::base_t;
399
400 struct arguments_t {
403 inline arguments_t() = default;
404 inline arguments_t(base_t base_, shape_t shape_)
405 : shape(shape_), base(base_) {}
406 };
407 template <typename matAcc_t>
408 __XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc,
409 const coord_t &coord, const arguments_t &args,
410 [[maybe_unused]] uint32_t slm_base = 0,
411 [[maybe_unused]] uint32_t nbarrier_base = 0) {
412 using dtype_acc = typename matAcc_t::dtype;
413 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
414 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
415 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
416 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
417 static constexpr int32_t num_block_x = matAcc_t::num_block_x;
418 static constexpr uint32_t block_elems = matAcc_t::block_elems;
419
420 using bias_tile_desc_t = tile_desc_t<tile_size_x, 1, block_size_x, 1,
423 using bias_payload_t = mem_payload_t<mem_desc_bias_t, bias_tile_desc_t,
424 msg_type_v<bias_tile_desc_t, mem_desc_bias_t::space>, arch_tag>;
425 coord_t bias_coord(coord.x, 0);
426 mem_desc_bias_t mem_desc_bias(args.base, args.shape, bias_coord);
427 bias_t bias;
428 bias_payload_t bias_payload(mem_desc_bias);
429 tile_load<cache_hint::cached, cache_hint::cached>(bias, bias_payload);
430
431#pragma unroll
432 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
433#pragma unroll
434 for (uint32_t j = 0; j < num_block_x; j++) {
435 auto dst_reg
436 = matAcc.reg
437 .xetla_select<block_elems, 1>(
438 (i * num_block_x + j) * block_elems)
439 .xetla_format<dtype_acc, block_size_y,
440 block_size_x>();
441#pragma unroll
442 for (uint32_t row_i = 0; row_i < block_size_y; row_i++) {
443 auto src_reg = bias.reg.xetla_select<block_size_x, 1>(
444 j * block_size_x);
445 dst_reg.row(row_i)
446 = xetla_cvt<dtype_acc, dtype_bias, block_size_x>(
447 src_reg)
448 + dst_reg.row(row_i);
449 }
450 }
451 }
452 // process the tail
453 if constexpr ((tile_size_y % block_size_y) != 0) {
454 constexpr uint32_t tail_start_y
455 = tile_size_y / block_size_y * block_size_y;
456 constexpr int32_t tail_size_y = tile_size_y % block_size_y;
457 constexpr int32_t tail_block_elems = tail_size_y * block_size_x;
458#pragma unroll
459 for (uint32_t j = 0; j < num_block_x; j++) {
460 auto dst_reg = matAcc.reg
461 .xetla_select<tail_block_elems, 1>(
462 tail_start_y * tile_size_x
463 + j * tail_block_elems)
464 .xetla_format<dtype_acc, tail_size_y,
465 block_size_x>();
466#pragma unroll
467 for (uint32_t row_i = 0; row_i < tail_size_y; row_i++) {
468 auto src_reg = bias.reg.xetla_select<block_size_x, 1>(
469 j * block_size_x);
470 dst_reg.row(row_i)
471 = xetla_cvt<dtype_acc, dtype_bias, block_size_x>(
472 src_reg)
473 + dst_reg.row(row_i);
474 }
475 }
476 }
477 }
478};
479
484template <typename scale_dtype, typename offset_dtype, gpu_arch arch_tag,
485 class enable = void>
488template <typename scale_dtype_, typename offset_dtype_, gpu_arch arch_tag>
489struct scale_v_offset_v_op_t<scale_dtype_, offset_dtype_, arch_tag,
490 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
491 using scale_dtype = scale_dtype_;
492 using offset_dtype = offset_dtype_;
493
498
499 using scale_shape_t = typename scale_mem_desc_t::shape_t;
500 using scale_base_t = typename scale_mem_desc_t::base_t;
501
502 using offset_shape_t = typename offset_mem_desc_t::shape_t;
503 using offset_base_t = typename offset_mem_desc_t::base_t;
504
505 using coord_t = typename scale_mem_desc_t::coord_t;
506
507 struct arguments_t {
512 inline arguments_t() = default;
513 inline arguments_t(scale_base_t scale_base_, scale_shape_t scale_shape_,
514 offset_base_t offset_base_, offset_shape_t offset_shape_)
515 : scale_base(scale_base_)
516 , scale_shape(scale_shape_)
517 , offset_base(offset_base_)
518 , offset_shape(offset_shape_) {}
519 };
520 template <typename matAcc_t>
521 __XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc,
522 const coord_t &coord, const arguments_t &args,
523 [[maybe_unused]] uint32_t slm_base = 0,
524 [[maybe_unused]] uint32_t nbarrier_base = 0) {
525
526 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
527 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
528 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
529 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
530 static constexpr int32_t num_block_x = matAcc_t::num_block_x;
531 static constexpr uint32_t block_elems = matAcc_t::block_elems;
532
533 using scale_tile_desc_t = tile_desc_t<tile_size_x, 1, block_size_x, 1,
535 using scale_tile_t = tile_t<scale_dtype, scale_tile_desc_t>;
536 using scale_payload_t
537 = mem_payload_t<scale_mem_desc_t, scale_tile_desc_t,
538 msg_type_v<scale_tile_desc_t, scale_mem_desc_t::space>,
539 arch_tag>;
540 coord_t scale_coord(coord.x, 0);
541 scale_mem_desc_t scale_mem_desc(
542 args.scale_base, args.scale_shape, scale_coord);
543 scale_tile_t scale_tile;
544 scale_payload_t scale_payload(scale_mem_desc);
545 tile_load<cache_hint::cached, cache_hint::cached>(
546 scale_tile, scale_payload);
547
548 using offset_tile_desc_t = tile_desc_t<tile_size_x, 1, block_size_x, 1,
550 using offset_tile_t = tile_t<offset_dtype, offset_tile_desc_t>;
551 using offset_payload_t = mem_payload_t<offset_mem_desc_t,
552 offset_tile_desc_t,
553 msg_type_v<offset_tile_desc_t, offset_mem_desc_t::space>,
554 arch_tag>;
555 coord_t offset_coord(coord.x, 0);
556 offset_mem_desc_t offset_mem_desc(
557 args.offset_base, args.offset_shape, offset_coord);
558 offset_tile_t offset_tile;
559 offset_payload_t offset_payload(offset_mem_desc);
560 tile_load<cache_hint::cached, cache_hint::cached>(
561 offset_tile, offset_payload);
562
563#pragma unroll
564 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
565#pragma unroll
566 for (uint32_t j = 0; j < num_block_x; j++) {
567 auto acc_reg = matAcc.reg.xetla_select<block_elems, 1>(
568 (i * num_block_x + j) * block_elems);
569 auto offset_reg = offset_tile.reg.xetla_select<block_size_x, 1>(
570 j * block_size_x);
571 auto scale_reg = scale_tile.reg.xetla_select<block_size_x, 1>(
572 j * block_size_x);
573#pragma unroll
574 for (uint32_t row_i = 0; row_i < block_size_y; row_i++) {
575 acc_reg.xetla_select<block_size_x, 1>(row_i * block_size_x)
576 = scale_reg
577 * acc_reg.xetla_select<block_size_x, 1>(
578 row_i * block_size_x)
579
580 + offset_reg;
581 }
582 }
583 }
584 // process the tail
585 if constexpr ((tile_size_y % block_size_y) != 0) {
586 constexpr uint32_t tail_start_y
587 = tile_size_y / block_size_y * block_size_y;
588 constexpr int32_t tail_size_y = tile_size_y % block_size_y;
589 constexpr int32_t tail_block_elems = tail_size_y * block_size_x;
590#pragma unroll
591 for (uint32_t j = 0; j < num_block_x; j++) {
592 auto acc_reg = matAcc.reg.xetla_select<tail_block_elems, 1>(
593 tail_start_y * tile_size_x + j * tail_block_elems);
594 auto offset_reg = offset_tile.reg.xetla_select<block_size_x, 1>(
595 j * block_size_x);
596 auto scale_reg = scale_tile.reg.xetla_select<block_size_x, 1>(
597 j * block_size_x);
598#pragma unroll
599 for (uint32_t row_i = 0; row_i < tail_size_y; row_i++) {
600 acc_reg.xetla_select<block_size_x, 1>(row_i * block_size_x)
601 = scale_reg
602 * acc_reg.xetla_select<block_size_x, 1>(
603 row_i * block_size_x)
604 + offset_reg;
605 }
606 }
607 }
608 }
609};
610
614template <typename scale_dtype, gpu_arch arch_tag, class enable = void>
615struct scale_v_op_t {};
617template <typename scale_dtype_, gpu_arch arch_tag>
618struct scale_v_op_t<scale_dtype_, arch_tag,
619 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
620 using scale_dtype = scale_dtype_;
621
624
625 using scale_shape_t = typename scale_mem_desc_t::shape_t;
626 using scale_base_t = typename scale_mem_desc_t::base_t;
627 using coord_t = typename scale_mem_desc_t::coord_t;
628
629 struct arguments_t {
632
633 inline arguments_t() = default;
634 inline arguments_t(scale_base_t scale_base_, scale_shape_t scale_shape_)
635 : scale_base(scale_base_), scale_shape(scale_shape_) {}
636 };
637 template <typename matAcc_t>
638 __XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc,
639 const coord_t &coord, const arguments_t &args,
640 [[maybe_unused]] uint32_t slm_base = 0,
641 [[maybe_unused]] uint32_t nbarrier_base = 0) {
642
643 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
644 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
645 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
646 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
647 static constexpr int32_t num_block_x = matAcc_t::num_block_x;
648 static constexpr uint32_t block_elems = matAcc_t::block_elems;
649
650 using scale_tile_desc_t = tile_desc_t<tile_size_x, 1, block_size_x, 1,
652 using scale_tile_t = tile_t<scale_dtype, scale_tile_desc_t>;
653 using scale_payload_t
654 = mem_payload_t<scale_mem_desc_t, scale_tile_desc_t,
655 msg_type_v<scale_tile_desc_t, scale_mem_desc_t::space>,
656 arch_tag>;
657 coord_t scale_coord(coord.x, 0);
658 scale_mem_desc_t scale_mem_desc(
659 args.scale_base, args.scale_shape, scale_coord);
660 scale_tile_t scale_tile;
661 scale_payload_t scale_payload(scale_mem_desc);
662 tile_load<cache_hint::cached, cache_hint::cached>(
663 scale_tile, scale_payload);
664
665#pragma unroll
666 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
667#pragma unroll
668 for (uint32_t j = 0; j < num_block_x; j++) {
669 auto acc_reg = matAcc.reg.xetla_select<block_elems, 1>(
670 (i * num_block_x + j) * block_elems);
671 auto scale_reg = scale_tile.reg.xetla_select<block_size_x, 1>(
672 j * block_size_x);
673#pragma unroll
674 for (uint32_t row_i = 0; row_i < block_size_y; row_i++) {
675 acc_reg.xetla_select<block_size_x, 1>(row_i * block_size_x)
676 = scale_reg
677 * acc_reg.xetla_select<block_size_x, 1>(
678 row_i * block_size_x);
679 }
680 }
681 }
682 // process the tail
683 if constexpr ((tile_size_y % block_size_y) != 0) {
684 constexpr uint32_t tail_start_y
685 = tile_size_y / block_size_y * block_size_y;
686 constexpr int32_t tail_size_y = tile_size_y % block_size_y;
687 constexpr int32_t tail_block_elems = tail_size_y * block_size_x;
688#pragma unroll
689 for (uint32_t j = 0; j < num_block_x; j++) {
690 auto acc_reg = matAcc.reg.xetla_select<tail_block_elems, 1>(
691 tail_start_y * tile_size_x + j * tail_block_elems);
692 auto scale_reg = scale_tile.reg.xetla_select<block_size_x, 1>(
693 j * block_size_x);
694#pragma unroll
695 for (uint32_t row_i = 0; row_i < tail_size_y; row_i++) {
696 acc_reg.xetla_select<block_size_x, 1>(row_i * block_size_x)
697 = scale_reg
698 * acc_reg.xetla_select<block_size_x, 1>(
699 row_i * block_size_x);
700 }
701 }
702 }
703 }
704};
705
713template <reduce_op reduce_kind, typename dtype_in, gpu_arch arch_tag,
714 class enable = void>
717template <reduce_op reduce_kind_, typename dtype_in_, gpu_arch arch_tag>
718struct elemwise_reduce_op_t<reduce_kind_, dtype_in_, arch_tag,
719 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
720 using dtype_in = dtype_in_;
723 using shape_t = typename mem_desc_in_t::shape_t;
724 using coord_t = typename mem_desc_in_t::coord_t;
725 using base_t = typename mem_desc_in_t::base_t;
726 static constexpr reduce_op reduce_kind = reduce_kind_;
727
728 struct arguments_t {
731 inline arguments_t() = default;
732 inline arguments_t(base_t base_, shape_t shape_)
733 : shape(shape_), base(base_) {}
734 };
735 template <typename matAcc_t>
736 __XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc,
737 const coord_t &coord, const arguments_t &args,
738 [[maybe_unused]] uint32_t slm_base = 0,
739 [[maybe_unused]] uint32_t nbarrier_base = 0) {
740 using dtype_acc = typename matAcc_t::dtype;
741 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
742 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
743 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
744 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
745 static constexpr int32_t num_block_x = matAcc_t::num_block_x;
746 static constexpr uint32_t block_elems = matAcc_t::block_elems;
747
748 using mat_in_tile_desc_t = tile_desc_t<tile_size_x, block_size_y,
749 block_size_x, block_size_y, reg_layout::tiled>;
750 using mat_in_tile_t = tile_t<dtype_in, mat_in_tile_desc_t>;
751 using mat_in_payload_t = mem_payload_t<mem_desc_in_t,
752 mat_in_tile_desc_t,
753 msg_type_v<mat_in_tile_desc_t, mem_desc_in_t::space>, arch_tag>;
754 using mat_in_tile_acc_t = tile_t<dtype_acc, mat_in_tile_desc_t>;
755 mem_desc_in_t mem_desc_in(args.base, args.shape, coord);
756 mat_in_tile_t mat_in;
757 mat_in_payload_t mat_in_payload(mem_desc_in);
758 mat_in_tile_acc_t mat_in_acc;
759
760#pragma unroll
761 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
762 tile_load<cache_hint::cached, cache_hint::cached>(
763 mat_in, mat_in_payload);
764 elemwise_cvt(mat_in_acc, mat_in);
765#pragma unroll
766 for (uint32_t j = 0; j < num_block_x; j++) {
767 auto dst_reg = matAcc.reg.xetla_select<block_elems, 1>(
768 (i * num_block_x + j) * block_elems);
769 auto src_reg = mat_in_acc.reg.xetla_select<block_elems, 1>(
770 j * block_elems);
771 dst_reg = reduce_helper<reduce_kind, dtype_acc, block_elems>(
772 src_reg, dst_reg);
773 }
774 mat_in_payload.template update_tdesc<tdesc_update_dir::y_dir>(
775 block_size_y);
776 }
777 // process the tail
778 if constexpr ((tile_size_y % block_size_y) != 0) {
779 constexpr uint32_t tail_start_y
780 = tile_size_y / block_size_y * block_size_y;
781 constexpr int32_t tail_size_y = tile_size_y % block_size_y;
782 constexpr int32_t tail_block_elems = tail_size_y * block_size_x;
783
784 using mat_tail_in_tile_desc_t = tile_desc_t<tile_size_x,
785 tail_size_y, block_size_x, tail_size_y, reg_layout::tiled>;
786 using mat_tail_in_tile_t
788 using mat_tail_in_payload_t = mem_payload_t<mem_desc_in_t,
789 mat_tail_in_tile_desc_t,
790 msg_type_v<mat_tail_in_tile_desc_t, mem_desc_in_t::space>,
791 arch_tag>;
792 using mat_tail_in_tile_acc_t
794 mat_tail_in_tile_t mat_tail_in;
795 mat_tail_in_payload_t mat_tail_in_payload(mem_desc_in);
796 mat_tail_in_tile_acc_t mat_tail_in_acc;
797 mat_tail_in_payload.template update_tdesc<tdesc_update_dir::y_dir>(
798 tail_start_y);
799 tile_load<cache_hint::cached, cache_hint::cached>(
800 mat_tail_in, mat_tail_in_payload);
801 elemwise_cvt(mat_tail_in_acc, mat_tail_in);
802#pragma unroll
803 for (uint32_t j = 0; j < num_block_x; j++) {
804 auto dst_reg = matAcc.reg.xetla_select<tail_block_elems, 1>(
805 tail_start_y * tile_size_x + j * tail_block_elems);
806 auto src_reg
807 = mat_tail_in_acc.reg.xetla_select<tail_block_elems, 1>(
808 j * tail_block_elems);
809 dst_reg = reduce_helper<reduce_kind, dtype_acc,
810 tail_block_elems>(src_reg, dst_reg);
811 }
812 }
813 }
814};
815
824template <reduce_op reduce_kind, typename dtype_in,
825 gpu_arch arch_tag = gpu_arch::Xe>
828template <reduce_op reduce_kind_, typename dtype_in_>
829struct elemwise_reduce_op_stream_k_t<reduce_kind_, dtype_in_, gpu_arch::Xe> {
830 using dtype_in = dtype_in_;
833 using shape_t = typename mem_desc_in_t::shape_t;
834 using coord_t = typename mem_desc_in_t::coord_t;
835 using base_t = typename mem_desc_in_t::base_t;
836 static constexpr reduce_op reduce_kind = reduce_kind_;
837
838 struct arguments_t {
841 inline arguments_t() = default;
842 inline arguments_t(base_t base_, shape_t shape_)
843 : shape(shape_), base(base_) {}
844 };
845 template <typename matAcc_t>
846 __XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc,
847 const coord_t &coord, const arguments_t &args,
848 [[maybe_unused]] uint32_t slm_base = 0,
849 [[maybe_unused]] uint32_t nbarrier_base = 0) {
850 using dtype_acc = typename matAcc_t::dtype;
851 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
852 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
853 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
854 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
855 static constexpr int32_t num_block_x = matAcc_t::num_block_x;
856 static constexpr uint32_t block_elems = matAcc_t::block_elems;
857
858 using mat_in_tile_desc_t = tile_desc_t<block_size_x, block_size_y,
859 block_size_x, block_size_y, reg_layout::tiled>;
860 using mat_in_tile_t = tile_t<dtype_in, mat_in_tile_desc_t>;
861 using mat_in_payload_t
862 = mem_payload_t<mem_desc_in_t, mat_in_tile_desc_t,
863 msg_type_v<mat_in_tile_desc_t, mem_desc_in_t::space>,
865 mem_desc_in_t mem_desc_in(args.base, args.shape, coord);
866 mat_in_tile_t mat_in;
867 mat_in_tile_t mat_zero(0);
868 mat_in_payload_t mat_in_payload(mem_desc_in);
869
870#pragma unroll
871 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
872#pragma unroll
873 for (uint32_t j = 0; j < num_block_x; j++) {
874
875 tile_load<cache_hint::cached, cache_hint::cached>(
876 mat_in, mat_in_payload);
877 auto dst_reg = matAcc.reg.xetla_select<block_elems, 1>(
878 (i * num_block_x + j) * block_elems);
879
880 auto src_reg = mat_in.reg;
881 dst_reg = reduce_helper<reduce_kind, dtype_acc, block_elems>(
882 src_reg, dst_reg);
883
885 cache_hint::write_back>(mat_zero, mat_in_payload);
886 mat_in_payload.template update_tdesc<tdesc_update_dir::x_dir>(
887 block_size_x);
888 }
889 mat_in_payload.template update_tdesc<tdesc_update_dir::x_dir>(
890 -num_block_x * block_size_x);
891 mat_in_payload.template update_tdesc<tdesc_update_dir::y_dir>(
892 block_size_y);
893 }
894 // process the tail
895 if constexpr ((tile_size_y % block_size_y) != 0) {
896 constexpr uint32_t tail_start_y
897 = tile_size_y / block_size_y * block_size_y;
898 constexpr int32_t tail_size_y = tile_size_y % block_size_y;
899 constexpr int32_t tail_block_elems = tail_size_y * block_size_x;
900#pragma unroll
901 for (uint32_t j = 0; j < num_block_x; j++) {
902
903 tile_load<cache_hint::cached, cache_hint::cached>(
904 mat_in, mat_in_payload);
905 auto dst_reg = matAcc.reg.xetla_select<tail_block_elems, 1>(
906 tail_start_y * tile_size_x + j * tail_block_elems);
907 auto src_reg = mat_in.reg.xetla_select<tail_block_elems, 1>(
908 tail_start_y * tile_size_x + j * tail_block_elems);
909 dst_reg = reduce_helper<reduce_kind, dtype_acc,
910 tail_block_elems>(src_reg, dst_reg);
911
913 cache_hint::write_back>(mat_zero, mat_in_payload);
914
915 mat_in_payload.template update_tdesc<tdesc_update_dir::x_dir>(
916 block_size_x);
917 }
918 }
919 }
920};
921
928template <typename dtype_mask, gpu_arch arch_tag, class enable = void>
929struct dropout_op_t {};
931template <typename dtype_mask_, gpu_arch arch_tag>
932struct dropout_op_t<dtype_mask_, arch_tag,
933 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
934 using dtype_mask = dtype_mask_;
937 using shape_t = typename mem_desc_mask_t::shape_t;
938 using coord_t = typename mem_desc_mask_t::coord_t;
939 using base_t = typename mem_desc_mask_t::base_t;
940 static constexpr uint32_t num_flag = 4;
941 static constexpr uint32_t unroll_size = num_flag * 16;
942 struct arguments_t {
945 float prob;
946 float scale;
947 inline arguments_t() = default;
949 base_t base_, shape_t shape_, float prob_, float scale_)
950 : shape(shape_), base(base_), prob(prob_), scale(scale_) {}
951 };
952
953 template <typename matAcc_t>
954 __XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc,
955 const coord_t &coord, const arguments_t &args,
956 [[maybe_unused]] uint32_t slm_base = 0,
957 [[maybe_unused]] uint32_t nbarrier_base = 0) {
958 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
959 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
960 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
961 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
962 static constexpr uint32_t tile_elems = matAcc_t::tile_elems;
963 if (args.prob == 0) { return; }
964 using mask_in_tile_desc_t = tile_desc_t<tile_size_x, tile_size_y,
965 block_size_x, block_size_y, reg_layout::tiled>;
966 using mask_in_tile_t = tile_t<dtype_mask, mask_in_tile_desc_t>;
967 using mask_in_payload_t
968 = mem_payload_t<mem_desc_mask_t, mask_in_tile_desc_t,
969 msg_type_v<mask_in_tile_desc_t, mem_desc_mask_t::space>,
970 arch_tag>;
971 mem_desc_mask_t mem_desc_mask(args.base, args.shape, coord);
972 mask_in_tile_t mask_in;
973 mask_in_payload_t mask_in_payload(mem_desc_mask);
974 tile_load<cache_hint::cached, cache_hint::cached>(
975 mask_in, mask_in_payload);
976#pragma unroll
977 for (uint32_t i = 0; i < tile_elems / unroll_size; i++) {
979 = mask_in.reg.xetla_select<unroll_size, 1>(i * unroll_size)
980 > 0;
981 auto dst_reg
982 = matAcc.reg.xetla_select<unroll_size, 1>(i * unroll_size);
983 dst_reg *= args.scale;
984 dst_reg.xetla_merge(0, mask_flag);
985 }
986 if constexpr (tile_elems % unroll_size != 0) {
987 constexpr uint32_t remain_len = tile_elems % unroll_size;
988 constexpr uint32_t remain_start
989 = tile_elems / unroll_size * unroll_size;
990 xetla_mask<remain_len> mask_flag
991 = mask_in.reg.xetla_select<remain_len, 1>(remain_start) > 0;
992 auto dst_reg = matAcc.reg.xetla_select<remain_len, 1>(remain_start);
993 dst_reg *= args.scale;
994 dst_reg.xetla_merge(0, mask_flag);
995 }
996 }
997};
998
1005template <typename dtype_mask, gpu_arch arch_tag, class enable = void>
1008template <typename dtype_mask_, gpu_arch arch_tag>
1009struct rng_dropout_op_t<dtype_mask_, arch_tag,
1010 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
1011 using dtype_mask = dtype_mask_;
1014 using shape_t = typename mem_desc_mask_t::shape_t;
1015 using coord_t = typename mem_desc_mask_t::coord_t;
1016 using base_t = typename mem_desc_mask_t::base_t;
1017 static constexpr uint32_t random_simd = 16;
1018 static constexpr uint32_t random_len = 4 * random_simd;
1020
1021 struct arguments_t {
1025 float prob;
1026 uint64_t rand_seed;
1027
1028 inline arguments_t() = default;
1029 inline arguments_t(base_t mask_base_, shape_t mask_shape_, float prob_,
1030 uint64_t *rand_offset_ptr_,
1031 uint64_t rand_seed_ = 67280421310721)
1032 : mask_shape(mask_shape_)
1033 , mask_base(mask_base_)
1034 , rand_offset_ptr(rand_offset_ptr_)
1035 , prob(prob_)
1036 , rand_seed(rand_seed_) {}
1037 };
1038
1039 template <typename matAcc_t>
1040 __XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc,
1041 const coord_t &coord, const arguments_t &args,
1042 [[maybe_unused]] uint32_t slm_base = 0,
1043 [[maybe_unused]] uint32_t nbarrier_base = 0) {
1044 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
1045 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
1046 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
1047 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
1048 static constexpr uint32_t tile_elems = matAcc_t::tile_elems;
1049
1050 using mask_out_tile_desc_t = tile_desc_t<tile_size_x, tile_size_y,
1051 block_size_x, block_size_y, reg_layout::tiled>;
1052 using mask_out_tile_t = tile_t<dtype_mask, mask_out_tile_desc_t>;
1053 using mask_out_payload_t = mem_payload_t<mem_desc_mask_t,
1054 mask_out_tile_desc_t,
1055 msg_type_v<mask_out_tile_desc_t, mem_desc_mask_t::space>,
1056 arch_tag>;
1057 if (args.prob == 0) { return; }
1058 //calculate the scale internally
1059 float scale = 1.f / (1.f - args.prob);
1060 uint32_t threshold = uint32_t(args.prob * float(4294967296));
1061 xetla_vector<uint64_t, 1> rand_offset_v
1064 args.rand_offset_ptr, 0);
1065 uint64_t rand_offset = rand_offset_v[0];
1066 uint64_t rand_subseq = uint64_t(coord.y) << 32 | uint64_t(coord.x);
1067 rand_gen.init(args.rand_seed, rand_subseq, rand_offset);
1068
1069 mem_desc_mask_t mem_desc_mask(args.mask_base, args.mask_shape, coord);
1070 mask_out_tile_t mask_out;
1071 mask_out_payload_t mask_out_payload(mem_desc_mask);
1072
1073#pragma unroll
1074 for (uint32_t i = 0; i < tile_elems / random_len; i++) {
1075 auto out_sub
1076 = matAcc.reg.xetla_select<random_len, 1>(i * random_len);
1077 auto mask_sub
1078 = mask_out.reg.xetla_select<random_len, 1>(i * random_len);
1079 xetla_vector<uint32_t, random_len> rand_val = rand_gen.rand();
1080 xetla_mask<random_len> mask_flag = rand_val < threshold;
1081 out_sub *= scale;
1082 out_sub.xetla_merge(0, mask_flag);
1083 mask_sub.xetla_merge(1, 0, mask_flag);
1084 }
1085 if constexpr (tile_elems % random_len != 0) {
1086 constexpr uint32_t remain_len = tile_elems % random_len;
1087 constexpr uint32_t remain_start
1088 = tile_elems / random_len * random_len;
1089 auto out_sub = matAcc.reg.xetla_select<remain_len, 1>(remain_start);
1090 auto mask_sub
1091 = mask_out.reg.xetla_select<remain_len, 1>(remain_start);
1092 // dropout, still generate random_len
1093 xetla_vector<uint32_t, random_len> rand_val = rand_gen.rand();
1094 xetla_mask<random_len> mask_flag = rand_val < threshold;
1095 out_sub *= scale;
1096 out_sub.xetla_merge(0, mask_flag.xetla_select<remain_len, 1>(0));
1097 mask_sub.xetla_merge(
1098 1, 0, mask_flag.xetla_select<remain_len, 1>(0));
1099 }
1100 tile_store<cache_hint::streaming>(mask_out, mask_out_payload);
1101 }
1102};
1103
1109template <typename dtype_in, gpu_arch arch_tag, class enable = void>
1112template <typename dtype_in_, gpu_arch arch_tag>
1113struct scalar_mul_op_t<dtype_in_, arch_tag,
1114 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
1115 using dtype_in = dtype_in_;
1118 using coord_t = typename mem_desc_in_t::coord_t;
1119
1120 struct arguments_t {
1122 inline arguments_t() = default;
1123 inline arguments_t(dtype_in multiplier_) : multiplier(multiplier_) {}
1124 };
1125
1126 template <typename matAcc_t>
1127 __XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc,
1128 [[maybe_unused]] coord_t coord, [[maybe_unused]] arguments_t args,
1129 [[maybe_unused]] uint32_t slm_base = 0,
1130 [[maybe_unused]] uint32_t nbarrier_base = 0) {
1131 using dtype_acc = typename matAcc_t::dtype;
1132 static_assert(std::is_same<dtype_in, dtype_acc>::value,
1133 "Given multiplier must have same type as matAcc!");
1134 matAcc.reg *= args.multiplier;
1135 }
1136};
1137
1143template <typename dtype_in, gpu_arch arch_tag, class enable = void>
1144struct linear_op_t {};
1146template <typename dtype_in_, gpu_arch arch_tag>
1147struct linear_op_t<dtype_in_, arch_tag,
1148 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
1149 using dtype_in = dtype_in_;
1152 using shape_t = typename mem_desc_in_t::shape_t;
1153 using coord_t = typename mem_desc_in_t::coord_t;
1154 using base_t = typename mem_desc_in_t::base_t;
1155
1156 struct arguments_t {
1161 inline arguments_t() = default;
1163 base_t base_, shape_t shape_, dtype_in alpha_, dtype_in beta_)
1164 : shape(shape_), base(base_), alpha(alpha_), beta(beta_) {}
1165 };
1166
1167 template <typename matAcc_t>
1168 __XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc,
1169 const coord_t &coord, const arguments_t &args,
1170 [[maybe_unused]] uint32_t slm_base = 0,
1171 [[maybe_unused]] uint32_t nbarrier_base = 0) {
1172 using dtype_acc = typename matAcc_t::dtype;
1173 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
1174 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
1175 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
1176 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
1177 static constexpr uint32_t num_block_x = matAcc_t::num_block_x;
1178 static constexpr uint32_t num_block_y = matAcc_t::num_block_y;
1179 static constexpr uint32_t block_elems = matAcc_t::block_elems;
1180 static constexpr uint32_t remained_size_y = tile_size_y % block_size_y;
1181
1182 using mat_in_tile_desc_t = tile_desc_t<tile_size_x, block_size_y,
1183 block_size_x, block_size_y, reg_layout::tiled>;
1184 using mat_in_tile_t = tile_t<dtype_in, mat_in_tile_desc_t>;
1185 using mat_in_payload_t = mem_payload_t<mem_desc_in_t,
1186 mat_in_tile_desc_t,
1187 msg_type_v<mat_in_tile_desc_t, mem_desc_in_t::space>, arch_tag>;
1188 using mat_in_tile_acc_t = tile_t<dtype_acc, mat_in_tile_desc_t>;
1189 mem_desc_in_t mem_desc_in(args.base, args.shape, coord);
1190 mat_in_tile_t mat_in;
1191 mat_in_payload_t mat_in_payload(mem_desc_in);
1192 mat_in_tile_acc_t mat_in_acc;
1193
1194 dtype_acc alpha = dtype_acc(args.alpha);
1195 dtype_acc beta = dtype_acc(args.beta);
1196 matAcc.reg *= alpha;
1197
1198#pragma unroll
1199 for (uint32_t i = 0; i < num_block_y; ++i) {
1200 tile_load<cache_hint::cached, cache_hint::cached>(
1201 mat_in, mat_in_payload);
1202 elemwise_cvt(mat_in_acc, mat_in);
1203 mat_in_acc.reg *= beta;
1204#pragma unroll
1205 for (uint32_t j = 0; j < num_block_x; ++j) {
1206 auto dst_reg = matAcc.reg.xetla_select<block_elems, 1>(
1207 (i * num_block_x + j) * block_elems);
1208 auto src_reg = mat_in_acc.reg.xetla_select<block_elems, 1>(
1209 j * block_elems);
1210 dst_reg = reduce_helper<reduce_op::sum, dtype_acc, block_elems>(
1211 src_reg, dst_reg);
1212 }
1213 mat_in_payload.template update_tdesc<tdesc_update_dir::y_dir>(
1214 block_size_y);
1215 }
1216 // process the tail
1217 if constexpr (remained_size_y > 0) {
1218 constexpr uint32_t tail_start_y = num_block_y * block_size_y;
1219 constexpr uint32_t tail_block_elems
1220 = remained_size_y * block_size_x;
1221
1222 using mat_tail_in_tile_desc_t
1223 = tile_desc_t<tile_size_x, remained_size_y, block_size_x,
1224 remained_size_y, reg_layout::tiled>;
1225 using mat_tail_in_tile_t
1227 using mat_tail_in_payload_t = mem_payload_t<mem_desc_in_t,
1228 mat_tail_in_tile_desc_t,
1229 msg_type_v<mat_tail_in_tile_desc_t, mem_desc_in_t::space>,
1230 arch_tag>;
1231 using mat_tail_in_tile_acc_t
1233
1234 mat_tail_in_tile_t mat_tail_in;
1235 mat_tail_in_payload_t mat_tail_in_payload(mem_desc_in);
1236 mat_tail_in_tile_acc_t mat_tail_in_acc;
1237 mat_tail_in_payload.template update_tdesc<tdesc_update_dir::y_dir>(
1238 tail_start_y);
1239 tile_load<cache_hint::cached, cache_hint::cached>(
1240 mat_tail_in, mat_tail_in_payload);
1241 elemwise_cvt(mat_tail_in_acc, mat_tail_in);
1242 mat_tail_in_acc.reg *= beta;
1243#pragma unroll
1244 for (uint32_t j = 0; j < num_block_x; ++j) {
1245 auto dst_reg = matAcc.reg.xetla_select<tail_block_elems, 1>(
1246 tail_start_y * tile_size_x + j * tail_block_elems);
1247 auto src_reg
1248 = mat_tail_in_acc.reg.xetla_select<tail_block_elems, 1>(
1249 j * tail_block_elems);
1250 dst_reg = reduce_helper<reduce_op::sum, dtype_acc,
1251 tail_block_elems>(src_reg, dst_reg);
1252 }
1253 }
1254 }
1255};
1256
1257} // namespace gpu::xetla::subgroup
#define __XETLA_API
Definition common.hpp:43
#define xetla_format
xetla format.
Definition base_ops.hpp:38
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
__XETLA_API xetla_vector< Ty, N *NElts > xetla_load_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_mask< N > pred=1)
Stateless scattered load.
Definition memory.hpp:245
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
__XETLA_API std::enable_if_t< reduce_kind==reduce_op::sum, xetla_vector< dtype, size > > reduce_helper(xetla_vector< dtype, size > a, xetla_vector< dtype, size > b)
Definition misc.hpp:112
C++ API.
Definition limitation.hpp:457
__XETLA_API std::enable_if_t<(T_src::register_layout !=reg_layout::linear) &&(T_dst::register_layout !=reg_layout::linear) &&is_same_layout< T_dst, T_src >::value &&(!is_floating_to_integer< T_dst, T_src >::value)> elemwise_cvt(T_dst &dst, T_src &src)
Is the element wise data conversion, the src and dst tile should have the same layout.
Definition op_function.hpp:40
__XETLA_API std::enable_if_t< detail::check_store_type< tile_t, payload_t >::is_global_2d_xe > tile_store(tile_t &tile, payload_t &payload)
Is the func storing data from register file to global memory.
Definition store_xe.hpp:91
reduce_op
xetla reduce op
Definition common.hpp:217
gpu_arch
Definition common.hpp:73
C++ API.
Definition dict.hpp:59
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:408
Is the bias_add op functor.
Definition tile_op_functor.hpp:388
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:954
typename mem_desc_mask_t::coord_t coord_t
Definition tile_op_functor.hpp:938
typename mem_desc_mask_t::shape_t shape_t
Definition tile_op_functor.hpp:937
arguments_t(base_t base_, shape_t shape_, float prob_, float scale_)
Definition tile_op_functor.hpp:948
Is the dropout op functor.
Definition tile_op_functor.hpp:929
typename mem_desc_in_t::base_t base_t
Definition tile_op_functor.hpp:835
typename mem_desc_in_t::shape_t shape_t
Definition tile_op_functor.hpp:833
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:846
typename mem_desc_in_t::coord_t coord_t
Definition tile_op_functor.hpp:834
Is the element-wise reduce op functor, specialized for stream_k dispatch Load partial sum from scratc...
Definition tile_op_functor.hpp:826
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:736
Is the element-wise reduce op functor.
Definition tile_op_functor.hpp:715
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:309
Is the element-wise gelu backward op functor.
Definition tile_op_functor.hpp:290
Definition tile_op_functor.hpp:130
Is the element-wise gelu inference forward op functor.
Definition tile_op_functor.hpp:129
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:132
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:192
Is the element-wise gelu training forward op functor.
Definition tile_op_functor.hpp:172
arguments_t(base_t base_, shape_t shape_, dtype_in alpha_, dtype_in beta_)
Definition tile_op_functor.hpp:1162
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:1168
typename mem_desc_in_t::coord_t coord_t
Definition tile_op_functor.hpp:1153
typename mem_desc_in_t::shape_t shape_t
Definition tile_op_functor.hpp:1152
typename mem_desc_in_t::base_t base_t
Definition tile_op_functor.hpp:1154
Is the linear_op functor.
Definition tile_op_functor.hpp:1144
Is to illustrate the memory information.
Definition api.hpp:44
Definition tile_op_functor.hpp:34
Is none op functor, for placeholder purpose.
Definition tile_op_functor.hpp:33
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:36
Definition tile_op_functor.hpp:47
Is the element-wise relu op functor.
Definition tile_op_functor.hpp:46
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:49
arguments_t(base_t mask_base_, shape_t mask_shape_, float prob_, uint64_t *rand_offset_ptr_, uint64_t rand_seed_=67280421310721)
Definition tile_op_functor.hpp:1029
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:1040
Is the random number generator and dropout op functor.
Definition tile_op_functor.hpp:1006
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, coord_t coord, arguments_t args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:1127
Is the scalar_multiply op functor.
Definition tile_op_functor.hpp:1110
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:521
arguments_t(scale_base_t scale_base_, scale_shape_t scale_shape_, offset_base_t offset_base_, offset_shape_t offset_shape_)
Definition tile_op_functor.hpp:513
Is MatAcc * vector scale + vector offset.
Definition tile_op_functor.hpp:486
arguments_t(scale_base_t scale_base_, scale_shape_t scale_shape_)
Definition tile_op_functor.hpp:634
typename scale_mem_desc_t::base_t scale_base_t
Definition tile_op_functor.hpp:626
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:638
typename scale_mem_desc_t::shape_t scale_shape_t
Definition tile_op_functor.hpp:625
typename scale_mem_desc_t::coord_t coord_t
Definition tile_op_functor.hpp:627
Is MatAcc * vector scale.
Definition tile_op_functor.hpp:615
Definition tile_op_functor.hpp:90
Is the element-wise sigmoid op functor.
Definition tile_op_functor.hpp:89
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:92
Definition tile_op_functor.hpp:63
Is the element-wise tanh op functor.
Definition tile_op_functor.hpp:62
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition tile_op_functor.hpp:65
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
Definition rand.hpp:30
__XETLA_API xetla_vector< uint32_t, 4 *SIMD > rand()
Definition rand.hpp:57
__XETLA_API void init(uint64_t seed, uint64_t subseq, uint64_t offset)
Definition rand.hpp:38
C++ API.
C++ API.