DPC++ Runtime
Runtime libraries for oneAPI DPC++
matrix-tensorcore.hpp
Go to the documentation of this file.
1 //===---- matrix-tensorcore.hpp - SYCL tensor cores matrix ----*- C++ -*---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // ===--------------------------------------------------------------------=== //
8 
9 #pragma once
10 
12 namespace sycl {
13 namespace ext {
14 namespace oneapi {
15 namespace experimental::matrix {
16 
17 enum class matrix_use { a, b, accumulator };
18 
19 enum class matrix_layout { row_major, col_major, packed_a, packed_b };
20 
21 namespace precision {
22 class tf32 {};
23 } // namespace precision
24 
25 template <typename T, matrix_use Use, size_t Rows = sycl::dynamic_extent,
26  size_t Cols = sycl::dynamic_extent,
27  matrix_layout Layout = matrix_layout::row_major,
28  typename Group = sycl::sub_group, typename Cond = void>
29 struct joint_matrix;
30 
31 #define __SYCL_JOINT_MATRIX_OVERLOAD(type, use, M, N, frag_type, frag_size) \
32  template <matrix_layout Layout> \
33  struct joint_matrix< \
34  type, matrix_use::use, M, N, Layout, sycl::sub_group, \
35  typename std::enable_if_t<Layout == matrix_layout::row_major || \
36  Layout == matrix_layout::col_major>> { \
37  frag_type data[frag_size]; \
38  };
39 
40 // m8n8k4 double only
41 __SYCL_JOINT_MATRIX_OVERLOAD(double, a, 8, 4, double, 1)
42 __SYCL_JOINT_MATRIX_OVERLOAD(double, b, 4, 8, double, 1)
43 __SYCL_JOINT_MATRIX_OVERLOAD(double, accumulator, 8, 8, double, 2)
44 
45 // m8n32k16
46 // bf16 data format uses uint16_t data type
47 __SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 8, 16, int32_t, 2)
48 __SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 32, int32_t, 8)
49 __SYCL_JOINT_MATRIX_OVERLOAD(half, a, 8, 16, int32_t, 8)
50 __SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 32, int32_t, 8)
51 __SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 8, 32, float, 8)
52 __SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 8, 32, int32_t, 4)
53 
54 __SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 8, 16, int32_t, 1)
55 __SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 32, int32_t, 4)
56 __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 8, 16, int32_t, 1)
57 __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 32, int32_t, 4)
58 __SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 8, 32, int32_t, 8)
59 
60 // m32n8k16
61 __SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 32, 16, int32_t, 8)
62 __SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 8, int32_t, 2)
63 __SYCL_JOINT_MATRIX_OVERLOAD(half, a, 32, 16, int32_t, 8)
64 __SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 8, int32_t, 8)
65 __SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 32, 8, float, 8)
66 __SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 32, 8, int32_t, 4)
67 
68 __SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 32, 16, int32_t, 4)
69 __SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 8, int32_t, 1)
70 __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 32, 16, int32_t, 4)
71 __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 8, int32_t, 1)
72 __SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 32, 8, int32_t, 8)
73 
74 // m16n16k16
75 __SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 16, 16, int32_t, 4)
76 __SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 16, int32_t, 4)
77 __SYCL_JOINT_MATRIX_OVERLOAD(half, a, 16, 16, int32_t, 8)
78 __SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 16, int32_t, 8)
79 __SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 16, 16, float, 8)
80 __SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 16, 16, int32_t, 4)
81 
82 __SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 16, 16, int32_t, 2)
83 __SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 16, int32_t, 2)
84 __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 16, 16, int32_t, 2)
85 __SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 16, int32_t, 2)
86 __SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 16, 16, int32_t, 8)
87 
88 // m16n16k8 tf32
89 __SYCL_JOINT_MATRIX_OVERLOAD(precision::tf32, a, 16, 8, float, 4)
90 __SYCL_JOINT_MATRIX_OVERLOAD(precision::tf32, b, 8, 16, float, 4)
91 
92 #undef __SYCL_JOINT_MATRIX_OVERLOAD
93 } // namespace experimental::matrix
94 
95 namespace detail {
96 
97 template <typename S, typename T,
99  size_t NumRows, size_t NumCols,
101  access::address_space Space, typename Cond = void>
104  S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
105  multi_ptr<T, Space> src, size_t stride);
106 };
107 
108 template <sycl::ext::oneapi::experimental::matrix::matrix_layout Layout>
109 constexpr int get_layout_id();
110 
111 template <>
112 constexpr int get_layout_id<
113  sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major>() {
114  return 0;
115 }
116 
117 template <>
118 constexpr int get_layout_id<
119  sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major>() {
120  return 1;
121 }
122 
123 template <typename S, typename T,
125  size_t NumRows, size_t NumCols,
127  access::address_space Space>
129  S, T, Use, NumRows, NumCols, Layout, Space,
130  typename std::enable_if_t<Layout == sycl::ext::oneapi::experimental::
131  matrix::matrix_layout::row_major ||
132  Layout == sycl::ext::oneapi::experimental::
133  matrix::matrix_layout::col_major>> {
135  S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
136  multi_ptr<T, Space> src, size_t stride) {
137  if constexpr (std::is_same<T, uint16_t>::value) {
138  int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
139  if constexpr (NumRows == 16 && NumCols == 16) {
140  if constexpr (Use ==
141  sycl::ext::oneapi::experimental::matrix::matrix_use::a) {
142  __mma_bf16_m16n16k16_ld_a(res.data, tileptr, stride,
143  get_layout_id<Layout>());
144  } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::
145  matrix_use::b) {
146  __mma_bf16_m16n16k16_ld_b(res.data, tileptr, stride,
147  get_layout_id<Layout>());
148  }
149  } else if constexpr (NumRows == 8 && NumCols == 16) {
150  __mma_bf16_m8n32k16_ld_a(res.data, tileptr, stride,
151  get_layout_id<Layout>());
152  } else if constexpr (NumRows == 16 && NumCols == 32) {
153  __mma_bf16_m8n32k16_ld_b(res.data, tileptr, stride,
154  get_layout_id<Layout>());
155  } else if constexpr (NumRows == 32 && NumCols == 16) {
156  __mma_bf16_m32n8k16_ld_a(res.data, tileptr, stride,
157  get_layout_id<Layout>());
158  } else if constexpr (NumRows == 16 && NumCols == 8) {
159  __mma_bf16_m32n8k16_ld_b(res.data, tileptr, stride,
160  get_layout_id<Layout>());
161  }
162  } else if constexpr (std::is_same<T, uint8_t>::value) {
163  int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
164  if constexpr (NumRows == 16 && NumCols == 16) {
165  if constexpr (Use ==
166  sycl::ext::oneapi::experimental::matrix::matrix_use::a) {
167  __imma_m16n16k16_ld_a_u8(res.data, tileptr, stride,
168  get_layout_id<Layout>());
169  } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::
170  matrix_use::b) {
171  __imma_m16n16k16_ld_b_u8(res.data, tileptr, stride,
172  get_layout_id<Layout>());
173  }
174  } else if constexpr (NumRows == 8 && NumCols == 16) {
175  __imma_m8n32k16_ld_a_u8(res.data, tileptr, stride,
176  get_layout_id<Layout>());
177  } else if constexpr (NumRows == 16 && NumCols == 32) {
178  __imma_m8n32k16_ld_b_u8(res.data, tileptr, stride,
179  get_layout_id<Layout>());
180  } else if constexpr (NumRows == 32 && NumCols == 16) {
181  __imma_m32n8k16_ld_a_u8(res.data, tileptr, stride,
182  get_layout_id<Layout>());
183  } else if constexpr (NumRows == 16 && NumCols == 8) {
184  __imma_m32n8k16_ld_b_u8(res.data, tileptr, stride,
185  get_layout_id<Layout>());
186  }
187  } else if constexpr (std::is_same<T, int8_t>::value) {
188  int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
189  if constexpr (NumRows == 16 && NumCols == 16) {
190  if constexpr (Use ==
191  sycl::ext::oneapi::experimental::matrix::matrix_use::a) {
192  __imma_m16n16k16_ld_a_s8(res.data, tileptr, stride,
193  get_layout_id<Layout>());
194  } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::
195  matrix_use::b) {
196  __imma_m16n16k16_ld_b_s8(res.data, tileptr, stride,
197  get_layout_id<Layout>());
198  }
199  } else if constexpr (NumRows == 8 && NumCols == 16) {
200  __imma_m8n32k16_ld_a_s8(res.data, tileptr, stride,
201  get_layout_id<Layout>());
202  } else if constexpr (NumRows == 16 && NumCols == 32) {
203  __imma_m8n32k16_ld_b_s8(res.data, tileptr, stride,
204  get_layout_id<Layout>());
205  } else if constexpr (NumRows == 32 && NumCols == 16) {
206  __imma_m32n8k16_ld_a_s8(res.data, tileptr, stride,
207  get_layout_id<Layout>());
208  } else if constexpr (NumRows == 16 && NumCols == 8) {
209  __imma_m32n8k16_ld_b_s8(res.data, tileptr, stride,
210  get_layout_id<Layout>());
211  }
212  } else if constexpr (std::is_same<T, half>::value) {
213  int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
214  if constexpr (NumRows == 16 && NumCols == 16) {
215  if constexpr (Use ==
216  sycl::ext::oneapi::experimental::matrix::matrix_use::a) {
217  __hmma_m16n16k16_ld_a(res.data, tileptr, stride,
218  get_layout_id<Layout>());
219  } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::
220  matrix_use::b) {
221  __hmma_m16n16k16_ld_b(res.data, tileptr, stride,
222  get_layout_id<Layout>());
223  } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::
224  matrix_use::accumulator) {
225  __hmma_m16n16k16_ld_c_f16(res.data, tileptr, stride,
226  get_layout_id<Layout>());
227  }
228  } else if constexpr (NumRows == 8 && NumCols == 16) {
229  __hmma_m8n32k16_ld_a(res.data, tileptr, stride,
230  get_layout_id<Layout>());
231  } else if constexpr (NumRows == 16 && NumCols == 32) {
232  __hmma_m8n32k16_ld_b(res.data, tileptr, stride,
233  get_layout_id<Layout>());
234  } else if constexpr (NumRows == 32 && NumCols == 16) {
235  __hmma_m32n8k16_ld_a(res.data, tileptr, stride,
236  get_layout_id<Layout>());
237  } else if constexpr (NumRows == 16 && NumCols == 8) {
238  __hmma_m32n8k16_ld_b(res.data, tileptr, stride,
239  get_layout_id<Layout>());
240  } else if constexpr (NumRows == 32 && NumCols == 8) {
241  __hmma_m32n8k16_ld_c_f16(res.data, tileptr, stride,
242  get_layout_id<Layout>());
243  } else if constexpr (NumRows == 8 && NumCols == 32) {
244  __hmma_m8n32k16_ld_c_f16(res.data, tileptr, stride,
245  get_layout_id<Layout>());
246  }
247 
248  } else if constexpr (std::is_same<T, int32_t>::value) {
249  if constexpr (NumRows == 16 && NumCols == 16) {
250  __imma_m16n16k16_ld_c(res.data, src.get(), stride,
251  get_layout_id<Layout>());
252  } else if constexpr (NumRows == 8 && NumCols == 32) {
253  __imma_m8n32k16_ld_c(res.data, src.get(), stride,
254  get_layout_id<Layout>());
255  } else if constexpr (NumRows == 32 && NumCols == 8) {
256  __imma_m32n8k16_ld_c(res.data, src.get(), stride,
257  get_layout_id<Layout>());
258  }
259  } else if constexpr (std::is_same<T, float>::value) {
260  if (std::is_same<S, float>::value) {
261  if constexpr (NumRows == 16 && NumCols == 16) {
262  __hmma_m16n16k16_ld_c_f32(res.data, src.get(), stride,
263  get_layout_id<Layout>());
264  } else if constexpr (NumRows == 8 && NumCols == 32) {
265  __hmma_m8n32k16_ld_c_f32(res.data, src.get(), stride,
266  get_layout_id<Layout>());
267  } else if constexpr (NumRows == 32 && NumCols == 8) {
268  __hmma_m32n8k16_ld_c_f32(res.data, src.get(), stride,
269  get_layout_id<Layout>());
270  }
271  } else if (std::is_same<S, sycl::ext::oneapi::experimental::matrix::
272  precision::tf32>::value) {
273  int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
274  if constexpr (NumRows == 16 && NumCols == 8) {
275  __mma_tf32_m16n16k8_ld_a(reinterpret_cast<int32_t *>(res.data),
276  tileptr, stride, get_layout_id<Layout>());
277  } else if constexpr (NumRows == 8 && NumCols == 16) {
278  __mma_tf32_m16n16k8_ld_b(reinterpret_cast<int32_t *>(res.data),
279  tileptr, stride, get_layout_id<Layout>());
280  }
281  }
282  } else if constexpr (std::is_same<T, double>::value) {
283  if constexpr (Use ==
284  sycl::ext::oneapi::experimental::matrix::matrix_use::a) {
285  __dmma_m8n8k4_ld_a(res.data, src.get(), stride,
286  get_layout_id<Layout>());
287  } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::
288  matrix_use::b) {
289  __dmma_m8n8k4_ld_b(res.data, src.get(), stride,
290  get_layout_id<Layout>());
291  } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::
292  matrix_use::accumulator) {
293  __dmma_m8n8k4_ld_c(res.data, src.get(), stride,
294  get_layout_id<Layout>());
295  }
296  }
297  }
298 };
299 
300 template <typename T, size_t NumRows, size_t NumCols,
302  access::address_space Space, typename Cond = void>
304  void
306  T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
307  NumRows, NumCols, Layout, sycl::sub_group> &src,
308  multi_ptr<T, Space> dst, size_t stride);
309 };
310 
311 template <typename T, size_t NumRows, size_t NumCols,
313  access::address_space Space>
315  T, NumRows, NumCols, Layout, Space,
316  typename std::enable_if_t<Layout == sycl::ext::oneapi::experimental::
317  matrix::matrix_layout::row_major ||
318  Layout == sycl::ext::oneapi::experimental::
319  matrix::matrix_layout::col_major>> {
320  void
322  T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
323  NumRows, NumCols, Layout, sycl::sub_group> &src,
324  multi_ptr<T, Space> dst, size_t stride) {
325  if (NumRows == 16 && NumCols == 16) {
326  if constexpr (std::is_same<T, float>::value) {
327  __hmma_m16n16k16_st_c_f32(dst.get(), src.data, stride,
328  get_layout_id<Layout>());
329  } else if constexpr (std::is_same<T, int32_t>::value) {
330  __imma_m16n16k16_st_c_i32(dst.get(), src.data, stride,
331  get_layout_id<Layout>());
332  } else if constexpr (std::is_same<T, half>::value) {
333  int32_t *tileptr = reinterpret_cast<int32_t *>(dst.get());
334  __hmma_m16n16k16_st_c_f16(tileptr, src.data, stride,
335  get_layout_id<Layout>());
336  }
337  } else if (NumRows == 8 && NumCols == 32) {
338  if constexpr (std::is_same<T, float>::value) {
339  __hmma_m8n32k16_st_c_f32(dst.get(), src.data, stride,
340  get_layout_id<Layout>());
341  } else if constexpr (std::is_same<T, int32_t>::value) {
342  __imma_m8n32k16_st_c_i32(dst.get(), src.data, stride,
343  get_layout_id<Layout>());
344  } else if constexpr (std::is_same<T, half>::value) {
345  int32_t *tileptr = reinterpret_cast<int32_t *>(dst.get());
346  __hmma_m8n32k16_st_c_f16(tileptr, src.data, stride,
347  get_layout_id<Layout>());
348  }
349  } else if (NumRows == 32 && NumCols == 8) {
350  if constexpr (std::is_same<T, float>::value) {
351  __hmma_m32n8k16_st_c_f32(dst.get(), src.data, stride,
352  get_layout_id<Layout>());
353  } else if constexpr (std::is_same<T, int32_t>::value) {
354  __imma_m32n8k16_st_c_i32(dst.get(), src.data, stride,
355  get_layout_id<Layout>());
356  } else if constexpr (std::is_same<T, half>::value) {
357  int32_t *tileptr = reinterpret_cast<int32_t *>(dst.get());
358  __hmma_m32n8k16_st_c_f16(tileptr, src.data, stride,
359  get_layout_id<Layout>());
360  }
361  } else if constexpr (std::is_same<T, double>::value) {
362  __dmma_m8n8k4_st_c_f64(dst.get(), src.data, stride,
363  get_layout_id<Layout>());
364  }
365  }
366 };
367 
368 template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
372  typename Cond = void>
375  T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
376  N, LayoutC, sycl::sub_group>
378  T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K,
379  LayoutA, sycl::sub_group>
380  A,
382  T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N,
383  LayoutB, sycl::sub_group>
384  B,
386  T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
387  M, N, LayoutC, sycl::sub_group>
388  C);
389 };
390 
393 constexpr int get_layout_pair_id();
394 
395 template <>
396 constexpr int get_layout_pair_id<
397  sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major,
398  sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major>() {
399  return 0;
400 }
401 
402 template <>
403 constexpr int get_layout_pair_id<
404  sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major,
405  sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major>() {
406  return 1;
407 }
408 
409 template <>
410 constexpr int get_layout_pair_id<
411  sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major,
412  sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major>() {
413  return 2;
414 }
415 
416 template <>
417 constexpr int get_layout_pair_id<
418  sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major,
419  sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major>() {
420  return 3;
421 }
422 
423 template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
428  T1, T2, M, K, N, LayoutA, LayoutB, LayoutC,
429  typename std::enable_if_t<
430  (LayoutA == sycl::ext::oneapi::experimental::matrix::matrix_layout::
431  row_major ||
432  LayoutA == sycl::ext::oneapi::experimental::matrix::matrix_layout::
433  col_major) &&
434  (LayoutB == sycl::ext::oneapi::experimental::matrix::matrix_layout::
435  row_major ||
436  LayoutB == sycl::ext::oneapi::experimental::matrix::matrix_layout::
437  col_major) &&
438  (LayoutC == sycl::ext::oneapi::experimental::matrix::matrix_layout::
439  row_major ||
440  LayoutC == sycl::ext::oneapi::experimental::matrix::matrix_layout::
441  col_major)>> {
443  T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
444  N, LayoutC, sycl::sub_group>
446  T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K,
447  LayoutA, sycl::sub_group>
448  A,
450  T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N,
451  LayoutB, sycl::sub_group>
452  B,
454  T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
455  M, N, LayoutC, sycl::sub_group>
456  C) {
458  T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
459  N, LayoutC, sycl::sub_group>
460  D;
461  if constexpr (M == 16 && N == 16 && K == 16) {
462  if constexpr (std::is_same<T1, int8_t>::value) {
463  __imma_m16n16k16_mma_s8(D.data, A.data, B.data, C.data,
464  get_layout_pair_id<LayoutA, LayoutB>(), 0);
465  } else if constexpr (std::is_same<T1, uint8_t>::value) {
466  __imma_m16n16k16_mma_u8(D.data, A.data, B.data, C.data,
467  get_layout_pair_id<LayoutA, LayoutB>(), 0);
468  } else if constexpr (std::is_same<T1, half>::value) {
469  if constexpr (std::is_same<T2, float>::value) {
470  __hmma_m16n16k16_mma_f32f32(D.data, A.data, B.data, C.data,
471  get_layout_pair_id<LayoutA, LayoutB>(),
472  0);
473  } else if constexpr (std::is_same<T2, half>::value) {
474  __hmma_m16n16k16_mma_f16f16(D.data, A.data, B.data, C.data,
475  get_layout_pair_id<LayoutA, LayoutB>(),
476  0);
477  }
478  } else if constexpr (std::is_same<T1, uint16_t>::value) {
479  __mma_bf16_m16n16k16_mma_f32(D.data, A.data, B.data, C.data,
480  get_layout_pair_id<LayoutA, LayoutB>(), 0);
481  }
482  } else if constexpr (M == 8 && N == 32 && K == 16) {
483  if constexpr (std::is_same<T1, int8_t>::value) {
484  __imma_m8n32k16_mma_s8(D.data, A.data, B.data, C.data,
485  get_layout_pair_id<LayoutA, LayoutB>(), 0);
486  } else if constexpr (std::is_same<T1, uint8_t>::value) {
487  __imma_m8n32k16_mma_u8(D.data, A.data, B.data, C.data,
488  get_layout_pair_id<LayoutA, LayoutB>(), 0);
489  } else if constexpr (std::is_same<T1, half>::value) {
490  if constexpr (std::is_same<T2, float>::value) {
491  __hmma_m8n32k16_mma_f32f32(D.data, A.data, B.data, C.data,
492  get_layout_pair_id<LayoutA, LayoutB>(), 0);
493  } else if constexpr (std::is_same<T2, half>::value) {
494  __hmma_m8n32k16_mma_f16f16(D.data, A.data, B.data, C.data,
495  get_layout_pair_id<LayoutA, LayoutB>(), 0);
496  }
497  } else if constexpr (std::is_same<T1, uint16_t>::value) {
498  __mma_bf16_m8n32k16_mma_f32(D.data, A.data, B.data, C.data,
499  get_layout_pair_id<LayoutA, LayoutB>(), 0);
500  }
501  } else if constexpr (M == 32 && N == 8 && K == 16) {
502  if constexpr (std::is_same<T1, int8_t>::value) {
503  __imma_m32n8k16_mma_s8(D.data, A.data, B.data, C.data,
504  get_layout_pair_id<LayoutA, LayoutB>(), 0);
505  } else if constexpr (std::is_same<T1, uint8_t>::value) {
506  __imma_m32n8k16_mma_u8(D.data, A.data, B.data, C.data,
507  get_layout_pair_id<LayoutA, LayoutB>(), 0);
508  } else if constexpr (std::is_same<T1, uint16_t>::value) {
509  __mma_bf16_m32n8k16_mma_f32(D.data, A.data, B.data, C.data,
510  get_layout_pair_id<LayoutA, LayoutB>(), 0);
511  } else if constexpr (std::is_same<T1, half>::value) {
512  if constexpr (std::is_same<T2, float>::value) {
513  __hmma_m32n8k16_mma_f32f32(D.data, A.data, B.data, C.data,
514  get_layout_pair_id<LayoutA, LayoutB>(), 0);
515  } else if constexpr (std::is_same<T2, half>::value) {
516  __hmma_m32n8k16_mma_f16f16(D.data, A.data, B.data, C.data,
517  get_layout_pair_id<LayoutA, LayoutB>(), 0);
518  }
519  }
520  } else if constexpr (M == 16 && N == 16 && K == 8) {
521  __mma_tf32_m16n16k8_mma_f32(D.data, reinterpret_cast<int32_t *>(A.data),
522  reinterpret_cast<int32_t *>(B.data), C.data,
523  get_layout_pair_id<LayoutA, LayoutB>(), 0);
524  } else if constexpr (std::is_same<T1, double>::value) {
525  __dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data,
526  get_layout_pair_id<LayoutA, LayoutB>(), 0);
527  }
528  return D;
529  }
530 };
531 
532 } // namespace detail
533 
534 namespace experimental::matrix {
535 
536 template <typename Group, typename S, typename T, matrix_use Use,
537  size_t NumRows, size_t NumCols, matrix_layout Layout,
538  access::address_space Space,
539  std::enable_if_t<std::is_same<S, T>::value ||
540  (std::is_same<S, precision::tf32>::value &&
541  std::is_same<T, float>::value),
542  bool> = true>
545  multi_ptr<T, Space> src, size_t stride) {
546 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
547  sycl::ext::oneapi::detail::joint_matrix_load_impl<S, T, Use, NumRows, NumCols,
548  Layout, Space>{}
549  .load(res, src, stride);
550 #else
551  (void)sg;
552  (void)res;
553  (void)src;
554  (void)stride;
555  throw runtime_error(
556  "When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_load is "
557  "only supported by CUDA devices",
558  PI_ERROR_INVALID_DEVICE);
559 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
560 }
561 
562 template <typename Group, typename T, size_t NumRows, size_t NumCols,
563  matrix_layout Layout, access::address_space Space>
564 void joint_matrix_store(Group sg,
565  joint_matrix<T, matrix_use::accumulator, NumRows,
566  NumCols, Layout, Group> &src,
567  multi_ptr<T, Space> dst, size_t stride) {
568 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
570  Layout, Space>{}
571  .store(src, dst, stride);
572 #else
573  (void)sg;
574  (void)src;
575  (void)dst;
576  (void)stride;
577  throw runtime_error(
578  "When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_store is "
579  "only supported by CUDA devices",
580  PI_ERROR_INVALID_DEVICE);
581 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
582 }
583 
584 template <typename Group, typename T1, typename T2, std::size_t M,
585  std::size_t K, std::size_t N, matrix_layout LayoutA,
586  matrix_layout LayoutB, matrix_layout LayoutC>
587 joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group>
592 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
594  T1, T2, M, K, N, LayoutA, LayoutB, LayoutC>{}
595  .mad(A, B, C);
596 #else
597  (void)sg;
598  (void)A;
599  (void)B;
600  (void)C;
601  throw runtime_error("When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_mad is "
602  "only supported by CUDA devices",
603  PI_ERROR_INVALID_DEVICE);
604 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
605 }
606 
607 // This function rounds the bottom 13 bits up or down, and then zeros out the
608 // bottom bits
609 float round_to_tf32(float a) {
610 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
611  int32_t tmp_int = __nvvm_f2tf32_rna(a);
612  return __nvvm_bitcast_i2f(tmp_int);
613 #else
614  uint32_t tmp_uint = reinterpret_cast<uint32_t &>(a);
615  tmp_uint += 0x1000u;
616  tmp_uint &= 0xFFFFE000u;
617  float ret = reinterpret_cast<float &>(tmp_uint);
618  return ret;
619 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
620 }
621 
622 } // namespace experimental::matrix
623 } // namespace oneapi
624 } // namespace ext
625 } // namespace sycl
626 } // __SYCL_INLINE_NAMESPACE(cl)
cl::sycl::ext::oneapi::experimental::matrix::joint_matrix_mad
joint_matrix< T2, matrix_use::accumulator, M, N, LayoutC, Group > joint_matrix_mad(Group sg, joint_matrix< T1, matrix_use::a, M, K, LayoutA, Group > A, joint_matrix< T1, matrix_use::b, K, N, LayoutB, Group > B, joint_matrix< T2, matrix_use::accumulator, M, N, LayoutC, Group > C)
Definition: matrix-tensorcore.hpp:588
T
cl::sycl::ext::oneapi::detail::joint_matrix_store_impl::store
void store(sycl::ext::oneapi::experimental::matrix::joint_matrix< T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, NumRows, NumCols, Layout, sycl::sub_group > &src, multi_ptr< T, Space > dst, size_t stride)
cl::sycl::ext::oneapi::experimental::matrix::matrix_use::b
@ b
cl::sycl::multi_ptr::get
pointer_t get() const
Definition: multi_ptr.hpp:234
cl::sycl::ext::oneapi::detail::get_layout_pair_id
constexpr int get_layout_pair_id()
cl::sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator
@ accumulator
cl::sycl::mad
detail::enable_if_t< detail::is_genfloat< T >::value, T > mad(T a, T b, T c) __NOEXC
Definition: builtins.hpp:336
cl::sycl::ext::oneapi::detail::joint_matrix_mad_impl< T1, T2, M, K, N, LayoutA, LayoutB, LayoutC, typename std::enable_if_t<(LayoutA==sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major||LayoutA==sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major) &&(LayoutB==sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major||LayoutB==sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major) &&(LayoutC==sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major||LayoutC==sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major)> >::mad
sycl::ext::oneapi::experimental::matrix::joint_matrix< T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, LayoutC, sycl::sub_group > mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K, LayoutA, sycl::sub_group > A, sycl::ext::oneapi::experimental::matrix::joint_matrix< T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N, LayoutB, sycl::sub_group > B, sycl::ext::oneapi::experimental::matrix::joint_matrix< T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, LayoutC, sycl::sub_group > C)
Definition: matrix-tensorcore.hpp:445
cl::sycl::ext::oneapi::detail::get_layout_id
constexpr int get_layout_id()
cl::sycl::ext::oneapi::detail::joint_matrix_load_impl::load
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< S, Use, NumRows, NumCols, Layout, sycl::sub_group > &res, multi_ptr< T, Space > src, size_t stride)
cl::sycl::ext::oneapi::experimental::matrix::matrix_layout
matrix_layout
Definition: matrix-jit.hpp:22
sycl
Definition: invoke_simd.hpp:68
cl::sycl::multi_ptr
Provides constructors for address space qualified and non address space qualified pointers to allow i...
Definition: atomic.hpp:33
cl::sycl::ext::oneapi::detail::joint_matrix_mad_impl::mad
sycl::ext::oneapi::experimental::matrix::joint_matrix< T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, LayoutC, sycl::sub_group > mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K, LayoutA, sycl::sub_group > A, sycl::ext::oneapi::experimental::matrix::joint_matrix< T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N, LayoutB, sycl::sub_group > B, sycl::ext::oneapi::experimental::matrix::joint_matrix< T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, N, LayoutC, sycl::sub_group > C)
cl::sycl::half
cl::sycl::detail::half_impl::half half
Definition: aliases.hpp:77
cl::sycl::ext::oneapi::experimental::matrix::precision::tf32
Definition: matrix-tensorcore.hpp:22
cl::sycl::ext::oneapi::detail::joint_matrix_mad_impl
Definition: matrix-tensorcore.hpp:373
cl::sycl::ext::oneapi::experimental::matrix::joint_matrix_load
void joint_matrix_load(Group sg, joint_matrix< S, Use, NumRows, NumCols, Layout, Group > &res, multi_ptr< T, Space > src, size_t stride)
Definition: matrix-tensorcore.hpp:543
cl::sycl::ext::oneapi::detail::joint_matrix_load_impl
Definition: matrix-tensorcore.hpp:102
cl
We provide new interfaces for matrix muliply in this patch:
Definition: access.hpp:13
cl::sycl::ext::oneapi::experimental::matrix::round_to_tf32
float round_to_tf32(float a)
Definition: matrix-tensorcore.hpp:609
cl::sycl::ext::oneapi::detail::joint_matrix_store_impl< T, NumRows, NumCols, Layout, Space, typename std::enable_if_t< Layout==sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major||Layout==sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major > >::store
void store(sycl::ext::oneapi::experimental::matrix::joint_matrix< T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, NumRows, NumCols, Layout, sycl::sub_group > &src, multi_ptr< T, Space > dst, size_t stride)
Definition: matrix-tensorcore.hpp:321
cl::sycl::access::address_space
address_space
Definition: access.hpp:45
cl::sycl::image_channel_order::a
@ a
cl::sycl::ext::oneapi::sub_group
Definition: sub_group.hpp:108
cl::sycl::ext::oneapi::detail::joint_matrix_load_impl< S, T, Use, NumRows, NumCols, Layout, Space, typename std::enable_if_t< Layout==sycl::ext::oneapi::experimental::matrix::matrix_layout::row_major||Layout==sycl::ext::oneapi::experimental::matrix::matrix_layout::col_major > >::load
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< S, Use, NumRows, NumCols, Layout, sycl::sub_group > &res, multi_ptr< T, Space > src, size_t stride)
Definition: matrix-tensorcore.hpp:134
cl::sycl::ext::oneapi::detail::joint_matrix_store_impl
Definition: matrix-tensorcore.hpp:303
cl::sycl::ext::oneapi::experimental::matrix::matrix_use
matrix_use
Definition: matrix-tensorcore.hpp:17
std
Definition: accessor.hpp:2617
cl::sycl::ext::oneapi::experimental::matrix::joint_matrix
Definition: matrix-jit.hpp:56
cl::sycl::ext::oneapi::experimental::matrix::joint_matrix_store
void joint_matrix_store(Group sg, joint_matrix< T, matrix_use::accumulator, NumRows, NumCols, Layout, Group > &src, multi_ptr< T, Space > dst, size_t stride)
Definition: matrix-tensorcore.hpp:564
cl::sycl::detail::enable_if_t
typename std::enable_if< B, T >::type enable_if_t
Definition: stl_type_traits.hpp:24
__SYCL_JOINT_MATRIX_OVERLOAD
#define __SYCL_JOINT_MATRIX_OVERLOAD(type, use, M, N, frag_type, frag_size)
Definition: matrix-tensorcore.hpp:31
cl::sycl::ext::intel::detail::dynamic_extent
constexpr size_t dynamic_extent
Definition: matrix-aot-amx.hpp:50
__SYCL_INLINE_NAMESPACE
#define __SYCL_INLINE_NAMESPACE(X)
Definition: defines_elementary.hpp:12