XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
op_function.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (c) 2022-2023 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
19
20#pragma once
21
22#include "subgroup/tile/api.hpp"
24
25namespace gpu::xetla::subgroup {
26
34template <typename T_dst, typename T_src>
36 typename std::enable_if_t<(T_src::register_layout != reg_layout::linear)
37 && (T_dst::register_layout != reg_layout::linear)
40 elemwise_cvt(T_dst &dst, T_src &src) {
41 constexpr uint32_t block_size_x = T_dst::block_size_x;
42 constexpr uint32_t tile_elems = T_dst::tile_elems;
43 using dtype_src = typename T_src::dtype;
44 using dtype_dst = typename T_dst::dtype;
45 if constexpr (std::is_same<dtype_src, dtype_dst>::value) {
46 dst.reg = src.reg;
47 } else {
48#pragma unroll
49 for (uint32_t i = 0; i < tile_elems; i += block_size_x) {
50 dst.reg.xetla_select<block_size_x, 1>(i)
51 = xetla_cvt<dtype_dst, dtype_src, block_size_x>(
52 src.reg.xetla_select<block_size_x, 1>(i));
53 }
54 }
55}
56
64template <typename T_dst, typename T_src>
66 typename std::enable_if_t<(T_src::register_layout != reg_layout::linear)
67 && (T_dst::register_layout != reg_layout::linear)
70 elemwise_cvt(T_dst &dst, T_src &src) {
71 constexpr uint32_t block_size_x = T_dst::block_size_x;
72 constexpr uint32_t tile_elems = T_dst::tile_elems;
73 using dtype_src = typename T_src::dtype;
74 using dtype_dst = typename T_dst::dtype;
75
77 //rnde
78#pragma unroll
79 for (uint32_t i = 0; i < tile_elems; i += block_size_x) {
80 rnde_reg.xetla_select<block_size_x, 1>(i)
81 = xetla_rnde<dtype_src, block_size_x>(
82 src.reg.xetla_select<block_size_x, 1>(i));
83 }
84 //sat
85#pragma unroll
86 for (uint32_t i = 0; i < tile_elems; i += block_size_x) {
87 dst.reg.xetla_select<block_size_x, 1>(i)
88 = xetla_sat<dtype_dst, dtype_src, block_size_x>(
89 rnde_reg.xetla_select<block_size_x, 1>(i));
90 }
91}
92
100template <typename T_dst, typename T_src>
102 typename std::enable_if_t<(T_src::register_layout != reg_layout::linear)
103 && (T_dst::register_layout != reg_layout::linear)
105 elemwise_cvt(T_dst &dst, T_src &src, float scale) {
106 dst.reg = xetla_cvt<typename T_dst::dtype, typename T_src::dtype>(
107 src.reg, scale);
108}
109
115template <typename T>
117 typename std::enable_if_t<T::register_layout == reg_layout::vnni_tiled>
118 vnni_convert(T &mat_Acc) {
119 constexpr uint32_t tile_size_y = T::tile_size_y;
120 constexpr uint32_t tile_size_x = T::tile_size_x;
121 constexpr uint32_t tile_elems = tile_size_y * tile_size_x;
122 constexpr uint32_t block_size_y = T::block_size_y;
123 constexpr uint32_t block_size_x = T::block_size_x;
124 constexpr uint32_t block_elems = block_size_y * block_size_x;
125 constexpr int32_t num_block_x = tile_size_x / block_size_x;
126 using dtype = typename T::dtype;
127 constexpr int32_t vnni_stride = sizeof(uint32_t) / sizeof(dtype);
128 constexpr int32_t move_cols = block_size_x * vnni_stride;
129 constexpr int32_t move_rows = block_size_y / vnni_stride;
131 static_assert(block_size_y % vnni_stride == 0, "vnni alignement check");
132 if constexpr (tile_size_x == 1) { return; }
133#pragma unroll
134 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
135#pragma unroll
136 for (uint32_t j = 0; j < num_block_x; j++) {
137 auto reg = (mat_Acc.reg)
138 .xetla_select<block_elems, 1>(
139 (i * num_block_x + j) * block_elems);
140 auto reg_2d = reg.xetla_format<native_type_t<dtype>, block_size_y,
141 block_size_x>();
142 auto reg_dst = rdst.xetla_select<block_elems, 1>(
143 (i * num_block_x + j) * block_elems);
144 auto reg_dst_2d = reg_dst.xetla_format<native_type_t<dtype>,
145 move_rows, move_cols>();
146#pragma unroll
147 for (uint32_t vnni_i = 0; vnni_i < vnni_stride; vnni_i++) {
148 reg_dst_2d
149 .xetla_select<move_rows, 1, block_size_x, vnni_stride>(
150 0, vnni_i)
151 = reg_2d.xetla_select<move_rows, vnni_stride,
152 block_size_x, 1>(vnni_i, 0);
153 }
154 }
155 }
156 // process the tail
157 if constexpr ((tile_size_y % block_size_y) != 0) {
158 constexpr int i = tile_size_y / block_size_y;
159 constexpr uint32_t remain_elems_start = i * block_size_y * tile_size_x;
160 constexpr uint32_t remain_size_y = tile_size_y % block_size_y;
161 constexpr uint32_t remain_block_elems = remain_size_y * block_size_x;
162 static_assert(
163 remain_size_y % vnni_stride == 0, "vnni alignement check");
164 constexpr int32_t remain_move_cols = block_size_x * vnni_stride;
165 constexpr int32_t remain_move_rows = remain_size_y / vnni_stride;
166#pragma unroll
167 for (uint32_t j = 0; j < num_block_x; j++) {
168 auto reg = (mat_Acc.reg)
169 .xetla_select<remain_block_elems, 1>(
170 remain_elems_start
171 + j * remain_block_elems);
172 auto reg_2d = reg.xetla_format<native_type_t<dtype>, remain_size_y,
173 block_size_x>();
174 auto reg_dst = rdst.xetla_select<remain_block_elems, 1>(
175 remain_elems_start + j * remain_block_elems);
176 auto reg_dst_2d = reg_dst.xetla_format<native_type_t<dtype>,
177 remain_move_rows, remain_move_cols>();
178 for (uint32_t vnni_i = 0; vnni_i < vnni_stride; vnni_i++) {
179 reg_dst_2d.xetla_select<remain_move_rows, 1, block_size_x,
180 vnni_stride>(0, vnni_i)
181 = reg_2d.xetla_select<remain_move_rows, vnni_stride,
182 block_size_x, 1>(vnni_i, 0);
183 }
184 }
185 }
186 mat_Acc.reg = rdst;
187}
188
194template <typename T>
195__XETLA_API typename std::enable_if_t<T::register_layout == reg_layout::tiled>
196vnni_reverse(T &mat_Acc) {
197 constexpr uint32_t tile_size_y = T::tile_size_y;
198 constexpr uint32_t tile_size_x = T::tile_size_x;
199 constexpr uint32_t tile_elems = tile_size_y * tile_size_x;
200 constexpr uint32_t block_size_y = T::block_size_y;
201 constexpr uint32_t block_size_x = T::block_size_x;
202 constexpr uint32_t block_elems = block_size_y * block_size_x;
203 constexpr int32_t num_block_x = tile_size_x / block_size_x;
204 using dtype = typename T::dtype;
205 constexpr int32_t vnni_stride = sizeof(uint32_t) / sizeof(dtype);
206 constexpr int32_t move_cols = block_size_x * vnni_stride;
207 constexpr int32_t move_rows = block_size_y / vnni_stride;
209 static_assert(block_size_y % vnni_stride == 0, "vnni alignement check");
210 if constexpr (tile_size_x == 1) { return; }
211#pragma unroll
212 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
213#pragma unroll
214 for (uint32_t j = 0; j < num_block_x; j++) {
215 auto reg = (mat_Acc.reg)
216 .xetla_select<block_elems, 1>(
217 (i * num_block_x + j) * block_elems);
218 auto reg_2d = reg.xetla_format<native_type_t<dtype>, move_rows,
219 move_cols>();
220 auto reg_dst = rdst.xetla_select<block_elems, 1>(
221 (i * num_block_x + j) * block_elems);
222 auto reg_dst_2d = reg_dst.xetla_format<native_type_t<dtype>,
223 block_size_y, block_size_x>();
224#pragma unroll
225 for (uint32_t vnni_i = 0; vnni_i < vnni_stride; vnni_i++) {
226 reg_dst_2d
227 .xetla_select<move_rows, vnni_stride, block_size_x, 1>(
228 vnni_i, 0)
229 = reg_2d.xetla_select<move_rows, 1, block_size_x,
230 vnni_stride>(0, vnni_i);
231 }
232 }
233 }
234 // process the tail
235 if constexpr ((tile_size_y % block_size_y) != 0) {
236 constexpr int i = tile_size_y / block_size_y;
237 constexpr uint32_t remain_elems_start = i * block_size_y * tile_size_x;
238 constexpr uint32_t remain_size_y = tile_size_y % block_size_y;
239 constexpr uint32_t remain_block_elems = remain_size_y * block_size_x;
240 static_assert(
241 remain_size_y % vnni_stride == 0, "vnni alignement check");
242 constexpr int32_t remain_move_cols = block_size_x * vnni_stride;
243 constexpr int32_t remain_move_rows = remain_size_y / vnni_stride;
244#pragma unroll
245 for (uint32_t j = 0; j < num_block_x; j++) {
246 auto reg = (mat_Acc.reg)
247 .xetla_select<remain_block_elems, 1>(
248 remain_elems_start
249 + j * remain_block_elems);
250 auto reg_2d = reg.xetla_format<native_type_t<dtype>,
251 remain_move_rows, remain_move_cols>();
252 auto reg_dst = rdst.xetla_select<remain_block_elems, 1>(
253 remain_elems_start + j * remain_block_elems);
254 auto reg_dst_2d = reg_dst.xetla_format<native_type_t<dtype>,
255 remain_size_y, block_size_x>();
256 for (uint32_t vnni_i = 0; vnni_i < vnni_stride; vnni_i++) {
257 reg_dst_2d.xetla_select<remain_move_rows, vnni_stride,
258 block_size_x, 1>(vnni_i, 0)
259 = reg_2d.xetla_select<remain_move_rows, 1, block_size_x,
260 vnni_stride>(0, vnni_i);
261 }
262 }
263 }
264 mat_Acc.reg = rdst;
265}
266
272template <typename T>
273__XETLA_API typename std::enable_if_t<T::register_layout
275vnni_reverse(T &mat_Acc) {
276 constexpr uint32_t tile_size_y = T::tile_size_y;
277 constexpr uint32_t tile_size_x = T::tile_size_x;
278 constexpr uint32_t tile_elems = tile_size_y * tile_size_x;
279 constexpr uint32_t block_size_y = T::block_size_y;
280 constexpr uint32_t block_size_x = T::block_size_x;
281 constexpr uint32_t block_elems = block_size_y * block_size_x;
282 constexpr int32_t num_block_x = tile_size_x / block_size_x;
283 using dtype = typename T::dtype;
284 constexpr int32_t vnni_stride = sizeof(uint32_t) / sizeof(dtype);
285 constexpr int32_t move_cols = block_size_y * vnni_stride;
286 constexpr int32_t move_rows = block_size_x / vnni_stride;
288 static_assert(block_size_x % vnni_stride == 0, "vnni alignement check");
289 if constexpr (tile_size_y == 1) { return; }
290#pragma unroll
291 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
292#pragma unroll
293 for (uint32_t j = 0; j < num_block_x; j++) {
294 auto reg = (mat_Acc.reg)
295 .xetla_select<block_elems, 1>(
296 (i * num_block_x + j) * block_elems);
297 auto reg_2d = reg.xetla_format<native_type_t<dtype>, move_rows,
298 move_cols>();
299 auto reg_dst = rdst.xetla_select<block_elems, 1>(
300 (i * num_block_x + j) * block_elems);
301 //transpose
302 auto reg_dst_2d = reg_dst.xetla_format<native_type_t<dtype>,
303 block_size_x, block_size_y>();
304 for (uint32_t vnni_i = 0; vnni_i < vnni_stride; vnni_i++) {
305 reg_dst_2d
306 .xetla_select<move_rows, vnni_stride, block_size_y, 1>(
307 vnni_i, 0)
308 = reg_2d.xetla_select<move_rows, 1, block_size_y,
309 vnni_stride>(0, vnni_i);
310 }
311 }
312 }
313 // process the tail
314 if constexpr ((tile_size_y % block_size_y) != 0) {
315 constexpr int i = tile_size_y / block_size_y;
316 constexpr uint32_t remain_elems_start = i * block_size_y * tile_size_x;
317 constexpr uint32_t remain_size_y = tile_size_y % block_size_y;
318 constexpr uint32_t remain_block_elems = remain_size_y * block_size_x;
319 constexpr int32_t remain_move_cols = remain_size_y * vnni_stride;
320 constexpr int32_t remain_move_rows = block_size_x / vnni_stride;
321#pragma unroll
322 for (uint32_t j = 0; j < num_block_x; j++) {
323 auto reg = (mat_Acc.reg)
324 .xetla_select<remain_block_elems, 1>(
325 remain_elems_start
326 + j * remain_block_elems);
327 auto reg_2d = reg.xetla_format<native_type_t<dtype>,
328 remain_move_rows, remain_move_cols>();
329 auto reg_dst = rdst.xetla_select<remain_block_elems, 1>(
330 remain_elems_start + j * remain_block_elems);
331 //transpose
332 auto reg_dst_2d = reg_dst.xetla_format<native_type_t<dtype>,
333 block_size_x, remain_size_y>();
334
335 for (uint32_t vnni_i = 0; vnni_i < vnni_stride; vnni_i++) {
336 reg_dst_2d.xetla_select<remain_move_rows, vnni_stride,
337 remain_size_y, 1>(vnni_i, 0)
338 = reg_2d.xetla_select<remain_move_rows, 1,
339 remain_size_y, vnni_stride>(0, vnni_i);
340 }
341 }
342 }
343 mat_Acc.reg = rdst;
344}
345
353template <typename T_dst, typename T_src>
354__XETLA_API typename std::enable_if_t<is_same_layout<T_dst, T_src>::value>
355vnni_transform(T_dst &dst, T_src &src) {
356 constexpr uint32_t tile_size_y = T_dst::tile_size_y;
357 constexpr uint32_t tile_size_x = T_dst::tile_size_x;
358 constexpr uint32_t tile_elems = tile_size_y * tile_size_x;
359 constexpr uint32_t block_size_y = T_dst::block_size_y;
360 constexpr uint32_t block_size_x = T_dst::block_size_x;
361 constexpr uint32_t block_elems = block_size_y * block_size_x;
362 constexpr int32_t num_block_x = tile_size_x / block_size_x;
363 using dtype_dst = typename T_dst::dtype;
364 using dtype_src = typename T_src::dtype;
365 constexpr uint32_t vnni_row_src = sizeof(uint32_t) / sizeof(dtype_src);
366 constexpr uint32_t vnni_row_dst = sizeof(uint32_t) / sizeof(dtype_dst);
367 constexpr int32_t vnni_row
368 = vnni_row_src > vnni_row_dst ? vnni_row_src : vnni_row_dst;
369 static_assert(block_size_y % vnni_row == 0);
370 static_assert(tile_size_y % vnni_row == 0);
371 constexpr int32_t move_elems = vnni_row * block_size_x;
373 = xetla_cvt<dtype_dst, dtype_src, tile_elems>(src.reg);
374 if constexpr (sizeof(dtype_src) == sizeof(dtype_dst)) {
375 dst.reg = reg_src;
376 return;
377 }
379 constexpr uint32_t scale_factor
381 using move_dtype = get_uint_type_t<sizeof(dtype_dst) * scale_factor>;
382 constexpr uint32_t select_stride = vnni_row / scale_factor;
383#pragma unroll
384 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
385#pragma unroll
386 for (uint32_t j = 0; j < num_block_x; j++) {
387 auto reg_src_blk = reg_src.xetla_select<block_elems, 1>(
388 (i * num_block_x + j) * block_elems);
389 auto reg_dst_blk = reg_dst.xetla_select<block_elems, 1>(
390 (i * num_block_x + j) * block_elems);
391 for (uint32_t row_i = 0; row_i < block_size_y; row_i += vnni_row) {
392 auto reg_src_move
393 = reg_src_blk
394 .xetla_select<move_elems, 1>(
395 row_i * block_size_x)
397 auto reg_dst_move
398 = reg_dst_blk
399 .xetla_select<move_elems, 1>(
400 row_i * block_size_x)
402#pragma unroll
403 for (uint32_t move_i = 0; move_i < select_stride; move_i++) {
404 if constexpr (sizeof(dtype_dst) > sizeof(dtype_src)) {
405 reg_dst_move.xetla_select<block_size_x, 1>(
406 move_i * block_size_x)
407 = reg_src_move.xetla_select<block_size_x,
408 select_stride>(move_i);
409 } else {
410 reg_dst_move.xetla_select<block_size_x, select_stride>(
411 move_i)
412 = reg_src_move.xetla_select<block_size_x, 1>(
413 move_i * block_size_x);
414 }
415 }
416 }
417 }
418 }
419 // process the tail
420 if constexpr ((tile_size_y % block_size_y) != 0) {
421 constexpr int i = tile_size_y / block_size_y;
422 constexpr uint32_t remain_elems_start = i * block_size_y * tile_size_x;
423 constexpr uint32_t remain_size_y = tile_size_y % block_size_y;
424 constexpr uint32_t remain_block_elems = remain_size_y * block_size_x;
425#pragma unroll
426 for (uint32_t j = 0; j < num_block_x; j++) {
427 auto reg_src_blk = reg_src.xetla_select<remain_block_elems, 1>(
428 remain_elems_start + j * remain_block_elems);
429 auto reg_dst_blk = reg_dst.xetla_select<remain_block_elems, 1>(
430 remain_elems_start + j * remain_block_elems);
431 // for mma, here we can guarantee that the remaining is a multiple of
432 // vnni_row
433 for (uint32_t row_i = 0; row_i < remain_size_y; row_i += vnni_row) {
434 auto reg_src_move
435 = reg_src_blk
436 .xetla_select<move_elems, 1>(
437 row_i * block_size_x)
439 auto reg_dst_move
440 = reg_dst_blk
441 .xetla_select<move_elems, 1>(
442 row_i * block_size_x)
444#pragma unroll
445 for (uint32_t move_i = 0; move_i < select_stride; move_i++) {
446 if constexpr (sizeof(dtype_dst) > sizeof(dtype_src)) {
447 reg_dst_move.xetla_select<block_size_x, 1>(
448 move_i * block_size_x)
449 = reg_src_move.xetla_select<block_size_x,
450 select_stride>(move_i);
451 } else {
452 reg_dst_move.xetla_select<block_size_x, select_stride>(
453 move_i)
454 = reg_src_move.xetla_select<block_size_x, 1>(
455 move_i * block_size_x);
456 }
457 }
458 }
459 }
460 }
461 dst.reg = reg_dst;
462}
463
471template <typename T_dst, typename T_src>
473 typename std::enable_if_t<(T_dst::register_layout == reg_layout::tiled)
474 && (T_src::register_layout == reg_layout::tiled)
475 && (T_src::tile_size_x == T_dst::tile_size_x)
476 && (T_src::tile_size_y == 1)>
477 row_broadcast(T_dst &dst, T_src &src) {
478 static constexpr uint32_t dst_tile_size_y = T_dst::tile_size_y;
479 static constexpr uint32_t dst_tile_size_x = T_dst::tile_size_x;
480 static constexpr uint32_t dst_block_size_y = T_dst::block_size_y;
481 static constexpr uint32_t dst_block_size_x = T_dst::block_size_x;
482 static constexpr uint32_t dst_block_elems = T_dst::block_elems;
483 static constexpr int32_t dst_num_block_x = T_dst::num_block_x;
484 using dst_dtype = typename T_dst::dtype;
485 using src_dtype = typename T_src::dtype;
486
487#pragma unroll
488 for (uint32_t i = 0; i < dst_tile_size_y / dst_block_size_y; i++) {
489#pragma unroll
490 for (uint32_t j = 0; j < dst_num_block_x; j++) {
491 auto dst_reg
492 = (dst.reg)
493 .xetla_select<dst_block_elems, 1>(
494 (i * dst_num_block_x + j)
495 * dst_block_elems)
497 dst_block_size_y, dst_block_size_x>();
498#pragma unroll
499 for (uint32_t row_i = 0; row_i < dst_block_size_y; row_i++) {
500 auto src_reg = src.reg.xetla_select<dst_block_size_x, 1>(
501 j * dst_block_size_x);
502 dst_reg.row(row_i)
503 = xetla_cvt<dst_dtype, src_dtype, dst_block_size_x>(
504 src_reg);
505 }
506 }
507 }
508
509 // process the tail
510 if constexpr ((dst_tile_size_y % dst_block_size_y) != 0) {
511 constexpr uint32_t tail_start_y
512 = dst_tile_size_y / dst_block_size_y * dst_block_size_y;
513 constexpr int32_t dst_tail_size_y = dst_tile_size_y % dst_block_size_y;
514 constexpr int32_t dst_tail_block_elems
515 = dst_tail_size_y * dst_block_size_x;
516#pragma unroll
517 for (uint32_t j = 0; j < dst_num_block_x; j++) {
518 auto dst_reg = (dst.reg)
519 .xetla_select<dst_tail_block_elems, 1>(
520 tail_start_y * dst_tile_size_x
521 + j * dst_tail_block_elems)
522 .xetla_format<native_type_t<dst_dtype>,
523 dst_tail_size_y, dst_block_size_x>();
524#pragma unroll
525 for (uint32_t row_i = 0; row_i < dst_tail_size_y; row_i++) {
526 auto src_reg = src.reg.xetla_select<dst_block_size_x, 1>(
527 j * dst_block_size_x);
528 dst_reg.row(row_i)
529 = xetla_cvt<dst_dtype, src_dtype, dst_block_size_x>(
530 src_reg);
531 }
532 }
533 }
534}
535
543template <typename T_dst, typename T_src>
544__XETLA_API typename std::enable_if_t<(T_dst::register_layout
546 && (T_src::register_layout == reg_layout::tiled)
547 && (T_src::tile_size_x == T_dst::tile_size_x)
548 && (T_src::tile_size_y == T_dst::tile_size_y)
549 && (T_dst::tile_size_x == T_dst::block_size_x)
550 && (T_dst::tile_size_y == T_dst::block_size_y)
551 && (std::is_same<typename T_dst::dtype, typename T_src::dtype>::value)>
552layout_convert(T_dst &dst, T_src &src) {
553 using tile_desc = typename T_src::tile_desc;
554 using dtype = typename T_dst::dtype;
555 static constexpr uint32_t num_block_x = tile_desc::num_block_x;
556 static constexpr uint32_t num_block_y = tile_desc::num_block_y;
557 static constexpr uint32_t block_elems = tile_desc::block_elems;
558 static constexpr uint32_t block_size_x = tile_desc::block_size_x;
559 static constexpr uint32_t block_size_y = tile_desc::block_size_y;
560 static constexpr uint32_t tile_size_x = tile_desc::tile_size_x;
561 static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
562
563 auto dst_reg = dst.reg.xetla_format<native_type_t<dtype>, tile_size_y,
564 tile_size_x>();
565#pragma unroll
566 for (uint32_t i = 0; i < num_block_y; ++i) {
567 uint32_t offset_y = i * block_size_y;
568#pragma unroll
569 for (uint32_t j = 0; j < num_block_x; ++j) {
570 uint32_t offset_x = j * block_size_x;
571 auto src_reg = src.reg.xetla_select<block_elems, 1>(
572 (i * num_block_x + j) * block_elems);
573 dst_reg.xetla_select<block_size_y, 1, block_size_x, 1>(
574 offset_y, offset_x)
575 = src_reg;
576 }
577 }
578 // process the tail
579 if constexpr (tile_desc::remained_size_y > 0) {
580 constexpr uint32_t remained_size_y = tile_desc::remained_size_y;
581 constexpr uint32_t offset_y = tile_size_y - remained_size_y;
582 constexpr uint32_t processed_elems = offset_y * tile_size_x;
583 constexpr uint32_t remained_block_elems
584 = remained_size_y * block_size_x;
585#pragma unroll
586 for (uint32_t j = 0; j < num_block_x; ++j) {
587 uint32_t offset_x = j * block_size_x;
588 auto src_reg = src.reg.xetla_select<remained_block_elems, 1>(
589 processed_elems + j * remained_block_elems);
590 dst_reg.xetla_select<remained_size_y, 1, block_size_x, 1>(
591 offset_y, offset_x)
592 = src_reg;
593 }
594 }
595}
596
604template <typename T_dst, typename T_src>
605__XETLA_API typename std::enable_if_t<(T_dst::register_layout
607 && (T_src::register_layout == reg_layout::linear)
608 && (T_dst::tile_size_x == T_src::tile_size_x)
609 && (T_dst::tile_size_y == T_src::tile_size_y)
610 && (T_src::tile_size_x == T_src::block_size_x)
611 && (T_src::tile_size_y == T_src::block_size_y)
612 && (std::is_same<typename T_dst::dtype, typename T_src::dtype>::value)>
613layout_convert(T_dst &dst, T_src &src) {
614 using tile_desc = typename T_dst::tile_desc;
615 using dtype = typename T_dst::dtype;
616 static constexpr uint32_t num_block_x = tile_desc::num_block_x;
617 static constexpr uint32_t num_block_y = tile_desc::num_block_y;
618 static constexpr uint32_t block_elems = tile_desc::block_elems;
619 static constexpr uint32_t block_size_x = tile_desc::block_size_x;
620 static constexpr uint32_t block_size_y = tile_desc::block_size_y;
621 static constexpr uint32_t tile_size_x = tile_desc::tile_size_x;
622 static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
623
624 auto src_reg = src.reg.xetla_format<native_type_t<dtype>, tile_size_y,
625 tile_size_x>();
626#pragma unroll
627 for (uint32_t i = 0; i < num_block_y; ++i) {
628 uint32_t offset_y = i * block_size_y;
629#pragma unroll
630 for (uint32_t j = 0; j < num_block_x; ++j) {
631 uint32_t offset_x = j * block_size_x;
632 auto dst_reg = dst.reg.xetla_select<block_elems, 1>(
633 (i * num_block_x + j) * block_elems);
634 dst_reg = src_reg.xetla_select<block_size_y, 1, block_size_x, 1>(
635 offset_y, offset_x);
636 }
637 }
638 // process the tail
639 if constexpr (tile_desc::remained_size_y > 0) {
640 constexpr uint32_t remained_size_y = tile_desc::remained_size_y;
641 constexpr uint32_t offset_y = tile_size_y - remained_size_y;
642 constexpr uint32_t processed_elems = offset_y * tile_size_x;
643 constexpr uint32_t remained_block_elems
644 = remained_size_y * block_size_x;
645#pragma unroll
646 for (uint32_t j = 0; j < num_block_x; ++j) {
647 uint32_t offset_x = j * block_size_x;
648 auto dst_reg = dst.reg.xetla_select<remained_block_elems, 1>(
649 processed_elems + j * remained_block_elems);
650 dst_reg = src_reg.xetla_select<remained_size_y, 1, block_size_x, 1>(
651 offset_y, offset_x);
652 }
653 }
654}
655} // namespace gpu::xetla::subgroup
#define __XETLA_API
Definition common.hpp:43
#define xetla_format
xetla format.
Definition base_ops.hpp:38
typename get_uint_type< Size >::type get_uint_type_t
Return the uint representation based on Size.
Definition base_types.hpp:141
typename native_type< T >::type native_type_t
Return the native data type of T.
Definition base_types.hpp:106
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
Definition limitation.hpp:457
__XETLA_API std::enable_if_t< T::register_layout==reg_layout::vnni_tiled > vnni_convert(T &mat_Acc)
Converts tiled layout to vnni_tiled layout format.
Definition op_function.hpp:118
__XETLA_API std::enable_if_t<(T_src::register_layout !=reg_layout::linear) &&(T_dst::register_layout !=reg_layout::linear) &&is_same_layout< T_dst, T_src >::value &&(!is_floating_to_integer< T_dst, T_src >::value)> elemwise_cvt(T_dst &dst, T_src &src)
Is the element wise data conversion, the src and dst tile should have the same layout.
Definition op_function.hpp:40
__XETLA_API std::enable_if_t< T::register_layout==reg_layout::tiled > vnni_reverse(T &mat_Acc)
Converts vnni_tiled layout format to tiled layout.
Definition op_function.hpp:196
__XETLA_API std::enable_if_t<(T_dst::register_layout==reg_layout::tiled) &&(T_src::register_layout==reg_layout::tiled) &&(T_src::tile_size_x==T_dst::tile_size_x) &&(T_src::tile_size_y==1)> row_broadcast(T_dst &dst, T_src &src)
Broadcasts 1d src tile to the entire 2d tile, as well as do the data conversion.
Definition op_function.hpp:477
__XETLA_API std::enable_if_t< is_same_layout< T_dst, T_src >::value > vnni_transform(T_dst &dst, T_src &src)
Changes vnni layout.
Definition op_function.hpp:355
__XETLA_API std::enable_if_t<(T_dst::register_layout==reg_layout::linear) &&(T_src::register_layout==reg_layout::tiled) &&(T_src::tile_size_x==T_dst::tile_size_x) &&(T_src::tile_size_y==T_dst::tile_size_y) &&(T_dst::tile_size_x==T_dst::block_size_x) &&(T_dst::tile_size_y==T_dst::block_size_y) &&(std::is_same< typename T_dst::dtype, typename T_src::dtype >::value)> layout_convert(T_dst &dst, T_src &src)
convert 2d tile in a tiled register layout to a 2d tile in a linear register layout
Definition op_function.hpp:552
Definition common.hpp:80
static constexpr bool value
Definition common.hpp:223
static constexpr bool value
Definition common.hpp:214
C++ API.
C++ API.