DPC++ Runtime
Runtime libraries for oneAPI DPC++
matrix-intel.hpp
Go to the documentation of this file.
1 //==------------------ matrix-intel.hpp - SYCL matrix ----------*- C++ -*---==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // ===--------------------------------------------------------------------=== //
8 
9 #pragma once
10 
11 #include "matrix-unified-utils.hpp" // for use, layout, tf32, matrix
12 #include "utils.hpp" // for getDecorated
13 
14 #include <CL/__spirv/spirv_types.hpp> // for MatrixLayout, MatrixUse
15 #include <sycl/access/access.hpp> // for address_space, decorated
16 #include <sycl/builtins.hpp> // for fabs
17 #include <sycl/detail/defines_elementary.hpp> // for __SYCL_ALWAYS_INLINE
18 #include <sycl/detail/pi.h> // for PI_ERROR_INVALID_DEVICE
19 #include <sycl/exception.hpp> // for runtime_error
20 #include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16
22 #include <sycl/group.hpp> // for group
23 #include <sycl/multi_ptr.hpp> // for multi_ptr
24 #include <sycl/sub_group.hpp> // for sub_group
25 
26 #include <cstddef> // for size_t
27 #include <stdint.h> // for uint32_t
28 #include <tuple> // for ignore, tuple, _Swallo...
29 #include <type_traits> // for enable_if_t
30 
31 namespace sycl {
32 inline namespace _V1 {
33 namespace ext {
34 namespace oneapi {
35 namespace experimental {
36 namespace matrix {
37 
38 template <layout Layout> struct spv_matrix_layout_traits {
40 };
41 
42 #define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT) \
43  template <> struct spv_matrix_layout_traits<LAYOUT> { \
44  static constexpr __spv::MatrixLayout value = SPV_LAYOUT; \
45  };
46 
51 
52 template <use Use> struct spv_matrix_use_traits {
53  static constexpr __spv::MatrixUse value = __spv::MatrixUse::MatrixA;
54 };
55 
56 #define SPV_MATRIX_USE_TRAITS(USE, SPV_USE) \
57  template <> struct spv_matrix_use_traits<USE> { \
58  static constexpr __spv::MatrixUse value = SPV_USE; \
59  };
60 
64 
65 template <typename G> struct spv_scope_traits {};
66 template <> struct spv_scope_traits<sycl::sub_group> {
67  constexpr static auto value = __spv::Scope::Subgroup;
68 };
69 template <int D> struct spv_scope_traits<sycl::group<D>> {
70  constexpr static auto value = __spv::Scope::Workgroup;
71 };
72 
73 // forward declarations
74 template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
75  layout Layout>
76 struct joint_matrix;
77 
78 } // namespace matrix
79 } // namespace experimental
80 
81 namespace detail {
82 // Differentiating between the "element type" and the "storage element type"
83 template <typename T> struct jm_type_interpretation_helper_trait {
84  using element_type = T;
86 };
87 
88 template <>
92  using storage_element_type = float;
93 };
94 
95 using namespace sycl::ext::oneapi::experimental::matrix;
96 // Begin wi_element definition
97 
98 template <typename T, size_t NumRows, size_t NumCols,
101  sycl::ext::oneapi::experimental::matrix::layout::dynamic,
102  typename Group = sycl::sub_group>
103 class wi_element {
105  NumCols, Layout> &M;
106  std::size_t idx;
107 
108 public:
113  Group, T, Use, NumRows, NumCols, Layout> &Mat,
114  std::size_t i)
115  : M(Mat), idx(i) {}
116 
117  inline __SYCL_ALWAYS_INLINE std::tuple<size_t, size_t> get_coord() {
118 #if defined(__SYCL_DEVICE_ONLY__)
119  __ocl_vec_t<uint32_t, 2> coord =
120  __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
121  const size_t row = coord[0];
122  const size_t col = coord[1];
123  return std::make_tuple(row, col);
124 #else
125  throw runtime_error("joint matrix is not supported on host device.",
126  PI_ERROR_INVALID_DEVICE);
127 #endif // __SYCL_DEVICE_ONLY__
128  }
129 
130  operator storage_element_type() {
131 #ifdef __SYCL_DEVICE_ONLY__
133  __spirv_VectorExtractDynamic<storage_element_type, T, NumRows, NumCols,
134  spv_matrix_use_traits<Use>::value,
135  spv_matrix_layout_traits<Layout>::value,
136  spv_scope_traits<Group>::value>(M.spvm,
137  idx);
138  return elem;
139 #else
140  throw runtime_error("joint matrix is not supported on host device.",
141  PI_ERROR_INVALID_DEVICE);
142 #endif // __SYCL_DEVICE_ONLY__
143  }
144 
145  explicit operator bool() {
146 #ifdef __SYCL_DEVICE_ONLY__
147  return __spirv_VectorExtractDynamic<storage_element_type, T, NumRows,
148  NumCols,
149  spv_matrix_use_traits<Use>::value,
150  spv_matrix_layout_traits<Layout>::value,
151  spv_scope_traits<Group>::value>(
152  M.spvm, idx) != static_cast<storage_element_type>(0);
153 #else
154  throw runtime_error("joint matrix is not supported on host device.",
155  PI_ERROR_INVALID_DEVICE);
156 #endif // __SYCL_DEVICE_ONLY__
157  }
158 
159  template <typename T2> wi_element &operator=(const T2 &rhs) {
160 #ifdef __SYCL_DEVICE_ONLY__
161  M.spvm = __spirv_VectorInsertDynamic(
162  M.spvm, static_cast<storage_element_type>(rhs), idx);
163  return *this;
164 #else
165  (void)rhs;
166  throw runtime_error("joint matrix is not supported on host device.",
167  PI_ERROR_INVALID_DEVICE);
168 #endif // __SYCL_DEVICE_ONLY__
169  }
170 
171  wi_element &
173 #ifdef __SYCL_DEVICE_ONLY__
174  M.spvm = __spirv_VectorInsertDynamic(
175  M.spvm,
176  __spirv_VectorExtractDynamic<storage_element_type, T, NumRows, NumCols,
177  spv_matrix_use_traits<Use>::value,
178  spv_matrix_layout_traits<Layout>::value,
179  spv_scope_traits<Group>::value>(rhs.M.spvm,
180  rhs.idx),
181  idx);
182  return *this;
183 #else
184  (void)rhs;
185  throw runtime_error("joint matrix is not supported on host device.",
186  PI_ERROR_INVALID_DEVICE);
187 #endif // __SYCL_DEVICE_ONLY__
188  }
189 
190 #if __SYCL_DEVICE_ONLY__
191 #define OP(op) \
192  template <typename T2> wi_element &operator op##=(const T2 & rhs) { \
193  M.spvm = __spirv_VectorInsertDynamic( \
194  M.spvm, \
195  static_cast<storage_element_type>( \
196  __spirv_VectorExtractDynamic< \
197  storage_element_type, T, NumRows, NumCols, \
198  spv_matrix_use_traits<Use>::value, \
199  spv_matrix_layout_traits<Layout>::value, \
200  spv_scope_traits<Group>::value>(M.spvm, idx) \
201  op static_cast<storage_element_type>(rhs)), \
202  idx); \
203  return *this; \
204  }
205 #else // __SYCL_DEVICE_ONLY__
206 #define OP(op) \
207  template <typename T2> wi_element &operator op##=(const T2 & rhs) { \
208  (void)rhs; \
209  throw runtime_error("joint matrix is not supported on host device.", \
210  PI_ERROR_INVALID_DEVICE); \
211  }
212 #endif // __SYCL_DEVICE_ONLY__
213  OP(+)
214  OP(-)
215  OP(*)
216  OP(/)
217 #undef OP
218 };
219 
220 template <size_t NumRows, size_t NumCols,
223  typename Group>
224 class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
225  Group> {
227  Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols, Layout> &M;
228  std::size_t idx;
229 
230 public:
232  Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols,
233  Layout> &Mat,
234  std::size_t i)
235  : M(Mat), idx(i) {}
236 
237  inline __SYCL_ALWAYS_INLINE std::tuple<uint32_t, uint32_t> get_coord() {
238 #if defined(__SYCL_DEVICE_ONLY__)
239  __ocl_vec_t<uint32_t, 2> coord =
240  __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
241  const uint32_t row = coord[0];
242  const uint32_t col = coord[1];
243  return std::make_tuple(row, col);
244 #else
245  throw runtime_error("joint matrix is not supported on host device.",
246  PI_ERROR_INVALID_DEVICE);
247 #endif // __SYCL_DEVICE_ONLY__
248  }
249 
251 #ifdef __SYCL_DEVICE_ONLY__
252  return __spirv_VectorExtractDynamic<
254  NumCols, spv_matrix_use_traits<Use>::value,
255  spv_matrix_layout_traits<Layout>::value,
256  spv_scope_traits<Group>::value>(M.spvm, idx);
257 #else
258  throw runtime_error("joint matrix is not supported on host device.",
259  PI_ERROR_INVALID_DEVICE);
260 #endif // __SYCL_DEVICE_ONLY__
261  }
262 
263  explicit operator bool() {
264 #ifdef __SYCL_DEVICE_ONLY__
265  return sycl::fabs(static_cast<float>(
266  __spirv_VectorExtractDynamic<
268  NumRows, NumCols, spv_matrix_use_traits<Use>::value,
269  spv_matrix_layout_traits<Layout>::value,
270  spv_scope_traits<Group>::value>(M.spvm, idx))) >=
271  std::numeric_limits<float>::epsilon();
272 #else
273  throw runtime_error("joint matrix is not supported on host device.",
274  PI_ERROR_INVALID_DEVICE);
275 #endif // __SYCL_DEVICE_ONLY__
276  }
277 
279 #ifdef __SYCL_DEVICE_ONLY__
280  M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx);
281  return *this;
282 #else
283  (void)rhs;
284  throw runtime_error("joint matrix is not supported on host device.",
285  PI_ERROR_INVALID_DEVICE);
286 #endif // __SYCL_DEVICE_ONLY__
287  }
288 
290  NumCols, Use, Layout, Group> &rhs) {
291 #ifdef __SYCL_DEVICE_ONLY__
292  M.spvm = __spirv_VectorInsertDynamic(
293  M.spvm,
294  __spirv_VectorExtractDynamic<sycl::ext::oneapi::bfloat16,
296  NumCols, spv_matrix_use_traits<Use>::value,
297  spv_matrix_layout_traits<Layout>::value,
298  spv_scope_traits<Group>::value>(rhs.M.spvm,
299  rhs.idx),
300  idx);
301  return *this;
302 #else
303  (void)rhs;
304  throw runtime_error("joint matrix is not supported on host device.",
305  PI_ERROR_INVALID_DEVICE);
306 #endif // __SYCL_DEVICE_ONLY__
307  }
308 
309 #if __SYCL_DEVICE_ONLY__
310 #define OP(opassign, op) \
311  wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \
312  M.spvm = __spirv_VectorInsertDynamic( \
313  M.spvm, \
314  __spirv_VectorExtractDynamic< \
315  sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
316  NumCols, spv_matrix_use_traits<Use>::value, \
317  spv_matrix_layout_traits<Layout>::value, \
318  spv_scope_traits<Group>::value>(M.spvm, idx) op rhs, \
319  idx); \
320  return *this; \
321  }
322 #else // __SYCL_DEVICE_ONLY__
323 #define OP(opassign, op) \
324  wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \
325  (void)rhs; \
326  throw runtime_error("joint matrix is not supported on host device.", \
327  PI_ERROR_INVALID_DEVICE); \
328  }
329 #endif // __SYCL_DEVICE_ONLY__
330  OP(+=, +)
331  OP(-=, -)
332  OP(*=, *)
333  OP(/=, /)
334 #undef OP
335 
336 #if __SYCL_DEVICE_ONLY__
337 #define OP(type, op) \
338  friend type operator op( \
339  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
340  Layout, Group> &lhs, \
341  const sycl::ext::oneapi::bfloat16 &rhs) { \
342  return __spirv_VectorExtractDynamic< \
343  sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
344  NumCols, spv_matrix_use_traits<Use>::value, \
345  spv_matrix_layout_traits<Layout>::value, \
346  spv_scope_traits<Group>::value>(lhs.M.spvm, lhs.idx) op rhs; \
347  } \
348  friend type operator op( \
349  const sycl::ext::oneapi::bfloat16 &lhs, \
350  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
351  Layout, Group> &rhs) { \
352  return __spirv_VectorExtractDynamic< \
353  sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
354  NumCols, spv_matrix_use_traits<Use>::value, \
355  spv_matrix_layout_traits<Layout>::value, \
356  spv_scope_traits<Group>::value>(rhs.M.spvm, rhs.idx) op lhs; \
357  }
359  OP(sycl::ext::oneapi::bfloat16, -)
360  OP(sycl::ext::oneapi::bfloat16, *)
361  OP(sycl::ext::oneapi::bfloat16, /)
362 #undef OP
363 #define OP(type, op) \
364  friend type operator op( \
365  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
366  Layout, Group> &lhs, \
367  const sycl::ext::oneapi::bfloat16 &rhs) { \
368  return type{static_cast<float>( \
369  __spirv_VectorExtractDynamic< \
370  sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
371  NumCols, spv_matrix_use_traits<Use>::value, \
372  spv_matrix_layout_traits<Layout>::value, \
373  spv_scope_traits<Group>::value>(lhs.M.spvm, lhs.idx)) \
374  op static_cast<float>(rhs)}; \
375  } \
376  friend type operator op( \
377  const sycl::ext::oneapi::bfloat16 &lhs, \
378  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
379  Layout, Group> &rhs) { \
380  return type{static_cast<float>( \
381  __spirv_VectorExtractDynamic< \
382  sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
383  NumCols, spv_matrix_use_traits<Use>::value, \
384  spv_matrix_layout_traits<Layout>::value, \
385  spv_scope_traits<Group>::value>(rhs.M.spvm, rhs.idx)) \
386  op static_cast<float>(lhs)}; \
387  }
388  OP(bool, ==)
389  OP(bool, !=)
390  OP(bool, <)
391  OP(bool, >)
392  OP(bool, <=)
393  OP(bool, >=)
394 #undef OP
395 #else // __SYCL_DEVICE_ONLY__
396 #define OP(type, op) \
397  friend type operator op( \
398  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
399  Layout, Group> &, \
400  const sycl::ext::oneapi::bfloat16 &) { \
401  throw runtime_error("joint matrix is not supported on host device.", \
402  PI_ERROR_INVALID_DEVICE); \
403  } \
404  friend type operator op( \
405  const sycl::ext::oneapi::bfloat16 &, \
406  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
407  Layout, Group> &) { \
408  throw runtime_error("joint matrix is not supported on host device.", \
409  PI_ERROR_INVALID_DEVICE); \
410  }
412  OP(sycl::ext::oneapi::bfloat16, -)
413  OP(sycl::ext::oneapi::bfloat16, *)
414  OP(sycl::ext::oneapi::bfloat16, /)
415  OP(bool, ==)
416  OP(bool, !=)
417  OP(bool, <)
418  OP(bool, >)
419  OP(bool, <=)
420  OP(bool, >=)
421 #undef OP
422 #endif // __SYCL_DEVICE_ONLY__
423 };
424 
425 // End wi_element definition
426 
427 // Begin wi_data definition
428 
429 template <typename Group, typename T,
432 class wi_data {
433 
435  Cols, Layout> &jm;
436 
438  Group, T, Use, Rows, Cols, Layout> &_jm)
439  : jm(_jm){};
440 
441  template <typename Grp, typename Type,
442  sycl::ext::oneapi::experimental::matrix::use UseJm, size_t NumRows,
443  size_t NumCols,
445  friend decltype(auto)
447  Grp, Type, UseJm, NumRows, NumCols, LayoutJm> &);
448 
449 public:
450  size_t length() {
451 #if __SYCL_DEVICE_ONLY__
452  return __spirv_JointMatrixWorkItemLengthINTEL(jm.spvm);
453 #else
454  throw runtime_error("joint matrix is not supported on host device.",
455  PI_ERROR_INVALID_DEVICE);
456 #endif
457  };
458 
459  decltype(auto) operator[](size_t i) {
461  };
462 };
463 
464 template <typename Group, typename T,
467 inline __SYCL_ALWAYS_INLINE decltype(auto)
468 get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix<
469  Group, T, Use, Rows, Cols, Layout> &jm) {
470  std::ignore = sg;
471  return wi_data(jm);
472 }
473 
474 // End wi_data definition
475 } // namespace detail
476 } // namespace oneapi
477 
478 namespace intel::experimental::matrix {
479 template <
480  typename Group, typename T, typename Tp,
483  access::address_space Space, access::decorated IsDecorated,
484  std::enable_if_t<Use == sycl::ext::oneapi::experimental::matrix::use::a ||
486  bool> = true>
487 inline __SYCL_ALWAYS_INLINE void
490  Group, Tp, Use, NumRows, NumCols, Layout> &src,
491  multi_ptr<T, Space, IsDecorated> dst, size_t stride) {
492 #if defined(__SYCL_DEVICE_ONLY__)
493  static_assert(Space != access::address_space::private_space,
494  "Joint Matrix doesn't support store to private memory!");
495 #if defined(__NVPTX__)
496  std::ignore = src;
497  std::ignore = dst;
498  std::ignore = stride;
499  throw runtime_error(
500  "This version of the matrix extension is only currently supported on "
501  "intel devices",
502  PI_ERROR_INVALID_DEVICE);
503 #else
504  // intel's impl
505  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
506  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(dst);
507  __spirv_JointMatrixStoreINTEL<DecorT, Tp, NumRows, NumCols,
512  Ptr, src.spvm, stride,
514  Layout>::value,
516 #endif // defined(__NVPTX__)
517 #else
518  std::ignore = src;
519  std::ignore = dst;
520  std::ignore = stride;
521  throw runtime_error("joint matrix is not supported on host device.",
522  PI_ERROR_INVALID_DEVICE);
523 #endif // defined(__SYCL_DEVICE_ONLY__)
524 }
525 
526 template <
527  typename Group, typename T, typename Tp,
530  typename PropertyListT,
531  std::enable_if_t<Use == sycl::ext::oneapi::experimental::matrix::use::a ||
533  bool> = true>
535  Group,
537  Group, Tp, Use, NumRows, NumCols, Layout> &src,
539  size_t stride) {
540 #if defined(__SYCL_DEVICE_ONLY__)
541 #if defined(__NVPTX__)
542  std::ignore = src;
543  std::ignore = dst;
544  std::ignore = stride;
545  throw runtime_error(
546  "This version of the matrix extension is only currently supported on "
547  "intel devices",
548  PI_ERROR_INVALID_DEVICE);
549 #else
550  // intel's impl
551  T *Ptr = dst.get();
552  __spirv_JointMatrixStoreINTEL<T, Tp, NumRows, NumCols,
557  Ptr, src.spvm, stride,
559  Layout>::value,
561 #endif // defined(__NVPTX__)
562 #else
563  std::ignore = src;
564  std::ignore = dst;
565  std::ignore = stride;
566  throw runtime_error("joint matrix is not supported on host device.",
567  PI_ERROR_INVALID_DEVICE);
568 #endif // defined(__SYCL_DEVICE_ONLY__)
569 }
570 
571 template <typename Group, typename T,
574  typename F>
576  Group sg,
578  Cols, Layout> &jm,
579  F &&lambda) {
580 #if defined(__SYCL_DEVICE_ONLY__)
581 #if defined(__NVPTX__)
582  std::ignore = sg;
583  for (int i = 0; i < jm.matrix_impl.wi_marray.size(); i++) {
584  lambda(jm.matrix_impl.wi_marray[i]);
585  }
586 #else // NVPTX
587  using storage_element_type =
589  T>::storage_element_type;
590  auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jm);
591  for (int i = 0; i < wi_data_c.length(); i++) {
592  storage_element_type element = wi_data_c[i];
593  auto [row, col] = wi_data_c[i].get_coord();
594  lambda(element, row, col);
595  wi_data_c[i] = element;
596  }
597 #endif
598 #else
599  std::ignore = sg;
600  std::ignore = jm;
601  std::ignore = lambda;
602  throw runtime_error("joint matrix is not supported on host device.",
603  PI_ERROR_INVALID_DEVICE);
604 #endif
605 }
606 
607 using namespace sycl::ext::oneapi::experimental::matrix;
608 
609 // Begin out-of-bounds API
610 
611 template <typename Group, typename T, size_t NumRows, size_t NumCols, use Use,
612  layout Layout, typename T2>
614  Group, joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &Res,
615  const T2 &Value, size_t Stride, size_t Height, size_t Width, size_t CoordX,
616  size_t CoordY) {
617 #if defined(__SYCL_DEVICE_ONLY__)
618  using storage_element_type =
620  T>::storage_element_type;
621  Res.spvm = __spirv_CooperativeMatrixConstructCheckedINTEL<
622  storage_element_type, T, NumRows, NumCols,
623  spv_matrix_use_traits<Use>::value,
624  spv_matrix_layout_traits<Layout>::value>(
625  static_cast<storage_element_type>(Value), Stride, Height, Width, CoordX,
626  CoordY);
627 #else
628  std::ignore = Res;
629  std::ignore = Value;
630  std::ignore = Stride;
631  std::ignore = Height;
632  std::ignore = Width;
633  std::ignore = CoordX;
634  std::ignore = CoordY;
635  throw runtime_error("joint matrix is not supported on host device.",
636  PI_ERROR_INVALID_DEVICE);
637 #endif // defined(__SYCL_DEVICE_ONLY__)
638 }
639 
640 template <
641  typename Group, typename S, typename T, size_t NumRows, size_t NumCols,
642  access::address_space Space, access::decorated IsDecorated,
643  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value, bool> =
644  true>
646  Group sg,
647  joint_matrix<Group, S, use::accumulator, NumRows, NumCols, layout::dynamic>
648  &Res,
649  multi_ptr<T, Space, IsDecorated> Src, size_t Stride, layout Layout,
650  size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
651 #if defined(__SYCL_DEVICE_ONLY__)
652  static_assert(Space != access::address_space::private_space,
653  "Joint Matrix doesn't support load from private memory!");
654  std::ignore = sg;
655  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
656  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
657  Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
658  DecorT, S, NumRows, NumCols,
659  spv_matrix_use_traits<use::accumulator>::value,
660  spv_matrix_layout_traits<layout::dynamic>::value>(
661  Ptr, Stride, Height, Width, CoordX, CoordY,
663  spv_scope_traits<Group>::value);
664 #else
665  std::ignore = sg;
666  std::ignore = Res;
667  std::ignore = Src;
668  std::ignore = Stride;
669  std::ignore = Height;
670  std::ignore = Width;
671  std::ignore = Layout;
672  std::ignore = CoordX;
673  std::ignore = CoordY;
674  throw runtime_error("joint matrix is not supported on host device.",
675  PI_ERROR_INVALID_DEVICE);
676 #endif // defined(__SYCL_DEVICE_ONLY__)
677 }
678 
679 template <
680  typename Group, typename S, typename T, use Use, size_t NumRows,
681  size_t NumCols, layout Layout, access::address_space Space,
682  access::decorated IsDecorated,
683  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
684  (std::is_same<S, precision::tf32>::value &&
685  std::is_same<std::remove_const_t<T>, float>::value),
686  bool> = true>
688  Group sg, joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &Res,
689  multi_ptr<T, Space, IsDecorated> Src, size_t Stride, size_t Height,
690  size_t Width, size_t CoordX, size_t CoordY) {
691 #if defined(__SYCL_DEVICE_ONLY__)
692  static_assert(Space != access::address_space::private_space,
693  "Joint Matrix doesn't support load from private memory!");
694  std::ignore = sg;
695  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
696  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
697  Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
698  DecorT, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
699  spv_matrix_layout_traits<Layout>::value>(
700  Ptr, Stride, Height, Width, CoordX, CoordY,
701  spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
702 #else
703  std::ignore = sg;
704  std::ignore = Res;
705  std::ignore = Src;
706  std::ignore = Stride;
707  std::ignore = Height;
708  std::ignore = Width;
709  std::ignore = CoordX;
710  std::ignore = CoordY;
711  throw runtime_error("joint matrix is not supported on host device.",
712  PI_ERROR_INVALID_DEVICE);
713 #endif // defined(__SYCL_DEVICE_ONLY__)
714 }
715 
716 template <typename Group, typename T, size_t NumRows, size_t NumCols,
717  access::address_space Space, access::decorated IsDecorated>
719  Group sg,
720  joint_matrix<Group, T, use::accumulator, NumRows, NumCols, layout::dynamic>
721  &Src,
722  multi_ptr<T, Space, IsDecorated> Dst, size_t Stride, layout Layout,
723  size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
724 #if defined(__SYCL_DEVICE_ONLY__)
725  static_assert(Space != access::address_space::private_space,
726  "Joint Matrix doesn't support store to private memory!");
727  std::ignore = sg;
728  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
729  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
730  __spirv_JointMatrixStoreCheckedINTEL<
731  DecorT, T, NumRows, NumCols,
732  spv_matrix_use_traits<use::accumulator>::value,
733  spv_matrix_layout_traits<layout::dynamic>::value>(
734  Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY,
736  spv_scope_traits<Group>::value);
737 #else
738  std::ignore = sg;
739  std::ignore = Src;
740  std::ignore = Dst;
741  std::ignore = Stride;
742  std::ignore = Height;
743  std::ignore = Width;
744  std::ignore = Layout;
745  std::ignore = CoordX;
746  std::ignore = CoordY;
747  throw runtime_error("joint matrix is not supported on host device.",
748  PI_ERROR_INVALID_DEVICE);
749 #endif // defined(__SYCL_DEVICE_ONLY__)
750 }
751 
752 template <typename Group, typename T, typename Tp, use Use, size_t NumRows,
753  size_t NumCols, layout Layout, access::address_space Space,
754  access::decorated IsDecorated,
755  std::enable_if_t<Use == use::a || Use == use::b, bool> = true>
757  Group sg, const joint_matrix<Group, Tp, Use, NumRows, NumCols, Layout> &Src,
758  multi_ptr<T, Space, IsDecorated> Dst, size_t Stride, size_t Height,
759  size_t Width, size_t CoordX, size_t CoordY) {
760 #if defined(__SYCL_DEVICE_ONLY__)
761  static_assert(Space != access::address_space::private_space,
762  "Joint Matrix doesn't support store to private memory!");
763  std::ignore = sg;
764  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
765  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
766  __spirv_JointMatrixStoreCheckedINTEL<DecorT, Tp, NumRows, NumCols,
767  spv_matrix_use_traits<Use>::value,
768  spv_matrix_layout_traits<Layout>::value>(
769  Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY,
770  spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
771 #else
772  std::ignore = sg;
773  std::ignore = Src;
774  std::ignore = Dst;
775  std::ignore = Stride;
776  std::ignore = Height;
777  std::ignore = Width;
778  std::ignore = CoordX;
779  std::ignore = CoordY;
780  throw runtime_error("joint matrix is not supported on host device.",
781  PI_ERROR_INVALID_DEVICE);
782 #endif // defined(__SYCL_DEVICE_ONLY__)
783 }
784 
785 // Annotated pointer overloads:
786 template <typename Group, typename S, typename T, size_t NumRows,
787  size_t NumCols, typename PropertyListT,
788  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value,
789  bool> = true>
791  Group sg,
792  joint_matrix<Group, S, use::accumulator, NumRows, NumCols, layout::dynamic>
793  &Res,
795  size_t Stride, layout Layout, size_t Height, size_t Width, size_t CoordX,
796  size_t CoordY) {
797 #if defined(__SYCL_DEVICE_ONLY__)
798  std::ignore = sg;
799  T *Ptr = Src.get();
800  Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
801  T, S, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
802  spv_matrix_layout_traits<layout::dynamic>::value>(
803  Ptr, Stride, Height, Width, CoordX, CoordY,
805  spv_scope_traits<Group>::value);
806 #else
807  std::ignore = sg;
808  std::ignore = Res;
809  std::ignore = Src;
810  std::ignore = Stride;
811  std::ignore = Height;
812  std::ignore = Width;
813  std::ignore = Layout;
814  std::ignore = CoordX;
815  std::ignore = CoordY;
816  throw runtime_error("joint matrix is not supported on host device.",
817  PI_ERROR_INVALID_DEVICE);
818 #endif // defined(__SYCL_DEVICE_ONLY__)
819 }
820 
821 template <
822  typename Group, typename S, typename T, use Use, size_t NumRows,
823  size_t NumCols, layout Layout, typename PropertyListT,
824  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
825  (std::is_same<S, precision::tf32>::value &&
826  std::is_same<std::remove_const_t<T>, float>::value),
827  bool> = true>
829  Group sg, joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &Res,
831  size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
832 #if defined(__SYCL_DEVICE_ONLY__)
833  std::ignore = sg;
834  T *Ptr = Src.get();
835  Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
836  T, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
837  spv_matrix_layout_traits<Layout>::value>(
838  Ptr, Stride, Height, Width, CoordX, CoordY,
839  spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
840 #else
841  std::ignore = sg;
842  std::ignore = Res;
843  std::ignore = Src;
844  std::ignore = Stride;
845  std::ignore = Height;
846  std::ignore = Width;
847  std::ignore = CoordX;
848  std::ignore = CoordY;
849  throw runtime_error("joint matrix is not supported on host device.",
850  PI_ERROR_INVALID_DEVICE);
851 #endif // defined(__SYCL_DEVICE_ONLY__)
852 }
853 
854 template <typename Group, typename T, size_t NumRows, size_t NumCols,
855  typename PropertyListT>
857  Group sg,
858  joint_matrix<Group, T, use::accumulator, NumRows, NumCols, layout::dynamic>
859  &Src,
861  size_t Stride, layout Layout, size_t Height, size_t Width, size_t CoordX,
862  size_t CoordY) {
863 #if defined(__SYCL_DEVICE_ONLY__)
864  std::ignore = sg;
865  T *Ptr = Dst.get();
866  __spirv_JointMatrixStoreCheckedINTEL<
867  T, T, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
868  spv_matrix_layout_traits<layout::dynamic>::value>(
869  Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY,
871  spv_scope_traits<Group>::value);
872 #else
873  std::ignore = sg;
874  std::ignore = Src;
875  std::ignore = Dst;
876  std::ignore = Stride;
877  std::ignore = Height;
878  std::ignore = Width;
879  std::ignore = Layout;
880  std::ignore = CoordX;
881  std::ignore = CoordY;
882  throw runtime_error("joint matrix is not supported on host device.",
883  PI_ERROR_INVALID_DEVICE);
884 #endif // defined(__SYCL_DEVICE_ONLY__)
885 }
886 
887 template <typename Group, typename T, typename Tp, use Use, size_t NumRows,
888  size_t NumCols, layout Layout, typename PropertyListT,
889  std::enable_if_t<Use == use::a || Use == use::b, bool> = true>
891  Group sg, const joint_matrix<Group, Tp, Use, NumRows, NumCols, Layout> &Src,
893  size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
894 #if defined(__SYCL_DEVICE_ONLY__)
895  std::ignore = sg;
896  T *Ptr = Dst.get();
897  __spirv_JointMatrixStoreCheckedINTEL<T, Tp, NumRows, NumCols,
898  spv_matrix_use_traits<Use>::value,
899  spv_matrix_layout_traits<Layout>::value>(
900  Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY,
901  spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
902 #else
903  std::ignore = sg;
904  std::ignore = Src;
905  std::ignore = Dst;
906  std::ignore = Stride;
907  std::ignore = Height;
908  std::ignore = Width;
909  std::ignore = CoordX;
910  std::ignore = CoordY;
911  throw runtime_error("joint matrix is not supported on host device.",
912  PI_ERROR_INVALID_DEVICE);
913 #endif // defined(__SYCL_DEVICE_ONLY__)
914 }
915 // End out-of-bounds API
916 
917 } // namespace intel::experimental::matrix
918 
919 } // namespace ext
920 } // namespace _V1
921 } // namespace sycl
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols, Layout > &Mat, std::size_t i)
wi_element & operator=(const wi_element< sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout, Group > &rhs)
wi_element & operator=(const wi_element< T, NumRows, NumCols, Use, Layout, Group > &rhs)
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, NumRows, NumCols, Layout > &Mat, std::size_t i)
wi_element & operator=(const T2 &rhs)
typename oneapi::detail::jm_type_interpretation_helper_trait< T >::storage_element_type storage_element_type
__SYCL_ALWAYS_INLINE std::tuple< size_t, size_t > get_coord()
#define __SYCL_ALWAYS_INLINE
#define SPV_MATRIX_USE_TRAITS(USE, SPV_USE)
#define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT)
#define OP(op)
__SYCL_ALWAYS_INLINE __spv::MatrixLayout joint_matrix_layout_to_spv(sycl::ext::oneapi::experimental::matrix::layout Layout)
constexpr tuple< Ts... > make_tuple(Ts... Args)
Definition: tuple.hpp:35
__SYCL_ALWAYS_INLINE void joint_matrix_fill_checked(Group, joint_matrix< Group, T, Use, NumRows, NumCols, Layout > &Res, const T2 &Value, size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY)
__SYCL_ALWAYS_INLINE void joint_matrix_store(Group, const sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, Tp, Use, NumRows, NumCols, Layout > &src, ext::oneapi::experimental::annotated_ptr< T, PropertyListT > dst, size_t stride)
__SYCL_ALWAYS_INLINE void joint_matrix_load_checked(Group sg, joint_matrix< Group, S, Use, NumRows, NumCols, Layout > &Res, ext::oneapi::experimental::annotated_ptr< T, PropertyListT > Src, size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY)
__SYCL_ALWAYS_INLINE void joint_matrix_apply(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, Rows, Cols, Layout > &jm, F &&lambda)
__SYCL_ALWAYS_INLINE void joint_matrix_store_checked(Group sg, const joint_matrix< Group, Tp, Use, NumRows, NumCols, Layout > &Src, ext::oneapi::experimental::annotated_ptr< T, PropertyListT > Dst, size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY)
decltype(auto) __SYCL_ALWAYS_INLINE get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, Rows, Cols, Layout > &jm)
std::enable_if_t< detail::is_bf16_storage_type< T >::value, T > fabs(T x)
Definition: access.hpp:18