XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
load_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "subgroup/tile/api.hpp"
25
26namespace gpu::xetla::subgroup {
27
28namespace detail {
29template <typename tile_t, typename payload_t>
31 static constexpr bool is_global_2d_xe
32 = (payload_t::memory_space == mem_space::global
33 && (payload_t::message_type == msg_type::block_2d)
34 && (payload_t::arch_tag == gpu_arch::Xe));
35
36 static constexpr bool is_global_block_1d_xe
37 = ((payload_t::memory_space == mem_space::global)
38 && (tile_t::tile_size_y == 1) && (tile_t::block_size_y == 1)
39 && (payload_t::message_type == msg_type::block_1d)
40 && (payload_t::arch_tag == gpu_arch::Xe));
41
42 static constexpr bool is_global_unaligned_2d_xe
43 = ((payload_t::memory_space == mem_space::global)
44 && (payload_t::message_type == msg_type::unaligned_2d)
45 && (payload_t::arch_tag == gpu_arch::Xe));
46
47 static constexpr bool is_local_scatter_xe
48 = ((payload_t::memory_space == mem_space::local)
49 && (payload_t::message_type == msg_type::scatter)
50 && (payload_t::arch_tag == gpu_arch::Xe));
51
52 static constexpr bool is_local_block_1d_xe
53 = ((payload_t::memory_space == mem_space::local)
54 && (payload_t::message_type == msg_type::block_1d)
55 && (payload_t::arch_tag == gpu_arch::Xe));
56};
57
58} // namespace detail
59
72template <cache_hint L1 = cache_hint::cached,
73 cache_hint L2 = cache_hint::cached, typename tile_t, typename payload_t>
74__XETLA_API typename std::enable_if_t<
76tile_load(tile_t &tile, payload_t &payload) {
77 using dtype = typename tile_t::dtype;
78 using load_dtype = typename payload_t::mem_dtype;
79 using tile_desc = typename tile_t::tile_desc;
80
81 static constexpr uint32_t tile_size_x = tile_desc::tile_size_x;
82 static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
83 static constexpr uint32_t block_size_x = tile_desc::block_size_x;
84 static constexpr uint32_t block_size_y = tile_desc::block_size_y;
85 static constexpr uint32_t remained_size_y = tile_desc::remained_size_y;
86
87 static constexpr uint32_t block_elems = tile_desc::block_elems;
88
89 static constexpr uint32_t num_block_x = tile_desc::num_block_x;
90 static constexpr uint32_t num_block_y = tile_desc::num_block_y;
91 static constexpr uint32_t num_block = tile_desc::num_block;
92
93 static constexpr gpu_arch arch_tag = payload_t::arch_tag;
94
95 static constexpr reg_layout reg_layout_ = tile_desc::register_layout;
96 static constexpr bool is_vnni_reverse = payload_t::mem_dword_transpose
97 && ((reg_layout_ == reg_layout::tiled)
98 || (reg_layout_ == reg_layout::transpose_tiled));
99 static constexpr bool reg_transpose = tile_desc::reg_transpose;
100
101 static constexpr bool mem_transpose = payload_t::mem_transpose;
102 static constexpr bool trans = reg_transpose ^ mem_transpose;
103 static constexpr uint32_t scale_factor = payload_t::scale_factor;
104
105 static constexpr bool mem_transform = payload_t::mem_transform;
106
107 using load_store_attr = typename arch_attr_t<
108 arch_tag>::template load_store_attr<msg_type::block_2d>;
109 static constexpr uint32_t elems_per_CL
110 = load_store_attr::cache_line_size_in_bytes / sizeof(dtype);
111 static constexpr uint32_t elems_per_reg
112 = arch_attr_t<arch_tag>::template register_attr<>::reg_in_bytes
113 / sizeof(dtype);
114 static constexpr int32_t max_load_block_height
115 = load_store_attr::max_load_height_in_elem;
116 static constexpr int32_t max_block_width
117 = load_store_attr::max_load_width_in_bytes / sizeof(dtype);
118 static constexpr int32_t max_trans_block_width
119 = load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype);
120
121 static constexpr uint32_t ld_blk_size_y_limit
122 = mem_transpose ? max_trans_block_width : max_load_block_height;
123 static constexpr uint32_t ld_blk_size_y = reg_transpose
124 ? block_size_y
125 : (block_size_y > ld_blk_size_y_limit ? ld_blk_size_y_limit
126 : block_size_y);
127
128 // array len is used to make sure memory load is cache line aligned
129 // disabled while register or memory transpose
130 static constexpr uint8_t arr_len_candidate
131 = (reg_transpose
132 || mem_transpose
133 // block elements should be integer
134 // times of register bytes
135 || ((block_size_y * block_size_x) % elems_per_reg != 0)
136 // tail blocks also need to meet above condition
137 || (((tile_size_y % block_size_y) * block_size_x)
138 % elems_per_reg
139 != 0))
140 || (block_size_y > ld_blk_size_y_limit)
141 ? 1
142 : (((tile_size_x % elems_per_CL) == 0)
143 ? (((elems_per_CL % block_size_x) == 0)
144 ? elems_per_CL / block_size_x
145 : 1)
146 : ((tile_size_x < elems_per_CL)
147 ? (tile_size_x / block_size_x)
148 : 1));
149 static constexpr bool is_valid_arr_len_candidate = (arr_len_candidate == 1)
150 || (arr_len_candidate == 2) || (arr_len_candidate == 4);
151
152 static constexpr uint8_t arr_len
153 = is_valid_arr_len_candidate ? arr_len_candidate : 1;
154
155 static_assert(reg_transpose || mem_transpose
156 || (!mem_transpose
157 && (block_size_x * arr_len) <= max_block_width),
158 "When reg_transpose was disabled, check 2d block width "
159 "restriction");
160 static_assert(!reg_transpose
161 || (!mem_transpose
162 && (block_size_x * arr_len)
163 <= max_trans_block_width)
164 || (mem_transpose
165 && (block_size_y * arr_len) <= max_block_width),
166 "When reg_transpose was enabled, check 2d block width "
167 "restriction");
168 static_assert(!reg_transpose
169 || (!mem_transpose
170 && (block_size_y <= max_load_block_height))
171 || (mem_transpose
172 && (block_size_x) <= max_load_block_height),
173 "When reg_transpose was enabled, check 2d block height "
174 "restriction");
175 static_assert(tile_size_x % (block_size_x * arr_len) == 0,
176 "tile_size_x should be a multiple of (block_size_x * arr_len)");
177 static_assert(
178 (reg_transpose
179 && ((block_size_x * sizeof(dtype)) % sizeof(load_dtype)
180 == 0))
181 || ((block_size_y * sizeof(dtype)) % sizeof(load_dtype)
182 == 0),
183 "check vnni limitation for DW transpose");
184
185 auto payload_2d = payload.payloads.xetla_format<uint32_t, num_block, 16>();
186#pragma unroll
187 for (uint32_t i = 0; i < num_block_y; ++i) {
188 constexpr uint32_t load_block_elems = block_elems * arr_len;
189 auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
190 i * num_block_x, 0);
191 detail::reset_tile_desc_core<num_block_x, block_size_x, ld_blk_size_y,
192 scale_factor, arr_len, mem_transpose>(payload_row);
193#pragma unroll
194 for (uint32_t j = 0; j < num_block_x; j += arr_len) {
195 xetla_tdescriptor tdesc = payload_row.row(j);
196 auto reg_blk = tile.reg.xetla_select<load_block_elems, 1>(
197 (i * num_block_x + j) * block_elems);
198 constexpr uint32_t ld_blk_height = (reg_transpose && trans)
199 ? detail::getNextPowerOf2<ld_blk_size_y>()
200 : ld_blk_size_y;
201 constexpr uint32_t tmp_size
202 = ld_blk_height * block_size_x * arr_len;
204#pragma unroll
205 for (uint32_t ii = 0; ii < block_size_y / ld_blk_size_y; ++ii) {
206 constexpr uint32_t load_elems
207 = ld_blk_size_y * block_size_x * arr_len;
208
209 reg_tmp.xetla_format<native_type_t<load_dtype>>()
210 = xetla_tload_global<load_dtype,
211 ld_blk_height * block_size_x * arr_len
212 / scale_factor,
213 L1, L2, trans, mem_transform, arch_tag>(tdesc);
214
215 if constexpr (reg_transpose && trans) {
216 reg_blk.xetla_select<load_elems, 1>(ii * load_elems)
218 = reg_tmp.xetla_format<load_dtype,
219 block_size_x / scale_factor,
220 ld_blk_height>()
221 .xetla_select<block_size_x / scale_factor,
222 1, ld_blk_size_y, 1>(0, 0);
223 } else {
224 reg_blk.xetla_select<tmp_size, 1>(ii * tmp_size) = reg_tmp;
225 }
226
227 if constexpr (mem_transpose) {
228 xetla_update_tdesc_offsetx(tdesc.xetla_format<uint32_t>(),
229 ld_blk_size_y / scale_factor);
230 } else {
232 tdesc.xetla_format<uint32_t>(), ld_blk_size_y);
233 }
234 }
235 // exceed HW limitation
236 if constexpr (block_size_y % ld_blk_size_y != 0) {
237 constexpr uint32_t remained_start_y
238 = block_size_y / ld_blk_size_y * ld_blk_size_y;
239 constexpr uint32_t remained_start
240 = remained_start_y * block_size_x * arr_len;
241 constexpr uint32_t remained_blk_size_y
242 = block_size_y % ld_blk_size_y;
243 constexpr uint32_t load_elems
244 = remained_blk_size_y * block_size_x * arr_len;
245
246 constexpr uint8_t block_width = mem_transpose
247 ? (remained_blk_size_y / scale_factor)
248 : block_size_x;
249 constexpr uint8_t block_height
250 = trans ? block_size_x : remained_blk_size_y;
251 constexpr uint32_t block_widthx_widthy_arrlen
252 = (block_width - 1) | ((block_height - 1) << 8);
254 tdesc.xetla_format<uint32_t>(),
255 block_widthx_widthy_arrlen);
256
257 reg_blk.xetla_select<load_elems, 1>(remained_start)
259 = xetla_tload_global<load_dtype,
260 (load_elems / scale_factor), L1, L2, trans,
261 mem_transform, arch_tag>(tdesc);
262 }
263 }
264 }
265 // process tail
266 if constexpr (remained_size_y > 0) {
267 constexpr uint32_t remained_block_elems
268 = block_size_x * remained_size_y;
269 constexpr uint32_t processed_elems
270 = num_block_y * num_block_x * block_elems;
271 constexpr uint32_t remained_ld_blk_size_y
272 = (!reg_transpose && (remained_size_y > ld_blk_size_y_limit))
273 ? ld_blk_size_y_limit
274 : remained_size_y;
275 auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
276 num_block_y * num_block_x, 0);
277 detail::reset_tile_desc_core<num_block_x, block_size_x,
278 remained_ld_blk_size_y, scale_factor, arr_len, mem_transpose>(
279 payload_row);
280#pragma unroll
281 for (uint32_t j = 0; j < num_block_x; j += arr_len) {
282 xetla_tdescriptor tdesc = payload_row.row(j);
283 auto reg_blk
284 = tile.reg.xetla_select<remained_block_elems * arr_len, 1>(
285 processed_elems + j * remained_block_elems);
286 constexpr uint32_t ld_blk_height = (reg_transpose && trans)
287 ? detail::getNextPowerOf2<remained_ld_blk_size_y>()
288 : remained_ld_blk_size_y;
289 constexpr uint32_t tmp_size
290 = ld_blk_height * block_size_x * arr_len;
292#pragma unroll
293 for (uint32_t ii = 0; ii < remained_size_y / remained_ld_blk_size_y;
294 ++ii) {
295 constexpr uint32_t load_elems
296 = remained_ld_blk_size_y * block_size_x * arr_len;
297
298 reg_tmp.xetla_format<native_type_t<load_dtype>>()
299 = xetla_tload_global<load_dtype,
300 (ld_blk_height * block_size_x * arr_len
301 / scale_factor),
302 L1, L2, trans, mem_transform, arch_tag>(tdesc);
303
304 if constexpr (reg_transpose && trans) {
305 reg_blk.xetla_select<load_elems, 1>(ii * load_elems)
307 = reg_tmp.xetla_format<load_dtype,
308 block_size_x / scale_factor,
309 ld_blk_height>()
310 .xetla_select<block_size_x / scale_factor,
311 1, remained_ld_blk_size_y, 1>(
312 0, 0);
313 } else {
314 reg_blk.xetla_select<tmp_size, 1>(ii * tmp_size) = reg_tmp;
315 }
316 if constexpr (mem_transpose) {
317 xetla_update_tdesc_offsetx(tdesc.xetla_format<uint32_t>(),
318 remained_ld_blk_size_y / scale_factor);
319 } else {
320 xetla_update_tdesc_offsety(tdesc.xetla_format<uint32_t>(),
321 remained_ld_blk_size_y);
322 }
323 }
324 constexpr uint32_t final_ld_blk_size_y
325 = remained_size_y % remained_ld_blk_size_y;
326 if constexpr (final_ld_blk_size_y != 0) {
327 constexpr uint32_t final_start = remained_size_y
328 / remained_ld_blk_size_y * remained_ld_blk_size_y
329 * block_size_x * arr_len;
330 constexpr uint32_t final_load_elems
331 = final_ld_blk_size_y * block_size_x * arr_len;
332 constexpr uint8_t block_width = mem_transpose
333 ? (final_ld_blk_size_y / scale_factor)
334 : block_size_x;
335 constexpr uint8_t block_height
336 = trans ? block_size_x : final_ld_blk_size_y;
337 constexpr uint32_t block_widthx_widthy_arrlen
338 = (block_width - 1) | ((block_height - 1) << 8);
340 tdesc.xetla_format<uint32_t>(),
341 block_widthx_widthy_arrlen);
342 reg_blk.xetla_select<final_load_elems, 1>(final_start)
344 = xetla_tload_global<load_dtype,
345 final_load_elems / scale_factor, L1, L2, trans,
346 mem_transform, arch_tag>(tdesc);
347 }
348 }
349 }
350
351 if constexpr (is_vnni_reverse) {
352 SW_BARRIER();
354 }
355}
356
368template <cache_hint L1 = cache_hint::cached,
369 cache_hint L2 = cache_hint::cached, typename tile_t, typename payload_t>
370__XETLA_API typename std::enable_if_t<
372tile_load(tile_t &tile, payload_t &payload) {
373 using dtype = typename tile_t::dtype;
374 using load_dtype = typename payload_t::mem_dtype;
375
376 static constexpr uint32_t tile_size_x = tile_t::tile_size_x;
377 static constexpr uint32_t scale_factor = payload_t::scale_factor;
378 constexpr uint32_t load_len = tile_size_x / scale_factor;
379
380 if constexpr (load_len >= 64) {
381#pragma unroll
382 for (uint32_t i = 0; i < load_len / 64; i++) {
383 uint32_t offset_x = i * 64 * scale_factor;
384 auto reg_sub
385 = tile.reg.xetla_select<64 * scale_factor, 1>(offset_x);
386 uint32_t address_offset = offset_x * sizeof(dtype);
387 reg_sub.xetla_format<load_dtype>() = xetla_load_global<load_dtype,
388 64, data_size::default_size, L1, L2>(
389 payload.base_ptr, payload.base_offset + address_offset);
390 }
391 }
392 constexpr uint32_t tail_len = load_len % 64;
393 uint32_t tail_offset = load_len / 64 * 64 * scale_factor;
394 detail::process_1d_tail<tail_len, 32, detail::process_flag::load, L1, L2>(
395 tile, payload, tail_offset);
396}
397
410template <cache_hint L1 = cache_hint::cached,
411 cache_hint L3 = cache_hint::cached, typename tile_t, typename payload_t,
412 typename oob_check_tag = global_atomic_oob_check_on_tag>
413__XETLA_API typename std::enable_if_t<
415tile_load(tile_t &tile, payload_t &payload,
416 [[maybe_unused]] oob_check_tag tag = {}) {
417 constexpr bool oob_check = std::is_same<oob_check_tag,
418 global_atomic_oob_check_on_tag>::value;
419 using dtype = typename payload_t::dtype;
420 using tile_desc = typename payload_t::tile_desc;
421 using load_dtype = typename payload_t::mem_dtype;
422 constexpr uint32_t num_channel_y = payload_t::num_channel_y;
423 constexpr uint32_t load_elems = num_channel_y * payload_t::num_channel_x;
424 constexpr uint32_t scale_factor = payload_t::scale_factor;
425
426#pragma unroll
427 for (uint32_t i = 0; i < tile_desc::tile_size_y / tile_desc::block_size_y;
428 i++) {
429 uint32_t offset_y = i * tile_desc::block_size_y;
430#pragma unroll
431 for (uint32_t j = 0; j < tile_desc::num_block_x; j++) {
432 uint32_t offset_x = j * tile_desc::block_size_x;
433 auto reg_sub = tile.reg.xetla_select<tile_desc::block_elems, 1>(
434 (i * tile_desc::num_block_x + j) * tile_desc::block_elems);
435 xetla_mask<load_elems> pred_x = oob_check
436 ? payload.step_x + payload.base_x + offset_x
437 < payload.width_in_elems
438 : 1;
439#pragma unroll
440 for (uint32_t sub_block_y = 0;
441 sub_block_y < tile_desc::block_size_y;
442 sub_block_y += num_channel_y) {
444 xetla_mask<load_elems> pred_y = oob_check
445 ? payload.step_y + payload.base_y + offset_y
446 + sub_block_y
447 < payload.height_in_elems
448 : 1;
449
450 uint32_t address_offset = payload_t::trans
451 ? offset_x * payload.pitch_in_bytes
452 + (offset_y + sub_block_y) * sizeof(dtype)
453 : offset_x * sizeof(dtype)
454 + (offset_y + sub_block_y)
455 * payload.pitch_in_bytes;
456
457 reg_tmp = xetla_load_global<load_dtype, 1,
458 data_size::default_size, L1, L3, load_elems>(
459 payload.base_ptr,
460 payload.channel_offset + payload.base_offset
461 + address_offset,
462 pred_x && pred_y);
463 reg_tmp.xetla_merge(reg_tmp, 0, pred_x && pred_y);
464
465 reg_sub.xetla_select<load_elems * scale_factor, 1>(
466 sub_block_y * tile_desc::block_size_x)
467 .xetla_format<load_dtype>()
468 = reg_tmp;
469 }
470 }
471 }
472 //process the tail
473 if constexpr ((tile_desc::tile_size_y % tile_desc::block_size_y) != 0) {
474 constexpr uint32_t remained_size_y = tile_desc::remained_size_y;
475 constexpr uint32_t offset_y = tile_desc::tile_size_y - remained_size_y;
476 constexpr uint32_t processed_elems = offset_y * tile_desc::tile_size_x;
477 constexpr uint32_t remain_block_elems
478 = remained_size_y * tile_desc::block_size_x;
479#pragma unroll
480 for (uint32_t j = 0; j < tile_desc::num_block_x; j++) {
481 uint32_t offset_x = j * tile_desc::block_size_x;
482 auto reg_sub = tile.reg.xetla_select<remain_block_elems, 1>(
483 processed_elems + j * remain_block_elems);
484 xetla_mask<load_elems> pred_x = oob_check
485 ? payload.step_x + payload.base_x + offset_x
486 < payload.width_in_elems
487 : 1;
488#pragma unroll
489 for (uint32_t sub_block_y = 0; sub_block_y < remained_size_y;
490 sub_block_y += num_channel_y) {
491 xetla_vector<load_dtype, load_elems> reg_tmp;
492 xetla_mask<load_elems> pred_y = oob_check
493 ? payload.step_y + payload.base_y + offset_y
494 + sub_block_y
495 < payload.height_in_elems
496 : 1;
497
498 uint32_t address_offset = payload_t::trans
499 ? offset_x * payload.pitch_in_bytes
500 + (offset_y + sub_block_y) * sizeof(dtype)
501 : offset_x * sizeof(dtype)
502 + (offset_y + sub_block_y)
503 * payload.pitch_in_bytes;
504
505 reg_tmp = xetla_load_global<load_dtype, 1,
506 data_size::default_size, L1, L3, load_elems>(
507 payload.base_ptr,
508 payload.channel_offset + payload.base_offset
509 + address_offset,
510 pred_x && pred_y);
511
512 reg_tmp.xetla_merge(reg_tmp, 0, pred_x && pred_y);
513
514 reg_sub.xetla_select<load_elems * scale_factor, 1>(
515 sub_block_y * tile_desc::block_size_x)
516 .xetla_format<load_dtype>()
517 = reg_tmp;
518 }
519 }
520 }
521
522 if constexpr (payload_t::mem_transform) {
523 SW_BARRIER();
525 }
526}
527
540template <cache_hint L1 = cache_hint::cached,
541 cache_hint L2 = cache_hint::cached, typename tile_t, typename payload_t>
542__XETLA_API typename std::enable_if_t<
544tile_load(tile_t &tile, payload_t &payload) {
545 using dtype = typename payload_t::dtype;
546 using tile_desc = typename payload_t::tile_desc;
547 using load_dtype = typename payload_t::mem_dtype;
548
549 constexpr uint32_t num_channel_y = payload_t::num_channel_y;
550 constexpr uint32_t load_elems = num_channel_y * tile_desc::block_size_x;
551 static constexpr bool mem_transform = payload_t::mem_transform;
552
553#pragma unroll
554 for (uint32_t i = 0; i < tile_desc::tile_size_y / tile_desc::block_size_y;
555 i++) {
556 uint32_t offset_y = i * tile_desc::block_size_y;
557#pragma unroll
558 for (uint32_t j = 0; j < tile_desc::num_block_x; j++) {
559 uint32_t offset_x = j * tile_desc::block_size_x;
560 auto reg_sub = tile.reg.xetla_select<tile_desc::block_elems, 1>(
561 (i * tile_desc::num_block_x + j) * tile_desc::block_elems);
562#pragma unroll
563 for (uint32_t sub_block_y = 0;
564 sub_block_y < tile_desc::block_size_y;
565 sub_block_y += num_channel_y) {
566 uint32_t address_offset = offset_x * sizeof(dtype)
567 + (sub_block_y + offset_y) * payload.pitch_in_bytes;
568 reg_sub.xetla_select<load_elems, 1>(
569 sub_block_y * tile_desc::block_size_x)
570 .xetla_format<load_dtype>()
571 = xetla_load_local<load_dtype>(
572 payload.address + address_offset);
573 }
574 }
575 }
576 //process the tail
577 if constexpr ((tile_desc::tile_size_y % tile_desc::block_size_y) != 0) {
578 constexpr uint32_t remained_size_y = tile_desc::remained_size_y;
579 constexpr uint32_t offset_y = tile_desc::tile_size_y - remained_size_y;
580 constexpr uint32_t processed_elems = offset_y * tile_desc::tile_size_x;
581 constexpr uint32_t remain_block_elems
582 = remained_size_y * tile_desc::block_size_x;
583#pragma unroll
584 for (uint32_t j = 0; j < tile_desc::num_block_x; j++) {
585 uint32_t offset_x = j * tile_desc::block_size_x;
586 auto reg_sub = tile.reg.xetla_select<remain_block_elems, 1>(
587 processed_elems + j * remain_block_elems);
588#pragma unroll
589 for (uint32_t sub_block_y = 0; sub_block_y < remained_size_y;
590 sub_block_y += num_channel_y) {
591 uint32_t address_offset = offset_x * sizeof(dtype)
592 + (sub_block_y + offset_y) * payload.pitch_in_bytes;
593 reg_sub.xetla_select<load_elems, 1>(
594 sub_block_y * tile_desc::block_size_x)
595 .xetla_format<load_dtype>()
596 = xetla_load_local<load_dtype>(
597 payload.address + address_offset);
598 }
599 }
600 }
601 if constexpr (mem_transform) {
602 SW_BARRIER();
604 }
605}
606
619template <cache_hint L1 = cache_hint::cached,
620 cache_hint L2 = cache_hint::cached, typename tile_t, typename payload_t>
621__XETLA_API typename std::enable_if_t<
623tile_load(tile_t &tile, payload_t &payload) {
624 using dtype = typename tile_t::dtype;
625 using tile_desc = typename tile_t::tile_desc;
626 using load_dtype = typename payload_t::mem_dtype;
627
628 constexpr uint32_t scale_factor = payload_t::scale_factor;
629 constexpr uint32_t load_len = tile_desc::tile_size_x / scale_factor;
630 if constexpr (load_len >= 64) {
631#pragma unroll
632 for (uint32_t j = 0; j < load_len / 64; j++) {
633 uint32_t offset_x = j * 64 * scale_factor;
634 auto reg_sub
635 = tile.reg.xetla_select<64 * scale_factor, 1>(offset_x);
636 uint32_t address_offset = offset_x * sizeof(dtype);
637 reg_sub.xetla_format<load_dtype>()
638 = xetla_load_local<load_dtype, 64, data_size::default_size>(
639 payload.address + address_offset);
640 }
641 }
643 L2>(tile, payload, load_len / 64 * 64 * scale_factor);
644}
645
646} // namespace gpu::xetla::subgroup
#define SW_BARRIER()
SW_BARRIER, insert software scheduling barrier, for better code control.
Definition common.hpp:227
#define __XETLA_API
Definition common.hpp:43
#define xetla_select
xetla select.
Definition base_ops.hpp:49
#define xetla_format
xetla format.
Definition base_ops.hpp:38
xetla_vector< uint32_t, 16 > xetla_tdescriptor
Description of nd tensor descriptor for load and store.
Definition base_types.hpp:155
typename native_type< T >::type native_type_t
Return the native data type of T.
Definition base_types.hpp:106
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
__XETLA_API xetla_vector< Ty, N *NElts > xetla_load_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_mask< N > pred=1)
Stateless scattered load.
Definition memory.hpp:245
__XETLA_API void xetla_update_tdesc_offsetx(xetla_tdescriptor_ref tdesc, int32_t doffset_x)
Update the x coordinate in the given tensor descriptor.
Definition raw_send_load_store.hpp:152
__XETLA_API void xetla_update_tdesc_offsety(xetla_tdescriptor_ref tdesc, int32_t doffset_y)
Update the y coordinate in the given tensor descriptor.
Definition raw_send_load_store.hpp:161
__XETLA_API std::enable_if_t< arch_tag==gpu_arch::Xe, xetla_vector< Ty, N > > xetla_tload_global(xetla_tdescriptor tdesc)
Tensor load API.
Definition raw_send_load_store.hpp:183
__XETLA_API void xetla_set_block_widthx_widthy_arrlen(xetla_tdescriptor_ref desc, uint32_t block_widthx_widthy_arrlen)
Definition tensor_descriptor.hpp:79
__XETLA_API uint32_t uint32_t uint32_t scale_factor
Definition common.hpp:195
__XETLA_API std::enable_if_t< base_len==0 > process_1d_tail(tile_t &tile, payload_t &payload, uint32_t offset)
Definition common.hpp:96
Definition limitation.hpp:457
__XETLA_API std::enable_if_t< T::register_layout==reg_layout::vnni_tiled > vnni_convert(T &mat_Acc)
Converts tiled layout to vnni_tiled layout format.
Definition op_function.hpp:118
__XETLA_API std::enable_if_t< T::register_layout==reg_layout::tiled > vnni_reverse(T &mat_Acc)
Converts vnni_tiled layout format to tiled layout.
Definition op_function.hpp:196
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
cache_hint
L1 or L2 cache hint kinds.
Definition common.hpp:89
reg_layout
tile layout in register linear: linear layout with one tile tiled: 2d block stacked in raster order v...
Definition common.hpp:209
@ tile
flush out to the local scope
gpu_arch
Definition common.hpp:73
Definition arch_config.hpp:72
static constexpr bool is_global_block_1d_xe
Definition load_xe.hpp:37
static constexpr bool is_local_block_1d_xe
Definition load_xe.hpp:53
static constexpr bool is_global_unaligned_2d_xe
Definition load_xe.hpp:43
static constexpr bool is_global_2d_xe
Definition load_xe.hpp:32
static constexpr bool is_local_scatter_xe
Definition load_xe.hpp:48
Is a struct contains some register file.
Definition api.hpp:99
tile_desc_ tile_desc
Definition api.hpp:101
dtype_ dtype
Definition api.hpp:100
C++ API.