DPC++ Runtime
Runtime libraries for oneAPI DPC++
matrix-intel.hpp
Go to the documentation of this file.
1 //==------------------ matrix-intel.hpp - SYCL matrix ----------*- C++ -*---==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // ===--------------------------------------------------------------------=== //
8 
9 #pragma once
10 
11 #include "matrix-unified-utils.hpp" // for use, layout, tf32, matrix
12 #include "utils.hpp" // for getDecorated
13 
14 #include <CL/__spirv/spirv_types.hpp> // for MatrixLayout, MatrixUse
15 #include <sycl/access/access.hpp> // for address_space, decorated
16 #include <sycl/builtins.hpp> // for fabs
17 #include <sycl/detail/defines_elementary.hpp> // for __SYCL_ALWAYS_INLINE
18 #include <sycl/exception.hpp>
19 #include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16
21 #include <sycl/group.hpp> // for group
22 #include <sycl/multi_ptr.hpp> // for multi_ptr
23 #include <sycl/sub_group.hpp> // for sub_group
24 
25 #include <cstddef> // for size_t
26 #include <stdint.h> // for uint32_t
27 #include <tuple> // for ignore, tuple, _Swallo...
28 #include <type_traits> // for enable_if_t
29 
30 namespace sycl {
31 inline namespace _V1 {
32 namespace ext {
33 namespace oneapi {
34 namespace experimental {
35 namespace matrix {
36 
37 template <layout Layout> struct spv_matrix_layout_traits {
39 };
40 
41 #define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT) \
42  template <> struct spv_matrix_layout_traits<LAYOUT> { \
43  static constexpr __spv::MatrixLayout value = SPV_LAYOUT; \
44  };
45 
50 
51 template <use Use> struct spv_matrix_use_traits {
52  static constexpr __spv::MatrixUse value = __spv::MatrixUse::MatrixA;
53 };
54 
55 #define SPV_MATRIX_USE_TRAITS(USE, SPV_USE) \
56  template <> struct spv_matrix_use_traits<USE> { \
57  static constexpr __spv::MatrixUse value = SPV_USE; \
58  };
59 
63 
64 template <typename G> struct spv_scope_traits {};
65 template <> struct spv_scope_traits<sycl::sub_group> {
66  constexpr static auto value = __spv::Scope::Subgroup;
67 };
68 template <int D> struct spv_scope_traits<sycl::group<D>> {
69  constexpr static auto value = __spv::Scope::Workgroup;
70 };
71 
72 // forward declarations
73 template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
74  layout Layout>
75 struct joint_matrix;
76 
77 } // namespace matrix
78 } // namespace experimental
79 
80 namespace detail {
81 // Differentiating between the "element type" and the "storage element type"
82 template <typename T> struct jm_type_interpretation_helper_trait {
83  using element_type = T;
85 };
86 
87 template <>
91  using storage_element_type = float;
92 };
93 
94 using namespace sycl::ext::oneapi::experimental::matrix;
95 // Begin wi_element definition
96 
97 template <typename T, size_t NumRows, size_t NumCols,
100  sycl::ext::oneapi::experimental::matrix::layout::dynamic,
101  typename Group = sycl::sub_group>
102 class wi_element {
104  NumCols, Layout> &M;
105  std::size_t idx;
106 
107 public:
112  Group, T, Use, NumRows, NumCols, Layout> &Mat,
113  std::size_t i)
114  : M(Mat), idx(i) {}
115 
116  inline __SYCL_ALWAYS_INLINE std::tuple<size_t, size_t> get_coord() {
117 #if defined(__SYCL_DEVICE_ONLY__)
118  __ocl_vec_t<uint32_t, 2> coord =
119  __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
120  const size_t row = coord[0];
121  const size_t col = coord[1];
122  return std::make_tuple(row, col);
123 #else
125  "joint matrix is not supported on host.");
126 #endif // __SYCL_DEVICE_ONLY__
127  }
128 
129  operator storage_element_type() {
130 #ifdef __SYCL_DEVICE_ONLY__
132  __spirv_VectorExtractDynamic<storage_element_type, T, NumRows, NumCols,
133  spv_matrix_use_traits<Use>::value,
134  spv_matrix_layout_traits<Layout>::value,
135  spv_scope_traits<Group>::value>(M.spvm,
136  idx);
137  return elem;
138 #else
140  "joint matrix is not supported on host.");
141 #endif // __SYCL_DEVICE_ONLY__
142  }
143 
144  explicit operator bool() {
145 #ifdef __SYCL_DEVICE_ONLY__
146  return __spirv_VectorExtractDynamic<storage_element_type, T, NumRows,
147  NumCols,
148  spv_matrix_use_traits<Use>::value,
149  spv_matrix_layout_traits<Layout>::value,
150  spv_scope_traits<Group>::value>(
151  M.spvm, idx) != static_cast<storage_element_type>(0);
152 #else
154  "joint matrix is not supported on host.");
155 #endif // __SYCL_DEVICE_ONLY__
156  }
157 
158  template <typename T2> wi_element &operator=(const T2 &rhs) {
159 #ifdef __SYCL_DEVICE_ONLY__
160  M.spvm = __spirv_VectorInsertDynamic(
161  M.spvm, static_cast<storage_element_type>(rhs), idx);
162  return *this;
163 #else
164  (void)rhs;
166  "joint matrix is not supported on host.");
167 #endif // __SYCL_DEVICE_ONLY__
168  }
169 
170  wi_element &
172 #ifdef __SYCL_DEVICE_ONLY__
173  M.spvm = __spirv_VectorInsertDynamic(
174  M.spvm,
175  __spirv_VectorExtractDynamic<storage_element_type, T, NumRows, NumCols,
176  spv_matrix_use_traits<Use>::value,
177  spv_matrix_layout_traits<Layout>::value,
178  spv_scope_traits<Group>::value>(rhs.M.spvm,
179  rhs.idx),
180  idx);
181  return *this;
182 #else
183  (void)rhs;
185  "joint matrix is not supported on host.");
186 #endif // __SYCL_DEVICE_ONLY__
187  }
188 
189 #if __SYCL_DEVICE_ONLY__
190 #define OP(op) \
191  template <typename T2> wi_element &operator op##=(const T2 & rhs) { \
192  M.spvm = __spirv_VectorInsertDynamic( \
193  M.spvm, \
194  static_cast<storage_element_type>( \
195  __spirv_VectorExtractDynamic< \
196  storage_element_type, T, NumRows, NumCols, \
197  spv_matrix_use_traits<Use>::value, \
198  spv_matrix_layout_traits<Layout>::value, \
199  spv_scope_traits<Group>::value>(M.spvm, idx) \
200  op static_cast<storage_element_type>(rhs)), \
201  idx); \
202  return *this; \
203  }
204 #else // __SYCL_DEVICE_ONLY__
205 #define OP(op) \
206  template <typename T2> wi_element &operator op##=(const T2 & rhs) { \
207  (void)rhs; \
208  throw exception(make_error_code(errc::runtime), \
209  "joint matrix is not supported on host."); \
210  }
211 #endif // __SYCL_DEVICE_ONLY__
212  OP(+)
213  OP(-)
214  OP(*)
215  OP(/)
216 #undef OP
217 };
218 
219 template <size_t NumRows, size_t NumCols,
222  typename Group>
223 class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
224  Group> {
226  Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols, Layout> &M;
227  std::size_t idx;
228 
229 public:
231  Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols,
232  Layout> &Mat,
233  std::size_t i)
234  : M(Mat), idx(i) {}
235 
236  inline __SYCL_ALWAYS_INLINE std::tuple<uint32_t, uint32_t> get_coord() {
237 #if defined(__SYCL_DEVICE_ONLY__)
238  __ocl_vec_t<uint32_t, 2> coord =
239  __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
240  const uint32_t row = coord[0];
241  const uint32_t col = coord[1];
242  return std::make_tuple(row, col);
243 #else
245  "joint matrix is not supported on host.");
246 #endif // __SYCL_DEVICE_ONLY__
247  }
248 
250 #ifdef __SYCL_DEVICE_ONLY__
251  return __spirv_VectorExtractDynamic<
253  NumCols, spv_matrix_use_traits<Use>::value,
254  spv_matrix_layout_traits<Layout>::value,
255  spv_scope_traits<Group>::value>(M.spvm, idx);
256 #else
258  "joint matrix is not supported on host.");
259 #endif // __SYCL_DEVICE_ONLY__
260  }
261 
262  explicit operator bool() {
263 #ifdef __SYCL_DEVICE_ONLY__
264  return sycl::fabs(static_cast<float>(
265  __spirv_VectorExtractDynamic<
267  NumRows, NumCols, spv_matrix_use_traits<Use>::value,
268  spv_matrix_layout_traits<Layout>::value,
269  spv_scope_traits<Group>::value>(M.spvm, idx))) >=
270  std::numeric_limits<float>::epsilon();
271 #else
273  "joint matrix is not supported on host.");
274 #endif // __SYCL_DEVICE_ONLY__
275  }
276 
278 #ifdef __SYCL_DEVICE_ONLY__
279  M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx);
280  return *this;
281 #else
282  (void)rhs;
284  "joint matrix is not supported on host.");
285 #endif // __SYCL_DEVICE_ONLY__
286  }
287 
289  NumCols, Use, Layout, Group> &rhs) {
290 #ifdef __SYCL_DEVICE_ONLY__
291  M.spvm = __spirv_VectorInsertDynamic(
292  M.spvm,
293  __spirv_VectorExtractDynamic<sycl::ext::oneapi::bfloat16,
295  NumCols, spv_matrix_use_traits<Use>::value,
296  spv_matrix_layout_traits<Layout>::value,
297  spv_scope_traits<Group>::value>(rhs.M.spvm,
298  rhs.idx),
299  idx);
300  return *this;
301 #else
302  (void)rhs;
304  "joint matrix is not supported on host.");
305 #endif // __SYCL_DEVICE_ONLY__
306  }
307 
308 #if __SYCL_DEVICE_ONLY__
309 #define OP(opassign, op) \
310  wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \
311  M.spvm = __spirv_VectorInsertDynamic( \
312  M.spvm, \
313  __spirv_VectorExtractDynamic< \
314  sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
315  NumCols, spv_matrix_use_traits<Use>::value, \
316  spv_matrix_layout_traits<Layout>::value, \
317  spv_scope_traits<Group>::value>(M.spvm, idx) op rhs, \
318  idx); \
319  return *this; \
320  }
321 #else // __SYCL_DEVICE_ONLY__
322 #define OP(opassign, op) \
323  wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \
324  (void)rhs; \
325  throw exception(make_error_code(errc::runtime), \
326  "joint matrix is not supported on host."); \
327  }
328 #endif // __SYCL_DEVICE_ONLY__
329  OP(+=, +)
330  OP(-=, -)
331  OP(*=, *)
332  OP(/=, /)
333 #undef OP
334 
335 #if __SYCL_DEVICE_ONLY__
336 #define OP(type, op) \
337  friend type operator op( \
338  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
339  Layout, Group> &lhs, \
340  const sycl::ext::oneapi::bfloat16 &rhs) { \
341  return __spirv_VectorExtractDynamic< \
342  sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
343  NumCols, spv_matrix_use_traits<Use>::value, \
344  spv_matrix_layout_traits<Layout>::value, \
345  spv_scope_traits<Group>::value>(lhs.M.spvm, lhs.idx) op rhs; \
346  } \
347  friend type operator op( \
348  const sycl::ext::oneapi::bfloat16 &lhs, \
349  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
350  Layout, Group> &rhs) { \
351  return __spirv_VectorExtractDynamic< \
352  sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
353  NumCols, spv_matrix_use_traits<Use>::value, \
354  spv_matrix_layout_traits<Layout>::value, \
355  spv_scope_traits<Group>::value>(rhs.M.spvm, rhs.idx) op lhs; \
356  }
358  OP(sycl::ext::oneapi::bfloat16, -)
359  OP(sycl::ext::oneapi::bfloat16, *)
360  OP(sycl::ext::oneapi::bfloat16, /)
361 #undef OP
362 #define OP(type, op) \
363  friend type operator op( \
364  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
365  Layout, Group> &lhs, \
366  const sycl::ext::oneapi::bfloat16 &rhs) { \
367  return type{static_cast<float>( \
368  __spirv_VectorExtractDynamic< \
369  sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
370  NumCols, spv_matrix_use_traits<Use>::value, \
371  spv_matrix_layout_traits<Layout>::value, \
372  spv_scope_traits<Group>::value>(lhs.M.spvm, lhs.idx)) \
373  op static_cast<float>(rhs)}; \
374  } \
375  friend type operator op( \
376  const sycl::ext::oneapi::bfloat16 &lhs, \
377  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
378  Layout, Group> &rhs) { \
379  return type{static_cast<float>( \
380  __spirv_VectorExtractDynamic< \
381  sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
382  NumCols, spv_matrix_use_traits<Use>::value, \
383  spv_matrix_layout_traits<Layout>::value, \
384  spv_scope_traits<Group>::value>(rhs.M.spvm, rhs.idx)) \
385  op static_cast<float>(lhs)}; \
386  }
387  OP(bool, ==)
388  OP(bool, !=)
389  OP(bool, <)
390  OP(bool, >)
391  OP(bool, <=)
392  OP(bool, >=)
393 #undef OP
394 #else // __SYCL_DEVICE_ONLY__
395 #define OP(type, op) \
396  friend type operator op( \
397  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
398  Layout, Group> &, \
399  const sycl::ext::oneapi::bfloat16 &) { \
400  throw exception(make_error_code(errc::runtime), \
401  "joint matrix is not supported on host."); \
402  } \
403  friend type operator op( \
404  const sycl::ext::oneapi::bfloat16 &, \
405  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
406  Layout, Group> &) { \
407  throw exception(make_error_code(errc::runtime), \
408  "joint matrix is not supported on host."); \
409  }
411  OP(sycl::ext::oneapi::bfloat16, -)
412  OP(sycl::ext::oneapi::bfloat16, *)
413  OP(sycl::ext::oneapi::bfloat16, /)
414  OP(bool, ==)
415  OP(bool, !=)
416  OP(bool, <)
417  OP(bool, >)
418  OP(bool, <=)
419  OP(bool, >=)
420 #undef OP
421 #endif // __SYCL_DEVICE_ONLY__
422 };
423 
424 // End wi_element definition
425 
426 // Begin wi_data definition
427 
428 template <typename Group, typename T,
431 class wi_data {
432 
434  Cols, Layout> &jm;
435 
437  Group, T, Use, Rows, Cols, Layout> &_jm)
438  : jm(_jm){};
439 
440  template <typename Grp, typename Type,
441  sycl::ext::oneapi::experimental::matrix::use UseJm, size_t NumRows,
442  size_t NumCols,
444  friend decltype(auto)
446  Grp, Type, UseJm, NumRows, NumCols, LayoutJm> &);
447 
448 public:
449  size_t length() {
450 #if __SYCL_DEVICE_ONLY__
451  return __spirv_JointMatrixWorkItemLengthINTEL(jm.spvm);
452 #else
454  "joint matrix is not supported on host.");
455 #endif
456  };
457 
458  decltype(auto) operator[](size_t i) {
460  };
461 };
462 
463 template <typename Group, typename T,
466 inline __SYCL_ALWAYS_INLINE decltype(auto)
467 get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix<
468  Group, T, Use, Rows, Cols, Layout> &jm) {
469  std::ignore = sg;
470  return wi_data(jm);
471 }
472 
473 // End wi_data definition
474 } // namespace detail
475 } // namespace oneapi
476 
477 namespace intel::experimental::matrix {
478 template <
479  typename Group, typename T, typename Tp,
482  access::address_space Space, access::decorated IsDecorated,
483  std::enable_if_t<Use == sycl::ext::oneapi::experimental::matrix::use::a ||
485  bool> = true>
486 inline __SYCL_ALWAYS_INLINE void
489  Group, Tp, Use, NumRows, NumCols, Layout> &src,
490  multi_ptr<T, Space, IsDecorated> dst, size_t stride) {
491 #if defined(__SYCL_DEVICE_ONLY__)
492  static_assert(Space != access::address_space::private_space,
493  "Joint Matrix doesn't support store to private memory!");
494 #if defined(__NVPTX__)
495  std::ignore = src;
496  std::ignore = dst;
497  std::ignore = stride;
498  throw exception(
500  "This version of the matrix extension is only currently supported on "
501  "intel devices");
502 #else
503  // intel's impl
504  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
505  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(dst);
506  __spirv_JointMatrixStoreINTEL<DecorT, Tp, NumRows, NumCols,
511  Ptr, src.spvm, stride,
513  Layout>::value,
515 #endif // defined(__NVPTX__)
516 #else
517  std::ignore = src;
518  std::ignore = dst;
519  std::ignore = stride;
521  "joint matrix is not supported on host.");
522 #endif // defined(__SYCL_DEVICE_ONLY__)
523 }
524 
525 template <
526  typename Group, typename T, typename Tp,
529  typename PropertyListT,
530  std::enable_if_t<Use == sycl::ext::oneapi::experimental::matrix::use::a ||
532  bool> = true>
534  Group,
536  Group, Tp, Use, NumRows, NumCols, Layout> &src,
538  size_t stride) {
539 #if defined(__SYCL_DEVICE_ONLY__)
540 #if defined(__NVPTX__)
541  std::ignore = src;
542  std::ignore = dst;
543  std::ignore = stride;
544  throw exception(
546  "This version of the matrix extension is only currently supported on "
547  "intel devices");
548 #else
549  // intel's impl
550  T *Ptr = dst.get();
551  __spirv_JointMatrixStoreINTEL<T, Tp, NumRows, NumCols,
556  Ptr, src.spvm, stride,
558  Layout>::value,
560 #endif // defined(__NVPTX__)
561 #else
562  std::ignore = src;
563  std::ignore = dst;
564  std::ignore = stride;
566  "joint matrix is not supported on host.");
567 #endif // defined(__SYCL_DEVICE_ONLY__)
568 }
569 
570 template <typename Group, typename T,
573  typename F>
575  Group sg,
577  Cols, Layout> &jm,
578  F &&lambda) {
579 #if defined(__SYCL_DEVICE_ONLY__)
580 #if defined(__NVPTX__)
581  std::ignore = sg;
582  for (int i = 0; i < jm.matrix_impl.wi_marray.size(); i++) {
583  lambda(jm.matrix_impl.wi_marray[i]);
584  }
585 #else // NVPTX
586  using storage_element_type =
588  T>::storage_element_type;
589  auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jm);
590  for (int i = 0; i < wi_data_c.length(); i++) {
591  storage_element_type element = wi_data_c[i];
592  auto [row, col] = wi_data_c[i].get_coord();
593  lambda(element, row, col);
594  wi_data_c[i] = element;
595  }
596 #endif
597 #else
598  std::ignore = sg;
599  std::ignore = jm;
600  std::ignore = lambda;
602  "joint matrix is not supported on host.");
603 #endif
604 }
605 
606 using namespace sycl::ext::oneapi::experimental::matrix;
607 
608 // Begin out-of-bounds API
609 
610 template <typename Group, typename T, size_t NumRows, size_t NumCols, use Use,
611  layout Layout, typename T2>
613  Group, joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &Res,
614  const T2 &Value, size_t Height, size_t Width, size_t CoordX,
615  size_t CoordY) {
616 #if defined(__SYCL_DEVICE_ONLY__)
617  using storage_element_type =
619  T>::storage_element_type;
620  Res.spvm = __spirv_CooperativeMatrixConstructCheckedINTEL<
621  storage_element_type, T, NumRows, NumCols,
622  spv_matrix_use_traits<Use>::value,
623  spv_matrix_layout_traits<Layout>::value>(
624  CoordX, CoordY, Height, Width, static_cast<storage_element_type>(Value));
625 #else
626  std::ignore = Res;
627  std::ignore = Value;
628  std::ignore = Height;
629  std::ignore = Width;
630  std::ignore = CoordX;
631  std::ignore = CoordY;
633  "joint matrix is not supported on host.");
634 #endif // defined(__SYCL_DEVICE_ONLY__)
635 }
636 
637 template <
638  typename Group, typename S, typename T, size_t NumRows, size_t NumCols,
639  access::address_space Space, access::decorated IsDecorated,
640  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value, bool> =
641  true>
643  Group sg,
644  joint_matrix<Group, S, use::accumulator, NumRows, NumCols, layout::dynamic>
645  &Res,
646  multi_ptr<T, Space, IsDecorated> Src, size_t Stride, layout Layout,
647  size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
648 #if defined(__SYCL_DEVICE_ONLY__)
649  static_assert(Space != access::address_space::private_space,
650  "Joint Matrix doesn't support load from private memory!");
651  std::ignore = sg;
652  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
653  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
654  Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
655  DecorT, S, NumRows, NumCols,
656  spv_matrix_use_traits<use::accumulator>::value,
657  spv_matrix_layout_traits<layout::dynamic>::value>(
658  Ptr, CoordX, CoordY, sycl::detail::joint_matrix_layout_to_spv(Layout),
659  Height, Width, Stride);
660 #else
661  std::ignore = sg;
662  std::ignore = Res;
663  std::ignore = Src;
664  std::ignore = Stride;
665  std::ignore = Height;
666  std::ignore = Width;
667  std::ignore = Layout;
668  std::ignore = CoordX;
669  std::ignore = CoordY;
671  "joint matrix is not supported on host.");
672 #endif // defined(__SYCL_DEVICE_ONLY__)
673 }
674 
675 template <
676  typename Group, typename S, typename T, use Use, size_t NumRows,
677  size_t NumCols, layout Layout, access::address_space Space,
678  access::decorated IsDecorated,
679  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
680  (std::is_same<S, precision::tf32>::value &&
681  std::is_same<std::remove_const_t<T>, float>::value),
682  bool> = true>
684  Group sg, joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &Res,
685  multi_ptr<T, Space, IsDecorated> Src, size_t Stride, size_t Height,
686  size_t Width, size_t CoordX, size_t CoordY) {
687 #if defined(__SYCL_DEVICE_ONLY__)
688  static_assert(Space != access::address_space::private_space,
689  "Joint Matrix doesn't support load from private memory!");
690  std::ignore = sg;
691  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
692  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
693  Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
694  DecorT, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
695  spv_matrix_layout_traits<Layout>::value>(
696  Ptr, CoordX, CoordY, spv_matrix_layout_traits<Layout>::value, Height,
697  Width, Stride);
698 #else
699  std::ignore = sg;
700  std::ignore = Res;
701  std::ignore = Src;
702  std::ignore = Stride;
703  std::ignore = Height;
704  std::ignore = Width;
705  std::ignore = CoordX;
706  std::ignore = CoordY;
708  "joint matrix is not supported on host.");
709 #endif // defined(__SYCL_DEVICE_ONLY__)
710 }
711 
712 template <typename Group, typename T, size_t NumRows, size_t NumCols,
713  access::address_space Space, access::decorated IsDecorated>
715  Group sg,
716  joint_matrix<Group, T, use::accumulator, NumRows, NumCols, layout::dynamic>
717  &Src,
718  multi_ptr<T, Space, IsDecorated> Dst, size_t Stride, layout Layout,
719  size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
720 #if defined(__SYCL_DEVICE_ONLY__)
721  static_assert(Space != access::address_space::private_space,
722  "Joint Matrix doesn't support store to private memory!");
723  std::ignore = sg;
724  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
725  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
726  __spirv_CooperativeMatrixStoreCheckedINTEL<
727  DecorT, T, NumRows, NumCols,
728  spv_matrix_use_traits<use::accumulator>::value,
729  spv_matrix_layout_traits<layout::dynamic>::value>(
730  Ptr, CoordX, CoordY, Src.spvm,
731  sycl::detail::joint_matrix_layout_to_spv(Layout), Height, Width, Stride);
732 #else
733  std::ignore = sg;
734  std::ignore = Src;
735  std::ignore = Dst;
736  std::ignore = Stride;
737  std::ignore = Height;
738  std::ignore = Width;
739  std::ignore = Layout;
740  std::ignore = CoordX;
741  std::ignore = CoordY;
743  "joint matrix is not supported on host.");
744 #endif // defined(__SYCL_DEVICE_ONLY__)
745 }
746 
747 template <typename Group, typename T, typename Tp, use Use, size_t NumRows,
748  size_t NumCols, layout Layout, access::address_space Space,
749  access::decorated IsDecorated,
750  std::enable_if_t<Use == use::a || Use == use::b, bool> = true>
752  Group sg, const joint_matrix<Group, Tp, Use, NumRows, NumCols, Layout> &Src,
753  multi_ptr<T, Space, IsDecorated> Dst, size_t Stride, size_t Height,
754  size_t Width, size_t CoordX, size_t CoordY) {
755 #if defined(__SYCL_DEVICE_ONLY__)
756  static_assert(Space != access::address_space::private_space,
757  "Joint Matrix doesn't support store to private memory!");
758  std::ignore = sg;
759  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
760  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
761  __spirv_CooperativeMatrixStoreCheckedINTEL<
762  DecorT, Tp, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
763  spv_matrix_layout_traits<Layout>::value>(
764  Ptr, CoordX, CoordY, Src.spvm, spv_matrix_layout_traits<Layout>::value,
765  Height, Width, Stride);
766 #else
767  std::ignore = sg;
768  std::ignore = Src;
769  std::ignore = Dst;
770  std::ignore = Stride;
771  std::ignore = Height;
772  std::ignore = Width;
773  std::ignore = CoordX;
774  std::ignore = CoordY;
776  "joint matrix is not supported on host.");
777 #endif // defined(__SYCL_DEVICE_ONLY__)
778 }
779 
780 // Annotated pointer overloads:
781 template <typename Group, typename S, typename T, size_t NumRows,
782  size_t NumCols, typename PropertyListT,
783  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value,
784  bool> = true>
786  Group sg,
787  joint_matrix<Group, S, use::accumulator, NumRows, NumCols, layout::dynamic>
788  &Res,
790  size_t Stride, layout Layout, size_t Height, size_t Width, size_t CoordX,
791  size_t CoordY) {
792 #if defined(__SYCL_DEVICE_ONLY__)
793  std::ignore = sg;
794  T *Ptr = Src.get();
795  Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
796  T, S, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
797  spv_matrix_layout_traits<layout::dynamic>::value>(
798  Ptr, CoordX, CoordY, sycl::detail::joint_matrix_layout_to_spv(Layout),
799  Height, Width, Stride);
800 #else
801  std::ignore = sg;
802  std::ignore = Res;
803  std::ignore = Src;
804  std::ignore = Stride;
805  std::ignore = Height;
806  std::ignore = Width;
807  std::ignore = Layout;
808  std::ignore = CoordX;
809  std::ignore = CoordY;
811  "joint matrix is not supported on host.");
812 #endif // defined(__SYCL_DEVICE_ONLY__)
813 }
814 
815 template <
816  typename Group, typename S, typename T, use Use, size_t NumRows,
817  size_t NumCols, layout Layout, typename PropertyListT,
818  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
819  (std::is_same<S, precision::tf32>::value &&
820  std::is_same<std::remove_const_t<T>, float>::value),
821  bool> = true>
823  Group sg, joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &Res,
825  size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
826 #if defined(__SYCL_DEVICE_ONLY__)
827  std::ignore = sg;
828  T *Ptr = Src.get();
829  Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
830  T, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
831  spv_matrix_layout_traits<Layout>::value>(
832  Ptr, CoordX, CoordY, spv_matrix_layout_traits<Layout>::value, Height,
833  Width, Stride);
834 #else
835  std::ignore = sg;
836  std::ignore = Res;
837  std::ignore = Src;
838  std::ignore = Stride;
839  std::ignore = Height;
840  std::ignore = Width;
841  std::ignore = CoordX;
842  std::ignore = CoordY;
844  "joint matrix is not supported on host.");
845 #endif // defined(__SYCL_DEVICE_ONLY__)
846 }
847 
848 template <typename Group, typename T, size_t NumRows, size_t NumCols,
849  typename PropertyListT>
851  Group sg,
852  joint_matrix<Group, T, use::accumulator, NumRows, NumCols, layout::dynamic>
853  &Src,
855  size_t Stride, layout Layout, size_t Height, size_t Width, size_t CoordX,
856  size_t CoordY) {
857 #if defined(__SYCL_DEVICE_ONLY__)
858  std::ignore = sg;
859  T *Ptr = Dst.get();
860  __spirv_CooperativeMatrixStoreCheckedINTEL<
861  T, T, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
862  spv_matrix_layout_traits<layout::dynamic>::value>(
863  Ptr, CoordX, CoordY, Src.spvm,
864  sycl::detail::joint_matrix_layout_to_spv(Layout), Height, Width, Stride);
865 #else
866  std::ignore = sg;
867  std::ignore = Src;
868  std::ignore = Dst;
869  std::ignore = Stride;
870  std::ignore = Height;
871  std::ignore = Width;
872  std::ignore = Layout;
873  std::ignore = CoordX;
874  std::ignore = CoordY;
876  "joint matrix is not supported on host.");
877 #endif // defined(__SYCL_DEVICE_ONLY__)
878 }
879 
880 template <typename Group, typename T, typename Tp, use Use, size_t NumRows,
881  size_t NumCols, layout Layout, typename PropertyListT,
882  std::enable_if_t<Use == use::a || Use == use::b, bool> = true>
884  Group sg, const joint_matrix<Group, Tp, Use, NumRows, NumCols, Layout> &Src,
886  size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
887 #if defined(__SYCL_DEVICE_ONLY__)
888  std::ignore = sg;
889  T *Ptr = Dst.get();
890  __spirv_CooperativeMatrixStoreCheckedINTEL<
891  T, Tp, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
892  spv_matrix_layout_traits<Layout>::value>(
893  Ptr, CoordX, CoordY, Src.spvm, spv_matrix_layout_traits<Layout>::value,
894  Height, Width, Stride);
895 #else
896  std::ignore = sg;
897  std::ignore = Src;
898  std::ignore = Dst;
899  std::ignore = Stride;
900  std::ignore = Height;
901  std::ignore = Width;
902  std::ignore = CoordX;
903  std::ignore = CoordY;
905  "joint matrix is not supported on host.");
906 #endif // defined(__SYCL_DEVICE_ONLY__)
907 }
908 // End out-of-bounds API
909 
910 } // namespace intel::experimental::matrix
911 
912 } // namespace ext
913 } // namespace _V1
914 } // namespace sycl
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols, Layout > &Mat, std::size_t i)
wi_element & operator=(const wi_element< sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout, Group > &rhs)
wi_element & operator=(const wi_element< T, NumRows, NumCols, Use, Layout, Group > &rhs)
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, NumRows, NumCols, Layout > &Mat, std::size_t i)
wi_element & operator=(const T2 &rhs)
typename oneapi::detail::jm_type_interpretation_helper_trait< T >::storage_element_type storage_element_type
__SYCL_ALWAYS_INLINE std::tuple< size_t, size_t > get_coord()
#define __SYCL_ALWAYS_INLINE
#define SPV_MATRIX_USE_TRAITS(USE, SPV_USE)
#define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT)
#define OP(op)
__SYCL_ALWAYS_INLINE __spv::MatrixLayout joint_matrix_layout_to_spv(sycl::ext::oneapi::experimental::matrix::layout Layout)
constexpr tuple< Ts... > make_tuple(Ts... Args)
Definition: tuple.hpp:35
sycl::ext::oneapi::bfloat16 bfloat16
__SYCL_ALWAYS_INLINE void joint_matrix_store(Group, const sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, Tp, Use, NumRows, NumCols, Layout > &src, ext::oneapi::experimental::annotated_ptr< T, PropertyListT > dst, size_t stride)
__SYCL_ALWAYS_INLINE void joint_matrix_load_checked(Group sg, joint_matrix< Group, S, Use, NumRows, NumCols, Layout > &Res, ext::oneapi::experimental::annotated_ptr< T, PropertyListT > Src, size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY)
__SYCL_ALWAYS_INLINE void joint_matrix_apply(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, Rows, Cols, Layout > &jm, F &&lambda)
__SYCL_ALWAYS_INLINE void joint_matrix_store_checked(Group sg, const joint_matrix< Group, Tp, Use, NumRows, NumCols, Layout > &Src, ext::oneapi::experimental::annotated_ptr< T, PropertyListT > Dst, size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY)
__SYCL_ALWAYS_INLINE void joint_matrix_fill_checked(Group, joint_matrix< Group, T, Use, NumRows, NumCols, Layout > &Res, const T2 &Value, size_t Height, size_t Width, size_t CoordX, size_t CoordY)
decltype(auto) __SYCL_ALWAYS_INLINE get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, Rows, Cols, Layout > &jm)
std::enable_if_t< detail::is_bf16_storage_type< T >::value, T > fabs(T x)
std::error_code make_error_code(sycl::errc E) noexcept
Constructs an error code using e and sycl_category()
Definition: exception.cpp:64
Definition: access.hpp:18