DPC++ Runtime
Runtime libraries for oneAPI DPC++
static-query-use.hpp
Go to the documentation of this file.
1 //===---------- static-query-use.hpp - SYCL matrix ------------*- C++ -*---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // ===--------------------------------------------------------------------=== //
8 // This file implements the static query interface for the joint_matrix
9 // experimental extension. Intel(R) Advanced Matrix Extensions (Intel(R) AMX),
10 // and Intel(R) Xe Matrix Extensions (Intel(R) XMX) support different logical
11 // sizes and types. The query interface is used to validate user code and inform
12 // them about supported types, sizes, scopes, and layouts by the current
13 // implementation. Note that this query interface is a compile-time query, so
14 // there will be no runtime errors. The query interface provides three
15 // functionalities: 1- At compile time, inform the user whether a specific
16 // combination is valid or not. 2- Construct the matrices using a default shape
17 // if user does not provide a combination 3- General query interface for sizes,
18 // types, scopes. This is needed to void padding by the user, for tuning, and
19 // efficient code generation if used by a library.
20 
21 #pragma once
22 
23 #include <sycl/aliases.hpp> // for half
25 #include <sycl/ext/oneapi/matrix/matrix-unified-utils.hpp> // for use, layout
26 #include <sycl/ext/oneapi/matrix/matrix-unified.hpp> // for joint_matrix
27 
28 #include <cstddef> // for size_t
29 #include <stdint.h> // for uint32_t, int8_t
30 #include <type_traits> // for enable_if
31 
32 namespace sycl {
33 inline namespace _V1 {
34 namespace ext {
35 namespace oneapi {
36 namespace experimental::matrix {
37 
38 template <architecture u, typename Ta, typename Tb, typename Tc,
39  typename Td = Tc, size_t sM = 0, size_t sN = 0, size_t sK = 0,
40  typename Enabled = void>
42 
43 template <typename Ta, typename Tb, typename Tc>
44 constexpr bool is_combination_valid_amx(size_t sM, size_t sN, size_t sK) {
45  // is_same_v is a C++17 feature
46  if ((std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, int8_t> &&
47  std::is_same_v<Tc, int> && sM <= 16 && sN <= 16 && sK <= 64) ||
48  (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, uint8_t> &&
49  std::is_same_v<Tc, int> && sM <= 16 && sN <= 16 && sK <= 64) ||
50  (std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, uint8_t> &&
51  std::is_same_v<Tc, int> && sM <= 16 && sN <= 16 && sK <= 64) ||
52  (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, int8_t> &&
53  std::is_same_v<Tc, int> && sM <= 16 && sN <= 16 && sK <= 64) ||
54  // bf16
55  (std::is_same_v<Ta, unsigned short> &&
56  std::is_same_v<Tb, unsigned short> && std::is_same_v<Tc, float> &&
57  sM <= 16 && sN <= 16 && sK <= 32))
58  return true;
59  else
60  return false;
61 }
62 
63 template <typename Ta, typename Tb, typename Tc>
64 constexpr bool are_types_valid_amx() {
65  if ((std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, int8_t> &&
66  std::is_same_v<Tc, int>) ||
67  (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, uint8_t> &&
68  std::is_same_v<Tc, int>) ||
69  (std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, uint8_t> &&
70  std::is_same_v<Tc, int>) ||
71  (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, int8_t> &&
72  std::is_same_v<Tc, int>) ||
73  (std::is_same_v<Ta, unsigned short> &&
74  std::is_same_v<Tb, unsigned short> && std::is_same_v<Tc, float>))
75  return true;
76  else
77  return false;
78 }
79 
80 // Default values query
81 // Specialization for when only types are given, need to query only sizes
82 template <typename Ta, typename Tb, typename Tc, typename Td>
84  architecture::intel_cpu_spr, Ta, Tb, Tc, Td, 0, 0, 0,
85  typename std::enable_if<(!std::is_same_v<Ta, void> &&
86  !std::is_same_v<Tb, void> &&
87  !std::is_same_v<Tc, void>)>::type> {
88  static_assert((are_types_valid_amx<Ta, Tb, Tc>()),
89  "Invalid types for AMX, supported types are int8_t, uint8_t, "
90  "and bf16 (Note that unsigned short should be used in the"
91  "DPC++ code to implement bf16) ");
92 
93  // construct the matrices using the default sizes
94  static constexpr std::size_t M = 16;
95  static constexpr std::size_t N = 16;
96  static constexpr std::size_t K = ((sizeof(Ta) == 1) ? 64 : 32);
97 
98  template <typename Group, layout Layout>
100  template <typename Group, layout Layout>
102  template <typename Group>
104  template <typename Group>
106 };
107 
108 // Validation query
109 // Specialization when both types and sizes are given
110 template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM,
111  size_t sN, size_t sK>
113  architecture::intel_cpu_spr, Ta, Tb, Tc, Td, sM, sN, sK,
114  typename std::enable_if<(
115  !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
116  !std::is_same_v<Tc, void> && sM != 0 && sN != 0 && sK != 0)>::type> {
117  // Validate that parameters are supported
118  static_assert(
119  (sM == 0 && sN == 0 && sK == 0) ||
120  (is_combination_valid_amx<Ta, Tb, Tc>(sM, sN, sK)),
121  "Invalid parameters for AMX, query valid types and maximum sizes "
122  "using: matrix_params<architecture::intel_cpu_spr> myparams; and then "
123  "check out "
124  "myparams.combinations array");
125 
126  // if combination is valid, construct the matrices
127 
128  static constexpr std::size_t M = sM;
129  static constexpr std::size_t N = sN;
130  static constexpr std::size_t K = sK;
131 
132  template <typename Group, layout Layout>
134  template <typename Group, layout Layout>
136  template <typename Group>
138  template <typename Group>
140 };
141 
142 // Intel XMX with SIMD8 capability
143 // The Intel XMX implementation supports the logical capability support of the
144 // HW So in this case, M, N, K sizes returned by the query represent the logical
145 // capabilities of the Intel XMX hardware.
146 
147 template <typename Ta, typename Tb, typename Tc>
148 constexpr bool is_combination_valid_xmx8(size_t sM, size_t sN, size_t sK) {
149  if ((std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, int8_t> &&
150  std::is_same_v<Tc, int> && (sM >= 1 && sM <= 8) && sN == 8 &&
151  sK == 32) ||
152  (std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, uint8_t> &&
153  std::is_same_v<Tc, int> && (sM >= 1 && sM <= 8) && sN == 8 &&
154  sK == 32) ||
155  (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, int8_t> &&
156  std::is_same_v<Tc, int> && (sM >= 1 && sM <= 8) && sN == 8 &&
157  sK == 32) ||
158  (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, uint8_t> &&
159  std::is_same_v<Tc, int> && (sM >= 1 && sM <= 8) && sN == 8 &&
160  sK == 32) ||
161  (std::is_same_v<Ta, half> && std::is_same_v<Tb, half> &&
162  std::is_same_v<Tc, float> && (sM >= 1 && sM <= 8) && sN == 8 &&
163  sK == 16) ||
164  (std::is_same_v<Ta, unsigned short> &&
165  std::is_same_v<Tb, unsigned short> && std::is_same_v<Tc, float> &&
166  (sM >= 1 && sM <= 8) && sN == 8 && sK == 16))
167  return true;
168  else
169  return false;
170 }
171 
172 template <typename Ta, typename Tb, typename Tc>
173 constexpr bool are_types_valid_xmx8() {
174  if ((std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, int8_t> &&
175  std::is_same_v<Tc, int>) ||
176  (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, int8_t> &&
177  std::is_same_v<Tc, int>) ||
178  (std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, uint8_t> &&
179  std::is_same_v<Tc, int>) ||
180  (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, uint8_t> &&
181  std::is_same_v<Tc, int>) ||
182  (std::is_same_v<Ta, half> && std::is_same_v<Tb, half> &&
183  std::is_same_v<Tc, float>) ||
184  (std::is_same_v<Ta, unsigned short> &&
185  std::is_same_v<Tb, unsigned short> && std::is_same_v<Tc, float>))
186  return true;
187  else
188  return false;
189 }
190 
191 // Default-values query:
192 // Specialization for when only types are given, need to query only sizes
193 
194 template <typename Ta, typename Tb, typename Tc, typename Td>
196  architecture::intel_gpu_dg2_g10, Ta, Tb, Tc, Td, 0, 0, 0,
197  typename std::enable_if<(!std::is_same_v<Ta, void> &&
198  !std::is_same_v<Tb, void> &&
199  !std::is_same_v<Tc, void>)>::type> {
200  static_assert((are_types_valid_xmx8<Ta, Tb, Tc>()),
201  "Invalid types for architecture::intel_gpu_dg2_g10, supported "
202  "types are int8_t, uint8_t, half, and bf16");
203 
204  // construct the matrices using the default sizes
205 
206  static constexpr std::size_t M = 8;
207  static constexpr std::size_t N = 8;
208  static constexpr std::size_t K = ((sizeof(Ta) == 1) ? 32 : 16);
209 
210  template <typename Group, layout Layout>
212  template <typename Group, layout Layout>
214  template <typename Group>
216  template <typename Group>
218 };
219 
220 // Validation query:
221 // Specialization when both types and sizes are given
222 template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM,
223  size_t sN, size_t sK>
225  architecture::intel_gpu_dg2_g10, Ta, Tb, Tc, Td, sM, sN, sK,
226  typename std::enable_if<(
227  !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
228  !std::is_same_v<Tc, void> && sM != 0 && sN != 0 && sK != 0)>::type> {
229  // Validate that parameters are supported
230  static_assert(
231  (sM == 0 && sN == 0 && sK == 0) ||
232  (is_combination_valid_xmx8<Ta, Tb, Tc>(sM, sN, sK)),
233  "Invalid parameters for XMX8, query valid combinations "
234  "using: "
235  "q.get_device().get_info<sycl::info::device::matrix::combinations>()");
236 
237  // if combination is valid, construct the matrices
238  static constexpr std::size_t M = sM;
239  static constexpr std::size_t N = sN;
240  static constexpr std::size_t K = sK;
241 
242  template <typename Group, layout Layout>
244  template <typename Group, layout Layout>
246  template <typename Group>
248  template <typename Group>
250 };
251 
252 // Default-values query:
253 // Specialization for when only types are given, need to query only sizes
254 
255 template <typename Ta, typename Tb, typename Tc, typename Td>
257  architecture::intel_gpu_dg2_g11, Ta, Tb, Tc, Td, 0, 0, 0,
258  typename std::enable_if<(!std::is_same_v<Ta, void> &&
259  !std::is_same_v<Tb, void> &&
260  !std::is_same_v<Tc, void>)>::type> {
261  static_assert((are_types_valid_xmx8<Ta, Tb, Tc>()),
262  "Invalid types for architecture::intel_gpu_dg2_g11, supported"
263  "types are int8_t, uint8_t, half, and bf16");
264 
265  // construct the matrices using the default sizes
266 
267  static constexpr std::size_t M = 8;
268  static constexpr std::size_t N = 8;
269  static constexpr std::size_t K = ((sizeof(Ta) == 1) ? 32 : 16);
270 
271  template <typename Group, layout Layout>
273  template <typename Group, layout Layout>
275  template <typename Group>
277  template <typename Group>
279 };
280 
281 // Validation query:
282 // Specialization when both types and sizes are given
283 template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM,
284  size_t sN, size_t sK>
286  architecture::intel_gpu_dg2_g11, Ta, Tb, Tc, Td, sM, sN, sK,
287  typename std::enable_if<(
288  !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
289  !std::is_same_v<Tc, void> && sM != 0 && sN != 0 && sK != 0)>::type> {
290  // Validate that parameters are supported
291  static_assert(
292  (sM == 0 && sN == 0 && sK == 0) ||
293  (is_combination_valid_xmx8<Ta, Tb, Tc>(sM, sN, sK)),
294  "Invalid parameters for XMX8, query valid combinations "
295  "using: "
296  "q.get_device().get_info<sycl::info::device::matrix::combinations>()");
297 
298  // if combination is valid, construct the matrices
299  static constexpr std::size_t M = sM;
300  static constexpr std::size_t N = sN;
301  static constexpr std::size_t K = sK;
302 
303  template <typename Group, layout Layout>
305  template <typename Group, layout Layout>
307  template <typename Group>
309  template <typename Group>
311 };
312 
313 // Default-values query:
314 // Specialization for when only types are given, need to query only sizes
315 
316 template <typename Ta, typename Tb, typename Tc, typename Td>
318  architecture::intel_gpu_dg2_g12, Ta, Tb, Tc, Td, 0, 0, 0,
319  typename std::enable_if<(!std::is_same_v<Ta, void> &&
320  !std::is_same_v<Tb, void> &&
321  !std::is_same_v<Tc, void>)>::type> {
322  static_assert((are_types_valid_xmx8<Ta, Tb, Tc>()),
323  "Invalid types for architecture::intel_gpu_dg2_g12, supported "
324  "types are int8_t, uint8_t, half, and bf16");
325 
326  // construct the matrices using the default sizes
327 
328  static constexpr std::size_t M = 8;
329  static constexpr std::size_t N = 8;
330  static constexpr std::size_t K = ((sizeof(Ta) == 1) ? 32 : 16);
331 
332  template <typename Group, layout Layout>
334  template <typename Group, layout Layout>
336  template <typename Group>
338  template <typename Group>
340 };
341 
342 // Validation query:
343 // Specialization when both types and sizes are given
344 template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM,
345  size_t sN, size_t sK>
347  architecture::intel_gpu_dg2_g12, Ta, Tb, Tc, Td, sM, sN, sK,
348  typename std::enable_if<(
349  !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
350  !std::is_same_v<Tc, void> && sM != 0 && sN != 0 && sK != 0)>::type> {
351  // Validate that parameters are supported
352  static_assert(
353  (sM == 0 && sN == 0 && sK == 0) ||
354  (is_combination_valid_xmx8<Ta, Tb, Tc>(sM, sN, sK)),
355  "Invalid parameters for XMX8, query valid combinations "
356  "using: "
357  "q.get_device().get_info<sycl::info::device::matrix::combinations>()");
358 
359  // if combination is valid, construct the matrices
360  static constexpr std::size_t M = sM;
361  static constexpr std::size_t N = sN;
362  static constexpr std::size_t K = sK;
363 
364  template <typename Group, layout Layout>
366  template <typename Group, layout Layout>
368  template <typename Group>
370  template <typename Group>
372 };
373 
374 // Intel XMX with SIMD16 capability
375 // The Intel XMX implementation supports the logical capability support of the
376 // HW So in this case, M, N, K sizes returned by the query represent the logical
377 // capabilities of the Intel XMX hardware.
378 
379 template <typename Ta, typename Tb, typename Tc>
380 constexpr bool is_combination_valid_xmx16(size_t sM, size_t sN, size_t sK) {
381  if ((std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, int8_t> &&
382  std::is_same_v<Tc, int> && (sM >= 1 && sM <= 8) && sN == 16 &&
383  sK == 32) ||
384  (std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, uint8_t> &&
385  std::is_same_v<Tc, int> && (sM >= 1 && sM <= 8) && sN == 16 &&
386  sK == 32) ||
387  (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, int8_t> &&
388  std::is_same_v<Tc, int> && (sM >= 1 && sM <= 8) && sN == 16 &&
389  sK == 32) ||
390  (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, uint8_t> &&
391  std::is_same_v<Tc, int> && (sM >= 1 && sM <= 8) && sN == 16 &&
392  sK == 32) ||
393  (std::is_same_v<Ta, half> && std::is_same_v<Tb, half> &&
394  std::is_same_v<Tc, float> && (sM >= 1 && sM <= 8) && sN == 16 &&
395  sK == 16) ||
396  (std::is_same_v<Ta, unsigned short> &&
397  std::is_same_v<Tb, unsigned short> && std::is_same_v<Tc, float> &&
398  (sM >= 1 && sM <= 8) && sN == 16 && sK == 16))
399  return true;
400  else
401  return false;
402 }
403 
404 template <typename Ta, typename Tb, typename Tc>
405 constexpr bool are_types_valid_xmx16() {
406  if ((std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, int8_t> &&
407  std::is_same_v<Tc, int>) ||
408  (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, int8_t> &&
409  std::is_same_v<Tc, int>) ||
410  (std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, uint8_t> &&
411  std::is_same_v<Tc, int>) ||
412  (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, uint8_t> &&
413  std::is_same_v<Tc, int>) ||
414  (std::is_same_v<Ta, half> && std::is_same_v<Tb, half> &&
415  std::is_same_v<Tc, float>) ||
416  (std::is_same_v<Ta, unsigned short> &&
417  std::is_same_v<Tb, unsigned short> && std::is_same_v<Tc, float>))
418  return true;
419  else
420  return false;
421 }
422 
423 // Default values query:
424 // Specialization for when only types are given, need to query only sizes
425 
426 template <typename Ta, typename Tb, typename Tc, typename Td>
428  architecture::intel_gpu_pvc, Ta, Tb, Tc, Td, 0, 0, 0,
429  typename std::enable_if<(!std::is_same_v<Ta, void> &&
430  !std::is_same_v<Tb, void> &&
431  !std::is_same_v<Tc, void>)>::type> {
432  static_assert((are_types_valid_xmx16<Ta, Tb, Tc>()),
433  "Invalid types for architecture::intel_gpu_pvc, supported "
434  "types are int8_t, uint8_t, "
435  "half, and bf16");
436 
437  // construct the matrices using the default sizes
438 
439  static constexpr std::size_t M = 8;
440  static constexpr std::size_t N = 16;
441  static constexpr std::size_t K = ((sizeof(Ta) == 1) ? 32 : 16);
442 
443  template <typename Group, layout Layout>
445  template <typename Group, layout Layout>
447  template <typename Group>
449  template <typename Group>
451 };
452 
453 // Validation query:
454 // Specialization when both types and sizes are given
455 template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM,
456  size_t sN, size_t sK>
458  architecture::intel_gpu_pvc, Ta, Tb, Tc, Td, sM, sN, sK,
459  typename std::enable_if<(
460  !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
461  !std::is_same_v<Tc, void> && sM != 0 && sN != 0 && sK != 0)>::type> {
462  // Validate that parameters are supported
463  static_assert(
464  (sM == 0 && sN == 0 && sK == 0) ||
465  (is_combination_valid_xmx16<Ta, Tb, Tc>(sM, sN, sK)),
466  "Invalid parameters for architecture::intel_gpu_pvc, query valid "
467  "combinations "
468  "using: "
469  "q.get_device().get_info<sycl::info::device::matrix::combinations>()");
470 
471  // if combination is valid, construct the matrices
472  static constexpr std::size_t M = sM;
473  static constexpr std::size_t N = sN;
474  static constexpr std::size_t K = sK;
475 
476  template <typename Group, layout Layout>
478  template <typename Group, layout Layout>
480  template <typename Group>
482  template <typename Group>
484 };
485 
489 
490 template <typename Ta, typename Tc>
491 constexpr bool is_combination_valid_amd_gfx90a(size_t sM, size_t sN,
492  size_t sK) {
493  return (std::is_same_v<Ta, half> && std::is_same_v<Tc, float> &&
494  ((sM == 32 && sN == 32 && sK == 8) ||
495  (sM == 16 && sN == 16 && sK == 16))) ||
496  (std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t> &&
497  ((sM == 32 && sN == 32 && sK == 8) ||
498  (sM == 16 && sN == 16 && sK == 16))) ||
499  (std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float> &&
500  ((sM == 32 && sN == 32 && sK == 8) ||
501  (sM == 16 && sN == 16 && sK == 16))) ||
502  (std::is_same_v<Ta, double> && std::is_same_v<Tc, double> &&
503  (sM == 16 && sN == 16 && sK == 4));
504 }
505 
506 template <typename Ta, typename Tc>
507 constexpr bool are_types_valid_amd_gfx90a() {
508  return (std::is_same_v<Ta, half> && std::is_same_v<Tc, float>) ||
509  (std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t>) ||
510  (std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float>) ||
511  (std::is_same_v<Ta, double> && std::is_same_v<Tc, double>);
512 }
513 
514 // Default-values query:
515 // Specialization for when only types are given, need to query only sizes
516 template <typename Ta, typename Tb, typename Tc, typename Td>
518  architecture::amd_gpu_gfx90a, Ta, Tb, Tc, Td, 0, 0, 0,
519  typename std::enable_if_t<(
520  !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
521  !std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
522  std::is_same_v<Ta, Tb> && std::is_same_v<Tc, Td>)>> {
523  static_assert(
524  are_types_valid_amd_gfx90a<Ta, Tc>(),
525  "Invalid types for AMD gfx90a, supported types are half, float, "
526  "int8_t, int32_t, double and bfloat16 ");
527 
528  // Default sizes for AMD gfx90a were chosen to represent a square matrix
529  static constexpr std::size_t M = 16;
530  static constexpr std::size_t N = 16;
531  static constexpr std::size_t K = ((sizeof(Ta) == 8) ? 16 : 4);
532 
533  template <typename Group, layout Layout>
535  template <typename Group, layout Layout>
537  template <typename Group>
539  template <typename Group>
541 };
542 
543 // Validation query
544 // Specialization when both types and sizes are given
545 template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM,
546  size_t sN, size_t sK>
548  architecture::amd_gpu_gfx90a, Ta, Tb, Tc, Td, sM, sN, sK,
549  typename std::enable_if_t<(
550  !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
551  !std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
552  std::is_same_v<Ta, Tb> && std::is_same_v<Tc, Td> && sM != 0 &&
553  sN != 0 && sK != 0)>> {
554  static_assert(
555  is_combination_valid_amd_gfx90a<Ta, Tc>(sM, sN, sK),
556  "Invalid parameters for AMD gfx90a, query valid combinations "
557  "using: "
558  "q.get_device().get_info<sycl::info::device::matrix::combinations>()");
559 
560  static constexpr std::size_t M = sM;
561  static constexpr std::size_t N = sN;
562  static constexpr std::size_t K = sK;
563 
564  template <typename Group, layout Layout>
566  template <typename Group, layout Layout>
568  template <typename Group>
570  template <typename Group>
572 };
573 
577 
578 template <typename Ta, typename Tc, typename Td>
579 constexpr bool are_types_valid_cuda_sm70() {
580  return (std::is_same_v<Ta, half> && std::is_same_v<Tc, float> &&
581  std::is_same_v<Td, float>) ||
582  (std::is_same_v<Ta, half> && std::is_same_v<Tc, half> &&
583  std::is_same_v<Td, half>) ||
584  (std::is_same_v<Ta, half> && std::is_same_v<Tc, float> &&
585  std::is_same_v<Td, half>) ||
586  (std::is_same_v<Ta, half> && std::is_same_v<Tc, half> &&
587  std::is_same_v<Td, float>);
588 }
589 
590 template <typename Ta, typename Tc, typename Td>
591 constexpr bool are_types_valid_cuda_sm72() {
592  return (std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t> &&
593  std::is_same_v<Td, int32_t>) ||
594  (std::is_same_v<Ta, uint8_t> && std::is_same_v<Tc, int32_t> &&
595  std::is_same_v<Td, int32_t>);
596 }
597 
598 template <typename Ta, typename Tc, typename Td>
599 constexpr bool are_types_valid_cuda_sm80() {
600  return (std::is_same_v<Ta, precision::tf32> && std::is_same_v<Tc, float> &&
601  std::is_same_v<Td, float>) ||
602  (std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float> &&
603  std::is_same_v<Td, float>) ||
604  (std::is_same_v<Ta, double> && std::is_same_v<Tc, double> &&
605  std::is_same_v<Td, double>);
606 }
607 
608 template <typename Ta, typename Tc, typename Td>
609 constexpr bool is_combination_valid_cuda_sm70(size_t sM, size_t sN, size_t sK) {
610  return are_types_valid_cuda_sm70<Ta, Tc, Td>() &&
611  ((sM == 8 && sN == 32 && sK == 16) ||
612  (sM == 16 && sN == 16 && sK == 16) ||
613  (sM == 32 && sN == 8 && sK == 16));
614 }
615 
616 template <typename Ta, typename Tc, typename Td>
617 constexpr bool is_combination_valid_cuda_sm72(size_t sM, size_t sN, size_t sK) {
618  return are_types_valid_cuda_sm72<Ta, Tc, Td>() &&
619  ((sM == 8 && sN == 32 && sK == 16) ||
620  (sM == 16 && sN == 16 && sK == 16) ||
621  (sM == 32 && sN == 8 && sK == 16));
622 }
623 
624 template <typename Ta, typename Tc, typename Td>
625 constexpr bool is_combination_valid_cuda_sm80(size_t sM, size_t sN, size_t sK) {
626  return ((std::is_same_v<Ta, precision::tf32> && std::is_same_v<Tc, float> &&
627  std::is_same_v<Td, float>)&&(sM == 16 && sN == 16 && sK == 8)) ||
628  ((std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float> &&
629  std::is_same_v<Td, float>)&&((sM == 16 && sN == 16 && sK == 16) ||
630  (sM == 8 && sN == 32 && sK == 16) ||
631  (sM == 32 && sN == 8 && sK == 16))) ||
632  ((std::is_same_v<Ta, double> && std::is_same_v<Tc, double> &&
633  std::is_same_v<Td, double>)&&(sM == 8 && sN == 8 && sK == 4));
634 }
635 
636 // Default-values query (nvidia sm70):
637 // Specialization for when only types are given, need to query only sizes
638 template <typename Ta, typename Tb, typename Tc, typename Td>
640  architecture::nvidia_gpu_sm_70, Ta, Tb, Tc, Td, 0, 0, 0,
641  typename std::enable_if_t<(
642  !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
643  !std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
644  std::is_same_v<Ta, Tb>)>> {
645  static_assert(
646  are_types_valid_cuda_sm70<Ta, Tc, Td>(),
647  "Invalid types for nvidia sm70, supported types are half and float ");
648 
649  // Default sizes for nvidia sm70 were chosen to represent a square matrix
650  static constexpr std::size_t M = 16;
651  static constexpr std::size_t N = 16;
652  static constexpr std::size_t K = 16;
653 
654  template <typename Group, layout Layout>
656  template <typename Group, layout Layout>
658  template <typename Group>
660  template <typename Group>
662 };
663 
664 // Default-values query (nvidia sm72):
665 // Specialization for when only types are given, need to query only sizes
666 template <typename Ta, typename Tb, typename Tc, typename Td>
668  architecture::nvidia_gpu_sm_72, Ta, Tb, Tc, Td, 0, 0, 0,
669  typename std::enable_if<(
670  !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
671  !std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
672  std::is_same_v<Ta, Tb>)>::type> {
673  static_assert(
674  are_types_valid_cuda_sm70<Ta, Tc, Td>() ||
675  are_types_valid_cuda_sm72<Ta, Tc, Td>(),
676  "Invalid types for nvidia sm72, supported types are half, float "
677  "int8_t, uint8_t and int32_t ");
678 
679  static constexpr std::size_t M = 16;
680  static constexpr std::size_t N = 16;
681  static constexpr std::size_t K = 16;
682 
683  template <typename Group, layout Layout>
685  template <typename Group, layout Layout>
687  template <typename Group>
689  template <typename Group>
691 };
692 
693 // Default-values query (nvidia sm80):
694 // Specialization for when only types are given, need to query only sizes
695 template <typename Ta, typename Tb, typename Tc, typename Td>
697  architecture::nvidia_gpu_sm_80, Ta, Tb, Tc, Td, 0, 0, 0,
698  typename std::enable_if_t<(
699  !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
700  !std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
701  std::is_same_v<Ta, Tb>)>> {
702  static_assert(
703  are_types_valid_cuda_sm70<Ta, Tc, Td>() ||
704  are_types_valid_cuda_sm72<Ta, Tc, Td>() ||
705  are_types_valid_cuda_sm80<Ta, Tc, Td>(),
706  "Invalid types for nvidia sm80, supported types are half, float "
707  "int8_t, uint8_t, int32_t, double, tf32 and bfloat16 ");
708 
709  static constexpr std::size_t M = (sizeof(Ta) == 8) ? 8 : 16;
710  static constexpr std::size_t N = (sizeof(Ta) == 8) ? 8 : 16;
711  static constexpr std::size_t K =
712  std::is_same_v<Ta, precision::tf32> ? 8 : (sizeof(Ta) == 8 ? 4 : 16);
713 
714  template <typename Group, layout Layout>
716  template <typename Group, layout Layout>
718  template <typename Group>
720  template <typename Group>
722 };
723 
724 // Validation query (nvidia sm70)
725 // Specialization when both types and sizes are given
726 template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM,
727  size_t sN, size_t sK>
729  architecture::nvidia_gpu_sm_70, Ta, Tb, Tc, Td, sM, sN, sK,
730  typename std::enable_if_t<(
731  !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
732  !std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
733  std::is_same_v<Ta, Tb> && sM != 0 && sN != 0 && sK != 0)>> {
734  static_assert(
735  is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM, sN, sK),
736  "Invalid parameters for nvidia sm70, query valid combinations "
737  "using: "
738  "q.get_device().get_info<sycl::info::device::matrix::combinations>()");
739 
740  static constexpr std::size_t M = sM;
741  static constexpr std::size_t N = sN;
742  static constexpr std::size_t K = sK;
743 
744  template <typename Group, layout Layout>
746  template <typename Group, layout Layout>
748  template <typename Group>
750  template <typename Group>
752 };
753 
754 // Validation query (nvidia sm72)
755 // Specialization when both types and sizes are given
756 template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM,
757  size_t sN, size_t sK>
759  architecture::nvidia_gpu_sm_72, Ta, Tb, Tc, Td, sM, sN, sK,
760  typename std::enable_if_t<(
761  !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
762  !std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
763  std::is_same_v<Ta, Tb> && sM != 0 && sN != 0 && sK != 0)>> {
764  static_assert(
765  is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM, sN, sK) ||
766  is_combination_valid_cuda_sm72<Ta, Tc, Td>(sM, sN, sK),
767  "Invalid parameters for nvidia sm72, query valid combinations "
768  "using: "
769  "q.get_device().get_info<sycl::info::device::matrix::combinations>()");
770 
771  static constexpr std::size_t M = sM;
772  static constexpr std::size_t N = sN;
773  static constexpr std::size_t K = sK;
774 
775  template <typename Group, layout Layout>
777  template <typename Group, layout Layout>
779  template <typename Group>
781  template <typename Group>
783 };
784 
785 // Validation query (nvidia sm80)
786 // Specialization when both types and sizes are given
787 template <typename Ta, typename Tb, typename Tc, typename Td, size_t sM,
788  size_t sN, size_t sK>
790  architecture::nvidia_gpu_sm_80, Ta, Tb, Tc, Td, sM, sN, sK,
791  typename std::enable_if_t<(
792  !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
793  !std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
794  std::is_same_v<Ta, Tb> && sM != 0 && sN != 0 && sK != 0)>> {
795  static_assert(
796  is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM, sN, sK) ||
797  is_combination_valid_cuda_sm72<Ta, Tc, Td>(sM, sN, sK) ||
798  is_combination_valid_cuda_sm80<Ta, Tc, Td>(sM, sN, sK),
799  "Invalid parameters for nvidia sm80, query valid combinations "
800  "using: "
801  "q.get_device().get_info<sycl::info::device::matrix::combinations>()");
802 
803  static constexpr std::size_t M = sM;
804  static constexpr std::size_t N = sN;
805  static constexpr std::size_t K = sK;
806 
807  template <typename Group, layout Layout>
809  template <typename Group, layout Layout>
811  template <typename Group>
813  template <typename Group>
815 };
816 
817 } // namespace experimental::matrix
818 } // namespace oneapi
819 } // namespace ext
820 } // namespace _V1
821 } // namespace sycl
constexpr bool is_combination_valid_cuda_sm72(size_t sM, size_t sN, size_t sK)
constexpr bool is_combination_valid_amx(size_t sM, size_t sN, size_t sK)
constexpr bool are_types_valid_cuda_sm70()
CUDA Tensor Cores - sm70, sm72 and sm80 ///.
constexpr bool is_combination_valid_xmx16(size_t sM, size_t sN, size_t sK)
constexpr bool is_combination_valid_cuda_sm80(size_t sM, size_t sN, size_t sK)
constexpr bool is_combination_valid_xmx8(size_t sM, size_t sN, size_t sK)
constexpr bool is_combination_valid_amd_gfx90a(size_t sM, size_t sN, size_t sK)
AMD Matrix Cores - GFX90A architecture ///.
constexpr bool is_combination_valid_cuda_sm70(size_t sM, size_t sN, size_t sK)
Definition: access.hpp:18