30 #include <type_traits>
33 inline namespace _V1 {
36 namespace experimental::matrix {
38 template <
architecture u,
typename Ta,
typename Tb,
typename Tc,
39 typename Td = Tc,
size_t sM = 0,
size_t sN = 0,
size_t sK = 0,
40 typename Enabled =
void>
43 template <
typename Ta,
typename Tb,
typename Tc>
46 if ((std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, int8_t> &&
47 std::is_same_v<Tc, int> && sM <= 16 && sN <= 16 && sK <= 64) ||
48 (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, uint8_t> &&
49 std::is_same_v<Tc, int> && sM <= 16 && sN <= 16 && sK <= 64) ||
50 (std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, uint8_t> &&
51 std::is_same_v<Tc, int> && sM <= 16 && sN <= 16 && sK <= 64) ||
52 (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, int8_t> &&
53 std::is_same_v<Tc, int> && sM <= 16 && sN <= 16 && sK <= 64) ||
55 (std::is_same_v<Ta, unsigned short> &&
56 std::is_same_v<Tb, unsigned short> && std::is_same_v<Tc, float> &&
57 sM <= 16 && sN <= 16 && sK <= 32))
63 template <
typename Ta,
typename Tb,
typename Tc>
65 if ((std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, int8_t> &&
66 std::is_same_v<Tc, int>) ||
67 (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, uint8_t> &&
68 std::is_same_v<Tc, int>) ||
69 (std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, uint8_t> &&
70 std::is_same_v<Tc, int>) ||
71 (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, int8_t> &&
72 std::is_same_v<Tc, int>) ||
73 (std::is_same_v<Ta, unsigned short> &&
74 std::is_same_v<Tb, unsigned short> && std::is_same_v<Tc, float>))
82 template <
typename Ta,
typename Tb,
typename Tc,
typename Td>
85 typename std::enable_if<(!std::is_same_v<Ta, void> &&
86 !std::is_same_v<Tb, void> &&
87 !std::is_same_v<Tc, void>)>::type> {
88 static_assert((are_types_valid_amx<Ta, Tb, Tc>()),
89 "Invalid types for AMX, supported types are int8_t, uint8_t, "
90 "and bf16 (Note that unsigned short should be used in the"
91 "DPC++ code to implement bf16) ");
94 static constexpr std::size_t M = 16;
95 static constexpr std::size_t N = 16;
96 static constexpr std::size_t K = ((
sizeof(Ta) == 1) ? 64 : 32);
98 template <
typename Group, layout Layout>
100 template <
typename Group, layout Layout>
102 template <
typename Group>
104 template <
typename Group>
110 template <
typename Ta,
typename Tb,
typename Tc,
typename Td,
size_t sM,
111 size_t sN,
size_t sK>
114 typename std::enable_if<(
115 !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
116 !std::is_same_v<Tc, void> && sM != 0 && sN != 0 && sK != 0)>::type> {
119 (sM == 0 && sN == 0 && sK == 0) ||
120 (is_combination_valid_amx<Ta, Tb, Tc>(sM, sN, sK)),
121 "Invalid parameters for AMX, query valid types and maximum sizes "
122 "using: matrix_params<architecture::intel_cpu_spr> myparams; and then "
124 "myparams.combinations array");
128 static constexpr std::size_t M = sM;
129 static constexpr std::size_t N = sN;
130 static constexpr std::size_t K = sK;
132 template <
typename Group, layout Layout>
134 template <
typename Group, layout Layout>
136 template <
typename Group>
138 template <
typename Group>
147 template <
typename Ta,
typename Tb,
typename Tc>
149 if ((std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, int8_t> &&
150 std::is_same_v<Tc, int> && (sM >= 1 && sM <= 8) && sN == 8 &&
152 (std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, uint8_t> &&
153 std::is_same_v<Tc, int> && (sM >= 1 && sM <= 8) && sN == 8 &&
155 (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, int8_t> &&
156 std::is_same_v<Tc, int> && (sM >= 1 && sM <= 8) && sN == 8 &&
158 (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, uint8_t> &&
159 std::is_same_v<Tc, int> && (sM >= 1 && sM <= 8) && sN == 8 &&
161 (std::is_same_v<Ta, half> && std::is_same_v<Tb, half> &&
162 std::is_same_v<Tc, float> && (sM >= 1 && sM <= 8) && sN == 8 &&
164 (std::is_same_v<Ta, unsigned short> &&
165 std::is_same_v<Tb, unsigned short> && std::is_same_v<Tc, float> &&
166 (sM >= 1 && sM <= 8) && sN == 8 && sK == 16))
172 template <
typename Ta,
typename Tb,
typename Tc>
174 if ((std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, int8_t> &&
175 std::is_same_v<Tc, int>) ||
176 (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, int8_t> &&
177 std::is_same_v<Tc, int>) ||
178 (std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, uint8_t> &&
179 std::is_same_v<Tc, int>) ||
180 (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, uint8_t> &&
181 std::is_same_v<Tc, int>) ||
182 (std::is_same_v<Ta, half> && std::is_same_v<Tb, half> &&
183 std::is_same_v<Tc, float>) ||
184 (std::is_same_v<Ta, unsigned short> &&
185 std::is_same_v<Tb, unsigned short> && std::is_same_v<Tc, float>))
194 template <
typename Ta,
typename Tb,
typename Tc,
typename Td>
197 typename std::enable_if<(!std::is_same_v<Ta, void> &&
198 !std::is_same_v<Tb, void> &&
199 !std::is_same_v<Tc, void>)>::type> {
200 static_assert((are_types_valid_xmx8<Ta, Tb, Tc>()),
201 "Invalid types for architecture::intel_gpu_dg2_g10, supported "
202 "types are int8_t, uint8_t, half, and bf16");
206 static constexpr std::size_t M = 8;
207 static constexpr std::size_t N = 8;
208 static constexpr std::size_t K = ((
sizeof(Ta) == 1) ? 32 : 16);
210 template <
typename Group, layout Layout>
212 template <
typename Group, layout Layout>
214 template <
typename Group>
216 template <
typename Group>
222 template <
typename Ta,
typename Tb,
typename Tc,
typename Td,
size_t sM,
223 size_t sN,
size_t sK>
226 typename std::enable_if<(
227 !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
228 !std::is_same_v<Tc, void> && sM != 0 && sN != 0 && sK != 0)>::type> {
231 (sM == 0 && sN == 0 && sK == 0) ||
232 (is_combination_valid_xmx8<Ta, Tb, Tc>(sM, sN, sK)),
233 "Invalid parameters for XMX8, query valid combinations "
235 "q.get_device().get_info<sycl::info::device::matrix::combinations>()");
238 static constexpr std::size_t M = sM;
239 static constexpr std::size_t N = sN;
240 static constexpr std::size_t K = sK;
242 template <
typename Group, layout Layout>
244 template <
typename Group, layout Layout>
246 template <
typename Group>
248 template <
typename Group>
255 template <
typename Ta,
typename Tb,
typename Tc,
typename Td>
258 typename std::enable_if<(!std::is_same_v<Ta, void> &&
259 !std::is_same_v<Tb, void> &&
260 !std::is_same_v<Tc, void>)>::type> {
261 static_assert((are_types_valid_xmx8<Ta, Tb, Tc>()),
262 "Invalid types for architecture::intel_gpu_dg2_g11, supported"
263 "types are int8_t, uint8_t, half, and bf16");
267 static constexpr std::size_t M = 8;
268 static constexpr std::size_t N = 8;
269 static constexpr std::size_t K = ((
sizeof(Ta) == 1) ? 32 : 16);
271 template <
typename Group, layout Layout>
273 template <
typename Group, layout Layout>
275 template <
typename Group>
277 template <
typename Group>
283 template <
typename Ta,
typename Tb,
typename Tc,
typename Td,
size_t sM,
284 size_t sN,
size_t sK>
287 typename std::enable_if<(
288 !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
289 !std::is_same_v<Tc, void> && sM != 0 && sN != 0 && sK != 0)>::type> {
292 (sM == 0 && sN == 0 && sK == 0) ||
293 (is_combination_valid_xmx8<Ta, Tb, Tc>(sM, sN, sK)),
294 "Invalid parameters for XMX8, query valid combinations "
296 "q.get_device().get_info<sycl::info::device::matrix::combinations>()");
299 static constexpr std::size_t M = sM;
300 static constexpr std::size_t N = sN;
301 static constexpr std::size_t K = sK;
303 template <
typename Group, layout Layout>
305 template <
typename Group, layout Layout>
307 template <
typename Group>
309 template <
typename Group>
316 template <
typename Ta,
typename Tb,
typename Tc,
typename Td>
319 typename std::enable_if<(!std::is_same_v<Ta, void> &&
320 !std::is_same_v<Tb, void> &&
321 !std::is_same_v<Tc, void>)>::type> {
322 static_assert((are_types_valid_xmx8<Ta, Tb, Tc>()),
323 "Invalid types for architecture::intel_gpu_dg2_g12, supported "
324 "types are int8_t, uint8_t, half, and bf16");
328 static constexpr std::size_t M = 8;
329 static constexpr std::size_t N = 8;
330 static constexpr std::size_t K = ((
sizeof(Ta) == 1) ? 32 : 16);
332 template <
typename Group, layout Layout>
334 template <
typename Group, layout Layout>
336 template <
typename Group>
338 template <
typename Group>
344 template <
typename Ta,
typename Tb,
typename Tc,
typename Td,
size_t sM,
345 size_t sN,
size_t sK>
348 typename std::enable_if<(
349 !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
350 !std::is_same_v<Tc, void> && sM != 0 && sN != 0 && sK != 0)>::type> {
353 (sM == 0 && sN == 0 && sK == 0) ||
354 (is_combination_valid_xmx8<Ta, Tb, Tc>(sM, sN, sK)),
355 "Invalid parameters for XMX8, query valid combinations "
357 "q.get_device().get_info<sycl::info::device::matrix::combinations>()");
360 static constexpr std::size_t M = sM;
361 static constexpr std::size_t N = sN;
362 static constexpr std::size_t K = sK;
364 template <
typename Group, layout Layout>
366 template <
typename Group, layout Layout>
368 template <
typename Group>
370 template <
typename Group>
379 template <
typename Ta,
typename Tb,
typename Tc>
381 if ((std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, int8_t> &&
382 std::is_same_v<Tc, int> && (sM >= 1 && sM <= 8) && sN == 16 &&
384 (std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, uint8_t> &&
385 std::is_same_v<Tc, int> && (sM >= 1 && sM <= 8) && sN == 16 &&
387 (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, int8_t> &&
388 std::is_same_v<Tc, int> && (sM >= 1 && sM <= 8) && sN == 16 &&
390 (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, uint8_t> &&
391 std::is_same_v<Tc, int> && (sM >= 1 && sM <= 8) && sN == 16 &&
393 (std::is_same_v<Ta, half> && std::is_same_v<Tb, half> &&
394 std::is_same_v<Tc, float> && (sM >= 1 && sM <= 8) && sN == 16 &&
396 (std::is_same_v<Ta, unsigned short> &&
397 std::is_same_v<Tb, unsigned short> && std::is_same_v<Tc, float> &&
398 (sM >= 1 && sM <= 8) && sN == 16 && sK == 16))
404 template <
typename Ta,
typename Tb,
typename Tc>
406 if ((std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, int8_t> &&
407 std::is_same_v<Tc, int>) ||
408 (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, int8_t> &&
409 std::is_same_v<Tc, int>) ||
410 (std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, uint8_t> &&
411 std::is_same_v<Tc, int>) ||
412 (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, uint8_t> &&
413 std::is_same_v<Tc, int>) ||
414 (std::is_same_v<Ta, half> && std::is_same_v<Tb, half> &&
415 std::is_same_v<Tc, float>) ||
416 (std::is_same_v<Ta, unsigned short> &&
417 std::is_same_v<Tb, unsigned short> && std::is_same_v<Tc, float>))
426 template <
typename Ta,
typename Tb,
typename Tc,
typename Td>
429 typename std::enable_if<(!std::is_same_v<Ta, void> &&
430 !std::is_same_v<Tb, void> &&
431 !std::is_same_v<Tc, void>)>::type> {
432 static_assert((are_types_valid_xmx16<Ta, Tb, Tc>()),
433 "Invalid types for architecture::intel_gpu_pvc, supported "
434 "types are int8_t, uint8_t, "
439 static constexpr std::size_t M = 8;
440 static constexpr std::size_t N = 16;
441 static constexpr std::size_t K = ((
sizeof(Ta) == 1) ? 32 : 16);
443 template <
typename Group, layout Layout>
445 template <
typename Group, layout Layout>
447 template <
typename Group>
449 template <
typename Group>
455 template <
typename Ta,
typename Tb,
typename Tc,
typename Td,
size_t sM,
456 size_t sN,
size_t sK>
459 typename std::enable_if<(
460 !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
461 !std::is_same_v<Tc, void> && sM != 0 && sN != 0 && sK != 0)>::type> {
464 (sM == 0 && sN == 0 && sK == 0) ||
465 (is_combination_valid_xmx16<Ta, Tb, Tc>(sM, sN, sK)),
466 "Invalid parameters for architecture::intel_gpu_pvc, query valid "
469 "q.get_device().get_info<sycl::info::device::matrix::combinations>()");
472 static constexpr std::size_t M = sM;
473 static constexpr std::size_t N = sN;
474 static constexpr std::size_t K = sK;
476 template <
typename Group, layout Layout>
478 template <
typename Group, layout Layout>
480 template <
typename Group>
482 template <
typename Group>
490 template <
typename Ta,
typename Tc>
493 return (std::is_same_v<Ta, half> && std::is_same_v<Tc, float> &&
494 ((sM == 32 && sN == 32 && sK == 8) ||
495 (sM == 16 && sN == 16 && sK == 16))) ||
496 (std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t> &&
497 ((sM == 32 && sN == 32 && sK == 8) ||
498 (sM == 16 && sN == 16 && sK == 16))) ||
499 (std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float> &&
500 ((sM == 32 && sN == 32 && sK == 8) ||
501 (sM == 16 && sN == 16 && sK == 16))) ||
502 (std::is_same_v<Ta, double> && std::is_same_v<Tc, double> &&
503 (sM == 16 && sN == 16 && sK == 4));
506 template <
typename Ta,
typename Tc>
508 return (std::is_same_v<Ta, half> && std::is_same_v<Tc, float>) ||
509 (std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t>) ||
510 (std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float>) ||
511 (std::is_same_v<Ta, double> && std::is_same_v<Tc, double>);
516 template <
typename Ta,
typename Tb,
typename Tc,
typename Td>
519 typename std::enable_if_t<(
520 !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
521 !std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
522 std::is_same_v<Ta, Tb> && std::is_same_v<Tc, Td>)>> {
524 are_types_valid_amd_gfx90a<Ta, Tc>(),
525 "Invalid types for AMD gfx90a, supported types are half, float, "
526 "int8_t, int32_t, double and bfloat16 ");
529 static constexpr std::size_t M = 16;
530 static constexpr std::size_t N = 16;
531 static constexpr std::size_t K = ((
sizeof(Ta) == 8) ? 16 : 4);
533 template <
typename Group, layout Layout>
535 template <
typename Group, layout Layout>
537 template <
typename Group>
539 template <
typename Group>
545 template <
typename Ta,
typename Tb,
typename Tc,
typename Td,
size_t sM,
546 size_t sN,
size_t sK>
549 typename std::enable_if_t<(
550 !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
551 !std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
552 std::is_same_v<Ta, Tb> && std::is_same_v<Tc, Td> && sM != 0 &&
553 sN != 0 && sK != 0)>> {
555 is_combination_valid_amd_gfx90a<Ta, Tc>(sM, sN, sK),
556 "Invalid parameters for AMD gfx90a, query valid combinations "
558 "q.get_device().get_info<sycl::info::device::matrix::combinations>()");
560 static constexpr std::size_t M = sM;
561 static constexpr std::size_t N = sN;
562 static constexpr std::size_t K = sK;
564 template <
typename Group, layout Layout>
566 template <
typename Group, layout Layout>
568 template <
typename Group>
570 template <
typename Group>
578 template <
typename Ta,
typename Tc,
typename Td>
580 return (std::is_same_v<Ta, half> && std::is_same_v<Tc, float> &&
581 std::is_same_v<Td, float>) ||
582 (std::is_same_v<Ta, half> && std::is_same_v<Tc, half> &&
583 std::is_same_v<Td, half>) ||
584 (std::is_same_v<Ta, half> && std::is_same_v<Tc, float> &&
585 std::is_same_v<Td, half>) ||
586 (std::is_same_v<Ta, half> && std::is_same_v<Tc, half> &&
587 std::is_same_v<Td, float>);
590 template <
typename Ta,
typename Tc,
typename Td>
592 return (std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t> &&
593 std::is_same_v<Td, int32_t>) ||
594 (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tc, int32_t> &&
595 std::is_same_v<Td, int32_t>);
598 template <
typename Ta,
typename Tc,
typename Td>
600 return (std::is_same_v<Ta, precision::tf32> && std::is_same_v<Tc, float> &&
601 std::is_same_v<Td, float>) ||
602 (std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float> &&
603 std::is_same_v<Td, float>) ||
604 (std::is_same_v<Ta, double> && std::is_same_v<Tc, double> &&
605 std::is_same_v<Td, double>);
608 template <
typename Ta,
typename Tc,
typename Td>
610 return are_types_valid_cuda_sm70<Ta, Tc, Td>() &&
611 ((sM == 8 && sN == 32 && sK == 16) ||
612 (sM == 16 && sN == 16 && sK == 16) ||
613 (sM == 32 && sN == 8 && sK == 16));
616 template <
typename Ta,
typename Tc,
typename Td>
618 return are_types_valid_cuda_sm72<Ta, Tc, Td>() &&
619 ((sM == 8 && sN == 32 && sK == 16) ||
620 (sM == 16 && sN == 16 && sK == 16) ||
621 (sM == 32 && sN == 8 && sK == 16));
624 template <
typename Ta,
typename Tc,
typename Td>
626 return ((std::is_same_v<Ta, precision::tf32> && std::is_same_v<Tc, float> &&
627 std::is_same_v<Td, float>)&&(sM == 16 && sN == 16 && sK == 8)) ||
628 ((std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float> &&
629 std::is_same_v<Td, float>)&&((sM == 16 && sN == 16 && sK == 16) ||
630 (sM == 8 && sN == 32 && sK == 16) ||
631 (sM == 32 && sN == 8 && sK == 16))) ||
632 ((std::is_same_v<Ta, double> && std::is_same_v<Tc, double> &&
633 std::is_same_v<Td, double>)&&(sM == 8 && sN == 8 && sK == 4));
638 template <
typename Ta,
typename Tb,
typename Tc,
typename Td>
641 typename std::enable_if_t<(
642 !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
643 !std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
644 std::is_same_v<Ta, Tb>)>> {
646 are_types_valid_cuda_sm70<Ta, Tc, Td>(),
647 "Invalid types for nvidia sm70, supported types are half and float ");
650 static constexpr std::size_t M = 16;
651 static constexpr std::size_t N = 16;
652 static constexpr std::size_t K = 16;
654 template <
typename Group, layout Layout>
656 template <
typename Group, layout Layout>
658 template <
typename Group>
660 template <
typename Group>
666 template <
typename Ta,
typename Tb,
typename Tc,
typename Td>
669 typename std::enable_if<(
670 !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
671 !std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
672 std::is_same_v<Ta, Tb>)>::type> {
674 are_types_valid_cuda_sm70<Ta, Tc, Td>() ||
675 are_types_valid_cuda_sm72<Ta, Tc, Td>(),
676 "Invalid types for nvidia sm72, supported types are half, float "
677 "int8_t, uint8_t and int32_t ");
679 static constexpr std::size_t M = 16;
680 static constexpr std::size_t N = 16;
681 static constexpr std::size_t K = 16;
683 template <
typename Group, layout Layout>
685 template <
typename Group, layout Layout>
687 template <
typename Group>
689 template <
typename Group>
695 template <
typename Ta,
typename Tb,
typename Tc,
typename Td>
698 typename std::enable_if_t<(
699 !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
700 !std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
701 std::is_same_v<Ta, Tb>)>> {
703 are_types_valid_cuda_sm70<Ta, Tc, Td>() ||
704 are_types_valid_cuda_sm72<Ta, Tc, Td>() ||
705 are_types_valid_cuda_sm80<Ta, Tc, Td>(),
706 "Invalid types for nvidia sm80, supported types are half, float "
707 "int8_t, uint8_t, int32_t, double, tf32 and bfloat16 ");
709 static constexpr std::size_t M = (
sizeof(Ta) == 8) ? 8 : 16;
710 static constexpr std::size_t N = (
sizeof(Ta) == 8) ? 8 : 16;
711 static constexpr std::size_t K =
712 std::is_same_v<Ta, precision::tf32> ? 8 : (
sizeof(Ta) == 8 ? 4 : 16);
714 template <
typename Group, layout Layout>
716 template <
typename Group, layout Layout>
718 template <
typename Group>
720 template <
typename Group>
726 template <
typename Ta,
typename Tb,
typename Tc,
typename Td,
size_t sM,
727 size_t sN,
size_t sK>
730 typename std::enable_if_t<(
731 !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
732 !std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
733 std::is_same_v<Ta, Tb> && sM != 0 && sN != 0 && sK != 0)>> {
735 is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM, sN, sK),
736 "Invalid parameters for nvidia sm70, query valid combinations "
738 "q.get_device().get_info<sycl::info::device::matrix::combinations>()");
740 static constexpr std::size_t M = sM;
741 static constexpr std::size_t N = sN;
742 static constexpr std::size_t K = sK;
744 template <
typename Group, layout Layout>
746 template <
typename Group, layout Layout>
748 template <
typename Group>
750 template <
typename Group>
756 template <
typename Ta,
typename Tb,
typename Tc,
typename Td,
size_t sM,
757 size_t sN,
size_t sK>
760 typename std::enable_if_t<(
761 !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
762 !std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
763 std::is_same_v<Ta, Tb> && sM != 0 && sN != 0 && sK != 0)>> {
765 is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM, sN, sK) ||
766 is_combination_valid_cuda_sm72<Ta, Tc, Td>(sM, sN, sK),
767 "Invalid parameters for nvidia sm72, query valid combinations "
769 "q.get_device().get_info<sycl::info::device::matrix::combinations>()");
771 static constexpr std::size_t M = sM;
772 static constexpr std::size_t N = sN;
773 static constexpr std::size_t K = sK;
775 template <
typename Group, layout Layout>
777 template <
typename Group, layout Layout>
779 template <
typename Group>
781 template <
typename Group>
787 template <
typename Ta,
typename Tb,
typename Tc,
typename Td,
size_t sM,
788 size_t sN,
size_t sK>
791 typename std::enable_if_t<(
792 !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
793 !std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
794 std::is_same_v<Ta, Tb> && sM != 0 && sN != 0 && sK != 0)>> {
796 is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM, sN, sK) ||
797 is_combination_valid_cuda_sm72<Ta, Tc, Td>(sM, sN, sK) ||
798 is_combination_valid_cuda_sm80<Ta, Tc, Td>(sM, sN, sK),
799 "Invalid parameters for nvidia sm80, query valid combinations "
801 "q.get_device().get_info<sycl::info::device::matrix::combinations>()");
803 static constexpr std::size_t M = sM;
804 static constexpr std::size_t N = sN;
805 static constexpr std::size_t K = sK;
807 template <
typename Group, layout Layout>
809 template <
typename Group, layout Layout>
811 template <
typename Group>
813 template <
typename Group>
constexpr bool are_types_valid_cuda_sm72()
constexpr bool is_combination_valid_cuda_sm72(size_t sM, size_t sN, size_t sK)
constexpr bool is_combination_valid_amx(size_t sM, size_t sN, size_t sK)
constexpr bool are_types_valid_amd_gfx90a()
constexpr bool are_types_valid_cuda_sm70()
CUDA Tensor Cores - sm70, sm72 and sm80 ///.
constexpr bool is_combination_valid_xmx16(size_t sM, size_t sN, size_t sK)
constexpr bool are_types_valid_cuda_sm80()
constexpr bool is_combination_valid_cuda_sm80(size_t sM, size_t sN, size_t sK)
constexpr bool are_types_valid_amx()
constexpr bool are_types_valid_xmx16()
constexpr bool is_combination_valid_xmx8(size_t sM, size_t sN, size_t sK)
constexpr bool are_types_valid_xmx8()
constexpr bool is_combination_valid_amd_gfx90a(size_t sM, size_t sN, size_t sK)
AMD Matrix Cores - GFX90A architecture ///.
constexpr bool is_combination_valid_cuda_sm70(size_t sM, size_t sN, size_t sK)