DPC++ Runtime
Runtime libraries for oneAPI DPC++
matrix-intel.hpp
Go to the documentation of this file.
1 //==------------------ matrix-intel.hpp - SYCL matrix ----------*- C++ -*---==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // ===--------------------------------------------------------------------=== //
8 
9 #pragma once
10 
11 #include "matrix-unified-utils.hpp" // for use, layout, tf32, matrix
12 #include "utils.hpp" // for getDecorated
13 
14 #include <CL/__spirv/spirv_types.hpp> // for MatrixLayout, MatrixUse
15 #include <sycl/access/access.hpp> // for address_space, decorated
16 #include <sycl/builtins.hpp> // for fabs
17 #include <sycl/detail/defines_elementary.hpp> // for __SYCL_ALWAYS_INLINE
18 #include <sycl/detail/pi.h> // for PI_ERROR_INVALID_DEVICE
19 #include <sycl/exception.hpp> // for runtime_error
20 #include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16
22 #include <sycl/group.hpp> // for group
23 #include <sycl/multi_ptr.hpp> // for multi_ptr
24 #include <sycl/sub_group.hpp> // for sub_group
25 
26 #include <cstddef> // for size_t
27 #include <stdint.h> // for uint32_t
28 #include <tuple> // for ignore, tuple, _Swallo...
29 #include <type_traits> // for enable_if_t
30 
31 namespace sycl {
32 inline namespace _V1 {
33 namespace ext {
34 namespace oneapi {
35 namespace experimental {
36 namespace matrix {
37 
38 template <layout Layout> struct spv_matrix_layout_traits {
40 };
41 
42 #define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT) \
43  template <> struct spv_matrix_layout_traits<LAYOUT> { \
44  static constexpr __spv::MatrixLayout value = SPV_LAYOUT; \
45  };
46 
51 
52 template <use Use> struct spv_matrix_use_traits {
53  static constexpr __spv::MatrixUse value = __spv::MatrixUse::MatrixA;
54 };
55 
56 #define SPV_MATRIX_USE_TRAITS(USE, SPV_USE) \
57  template <> struct spv_matrix_use_traits<USE> { \
58  static constexpr __spv::MatrixUse value = SPV_USE; \
59  };
60 
64 
65 template <typename G> struct spv_scope_traits {};
66 template <> struct spv_scope_traits<sycl::sub_group> {
67  constexpr static auto value = __spv::Scope::Subgroup;
68 };
69 template <int D> struct spv_scope_traits<sycl::group<D>> {
70  constexpr static auto value = __spv::Scope::Workgroup;
71 };
72 
73 // forward declarations
74 template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
75  layout Layout>
76 struct joint_matrix;
77 
78 } // namespace matrix
79 } // namespace experimental
80 
81 namespace detail {
82 // Differentiating between the "element type" and the "storage element type"
83 template <typename T> struct jm_type_interpretation_helper_trait {
84  using element_type = T;
86 };
87 
88 template <>
92  using storage_element_type = float;
93 };
94 
95 using namespace sycl::ext::oneapi::experimental::matrix;
96 // Begin wi_element definition
97 
98 template <typename T, size_t NumRows, size_t NumCols,
101  sycl::ext::oneapi::experimental::matrix::layout::dynamic,
102  typename Group = sycl::sub_group>
103 class wi_element {
105  NumCols, Layout> &M;
106  std::size_t idx;
107 
108 public:
113  Group, T, Use, NumRows, NumCols, Layout> &Mat,
114  std::size_t i)
115  : M(Mat), idx(i) {}
116 
117  inline __SYCL_ALWAYS_INLINE std::tuple<size_t, size_t> get_coord() {
118 #if defined(__SYCL_DEVICE_ONLY__)
119  __ocl_vec_t<uint32_t, 2> coord =
120  __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
121  const size_t row = coord[0];
122  const size_t col = coord[1];
123  return std::make_tuple(row, col);
124 #else
125  throw runtime_error("joint matrix is not supported on host device.",
126  PI_ERROR_INVALID_DEVICE);
127 #endif // __SYCL_DEVICE_ONLY__
128  }
129 
130  operator storage_element_type() {
131 #ifdef __SYCL_DEVICE_ONLY__
133  __spirv_VectorExtractDynamic<storage_element_type, T, NumRows, NumCols,
134  spv_matrix_use_traits<Use>::value,
135  spv_matrix_layout_traits<Layout>::value,
136  spv_scope_traits<Group>::value>(M.spvm,
137  idx);
138  return elem;
139 #else
140  throw runtime_error("joint matrix is not supported on host device.",
141  PI_ERROR_INVALID_DEVICE);
142 #endif // __SYCL_DEVICE_ONLY__
143  }
144 
145  explicit operator bool() {
146 #ifdef __SYCL_DEVICE_ONLY__
147  return __spirv_VectorExtractDynamic<storage_element_type, T, NumRows,
148  NumCols,
149  spv_matrix_use_traits<Use>::value,
150  spv_matrix_layout_traits<Layout>::value,
151  spv_scope_traits<Group>::value>(
152  M.spvm, idx) != static_cast<storage_element_type>(0);
153 #else
154  throw runtime_error("joint matrix is not supported on host device.",
155  PI_ERROR_INVALID_DEVICE);
156 #endif // __SYCL_DEVICE_ONLY__
157  }
158 
159  template <typename T2> wi_element &operator=(const T2 &rhs) {
160 #ifdef __SYCL_DEVICE_ONLY__
161  M.spvm = __spirv_VectorInsertDynamic(
162  M.spvm, static_cast<storage_element_type>(rhs), idx);
163  return *this;
164 #else
165  (void)rhs;
166  throw runtime_error("joint matrix is not supported on host device.",
167  PI_ERROR_INVALID_DEVICE);
168 #endif // __SYCL_DEVICE_ONLY__
169  }
170 
171  wi_element &
173 #ifdef __SYCL_DEVICE_ONLY__
174  M.spvm = __spirv_VectorInsertDynamic(
175  M.spvm,
176  __spirv_VectorExtractDynamic<storage_element_type, T, NumRows, NumCols,
177  spv_matrix_use_traits<Use>::value,
178  spv_matrix_layout_traits<Layout>::value,
179  spv_scope_traits<Group>::value>(rhs.M.spvm,
180  rhs.idx),
181  idx);
182  return *this;
183 #else
184  (void)rhs;
185  throw runtime_error("joint matrix is not supported on host device.",
186  PI_ERROR_INVALID_DEVICE);
187 #endif // __SYCL_DEVICE_ONLY__
188  }
189 
190 #if __SYCL_DEVICE_ONLY__
191 #define OP(op) \
192  template <typename T2> wi_element &operator op##=(const T2 & rhs) { \
193  M.spvm = __spirv_VectorInsertDynamic( \
194  M.spvm, \
195  static_cast<storage_element_type>( \
196  __spirv_VectorExtractDynamic< \
197  storage_element_type, T, NumRows, NumCols, \
198  spv_matrix_use_traits<Use>::value, \
199  spv_matrix_layout_traits<Layout>::value, \
200  spv_scope_traits<Group>::value>(M.spvm, idx) \
201  op static_cast<storage_element_type>(rhs)), \
202  idx); \
203  return *this; \
204  }
205 #else // __SYCL_DEVICE_ONLY__
206 #define OP(op) \
207  template <typename T2> wi_element &operator op##=(const T2 & rhs) { \
208  (void)rhs; \
209  throw runtime_error("joint matrix is not supported on host device.", \
210  PI_ERROR_INVALID_DEVICE); \
211  }
212 #endif // __SYCL_DEVICE_ONLY__
213  OP(+)
214  OP(-)
215  OP(*)
216  OP(/)
217 #undef OP
218 };
219 
220 template <size_t NumRows, size_t NumCols,
223  typename Group>
224 class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
225  Group> {
227  Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols, Layout> &M;
228  std::size_t idx;
229 
230 public:
232  Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols,
233  Layout> &Mat,
234  std::size_t i)
235  : M(Mat), idx(i) {}
236 
237  inline __SYCL_ALWAYS_INLINE std::tuple<uint32_t, uint32_t> get_coord() {
238 #if defined(__SYCL_DEVICE_ONLY__)
239  __ocl_vec_t<uint32_t, 2> coord =
240  __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
241  const uint32_t row = coord[0];
242  const uint32_t col = coord[1];
243  return std::make_tuple(row, col);
244 #else
245  throw runtime_error("joint matrix is not supported on host device.",
246  PI_ERROR_INVALID_DEVICE);
247 #endif // __SYCL_DEVICE_ONLY__
248  }
249 
251 #ifdef __SYCL_DEVICE_ONLY__
252  return __spirv_VectorExtractDynamic<
254  NumCols, spv_matrix_use_traits<Use>::value,
255  spv_matrix_layout_traits<Layout>::value,
256  spv_scope_traits<Group>::value>(M.spvm, idx);
257 #else
258  throw runtime_error("joint matrix is not supported on host device.",
259  PI_ERROR_INVALID_DEVICE);
260 #endif // __SYCL_DEVICE_ONLY__
261  }
262 
263  explicit operator bool() {
264 #ifdef __SYCL_DEVICE_ONLY__
265  return sycl::fabs(static_cast<float>(
266  __spirv_VectorExtractDynamic<
268  NumRows, NumCols, spv_matrix_use_traits<Use>::value,
269  spv_matrix_layout_traits<Layout>::value,
270  spv_scope_traits<Group>::value>(M.spvm, idx))) >=
271  std::numeric_limits<float>::epsilon();
272 #else
273  throw runtime_error("joint matrix is not supported on host device.",
274  PI_ERROR_INVALID_DEVICE);
275 #endif // __SYCL_DEVICE_ONLY__
276  }
277 
279 #ifdef __SYCL_DEVICE_ONLY__
280  M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx);
281  return *this;
282 #else
283  (void)rhs;
284  throw runtime_error("joint matrix is not supported on host device.",
285  PI_ERROR_INVALID_DEVICE);
286 #endif // __SYCL_DEVICE_ONLY__
287  }
288 
290  NumCols, Use, Layout, Group> &rhs) {
291 #ifdef __SYCL_DEVICE_ONLY__
292  M.spvm = __spirv_VectorInsertDynamic(
293  M.spvm,
294  __spirv_VectorExtractDynamic<sycl::ext::oneapi::bfloat16,
296  NumCols, spv_matrix_use_traits<Use>::value,
297  spv_matrix_layout_traits<Layout>::value,
298  spv_scope_traits<Group>::value>(rhs.M.spvm,
299  rhs.idx),
300  idx);
301  return *this;
302 #else
303  (void)rhs;
304  throw runtime_error("joint matrix is not supported on host device.",
305  PI_ERROR_INVALID_DEVICE);
306 #endif // __SYCL_DEVICE_ONLY__
307  }
308 
309 #if __SYCL_DEVICE_ONLY__
310 #define OP(opassign, op) \
311  wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \
312  M.spvm = __spirv_VectorInsertDynamic( \
313  M.spvm, \
314  __spirv_VectorExtractDynamic< \
315  sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
316  NumCols, spv_matrix_use_traits<Use>::value, \
317  spv_matrix_layout_traits<Layout>::value, \
318  spv_scope_traits<Group>::value>(M.spvm, idx) op rhs, \
319  idx); \
320  return *this; \
321  }
322 #else // __SYCL_DEVICE_ONLY__
323 #define OP(opassign, op) \
324  wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \
325  (void)rhs; \
326  throw runtime_error("joint matrix is not supported on host device.", \
327  PI_ERROR_INVALID_DEVICE); \
328  }
329 #endif // __SYCL_DEVICE_ONLY__
330  OP(+=, +)
331  OP(-=, -)
332  OP(*=, *)
333  OP(/=, /)
334 #undef OP
335 
336 #if __SYCL_DEVICE_ONLY__
337 #define OP(type, op) \
338  friend type operator op( \
339  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
340  Layout, Group> &lhs, \
341  const sycl::ext::oneapi::bfloat16 &rhs) { \
342  return __spirv_VectorExtractDynamic< \
343  sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
344  NumCols, spv_matrix_use_traits<Use>::value, \
345  spv_matrix_layout_traits<Layout>::value, \
346  spv_scope_traits<Group>::value>(lhs.M.spvm, lhs.idx) op rhs; \
347  } \
348  friend type operator op( \
349  const sycl::ext::oneapi::bfloat16 &lhs, \
350  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
351  Layout, Group> &rhs) { \
352  return __spirv_VectorExtractDynamic< \
353  sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
354  NumCols, spv_matrix_use_traits<Use>::value, \
355  spv_matrix_layout_traits<Layout>::value, \
356  spv_scope_traits<Group>::value>(rhs.M.spvm, rhs.idx) op lhs; \
357  }
359  OP(sycl::ext::oneapi::bfloat16, -)
360  OP(sycl::ext::oneapi::bfloat16, *)
361  OP(sycl::ext::oneapi::bfloat16, /)
362 #undef OP
363 #define OP(type, op) \
364  friend type operator op( \
365  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
366  Layout, Group> &lhs, \
367  const sycl::ext::oneapi::bfloat16 &rhs) { \
368  return type{static_cast<float>( \
369  __spirv_VectorExtractDynamic< \
370  sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
371  NumCols, spv_matrix_use_traits<Use>::value, \
372  spv_matrix_layout_traits<Layout>::value, \
373  spv_scope_traits<Group>::value>(lhs.M.spvm, lhs.idx)) \
374  op static_cast<float>(rhs)}; \
375  } \
376  friend type operator op( \
377  const sycl::ext::oneapi::bfloat16 &lhs, \
378  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
379  Layout, Group> &rhs) { \
380  return type{static_cast<float>( \
381  __spirv_VectorExtractDynamic< \
382  sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
383  NumCols, spv_matrix_use_traits<Use>::value, \
384  spv_matrix_layout_traits<Layout>::value, \
385  spv_scope_traits<Group>::value>(rhs.M.spvm, rhs.idx)) \
386  op static_cast<float>(lhs)}; \
387  }
388  OP(bool, ==)
389  OP(bool, !=)
390  OP(bool, <)
391  OP(bool, >)
392  OP(bool, <=)
393  OP(bool, >=)
394 #undef OP
395 #else // __SYCL_DEVICE_ONLY__
396 #define OP(type, op) \
397  friend type operator op( \
398  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
399  Layout, Group> &, \
400  const sycl::ext::oneapi::bfloat16 &) { \
401  throw runtime_error("joint matrix is not supported on host device.", \
402  PI_ERROR_INVALID_DEVICE); \
403  } \
404  friend type operator op( \
405  const sycl::ext::oneapi::bfloat16 &, \
406  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
407  Layout, Group> &) { \
408  throw runtime_error("joint matrix is not supported on host device.", \
409  PI_ERROR_INVALID_DEVICE); \
410  }
412  OP(sycl::ext::oneapi::bfloat16, -)
413  OP(sycl::ext::oneapi::bfloat16, *)
414  OP(sycl::ext::oneapi::bfloat16, /)
415  OP(bool, ==)
416  OP(bool, !=)
417  OP(bool, <)
418  OP(bool, >)
419  OP(bool, <=)
420  OP(bool, >=)
421 #undef OP
422 #endif // __SYCL_DEVICE_ONLY__
423 };
424 
425 // End wi_element definition
426 
427 // Begin wi_data definition
428 
429 template <typename Group, typename T,
432 class wi_data {
433 
435  Cols, Layout> &jm;
436 
438  Group, T, Use, Rows, Cols, Layout> &_jm)
439  : jm(_jm){};
440 
441  template <typename Grp, typename Type,
442  sycl::ext::oneapi::experimental::matrix::use UseJm, size_t NumRows,
443  size_t NumCols,
445  friend decltype(auto)
447  Grp, Type, UseJm, NumRows, NumCols, LayoutJm> &);
448 
449 public:
450  size_t length() {
451 #if __SYCL_DEVICE_ONLY__
452  return __spirv_JointMatrixWorkItemLengthINTEL(jm.spvm);
453 #else
454  throw runtime_error("joint matrix is not supported on host device.",
455  PI_ERROR_INVALID_DEVICE);
456 #endif
457  };
458 
459  decltype(auto) operator[](size_t i) {
461  };
462 };
463 
464 template <typename Group, typename T,
467 inline __SYCL_ALWAYS_INLINE decltype(auto)
468 get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix<
469  Group, T, Use, Rows, Cols, Layout> &jm) {
470  std::ignore = sg;
471  return wi_data(jm);
472 }
473 
474 // End wi_data definition
475 } // namespace detail
476 } // namespace oneapi
477 
478 namespace intel::experimental::matrix {
479 template <
480  typename Group, typename T, typename Tp,
483  access::address_space Space, access::decorated IsDecorated,
484  std::enable_if_t<Use == sycl::ext::oneapi::experimental::matrix::use::a ||
486  bool> = true>
487 inline __SYCL_ALWAYS_INLINE void
490  Group, Tp, Use, NumRows, NumCols, Layout> &src,
491  multi_ptr<T, Space, IsDecorated> dst, size_t stride) {
492 #if defined(__SYCL_DEVICE_ONLY__)
493  static_assert(Space != access::address_space::private_space,
494  "Joint Matrix doesn't support store to private memory!");
495 #if defined(__NVPTX__)
496  std::ignore = src;
497  std::ignore = dst;
498  std::ignore = stride;
499  throw runtime_error(
500  "This version of the matrix extension is only currently supported on "
501  "intel devices",
502  PI_ERROR_INVALID_DEVICE);
503 #else
504  // intel's impl
505  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
506  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(dst);
507  __spirv_JointMatrixStoreINTEL<DecorT, Tp, NumRows, NumCols,
512  Ptr, src.spvm, stride,
514  Layout>::value,
516 #endif // defined(__NVPTX__)
517 #else
518  std::ignore = src;
519  std::ignore = dst;
520  std::ignore = stride;
521  throw runtime_error("joint matrix is not supported on host device.",
522  PI_ERROR_INVALID_DEVICE);
523 #endif // defined(__SYCL_DEVICE_ONLY__)
524 }
525 
526 template <
527  typename Group, typename T, typename Tp,
530  typename PropertyListT,
531  std::enable_if_t<Use == sycl::ext::oneapi::experimental::matrix::use::a ||
533  bool> = true>
535  Group,
537  Group, Tp, Use, NumRows, NumCols, Layout> &src,
539  size_t stride) {
540 #if defined(__SYCL_DEVICE_ONLY__)
541 #if defined(__NVPTX__)
542  std::ignore = src;
543  std::ignore = dst;
544  std::ignore = stride;
545  throw runtime_error(
546  "This version of the matrix extension is only currently supported on "
547  "intel devices",
548  PI_ERROR_INVALID_DEVICE);
549 #else
550  // intel's impl
551  T *Ptr = dst.get();
552  __spirv_JointMatrixStoreINTEL<T, Tp, NumRows, NumCols,
557  Ptr, src.spvm, stride,
559  Layout>::value,
561 #endif // defined(__NVPTX__)
562 #else
563  std::ignore = src;
564  std::ignore = dst;
565  std::ignore = stride;
566  throw runtime_error("joint matrix is not supported on host device.",
567  PI_ERROR_INVALID_DEVICE);
568 #endif // defined(__SYCL_DEVICE_ONLY__)
569 }
570 
571 template <typename Group, typename T,
574  typename F>
576  Group sg,
578  Cols, Layout> &jm,
579  F &&lambda) {
580 #if defined(__SYCL_DEVICE_ONLY__)
581 #if defined(__NVPTX__)
582  std::ignore = sg;
583  for (int i = 0; i < jm.matrix_impl.wi_marray.size(); i++) {
584  lambda(jm.matrix_impl.wi_marray[i]);
585  }
586 #else // NVPTX
587  using storage_element_type =
589  T>::storage_element_type;
590  auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jm);
591  for (int i = 0; i < wi_data_c.length(); i++) {
592  storage_element_type element = wi_data_c[i];
593  auto [row, col] = wi_data_c[i].get_coord();
594  lambda(element, row, col);
595  wi_data_c[i] = element;
596  }
597 #endif
598 #else
599  std::ignore = sg;
600  std::ignore = jm;
601  std::ignore = lambda;
602  throw runtime_error("joint matrix is not supported on host device.",
603  PI_ERROR_INVALID_DEVICE);
604 #endif
605 }
606 
607 using namespace sycl::ext::oneapi::experimental::matrix;
608 
609 // Begin out-of-bounds API
610 
611 template <typename Group, typename T, size_t NumRows, size_t NumCols, use Use,
612  layout Layout, typename T2>
614  Group, joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &Res,
615  const T2 &Value, size_t Height, size_t Width, size_t CoordX,
616  size_t CoordY) {
617 #if defined(__SYCL_DEVICE_ONLY__)
618  using storage_element_type =
620  T>::storage_element_type;
621  Res.spvm = __spirv_CooperativeMatrixConstructCheckedINTEL<
622  storage_element_type, T, NumRows, NumCols,
623  spv_matrix_use_traits<Use>::value,
624  spv_matrix_layout_traits<Layout>::value>(
625  CoordX, CoordY, Height, Width, static_cast<storage_element_type>(Value));
626 #else
627  std::ignore = Res;
628  std::ignore = Value;
629  std::ignore = Height;
630  std::ignore = Width;
631  std::ignore = CoordX;
632  std::ignore = CoordY;
633  throw runtime_error("joint matrix is not supported on host device.",
634  PI_ERROR_INVALID_DEVICE);
635 #endif // defined(__SYCL_DEVICE_ONLY__)
636 }
637 
638 template <
639  typename Group, typename S, typename T, size_t NumRows, size_t NumCols,
640  access::address_space Space, access::decorated IsDecorated,
641  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value, bool> =
642  true>
644  Group sg,
645  joint_matrix<Group, S, use::accumulator, NumRows, NumCols, layout::dynamic>
646  &Res,
647  multi_ptr<T, Space, IsDecorated> Src, size_t Stride, layout Layout,
648  size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
649 #if defined(__SYCL_DEVICE_ONLY__)
650  static_assert(Space != access::address_space::private_space,
651  "Joint Matrix doesn't support load from private memory!");
652  std::ignore = sg;
653  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
654  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
655  Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
656  DecorT, S, NumRows, NumCols,
657  spv_matrix_use_traits<use::accumulator>::value,
658  spv_matrix_layout_traits<layout::dynamic>::value>(
659  Ptr, CoordX, CoordY, sycl::detail::joint_matrix_layout_to_spv(Layout),
660  Height, Width, Stride);
661 #else
662  std::ignore = sg;
663  std::ignore = Res;
664  std::ignore = Src;
665  std::ignore = Stride;
666  std::ignore = Height;
667  std::ignore = Width;
668  std::ignore = Layout;
669  std::ignore = CoordX;
670  std::ignore = CoordY;
671  throw runtime_error("joint matrix is not supported on host device.",
672  PI_ERROR_INVALID_DEVICE);
673 #endif // defined(__SYCL_DEVICE_ONLY__)
674 }
675 
676 template <
677  typename Group, typename S, typename T, use Use, size_t NumRows,
678  size_t NumCols, layout Layout, access::address_space Space,
679  access::decorated IsDecorated,
680  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
681  (std::is_same<S, precision::tf32>::value &&
682  std::is_same<std::remove_const_t<T>, float>::value),
683  bool> = true>
685  Group sg, joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &Res,
686  multi_ptr<T, Space, IsDecorated> Src, size_t Stride, size_t Height,
687  size_t Width, size_t CoordX, size_t CoordY) {
688 #if defined(__SYCL_DEVICE_ONLY__)
689  static_assert(Space != access::address_space::private_space,
690  "Joint Matrix doesn't support load from private memory!");
691  std::ignore = sg;
692  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
693  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
694  Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
695  DecorT, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
696  spv_matrix_layout_traits<Layout>::value>(
697  Ptr, CoordX, CoordY, spv_matrix_layout_traits<Layout>::value, Height,
698  Width, Stride);
699 #else
700  std::ignore = sg;
701  std::ignore = Res;
702  std::ignore = Src;
703  std::ignore = Stride;
704  std::ignore = Height;
705  std::ignore = Width;
706  std::ignore = CoordX;
707  std::ignore = CoordY;
708  throw runtime_error("joint matrix is not supported on host device.",
709  PI_ERROR_INVALID_DEVICE);
710 #endif // defined(__SYCL_DEVICE_ONLY__)
711 }
712 
713 template <typename Group, typename T, size_t NumRows, size_t NumCols,
714  access::address_space Space, access::decorated IsDecorated>
716  Group sg,
717  joint_matrix<Group, T, use::accumulator, NumRows, NumCols, layout::dynamic>
718  &Src,
719  multi_ptr<T, Space, IsDecorated> Dst, size_t Stride, layout Layout,
720  size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
721 #if defined(__SYCL_DEVICE_ONLY__)
722  static_assert(Space != access::address_space::private_space,
723  "Joint Matrix doesn't support store to private memory!");
724  std::ignore = sg;
725  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
726  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
727  __spirv_CooperativeMatrixStoreCheckedINTEL<
728  DecorT, T, NumRows, NumCols,
729  spv_matrix_use_traits<use::accumulator>::value,
730  spv_matrix_layout_traits<layout::dynamic>::value>(
731  Ptr, CoordX, CoordY, Src.spvm,
732  sycl::detail::joint_matrix_layout_to_spv(Layout), Height, Width, Stride);
733 #else
734  std::ignore = sg;
735  std::ignore = Src;
736  std::ignore = Dst;
737  std::ignore = Stride;
738  std::ignore = Height;
739  std::ignore = Width;
740  std::ignore = Layout;
741  std::ignore = CoordX;
742  std::ignore = CoordY;
743  throw runtime_error("joint matrix is not supported on host device.",
744  PI_ERROR_INVALID_DEVICE);
745 #endif // defined(__SYCL_DEVICE_ONLY__)
746 }
747 
748 template <typename Group, typename T, typename Tp, use Use, size_t NumRows,
749  size_t NumCols, layout Layout, access::address_space Space,
750  access::decorated IsDecorated,
751  std::enable_if_t<Use == use::a || Use == use::b, bool> = true>
753  Group sg, const joint_matrix<Group, Tp, Use, NumRows, NumCols, Layout> &Src,
754  multi_ptr<T, Space, IsDecorated> Dst, size_t Stride, size_t Height,
755  size_t Width, size_t CoordX, size_t CoordY) {
756 #if defined(__SYCL_DEVICE_ONLY__)
757  static_assert(Space != access::address_space::private_space,
758  "Joint Matrix doesn't support store to private memory!");
759  std::ignore = sg;
760  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
761  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
762  __spirv_CooperativeMatrixStoreCheckedINTEL<
763  DecorT, Tp, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
764  spv_matrix_layout_traits<Layout>::value>(
765  Ptr, CoordX, CoordY, Src.spvm, spv_matrix_layout_traits<Layout>::value,
766  Height, Width, Stride);
767 #else
768  std::ignore = sg;
769  std::ignore = Src;
770  std::ignore = Dst;
771  std::ignore = Stride;
772  std::ignore = Height;
773  std::ignore = Width;
774  std::ignore = CoordX;
775  std::ignore = CoordY;
776  throw runtime_error("joint matrix is not supported on host device.",
777  PI_ERROR_INVALID_DEVICE);
778 #endif // defined(__SYCL_DEVICE_ONLY__)
779 }
780 
781 // Annotated pointer overloads:
782 template <typename Group, typename S, typename T, size_t NumRows,
783  size_t NumCols, typename PropertyListT,
784  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value,
785  bool> = true>
787  Group sg,
788  joint_matrix<Group, S, use::accumulator, NumRows, NumCols, layout::dynamic>
789  &Res,
791  size_t Stride, layout Layout, size_t Height, size_t Width, size_t CoordX,
792  size_t CoordY) {
793 #if defined(__SYCL_DEVICE_ONLY__)
794  std::ignore = sg;
795  T *Ptr = Src.get();
796  Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
797  T, S, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
798  spv_matrix_layout_traits<layout::dynamic>::value>(
799  Ptr, CoordX, CoordY, sycl::detail::joint_matrix_layout_to_spv(Layout),
800  Height, Width, Stride);
801 #else
802  std::ignore = sg;
803  std::ignore = Res;
804  std::ignore = Src;
805  std::ignore = Stride;
806  std::ignore = Height;
807  std::ignore = Width;
808  std::ignore = Layout;
809  std::ignore = CoordX;
810  std::ignore = CoordY;
811  throw runtime_error("joint matrix is not supported on host device.",
812  PI_ERROR_INVALID_DEVICE);
813 #endif // defined(__SYCL_DEVICE_ONLY__)
814 }
815 
816 template <
817  typename Group, typename S, typename T, use Use, size_t NumRows,
818  size_t NumCols, layout Layout, typename PropertyListT,
819  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
820  (std::is_same<S, precision::tf32>::value &&
821  std::is_same<std::remove_const_t<T>, float>::value),
822  bool> = true>
824  Group sg, joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &Res,
826  size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
827 #if defined(__SYCL_DEVICE_ONLY__)
828  std::ignore = sg;
829  T *Ptr = Src.get();
830  Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
831  T, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
832  spv_matrix_layout_traits<Layout>::value>(
833  Ptr, CoordX, CoordY, spv_matrix_layout_traits<Layout>::value, Height,
834  Width, Stride);
835 #else
836  std::ignore = sg;
837  std::ignore = Res;
838  std::ignore = Src;
839  std::ignore = Stride;
840  std::ignore = Height;
841  std::ignore = Width;
842  std::ignore = CoordX;
843  std::ignore = CoordY;
844  throw runtime_error("joint matrix is not supported on host device.",
845  PI_ERROR_INVALID_DEVICE);
846 #endif // defined(__SYCL_DEVICE_ONLY__)
847 }
848 
849 template <typename Group, typename T, size_t NumRows, size_t NumCols,
850  typename PropertyListT>
852  Group sg,
853  joint_matrix<Group, T, use::accumulator, NumRows, NumCols, layout::dynamic>
854  &Src,
856  size_t Stride, layout Layout, size_t Height, size_t Width, size_t CoordX,
857  size_t CoordY) {
858 #if defined(__SYCL_DEVICE_ONLY__)
859  std::ignore = sg;
860  T *Ptr = Dst.get();
861  __spirv_CooperativeMatrixStoreCheckedINTEL<
862  T, T, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
863  spv_matrix_layout_traits<layout::dynamic>::value>(
864  Ptr, CoordX, CoordY, Src.spvm,
865  sycl::detail::joint_matrix_layout_to_spv(Layout), Height, Width, Stride);
866 #else
867  std::ignore = sg;
868  std::ignore = Src;
869  std::ignore = Dst;
870  std::ignore = Stride;
871  std::ignore = Height;
872  std::ignore = Width;
873  std::ignore = Layout;
874  std::ignore = CoordX;
875  std::ignore = CoordY;
876  throw runtime_error("joint matrix is not supported on host device.",
877  PI_ERROR_INVALID_DEVICE);
878 #endif // defined(__SYCL_DEVICE_ONLY__)
879 }
880 
881 template <typename Group, typename T, typename Tp, use Use, size_t NumRows,
882  size_t NumCols, layout Layout, typename PropertyListT,
883  std::enable_if_t<Use == use::a || Use == use::b, bool> = true>
885  Group sg, const joint_matrix<Group, Tp, Use, NumRows, NumCols, Layout> &Src,
887  size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
888 #if defined(__SYCL_DEVICE_ONLY__)
889  std::ignore = sg;
890  T *Ptr = Dst.get();
891  __spirv_CooperativeMatrixStoreCheckedINTEL<
892  T, Tp, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
893  spv_matrix_layout_traits<Layout>::value>(
894  Ptr, CoordX, CoordY, Src.spvm, spv_matrix_layout_traits<Layout>::value,
895  Height, Width, Stride);
896 #else
897  std::ignore = sg;
898  std::ignore = Src;
899  std::ignore = Dst;
900  std::ignore = Stride;
901  std::ignore = Height;
902  std::ignore = Width;
903  std::ignore = CoordX;
904  std::ignore = CoordY;
905  throw runtime_error("joint matrix is not supported on host device.",
906  PI_ERROR_INVALID_DEVICE);
907 #endif // defined(__SYCL_DEVICE_ONLY__)
908 }
909 // End out-of-bounds API
910 
911 } // namespace intel::experimental::matrix
912 
913 } // namespace ext
914 } // namespace _V1
915 } // namespace sycl
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols, Layout > &Mat, std::size_t i)
wi_element & operator=(const wi_element< sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout, Group > &rhs)
wi_element & operator=(const wi_element< T, NumRows, NumCols, Use, Layout, Group > &rhs)
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, NumRows, NumCols, Layout > &Mat, std::size_t i)
wi_element & operator=(const T2 &rhs)
typename oneapi::detail::jm_type_interpretation_helper_trait< T >::storage_element_type storage_element_type
__SYCL_ALWAYS_INLINE std::tuple< size_t, size_t > get_coord()
#define __SYCL_ALWAYS_INLINE
#define SPV_MATRIX_USE_TRAITS(USE, SPV_USE)
#define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT)
#define OP(op)
__SYCL_ALWAYS_INLINE __spv::MatrixLayout joint_matrix_layout_to_spv(sycl::ext::oneapi::experimental::matrix::layout Layout)
constexpr tuple< Ts... > make_tuple(Ts... Args)
Definition: tuple.hpp:35
sycl::ext::oneapi::bfloat16 bfloat16
__SYCL_ALWAYS_INLINE void joint_matrix_store(Group, const sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, Tp, Use, NumRows, NumCols, Layout > &src, ext::oneapi::experimental::annotated_ptr< T, PropertyListT > dst, size_t stride)
__SYCL_ALWAYS_INLINE void joint_matrix_load_checked(Group sg, joint_matrix< Group, S, Use, NumRows, NumCols, Layout > &Res, ext::oneapi::experimental::annotated_ptr< T, PropertyListT > Src, size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY)
__SYCL_ALWAYS_INLINE void joint_matrix_apply(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, Rows, Cols, Layout > &jm, F &&lambda)
__SYCL_ALWAYS_INLINE void joint_matrix_store_checked(Group sg, const joint_matrix< Group, Tp, Use, NumRows, NumCols, Layout > &Src, ext::oneapi::experimental::annotated_ptr< T, PropertyListT > Dst, size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY)
__SYCL_ALWAYS_INLINE void joint_matrix_fill_checked(Group, joint_matrix< Group, T, Use, NumRows, NumCols, Layout > &Res, const T2 &Value, size_t Height, size_t Width, size_t CoordX, size_t CoordY)
decltype(auto) __SYCL_ALWAYS_INLINE get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, Rows, Cols, Layout > &jm)
std::enable_if_t< detail::is_bf16_storage_type< T >::value, T > fabs(T x)
Definition: access.hpp:18