XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
limitation.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (c) 2022-2023 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
19
20#pragma once
21
24
25#define IN_RANGE(x, l, r) ((x) >= (l) && (x) <= (r))
26
27namespace gpu::xetla {
28
29namespace core {
30template <gpu_arch arch, typename T>
31struct general_1d {};
32template <gpu_arch arch, typename T>
33class block_2d {};
34
35template <typename T>
36struct general_1d<gpu_arch::Xe, T> {
37 template <uint8_t NElts>
38 static inline bool check_restriction(uint64_t offset, uint64_t p = 0) {
39 static_assert(sizeof(T) == 4 || sizeof(T) == 8,
40 "block 1d only support D32/D64");
41 static_assert(NElts == 1 || NElts == 2 || NElts == 3 || NElts == 4
42 || NElts == 8 || NElts == 16 || NElts == 32
43 || NElts == 64,
44 "block 1d not support the NElts");
45
46 bool ret = ((p + offset) % sizeof(T) == 0);
48 ret, "block 1d require data size aligned but is %u", offset);
49 return ret;
50 }
51
52 template <uint8_t NElts, int N, typename Toffset = uint32_t>
53 static inline bool check_restriction(
54 xetla_vector<Toffset, N> offsets, uint64_t p = 0) {
55 static_assert(
56 N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32,
57 "scatter not supported N");
58 if constexpr (NElts != 1) {
59 static_assert(NElts == 2 || NElts == 3 || NElts == 4 || NElts == 8,
60 "scatter not supported NElts");
61 for (auto i = 0; i < N; i++) {
62 bool ret = ((p + offsets[i]) % sizeof(T) == 0);
63 XETLA_ASSERT(ret, "scatter require data size aligned but is %u",
64 offsets[i]);
65 if (!ret) { return false; }
66 }
67 }
68
69 return true;
70 }
71};
72
73template <typename T>
74class block_2d<gpu_arch::Xe, T> {
75public:
76 template <bool transpose, bool vnni_transform>
77 static inline bool check_load(xetla_tdescriptor tdesc) {
78 if (!check_common(tdesc)) { return false; }
79
80 bool ret = false;
81 uint32_t block_width
83 uint32_t block_height
85 uint8_t array_len
87 int32_t block_start_x
89
90 ret = ((block_width * block_height * element_size)
91 <= (32 * bytes_per_grf));
92 XETLA_ASSERT(ret,
93 "2D Block Loads upto 32 GRFs are can be read but is %u:%u",
94 block_width, block_height);
95 if (!ret) { return false; }
96
97 if constexpr (transpose || vnni_transform) {
98 if constexpr (transpose) {
99 ret = (array_len == 1);
100 XETLA_ASSERT(ret,
101 "Transposed load do not allow an array length more "
102 "than 1 but is %d",
103 array_len);
104 if (!ret) { return false; }
105
106 ret = (element_size == 4 || element_size == 8);
107 XETLA_ASSERT(ret,
108 "For 2D Block Load with Transpose, allowed data "
109 "sizes "
110 "are d32 and d64 only but is %u",
111 element_size);
112 if (!ret) { return false; }
113
114 auto block_height_in_bytes = block_height * element_size;
115 ret = (block_height_in_bytes == 4 || block_height_in_bytes == 8
116 || block_height_in_bytes == 16
117 || block_height_in_bytes == 32
118 || block_height_in_bytes == 64
119 || block_height_in_bytes == 128);
120 XETLA_ASSERT(ret,
121 "For 2D load with Transpose, the pre-operation "
122 "block "
123 "Height is padded up to the next power-of-two "
124 "value "
125 "(minimum 4 bytes)"
126 "but is %u:%u",
127 block_height, element_size);
128 if (!ret) { return false; }
129
130 if (element_size == 4) {
131 ret = (IN_RANGE(block_height, 1, 32)
132 && IN_RANGE(block_width, 1, 8));
133 XETLA_ASSERT(ret,
134 "block height must be in 1 ~ 32, width in 1 ~ "
135 "8 "
136 "but is %u:%u",
137 block_height, block_width);
138 if (!ret) { return false; }
139 } else {
140 ret = (block_start_x >= 0);
141 XETLA_ASSERT(ret,
142 "When element is d64, block X_offset must be "
143 "non-negative but is %d",
144 block_start_x);
145 if (!ret) { return false; }
146
147 ret = ((block_height == 8)
148 && (block_width == 1 || block_width == 2
149 || block_width == 4));
150 XETLA_ASSERT(ret,
151 "block height must be 8, width is 1/2/4 but is "
152 "%u:%u",
153 block_height, block_width);
154 if (!ret) { return false; }
155 }
156 }
157
158 if constexpr (vnni_transform) {
159 ret = (element_size == 1 || element_size == 2);
160 XETLA_ASSERT(ret,
161 "For 2D Block Load with VNNI Transform, allowed "
162 "data "
163 "sizes are d8 and d16 only. but is %u",
164 element_size);
165 if (!ret) { return false; }
166
167 if constexpr (element_size == 1) {
168 ret = IN_RANGE(block_height, 4, 32)
169 && IN_RANGE(block_width, 4, 16)
170 && (array_len == 1 || array_len == 2
171 || array_len == 4)
172 && (block_width * array_len) <= 64
173 && (block_height * element_size) % 4 == 0;
174 XETLA_ASSERT(ret,
175 "For 2D Block Load with VNNI Transform, height "
176 "in "
177 "4 ~ 32, width in 4 ~ 16, array len 1/2/4, and "
178 "width X array_length <= 64 but is %u:%u:%u",
179 block_height, block_width, array_len);
180 if (!ret) { return false; }
181 } else {
182 ret = IN_RANGE(block_height, 2, 32)
183 && IN_RANGE(block_width, 2, 16)
184 && (array_len == 1 || array_len == 2
185 || array_len == 4)
186 && (block_width * array_len) <= 32
187 && (block_height * element_size) % 4 == 0;
188 XETLA_ASSERT(ret,
189 "For 2D Block Load with VNNI Transform, height "
190 "in "
191 "2 ~ 32, width in 2 ~ 16, array len 1/2/4, and "
192 "width X array_length <= 32 but is %u:%u:%u",
193 block_height, block_width, array_len);
194 if (!ret) { return false; }
195 }
196 }
197 } else {
198 ret = ((block_width * array_len * element_size) <= 64);
199 XETLA_ASSERT(ret,
200 "2D block load operations block_width times array_size "
201 "must "
202 "not exceed 64 bytes but is %u:%u",
203 block_width, array_len);
204 if (!ret) { return false; }
205 }
206
207 return true;
208 }
209
210 static inline bool check_store(xetla_tdescriptor tdesc) {
211 if (!check_common(tdesc)) { return false; }
212
213 uint32_t block_width
215 uint32_t block_height
217 uint8_t array_len
219
220 bool ret = false;
221
222 ret = (element_size == 1 || element_size == 2 || element_size == 4
223 || element_size == 8);
224 XETLA_ASSERT(ret,
225 "2D block store only support element size 1/2/4/8 but is "
226 "%u",
227 element_size);
228 if (!ret) { return false; }
229
230 ret = IN_RANGE(block_height, 1, 8);
231 XETLA_ASSERT(ret, "2D block height only support 1 ~ 8 but is %u",
232 block_height);
233 if (!ret) { return false; }
234
235 ret = (array_len == 1);
237 ret, "2D block array len only support 1 but is %u", array_len);
238 if (!ret) { return false; }
239
240 ret = ((block_width * element_size) >= 4
241 && (block_width * element_size) <= 64
242 && (block_width * block_height * element_size) <= 512);
243 XETLA_ASSERT(ret,
244 "2D block Store, block width * element size in 4 ~ 64 and "
245 "Total GRF data should not exceed 512 bytes but is "
246 "%u:%u:%u",
247 block_width, block_height, element_size);
248 if (!ret) { return false; }
249
250 return true;
251 }
252
253 static inline bool check_surface(
254 uint64_t base, uint32_t width, uint32_t height, uint32_t pitch) {
255 if (check_base_address(base) && check_surface_width(width)
256 && check_surface_height(height)
257 && check_surface_pitch(pitch, width)) {
258 return true;
259 }
260
261 return false;
262 }
263
264private:
265 static constexpr auto element_size = sizeof(T);
266 static constexpr uint32_t max_24bit = 16 * 1024 * 1024; // 2 ^ 24
267 static constexpr auto bytes_per_grf
269
270 static inline bool check_base_address(uint64_t base) {
271 bool ret = ((base % 64) == 0);
273 ret, "Base address must be CL (64B) aligned but is %p", base);
274 if (!ret) { return false; }
275
276 return ret;
277 }
278
279 static inline bool check_surface_width(uint32_t width_in_bytes) {
280 bool ret = (width_in_bytes <= max_24bit);
281 XETLA_ASSERT(ret,
282 "Only 24 bits are supported for surface width(in bytes) "
283 "but is %u",
284 width_in_bytes);
285 if (!ret) { return false; }
286
287 ret = (width_in_bytes >= 64);
288 XETLA_ASSERT(ret,
289 "Surface width(in bytes) must be equal or greater than 64B "
290 "but is %u",
291 width_in_bytes);
292 if (!ret) { return false; }
293
294 ret = ((width_in_bytes % (element_size > 4 ? element_size : 4)) == 0);
295 XETLA_ASSERT(ret,
296 "Surface width(in bytes) must be aligned to MAX(DW, "
297 "element_size) "
298 "but is "
299 "%u",
300 width_in_bytes);
301 if (!ret) { return false; }
302
303 return ret;
304 }
305
306 static inline bool check_surface_height(uint32_t height_in_elements) {
307 bool ret = (height_in_elements < max_24bit);
308 XETLA_ASSERT(ret,
309 "Only 24 bits are supported for surface height(in "
310 "elements) but is %u",
311 height_in_elements);
312 if (!ret) { return false; }
313
314 return ret;
315 }
316
317 static inline bool check_surface_pitch(
318 uint32_t pitch_in_bytes, uint32_t width_in_bytes) {
319 bool ret = (pitch_in_bytes >= width_in_bytes);
320 XETLA_ASSERT(ret,
321 "Pitch(in bytes) must be greater or equal to Width(in "
322 "bytes) but is %u:%u",
323 pitch_in_bytes, width_in_bytes);
324 if (!ret) { return false; }
325
326 ret = (pitch_in_bytes <= max_24bit);
327 XETLA_ASSERT(ret,
328 "Only 24 bits are supported for surface pitch(in bytes) "
329 "but is %u",
330 pitch_in_bytes);
331 if (!ret) { return false; }
332
333 ret = (pitch_in_bytes >= 64);
334 XETLA_ASSERT(ret,
335 "Surface pitch(in bytes) must be equal or greater than 64B "
336 "but is %u",
337 pitch_in_bytes);
338 if (!ret) { return false; }
339
340 ret = ((pitch_in_bytes % 8) == 0);
341 XETLA_ASSERT(ret,
342 "Surface pitch(in bytes) must be a multiple of QW (8 "
343 "bytes) but is "
344 "%u",
345 pitch_in_bytes);
346 if (!ret) { return false; }
347
348 return ret;
349 }
350
351 static inline bool check_block_start_x(int32_t x_in_elements) {
352 bool ret = true;
353
354 if constexpr (element_size == 1 || element_size == 2) {
355 ret = ((element_size == 1 && (x_in_elements % 4) == 0)
356 || (element_size == 2 && (x_in_elements % 2) == 0));
357 XETLA_ASSERT(ret,
358 "For element d8, Block StartX must be a multiple of "
359 "4. For element d16, Block Start X must be a "
360 "multiple of 2. but is "
361 "%d:%u",
362 x_in_elements, element_size);
363 if (!ret) { return false; }
364 }
365
366 return ret;
367 }
368
369 static inline bool check_block_width(uint32_t width_in_elements) {
370 bool ret = IN_RANGE(width_in_elements, 1, 64);
371 XETLA_ASSERT(ret,
372 "Block width(in elements) must be between 1-64 but is %u",
373 width_in_elements);
374 if (!ret) { return false; }
375
376 auto width_in_bytes = width_in_elements * element_size;
377
378 ret = (width_in_bytes % 4 == 0);
379 XETLA_ASSERT(ret,
380 "Block width in bytes must "
381 "be a "
382 "multiple of 4 bytes but is %u",
383 width_in_bytes);
384 if (!ret) { return false; }
385
386 ret = (width_in_bytes == 4 || width_in_bytes == 8
387 || width_in_bytes == 16 || width_in_bytes == 32
388 || width_in_bytes == 64 || width_in_bytes == 128
389 || width_in_bytes == 256 || width_in_bytes == 512);
390 XETLA_ASSERT(ret,
391 "2D block load/store , the block width in bytes is padded "
392 "up to 2^X "
393 "in the GRF but is %u:%u",
394 width_in_elements, element_size);
395 if (!ret) { return false; }
396
397 return ret;
398 }
399
400 static inline bool check_block_height(uint32_t height_in_elements) {
401 bool ret = IN_RANGE(height_in_elements, 1, 32);
402 XETLA_ASSERT(ret,
403 "Block height(in elements) must be between 1-32 but is %u",
404 height_in_elements);
405 if (!ret) { return false; }
406 return ret;
407 }
408
409 static inline bool check_array_len(uint32_t len) {
410 bool ret = IN_RANGE(len, 1, 4);
411 XETLA_ASSERT(ret, "Array Length must be in 1-4 but is %u", len);
412 if (!ret) { return false; }
413 return ret;
414 }
415
416 static inline bool check_block(int32_t x, [[maybe_unused]] int32_t y,
417 uint32_t width, uint32_t height, uint8_t array_len) {
418 if (check_block_start_x(x) && check_block_width(width)
419 && check_block_height(height) && check_array_len(array_len)) {
420 return true;
421 }
422 return false;
423 }
424
425 static inline bool check_common(xetla_tdescriptor tdesc) {
426 uint64_t base
428 uint32_t surface_width
430 uint32_t surface_height
432 uint32_t surface_pitch
434 int32_t block_start_x
436 int32_t block_start_y
438
439 uint32_t block_width
441 uint32_t block_height
443 uint8_t array_len
445
446 if (check_surface(base, surface_width, surface_height, surface_pitch)
447 && check_block(block_start_x, block_start_y, block_width,
448 block_height, array_len)) {
449 return true;
450 }
451
452 return false;
453 }
454};
455} // namespace core
456
457namespace subgroup {
458template <gpu_arch arch, typename dtype, typename mem_dtype>
459struct check_load {};
460template <gpu_arch arch, typename dtype, typename mem_dtype = uint32_t>
461struct check_store {};
462
463template <typename dtype, typename mem_dtype>
464struct check_load<gpu_arch::Xe, dtype, mem_dtype> {
465 template <bool mem_transform, size_t block_size_x>
466 struct global_2d {
467 using load_store_attr = typename arch_attr_t<
468 gpu_arch::Xe>::template load_store_attr<msg_type::block_2d>;
469 static constexpr int32_t max_vnni_block_width
470 = load_store_attr::max_vnni_load_width_in_elems;
471 static_assert(!mem_transform || block_size_x <= max_vnni_block_width,
472 "For VNNI transform, the maximum block width is 16 width.");
473 static constexpr int32_t max_block_width
474 = load_store_attr::max_load_width_in_bytes / sizeof(dtype);
475 static_assert((max_block_width % block_size_x) == 0,
476 "max_block_width should be a multiply of block size x.");
477 };
478
479 struct global_1d {
480 static_assert(sizeof(mem_dtype) == 4 || sizeof(mem_dtype) == 8,
481 "tile 1d only support D32/D64");
482 };
483
484 template <bool mem_transform, size_t block_size_x>
486 using load_store_attr = typename arch_attr_t<
487 gpu_arch::Xe>::template load_store_attr<msg_type::block_2d>;
488 static constexpr int32_t max_vnni_block_width
489 = load_store_attr::max_vnni_load_width_in_elems;
490 static_assert(!mem_transform || block_size_x <= max_vnni_block_width,
491 "For VNNI transform, the maximum block width is 16 width.");
492 static constexpr int32_t max_block_width
493 = load_store_attr::max_load_width_in_bytes / sizeof(dtype);
494 static_assert((max_block_width % block_size_x) == 0,
495 "max_block_width should be a multiply of block size x.");
496 };
497
498 template <mem_layout memory_layout, size_t block_size_x, size_t tile_bytes,
499 size_t min_bytes, size_t block_bytes, size_t num_channel_x,
500 size_t num_channel>
501 struct local_scatter {
502 static_assert(memory_layout == mem_layout::row_major,
503 "only support row major in local load, you can use local "
504 "store "
505 "to "
506 "do the transpose");
507 static_assert(
508 sizeof(mem_dtype) >= 4, "load size should at least DW aligned");
509 static_assert((block_size_x * sizeof(dtype) % sizeof(mem_dtype)) == 0,
510 "bytes per row should be a multiply of sizeof load_dtype");
511 static_assert(((tile_bytes % min_bytes) == 0
512 && (block_bytes % min_bytes) == 0),
513 "currently, we are not able to handle the corner case");
514 static_assert((num_channel_x > 0) && (num_channel_x <= num_channel),
515 "The number of simd channel x should be greater than 0 and "
516 "less "
517 "than num_channel");
518 };
519
520 struct local_1d {
521 static_assert(sizeof(mem_dtype) == 4 || sizeof(mem_dtype) == 8,
522 "tile 1d only support D32/D64");
523 };
524};
525
526template <typename dtype, typename mem_dtype>
527struct check_store<gpu_arch::Xe, dtype, mem_dtype> {
528 template <size_t block_size_x>
529 struct global_2d {
530 using load_store_attr = typename arch_attr_t<
531 gpu_arch::Xe>::template load_store_attr<msg_type::block_2d>;
532
533 static constexpr int32_t max_block_width
534 = load_store_attr::max_load_width_in_bytes / sizeof(dtype);
535 static_assert((max_block_width % block_size_x) == 0,
536 "max_block_width should be a multiply of block size x.");
537 };
538
539 struct global_1d {
540 static_assert(sizeof(mem_dtype) == 4 || sizeof(mem_dtype) == 8,
541 "tile 1d only support D32/D64");
542 };
543
544 template <size_t block_size_x>
546 using load_store_attr = typename arch_attr_t<
547 gpu_arch::Xe>::template load_store_attr<msg_type::block_2d>;
548
549 static constexpr int32_t max_block_width
550 = load_store_attr::max_load_width_in_bytes / sizeof(dtype);
551 static_assert((max_block_width % block_size_x) == 0,
552 "max_block_width should be a multiply of block size x.");
553 };
554
555 template <size_t tile_bytes, size_t min_store_bytes, size_t block_bytes,
556 size_t num_channel_x, size_t num_channel>
557 struct global_atomic {
558 static_assert(std::is_same<remove_const_t<dtype>, uint32_t>::value
559 || std::is_same<remove_const_t<dtype>, uint64_t>::value
560 || std::is_same<remove_const_t<dtype>, int>::value
561 || std::is_same<remove_const_t<dtype>, float>::value
562 || std::is_same<remove_const_t<dtype>, double>::value,
563 "for global atomic add, we only support fp32 and fp64, "
564 "uin32_t, uint64_t, int");
565 static_assert(((tile_bytes % min_store_bytes) == 0
566 && (block_bytes % min_store_bytes) == 0),
567 "currently, we are not able to handle the corner case");
568 static_assert((num_channel_x > 0) && (num_channel_x <= num_channel),
569 "The number of simd channel x should be greater than 0 and "
570 "less "
571 "than num_channel");
572 static_assert(sizeof(dtype) >= 4, "Only support DW and QW atomic add");
573 };
574
575 template <size_t tile_bytes, size_t min_bytes, size_t block_bytes,
576 size_t num_channel_x, size_t num_channel>
577 struct local_scatter {
578 static_assert(((tile_bytes % min_bytes) == 0
579 && (block_bytes % min_bytes) == 0),
580 "currently, we are not able to handle the corner case");
581 static_assert((num_channel_x > 0) && (num_channel_x <= num_channel),
582 "The number of simd channel x should be greater than 0 and "
583 "less "
584 "than num_channel");
585 };
586
587 template <size_t tile_bytes, size_t min_store_bytes, size_t block_bytes,
588 size_t num_channel_x, size_t num_channel>
589 struct local_scatter_vnni_col {
590 static_assert(((tile_bytes % min_store_bytes) == 0
591 && (block_bytes % min_store_bytes) == 0),
592 "currently, we are not able to handle the corner case");
593
594 static_assert((num_channel_x > 0) && (num_channel_x <= num_channel),
595 "The number of simd channel x should be greater than 0 and "
596 "less "
597 "than num_channel");
598 };
599
600 struct local_1d {
601 static_assert(sizeof(mem_dtype) == 4 || sizeof(mem_dtype) == 8,
602 "tile 1d only support D32/D64");
603 };
604};
605} // namespace subgroup
606
607namespace group {
608template <gpu_arch arch>
609struct gemm {};
610
611template <>
612struct gemm<gpu_arch::Xe> {
613 struct default_fpu {
614 template <typename dtype_a, typename dtype_b, typename dtype_mma_a,
615 typename dtype_mma_b, typename dtype_mma_acc>
616 struct check_dtype_default {
617 static_assert(
618 std::is_same<remove_const_t<dtype_mma_a>, float>::value,
619 "current only support sgemm");
620 static_assert(
621 std::is_same<remove_const_t<dtype_mma_b>, float>::value,
622 "current only support sgemm");
623 static_assert(
624 std::is_same<remove_const_t<dtype_mma_acc>, float>::value,
625 "current only support sgemm");
626 };
627
628 template <mem_layout mem_layout_a, mem_layout mem_layout_b,
629 mem_space mem_space_a, mem_space mem_space_b>
630 struct check_memory_default {
631 static constexpr bool is_col_major_a
632 = mem_layout_a == mem_layout::col_major;
633 static constexpr bool is_col_major_b
634 = mem_layout_b == mem_layout::col_major;
635 static constexpr bool is_local_a = mem_space_a == mem_space::local;
636 static constexpr bool is_local_b = mem_space_b == mem_space::local;
637 static_assert(!is_local_a,
638 "current don't support matA load from local memory");
639 static_assert(!is_local_b,
640 "current don't support matB load from local memory");
641 };
642
643 template <typename dtype_mma, int tile_size_x_a, int tile_size_y_a,
644 int block_size_x_a, int block_size_y_a, int tile_size_x_b,
645 int tile_size_y_b, int block_size_x_b, int block_size_y_b>
646 struct check_tile_size_default {
647 static constexpr uint32_t reg_in_bytes
648 = register_attr_t<grf_mode::double_grf,
649 gpu_arch::Xe>::reg_in_bytes;
650 static constexpr uint32_t simd_len
651 = reg_in_bytes / sizeof(dtype_mma);
652
653 static_assert((block_size_x_b % simd_len == 0),
654 "block_size_x_b should be a multiple of simd_len");
655 static_assert((tile_size_x_a % block_size_x_a) == 0);
656 static_assert((tile_size_y_b % block_size_y_b) == 0);
657 static_assert(block_size_x_a == block_size_y_b);
658 };
659 };
660
661 struct default_xmx {
662 template <typename dtype_a, typename dtype_b, typename dtype_mma_a,
663 typename dtype_mma_b>
664 struct check_dtype_default {
665 static_assert(std::is_same<remove_const_t<dtype_mma_a>,
667 "dtype_mma_a should be the same as dtype_mma_b in xe "
668 "arch ");
669 static_assert((sizeof(dtype_mma_a) == sizeof(dtype_a))
670 || (sizeof(dtype_mma_a) == 2 * sizeof(dtype_a))
671 || (2 * sizeof(dtype_mma_a) == sizeof(dtype_a)),
672 "Current we cannot support fp32 <->fp8, since it will "
673 "meet a "
674 "lot "
675 "of HW limitations. ");
676 static_assert((sizeof(dtype_mma_b) == sizeof(dtype_b))
677 || (sizeof(dtype_mma_b) == 2 * sizeof(dtype_b))
678 || (2 * sizeof(dtype_mma_b) == sizeof(dtype_b)),
679 "Current we cannot support fp32 <->fp8, since it will "
680 "meet a "
681 "lot "
682 "of HW limitations. ");
683 };
684
685 template <mem_layout mem_layout_a, mem_layout mem_layout_b,
686 mem_space mem_space_a, mem_space mem_space_b>
687 struct check_memory_default {
688 static constexpr bool is_col_major_a
689 = mem_layout_a == mem_layout::col_major;
690 static constexpr bool is_col_major_b
691 = mem_layout_b == mem_layout::col_major;
692 static constexpr bool is_local_a = mem_space_a == mem_space::local;
693 static constexpr bool is_local_b = mem_space_b == mem_space::local;
694 static_assert(!is_local_a || !is_col_major_a,
695 "if matA load from local memory, then matA should be "
696 "row-major");
697 static_assert(!is_local_b || !is_col_major_b,
698 "if matB load from local memory, then matB should be "
699 "row-major");
700 };
701
702 template <typename dtype_mma, int tile_size_x_a, int tile_size_y_a,
703 int block_size_x_a, int block_size_y_a, int tile_size_x_b,
704 int tile_size_y_b, int block_size_x_b, int block_size_y_b>
705 struct check_tile_size_default {
707 static constexpr int32_t mma_m = mma_attr::mma_m_in_elem;
708 static constexpr int32_t mma_n = mma_attr::mma_n_in_elem;
709 static constexpr int32_t mma_k
710 = mma_attr::mma_k_in_bytes / sizeof(dtype_mma);
711
712 static_assert(tile_size_x_a % mma_k == 0,
713 "tile_size_x_a should be a multiple of mma_k");
714 static_assert(block_size_x_a == mma_k,
715 "block_size_x_a should be equal to mma_k");
716 static_assert(tile_size_y_a % mma_m == 0,
717 "tile_size_y_a should be a multiple of mma_m");
718 static_assert(block_size_y_a % mma_m == 0,
719 "block_size_y_a should be a multiple of mma_m");
720
721 static_assert(tile_size_x_b % mma_n == 0,
722 "tile_size_x_b should be a multiple of mma_n");
723 static_assert(block_size_x_b == mma_n,
724 "block_size_x_b should be equal to mma_n");
725 static_assert(tile_size_y_b % mma_k == 0,
726 "tile_size_y_b should be a multiple of mma_k");
727 static_assert(block_size_y_b % mma_k == 0,
728 "block_size_y_b should be a multiple of mma_k");
729 };
730 };
731};
732} // namespace group
733
734namespace kernel {
735template <gpu_arch arch, typename T>
736class general_1d {};
737template <gpu_arch arch, typename T>
738class block_2d {};
739
740template <typename T>
742public:
743 static inline bool check_alignment(T *base, uint32_t pitch) {
744 auto pitch_in_bytes = pitch * element_size;
745 bool ret = ((pitch_in_bytes % pitch_alignment_bytes) == 0);
746 XETLA_ASSERT(ret, "Pitch in byte must be a multiple of 4 but is %u:%u",
747 pitch, element_size);
748 if (!ret) { return false; }
749
750 ret = (pitch_in_bytes >= min_pitch_bytes);
751 XETLA_ASSERT(ret,
752 "Pitch in byte must be equal or greater than 4B but is "
753 "%u:%u",
754 pitch, element_size);
755 if (!ret) { return false; }
756
757 ret = ((((uint64_t)base) % base_alignment_bytes) == 0);
758 XETLA_ASSERT(ret, "Base address must be 4B aligned but is %p", base);
759 return ret;
760 }
761
762private:
763 static constexpr size_t element_size = sizeof(T);
764 static constexpr size_t pitch_alignment_bytes = 4;
765 static constexpr size_t min_pitch_bytes = 4;
766 static constexpr size_t base_alignment_bytes = 4;
767};
768
769template <typename T>
770class block_2d<gpu_arch::Xe, T> {
771public:
772 static inline bool check_tensor(
773 uint64_t base, uint32_t width, uint32_t height, uint32_t pitch) {
775 base, width * element_size, height, pitch * element_size);
776 }
777
778private:
779 static constexpr size_t element_size = sizeof(T);
780};
781} // namespace kernel
782
783} // namespace gpu::xetla
static bool check_surface(uint64_t base, uint32_t width, uint32_t height, uint32_t pitch)
Definition limitation.hpp:253
static bool check_load(xetla_tdescriptor tdesc)
Definition limitation.hpp:77
static bool check_store(xetla_tdescriptor tdesc)
Definition limitation.hpp:210
Definition limitation.hpp:33
static bool check_tensor(uint64_t base, uint32_t width, uint32_t height, uint32_t pitch)
Definition limitation.hpp:772
Definition limitation.hpp:738
static bool check_alignment(T *base, uint32_t pitch)
Definition limitation.hpp:743
Definition limitation.hpp:736
C++ API.
typename std::remove_const< T >::type remove_const_t
Definition common.hpp:26
#define XETLA_ASSERT(c, s,...)
Definition debug.hpp:158
xetla_vector< uint32_t, 16 > xetla_tdescriptor
Description of nd tensor descriptor for load and store.
Definition base_types.hpp:155
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
#define IN_RANGE(x, l, r)
Definition limitation.hpp:25
__XETLA_API uint8_t xetla_get_block_array_len(xetla_tdescriptor desc)
Definition tensor_descriptor.hpp:90
__XETLA_API uint8_t xetla_get_block_width_x(xetla_tdescriptor desc)
Definition tensor_descriptor.hpp:84
__XETLA_API uint32_t xetla_get_tensor_width_y(xetla_tdescriptor desc)
Definition tensor_descriptor.hpp:51
__XETLA_API uint32_t xetla_get_tensor_width_x(xetla_tdescriptor desc)
Definition tensor_descriptor.hpp:43
__XETLA_API uint8_t xetla_get_block_width_y(xetla_tdescriptor desc)
Definition tensor_descriptor.hpp:87
__XETLA_API int32_t xetla_get_tensor_offset_x(xetla_tdescriptor desc)
Definition tensor_descriptor.hpp:67
__XETLA_API uint64_t xetla_get_tensor_base_address(xetla_tdescriptor desc)
Definition tensor_descriptor.hpp:35
__XETLA_API uint32_t xetla_get_tensor_pitch_x(xetla_tdescriptor desc)
Definition tensor_descriptor.hpp:59
__XETLA_API int32_t xetla_get_tensor_offset_y(xetla_tdescriptor desc)
Definition tensor_descriptor.hpp:75
Definition arch_config.hpp:24
mem_space
Definition common.hpp:77
gpu_arch
Definition common.hpp:73
mem_layout
Definition common.hpp:76
Definition arch_config.hpp:72
static bool check_restriction(xetla_vector< Toffset, N > offsets, uint64_t p=0)
Definition limitation.hpp:53
static bool check_restriction(uint64_t offset, uint64_t p=0)
Definition limitation.hpp:38
Definition limitation.hpp:31
Definition limitation.hpp:609
Definition arch_config.hpp:55
Definition arch_config.hpp:53
Definition arch_config.hpp:62
typename arch_attr_t< gpu_arch::Xe >::template load_store_attr< msg_type::block_2d > load_store_attr
Definition limitation.hpp:468
typename arch_attr_t< gpu_arch::Xe >::template load_store_attr< msg_type::block_2d > load_store_attr
Definition limitation.hpp:487
Definition limitation.hpp:459
typename arch_attr_t< gpu_arch::Xe >::template load_store_attr< msg_type::block_2d > load_store_attr
Definition limitation.hpp:531
typename arch_attr_t< gpu_arch::Xe >::template load_store_attr< msg_type::block_2d > load_store_attr
Definition limitation.hpp:547
Definition limitation.hpp:461