DPC++ Runtime
Runtime libraries for oneAPI DPC++
matrix-unified.hpp
Go to the documentation of this file.
1 //===------- matrix-unified.hpp - SYCL matrix extension ----*- C++ -*------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // ===--------------------------------------------------------------------=== //
8 
9 #pragma once
10 
11 #include "matrix-intel.hpp"
12 
13 #if defined(__SYCL_DEVICE_ONLY__)
14 #if defined(__NVPTX__)
15 #include "matrix-tensorcores.hpp"
16 #elif defined(__gfx90a__)
17 #include "matrix-hip.hpp"
18 #endif // defined(__NVPTX__)
19 #endif // defined(__SYCL_DEVICE_ONLY__)
20 
21 #include <sycl/access/access.hpp> // for address_space
22 #include <sycl/detail/defines_elementary.hpp> // for __SYCL_ALWAYS_...
23 #include <sycl/detail/pi.h> // for PI_ERROR_INVAL...
24 #include <sycl/exception.hpp> // for runtime_error
25 #include <sycl/ext/oneapi/matrix/matrix-unified-utils.hpp> // for layout, use, tf32, convertMatrixUseEnumToString
26 #include <sycl/ext/oneapi/matrix/query-types.hpp> // for convertTypeToMatrixTypeString
27 #include <sycl/marray.hpp> // for marray
28 #include <sycl/multi_ptr.hpp> // for multi_ptr
29 
30 #include <cstring> // for size_t, memcpy
31 #include <stdint.h> // for uint32_t
32 #include <tuple> // for ignore, _Swall...
33 #include <type_traits> // for is_same, remov...
34 
35 namespace sycl {
36 inline namespace _V1 {
37 namespace ext {
38 namespace oneapi {
39 namespace experimental {
40 namespace matrix {
41 
42 template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
43  layout Layout>
44 struct joint_matrix {
45 
46 #if defined(__SYCL_DEVICE_ONLY__)
47 #if defined(__NVPTX__)
49  matrix_impl;
50 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
52  matrix_impl;
53 #elif defined(__SPIR__) || defined(__SPIRV__)
57 #else
58  static_assert(false, "The joint_matrix API is only supported by the Intel, "
59  "CUDA and HIP (GFX90A) backends");
60 #endif // defined(__NVPTX__)
61 #endif // defined(__SYCL_DEVICE_ONLY__)
62 
63 #if defined(__SYCL_DEVICE_ONLY__)
64  [[__sycl_detail__::add_ir_attributes_function(
65  "sycl-joint-matrix-type", "sycl-joint-matrix-use",
66  "sycl-joint-matrix-rows", "sycl-joint-matrix-cols",
67  sycl::detail::convertTypeToMatrixTypeString<T>(),
69 #endif // defined(__SYCL_DEVICE_ONLY__)
71 #ifndef __SYCL_DEVICE_ONLY__
72  throw runtime_error("joint matrix is not supported on host device.",
73  PI_ERROR_INVALID_DEVICE);
74 #endif
75  }
76 #ifdef __SYCL_DEVICE_ONLY__
77 #if defined(__SPIR__) || defined(__SPIRV__)
78  joint_matrix(const joint_matrix &other) = delete;
79  joint_matrix &operator=(const joint_matrix &rhs) = delete;
80 #endif // defined(__SPIR__) || defined(__SPIRV__)
81 #endif
82 };
83 
84 template <typename Group, typename T, use Use, size_t M, size_t N,
85  layout Layout, typename F>
86 inline __SYCL_ALWAYS_INLINE void
88  F &&lambda) {
89 #if defined(__SYCL_DEVICE_ONLY__)
90 #if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__)
91  std::ignore = sg;
92  for (int i = 0; i < jm.matrix_impl.wi_marray.size(); i++) {
93  lambda(jm.matrix_impl.wi_marray[i]);
94  }
95 #else // NVPTX
96  using storage_element_type =
98  T>::storage_element_type;
99  auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jm);
100  for (int i = 0; i < wi_data_c.length(); i++) {
101  storage_element_type element = wi_data_c[i];
102  lambda(element);
103  wi_data_c[i] = element;
104  }
105 #endif
106 #else
107  std::ignore = sg;
108  std::ignore = jm;
109  std::ignore = lambda;
110  throw runtime_error("joint matrix is not supported on host device.",
111  PI_ERROR_INVALID_DEVICE);
112 #endif
113  return;
114 }
115 
116 template <typename Group, typename T, use Use, size_t M, size_t N,
117  layout Layout, typename F>
118 inline __SYCL_ALWAYS_INLINE void
121  F &&lambda) {
122 #if defined(__SYCL_DEVICE_ONLY__)
123 #if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__)
124  std::ignore = sg;
125  for (int i = 0; i < jmsrc.matrix_impl.wi_marray.size(); i++) {
126  lambda(jmsrc.matrix_impl.wi_marray[i], jmdest.matrix_impl.wi_marray[i]);
127  }
128 #else // NVPTX
129  using storage_element_type =
131  T>::storage_element_type;
132  auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jmsrc);
133  auto wi_data_d = sycl::ext::oneapi::detail::get_wi_data(sg, jmdest);
134  for (int i = 0; i < wi_data_c.length(); i++) {
135  storage_element_type elementsrc = wi_data_c[i];
136  storage_element_type elementdest = wi_data_d[i];
137  lambda(elementsrc, elementdest);
138  wi_data_d[i] = elementdest;
139  }
140 #endif
141 #else
142  std::ignore = sg;
143  std::ignore = jmsrc;
144  std::ignore = jmdest;
145  std::ignore = lambda;
146  throw runtime_error("joint matrix is not supported on host device.",
147  PI_ERROR_INVALID_DEVICE);
148 #endif
149  return;
150 }
151 
152 template <typename Group, typename T, size_t NumRows, size_t NumCols, use Use,
153  layout Layout, typename T2>
154 inline __SYCL_ALWAYS_INLINE void
157  const T2 &v) {
158 #if defined(__SYCL_DEVICE_ONLY__)
159 #if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__)
160  res.matrix_impl.wi_marray = v;
161 #else
162  using storage_element_type =
164  T>::storage_element_type;
165  res.spvm =
166  __spirv_CompositeConstruct<storage_element_type, T, NumRows, NumCols,
169  static_cast<storage_element_type>(v));
170 #endif // defined(__NVPTX__)
171 #else
172  std::ignore = res;
173  std::ignore = v;
174  throw runtime_error("joint matrix is not supported on host device.",
175  PI_ERROR_INVALID_DEVICE);
176 #endif // defined(__SYCL_DEVICE_ONLY__)
177 }
178 
179 template <
180  typename Group, typename S, typename T, size_t NumRows, size_t NumCols,
181  access::address_space Space, access::decorated IsDecorated,
182  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value, bool> =
183  true>
185  Group sg,
186  joint_matrix<Group, S, use::accumulator, NumRows, NumCols,
187  sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res,
188  multi_ptr<T, Space, IsDecorated> src, size_t stride,
190 #if defined(__SYCL_DEVICE_ONLY__)
191  static_assert(Space != access::address_space::private_space,
192  "Joint Matrix doesn't support load from private memory!");
193 #if defined(__NVPTX__)
194  std::ignore = sg;
195  sycl::ext::oneapi::detail::load_accumulator_cuda(res.matrix_impl, src, stride,
196  Layout);
197 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
198  sycl::ext::oneapi::detail::load_accumulator_hip(res.matrix_impl, src, stride,
199  Layout, sg);
200 #else
201  std::ignore = sg;
202  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
203  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(src);
204  res.spvm = __spirv_JointMatrixLoadINTEL<
205  DecorT, S, NumRows, NumCols,
208  Ptr, stride, sycl::detail::joint_matrix_layout_to_spv(Layout),
210 #endif // defined(__NVPTX__)
211 #else
212  std::ignore = sg;
213  std::ignore = res;
214  std::ignore = src;
215  std::ignore = stride;
216  std::ignore = Layout;
217  throw runtime_error("joint matrix is not supported on host device.",
218  PI_ERROR_INVALID_DEVICE);
219 #endif // defined(__SYCL_DEVICE_ONLY__)
220 }
221 
222 template <
223  typename Group, typename S, typename T, use Use, size_t NumRows,
224  size_t NumCols, matrix::layout Layout, access::address_space Space,
225  access::decorated IsDecorated,
226  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
227  (std::is_same<S, precision::tf32>::value &&
228  std::is_same<std::remove_const_t<T>, float>::value),
229  bool> = true>
230 inline __SYCL_ALWAYS_INLINE void
233  multi_ptr<T, Space, IsDecorated> src, size_t stride) {
234 #if defined(__SYCL_DEVICE_ONLY__)
235  static_assert(Space != access::address_space::private_space,
236  "Joint Matrix doesn't support load from private memory!");
237 #if defined(__NVPTX__)
238  std::ignore = sg;
239  sycl::ext::oneapi::detail::load_multiplicand_cuda<S, T, NumRows, NumCols, Use,
240  Layout, Space>(
241  res.matrix_impl, src, stride);
242 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
244  NumCols, Use, Layout, Space>(
245  res.matrix_impl, src, stride, sg);
246 #else
247  std::ignore = sg;
248  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
249  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(src);
250  res.spvm =
251  __spirv_JointMatrixLoadINTEL<DecorT, S, NumRows, NumCols,
256 #endif // defined(__NVPTX__)
257 #else
258  std::ignore = sg;
259  std::ignore = res;
260  std::ignore = src;
261  std::ignore = stride;
262  throw runtime_error("joint matrix is not supported on host device.",
263  PI_ERROR_INVALID_DEVICE);
264 #endif // defined(__SYCL_DEVICE_ONLY__)
265 }
266 
267 template <typename Group, typename S, typename T, size_t NumRows,
268  size_t NumCols, typename PropertyListT,
269  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value,
270  bool> = true>
272  Group sg,
273  joint_matrix<Group, S, use::accumulator, NumRows, NumCols,
274  sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res,
276  size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout) {
277 #if defined(__SYCL_DEVICE_ONLY__)
278 #if defined(__NVPTX__)
279  std::ignore = sg;
280  throw runtime_error("Use joint_matrix_load on multi_ptr on Nvidia device.",
281  PI_ERROR_INVALID_DEVICE);
282 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
283  throw runtime_error("Use joint_matrix_load on multi_ptr on AMD device.",
284  PI_ERROR_INVALID_DEVICE);
285 #else
286  std::ignore = sg;
287  T *Ptr = src.get();
288  res.spvm = __spirv_JointMatrixLoadINTEL<
289  T, S, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
291  Ptr, stride, sycl::detail::joint_matrix_layout_to_spv(Layout),
293 #endif // defined(__NVPTX__)
294 #else
295  std::ignore = sg;
296  std::ignore = res;
297  std::ignore = src;
298  std::ignore = stride;
299  std::ignore = Layout;
300  throw runtime_error("joint matrix is not supported on host device.",
301  PI_ERROR_INVALID_DEVICE);
302 #endif // defined(__SYCL_DEVICE_ONLY__)
303 }
304 
305 template <
306  typename Group, typename S, typename T, use Use, size_t NumRows,
307  size_t NumCols, matrix::layout Layout, typename PropertyListT,
308  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
309  (std::is_same<S, precision::tf32>::value &&
310  std::is_same<std::remove_const_t<T>, float>::value),
311  bool> = true>
315  size_t stride) {
316 #if defined(__SYCL_DEVICE_ONLY__)
317 #if defined(__NVPTX__)
318  std::ignore = sg;
319  throw runtime_error("Use joint_matrix_load on multi_ptr on Nvidia device.",
320  PI_ERROR_INVALID_DEVICE);
321 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
322  throw runtime_error("Use joint_matrix_load on multi_ptr on AMD device.",
323  PI_ERROR_INVALID_DEVICE);
324 #else
325  std::ignore = sg;
326  T *Ptr = src.get();
327  res.spvm =
328  __spirv_JointMatrixLoadINTEL<T, S, NumRows, NumCols,
333 #endif // defined(__NVPTX__)
334 #else
335  std::ignore = sg;
336  std::ignore = res;
337  std::ignore = src;
338  std::ignore = stride;
339  throw runtime_error("joint matrix is not supported on host device.",
340  PI_ERROR_INVALID_DEVICE);
341 #endif // defined(__SYCL_DEVICE_ONLY__)
342 }
343 
344 template <typename Group, typename T, size_t NumRows, size_t NumCols,
345  access::address_space Space, access::decorated IsDecorated>
347  Group sg,
348  const joint_matrix<Group, T, use::accumulator, NumRows, NumCols,
349  sycl::ext::oneapi::experimental::matrix::layout::dynamic>
350  &src,
351  multi_ptr<T, Space, IsDecorated> dst, size_t stride,
353 #if defined(__SYCL_DEVICE_ONLY__)
354  static_assert(Space != access::address_space::private_space,
355  "Joint Matrix doesn't support store to private memory!");
356 #if defined(__NVPTX__)
357  std::ignore = sg;
358  sycl::ext::oneapi::detail::joint_matrix_store_cuda<T, NumRows, NumCols,
359  Space>(
360  src.matrix_impl, dst, stride, Layout);
361 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
362  sycl::ext::oneapi::detail::joint_matrix_store_hip<Group, T, NumRows, NumCols,
363  Space>(src.matrix_impl, dst,
364  stride, Layout, sg);
365 #else
366  std::ignore = sg;
367  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
368  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(dst);
369  __spirv_JointMatrixStoreINTEL<
370  DecorT, T, NumRows, NumCols,
373  Ptr, src.spvm, stride, sycl::detail::joint_matrix_layout_to_spv(Layout),
375 #endif // defined(__NVPTX__)
376 #else
377  std::ignore = sg;
378  std::ignore = src;
379  std::ignore = dst;
380  std::ignore = stride;
381  std::ignore = Layout;
382  throw runtime_error("joint matrix is not supported on host device.",
383  PI_ERROR_INVALID_DEVICE);
384 #endif // defined(__SYCL_DEVICE_ONLY__)
385 }
386 
387 template <typename Group, typename T, size_t NumRows, size_t NumCols,
388  typename PropertyListT>
390  Group sg,
391  const joint_matrix<Group, T, use::accumulator, NumRows, NumCols,
392  sycl::ext::oneapi::experimental::matrix::layout::dynamic>
393  &src,
395  size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout) {
396 #if defined(__SYCL_DEVICE_ONLY__)
397 #if defined(__NVPTX__)
398  std::ignore = sg;
399  throw runtime_error("Use joint_matrix_store on multi_ptr on Nvidia device.",
400  PI_ERROR_INVALID_DEVICE);
401 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
402  throw runtime_error("Use joint_matrix_store on multi_ptr on AMD device.",
403  PI_ERROR_INVALID_DEVICE);
404 #else
405  std::ignore = sg;
406  T *Ptr = dst.get();
407  __spirv_JointMatrixStoreINTEL<
408  T, T, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
410  Ptr, src.spvm, stride, sycl::detail::joint_matrix_layout_to_spv(Layout),
412 #endif // defined(__NVPTX__)
413 #else
414  std::ignore = sg;
415  std::ignore = src;
416  std::ignore = dst;
417  std::ignore = stride;
418  std::ignore = Layout;
419  throw runtime_error("joint matrix is not supported on host device.",
420  PI_ERROR_INVALID_DEVICE);
421 #endif // defined(__SYCL_DEVICE_ONLY__)
422 }
423 
424 template <typename Group, typename Ta, typename Tb, typename Tc, typename Td,
425  std::size_t M, std::size_t K, std::size_t N, layout LayoutA,
426  layout LayoutB>
427 #if defined(__SYCL_DEVICE_ONLY__)
428 [[__sycl_detail__::add_ir_attributes_function(
429  "sycl-joint-matrix-mad-type-A", "sycl-joint-matrix-mad-type-B",
430  "sycl-joint-matrix-mad-type-C", "sycl-joint-matrix-mad-type-D",
431  "sycl-joint-matrix-mad-size-M", "sycl-joint-matrix-mad-size-K",
432  "sycl-joint-matrix-mad-size-N",
433  sycl::detail::convertTypeToMatrixTypeString<Ta>(),
434  sycl::detail::convertTypeToMatrixTypeString<Tb>(),
435  sycl::detail::convertTypeToMatrixTypeString<Tc>(),
436  sycl::detail::convertTypeToMatrixTypeString<Td>(), M, K, N)]]
437 #endif // defined(__SYCL_DEVICE_ONLY__)
438 inline __SYCL_ALWAYS_INLINE void
440  Group,
441  joint_matrix<Group, Td, use::accumulator, M, N,
442  sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D,
445  const joint_matrix<Group, Tc, use::accumulator, M, N,
446  sycl::ext::oneapi::experimental::matrix::layout::dynamic>
447  &C) {
448 #if defined(__SYCL_DEVICE_ONLY__)
449 #if defined(__NVPTX__)
450  if constexpr (std::is_same<Ta, Tb>::value) {
451  sycl::ext::oneapi::detail::joint_matrix_mad_cuda<Ta, Tc, Td, M, K, N,
452  LayoutA, LayoutB>(
453  D.matrix_impl, A.matrix_impl, B.matrix_impl, C.matrix_impl);
454  } else {
455  assert(false && "Ta != Tb : In the CUDA backend joint_matrix_mad "
456  "requires that joint_matrix data types Ta and Tb match");
457  }
458 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
459  if constexpr (std::is_same<Ta, Tb>::value) {
460  sycl::ext::oneapi::detail::joint_matrix_mad_hip<Ta, Tc, M, K, N, LayoutA,
461  LayoutB>(
462  D.matrix_impl, A.matrix_impl, B.matrix_impl, C.matrix_impl);
463  } else {
464  assert(false && "Ta != Tb : In the HIP backend joint_matrix_mad "
465  "requires that joint_matrix data types Ta and Tb match");
466  }
467 #else
468  if constexpr (std::is_same<Ta, uint16_t>::value &&
469  std::is_same<Tb, uint16_t>::value &&
470  std::is_same<Tc, float>::value)
471  D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
472  else if constexpr (std::is_unsigned<Ta>::value && std::is_unsigned<Tb>::value)
473  D.spvm = __spirv_JointMatrixUUMadINTEL(A.spvm, B.spvm, C.spvm);
474  else if constexpr (std::is_signed<Ta>::value && std::is_unsigned<Tb>::value)
475  D.spvm = __spirv_JointMatrixSUMadINTEL(A.spvm, B.spvm, C.spvm);
476  else if constexpr (std::is_unsigned<Ta>::value && std::is_signed<Tb>::value)
477  D.spvm = __spirv_JointMatrixUSMadINTEL(A.spvm, B.spvm, C.spvm);
478  else
479  D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
480 #endif // defined(__NVPTX__)
481 #else
482  std::ignore = A;
483  std::ignore = B;
484  std::ignore = C;
485  std::ignore = D;
486  throw runtime_error("joint matrix is not supported on host device.",
487  PI_ERROR_INVALID_DEVICE);
488 #endif // defined(__SYCL_DEVICE_ONLY__)
489 }
490 
491 template <typename Group, typename T1, typename T2, size_t Rows, size_t Cols,
492  use Use1, use Use2, layout Layout1, layout Layout2>
496 #if defined(__SYCL_DEVICE_ONLY__)
497 #if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__)
498  std::ignore = sg;
499  dst.matrix_impl.wi_marray = src.matrix_impl.wi_marray;
500 #else
501  using storage_element_type =
503  T2>::storage_element_type;
504  auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, src);
505  auto wi_data_dst = sycl::ext::oneapi::detail::get_wi_data(sg, dst);
506  for (int i = 0; i < wi_data_c.length(); i++) {
507  wi_data_dst[i] = static_cast<storage_element_type>(wi_data_c[i]);
508  }
509 #endif // defined(__NVPTX__)
510 #else
511  std::ignore = sg;
512  std::ignore = dst;
513  std::ignore = src;
514  throw runtime_error("joint matrix is not supported on host device.",
515  PI_ERROR_INVALID_DEVICE);
516 #endif // defined(__SYCL_DEVICE_ONLY__)
517 }
518 
519 // This function rounds the bottom 13 bits up or down, and then zeros out the
520 // bottom bits
521 inline __SYCL_ALWAYS_INLINE float round_to_tf32(const float &a) {
522 #if defined(__SYCL_DEVICE_ONLY__)
523 #if defined(__NVPTX__)
524  int32_t tmp_int = __nvvm_f2tf32_rna(a);
525  return __nvvm_bitcast_i2f(tmp_int);
526 #else
527  return __spirv_RoundFToTF32INTEL(a);
528 #endif // defined(__NVPTX__)
529 #else
530  uint32_t tmp_uint = reinterpret_cast<const uint32_t &>(a);
531  tmp_uint += 0x1000u;
532  tmp_uint &= 0xFFFFE000u;
533  float ret = 0;
534  std::memcpy(&ret, &tmp_uint, sizeof(float));
535  return ret;
536 #endif // defined(__SYCL_DEVICE_ONLY__)
537 }
538 
539 template <size_t NumRows, size_t NumCols, typename Group, typename T,
540  typename Properties = ext::oneapi::experimental::empty_properties_t>
541 inline __SYCL_ALWAYS_INLINE void
542 joint_matrix_prefetch(Group sg, T *Ptr, size_t stride,
544  Properties properties = {}) {
545 #if defined(__SYCL_DEVICE_ONLY__)
546 #if defined(__NVPTX__)
547  std::ignore = sg;
548  std::ignore = properties;
549  throw runtime_error(
550  "joint_matrix_prefetch is not supported on Nvidia device.",
551  PI_ERROR_INVALID_DEVICE);
552 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
553  std::ignore = sg;
554  std::ignore = properties;
555  throw runtime_error("joint_matrix_prefetch is not supported on AMD device.",
556  PI_ERROR_INVALID_DEVICE);
557 #else
558  std::ignore = sg;
559  auto prop = properties.template get_property<prefetch_hint_key>();
560  __spirv_CooperativeMatrixPrefetchINTEL<T>(
561  Ptr, NumRows, NumCols, detail::PropertyMetaInfo<decltype(prop)>::value,
563 #endif // defined(__NVPTX__)
564 #else
565  std::ignore = sg;
566  std::ignore = Ptr;
567  std::ignore = stride;
568  std::ignore = Layout;
569  std::ignore = properties;
570  throw runtime_error("joint matrix is not supported on host device.",
571  PI_ERROR_INVALID_DEVICE);
572 #endif // defined(__SYCL_DEVICE_ONLY__)
573 }
574 
575 } // namespace matrix
576 } // namespace experimental
577 } // namespace oneapi
578 } // namespace ext
579 } // namespace _V1
580 } // namespace sycl
#define __SYCL_ALWAYS_INLINE
__SYCL_ALWAYS_INLINE __spv::MatrixLayout joint_matrix_layout_to_spv(sycl::ext::oneapi::experimental::matrix::layout Layout)
constexpr const char * convertMatrixUseEnumToString(ext::oneapi::experimental::matrix::use Use)
void joint_matrix_store_hip(const joint_matrix_hip< T, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &src, multi_ptr< T, Space, IsDecorated > dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout layout, Group &sg)
Definition: matrix-hip.hpp:317
void joint_matrix_mad_hip(joint_matrix_hip< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &D, const joint_matrix_hip< Tm, sycl::ext::oneapi::experimental::matrix::use::a, M, K, LayoutA > &A, const joint_matrix_hip< Tm, sycl::ext::oneapi::experimental::matrix::use::b, K, N, LayoutB > &B, const joint_matrix_hip< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &C)
Definition: matrix-hip.hpp:337
void load_accumulator_hip(joint_matrix_hip< S, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, sycl::ext::oneapi::experimental::matrix::layout layout, Group &sg)
Definition: matrix-hip.hpp:185
void load_multiplicand_hip(joint_matrix_hip< S, Use, M, N, Layout > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, Group &sg)
Definition: matrix-hip.hpp:211
decltype(auto) __SYCL_ALWAYS_INLINE get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, Rows, Cols, Layout > &jm)
__SYCL_ALWAYS_INLINE float round_to_tf32(const float &a)
void joint_matrix_copy(Group sg, joint_matrix< Group, T1, Use1, Rows, Cols, Layout1 > &src, joint_matrix< Group, T2, Use2, Rows, Cols, Layout2 > &dst)
__SYCL_ALWAYS_INLINE void joint_matrix_load(Group sg, joint_matrix< Group, S, use::accumulator, NumRows, NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout)
__SYCL_ALWAYS_INLINE void joint_matrix_fill(Group, joint_matrix< Group, T, Use, NumRows, NumCols, Layout > &res, const T2 &v)
__SYCL_ALWAYS_INLINE void joint_matrix_store(Group sg, const joint_matrix< Group, T, use::accumulator, NumRows, NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &src, multi_ptr< T, Space, IsDecorated > dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout)
__SYCL_ALWAYS_INLINE void joint_matrix_mad(Group, joint_matrix< Group, Td, use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &D, const joint_matrix< Group, Ta, use::a, M, K, LayoutA > &A, const joint_matrix< Group, Tb, use::b, K, N, LayoutB > &B, const joint_matrix< Group, Tc, use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &C)
__SYCL_ALWAYS_INLINE void joint_matrix_prefetch(Group sg, T *Ptr, size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout, Properties properties={})
__SYCL_ALWAYS_INLINE void joint_matrix_apply(Group sg, joint_matrix< Group, T, Use, M, N, Layout > &jm, F &&lambda)
annotated_arg & operator=(annotated_arg &)=default
Definition: access.hpp:18