22#include "common/common.hpp"
34template <u
int32_t N, u
int32_t K,
bool K_gt_eq_N>
41template <u
int32_t N, u
int32_t K>
43 static constexpr uint32_t
get() {
return K; }
50template <u
int32_t N, u
int32_t K>
52 static constexpr uint32_t
get() {
79template <u
int32_t a, u
int32_t b>
89 static constexpr uint32_t
value = a;
94template <uint32_t remained_len, uint32_t base_len,
process_flag flag,
97 [[maybe_unused]]
tile_t &
tile, [[maybe_unused]] payload_t &payload,
98 [[maybe_unused]] uint32_t offset) {}
100template <uint32_t remained_len, uint32_t base_len,
process_flag flag,
105 using dtype =
typename payload_t::dtype;
106 using mem_dtype =
typename payload_t::mem_dtype;
107 if constexpr (remained_len >= base_len) {
108 uint32_t address_offset = offset *
sizeof(dtype);
110 =
tile.reg.xetla_select<base_len * payload_t::scale_factor, 1>(
115 payload.base_ptr, payload.base_offset + address_offset);
118 L2>(payload.base_ptr, payload.base_offset + address_offset,
119 reg_sub.xetla_format<mem_dtype>());
121 process_1d_tail<remained_len - base_len, (base_len >> 1), flag, L1, L2>(
122 tile, payload, offset + base_len * payload_t::scale_factor);
125 tile, payload, offset);
129template <uint32_t remained_len, uint32_t base_len,
process_flag flag,
134 using mem_dtype =
typename payload_t::mem_dtype;
135 if constexpr (remained_len >= base_len) {
137 =
tile.reg.xetla_select<base_len * payload_t::scale_factor, 1>(
139 uint32_t address_offset = offset *
sizeof(
typename tile_t::dtype);
143 payload.address + address_offset);
145 xetla_store_local<mem_dtype, base_len>(
146 payload.address + address_offset,
147 reg_sub.xetla_format<mem_dtype>());
149 process_1d_tail<remained_len - base_len, (base_len >> 1), flag, L1, L2,
151 tile, payload, offset + base_len * payload_t::scale_factor);
153 process_1d_tail<remained_len, (base_len >> 1), flag, L1, L2, payload_t,
160template <uint32_t remained_len, uint32_t base_len,
cache_hint L1,
163 payload_t &payload, uint32_t offset) {
164 using dtype =
typename payload_t::dtype;
165 using prefetch_dtype =
typename payload_t::prefetch_dtype;
166 uint32_t address_offset = offset *
sizeof(dtype);
167 constexpr uint32_t prefetch_min_size = 64 /
sizeof(prefetch_dtype);
168 if constexpr (remained_len > 0) {
171 payload.base_ptr, payload.base_offset + address_offset);
175template <uint32_t remained_len, uint32_t base_len,
cache_hint L1,
178 payload_t &payload, uint32_t offset) {
179 using dtype =
typename payload_t::dtype;
180 using prefetch_dtype =
typename payload_t::prefetch_dtype;
181 if constexpr (remained_len >= base_len) {
182 uint32_t address_offset = offset *
sizeof(dtype);
184 L1, L2>(payload.base_ptr, payload.base_offset + address_offset);
187 payload, offset + base_len * payload_t::scale_factor);
199 for (uint32_t j = 0; j < num_tdesc; j++) {
200 constexpr uint8_t block_width
202 constexpr uint8_t block_height = trans ?
size_x :
size_y;
203 constexpr uint32_t block_widthx_widthy_arrlen = (block_width - 1)
204 | ((block_height - 1) << 8) | ((
arr_len - 1) << 16);
206 payload_row.row(j), block_widthx_widthy_arrlen);
212template <
typename T_dst,
typename T_src>
214 static constexpr bool value = (T_src::block_size_y == T_dst::block_size_y)
215 && (T_src::block_size_x == T_dst::block_size_x)
216 && (T_src::tile_size_y == T_dst::tile_size_y)
217 && (T_src::tile_size_x == T_dst::tile_size_x);
220template <
typename T_dst,
typename T_src>
227template <
typename tile_desc_,
mem_space memory_space,
231 ? (((tile_desc_::tile_size_y == 1)
235 : (((tile_desc_::tile_size_y == 1)
241template <
typename tile_desc_, mem_space memory_space>
244template <
typename dtype, uint32_t tile_size_x, uint32_t tile_size_y,
249template <
typename dtype, u
int32_t tile_size_x, u
int32_t tile_size_y>
255 static constexpr uint32_t max_load_height_in_elem
256 = load_store_attr::max_load_height_in_elem;
257 static constexpr uint32_t max_load_width_in_bytes
258 = load_store_attr::max_load_width_in_bytes;
259 static constexpr uint32_t max_load_width_in_elem
260 = max_load_width_in_bytes /
sizeof(dtype);
264 static constexpr uint32_t block_size_x
266 static constexpr uint32_t block_size_y
267 = max_load_height_in_elem > tile_size_y ? tile_size_y
268 : max_load_height_in_elem;
271template <
typename dtype, uint32_t tile_size_x, uint32_t tile_size_y,
276template <
typename dtype, u
int32_t tile_size_x, u
int32_t tile_size_y>
282 static constexpr uint32_t max_store_height_in_elem
283 = load_store_attr::max_store_height_in_elem;
284 static constexpr uint32_t max_store_width_in_bytes
285 = load_store_attr::max_store_width_in_bytes;
286 static constexpr uint32_t max_store_width_in_elem
287 = max_store_width_in_bytes /
sizeof(dtype);
291 static constexpr uint32_t block_size_x
293 static constexpr uint32_t block_size_y
294 = max_store_height_in_elem > tile_size_y ? tile_size_y
295 : max_store_height_in_elem;
#define __XETLA_API
Definition common.hpp:43
Workaround for ESIMD matrix(2D) ref type.
Definition base_types.hpp:202
#define __REF__
Workaround for ESIMD reference usage.
Definition base_types.hpp:177
__XETLA_API void xetla_prefetch_global(Ty *p, xetla_vector< uint32_t, N > offsets, xetla_mask< N > pred=1)
Stateless scattered prefetch.
Definition memory.hpp:187
__XETLA_API xetla_vector< Ty, N *NElts > xetla_load_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_mask< N > pred=1)
Stateless scattered load.
Definition memory.hpp:245
__XETLA_API xetla_vector< Ty, N *NElts > xetla_load_local(xetla_vector< uint32_t, N > offsets, xetla_mask< N > pred=1)
SLM scattered load.
Definition memory.hpp:464
__XETLA_API void xetla_store_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_vector< Ty, N *NElts > vals, xetla_mask< N > pred=1)
Stateless scattered store.
Definition memory.hpp:316
__XETLA_API void xetla_set_block_widthx_widthy_arrlen(xetla_tdescriptor_ref desc, uint32_t block_widthx_widthy_arrlen)
Definition tensor_descriptor.hpp:79
__XETLA_API uint32_t uint32_t size_y
Definition common.hpp:194
process_flag
Definition common.hpp:92
constexpr uint32_t getNextPowerOf2()
Get the Next Power Of2 object.
Definition common.hpp:62
__XETLA_API uint32_t uint32_t uint32_t uint8_t arr_len
Definition common.hpp:195
__XETLA_API uint32_t uint32_t uint32_t scale_factor
Definition common.hpp:195
__XETLA_API std::enable_if_t< base_len==0 > process_1d_tail(tile_t &tile, payload_t &payload, uint32_t offset)
Definition common.hpp:96
constexpr uint32_t getNextPowerOf2< 0 >()
Get the Next Power Of2<0> object.
Definition common.hpp:71
__XETLA_API uint32_t size_x
Definition common.hpp:194
Definition limitation.hpp:457
constexpr msg_type msg_type_v
Definition common.hpp:242
cache_hint
L1 or L2 cache hint kinds.
Definition common.hpp:89
reg_layout
tile layout in register linear: linear layout with one tile tiled: 2d block stacked in raster order v...
Definition common.hpp:209
@ tile
flush out to the local scope
mem_space
Definition common.hpp:77
gpu_arch
Definition common.hpp:73
msg_type
Definition common.hpp:78
mem_layout
Definition common.hpp:76
Definition arch_config.hpp:72
Used to check if the type is floating_point.
Definition base_types.hpp:75
Used to check if the type is floating_point.
Definition base_types.hpp:86
static constexpr uint32_t get()
Definition common.hpp:52
static constexpr uint32_t get()
Definition common.hpp:43
Compute next power of 2 of a constexpr with guaranteed compile-time evaluation.
Definition common.hpp:35
static constexpr uint32_t value
Definition common.hpp:81
Definition common.hpp:247
Definition common.hpp:274
Definition common.hpp:302
Definition common.hpp:299
Definition common.hpp:221
static constexpr bool value
Definition common.hpp:223
Definition common.hpp:213
static constexpr bool value
Definition common.hpp:214
Definition common.hpp:229
static constexpr msg_type value
Definition common.hpp:230
Is a struct contains some register file.
Definition api.hpp:99
dtype_ dtype
Definition api.hpp:100