DPC++ Runtime
Runtime libraries for oneAPI DPC++
matrix-unified.hpp
Go to the documentation of this file.
1 //===------- matrix-unified.hpp - SYCL matrix extension ----*- C++ -*------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // ===--------------------------------------------------------------------=== //
8 
9 #pragma once
10 
11 #include "matrix-intel.hpp"
12 
13 #if defined(__SYCL_DEVICE_ONLY__)
14 #if defined(__NVPTX__)
15 #include "matrix-tensorcores.hpp"
16 #elif defined(__gfx90a__)
17 #include "matrix-hip.hpp"
18 #endif // defined(__NVPTX__)
19 #endif // defined(__SYCL_DEVICE_ONLY__)
20 
21 #include <sycl/access/access.hpp> // for address_space
22 #include <sycl/detail/defines_elementary.hpp> // for __SYCL_ALWAYS_...
23 #include <sycl/exception.hpp>
24 #include <sycl/ext/oneapi/matrix/matrix-unified-utils.hpp> // for layout, use, tf32, convertMatrixUseEnumToString
25 #include <sycl/ext/oneapi/matrix/query-types.hpp> // for convertTypeToMatrixTypeString
26 #include <sycl/marray.hpp> // for marray
27 #include <sycl/multi_ptr.hpp> // for multi_ptr
28 
29 #include <cstring> // for size_t, memcpy
30 #include <stdint.h> // for uint32_t
31 #include <tuple> // for ignore, _Swall...
32 #include <type_traits> // for is_same, remov...
33 
34 namespace sycl {
35 inline namespace _V1 {
36 namespace ext {
37 namespace oneapi {
38 namespace experimental {
39 namespace matrix {
40 
41 template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
42  layout Layout>
43 struct joint_matrix {
44 
45 #if defined(__SYCL_DEVICE_ONLY__)
46 #if defined(__NVPTX__)
48  matrix_impl;
49 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
51  matrix_impl;
52 #elif defined(__SPIR__) || defined(__SPIRV__)
56 #else
57  static_assert(false, "The joint_matrix API is only supported by the Intel, "
58  "CUDA and HIP (GFX90A) backends");
59 #endif // defined(__NVPTX__)
60 #endif // defined(__SYCL_DEVICE_ONLY__)
61 
62 #if defined(__SYCL_DEVICE_ONLY__)
63  [[__sycl_detail__::add_ir_attributes_function(
64  "sycl-joint-matrix-type", "sycl-joint-matrix-use",
65  "sycl-joint-matrix-rows", "sycl-joint-matrix-cols",
66  sycl::detail::convertTypeToMatrixTypeString<T>(),
68 #endif // defined(__SYCL_DEVICE_ONLY__)
70 #ifndef __SYCL_DEVICE_ONLY__
72  "joint matrix is not supported on host.");
73 #endif
74  }
75 #ifdef __SYCL_DEVICE_ONLY__
76 #if defined(__SPIR__) || defined(__SPIRV__)
77  joint_matrix(const joint_matrix &other) = delete;
78  joint_matrix &operator=(const joint_matrix &rhs) = delete;
79 #endif // defined(__SPIR__) || defined(__SPIRV__)
80 #endif
81 };
82 
83 template <typename Group, typename T, use Use, size_t M, size_t N,
84  layout Layout, typename F>
85 inline __SYCL_ALWAYS_INLINE void
87  F &&lambda) {
88 #if defined(__SYCL_DEVICE_ONLY__)
89 #if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__)
90  std::ignore = sg;
91  for (int i = 0; i < jm.matrix_impl.wi_marray.size(); i++) {
92  lambda(jm.matrix_impl.wi_marray[i]);
93  }
94 #else // NVPTX
95  using storage_element_type =
97  T>::storage_element_type;
98  auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jm);
99  for (int i = 0; i < wi_data_c.length(); i++) {
100  storage_element_type element = wi_data_c[i];
101  lambda(element);
102  wi_data_c[i] = element;
103  }
104 #endif
105 #else
106  std::ignore = sg;
107  std::ignore = jm;
108  std::ignore = lambda;
110  "joint matrix is not supported on host.");
111 #endif
112  return;
113 }
114 
115 template <typename Group, typename T, use Use, size_t M, size_t N,
116  layout Layout, typename F>
117 inline __SYCL_ALWAYS_INLINE void
120  F &&lambda) {
121 #if defined(__SYCL_DEVICE_ONLY__)
122 #if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__)
123  std::ignore = sg;
124  for (int i = 0; i < jmsrc.matrix_impl.wi_marray.size(); i++) {
125  lambda(jmsrc.matrix_impl.wi_marray[i], jmdest.matrix_impl.wi_marray[i]);
126  }
127 #else // NVPTX
128  using storage_element_type =
130  T>::storage_element_type;
131  auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jmsrc);
132  auto wi_data_d = sycl::ext::oneapi::detail::get_wi_data(sg, jmdest);
133  for (int i = 0; i < wi_data_c.length(); i++) {
134  storage_element_type elementsrc = wi_data_c[i];
135  storage_element_type elementdest = wi_data_d[i];
136  lambda(elementsrc, elementdest);
137  wi_data_d[i] = elementdest;
138  }
139 #endif
140 #else
141  std::ignore = sg;
142  std::ignore = jmsrc;
143  std::ignore = jmdest;
144  std::ignore = lambda;
146  "joint matrix is not supported on host.");
147 #endif
148  return;
149 }
150 
151 template <typename Group, typename T, size_t NumRows, size_t NumCols, use Use,
152  layout Layout, typename T2>
153 inline __SYCL_ALWAYS_INLINE void
156  const T2 &v) {
157 #if defined(__SYCL_DEVICE_ONLY__)
158 #if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__)
159  res.matrix_impl.wi_marray = v;
160 #else
161  using storage_element_type =
163  T>::storage_element_type;
164  res.spvm =
165  __spirv_CompositeConstruct<storage_element_type, T, NumRows, NumCols,
168  static_cast<storage_element_type>(v));
169 #endif // defined(__NVPTX__)
170 #else
171  std::ignore = res;
172  std::ignore = v;
174  "joint matrix is not supported on host.");
175 #endif // defined(__SYCL_DEVICE_ONLY__)
176 }
177 
178 template <
179  typename Group, typename S, typename T, size_t NumRows, size_t NumCols,
180  access::address_space Space, access::decorated IsDecorated,
181  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value, bool> =
182  true>
184  Group sg,
185  joint_matrix<Group, S, use::accumulator, NumRows, NumCols,
186  sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res,
187  multi_ptr<T, Space, IsDecorated> src, size_t stride,
189 #if defined(__SYCL_DEVICE_ONLY__)
190  static_assert(Space != access::address_space::private_space,
191  "Joint Matrix doesn't support load from private memory!");
192 #if defined(__NVPTX__)
193  std::ignore = sg;
194  sycl::ext::oneapi::detail::load_accumulator_cuda(res.matrix_impl, src, stride,
195  Layout);
196 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
197  sycl::ext::oneapi::detail::load_accumulator_hip(res.matrix_impl, src, stride,
198  Layout, sg);
199 #else
200  std::ignore = sg;
201  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
202  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(src);
203  res.spvm = __spirv_JointMatrixLoadINTEL<
204  DecorT, S, NumRows, NumCols,
207  Ptr, stride, sycl::detail::joint_matrix_layout_to_spv(Layout),
209 #endif // defined(__NVPTX__)
210 #else
211  std::ignore = sg;
212  std::ignore = res;
213  std::ignore = src;
214  std::ignore = stride;
215  std::ignore = Layout;
217  "joint matrix is not supported on host.");
218 #endif // defined(__SYCL_DEVICE_ONLY__)
219 }
220 
221 template <
222  typename Group, typename S, typename T, use Use, size_t NumRows,
223  size_t NumCols, matrix::layout Layout, access::address_space Space,
224  access::decorated IsDecorated,
225  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
226  (std::is_same<S, precision::tf32>::value &&
227  std::is_same<std::remove_const_t<T>, float>::value),
228  bool> = true>
229 inline __SYCL_ALWAYS_INLINE void
232  multi_ptr<T, Space, IsDecorated> src, size_t stride) {
233 #if defined(__SYCL_DEVICE_ONLY__)
234  static_assert(Space != access::address_space::private_space,
235  "Joint Matrix doesn't support load from private memory!");
236 #if defined(__NVPTX__)
237  std::ignore = sg;
238  sycl::ext::oneapi::detail::load_multiplicand_cuda<S, T, NumRows, NumCols, Use,
239  Layout, Space>(
240  res.matrix_impl, src, stride);
241 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
243  NumCols, Use, Layout, Space>(
244  res.matrix_impl, src, stride, sg);
245 #else
246  std::ignore = sg;
247  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
248  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(src);
249  res.spvm =
250  __spirv_JointMatrixLoadINTEL<DecorT, S, NumRows, NumCols,
255 #endif // defined(__NVPTX__)
256 #else
257  std::ignore = sg;
258  std::ignore = res;
259  std::ignore = src;
260  std::ignore = stride;
262  "joint matrix is not supported on host.");
263 #endif // defined(__SYCL_DEVICE_ONLY__)
264 }
265 
266 template <typename Group, typename S, typename T, size_t NumRows,
267  size_t NumCols, typename PropertyListT,
268  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value,
269  bool> = true>
271  Group sg,
272  joint_matrix<Group, S, use::accumulator, NumRows, NumCols,
273  sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res,
275  size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout) {
276 #if defined(__SYCL_DEVICE_ONLY__)
277 #if defined(__NVPTX__)
278  std::ignore = sg;
280  "Use joint_matrix_load on multi_ptr on Nvidia device.");
281 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
283  "Use joint_matrix_load on multi_ptr on AMD device.");
284 #else
285  std::ignore = sg;
286  T *Ptr = src.get();
287  res.spvm = __spirv_JointMatrixLoadINTEL<
288  T, S, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
290  Ptr, stride, sycl::detail::joint_matrix_layout_to_spv(Layout),
292 #endif // defined(__NVPTX__)
293 #else
294  std::ignore = sg;
295  std::ignore = res;
296  std::ignore = src;
297  std::ignore = stride;
298  std::ignore = Layout;
300  "joint matrix is not supported on host.");
301 #endif // defined(__SYCL_DEVICE_ONLY__)
302 }
303 
304 template <
305  typename Group, typename S, typename T, use Use, size_t NumRows,
306  size_t NumCols, matrix::layout Layout, typename PropertyListT,
307  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
308  (std::is_same<S, precision::tf32>::value &&
309  std::is_same<std::remove_const_t<T>, float>::value),
310  bool> = true>
314  size_t stride) {
315 #if defined(__SYCL_DEVICE_ONLY__)
316 #if defined(__NVPTX__)
317  std::ignore = sg;
319  "Use joint_matrix_load on multi_ptr on Nvidia device.");
320 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
322  "Use joint_matrix_load on multi_ptr on AMD device.");
323 #else
324  std::ignore = sg;
325  T *Ptr = src.get();
326  res.spvm =
327  __spirv_JointMatrixLoadINTEL<T, S, NumRows, NumCols,
332 #endif // defined(__NVPTX__)
333 #else
334  std::ignore = sg;
335  std::ignore = res;
336  std::ignore = src;
337  std::ignore = stride;
339  "joint matrix is not supported on host.");
340 #endif // defined(__SYCL_DEVICE_ONLY__)
341 }
342 
343 template <typename Group, typename T, size_t NumRows, size_t NumCols,
344  access::address_space Space, access::decorated IsDecorated>
346  Group sg,
347  const joint_matrix<Group, T, use::accumulator, NumRows, NumCols,
348  sycl::ext::oneapi::experimental::matrix::layout::dynamic>
349  &src,
350  multi_ptr<T, Space, IsDecorated> dst, size_t stride,
352 #if defined(__SYCL_DEVICE_ONLY__)
353  static_assert(Space != access::address_space::private_space,
354  "Joint Matrix doesn't support store to private memory!");
355 #if defined(__NVPTX__)
356  std::ignore = sg;
357  sycl::ext::oneapi::detail::joint_matrix_store_cuda<T, NumRows, NumCols,
358  Space>(
359  src.matrix_impl, dst, stride, Layout);
360 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
361  sycl::ext::oneapi::detail::joint_matrix_store_hip<Group, T, NumRows, NumCols,
362  Space>(src.matrix_impl, dst,
363  stride, Layout, sg);
364 #else
365  std::ignore = sg;
366  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
367  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(dst);
368  __spirv_JointMatrixStoreINTEL<
369  DecorT, T, NumRows, NumCols,
372  Ptr, src.spvm, stride, sycl::detail::joint_matrix_layout_to_spv(Layout),
374 #endif // defined(__NVPTX__)
375 #else
376  std::ignore = sg;
377  std::ignore = src;
378  std::ignore = dst;
379  std::ignore = stride;
380  std::ignore = Layout;
382  "joint matrix is not supported on host.");
383 #endif // defined(__SYCL_DEVICE_ONLY__)
384 }
385 
386 template <typename Group, typename T, size_t NumRows, size_t NumCols,
387  typename PropertyListT>
389  Group sg,
390  const joint_matrix<Group, T, use::accumulator, NumRows, NumCols,
391  sycl::ext::oneapi::experimental::matrix::layout::dynamic>
392  &src,
394  size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout) {
395 #if defined(__SYCL_DEVICE_ONLY__)
396 #if defined(__NVPTX__)
397  std::ignore = sg;
399  "Use joint_matrix_store on multi_ptr on Nvidia device.");
400 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
402  "Use joint_matrix_store on multi_ptr on AMD device.");
403 #else
404  std::ignore = sg;
405  T *Ptr = dst.get();
406  __spirv_JointMatrixStoreINTEL<
407  T, T, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
409  Ptr, src.spvm, stride, sycl::detail::joint_matrix_layout_to_spv(Layout),
411 #endif // defined(__NVPTX__)
412 #else
413  std::ignore = sg;
414  std::ignore = src;
415  std::ignore = dst;
416  std::ignore = stride;
417  std::ignore = Layout;
419  "joint matrix is not supported on host.");
420 #endif // defined(__SYCL_DEVICE_ONLY__)
421 }
422 
423 template <typename Group, typename Ta, typename Tb, typename Tc, typename Td,
424  std::size_t M, std::size_t K, std::size_t N, layout LayoutA,
425  layout LayoutB>
426 #if defined(__SYCL_DEVICE_ONLY__)
427 [[__sycl_detail__::add_ir_attributes_function(
428  "sycl-joint-matrix-mad-type-A", "sycl-joint-matrix-mad-type-B",
429  "sycl-joint-matrix-mad-type-C", "sycl-joint-matrix-mad-type-D",
430  "sycl-joint-matrix-mad-size-M", "sycl-joint-matrix-mad-size-K",
431  "sycl-joint-matrix-mad-size-N",
432  sycl::detail::convertTypeToMatrixTypeString<Ta>(),
433  sycl::detail::convertTypeToMatrixTypeString<Tb>(),
434  sycl::detail::convertTypeToMatrixTypeString<Tc>(),
435  sycl::detail::convertTypeToMatrixTypeString<Td>(), M, K, N)]]
436 #endif // defined(__SYCL_DEVICE_ONLY__)
437 inline __SYCL_ALWAYS_INLINE void
439  Group,
440  joint_matrix<Group, Td, use::accumulator, M, N,
441  sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D,
444  const joint_matrix<Group, Tc, use::accumulator, M, N,
445  sycl::ext::oneapi::experimental::matrix::layout::dynamic>
446  &C) {
447 #if defined(__SYCL_DEVICE_ONLY__)
448 #if defined(__NVPTX__)
449  if constexpr (std::is_same<Ta, Tb>::value) {
450  sycl::ext::oneapi::detail::joint_matrix_mad_cuda<Ta, Tc, Td, M, K, N,
451  LayoutA, LayoutB>(
452  D.matrix_impl, A.matrix_impl, B.matrix_impl, C.matrix_impl);
453  } else {
454  assert(false && "Ta != Tb : In the CUDA backend joint_matrix_mad "
455  "requires that joint_matrix data types Ta and Tb match");
456  }
457 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
458  if constexpr (std::is_same<Ta, Tb>::value) {
459  sycl::ext::oneapi::detail::joint_matrix_mad_hip<Ta, Tc, M, K, N, LayoutA,
460  LayoutB>(
461  D.matrix_impl, A.matrix_impl, B.matrix_impl, C.matrix_impl);
462  } else {
463  assert(false && "Ta != Tb : In the HIP backend joint_matrix_mad "
464  "requires that joint_matrix data types Ta and Tb match");
465  }
466 #else
467  if constexpr (std::is_same<Ta, uint16_t>::value &&
468  std::is_same<Tb, uint16_t>::value &&
469  std::is_same<Tc, float>::value)
470  D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
471  else if constexpr (std::is_unsigned<Ta>::value && std::is_unsigned<Tb>::value)
472  D.spvm = __spirv_JointMatrixUUMadINTEL(A.spvm, B.spvm, C.spvm);
473  else if constexpr (std::is_signed<Ta>::value && std::is_unsigned<Tb>::value)
474  D.spvm = __spirv_JointMatrixSUMadINTEL(A.spvm, B.spvm, C.spvm);
475  else if constexpr (std::is_unsigned<Ta>::value && std::is_signed<Tb>::value)
476  D.spvm = __spirv_JointMatrixUSMadINTEL(A.spvm, B.spvm, C.spvm);
477  else
478  D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
479 #endif // defined(__NVPTX__)
480 #else
481  std::ignore = A;
482  std::ignore = B;
483  std::ignore = C;
484  std::ignore = D;
486  "joint matrix is not supported on host.");
487 #endif // defined(__SYCL_DEVICE_ONLY__)
488 }
489 
490 template <typename Group, typename T1, typename T2, size_t Rows, size_t Cols,
491  use Use1, use Use2, layout Layout1, layout Layout2>
495 #if defined(__SYCL_DEVICE_ONLY__)
496 #if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__)
497  std::ignore = sg;
498  dst.matrix_impl.wi_marray = src.matrix_impl.wi_marray;
499 #else
500  using storage_element_type =
502  T2>::storage_element_type;
503  auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, src);
504  auto wi_data_dst = sycl::ext::oneapi::detail::get_wi_data(sg, dst);
505  for (int i = 0; i < wi_data_c.length(); i++) {
506  wi_data_dst[i] = static_cast<storage_element_type>(wi_data_c[i]);
507  }
508 #endif // defined(__NVPTX__)
509 #else
510  std::ignore = sg;
511  std::ignore = dst;
512  std::ignore = src;
514  "joint matrix is not supported on host.");
515 #endif // defined(__SYCL_DEVICE_ONLY__)
516 }
517 
518 // This function rounds the bottom 13 bits up or down, and then zeros out the
519 // bottom bits
520 inline __SYCL_ALWAYS_INLINE float round_to_tf32(const float &a) {
521 #if defined(__SYCL_DEVICE_ONLY__)
522 #if defined(__NVPTX__)
523  int32_t tmp_int = __nvvm_f2tf32_rna(a);
524  return __nvvm_bitcast_i2f(tmp_int);
525 #else
526  return __spirv_RoundFToTF32INTEL(a);
527 #endif // defined(__NVPTX__)
528 #else
529  uint32_t tmp_uint = reinterpret_cast<const uint32_t &>(a);
530  tmp_uint += 0x1000u;
531  tmp_uint &= 0xFFFFE000u;
532  float ret = 0;
533  std::memcpy(&ret, &tmp_uint, sizeof(float));
534  return ret;
535 #endif // defined(__SYCL_DEVICE_ONLY__)
536 }
537 
538 template <size_t NumRows, size_t NumCols, typename Group, typename T,
539  typename Properties = ext::oneapi::experimental::empty_properties_t>
540 inline __SYCL_ALWAYS_INLINE void
541 joint_matrix_prefetch(Group sg, T *Ptr, size_t stride,
543  Properties properties = {}) {
544 #if defined(__SYCL_DEVICE_ONLY__)
545 #if defined(__NVPTX__)
546  std::ignore = sg;
547  std::ignore = properties;
549  "joint_matrix_prefetch is not supported on Nvidia device.");
550 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
551  std::ignore = sg;
552  std::ignore = properties;
554  "joint_matrix_prefetch is not supported on AMD device.");
555 #else
556  std::ignore = sg;
557  auto prop = properties.template get_property<prefetch_hint_key>();
558  __spirv_CooperativeMatrixPrefetchINTEL<T>(
559  Ptr, NumRows, NumCols, detail::PropertyMetaInfo<decltype(prop)>::value,
561 #endif // defined(__NVPTX__)
562 #else
563  std::ignore = sg;
564  std::ignore = Ptr;
565  std::ignore = stride;
566  std::ignore = Layout;
567  std::ignore = properties;
569  "joint matrix is not supported on host.");
570 #endif // defined(__SYCL_DEVICE_ONLY__)
571 }
572 
573 } // namespace matrix
574 } // namespace experimental
575 } // namespace oneapi
576 } // namespace ext
577 } // namespace _V1
578 } // namespace sycl
#define __SYCL_ALWAYS_INLINE
__SYCL_ALWAYS_INLINE __spv::MatrixLayout joint_matrix_layout_to_spv(sycl::ext::oneapi::experimental::matrix::layout Layout)
constexpr const char * convertMatrixUseEnumToString(ext::oneapi::experimental::matrix::use Use)
void joint_matrix_store_hip(const joint_matrix_hip< T, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &src, multi_ptr< T, Space, IsDecorated > dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout layout, Group &sg)
Definition: matrix-hip.hpp:317
void joint_matrix_mad_hip(joint_matrix_hip< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &D, const joint_matrix_hip< Tm, sycl::ext::oneapi::experimental::matrix::use::a, M, K, LayoutA > &A, const joint_matrix_hip< Tm, sycl::ext::oneapi::experimental::matrix::use::b, K, N, LayoutB > &B, const joint_matrix_hip< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &C)
Definition: matrix-hip.hpp:337
void load_accumulator_hip(joint_matrix_hip< S, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, sycl::ext::oneapi::experimental::matrix::layout layout, Group &sg)
Definition: matrix-hip.hpp:185
void load_multiplicand_hip(joint_matrix_hip< S, Use, M, N, Layout > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, Group &sg)
Definition: matrix-hip.hpp:211
decltype(auto) __SYCL_ALWAYS_INLINE get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, Rows, Cols, Layout > &jm)
__SYCL_ALWAYS_INLINE float round_to_tf32(const float &a)
void joint_matrix_copy(Group sg, joint_matrix< Group, T1, Use1, Rows, Cols, Layout1 > &src, joint_matrix< Group, T2, Use2, Rows, Cols, Layout2 > &dst)
__SYCL_ALWAYS_INLINE void joint_matrix_load(Group sg, joint_matrix< Group, S, use::accumulator, NumRows, NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout)
__SYCL_ALWAYS_INLINE void joint_matrix_fill(Group, joint_matrix< Group, T, Use, NumRows, NumCols, Layout > &res, const T2 &v)
__SYCL_ALWAYS_INLINE void joint_matrix_store(Group sg, const joint_matrix< Group, T, use::accumulator, NumRows, NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &src, multi_ptr< T, Space, IsDecorated > dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout)
__SYCL_ALWAYS_INLINE void joint_matrix_mad(Group, joint_matrix< Group, Td, use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &D, const joint_matrix< Group, Ta, use::a, M, K, LayoutA > &A, const joint_matrix< Group, Tb, use::b, K, N, LayoutB > &B, const joint_matrix< Group, Tc, use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &C)
__SYCL_ALWAYS_INLINE void joint_matrix_prefetch(Group sg, T *Ptr, size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout, Properties properties={})
__SYCL_ALWAYS_INLINE void joint_matrix_apply(Group sg, joint_matrix< Group, T, Use, M, N, Layout > &jm, F &&lambda)
annotated_arg & operator=(annotated_arg &)=default
std::error_code make_error_code(sycl::errc E) noexcept
Constructs an error code using e and sycl_category()
Definition: exception.cpp:65
Definition: access.hpp:18