29template <
typename tile_t,
typename payload_t>
38 && (tile_t::tile_size_y == 1) && (tile_t::block_size_y == 1)
78 using load_dtype =
typename payload_t::mem_dtype;
81 static constexpr uint32_t tile_size_x = tile_desc::tile_size_x;
82 static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
83 static constexpr uint32_t block_size_x = tile_desc::block_size_x;
84 static constexpr uint32_t block_size_y = tile_desc::block_size_y;
85 static constexpr uint32_t remained_size_y = tile_desc::remained_size_y;
87 static constexpr uint32_t block_elems = tile_desc::block_elems;
89 static constexpr uint32_t num_block_x = tile_desc::num_block_x;
90 static constexpr uint32_t num_block_y = tile_desc::num_block_y;
91 static constexpr uint32_t num_block = tile_desc::num_block;
93 static constexpr gpu_arch arch_tag = payload_t::arch_tag;
95 static constexpr reg_layout reg_layout_ = tile_desc::register_layout;
96 static constexpr bool is_vnni_reverse = payload_t::mem_dword_transpose
99 static constexpr bool reg_transpose = tile_desc::reg_transpose;
101 static constexpr bool mem_transpose = payload_t::mem_transpose;
102 static constexpr bool trans = reg_transpose ^ mem_transpose;
103 static constexpr uint32_t scale_factor = payload_t::scale_factor;
105 static constexpr bool mem_transform = payload_t::mem_transform;
108 arch_tag>::template load_store_attr<msg_type::block_2d>;
109 static constexpr uint32_t elems_per_CL
110 = load_store_attr::cache_line_size_in_bytes /
sizeof(dtype);
111 static constexpr uint32_t elems_per_reg
114 static constexpr int32_t max_load_block_height
115 = load_store_attr::max_load_height_in_elem;
116 static constexpr int32_t max_block_width
117 = load_store_attr::max_load_width_in_bytes /
sizeof(dtype);
118 static constexpr int32_t max_trans_block_width
119 = load_store_attr::max_trans_load_width_in_bytes /
sizeof(dtype);
121 static constexpr uint32_t ld_blk_size_y_limit
122 = mem_transpose ? max_trans_block_width : max_load_block_height;
123 static constexpr uint32_t ld_blk_size_y = reg_transpose
125 : (block_size_y > ld_blk_size_y_limit ? ld_blk_size_y_limit
130 static constexpr uint8_t arr_len_candidate
135 || ((block_size_y * block_size_x) % elems_per_reg != 0)
137 || (((tile_size_y % block_size_y) * block_size_x)
140 || (block_size_y > ld_blk_size_y_limit)
142 : (((tile_size_x % elems_per_CL) == 0)
143 ? (((elems_per_CL % block_size_x) == 0)
144 ? elems_per_CL / block_size_x
146 : ((tile_size_x < elems_per_CL)
147 ? (tile_size_x / block_size_x)
149 static constexpr bool is_valid_arr_len_candidate = (arr_len_candidate == 1)
150 || (arr_len_candidate == 2) || (arr_len_candidate == 4);
152 static constexpr uint8_t arr_len
153 = is_valid_arr_len_candidate ? arr_len_candidate : 1;
155 static_assert(reg_transpose || mem_transpose
157 && (block_size_x * arr_len) <= max_block_width),
158 "When reg_transpose was disabled, check 2d block width "
160 static_assert(!reg_transpose
162 && (block_size_x * arr_len)
163 <= max_trans_block_width)
165 && (block_size_y * arr_len) <= max_block_width),
166 "When reg_transpose was enabled, check 2d block width "
168 static_assert(!reg_transpose
170 && (block_size_y <= max_load_block_height))
172 && (block_size_x) <= max_load_block_height),
173 "When reg_transpose was enabled, check 2d block height "
175 static_assert(tile_size_x % (block_size_x * arr_len) == 0,
176 "tile_size_x should be a multiple of (block_size_x * arr_len)");
179 && ((block_size_x *
sizeof(dtype)) %
sizeof(load_dtype)
181 || ((block_size_y *
sizeof(dtype)) %
sizeof(load_dtype)
183 "check vnni limitation for DW transpose");
185 auto payload_2d = payload.payloads.xetla_format<uint32_t, num_block, 16>();
187 for (uint32_t i = 0; i < num_block_y; ++i) {
188 constexpr uint32_t load_block_elems = block_elems * arr_len;
189 auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
191 detail::reset_tile_desc_core<num_block_x, block_size_x, ld_blk_size_y,
192 scale_factor, arr_len, mem_transpose>(payload_row);
194 for (uint32_t j = 0; j < num_block_x; j += arr_len) {
196 auto reg_blk =
tile.reg.xetla_select<load_block_elems, 1>(
197 (i * num_block_x + j) * block_elems);
198 constexpr uint32_t ld_blk_height = (reg_transpose && trans)
199 ? detail::getNextPowerOf2<ld_blk_size_y>()
201 constexpr uint32_t tmp_size
202 = ld_blk_height * block_size_x * arr_len;
205 for (uint32_t ii = 0; ii < block_size_y / ld_blk_size_y; ++ii) {
206 constexpr uint32_t load_elems
207 = ld_blk_size_y * block_size_x * arr_len;
211 ld_blk_height * block_size_x * arr_len
213 L1, L2, trans, mem_transform, arch_tag>(tdesc);
215 if constexpr (reg_transpose && trans) {
216 reg_blk.xetla_select<load_elems, 1>(ii * load_elems)
218 = reg_tmp.xetla_format<load_dtype,
219 block_size_x / scale_factor,
222 1, ld_blk_size_y, 1>(0, 0);
224 reg_blk.xetla_select<tmp_size, 1>(ii * tmp_size) = reg_tmp;
227 if constexpr (mem_transpose) {
229 ld_blk_size_y / scale_factor);
232 tdesc.xetla_format<uint32_t>(), ld_blk_size_y);
236 if constexpr (block_size_y % ld_blk_size_y != 0) {
237 constexpr uint32_t remained_start_y
238 = block_size_y / ld_blk_size_y * ld_blk_size_y;
239 constexpr uint32_t remained_start
240 = remained_start_y * block_size_x * arr_len;
241 constexpr uint32_t remained_blk_size_y
242 = block_size_y % ld_blk_size_y;
243 constexpr uint32_t load_elems
244 = remained_blk_size_y * block_size_x * arr_len;
246 constexpr uint8_t block_width = mem_transpose
247 ? (remained_blk_size_y / scale_factor)
249 constexpr uint8_t block_height
250 = trans ? block_size_x : remained_blk_size_y;
251 constexpr uint32_t block_widthx_widthy_arrlen
252 = (block_width - 1) | ((block_height - 1) << 8);
254 tdesc.xetla_format<uint32_t>(),
255 block_widthx_widthy_arrlen);
257 reg_blk.xetla_select<load_elems, 1>(remained_start)
260 (load_elems / scale_factor), L1, L2, trans,
261 mem_transform, arch_tag>(tdesc);
266 if constexpr (remained_size_y > 0) {
267 constexpr uint32_t remained_block_elems
268 = block_size_x * remained_size_y;
269 constexpr uint32_t processed_elems
270 = num_block_y * num_block_x * block_elems;
271 constexpr uint32_t remained_ld_blk_size_y
272 = (!reg_transpose && (remained_size_y > ld_blk_size_y_limit))
273 ? ld_blk_size_y_limit
275 auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
276 num_block_y * num_block_x, 0);
277 detail::reset_tile_desc_core<num_block_x, block_size_x,
278 remained_ld_blk_size_y, scale_factor, arr_len, mem_transpose>(
281 for (uint32_t j = 0; j < num_block_x; j += arr_len) {
284 =
tile.reg.xetla_select<remained_block_elems * arr_len, 1>(
285 processed_elems + j * remained_block_elems);
286 constexpr uint32_t ld_blk_height = (reg_transpose && trans)
287 ? detail::getNextPowerOf2<remained_ld_blk_size_y>()
288 : remained_ld_blk_size_y;
289 constexpr uint32_t tmp_size
290 = ld_blk_height * block_size_x * arr_len;
293 for (uint32_t ii = 0; ii < remained_size_y / remained_ld_blk_size_y;
295 constexpr uint32_t load_elems
296 = remained_ld_blk_size_y * block_size_x * arr_len;
300 (ld_blk_height * block_size_x * arr_len
302 L1, L2, trans, mem_transform, arch_tag>(tdesc);
304 if constexpr (reg_transpose && trans) {
305 reg_blk.xetla_select<load_elems, 1>(ii * load_elems)
307 = reg_tmp.xetla_format<load_dtype,
308 block_size_x / scale_factor,
311 1, remained_ld_blk_size_y, 1>(
314 reg_blk.xetla_select<tmp_size, 1>(ii * tmp_size) = reg_tmp;
316 if constexpr (mem_transpose) {
318 remained_ld_blk_size_y / scale_factor);
321 remained_ld_blk_size_y);
324 constexpr uint32_t final_ld_blk_size_y
325 = remained_size_y % remained_ld_blk_size_y;
326 if constexpr (final_ld_blk_size_y != 0) {
327 constexpr uint32_t final_start = remained_size_y
328 / remained_ld_blk_size_y * remained_ld_blk_size_y
329 * block_size_x * arr_len;
330 constexpr uint32_t final_load_elems
331 = final_ld_blk_size_y * block_size_x * arr_len;
332 constexpr uint8_t block_width = mem_transpose
333 ? (final_ld_blk_size_y / scale_factor)
335 constexpr uint8_t block_height
336 = trans ? block_size_x : final_ld_blk_size_y;
337 constexpr uint32_t block_widthx_widthy_arrlen
338 = (block_width - 1) | ((block_height - 1) << 8);
340 tdesc.xetla_format<uint32_t>(),
341 block_widthx_widthy_arrlen);
342 reg_blk.xetla_select<final_load_elems, 1>(final_start)
345 final_load_elems / scale_factor, L1, L2, trans,
346 mem_transform, arch_tag>(tdesc);
351 if constexpr (is_vnni_reverse) {
374 using load_dtype =
typename payload_t::mem_dtype;
376 static constexpr uint32_t tile_size_x = tile_t::tile_size_x;
377 static constexpr uint32_t scale_factor = payload_t::scale_factor;
378 constexpr uint32_t load_len = tile_size_x / scale_factor;
380 if constexpr (load_len >= 64) {
382 for (uint32_t i = 0; i < load_len / 64; i++) {
383 uint32_t offset_x = i * 64 * scale_factor;
385 =
tile.reg.xetla_select<64 * scale_factor, 1>(offset_x);
386 uint32_t address_offset = offset_x *
sizeof(dtype);
389 payload.base_ptr, payload.base_offset + address_offset);
392 constexpr uint32_t tail_len = load_len % 64;
393 uint32_t tail_offset = load_len / 64 * 64 * scale_factor;
394 detail::process_1d_tail<tail_len, 32, detail::process_flag::load, L1, L2>(
395 tile, payload, tail_offset);
412 typename oob_check_tag = global_atomic_oob_check_on_tag>
416 [[maybe_unused]] oob_check_tag tag = {}) {
417 constexpr bool oob_check = std::is_same<oob_check_tag,
418 global_atomic_oob_check_on_tag>::value;
419 using dtype =
typename payload_t::dtype;
420 using tile_desc =
typename payload_t::tile_desc;
421 using load_dtype =
typename payload_t::mem_dtype;
422 constexpr uint32_t num_channel_y = payload_t::num_channel_y;
423 constexpr uint32_t load_elems = num_channel_y * payload_t::num_channel_x;
424 constexpr uint32_t scale_factor = payload_t::scale_factor;
427 for (uint32_t i = 0; i < tile_desc::tile_size_y / tile_desc::block_size_y;
429 uint32_t offset_y = i * tile_desc::block_size_y;
431 for (uint32_t j = 0; j < tile_desc::num_block_x; j++) {
432 uint32_t offset_x = j * tile_desc::block_size_x;
433 auto reg_sub =
tile.reg.xetla_select<tile_desc::block_elems, 1>(
434 (i * tile_desc::num_block_x + j) * tile_desc::block_elems);
436 ? payload.step_x + payload.base_x + offset_x
437 < payload.width_in_elems
440 for (uint32_t sub_block_y = 0;
441 sub_block_y < tile_desc::block_size_y;
442 sub_block_y += num_channel_y) {
445 ? payload.step_y + payload.base_y + offset_y
447 < payload.height_in_elems
450 uint32_t address_offset = payload_t::trans
451 ? offset_x * payload.pitch_in_bytes
452 + (offset_y + sub_block_y) *
sizeof(dtype)
453 : offset_x *
sizeof(dtype)
454 + (offset_y + sub_block_y)
455 * payload.pitch_in_bytes;
460 payload.channel_offset + payload.base_offset
463 reg_tmp.xetla_merge(reg_tmp, 0, pred_x && pred_y);
465 reg_sub.xetla_select<load_elems * scale_factor, 1>(
466 sub_block_y * tile_desc::block_size_x)
467 .xetla_format<load_dtype>()
473 if constexpr ((tile_desc::tile_size_y % tile_desc::block_size_y) != 0) {
474 constexpr uint32_t remained_size_y = tile_desc::remained_size_y;
475 constexpr uint32_t offset_y = tile_desc::tile_size_y - remained_size_y;
476 constexpr uint32_t processed_elems = offset_y * tile_desc::tile_size_x;
477 constexpr uint32_t remain_block_elems
478 = remained_size_y * tile_desc::block_size_x;
480 for (uint32_t j = 0; j < tile_desc::num_block_x; j++) {
481 uint32_t offset_x = j * tile_desc::block_size_x;
482 auto reg_sub =
tile.reg.xetla_select<remain_block_elems, 1>(
483 processed_elems + j * remain_block_elems);
484 xetla_mask<load_elems> pred_x = oob_check
485 ? payload.step_x + payload.base_x + offset_x
486 < payload.width_in_elems
489 for (uint32_t sub_block_y = 0; sub_block_y < remained_size_y;
490 sub_block_y += num_channel_y) {
491 xetla_vector<load_dtype, load_elems> reg_tmp;
492 xetla_mask<load_elems> pred_y = oob_check
493 ? payload.step_y + payload.base_y + offset_y
495 < payload.height_in_elems
498 uint32_t address_offset = payload_t::trans
499 ? offset_x * payload.pitch_in_bytes
500 + (offset_y + sub_block_y) *
sizeof(dtype)
501 : offset_x *
sizeof(dtype)
502 + (offset_y + sub_block_y)
503 * payload.pitch_in_bytes;
508 payload.channel_offset + payload.base_offset
512 reg_tmp.xetla_merge(reg_tmp, 0, pred_x && pred_y);
515 sub_block_y * tile_desc::block_size_x)
516 .xetla_format<load_dtype>()
522 if constexpr (payload_t::mem_transform) {
545 using dtype =
typename payload_t::dtype;
546 using tile_desc =
typename payload_t::tile_desc;
547 using load_dtype =
typename payload_t::mem_dtype;
549 constexpr uint32_t num_channel_y = payload_t::num_channel_y;
550 constexpr uint32_t load_elems = num_channel_y * tile_desc::block_size_x;
551 static constexpr bool mem_transform = payload_t::mem_transform;
554 for (uint32_t i = 0; i < tile_desc::tile_size_y / tile_desc::block_size_y;
556 uint32_t offset_y = i * tile_desc::block_size_y;
558 for (uint32_t j = 0; j < tile_desc::num_block_x; j++) {
559 uint32_t offset_x = j * tile_desc::block_size_x;
560 auto reg_sub =
tile.reg.xetla_select<tile_desc::block_elems, 1>(
561 (i * tile_desc::num_block_x + j) * tile_desc::block_elems);
563 for (uint32_t sub_block_y = 0;
564 sub_block_y < tile_desc::block_size_y;
565 sub_block_y += num_channel_y) {
566 uint32_t address_offset = offset_x *
sizeof(dtype)
567 + (sub_block_y + offset_y) * payload.pitch_in_bytes;
568 reg_sub.xetla_select<load_elems, 1>(
569 sub_block_y * tile_desc::block_size_x)
570 .xetla_format<load_dtype>()
571 = xetla_load_local<load_dtype>(
572 payload.address + address_offset);
577 if constexpr ((tile_desc::tile_size_y % tile_desc::block_size_y) != 0) {
578 constexpr uint32_t remained_size_y = tile_desc::remained_size_y;
579 constexpr uint32_t offset_y = tile_desc::tile_size_y - remained_size_y;
580 constexpr uint32_t processed_elems = offset_y * tile_desc::tile_size_x;
581 constexpr uint32_t remain_block_elems
582 = remained_size_y * tile_desc::block_size_x;
584 for (uint32_t j = 0; j < tile_desc::num_block_x; j++) {
585 uint32_t offset_x = j * tile_desc::block_size_x;
586 auto reg_sub =
tile.reg.xetla_select<remain_block_elems, 1>(
587 processed_elems + j * remain_block_elems);
589 for (uint32_t sub_block_y = 0; sub_block_y < remained_size_y;
590 sub_block_y += num_channel_y) {
591 uint32_t address_offset = offset_x *
sizeof(dtype)
592 + (sub_block_y + offset_y) * payload.pitch_in_bytes;
593 reg_sub.xetla_select<load_elems, 1>(
594 sub_block_y * tile_desc::block_size_x)
595 .xetla_format<load_dtype>()
596 = xetla_load_local<load_dtype>(
597 payload.address + address_offset);
601 if constexpr (mem_transform) {
626 using load_dtype =
typename payload_t::mem_dtype;
628 constexpr uint32_t scale_factor = payload_t::scale_factor;
629 constexpr uint32_t load_len = tile_desc::tile_size_x / scale_factor;
630 if constexpr (load_len >= 64) {
632 for (uint32_t j = 0; j < load_len / 64; j++) {
633 uint32_t offset_x = j * 64 * scale_factor;
635 =
tile.reg.xetla_select<64 * scale_factor, 1>(offset_x);
636 uint32_t address_offset = offset_x *
sizeof(dtype);
637 reg_sub.xetla_format<load_dtype>()
638 = xetla_load_local<load_dtype, 64, data_size::default_size>(
639 payload.address + address_offset);
643 L2>(
tile, payload, load_len / 64 * 64 * scale_factor);
#define SW_BARRIER()
SW_BARRIER, insert software scheduling barrier, for better code control.
Definition common.hpp:227
#define __XETLA_API
Definition common.hpp:43
#define xetla_select
xetla select.
Definition base_ops.hpp:49
#define xetla_format
xetla format.
Definition base_ops.hpp:38
xetla_vector< uint32_t, 16 > xetla_tdescriptor
Description of nd tensor descriptor for load and store.
Definition base_types.hpp:155
typename native_type< T >::type native_type_t
Return the native data type of T.
Definition base_types.hpp:106
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
__XETLA_API xetla_vector< Ty, N *NElts > xetla_load_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_mask< N > pred=1)
Stateless scattered load.
Definition memory.hpp:245
__XETLA_API void xetla_update_tdesc_offsetx(xetla_tdescriptor_ref tdesc, int32_t doffset_x)
Update the x coordinate in the given tensor descriptor.
Definition raw_send_load_store.hpp:152
__XETLA_API void xetla_update_tdesc_offsety(xetla_tdescriptor_ref tdesc, int32_t doffset_y)
Update the y coordinate in the given tensor descriptor.
Definition raw_send_load_store.hpp:161
__XETLA_API std::enable_if_t< arch_tag==gpu_arch::Xe, xetla_vector< Ty, N > > xetla_tload_global(xetla_tdescriptor tdesc)
Tensor load API.
Definition raw_send_load_store.hpp:183
__XETLA_API void xetla_set_block_widthx_widthy_arrlen(xetla_tdescriptor_ref desc, uint32_t block_widthx_widthy_arrlen)
Definition tensor_descriptor.hpp:79
__XETLA_API uint32_t uint32_t uint32_t scale_factor
Definition common.hpp:195
__XETLA_API std::enable_if_t< base_len==0 > process_1d_tail(tile_t &tile, payload_t &payload, uint32_t offset)
Definition common.hpp:96
Definition limitation.hpp:457
__XETLA_API std::enable_if_t< T::register_layout==reg_layout::vnni_tiled > vnni_convert(T &mat_Acc)
Converts tiled layout to vnni_tiled layout format.
Definition op_function.hpp:118
__XETLA_API std::enable_if_t< T::register_layout==reg_layout::tiled > vnni_reverse(T &mat_Acc)
Converts vnni_tiled layout format to tiled layout.
Definition op_function.hpp:196
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
cache_hint
L1 or L2 cache hint kinds.
Definition common.hpp:89
reg_layout
tile layout in register linear: linear layout with one tile tiled: 2d block stacked in raster order v...
Definition common.hpp:209
@ tile
flush out to the local scope
gpu_arch
Definition common.hpp:73
Definition arch_config.hpp:72
Definition load_xe.hpp:30
static constexpr bool is_global_block_1d_xe
Definition load_xe.hpp:37
static constexpr bool is_local_block_1d_xe
Definition load_xe.hpp:53
static constexpr bool is_global_unaligned_2d_xe
Definition load_xe.hpp:43
static constexpr bool is_global_2d_xe
Definition load_xe.hpp:32
static constexpr bool is_local_scatter_xe
Definition load_xe.hpp:48
Is a struct contains some register file.
Definition api.hpp:99
tile_desc_ tile_desc
Definition api.hpp:101
dtype_ dtype
Definition api.hpp:100