DPC++ Runtime
Runtime libraries for oneAPI DPC++
matrix-tensorcores.hpp
Go to the documentation of this file.
1 
2 //===-------- matrix-tensorcores.hpp - matrix ext impl ---*- C++ -*-------===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 // ===-------------------------------------------------------------------=== //
9 
10 #pragma once
11 
12 #include "matrix-unified-utils.hpp"
13 
14 #include <sycl/aliases.hpp> // for half
15 #include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16
16 #include <sycl/half_type.hpp> // for half
17 #include <sycl/marray.hpp> // for marray
18 
19 #include <stddef.h> // for size_t
20 #include <stdint.h> // for int8_t, uint8_t, int32_t
21 #include <type_traits> // for enable_if_t
22 
23 namespace sycl {
24 inline namespace _V1 {
25 namespace ext {
26 namespace oneapi {
27 namespace experimental {
28 namespace matrix {
29 
30 template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
31  layout Layout = layout::dynamic>
32 struct joint_matrix;
33 
34 } // namespace matrix
35 } // namespace experimental
36 
37 namespace detail {
38 
39 template <typename T, sycl::ext::oneapi::experimental::matrix::use Use,
40  size_t Rows, size_t Cols,
42  sycl::ext::oneapi::experimental::matrix::layout::dynamic,
43  typename Cond = void>
45 
46 #define __SYCL_JOINT_MATRIX_OVERLOAD_ARR(TYPE, USE, M, N, SIZE) \
47  template <sycl::ext::oneapi::experimental::matrix::layout Layout> \
48  struct joint_matrix_cuda< \
49  TYPE, sycl::ext::oneapi::experimental::matrix::use::USE, M, N, Layout, \
50  typename std::enable_if_t< \
51  Layout == \
52  sycl::ext::oneapi::experimental::matrix::layout::row_major || \
53  Layout == \
54  sycl::ext::oneapi::experimental::matrix::layout::col_major>> { \
55  marray<TYPE, SIZE> wi_marray; \
56  };
57 
58 // m8n32k16
60 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(sycl::ext::oneapi::bfloat16, b, 16, 32, 16)
63 
64 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 8, 16, 4)
65 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 32, 16)
66 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 8, 16, 4)
67 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 32, 16)
68 // m32n8k16
69 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(sycl::ext::oneapi::bfloat16, a, 32, 16, 16)
70 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(sycl::ext::oneapi::bfloat16, b, 16, 8, 4)
73 
74 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 32, 16, 16)
75 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 8, 4)
76 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 32, 16, 16)
77 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 8, 4)
78 // m16n16k16
79 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(sycl::ext::oneapi::bfloat16, a, 16, 16, 8)
80 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(sycl::ext::oneapi::bfloat16, b, 16, 16, 8)
83 
84 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 16, 16, 8)
85 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 16, 8)
86 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, a, 16, 16, 8)
87 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(uint8_t, b, 16, 16, 8)
88 // m8n8k4 double only
89 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, a, 8, 4, 1)
90 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, b, 4, 8, 1)
91 
92 #undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR
93 
94 #define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(TYPE, M, N, SIZE) \
95  template <> \
96  struct joint_matrix_cuda< \
97  TYPE, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, \
98  sycl::ext::oneapi::experimental::matrix::layout::dynamic> { \
99  marray<TYPE, SIZE> wi_marray; \
100  };
101 
104 __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 8, 32, 8)
107 __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 32, 8, 8)
109 __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 16, 16, 8)
110 __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 16, 16, 8)
112 
113 #undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC
114 
115 #define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(PRECISION, USE, M, N, TYPE, \
116  SIZE) \
117  template <sycl::ext::oneapi::experimental::matrix::layout Layout> \
118  struct joint_matrix_cuda< \
119  PRECISION, sycl::ext::oneapi::experimental::matrix::use::USE, M, N, \
120  Layout, \
121  typename std::enable_if_t< \
122  Layout == \
123  sycl::ext::oneapi::experimental::matrix::layout::row_major || \
124  Layout == \
125  sycl::ext::oneapi::experimental::matrix::layout::col_major>> { \
126  marray<TYPE, SIZE> wi_marray; \
127  };
128 // m16n16k8 tf32 only
131  4)
133  sycl::ext::oneapi::experimental::matrix::precision::tf32, b, 8, 16, float,
134  4)
135 
136 #undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION
137 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
138 template <sycl::ext::oneapi::experimental::matrix::layout Layout>
139 constexpr int get_layout_id();
140 
141 template <>
142 constexpr int
143 get_layout_id<sycl::ext::oneapi::experimental::matrix::layout::row_major>() {
144  return 0;
145 }
146 
147 template <>
148 constexpr int
149 get_layout_id<sycl::ext::oneapi::experimental::matrix::layout::col_major>() {
150  return 1;
151 }
152 
153 template <sycl::ext::oneapi::experimental::matrix::layout Layout, typename S,
154  typename T, size_t NumRows, size_t NumCols,
155  access::address_space Space, access::decorated IsDecorated>
157  joint_matrix_cuda<
158  S, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
159  NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res,
160  multi_ptr<T, Space, IsDecorated> src, size_t stride) {
161  if constexpr (std::is_same_v<S, int32_t>) {
162  auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
163  if constexpr (NumRows == 16 && NumCols == 16) {
164  __imma_m16n16k16_ld_c(destptr, src.get(), stride,
165  get_layout_id<Layout>());
166  } else if constexpr (NumRows == 8 && NumCols == 32) {
167  __imma_m8n32k16_ld_c(destptr, src.get(), stride, get_layout_id<Layout>());
168  } else if constexpr (NumRows == 32 && NumCols == 8) {
169  __imma_m32n8k16_ld_c(destptr, src.get(), stride, get_layout_id<Layout>());
170  }
171  } else if constexpr (std::is_same_v<S, float>) {
172  auto dstptr = reinterpret_cast<float *>(&res.wi_marray);
173  if constexpr (NumRows == 16 && NumCols == 16) {
174  __hmma_m16n16k16_ld_c_f32(dstptr, src.get(), stride,
175  get_layout_id<Layout>());
176  } else if constexpr (NumRows == 8 && NumCols == 32) {
177  __hmma_m8n32k16_ld_c_f32(dstptr, src.get(), stride,
178  get_layout_id<Layout>());
179  } else if constexpr (NumRows == 32 && NumCols == 8) {
180  __hmma_m32n8k16_ld_c_f32(dstptr, src.get(), stride,
181  get_layout_id<Layout>());
182  }
183  } else if constexpr (std::is_same_v<S, half>) {
184  auto tileptr = reinterpret_cast<const int32_t *>(src.get());
185  auto dstptr = reinterpret_cast<int32_t *>(&res.wi_marray);
186  if constexpr (NumRows == 32 && NumCols == 8) {
187  __hmma_m32n8k16_ld_c_f16(dstptr, tileptr, stride,
188  get_layout_id<Layout>());
189  } else if constexpr (NumRows == 8 && NumCols == 32) {
190  __hmma_m8n32k16_ld_c_f16(dstptr, tileptr, stride,
191  get_layout_id<Layout>());
192  } else if constexpr (NumRows == 16 && NumCols == 16) {
193  __hmma_m16n16k16_ld_c_f16(dstptr, tileptr, stride,
194  get_layout_id<Layout>());
195  }
196  } else if constexpr (std::is_same_v<S, double>) {
197  __dmma_m8n8k4_ld_c(reinterpret_cast<double *>(&res.wi_marray), src.get(),
198  stride, get_layout_id<Layout>());
199  }
200 };
201 
202 template <typename S, typename T, size_t NumRows, size_t NumCols,
203  access::address_space Space, access::decorated IsDecorated>
204 void load_accumulator_cuda(
205  joint_matrix_cuda<
206  S, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
207  NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res,
208  multi_ptr<T, Space, IsDecorated> src, size_t stride,
210  switch (Layout) {
211  case sycl::ext::oneapi::experimental::matrix::layout::row_major:
213  sycl::ext::oneapi::experimental::matrix::layout::row_major>(res, src,
214  stride);
215  break;
216  case sycl::ext::oneapi::experimental::matrix::layout::col_major:
218  sycl::ext::oneapi::experimental::matrix::layout::col_major>(res, src,
219  stride);
220  break;
221  default:
222  assert(false && "Invalid layout specified!");
223  }
224 }
225 
226 template <
227  typename S, typename T, size_t NumRows, size_t NumCols,
230  access::address_space Space, access::decorated IsDecorated,
231  std::enable_if_t<
232  Layout == sycl::ext::oneapi::experimental::matrix::layout::row_major ||
233  Layout ==
234  sycl::ext::oneapi::experimental::matrix::layout::col_major,
235  bool> = true>
236 void load_multiplicand_cuda(
237  joint_matrix_cuda<S, Use, NumRows, NumCols, Layout> &res,
238  multi_ptr<T, Space, IsDecorated> src, size_t stride) {
239  if constexpr (std::is_same_v<S, sycl::ext::oneapi::bfloat16>) {
240  auto tileptr = reinterpret_cast<const int32_t *>(src.get());
241  auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
242  if constexpr (NumRows == 16 && NumCols == 16) {
244  __mma_bf16_m16n16k16_ld_a(destptr, tileptr, stride,
245  get_layout_id<Layout>());
246  } else if constexpr (Use ==
248  __mma_bf16_m16n16k16_ld_b(destptr, tileptr, stride,
249  get_layout_id<Layout>());
250  }
251  } else if constexpr (NumRows == 8 && NumCols == 16) {
252  __mma_bf16_m8n32k16_ld_a(destptr, tileptr, stride,
253  get_layout_id<Layout>());
254  } else if constexpr (NumRows == 16 && NumCols == 32) {
255  __mma_bf16_m8n32k16_ld_b(destptr, tileptr, stride,
256  get_layout_id<Layout>());
257  } else if constexpr (NumRows == 32 && NumCols == 16) {
258  __mma_bf16_m32n8k16_ld_a(destptr, tileptr, stride,
259  get_layout_id<Layout>());
260  } else if constexpr (NumRows == 16 && NumCols == 8) {
261  __mma_bf16_m32n8k16_ld_b(destptr, tileptr, stride,
262  get_layout_id<Layout>());
263  }
264  } else if constexpr (std::is_same_v<S, uint8_t>) {
265  auto tileptr = reinterpret_cast<const int32_t *>(src.get());
266  auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
267  if constexpr (NumRows == 16 && NumCols == 16) {
269  __imma_m16n16k16_ld_a_u8(destptr, tileptr, stride,
270  get_layout_id<Layout>());
271  } else if constexpr (Use ==
273  __imma_m16n16k16_ld_b_u8(destptr, tileptr, stride,
274  get_layout_id<Layout>());
275  }
276  } else if constexpr (NumRows == 8 && NumCols == 16) {
277  __imma_m8n32k16_ld_a_u8(destptr, tileptr, stride,
278  get_layout_id<Layout>());
279  } else if constexpr (NumRows == 16 && NumCols == 32) {
280  __imma_m8n32k16_ld_b_u8(destptr, tileptr, stride,
281  get_layout_id<Layout>());
282  } else if constexpr (NumRows == 32 && NumCols == 16) {
283  __imma_m32n8k16_ld_a_u8(destptr, tileptr, stride,
284  get_layout_id<Layout>());
285  } else if constexpr (NumRows == 16 && NumCols == 8) {
286  __imma_m32n8k16_ld_b_u8(destptr, tileptr, stride,
287  get_layout_id<Layout>());
288  }
289  } else if constexpr (std::is_same_v<S, int8_t>) {
290  auto tileptr = reinterpret_cast<const int32_t *>(src.get());
291  auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
292  if constexpr (NumRows == 16 && NumCols == 16) {
294  __imma_m16n16k16_ld_a_s8(destptr, tileptr, stride,
295  get_layout_id<Layout>());
296  } else if constexpr (Use ==
298  __imma_m16n16k16_ld_b_s8(destptr, tileptr, stride,
299  get_layout_id<Layout>());
300  }
301  } else if constexpr (NumRows == 8 && NumCols == 16) {
302  __imma_m8n32k16_ld_a_s8(destptr, tileptr, stride,
303  get_layout_id<Layout>());
304  } else if constexpr (NumRows == 16 && NumCols == 32) {
305  __imma_m8n32k16_ld_b_s8(destptr, tileptr, stride,
306  get_layout_id<Layout>());
307  } else if constexpr (NumRows == 32 && NumCols == 16) {
308  __imma_m32n8k16_ld_a_s8(destptr, tileptr, stride,
309  get_layout_id<Layout>());
310  } else if constexpr (NumRows == 16 && NumCols == 8) {
311  __imma_m32n8k16_ld_b_s8(destptr, tileptr, stride,
312  get_layout_id<Layout>());
313  }
314  } else if constexpr (std::is_same_v<S, half>) {
315  auto tileptr = reinterpret_cast<const int32_t *>(src.get());
316  auto dstptr = reinterpret_cast<int32_t *>(&res.wi_marray);
317  if constexpr (NumRows == 16 && NumCols == 16) {
319  __hmma_m16n16k16_ld_a(dstptr, tileptr, stride, get_layout_id<Layout>());
320  } else if constexpr (Use ==
322  __hmma_m16n16k16_ld_b(dstptr, tileptr, stride, get_layout_id<Layout>());
323  }
324  } else if constexpr (NumRows == 8 && NumCols == 16) {
325  __hmma_m8n32k16_ld_a(dstptr, tileptr, stride, get_layout_id<Layout>());
326  } else if constexpr (NumRows == 16 && NumCols == 32) {
327  __hmma_m8n32k16_ld_b(dstptr, tileptr, stride, get_layout_id<Layout>());
328  } else if constexpr (NumRows == 32 && NumCols == 16) {
329  __hmma_m32n8k16_ld_a(dstptr, tileptr, stride, get_layout_id<Layout>());
330  } else if constexpr (NumRows == 16 && NumCols == 8) {
331  __hmma_m32n8k16_ld_b(dstptr, tileptr, stride, get_layout_id<Layout>());
332  }
333 
334  } else if constexpr (std::is_same_v<S, sycl::ext::oneapi::experimental::
335  matrix::precision::tf32>) {
336  auto tileptr = reinterpret_cast<const int32_t *>(src.get());
337  auto dstptr = reinterpret_cast<int32_t *>(&res.wi_marray);
338  if constexpr (NumRows == 16 && NumCols == 8) {
339  __mma_tf32_m16n16k8_ld_a(dstptr, tileptr, stride,
340  get_layout_id<Layout>());
341  } else if constexpr (NumRows == 8 && NumCols == 16) {
342  __mma_tf32_m16n16k8_ld_b(dstptr, tileptr, stride,
343  get_layout_id<Layout>());
344  }
345  } else if constexpr (std::is_same_v<S, double>) {
346  auto dstptr = reinterpret_cast<double *>(&res.wi_marray);
348  __dmma_m8n8k4_ld_a(dstptr, src.get(), stride, get_layout_id<Layout>());
349  } else if constexpr (Use ==
351  __dmma_m8n8k4_ld_b(dstptr, src.get(), stride, get_layout_id<Layout>());
352  }
353  }
354 }
355 
356 template <sycl::ext::oneapi::experimental::matrix::layout Layout, typename T,
357  size_t NumRows, size_t NumCols, access::address_space Space,
358  access::decorated IsDecorated>
359 void store_layoutT(
360  const joint_matrix_cuda<
361  T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
362  NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
363  multi_ptr<T, Space, IsDecorated> dst, size_t stride) {
364  if constexpr (NumRows == 16 && NumCols == 16) {
365  if constexpr (std::is_same_v<T, float>) {
366  __hmma_m16n16k16_st_c_f32(dst.get(), &src.wi_marray[0], stride,
367  get_layout_id<Layout>());
368  } else if constexpr (std::is_same_v<T, int32_t>) {
369  __imma_m16n16k16_st_c_i32(dst.get(), &src.wi_marray[0], stride,
370  get_layout_id<Layout>());
371  } else if constexpr (std::is_same_v<T, half>) {
372  __hmma_m16n16k16_st_c_f16(
373  reinterpret_cast<int32_t *>(dst.get()),
374  reinterpret_cast<const int32_t *>(&src.wi_marray[0]), stride,
375  get_layout_id<Layout>());
376  }
377  } else if constexpr (NumRows == 8 && NumCols == 32) {
378  if constexpr (std::is_same_v<T, float>) {
379  __hmma_m8n32k16_st_c_f32(dst.get(), &src.wi_marray[0], stride,
380  get_layout_id<Layout>());
381  } else if constexpr (std::is_same_v<T, int32_t>) {
382  __imma_m8n32k16_st_c_i32(dst.get(), &src.wi_marray[0], stride,
383  get_layout_id<Layout>());
384  } else if constexpr (std::is_same_v<T, half>) {
385  __hmma_m8n32k16_st_c_f16(
386  reinterpret_cast<int32_t *>(dst.get()),
387  reinterpret_cast<const int32_t *>(&src.wi_marray[0]), stride,
388  get_layout_id<Layout>());
389  }
390  } else if constexpr (NumRows == 32 && NumCols == 8) {
391  if constexpr (std::is_same_v<T, float>) {
392  __hmma_m32n8k16_st_c_f32(dst.get(), &src.wi_marray[0], stride,
393  get_layout_id<Layout>());
394  } else if constexpr (std::is_same_v<T, int32_t>) {
395  __imma_m32n8k16_st_c_i32(dst.get(), &src.wi_marray[0], stride,
396  get_layout_id<Layout>());
397  } else if constexpr (std::is_same_v<T, half>) {
398  __hmma_m32n8k16_st_c_f16(
399  reinterpret_cast<int32_t *>(dst.get()),
400  reinterpret_cast<const int32_t *>(&src.wi_marray[0]), stride,
401  get_layout_id<Layout>());
402  }
403  } else if constexpr (std::is_same_v<T, double>) {
404  __dmma_m8n8k4_st_c_f64(dst.get(), &src.wi_marray[0], stride,
405  get_layout_id<Layout>());
406  }
407 }
408 
409 template <typename T, size_t NumRows, size_t NumCols,
410  access::address_space Space, access::decorated IsDecorated>
411 void joint_matrix_store_cuda(
412  const joint_matrix_cuda<
413  T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
414  NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
415  multi_ptr<T, Space, IsDecorated> dst, size_t stride,
417  switch (Layout) {
418  case sycl::ext::oneapi::experimental::matrix::layout::row_major:
419  store_layoutT<sycl::ext::oneapi::experimental::matrix::layout::row_major>(
420  src, dst, stride);
421  break;
422  case sycl::ext::oneapi::experimental::matrix::layout::col_major:
423  store_layoutT<sycl::ext::oneapi::experimental::matrix::layout::col_major>(
424  src, dst, stride);
425  break;
426  default:
427  assert(false && "Invalid layout specified!");
428  }
429 }
430 
433 constexpr int get_layout_pair_id();
434 
435 template <>
436 constexpr int get_layout_pair_id<
437  sycl::ext::oneapi::experimental::matrix::layout::row_major,
438  sycl::ext::oneapi::experimental::matrix::layout::row_major>() {
439  return 0;
440 }
441 
442 template <>
443 constexpr int get_layout_pair_id<
444  sycl::ext::oneapi::experimental::matrix::layout::row_major,
445  sycl::ext::oneapi::experimental::matrix::layout::col_major>() {
446  return 1;
447 }
448 
449 template <>
450 constexpr int get_layout_pair_id<
451  sycl::ext::oneapi::experimental::matrix::layout::col_major,
452  sycl::ext::oneapi::experimental::matrix::layout::row_major>() {
453  return 2;
454 }
455 
456 template <>
457 constexpr int get_layout_pair_id<
458  sycl::ext::oneapi::experimental::matrix::layout::col_major,
459  sycl::ext::oneapi::experimental::matrix::layout::col_major>() {
460  return 3;
461 }
462 
463 template <
464  typename Tm, typename Tc, typename Td, std::size_t M, std::size_t K,
467  std::enable_if_t<
468  (LayoutA ==
469  sycl::ext::oneapi::experimental::matrix::layout::row_major ||
470  LayoutA ==
471  sycl::ext::oneapi::experimental::matrix::layout::col_major) &&
472  (LayoutB ==
473  sycl::ext::oneapi::experimental::matrix::layout::row_major ||
474  LayoutB ==
475  sycl::ext::oneapi::experimental::matrix::layout::col_major),
476  bool> = true>
477 void joint_matrix_mad_cuda(
478  joint_matrix_cuda<
479  Td, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
480  sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D,
481  const joint_matrix_cuda<Tm, sycl::ext::oneapi::experimental::matrix::use::a,
482  M, K, LayoutA> &A,
483  const joint_matrix_cuda<Tm, sycl::ext::oneapi::experimental::matrix::use::b,
484  K, N, LayoutB> &B,
485  const joint_matrix_cuda<
486  Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
487  sycl::ext::oneapi::experimental::matrix::layout::dynamic> &C) {
488  if constexpr (M == 16 && N == 16 && K == 16) {
489  if constexpr (std::is_same_v<Tc, int32_t>) {
490  auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
491  auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
492  auto ptrC = reinterpret_cast<const int32_t *>(&C.wi_marray);
493  auto ptrD = reinterpret_cast<int32_t *>(&D.wi_marray);
494  if constexpr (std::is_same_v<Tm, int8_t>) {
495  __imma_m16n16k16_mma_s8(ptrD, ptrA, ptrB, ptrC,
496  get_layout_pair_id<LayoutA, LayoutB>(), 0);
497  } else if constexpr (std::is_same_v<Tm, uint8_t>) {
498  __imma_m16n16k16_mma_u8(ptrD, ptrA, ptrB, ptrC,
499  get_layout_pair_id<LayoutA, LayoutB>(), 0);
500  }
501  } else if constexpr (std::is_same_v<Tm, half>) {
502  auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
503  auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
504  if constexpr (std::is_same_v<Tc, float>) {
505  if constexpr (std::is_same<Td, float>::value) {
506  __hmma_m16n16k16_mma_f32f32(
507  reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
508  reinterpret_cast<const float *>(&C.wi_marray),
509  get_layout_pair_id<LayoutA, LayoutB>(), 0);
510  } else {
511  __hmma_m16n16k16_mma_f16f32(
512  reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
513  reinterpret_cast<const float *>(&C.wi_marray),
514  get_layout_pair_id<LayoutA, LayoutB>(), 0);
515  }
516  } else if constexpr (std::is_same_v<Tc, half>) {
517  if constexpr (std::is_same<Td, float>::value) {
518  __hmma_m16n16k16_mma_f32f16(
519  reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
520  reinterpret_cast<const int32_t *>(&C.wi_marray),
521  get_layout_pair_id<LayoutA, LayoutB>(), 0);
522  } else {
523  __hmma_m16n16k16_mma_f16f16(
524  reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
525  reinterpret_cast<const int32_t *>(&C.wi_marray),
526  get_layout_pair_id<LayoutA, LayoutB>(), 0);
527  }
528  }
529  } else if constexpr (std::is_same_v<Tm, sycl::ext::oneapi::bfloat16>) {
530  __mma_bf16_m16n16k16_mma_f32(
531  reinterpret_cast<float *>(&D.wi_marray),
532  reinterpret_cast<const int32_t *>(&A.wi_marray),
533  reinterpret_cast<const int32_t *>(&B.wi_marray),
534  reinterpret_cast<const float *>(&C.wi_marray),
535  get_layout_pair_id<LayoutA, LayoutB>(), 0);
536  }
537  } else if constexpr (M == 8 && N == 32 && K == 16) {
538  if constexpr (std::is_same_v<Tc, int32_t>) {
539  auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
540  auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
541  auto ptrC = reinterpret_cast<const int32_t *>(&C.wi_marray);
542  auto ptrD = reinterpret_cast<int32_t *>(&D.wi_marray);
543  if constexpr (std::is_same_v<Tm, int8_t>) {
544  __imma_m8n32k16_mma_s8(ptrD, ptrA, ptrB, ptrC,
545  get_layout_pair_id<LayoutA, LayoutB>(), 0);
546  } else if constexpr (std::is_same_v<Tm, uint8_t>) {
547  __imma_m8n32k16_mma_u8(ptrD, ptrA, ptrB, ptrC,
548  get_layout_pair_id<LayoutA, LayoutB>(), 0);
549  }
550  } else if constexpr (std::is_same_v<Tm, half>) {
551  auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
552  auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
553  if constexpr (std::is_same_v<Tc, float>) {
554  if constexpr (std::is_same<Td, float>::value) {
555  __hmma_m8n32k16_mma_f32f32(
556  reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
557  reinterpret_cast<const float *>(&C.wi_marray),
558  get_layout_pair_id<LayoutA, LayoutB>(), 0);
559  } else {
560  __hmma_m8n32k16_mma_f16f32(
561  reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
562  reinterpret_cast<const float *>(&C.wi_marray),
563  get_layout_pair_id<LayoutA, LayoutB>(), 0);
564  }
565  } else if constexpr (std::is_same_v<Tc, half>) {
566  if constexpr (std::is_same<Td, float>::value) {
567  __hmma_m8n32k16_mma_f32f16(
568  reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
569  reinterpret_cast<const int32_t *>(&C.wi_marray),
570  get_layout_pair_id<LayoutA, LayoutB>(), 0);
571  } else {
572  __hmma_m8n32k16_mma_f16f16(
573  reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
574  reinterpret_cast<const int32_t *>(&C.wi_marray),
575  get_layout_pair_id<LayoutA, LayoutB>(), 0);
576  }
577  }
578  } else if constexpr (std::is_same_v<Tm, sycl::ext::oneapi::bfloat16>) {
579  __mma_bf16_m8n32k16_mma_f32(
580  reinterpret_cast<float *>(&D.wi_marray),
581  reinterpret_cast<const int32_t *>(&A.wi_marray),
582  reinterpret_cast<const int32_t *>(&B.wi_marray),
583  reinterpret_cast<const float *>(&C.wi_marray),
584  get_layout_pair_id<LayoutA, LayoutB>(), 0);
585  }
586  } else if constexpr (M == 32 && N == 8 && K == 16) {
587  if constexpr (std::is_same_v<Tc, int32_t>) {
588  auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
589  auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
590  auto ptrC = reinterpret_cast<const int32_t *>(&C.wi_marray);
591  auto ptrD = reinterpret_cast<int32_t *>(&D.wi_marray);
592  if constexpr (std::is_same_v<Tm, int8_t>) {
593  __imma_m32n8k16_mma_s8(ptrD, ptrA, ptrB, ptrC,
594  get_layout_pair_id<LayoutA, LayoutB>(), 0);
595  } else if constexpr (std::is_same_v<Tm, uint8_t>) {
596  __imma_m32n8k16_mma_u8(ptrD, ptrA, ptrB, ptrC,
597  get_layout_pair_id<LayoutA, LayoutB>(), 0);
598  }
599  } else if constexpr (std::is_same_v<Tm, sycl::ext::oneapi::bfloat16>) {
600  __mma_bf16_m32n8k16_mma_f32(
601  reinterpret_cast<float *>(&D.wi_marray),
602  reinterpret_cast<const int32_t *>(&A.wi_marray),
603  reinterpret_cast<const int32_t *>(&B.wi_marray),
604  reinterpret_cast<const float *>(&C.wi_marray),
605  get_layout_pair_id<LayoutA, LayoutB>(), 0);
606  } else if constexpr (std::is_same_v<Tm, half>) {
607 
608  auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
609  auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
610  if constexpr (std::is_same_v<Tc, float>) {
611  if constexpr (std::is_same<Td, float>::value) {
612  __hmma_m32n8k16_mma_f32f32(
613  reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
614  reinterpret_cast<const float *>(&C.wi_marray),
615  get_layout_pair_id<LayoutA, LayoutB>(), 0);
616  } else {
617  __hmma_m32n8k16_mma_f16f32(
618  reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
619  reinterpret_cast<const float *>(&C.wi_marray),
620  get_layout_pair_id<LayoutA, LayoutB>(), 0);
621  }
622  } else if constexpr (std::is_same_v<Tc, half>) {
623  if constexpr (std::is_same<Td, float>::value) {
624  __hmma_m32n8k16_mma_f32f16(
625  reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
626  reinterpret_cast<const int32_t *>(&C.wi_marray),
627  get_layout_pair_id<LayoutA, LayoutB>(), 0);
628  } else {
629  __hmma_m32n8k16_mma_f16f16(
630  reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
631  reinterpret_cast<const int32_t *>(&C.wi_marray),
632  get_layout_pair_id<LayoutA, LayoutB>(), 0);
633  }
634  }
635  }
636  } else if constexpr (M == 16 && N == 16 && K == 8) {
637  __mma_tf32_m16n16k8_mma_f32(reinterpret_cast<float *>(&D.wi_marray),
638  reinterpret_cast<const int32_t *>(&A.wi_marray),
639  reinterpret_cast<const int32_t *>(&B.wi_marray),
640  reinterpret_cast<const float *>(&C.wi_marray),
641  get_layout_pair_id<LayoutA, LayoutB>(), 0);
642  } else if constexpr (std::is_same_v<Tm, double>) {
643  __dmma_m8n8k4_mma_f64(reinterpret_cast<double *>(&D.wi_marray),
644  reinterpret_cast<const double *>(&A.wi_marray),
645  reinterpret_cast<const double *>(&B.wi_marray),
646  reinterpret_cast<const double *>(&C.wi_marray),
647  get_layout_pair_id<LayoutA, LayoutB>(), 0);
648  }
649 }
650 
651 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
652 
653 } // namespace detail
654 } // namespace oneapi
655 } // namespace ext
656 } // namespace _V1
657 } // namespace sycl
#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(PRECISION, USE, M, N, TYPE, SIZE)
#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(TYPE, M, N, SIZE)
#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR(TYPE, USE, M, N, SIZE)
void store_layoutT(const joint_matrix_hip< T, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &src, multi_ptr< T, Space, IsDecorated > dst, size_t stride, Group &sg)
Definition: matrix-hip.hpp:249
void load_accumulator_layoutT(joint_matrix_hip< S, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, Group &sg)
Definition: matrix-hip.hpp:115
auto autodecltype(a) b
sycl::detail::half_impl::half half
Definition: aliases.hpp:101
Definition: access.hpp:18