XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
payload_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "subgroup/tile/api.hpp"
24
25namespace gpu::xetla::subgroup {
26
36template <typename dtype_, typename tile_desc_, mem_layout mem_layout_,
37 gpu_arch arch_tag_, uint32_t alignment_>
39 mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_>,
40 tile_desc_, msg_type::block_2d, arch_tag_,
41 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
42 using tile_desc = tile_desc_;
45 using dtype = dtype_;
46 static constexpr msg_type message_type = msg_type::block_2d;
47 static constexpr mem_space memory_space = mem_space::global;
48 static constexpr mem_layout memory_layout = mem_layout_;
49 static constexpr gpu_arch arch_tag = arch_tag_;
50
51private:
52 static constexpr uint32_t block_size_x = tile_desc::block_size_x;
53 static constexpr uint32_t block_size_y = tile_desc::block_size_y;
54 static constexpr uint32_t num_block_x = tile_desc::num_block_x;
55 static constexpr uint32_t num_block_y = tile_desc::num_block_y;
56 static constexpr uint32_t num_block = tile_desc::num_block;
57 static constexpr uint32_t remained_size_y = tile_desc::remained_size_y;
59 msg_type::block_2d, arch_tag>;
60
61public:
62 static constexpr bool mem_transpose
63 = memory_layout == mem_layout::col_major;
64
65 static constexpr reg_layout register_layout = tile_desc::register_layout;
66 static constexpr bool reg_transpose
67 = register_layout == reg_layout::transpose_tiled;
68 static constexpr bool trans = mem_transpose ^ reg_transpose;
69
70 static constexpr bool mem_transform = (sizeof(dtype) < 4) && !mem_transpose
71 && (register_layout == reg_layout::vnni_tiled
72 || register_layout == reg_layout::vnni_tiled_col_major);
73 static constexpr bool mem_dword_transpose = (sizeof(dtype) < 4) && trans;
74
75 using mem_dtype = typename std::conditional<mem_dword_transpose, uint32_t,
76 dtype>::type;
77 static constexpr uint32_t scale_factor = sizeof(mem_dtype) / sizeof(dtype);
78
80
81 inline mem_payload_t(const this_payload_t &rhs) {
82 this->payload = rhs.payload;
83 }
84
85 inline mem_payload_t(mem_desc_t &mem_desc) {
86 xetla_tdescriptor base_tdesc = mem_desc.get_tdesc();
87 int32_t offset
89 / int32_t(scale_factor);
91 base_tdesc.xetla_format<uint32_t>(), offset);
92 prepare_tdesc(base_tdesc);
93 }
94
95 inline mem_payload_t(dtype *p, uint32_t surface_width,
96 uint32_t surface_height, uint32_t surface_pitch,
97 int32_t surface_offset_x = 0, int32_t surface_offset_y = 0) {
98 xetla_tdescriptor base_tdesc;
99 xetla_fill_tdesc(base_tdesc.xetla_format<uint32_t>(), p, surface_width,
100 surface_height, surface_pitch,
101 surface_offset_x / int32_t(scale_factor), surface_offset_y);
102 prepare_tdesc(base_tdesc);
103 }
104
105 __XETLA_API void init(mem_desc_t &mem_desc) {
106 xetla_tdescriptor base_tdesc = mem_desc.get_tdesc();
107 int32_t offset
109 / int32_t(scale_factor);
111 base_tdesc.xetla_format<uint32_t>(), offset);
112 prepare_tdesc(base_tdesc);
113 }
114
115 __XETLA_API void init(dtype *p, uint32_t surface_width,
116 uint32_t surface_height, uint32_t surface_pitch,
117 int32_t surface_offset_x = 0, int32_t surface_offset_y = 0) {
118 xetla_tdescriptor base_tdesc;
119 xetla_fill_tdesc(base_tdesc.xetla_format<uint32_t>(), p, surface_width,
120 surface_height, surface_pitch,
121 surface_offset_x / int32_t(scale_factor), surface_offset_y);
122 prepare_tdesc(base_tdesc);
123 }
124
125 inline mem_payload_t() = default;
126 // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor)
127 // Please check if you need to add self-define destructor
128 // ~mem_payload_t(){}
129
131 this->payload = rhs.payload;
132 return *this;
133 }
134
135 template <tdesc_update_dir update_dir = tdesc_update_dir::x_dir>
136 __XETLA_API void update_tdesc(int offset) {
137 auto payloads_2d = payloads.xetla_format<uint32_t, num_block, 16>();
138 if constexpr (update_dir == tdesc_update_dir::x_dir) {
139#pragma unroll
140 for (uint32_t i = 0; i < num_block; i++) {
142 payloads_2d.row(i), offset / int32_t(scale_factor));
143 }
144 } else {
145#pragma unroll
146 for (uint32_t i = 0; i < num_block; i++) {
147 xetla_update_tdesc_offsety(payloads_2d.row(i), offset);
148 }
149 }
150 }
151
152private:
153 __XETLA_API void prepare_tdesc(xetla_tdescriptor base_tdesc) {
154 auto payloads_2d = payloads.xetla_format<uint32_t, num_block, 16>();
155 uint32_t base_offset_y = 0;
156#pragma unroll
157 for (uint32_t i = 0; i < num_block_y; i++) {
158 auto tdesc_row_2d = payloads_2d.xetla_select<num_block_x, 1, 16, 1>(
159 i * num_block_x, 0);
160 prepare_tile_desc_core<num_block_x, block_size_x, block_size_y, 1,
161 mem_transpose>(tdesc_row_2d, base_tdesc, base_offset_y);
162 base_offset_y += block_size_y;
163 }
164 // process the tail
165 if constexpr (remained_size_y > 0) {
166 auto tdesc_row_2d = payloads_2d.xetla_select<num_block_x, 1, 16, 1>(
167 num_block_y * num_block_x, 0);
168 prepare_tile_desc_core<num_block_x, block_size_x, remained_size_y,
169 1, mem_transpose>(tdesc_row_2d, base_tdesc, base_offset_y);
170 }
171 }
172
173 template <uint32_t num_tdesc, uint32_t size_x, uint32_t size_y,
174 uint8_t arr_len, bool trans>
175 __XETLA_API static void prepare_tile_desc_core(
176 xetla_matrix_ref<uint32_t, num_tdesc, 16> __REF__ payloads_row_2d,
177 xetla_tdescriptor base_tdesc, uint32_t base_offset_y) {
178 uint32_t base_offset_x = 0;
179#pragma unroll
180 for (uint32_t j = 0; j < num_tdesc; j++) {
181 payloads_row_2d.row(j) = base_tdesc;
182 // To mimic dw transpose for word/byte data type with transpose and pack
183 constexpr uint8_t block_width
184 = trans ? (size_y / scale_factor) : (size_x / scale_factor);
185 constexpr uint8_t block_height = trans ? size_x : size_y;
186 constexpr uint32_t block_widthx_widthy_arrlen = (block_width - 1)
187 | ((block_height - 1) << 8) | ((arr_len - 1) << 16);
189 payloads_row_2d.row(j), block_widthx_widthy_arrlen);
190
191 // To mimic dw transpose for word/byte data type with transpose and pack
192 uint32_t offset_width = trans
193 ? (base_offset_y / int32_t(scale_factor))
194 : (base_offset_x / int32_t(scale_factor));
195 uint32_t offset_height = trans ? base_offset_x : base_offset_y;
196
197 xetla_update_tdesc_offsetx(payloads_row_2d.row(j), offset_width);
198 xetla_update_tdesc_offsety(payloads_row_2d.row(j), offset_height);
199 base_offset_x += size_x * arr_len;
200 }
201 }
202};
203
212template <typename dtype_, typename tile_desc_, gpu_arch arch_tag_,
213 uint32_t alignment_>
215 mem_space::global, alignment_>,
216 tile_desc_, msg_type::block_1d, arch_tag_,
217 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
219 mem_space::global, alignment_>;
220 using dtype = dtype_;
221 using tile_desc = tile_desc_;
222 static constexpr mem_space memory_space = mem_space::global;
223 static constexpr mem_layout memory_layout = mem_layout::row_major;
224 static constexpr msg_type message_type = msg_type::block_1d;
225 static constexpr uint32_t alignment_in_bytes
226 = mem_desc_t::alignment_in_bytes;
227 static constexpr gpu_arch arch_tag = arch_tag_;
228 static_assert((alignment_in_bytes % sizeof(uint32_t)) == 0,
229 "alignment should at least DW aligned");
230
231private:
232 static constexpr uint32_t tile_size_x = tile_desc::tile_size_x;
233 static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
234 static_assert(tile_size_y == 1,
235 "For tile_size_y > 1 case, please use 2d block message! ");
237 msg_type::block_1d, arch_tag>;
238
239public:
240 static constexpr uint32_t bytes_per_row = tile_size_x * sizeof(dtype);
241 using mem_dtype = typename std::conditional<
242 (bytes_per_row % sizeof(uint64_t) == 0)
243 && (alignment_in_bytes % sizeof(uint64_t) == 0),
244 uint64_t,
245 typename std::conditional<(bytes_per_row % sizeof(uint32_t) == 0),
246 uint32_t, dtype>::type>::type;
247 static constexpr uint32_t scale_factor = sizeof(mem_dtype) / sizeof(dtype);
248
249 uint64_t base_offset;
252
253 inline mem_payload_t(mem_desc_t &mem_tdesc) {
254 pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype);
255 uint32_t offset_x = mem_tdesc.coord.x;
256 uint32_t offset_y = mem_tdesc.coord.y;
257 base_offset = offset_y * pitch_in_bytes + offset_x * sizeof(dtype);
258 base_ptr = (mem_dtype *)mem_tdesc.base.base;
259 }
260
261 inline mem_payload_t(dtype *p, [[maybe_unused]] int surface_width,
262 [[maybe_unused]] int surface_height, int surface_pitch,
263 int surface_offset_x, int surface_offset_y) {
264 pitch_in_bytes = surface_pitch * sizeof(dtype);
265 uint32_t offset_x = surface_offset_x;
266 uint32_t offset_y = surface_offset_y;
267 base_offset = offset_y * pitch_in_bytes + offset_x * sizeof(dtype);
268 base_ptr = (mem_dtype *)p;
269 }
270
271 __XETLA_API void init(mem_desc_t &mem_tdesc) {
272 pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype);
273 uint32_t offset_x = mem_tdesc.coord.x;
274 uint32_t offset_y = mem_tdesc.coord.y;
275 base_offset = offset_y * pitch_in_bytes + offset_x * sizeof(dtype);
276 base_ptr = (mem_dtype *)mem_tdesc.base.base;
277 }
278
279 __XETLA_API void init(dtype *p, [[maybe_unused]] int surface_width,
280 [[maybe_unused]] int surface_height, int surface_pitch,
281 int surface_offset_x, int surface_offset_y) {
282 pitch_in_bytes = surface_pitch * sizeof(dtype);
283 uint32_t offset_x = surface_offset_x;
284 uint32_t offset_y = surface_offset_y;
285 base_offset = offset_y * pitch_in_bytes + offset_x * sizeof(dtype);
286 base_ptr = (mem_dtype *)p;
287 }
288
289 inline mem_payload_t(const this_payload_t &rhs) {
290 this->base_offset = rhs.base_offset;
291 this->base_ptr = rhs.base_ptr;
292 this->pitch_in_bytes = rhs.pitch_in_bytes;
293 }
294
295 inline mem_payload_t() = default;
297 this->base_offset = rhs.base_offset;
298 this->base_ptr = rhs.base_ptr;
299 this->pitch_in_bytes = rhs.pitch_in_bytes;
300 return *this;
301 }
302
303 template <tdesc_update_dir update_dir = tdesc_update_dir::x_dir>
304 __XETLA_API void update_tdesc(int offset) {
305 if constexpr (update_dir == tdesc_update_dir::x_dir) {
306 base_offset += int64_t(offset) * sizeof(dtype);
307 } else {
308 base_offset += int64_t(offset) * pitch_in_bytes;
309 }
310 }
311};
312
319template <typename dtype_, typename tile_desc_, gpu_arch arch_tag_,
320 uint32_t alignment_>
322 mem_space::global, alignment_>,
323 tile_desc_, msg_type::atomic_add, arch_tag_,
324 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
326 mem_space::global, alignment_>;
327 using dtype = dtype_;
328 using tile_desc = tile_desc_;
329 static constexpr mem_space memory_space = mem_space::global;
330 static constexpr mem_layout memory_layout = mem_layout::row_major;
331 static constexpr msg_type message_type = msg_type::atomic_add;
332 static constexpr uint32_t alignment_in_bytes
333 = mem_desc_t::alignment_in_bytes;
334 static constexpr gpu_arch arch_tag = arch_tag_;
335 static_assert(
336 sizeof(dtype) >= 4, "for atomic add, we only support DW or QW");
337
338private:
339 static constexpr uint32_t tile_size_x = tile_desc::tile_size_x;
340 static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
341 static constexpr uint32_t block_size_x = tile_desc::block_size_x;
342 static constexpr uint32_t block_size_y = tile_desc::block_size_y;
344 msg_type::atomic_add, arch_tag>;
345
346public:
347 static constexpr uint32_t tile_bytes
348 = tile_size_x * tile_size_y * sizeof(dtype);
349 static constexpr uint32_t block_bytes
350 = block_size_x * block_size_y * sizeof(dtype);
351
352 // for pvc, we can use simd16 or simd32
353 static constexpr uint32_t min_store_bytes = 16 * sizeof(dtype);
354 static constexpr uint32_t max_store_bytes = 32 * sizeof(dtype);
355 static constexpr uint32_t num_channel
356 = ((tile_bytes % max_store_bytes) == 0
357 && (block_bytes % max_store_bytes) == 0)
358 ? 32
359 : 16;
360
361 static constexpr uint32_t num_channel_x = block_size_x;
362 static constexpr uint32_t num_channel_y = num_channel / num_channel_x;
363 static constexpr uint32_t store_elems = num_channel_y * block_size_x;
364
371 uint32_t base_x;
372 uint32_t base_y;
373 uint64_t base_pointer;
374
375 inline mem_payload_t(mem_desc_t &mem_tdesc) {
376 pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype);
377 base_x = mem_tdesc.coord.x;
378 base_y = mem_tdesc.coord.y;
379 width_in_elems = mem_tdesc.shape.x;
380 height_in_elems = mem_tdesc.shape.y;
381 base_pointer = (uint64_t)mem_tdesc.base.base;
382 base_pointer += base_y * pitch_in_bytes + base_x * sizeof(dtype);
384 = xetla_vector_gen<uint32_t, num_channel>(0, 1);
385 step_x = channel_index % num_channel_x;
386 step_y = channel_index / num_channel_x;
387 channel_offset = step_x * sizeof(dtype) + step_y * pitch_in_bytes;
388 }
389
390 inline mem_payload_t(dtype *p, int surface_width, int surface_height,
391 int surface_pitch, int surface_offset_x, int surface_offset_y) {
392 pitch_in_bytes = surface_pitch * sizeof(dtype);
393 base_x = surface_offset_x;
394 base_y = surface_offset_y;
395 width_in_elems = surface_width;
396 height_in_elems = surface_height;
397 base_pointer = (uint64_t)p;
398 base_pointer += base_y * pitch_in_bytes + base_x * sizeof(dtype);
400 = xetla_vector_gen<uint32_t, num_channel>(0, 1);
401 step_x = channel_index % num_channel_x;
402 step_y = channel_index / num_channel_x;
403 channel_offset = step_x * sizeof(dtype) + step_y * pitch_in_bytes;
404 }
405
406 __XETLA_API void init(dtype *p, int surface_width, int surface_height,
407 int surface_pitch, int surface_offset_x, int surface_offset_y) {
408 pitch_in_bytes = surface_pitch * sizeof(dtype);
409 base_x = surface_offset_x;
410 base_y = surface_offset_y;
411 width_in_elems = surface_width;
412 height_in_elems = surface_height;
413 base_pointer = (uint64_t)p;
414 base_pointer += base_y * pitch_in_bytes + base_x * sizeof(dtype);
416 = xetla_vector_gen<uint32_t, num_channel>(0, 1);
417 step_x = channel_index % num_channel_x;
418 step_y = channel_index / num_channel_x;
419 channel_offset = step_x * sizeof(dtype) + step_y * pitch_in_bytes;
420 }
421
422 __XETLA_API void init(mem_desc_t &mem_tdesc) {
423 pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype);
424 base_x = mem_tdesc.coord.x;
425 base_y = mem_tdesc.coord.y;
426 width_in_elems = mem_tdesc.shape.x;
427 height_in_elems = mem_tdesc.shape.y;
428 base_pointer = (uint64_t)mem_tdesc.base.base;
429 base_pointer += base_y * pitch_in_bytes + base_x * sizeof(dtype);
431 = xetla_vector_gen<uint32_t, num_channel>(0, 1);
432 step_x = channel_index % num_channel_x;
433 step_y = channel_index / num_channel_x;
434 channel_offset = step_x * sizeof(dtype) + step_y * pitch_in_bytes;
435 }
436
437 inline mem_payload_t(const this_payload_t &rhs) {
438 this->pitch_in_bytes = rhs.pitch_in_bytes;
439 this->width_in_elems = rhs.width_in_elems;
440 this->height_in_elems = rhs.height_in_elems;
441 this->base_x = rhs.base_x;
442 this->base_y = rhs.base_y;
443 this->base_pointer = rhs.base_pointer;
444 this->channel_offset = rhs.channel_offset;
445 this->step_x = rhs.step_x;
446 this->step_y = rhs.step_y;
447 }
448
449 inline mem_payload_t() = default;
451 this->pitch_in_bytes = rhs.pitch_in_bytes;
452 this->width_in_elems = rhs.width_in_elems;
453 this->height_in_elems = rhs.height_in_elems;
454 this->base_x = rhs.base_x;
455 this->base_y = rhs.base_y;
456 this->base_pointer = rhs.base_pointer;
457 this->channel_offset = rhs.channel_offset;
458 this->step_x = rhs.step_x;
459 this->step_y = rhs.step_y;
460 return *this;
461 }
462
463 template <tdesc_update_dir update_dir = tdesc_update_dir::x_dir>
464 __XETLA_API void update_tdesc(int offset) {
465 if constexpr (update_dir == tdesc_update_dir::x_dir) {
466 base_pointer += int64_t(offset) * sizeof(dtype);
467 base_x += offset;
468 } else {
469 base_pointer += int64_t(offset) * pitch_in_bytes;
470 base_y += offset;
471 }
472 }
473};
474
480template <typename dtype_, typename tile_desc_, gpu_arch arch_tag_,
481 uint32_t alignment_>
483 mem_desc_t<dtype_, mem_layout::row_major, mem_space::local, alignment_>,
484 tile_desc_, msg_type::block_1d, arch_tag_,
485 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
487 mem_space::local, alignment_>;
488 using dtype = dtype_;
489 using tile_desc = tile_desc_;
490 static constexpr mem_space memory_space = mem_space::local;
491 static constexpr mem_layout memory_layout = mem_layout::row_major;
492 static constexpr msg_type message_type = msg_type::block_1d;
493 static constexpr uint32_t alignment_in_bytes
494 = mem_desc_t::alignment_in_bytes;
495 static_assert((alignment_in_bytes % sizeof(uint32_t)) == 0,
496 "alignment should at least DW aligned");
497 static constexpr gpu_arch arch_tag = arch_tag_;
498
499private:
500 static constexpr uint32_t tile_size_x = tile_desc::tile_size_x;
501 static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
502 static constexpr uint32_t block_size_x = tile_desc::block_size_x;
503 static constexpr uint32_t block_size_y = tile_desc::block_size_y;
505 msg_type::block_1d, arch_tag>;
506
507public:
508 static constexpr uint32_t tile_bytes
509 = tile_size_x * tile_size_y * sizeof(dtype);
510 static constexpr uint32_t block_bytes
511 = block_size_x * block_size_y * sizeof(dtype);
512 static constexpr uint32_t bytes_per_row = block_size_x * sizeof(dtype);
513 using mem_dtype = typename std::conditional<
514 (bytes_per_row % sizeof(uint64_t) == 0)
515 && (alignment_in_bytes % sizeof(uint64_t) == 0),
516 uint64_t,
517 typename std::conditional<(bytes_per_row % sizeof(uint32_t) == 0),
518 uint32_t, dtype>::type>::type;
519 static constexpr uint32_t scale_factor = sizeof(mem_dtype) / sizeof(dtype);
520
521 uint32_t address;
523
524 inline mem_payload_t(mem_desc_t &mem_tdesc) {
525 pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype);
526 uint32_t offset_x = mem_tdesc.coord.x;
527 uint32_t offset_y = mem_tdesc.coord.y;
528 address = mem_tdesc.base.base + offset_y * pitch_in_bytes
529 + offset_x * sizeof(dtype);
530 }
531 inline mem_payload_t(uint32_t base, [[maybe_unused]] int surface_width,
532 [[maybe_unused]] int surface_height, int surface_pitch,
533 int surface_offset_x, int surface_offset_y) {
534 uint32_t offset_x = surface_offset_x;
535 uint32_t offset_y = surface_offset_y;
536 pitch_in_bytes = surface_pitch * sizeof(dtype);
537 address = base + offset_y * pitch_in_bytes + offset_x * sizeof(dtype);
538 }
539
540 __XETLA_API void init(mem_desc_t &mem_tdesc) {
541 pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype);
542 uint32_t offset_x = mem_tdesc.coord.x;
543 uint32_t offset_y = mem_tdesc.coord.y;
544 address = mem_tdesc.base.base + offset_y * pitch_in_bytes
545 + offset_x * sizeof(dtype);
546 }
547
548 __XETLA_API void init(uint32_t base, [[maybe_unused]] int surface_width,
549 [[maybe_unused]] int surface_height, int surface_pitch,
550 int surface_offset_x, int surface_offset_y) {
551 uint32_t offset_x = surface_offset_x;
552 uint32_t offset_y = surface_offset_y;
553 pitch_in_bytes = surface_pitch * sizeof(dtype);
554 address = base + offset_y * pitch_in_bytes + offset_x * sizeof(dtype);
555 }
556
557 inline mem_payload_t(const this_payload_t &rhs) {
558 this->address = rhs.address;
559 this->pitch_in_bytes = rhs.pitch_in_bytes;
560 }
561
562 inline mem_payload_t() = default;
564 this->address = rhs.address;
565 this->pitch_in_bytes = rhs.pitch_in_bytes;
566 return *this;
567 }
568
569 template <tdesc_update_dir update_dir = tdesc_update_dir::x_dir>
570 __XETLA_API void update_tdesc(int offset) {
571 if constexpr (update_dir == tdesc_update_dir::x_dir) {
572 address += offset * sizeof(dtype);
573 } else {
574 address += offset * pitch_in_bytes;
575 }
576 }
577};
578
588template <typename dtype_, typename tile_desc_, mem_layout mem_layout_,
589 uint32_t alignment_, gpu_arch arch_tag_>
591 mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_>,
592 tile_desc_, msg_type::unaligned_2d, arch_tag_,
593 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
594 using dtype = dtype_;
597 using tile_desc = tile_desc_;
598 static constexpr mem_space memory_space = mem_space::global;
599 static constexpr mem_layout memory_layout = mem_layout_;
600 static constexpr msg_type message_type = msg_type::unaligned_2d;
601 static constexpr uint32_t alignment_in_bytes
602 = mem_desc_t::alignment_in_bytes;
603 static constexpr gpu_arch arch_tag = gpu_arch::Xe;
604
605private:
606 static constexpr uint32_t tile_size_x = tile_desc::tile_size_x;
607 static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
608 static constexpr uint32_t block_size_x = tile_desc::block_size_x;
609 static constexpr uint32_t block_size_y = tile_desc::block_size_y;
610
613
614public:
615 static constexpr bool mem_transpose
616 = memory_layout == mem_layout::col_major;
617
618 static constexpr reg_layout register_layout = tile_desc::register_layout;
619 static constexpr bool reg_transpose
620 = register_layout == reg_layout::transpose_tiled;
621 static constexpr bool trans = mem_transpose ^ reg_transpose;
622
623 static constexpr bool mem_transform = (sizeof(dtype) < 4)
624 && (register_layout == reg_layout::vnni_tiled
625 || register_layout == reg_layout::vnni_tiled_col_major);
626
627 static constexpr uint32_t tile_bytes
628 = tile_size_x * tile_size_y * sizeof(dtype);
629 static constexpr uint32_t block_bytes
630 = block_size_x * block_size_y * sizeof(dtype);
631
632 using mem_dtype = typename std::conditional<
633 (alignment_in_bytes % sizeof(uint64_t) == 0), uint64_t,
634 typename std::conditional<(alignment_in_bytes % sizeof(uint32_t)
635 == 0),
636 uint32_t, dtype>::type>::type;
637 static constexpr uint32_t scale_factor = sizeof(mem_dtype) / sizeof(dtype);
638
639 // for pvc, we can use simd16 or simd32
640 static constexpr uint32_t min_store_bytes = 16 * sizeof(dtype);
641 static constexpr uint32_t max_store_bytes = 32 * sizeof(dtype);
642 static constexpr uint32_t num_channel
643 = ((tile_bytes % max_store_bytes) == 0
644 && (block_bytes % max_store_bytes) == 0)
645 ? 32
646 : 16;
647
648 static constexpr uint32_t num_channel_x
649 = block_size_x * sizeof(dtype) / sizeof(mem_dtype);
650 static constexpr uint32_t num_channel_y = num_channel / num_channel_x;
651
655
656 uint64_t base_offset;
657 uint32_t base_x;
658 uint32_t base_y;
661
664
665 inline mem_payload_t(mem_desc_t &mem_tdesc) {
666 pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype);
667 base_x = mem_tdesc.coord.x;
668 base_y = mem_tdesc.coord.y;
669 width_in_elems = mem_tdesc.shape.x;
670 height_in_elems = mem_tdesc.shape.y;
671 base_offset = trans ? base_x * pitch_in_bytes + base_y * sizeof(dtype)
672 : base_y * pitch_in_bytes + base_x * sizeof(dtype);
673 base_ptr = (mem_dtype *)mem_tdesc.base.base;
674
676 = xetla_vector_gen<uint32_t, num_channel>(0, 1);
677 step_x = channel_index % num_channel_x;
678 step_y = channel_index / num_channel_x;
679 channel_offset = trans
680 ? step_y * sizeof(mem_dtype) + step_x * pitch_in_bytes
681 : step_x * sizeof(mem_dtype) + step_y * pitch_in_bytes;
682 }
683
684 inline mem_payload_t(dtype *p, int surface_width, int surface_height,
685 int surface_pitch, int surface_offset_x, int surface_offset_y) {
686 pitch_in_bytes = surface_pitch * sizeof(dtype);
687 base_x = surface_offset_x;
688 base_y = surface_offset_y;
689 width_in_elems = surface_width;
690 height_in_elems = surface_height;
691 base_offset = trans ? base_x * pitch_in_bytes + base_y * sizeof(dtype)
692 : base_y * pitch_in_bytes + base_x * sizeof(dtype);
693 base_ptr = (mem_dtype *)p;
694
696 = xetla_vector_gen<uint32_t, num_channel>(0, 1);
697 step_x = channel_index % num_channel_x;
698 step_y = channel_index / num_channel_x;
699 channel_offset = trans
700 ? step_y * sizeof(mem_dtype) + step_x * pitch_in_bytes
701 : step_x * sizeof(mem_dtype) + step_y * pitch_in_bytes;
702 }
703
704 __XETLA_API void init(mem_desc_t &mem_tdesc) {
705 pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype);
706 base_x = mem_tdesc.coord.x;
707 base_y = mem_tdesc.coord.y;
708 width_in_elems = mem_tdesc.shape.x;
709 height_in_elems = mem_tdesc.shape.y;
710 base_offset = trans ? base_x * pitch_in_bytes + base_y * sizeof(dtype)
711 : base_y * pitch_in_bytes + base_x * sizeof(dtype);
712 base_ptr = (mem_dtype *)mem_tdesc.base.base;
713
715 = xetla_vector_gen<uint32_t, num_channel>(0, 1);
716 step_x = channel_index % num_channel_x;
717 step_y = channel_index / num_channel_x;
718 channel_offset = trans
719 ? step_y * sizeof(mem_dtype) + step_x * pitch_in_bytes
720 : step_x * sizeof(mem_dtype) + step_y * pitch_in_bytes;
721 }
722
723 __XETLA_API void init(dtype *p, int surface_width, int surface_height,
724 int surface_pitch, int surface_offset_x, int surface_offset_y) {
725 pitch_in_bytes = surface_pitch * sizeof(dtype);
726 base_x = surface_offset_x;
727 base_y = surface_offset_y;
728 width_in_elems = surface_width;
729 height_in_elems = surface_height;
730 base_offset = trans ? base_x * pitch_in_bytes + base_y * sizeof(dtype)
731 : base_y * pitch_in_bytes + base_x * sizeof(dtype);
732 base_ptr = (mem_dtype *)p;
733
735 = xetla_vector_gen<uint32_t, num_channel>(0, 1);
736 step_x = channel_index % num_channel_x;
737 step_y = channel_index / num_channel_x;
738 channel_offset = trans
739 ? step_y * sizeof(mem_dtype) + step_x * pitch_in_bytes
740 : step_x * sizeof(mem_dtype) + step_y * pitch_in_bytes;
741 }
742
743 inline mem_payload_t(const this_payload_t &rhs) {
744 this->base_offset = rhs.base_offset;
745 this->base_ptr = rhs.base_ptr;
746 this->pitch_in_bytes = rhs.pitch_in_bytes;
747 this->base_x = rhs.base_x;
748 this->base_y = rhs.base_y;
749 this->width_in_elems = rhs.width_in_elems;
750 this->height_in_elems = rhs.height_in_elems;
751
752 this->step_x = rhs.step_x;
753 this->step_y = rhs.step_y;
754
755 this->channel_offset = rhs.channel_offset;
756 }
757
758 inline mem_payload_t() = default;
760 this->base_offset = rhs.base_offset;
761 this->base_ptr = rhs.base_ptr;
762 this->pitch_in_bytes = rhs.pitch_in_bytes;
763 this->base_x = rhs.base_x;
764 this->base_y = rhs.base_y;
765 this->width_in_elems = rhs.width_in_elems;
766 this->height_in_elems = rhs.height_in_elems;
767
768 this->step_x = rhs.step_x;
769 this->step_y = rhs.step_y;
770 this->channel_offset = rhs.channel_offset;
771
772 return *this;
773 }
774
775 template <tdesc_update_dir update_dir = tdesc_update_dir::x_dir>
776 __XETLA_API void update_tdesc(int offset) {
777 if constexpr (update_dir == tdesc_update_dir::x_dir) {
778 base_offset += int64_t(offset) * sizeof(dtype);
779 trans ? base_y += offset : base_x += offset;
780 } else {
781 base_offset += int64_t(offset) * pitch_in_bytes;
782 trans ? base_x += offset : base_y += offset;
783 }
784 }
785};
786
792template <typename dtype_, typename tile_desc_, gpu_arch arch_tag_,
793 uint32_t alignment_>
795 mem_desc_t<dtype_, mem_layout::row_major, mem_space::local, alignment_>,
796 tile_desc_, msg_type::scatter, arch_tag_,
797 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
799 mem_space::local, alignment_>;
800 using dtype = dtype_;
801 using tile_desc = tile_desc_;
802 static constexpr mem_space memory_space = mem_space::local;
803 static constexpr mem_layout memory_layout = mem_layout::row_major;
804 static constexpr msg_type message_type = msg_type::scatter;
805 static constexpr uint32_t alignment_in_bytes
806 = mem_desc_t::alignment_in_bytes;
807 static constexpr gpu_arch arch_tag = arch_tag_;
808
809private:
810 static constexpr uint32_t tile_size_x = tile_desc::tile_size_x;
811 static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
812 static constexpr uint32_t block_size_x = tile_desc::block_size_x;
813 static constexpr uint32_t block_size_y = tile_desc::block_size_y;
814 using this_payload_t
816
817public:
818 static constexpr reg_layout register_layout = tile_desc::register_layout;
819 static constexpr bool mem_transform
820 = (sizeof(dtype) < 4) && register_layout == reg_layout::vnni_tiled;
821
822 static constexpr uint32_t tile_bytes
823 = tile_size_x * tile_size_y * sizeof(dtype);
824 static constexpr uint32_t block_bytes
825 = block_size_x * block_size_y * sizeof(dtype);
826 using mem_dtype = typename std::conditional<
827 (block_bytes % (16 * sizeof(uint64_t)) == 0), uint64_t,
828 typename std::conditional<(block_bytes % (16 * sizeof(uint32_t))
829 == 0),
830 uint32_t, dtype>::type>::type;
831 // we can use simd16 or simd32
832 static constexpr uint32_t min_bytes = 16 * sizeof(mem_dtype);
833 static constexpr uint32_t max_bytes = 32 * sizeof(mem_dtype);
834
835 static constexpr uint32_t num_channel
836 = ((tile_bytes % max_bytes) == 0 && (block_bytes % max_bytes) == 0)
837 ? 32
838 : 16;
839 static constexpr uint32_t num_channel_x
840 = block_size_x * sizeof(dtype) / sizeof(mem_dtype);
841 static constexpr uint32_t num_channel_y = num_channel / num_channel_x;
846
847 inline mem_payload_t(mem_desc_t &mem_tdesc) {
848 xetla_tdescriptor base_tdesc = mem_tdesc.get_tdesc();
849 pitch_in_bytes = base_tdesc[4];
850 wg_width_in_bytes = base_tdesc[2];
851 wg_height_in_elems = base_tdesc[3];
852 uint32_t offset_x = base_tdesc[5];
853 uint32_t offset_y = base_tdesc[6];
854 uint32_t start_address = base_tdesc[0];
855 start_address += offset_y * pitch_in_bytes + offset_x * sizeof(dtype);
857 = xetla_vector_gen<uint32_t, num_channel>(0, 1);
858 address = start_address
859 + (channel_index % num_channel_x) * sizeof(mem_dtype)
860 + (channel_index / num_channel_x) * pitch_in_bytes;
861 }
862
863 inline mem_payload_t(uint32_t base, int surface_width, int surface_height,
864 int surface_pitch, int surface_offset_x, int surface_offset_y) {
865 pitch_in_bytes = surface_pitch * sizeof(dtype);
866 wg_width_in_bytes = surface_width * sizeof(dtype);
867 wg_height_in_elems = surface_height;
868 uint32_t offset_x = surface_offset_x;
869 uint32_t offset_y = surface_offset_y;
870 uint32_t start_address = base;
871 start_address += offset_y * pitch_in_bytes + offset_x * sizeof(dtype);
873 = xetla_vector_gen<uint32_t, num_channel>(0, 1);
874 address = start_address
875 + (channel_index % num_channel_x) * sizeof(mem_dtype)
876 + (channel_index / num_channel_x) * pitch_in_bytes;
877 }
878
879 __XETLA_API void init(uint32_t base, int surface_width, int surface_height,
880 int surface_pitch, int surface_offset_x, int surface_offset_y) {
881 pitch_in_bytes = surface_pitch * sizeof(dtype);
882 wg_width_in_bytes = surface_width * sizeof(dtype);
883 wg_height_in_elems = surface_height;
884 uint32_t offset_x = surface_offset_x;
885 uint32_t offset_y = surface_offset_y;
886 uint32_t start_address = base;
887 start_address += offset_y * pitch_in_bytes + offset_x * sizeof(dtype);
889 = xetla_vector_gen<uint32_t, num_channel>(0, 1);
890 address = start_address
891 + (channel_index % num_channel_x) * sizeof(mem_dtype)
892 + (channel_index / num_channel_x) * pitch_in_bytes;
893 }
894
895 __XETLA_API void init(mem_desc_t &mem_tdesc) {
896 xetla_tdescriptor base_tdesc = mem_tdesc.get_tdesc();
897 pitch_in_bytes = base_tdesc[4];
898 wg_width_in_bytes = base_tdesc[2];
899 wg_height_in_elems = base_tdesc[3];
900 uint32_t offset_x = base_tdesc[5];
901 uint32_t offset_y = base_tdesc[6];
902 uint32_t start_address = base_tdesc[0];
903 start_address += offset_y * pitch_in_bytes + offset_x * sizeof(dtype);
905 = xetla_vector_gen<uint32_t, num_channel>(0, 1);
906 address = start_address
907 + (channel_index % num_channel_x) * sizeof(mem_dtype)
908 + (channel_index / num_channel_x) * pitch_in_bytes;
909 }
910
911 inline mem_payload_t(const this_payload_t &rhs) {
912 this->address = rhs.address;
913 this->pitch_in_bytes = rhs.pitch_in_bytes;
914 this->wg_width_in_bytes = rhs.wg_width_in_bytes;
915 this->wg_height_in_elems = rhs.wg_height_in_elems;
916 }
917
918 inline mem_payload_t() = default;
920 this->address = rhs.address;
921 this->pitch_in_bytes = rhs.pitch_in_bytes;
922 this->wg_width_in_bytes = rhs.wg_width_in_bytes;
923 this->wg_height_in_elems = rhs.wg_height_in_elems;
924 return *this;
925 }
926
927 template <tdesc_update_dir update_dir = tdesc_update_dir::x_dir>
928 __XETLA_API void update_tdesc(int offset) {
929 if constexpr (update_dir == tdesc_update_dir::x_dir) {
930 address += offset * sizeof(dtype);
931 } else {
932 address += offset * pitch_in_bytes;
933 }
934 }
935};
936
943template <typename dtype_, uint32_t tile_size_x_, uint32_t tile_size_y_,
944 uint32_t block_size_x_, uint32_t block_size_y_, gpu_arch arch_tag_,
945 uint32_t alignment_>
947 mem_desc_t<dtype_, mem_layout::row_major, mem_space::local, alignment_>,
948 tile_desc_t<tile_size_x_, tile_size_y_, block_size_x_, block_size_y_,
949 reg_layout::vnni_tiled_col_major>,
950 msg_type::scatter, arch_tag_,
951 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
953 mem_space::local, alignment_>;
954 using dtype = dtype_;
955 using tile_desc = tile_desc_t<tile_size_x_, tile_size_y_, block_size_x_,
956 block_size_y_, reg_layout::vnni_tiled_col_major>;
957 static constexpr mem_space memory_space = mem_space::local;
958 static constexpr mem_layout memory_layout = mem_layout::row_major;
959 static constexpr msg_type message_type = msg_type::scatter;
960 static constexpr uint32_t alignment_in_bytes
961 = mem_desc_t::alignment_in_bytes;
962 static constexpr gpu_arch arch_tag = arch_tag_;
963
964private:
965 static constexpr uint32_t tile_size_x = tile_desc::tile_size_x;
966 static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
967 static constexpr uint32_t block_size_x = tile_desc::block_size_x;
968 static constexpr uint32_t block_size_y = tile_desc::block_size_y;
969 using this_payload_t
971
972public:
973 static constexpr uint32_t tile_bytes
974 = tile_size_x * tile_size_y * sizeof(dtype);
975 static constexpr uint32_t block_bytes
976 = block_size_x * block_size_y * sizeof(dtype);
977 using store_dtype = uint32_t;
978 static constexpr uint32_t vnni_scale_factor
979 = sizeof(store_dtype) / sizeof(dtype);
980 static constexpr uint32_t is_simd16_vec
981 = (block_size_x == 16) && ((tile_size_y & (tile_size_y - 1)) == 0);
982 static constexpr uint32_t num_vector_size = is_simd16_vec
983 ? detail::gcd<tile_size_y / vnni_scale_factor, 8>::value
984 : 1;
985
986 static constexpr uint32_t min_store_bytes = 16 * sizeof(store_dtype);
987 static constexpr uint32_t max_store_bytes = 32 * sizeof(store_dtype);
988 static constexpr uint32_t num_channel = is_simd16_vec
989 ? 16
990 : (((tile_bytes % max_store_bytes) == 0
991 && (block_bytes % max_store_bytes) == 0)
992 ? 32
993 : 16);
994 static constexpr uint32_t num_channel_x = block_size_x;
995 static constexpr uint32_t num_channel_y
996 = is_simd16_vec ? 1 : num_channel / num_channel_x;
997 static constexpr uint32_t store_elems = num_channel_y * num_vector_size
998 * vnni_scale_factor * block_size_x;
1004
1005 // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor)
1006 // Please check if you need to add self-define destructor
1007 // ~mem_payload_t(){}
1008 inline mem_payload_t(mem_desc_t mem_tdesc) {
1009 xetla_tdescriptor base_tdesc = mem_tdesc.get_tdesc();
1010 cyclic_count = 0;
1011 pitch_in_bytes = base_tdesc[4];
1012 wg_width_in_bytes = base_tdesc[2];
1013 wg_height_in_elems = base_tdesc[3];
1014 uint32_t offset_x = base_tdesc[5];
1015 uint32_t offset_y = base_tdesc[6];
1016 uint32_t start_address = base_tdesc[0];
1017 start_address += offset_y * pitch_in_bytes + offset_x * sizeof(dtype);
1019 = xetla_vector_gen<uint32_t, num_channel>(0, 1);
1020 address = start_address
1021 + (channel_index % num_channel_x) * pitch_in_bytes
1022 + (channel_index / num_channel_x) * sizeof(store_dtype);
1023 }
1024
1025 inline mem_payload_t(uint32_t base, int surface_width, int surface_height,
1026 int surface_pitch, int surface_offset_x, int surface_offset_y) {
1027 pitch_in_bytes = surface_pitch * sizeof(dtype);
1028 wg_width_in_bytes = surface_width * sizeof(dtype);
1029 wg_height_in_elems = surface_height;
1030 uint32_t offset_x = surface_offset_x;
1031 uint32_t offset_y = surface_offset_y;
1032 uint32_t start_address = base;
1033 start_address += offset_y * pitch_in_bytes + offset_x * sizeof(dtype);
1035 = xetla_vector_gen<uint32_t, num_channel>(0, 1);
1036 address = start_address
1037 + ((channel_index % num_channel_x) * pitch_in_bytes
1038 + (channel_index / num_channel_x)
1039 * sizeof(store_dtype));
1040 cyclic_count = 0;
1041 }
1042
1043 __XETLA_API void init(uint32_t base, int surface_width, int surface_height,
1044 int surface_pitch, int surface_offset_x, int surface_offset_y) {
1045 pitch_in_bytes = surface_pitch * sizeof(dtype);
1046 wg_width_in_bytes = surface_width * sizeof(dtype);
1047 wg_height_in_elems = surface_height;
1048 uint32_t offset_x = surface_offset_x;
1049 uint32_t offset_y = surface_offset_y;
1050 uint32_t start_address = base;
1051 start_address += offset_y * pitch_in_bytes + offset_x * sizeof(dtype);
1053 = xetla_vector_gen<uint32_t, num_channel>(0, 1);
1054 address = start_address
1055 + ((channel_index % num_channel_x) * pitch_in_bytes
1056 + (channel_index / num_channel_x)
1057 * sizeof(store_dtype));
1058 cyclic_count = 0;
1059 }
1060
1061 __XETLA_API void init(mem_desc_t mem_tdesc) {
1062 xetla_tdescriptor base_tdesc = mem_tdesc.get_tdesc();
1063 cyclic_count = 0;
1064 pitch_in_bytes = base_tdesc[4];
1065 wg_width_in_bytes = base_tdesc[2];
1066 wg_height_in_elems = base_tdesc[3];
1067 uint32_t offset_x = base_tdesc[5];
1068 uint32_t offset_y = base_tdesc[6];
1069 uint32_t start_address = base_tdesc[0];
1070 start_address += offset_y * pitch_in_bytes + offset_x * sizeof(dtype);
1072 = xetla_vector_gen<uint32_t, num_channel>(0, 1);
1073 address = start_address
1074 + (channel_index % num_channel_x) * pitch_in_bytes
1075 + (channel_index / num_channel_x) * sizeof(store_dtype);
1076 }
1077
1078 inline mem_payload_t(const this_payload_t &rhs) {
1079 this->address = rhs.address;
1080 this->pitch_in_bytes = rhs.pitch_in_bytes;
1081 this->cyclic_count = 0;
1082 this->wg_width_in_bytes = rhs.wg_width_in_bytes;
1083 this->wg_height_in_elems = rhs.wg_height_in_elems;
1084 }
1085
1086 inline mem_payload_t() = default;
1088 this->address = rhs.address;
1089 this->pitch_in_bytes = rhs.pitch_in_bytes;
1090 this->cyclic_count = 0;
1091 this->wg_width_in_bytes = rhs.wg_width_in_bytes;
1092 this->wg_height_in_elems = rhs.wg_height_in_elems;
1093 return *this;
1094 }
1095
1096 template <tdesc_update_dir update_dir = tdesc_update_dir::x_dir>
1097 __XETLA_API void update_tdesc(int offset) {
1098 if constexpr (update_dir == tdesc_update_dir::x_dir) {
1099 address += offset * sizeof(dtype);
1100 } else {
1101 address += offset * pitch_in_bytes;
1102 }
1103 }
1104};
1105
1112template <typename dtype_, uint32_t tile_size_x_, uint32_t tile_size_y_,
1113 uint32_t block_size_x_, uint32_t block_size_y_, mem_layout mem_layout_,
1114 uint32_t alignment_, uint32_t num_coop_sg_, reg_layout reg_layout_,
1115 gpu_arch arch_tag_>
1117 mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_>,
1118 tile_desc_t<tile_size_x_, tile_size_y_, block_size_x_, block_size_y_,
1119 reg_layout_>,
1120 num_coop_sg_, arch_tag_,
1121 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
1122 using dtype = dtype_;
1125 using tile_desc = tile_desc_t<tile_size_x_, tile_size_y_, block_size_x_,
1126 block_size_y_, reg_layout_>;
1127 static constexpr mem_space memory_space = mem_space::global;
1128 static constexpr mem_layout memory_layout = mem_layout_;
1129 static constexpr gpu_arch arch_tag = arch_tag_;
1130
1131private:
1132 static constexpr uint32_t tile_size_x = tile_desc::tile_size_x;
1133 static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
1134 static constexpr bool is_col_major = mem_layout_ == mem_layout::col_major;
1135 static constexpr uint32_t mem_tile_size_w
1136 = is_col_major ? tile_size_y : tile_size_x;
1137 static constexpr uint32_t mem_tile_size_h
1138 = is_col_major ? tile_size_x : tile_size_y;
1139 using load_store_attr = typename arch_attr_t<
1140 arch_tag>::template load_store_attr<msg_type::block_2d>;
1141 static constexpr uint32_t special_prefetch_width
1142 = load_store_attr::special_prefetch_width_in_bytes / sizeof(dtype);
1143 static constexpr uint32_t normal_prefetch_width
1144 = load_store_attr::max_load_width_in_bytes / sizeof(dtype);
1145 static constexpr bool is_special_prefetch
1146 = (mem_tile_size_w % special_prefetch_width) == 0;
1147
1148 static constexpr uint32_t block_size_w = is_special_prefetch
1149 ? special_prefetch_width
1150 : (normal_prefetch_width > mem_tile_size_w ? mem_tile_size_w
1151 : normal_prefetch_width);
1152 static constexpr uint32_t block_size_h
1153 = load_store_attr::max_load_height_in_elem;
1154 //could have over-prefetch, but that's should be fine
1155 static constexpr uint32_t max_num_block_w
1156 = (mem_tile_size_w + block_size_w - 1) / block_size_w;
1157 static constexpr uint32_t num_coop_sg = num_coop_sg_;
1158 static constexpr uint32_t num_coop_sg_w
1160 static constexpr uint32_t num_coop_sg_h = num_coop_sg / num_coop_sg_w;
1161
1162 static constexpr uint32_t num_block_w = max_num_block_w / num_coop_sg_w;
1163 static constexpr uint32_t tile_size_w = block_size_w * num_block_w;
1164 static constexpr uint32_t tile_size_h
1165 = (mem_tile_size_h + num_coop_sg_h - 1) / num_coop_sg_h;
1166 static constexpr uint32_t num_block_h
1167 = (tile_size_h + block_size_h - 1) / block_size_h;
1168 using this_payload_t
1170
1171public:
1172 static constexpr uint32_t num_tdesc = num_block_w * num_block_h;
1174
1176 this->tdesc_prefetch = rhs.tdesc_prefetch;
1177 }
1178
1179 inline prefetch_payload_t() = default;
1180
1182 this->tdesc_prefetch = rhs.tdesc_prefetch;
1183 return *this;
1184 }
1185
1186 inline prefetch_payload_t(mem_desc_t &mem_desc, uint32_t coop_id = 0) {
1187 xetla_tdescriptor base_tdesc = mem_desc.get_tdesc();
1188 uint32_t coop_id_x = coop_id % num_coop_sg_w;
1189 uint32_t coop_id_y = coop_id / num_coop_sg_w;
1191 base_tdesc.xetla_format<uint32_t>(), coop_id_x * tile_size_w);
1193 base_tdesc.xetla_format<uint32_t>(), coop_id_y * tile_size_h);
1194 prepare_tdesc(base_tdesc);
1195 }
1196
1197 inline prefetch_payload_t(dtype *p, int surface_width, int surface_height,
1198 int surface_pitch, int surface_offset_x, int surface_offset_y,
1199 uint32_t coop_id = 0) {
1200 uint32_t coop_id_x = coop_id % num_coop_sg_w;
1201 uint32_t coop_id_y = coop_id / num_coop_sg_w;
1202 xetla_tdescriptor base_tdesc;
1203 xetla_fill_tdesc(base_tdesc.xetla_format<uint32_t>(), p, surface_width,
1204 surface_height, surface_pitch,
1205 surface_offset_x + coop_id_x * tile_size_w,
1206 surface_offset_y + coop_id_y * tile_size_h);
1207 prepare_tdesc(base_tdesc);
1208 }
1209 // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor)
1210 // Please check if you need to add self-define destructor
1211 // ~prefetch_payload_t(){}
1212
1213 template <tdesc_update_dir update_dir = tdesc_update_dir::x_dir>
1214 __XETLA_API void update_tdesc(int offset) {
1215 auto tdesc_2d = tdesc_prefetch.xetla_format<uint32_t, num_tdesc, 16>();
1216 if constexpr (update_dir == tdesc_update_dir::x_dir) {
1217#pragma unroll
1218 for (uint32_t i = 0; i < num_tdesc; i++) {
1219 xetla_update_tdesc_offsetx(tdesc_2d.row(i), offset);
1220 }
1221 } else {
1222#pragma unroll
1223 for (uint32_t i = 0; i < num_tdesc; i++) {
1224 xetla_update_tdesc_offsety(tdesc_2d.row(i), offset);
1225 }
1226 }
1227 }
1228
1229private:
1230 __XETLA_API void prepare_tdesc(xetla_tdescriptor base_tdesc) {
1231 auto tdesc_2d = tdesc_prefetch.xetla_format<uint32_t, num_tdesc, 16>();
1232 uint32_t base_offset_y = 0;
1233#pragma unroll
1234 for (uint32_t i = 0; i < tile_size_h / block_size_h; i++) {
1235 auto tdesc_row_2d = tdesc_2d.xetla_select<num_block_w, 1, 16, 1>(
1236 i * num_block_w, 0);
1237 prepare_tile_desc_core<num_block_w, block_size_w, block_size_h>(
1238 tdesc_row_2d, base_tdesc, base_offset_y);
1239 base_offset_y += block_size_h;
1240 }
1241 if constexpr ((tile_size_h % block_size_h) != 0) {
1242 constexpr int i = tile_size_h / block_size_h;
1243 auto tdesc_row_2d = tdesc_2d.xetla_select<num_block_w, 1, 16, 1>(
1244 i * num_block_w, 0);
1245 constexpr uint32_t remain_size_y = tile_size_h % block_size_h;
1246 prepare_tile_desc_core<num_block_w, block_size_w, remain_size_y>(
1247 tdesc_row_2d, base_tdesc, base_offset_y);
1248 }
1249 }
1250
1251 template <int32_t num_tdesc, uint32_t size_x, uint32_t size_y>
1252 __XETLA_API static void prepare_tile_desc_core(
1253 xetla_matrix_ref<uint32_t, num_tdesc, 16> __REF__ tdesc_2d,
1254 xetla_tdescriptor base_tdesc, uint32_t base_offset_y) {
1255 uint32_t base_offset_x = 0;
1256#pragma unroll
1257 for (int j = 0; j < num_tdesc; j++) {
1258 tdesc_2d.row(j) = base_tdesc;
1259
1260 constexpr uint32_t block_widthx_widthy_arrlen
1261 = (size_x - 1) | ((size_y - 1) << 8);
1263 tdesc_2d.row(j), block_widthx_widthy_arrlen);
1264
1265 xetla_update_tdesc_offsetx(tdesc_2d.row(j), base_offset_x);
1266 xetla_update_tdesc_offsety(tdesc_2d.row(j), base_offset_y);
1267 base_offset_x += size_x;
1268 }
1269 }
1270};
1271
1278template <typename dtype_, uint32_t tile_size_x_, uint32_t block_size_x_,
1279 mem_layout mem_layout_, uint32_t alignment_, uint32_t num_coop_sg_,
1280 reg_layout reg_layout_, gpu_arch arch_tag_>
1282 mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_>,
1283 tile_desc_t<tile_size_x_, 1, block_size_x_, 1, reg_layout_>,
1284 num_coop_sg_, arch_tag_,
1285 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
1286 using dtype = dtype_;
1289 // CL aligned, so we can use uint64_t
1290 using prefetch_dtype = uint64_t;
1293 static constexpr mem_space memory_space = mem_space::global;
1294 static constexpr mem_layout memory_layout = mem_layout_;
1295 static constexpr gpu_arch arch_tag = arch_tag_;
1296
1297private:
1298 // Fetches the entire CL.
1299 static constexpr uint32_t cacheline_elems = 64 / sizeof(dtype);
1300 static constexpr uint32_t mem_block_nums
1301 = (tile_desc::tile_size_x + cacheline_elems - 1) / cacheline_elems;
1302 static constexpr uint32_t num_coop_sg = num_coop_sg_;
1303
1304 // For mem_tile_nums < num_coop_sg cases, mem_tile_size_x will be CL length
1305 // which might lead to illegal read.
1306 // there are num_coop_sg threads to prefetch mem_block_nums
1307 // each thread will prefetch mem_tile_size_x elements
1308 static constexpr uint32_t mem_tile_size_x = mem_block_nums > num_coop_sg
1309 ? (mem_block_nums + num_coop_sg - 1) / num_coop_sg *cacheline_elems
1310 : 0;
1311 using this_payload_t
1313
1314 // Fixed prefetch_dtype, close this assertion
1315 // static_assert(sizeof(prefetch_dtype) >= 4,
1316 // "prefetch dtype size should at least DW aligned");
1317
1318public:
1319 static constexpr uint32_t scale_factor
1320 = sizeof(prefetch_dtype) / sizeof(dtype);
1321 uint32_t base_offset;
1324
1326 this->base_offset = rhs.base_offset;
1327 this->base_ptr = rhs.base_ptr;
1328 this->pitch_in_bytes = rhs.pitch_in_bytes;
1329 }
1330
1331 inline prefetch_payload_t() = default;
1332
1334 this->base_offset = rhs.base_offset;
1335 this->base_ptr = rhs.base_ptr;
1336 this->pitch_in_bytes = rhs.pitch_in_bytes;
1337 return *this;
1338 }
1339
1340 inline prefetch_payload_t(mem_desc_t &mem_desc, uint32_t coop_id = 0) {
1341 pitch_in_bytes = mem_desc.shape.stride * sizeof(dtype);
1342 uint32_t offset_x = mem_desc.coord.x;
1343 uint32_t offset_y = mem_desc.coord.y;
1344 base_offset = offset_y * pitch_in_bytes + offset_x * sizeof(dtype);
1345 uint64_t ptr_temp = (uint64_t)mem_desc.base.base;
1346 base_ptr = (prefetch_dtype *)ptr_temp
1347 + (coop_id % num_coop_sg) * mem_tile_size_x;
1348 }
1349
1350 inline prefetch_payload_t(dtype *p, [[maybe_unused]] int surface_width,
1351 [[maybe_unused]] int surface_height, int surface_pitch,
1352 int surface_offset_x, int surface_offset_y, uint32_t coop_id = 0) {
1353 pitch_in_bytes = surface_pitch * sizeof(dtype);
1354 uint32_t offset_x = surface_offset_x;
1355 uint32_t offset_y = surface_offset_y;
1356 base_offset = offset_y * pitch_in_bytes + offset_x * sizeof(dtype);
1357 base_ptr = (prefetch_dtype *)p
1358 + (coop_id % num_coop_sg) * mem_tile_size_x;
1359 }
1360
1361 template <tdesc_update_dir update_dir = tdesc_update_dir::x_dir>
1362 __XETLA_API void update_tdesc(int offset) {
1363 if constexpr (update_dir == tdesc_update_dir::x_dir) {
1364 base_offset += offset * sizeof(dtype);
1365 } else {
1366 base_offset += offset * pitch_in_bytes;
1367 }
1368 }
1369};
1370
1377template <typename dtype_, typename tile_desc_, mem_layout mem_layout_,
1378 uint32_t alignment_, uint32_t num_coop_sg_, gpu_arch arch_tag_>
1380 mem_desc_t<dtype_, mem_layout_, mem_space::local, alignment_>,
1381 tile_desc_, num_coop_sg_, arch_tag_,
1382 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
1383 using dtype = dtype_;
1386 using tile_desc = tile_desc_;
1387 static constexpr mem_space memory_space = mem_space::local;
1388 static constexpr mem_layout memory_layout = mem_layout_;
1389 static constexpr gpu_arch arch_tag = arch_tag_;
1390
1391 inline prefetch_payload_t([[maybe_unused]] mem_desc_t &mem_desc,
1392 [[maybe_unused]] uint32_t coop_id = 0) {}
1393
1394 inline prefetch_payload_t([[maybe_unused]] dtype *p,
1395 [[maybe_unused]] int surface_width,
1396 [[maybe_unused]] int surface_height,
1397 [[maybe_unused]] int surface_pitch,
1398 [[maybe_unused]] int surface_offset_x,
1399 [[maybe_unused]] int surface_offset_y,
1400 [[maybe_unused]] uint32_t coop_id = 0) {}
1401
1402 template <tdesc_update_dir update_dir = tdesc_update_dir::x_dir>
1403 __XETLA_API void update_tdesc([[maybe_unused]] int offset) {}
1404};
1405
1406} // namespace gpu::xetla::subgroup
#define __XETLA_API
Definition common.hpp:43
#define __REF__
Workaround for ESIMD reference usage.
Definition base_types.hpp:177
xetla_vector< uint32_t, 16 > xetla_tdescriptor
Description of nd tensor descriptor for load and store.
Definition base_types.hpp:155
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__XETLA_API void xetla_update_tdesc_offsetx(xetla_tdescriptor_ref tdesc, int32_t doffset_x)
Update the x coordinate in the given tensor descriptor.
Definition raw_send_load_store.hpp:152
__XETLA_API void xetla_update_tdesc_offsety(xetla_tdescriptor_ref tdesc, int32_t doffset_y)
Update the y coordinate in the given tensor descriptor.
Definition raw_send_load_store.hpp:161
__XETLA_API void xetla_fill_tdesc(xetla_tdescriptor_ref tdesc, Ty *p, int tensor_width, int tensor_height, int tensor_pitch, int offset_x, int offset_y)
Tensor descriptor construction(global memory version).
Definition raw_send_load_store.hpp:52
__XETLA_API void xetla_set_tensor_offset_x(xetla_tdescriptor_ref desc, int32_t offset_x)
Definition tensor_descriptor.hpp:63
__XETLA_API void xetla_set_block_widthx_widthy_arrlen(xetla_tdescriptor_ref desc, uint32_t block_widthx_widthy_arrlen)
Definition tensor_descriptor.hpp:79
__XETLA_API int32_t xetla_get_tensor_offset_x(xetla_tdescriptor desc)
Definition tensor_descriptor.hpp:67
Definition limitation.hpp:457
reg_layout
tile layout in register linear: linear layout with one tile tiled: 2d block stacked in raster order v...
Definition common.hpp:209
@ vnni_tiled_col_major
this is vnni tiled format, but for each block, they are stored in col major order
mem_space
Definition common.hpp:77
gpu_arch
Definition common.hpp:73
msg_type
Definition common.hpp:78
mem_layout
Definition common.hpp:76
Definition arch_config.hpp:72
Definition common.hpp:80
mem_payload_t(dtype *p, int surface_width, int surface_height, int surface_pitch, int surface_offset_x, int surface_offset_y)
Definition payload_xe.hpp:390
__XETLA_API void init(dtype *p, int surface_width, int surface_height, int surface_pitch, int surface_offset_x, int surface_offset_y)
Definition payload_xe.hpp:406
mem_payload_t(dtype *p, int surface_width, int surface_height, int surface_pitch, int surface_offset_x, int surface_offset_y)
Definition payload_xe.hpp:684
typename std::conditional<(alignment_in_bytes % sizeof(uint64_t)==0), uint64_t, typename std::conditional<(alignment_in_bytes % sizeof(uint32_t)==0), uint32_t, dtype >::type >::type mem_dtype
Definition payload_xe.hpp:636
__XETLA_API void init(dtype *p, int surface_width, int surface_height, int surface_pitch, int surface_offset_x, int surface_offset_y)
Definition payload_xe.hpp:723
mem_payload_t(dtype *p, uint32_t surface_width, uint32_t surface_height, uint32_t surface_pitch, int32_t surface_offset_x=0, int32_t surface_offset_y=0)
Definition payload_xe.hpp:95
__XETLA_API void init(dtype *p, uint32_t surface_width, uint32_t surface_height, uint32_t surface_pitch, int32_t surface_offset_x=0, int32_t surface_offset_y=0)
Definition payload_xe.hpp:115
__XETLA_API void init(dtype *p, int surface_width, int surface_height, int surface_pitch, int surface_offset_x, int surface_offset_y)
Definition payload_xe.hpp:279
mem_payload_t(dtype *p, int surface_width, int surface_height, int surface_pitch, int surface_offset_x, int surface_offset_y)
Definition payload_xe.hpp:261
typename std::conditional<(bytes_per_row % sizeof(uint64_t)==0) &&(alignment_in_bytes % sizeof(uint64_t)==0), uint64_t, typename std::conditional<(bytes_per_row % sizeof(uint32_t)==0), uint32_t, dtype >::type >::type mem_dtype
Definition payload_xe.hpp:246
mem_payload_t(uint32_t base, int surface_width, int surface_height, int surface_pitch, int surface_offset_x, int surface_offset_y)
Definition payload_xe.hpp:863
typename std::conditional<(block_bytes %(16 *sizeof(uint64_t))==0), uint64_t, typename std::conditional<(block_bytes %(16 *sizeof(uint32_t))==0), uint32_t, dtype >::type >::type mem_dtype
Definition payload_xe.hpp:830
__XETLA_API void init(uint32_t base, int surface_width, int surface_height, int surface_pitch, int surface_offset_x, int surface_offset_y)
Definition payload_xe.hpp:879
typename std::conditional<(bytes_per_row % sizeof(uint64_t)==0) &&(alignment_in_bytes % sizeof(uint64_t)==0), uint64_t, typename std::conditional<(bytes_per_row % sizeof(uint32_t)==0), uint32_t, dtype >::type >::type mem_dtype
Definition payload_xe.hpp:518
__XETLA_API void init(uint32_t base, int surface_width, int surface_height, int surface_pitch, int surface_offset_x, int surface_offset_y)
Definition payload_xe.hpp:548
mem_payload_t(uint32_t base, int surface_width, int surface_height, int surface_pitch, int surface_offset_x, int surface_offset_y)
Definition payload_xe.hpp:531
Is to illustrate the memory information.
Definition api.hpp:44
prefetch_payload_t(dtype *p, int surface_width, int surface_height, int surface_pitch, int surface_offset_x, int surface_offset_y, uint32_t coop_id=0)
Definition payload_xe.hpp:1394
prefetch_payload_t(dtype *p, int surface_width, int surface_height, int surface_pitch, int surface_offset_x, int surface_offset_y, uint32_t coop_id=0)
Definition payload_xe.hpp:1350
Is to illustrate the memory information to prefetch data to cache.
Definition api.hpp:53
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
C++ API.