DPC++ Runtime
Runtime libraries for oneAPI DPC++
matrix-intel.hpp
Go to the documentation of this file.
1 //==------------------ matrix-intel.hpp - SYCL matrix ----------*- C++ -*---==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // ===--------------------------------------------------------------------=== //
8 
9 #pragma once
10 
11 #include "matrix-unified-utils.hpp" // for use, layout, tf32, matrix
12 #include "utils.hpp" // for getDecorated
13 
14 #include <CL/__spirv/spirv_types.hpp> // for MatrixLayout, MatrixUse
15 #include <sycl/access/access.hpp> // for address_space, decorated
16 #include <sycl/builtins.hpp> // for fabs
17 #include <sycl/detail/defines_elementary.hpp> // for __SYCL_ALWAYS_INLINE
18 #include <sycl/exception.hpp>
19 #include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16
21 #include <sycl/group.hpp> // for group
22 #include <sycl/multi_ptr.hpp> // for multi_ptr
23 #include <sycl/sub_group.hpp> // for sub_group
24 
25 #include <cstddef> // for size_t
26 #include <stdint.h> // for uint32_t
27 #include <tuple> // for ignore, tuple, _Swallo...
28 #include <type_traits> // for enable_if_t
29 
30 namespace sycl {
31 inline namespace _V1 {
32 namespace ext {
33 namespace oneapi {
34 namespace experimental {
35 namespace matrix {
36 
37 template <layout Layout> struct spv_matrix_layout_traits {
39 };
40 
41 #define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT) \
42  template <> struct spv_matrix_layout_traits<LAYOUT> { \
43  static constexpr __spv::MatrixLayout value = SPV_LAYOUT; \
44  };
45 
50 
51 template <use Use> struct spv_matrix_use_traits {
52  static constexpr __spv::MatrixUse value = __spv::MatrixUse::MatrixA;
53 };
54 
55 #define SPV_MATRIX_USE_TRAITS(USE, SPV_USE) \
56  template <> struct spv_matrix_use_traits<USE> { \
57  static constexpr __spv::MatrixUse value = SPV_USE; \
58  };
59 
63 
64 template <typename G> struct spv_scope_traits {};
65 template <> struct spv_scope_traits<sycl::sub_group> {
66  constexpr static auto value = __spv::Scope::Subgroup;
67 };
68 template <int D> struct spv_scope_traits<sycl::group<D>> {
69  constexpr static auto value = __spv::Scope::Workgroup;
70 };
71 
72 // forward declarations
73 template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
74  layout Layout>
75 struct joint_matrix;
76 
77 } // namespace matrix
78 } // namespace experimental
79 
80 namespace detail {
81 // Differentiating between the "element type" and the "storage element type"
82 template <typename T> struct jm_type_interpretation_helper_trait {
83  using element_type = T;
85 };
86 
87 template <>
91  using storage_element_type = float;
92 };
93 
94 using namespace sycl::ext::oneapi::experimental::matrix;
95 // Begin wi_element definition
96 
97 template <typename T, size_t NumRows, size_t NumCols,
100  sycl::ext::oneapi::experimental::matrix::layout::dynamic,
101  typename Group = sycl::sub_group>
102 class wi_element {
104  NumCols, Layout> &M;
105  std::size_t idx;
106 
107 public:
112  Group, T, Use, NumRows, NumCols, Layout> &Mat,
113  std::size_t i)
114  : M(Mat), idx(i) {}
115 
116  inline __SYCL_ALWAYS_INLINE std::tuple<size_t, size_t> get_coord() {
117 #if defined(__SYCL_DEVICE_ONLY__)
118 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
119  __ocl_vec_t<uint32_t, 2> coord =
120  __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
121 #else
122  __ocl_vec_t<uint32_t, 2> coord =
123  __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
124 #endif // __SPIRV_USE_COOPERATIVE_MATRIX
125  const size_t row = coord[0];
126  const size_t col = coord[1];
127  return std::make_tuple(row, col);
128 #else
130  "joint matrix is not supported on host.");
131 #endif // __SYCL_DEVICE_ONLY__
132  }
133 
134  operator storage_element_type() {
135 #ifdef __SYCL_DEVICE_ONLY__
136 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
138  __spirv_VectorExtractDynamic<storage_element_type, T, NumRows, NumCols,
139  spv_matrix_use_traits<Use>::value,
140  spv_matrix_layout_traits<Layout>::value,
141  spv_scope_traits<Group>::value>(M.spvm,
142  idx);
143 #else
144  storage_element_type *ExtractP =
145  __spirv_AccessChain<storage_element_type, T, NumRows, NumCols,
146  spv_matrix_use_traits<Use>::value,
147  spv_scope_traits<Group>::value>(&M.spvm, idx);
148  storage_element_type elem = *ExtractP;
149 #endif // __SPIRV_USE_COOPERATIVE_MATRIX
150  return elem;
151 #else
153  "joint matrix is not supported on host.");
154 #endif // __SYCL_DEVICE_ONLY__
155  }
156 
157  explicit operator bool() {
158 #ifdef __SYCL_DEVICE_ONLY__
159 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
160  return __spirv_VectorExtractDynamic<storage_element_type, T, NumRows,
161  NumCols,
162  spv_matrix_use_traits<Use>::value,
163  spv_matrix_layout_traits<Layout>::value,
164  spv_scope_traits<Group>::value>(
165  M.spvm, idx) != static_cast<storage_element_type>(0);
166 #else
167  storage_element_type *ExtractP =
168  __spirv_AccessChain<storage_element_type, T, NumRows, NumCols,
169  spv_matrix_use_traits<Use>::value,
170  spv_scope_traits<Group>::value>(&M.spvm, idx);
171  return *ExtractP != static_cast<storage_element_type>(0);
172 #endif // __SPIRV_USE_COOPERATIVE_MATRIX
173 #else
175  "joint matrix is not supported on host.");
176 #endif // __SYCL_DEVICE_ONLY__
177  }
178 
179  template <typename T2> wi_element &operator=(const T2 &rhs) {
180 #ifdef __SYCL_DEVICE_ONLY__
181 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
182  M.spvm = __spirv_VectorInsertDynamic(
183  M.spvm, static_cast<storage_element_type>(rhs), idx);
184 #else
185  storage_element_type *InsertP =
186  __spirv_AccessChain<storage_element_type, T, NumRows, NumCols,
187  spv_matrix_use_traits<Use>::value,
188  spv_scope_traits<Group>::value>(&M.spvm, idx);
189  *InsertP = static_cast<storage_element_type>(rhs);
190 #endif // __SPIRV_USE_COOPERATIVE_MATRIX
191  return *this;
192 #else
193  (void)rhs;
195  "joint matrix is not supported on host.");
196 #endif // __SYCL_DEVICE_ONLY__
197  }
198 
199  wi_element &
201 #ifdef __SYCL_DEVICE_ONLY__
202 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
203  M.spvm = __spirv_VectorInsertDynamic(
204  M.spvm,
205  __spirv_VectorExtractDynamic<storage_element_type, T, NumRows, NumCols,
206  spv_matrix_use_traits<Use>::value,
207  spv_matrix_layout_traits<Layout>::value,
208  spv_scope_traits<Group>::value>(rhs.M.spvm,
209  rhs.idx),
210  idx);
211 #else
212  storage_element_type *ExtractP =
213  __spirv_AccessChain<storage_element_type, T, NumRows, NumCols,
214  spv_matrix_use_traits<Use>::value,
215  spv_scope_traits<Group>::value>(&rhs.M.spvm,
216  rhs.idx);
217  storage_element_type *InsertP =
218  __spirv_AccessChain<storage_element_type, T, NumRows, NumCols,
219  spv_matrix_use_traits<Use>::value,
220  spv_scope_traits<Group>::value>(&M.spvm, idx);
221  *InsertP = *ExtractP;
222 #endif // __SPIRV_USE_COOPERATIVE_MATRIX
223  return *this;
224 #else
225  (void)rhs;
227  "joint matrix is not supported on host.");
228 #endif // __SYCL_DEVICE_ONLY__
229  }
230 
231 #if __SYCL_DEVICE_ONLY__
232 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
233 #define OP(op) \
234  template <typename T2> wi_element &operator op##=(const T2 & rhs) { \
235  M.spvm = __spirv_VectorInsertDynamic( \
236  M.spvm, \
237  static_cast<storage_element_type>( \
238  __spirv_VectorExtractDynamic< \
239  storage_element_type, T, NumRows, NumCols, \
240  spv_matrix_use_traits<Use>::value, \
241  spv_matrix_layout_traits<Layout>::value, \
242  spv_scope_traits<Group>::value>(M.spvm, idx) \
243  op static_cast<storage_element_type>(rhs)), \
244  idx); \
245  return *this; \
246  }
247 #else // __SPIRV_USE_COOPERATIVE_MATRIX
248 #define OP(op) \
249  template <typename T2> wi_element &operator op##=(const T2 & rhs) { \
250  storage_element_type *ExtractP = \
251  __spirv_AccessChain<storage_element_type, T, NumRows, NumCols, \
252  spv_matrix_use_traits<Use>::value, \
253  spv_scope_traits<Group>::value>(&rhs.M.spvm, \
254  rhs.idx); \
255  storage_element_type *InsertP = \
256  __spirv_AccessChain<storage_element_type, T, NumRows, NumCols, \
257  spv_matrix_use_traits<Use>::value, \
258  spv_scope_traits<Group>::value>(&M.spvm, idx); \
259  *InsertP = *ExtractP op static_cast<storage_element_type>(rhs); \
260  return *this; \
261  }
262 #endif // __SPIRV_USE_COOPERATIVE_MATRIX
263 #else // __SYCL_DEVICE_ONLY__
264 #define OP(op) \
265  template <typename T2> wi_element &operator op##=(const T2 & rhs) { \
266  (void)rhs; \
267  throw exception(make_error_code(errc::runtime), \
268  "joint matrix is not supported on host."); \
269  }
270 #endif // __SYCL_DEVICE_ONLY__
271  OP(+)
272  OP(-)
273  OP(*)
274  OP(/)
275 #undef OP
276 };
277 
278 template <size_t NumRows, size_t NumCols,
281  typename Group>
282 class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
283  Group> {
285  Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols, Layout> &M;
286  std::size_t idx;
287 
288 public:
290  Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols,
291  Layout> &Mat,
292  std::size_t i)
293  : M(Mat), idx(i) {}
294 
295  inline __SYCL_ALWAYS_INLINE std::tuple<uint32_t, uint32_t> get_coord() {
296 #if defined(__SYCL_DEVICE_ONLY__)
297 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
298  __ocl_vec_t<uint32_t, 2> coord =
299  __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
300 #else
301  __ocl_vec_t<uint32_t, 2> coord =
302  __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
303 #endif // __SPIRV_USE_COOPERATIVE_MATRIX
304  const uint32_t row = coord[0];
305  const uint32_t col = coord[1];
306  return std::make_tuple(row, col);
307 #else
309  "joint matrix is not supported on host.");
310 #endif // __SYCL_DEVICE_ONLY__
311  }
312 
314 #ifdef __SYCL_DEVICE_ONLY__
315 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
316  return __spirv_VectorExtractDynamic<
318  NumCols, spv_matrix_use_traits<Use>::value,
319  spv_matrix_layout_traits<Layout>::value,
320  spv_scope_traits<Group>::value>(M.spvm, idx);
321 #else
322  sycl::ext::oneapi::bfloat16 *ExtractP =
323  __spirv_AccessChain<sycl::ext::oneapi::bfloat16,
324  sycl::ext::oneapi::bfloat16, NumRows, NumCols,
325  spv_matrix_use_traits<Use>::value,
326  spv_scope_traits<Group>::value>(&M.spvm, idx);
327  return *ExtractP;
328 #endif // __SPIRV_USE_COOPERATIVE_MATRIX
329 #else
331  "joint matrix is not supported on host.");
332 #endif // __SYCL_DEVICE_ONLY__
333  }
334 
335  explicit operator bool() {
336 #ifdef __SYCL_DEVICE_ONLY__
337 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
338  return sycl::fabs(static_cast<float>(
339  __spirv_VectorExtractDynamic<
341  NumRows, NumCols, spv_matrix_use_traits<Use>::value,
342  spv_matrix_layout_traits<Layout>::value,
343  spv_scope_traits<Group>::value>(M.spvm, idx))) >=
344  std::numeric_limits<float>::epsilon();
345 #else
346  sycl::ext::oneapi::bfloat16 *ExtractP =
347  __spirv_AccessChain<sycl::ext::oneapi::bfloat16,
348  sycl::ext::oneapi::bfloat16, NumRows, NumCols,
349  spv_matrix_use_traits<Use>::value,
350  spv_scope_traits<Group>::value>(&M.spvm, idx);
351  sycl::ext::oneapi::bfloat16 Elem = *ExtractP;
352  return sycl::fabs(static_cast<float>(Elem)) >=
353  std::numeric_limits<float>::epsilon();
354 #endif // __SPIRV_USE_COOPERATIVE_MATRIX
355 #else
357  "joint matrix is not supported on host.");
358 #endif // __SYCL_DEVICE_ONLY__
359  }
360 
362 #ifdef __SYCL_DEVICE_ONLY__
363 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
364  M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx);
365 #else
366  sycl::ext::oneapi::bfloat16 *InsertP =
367  __spirv_AccessChain<sycl::ext::oneapi::bfloat16,
368  sycl::ext::oneapi::bfloat16, NumRows, NumCols,
369  spv_matrix_use_traits<Use>::value,
370  spv_scope_traits<Group>::value>(&M.spvm, idx);
371  *InsertP = rhs;
372 #endif // __SPIRV_USE_COOPERATIVE_MATRIX
373  return *this;
374 #else
375  (void)rhs;
377  "joint matrix is not supported on host.");
378 #endif // __SYCL_DEVICE_ONLY__
379  }
380 
382  NumCols, Use, Layout, Group> &rhs) {
383 #ifdef __SYCL_DEVICE_ONLY__
384 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
385  M.spvm = __spirv_VectorInsertDynamic(
386  M.spvm,
387  __spirv_VectorExtractDynamic<sycl::ext::oneapi::bfloat16,
389  NumCols, spv_matrix_use_traits<Use>::value,
390  spv_matrix_layout_traits<Layout>::value,
391  spv_scope_traits<Group>::value>(rhs.M.spvm,
392  rhs.idx),
393  idx);
394  return *this;
395 #else
396  sycl::ext::oneapi::bfloat16 *ExtractP =
397  __spirv_AccessChain<sycl::ext::oneapi::bfloat16,
398  sycl::ext::oneapi::bfloat16, NumRows, NumCols,
399  spv_matrix_use_traits<Use>::value,
400  spv_scope_traits<Group>::value>(&rhs.M.spvm,
401  rhs.idx);
402  sycl::ext::oneapi::bfloat16 *InsertP =
403  __spirv_AccessChain<sycl::ext::oneapi::bfloat16,
404  sycl::ext::oneapi::bfloat16, NumRows, NumCols,
405  spv_matrix_use_traits<Use>::value,
406  spv_scope_traits<Group>::value>(&M.spvm, idx);
407  *InsertP = *ExtractP;
408  return *this;
409 #endif // __SPIRV_USE_COOPERATIVE_MATRIX
410 #else
411  (void)rhs;
413  "joint matrix is not supported on host.");
414 #endif // __SYCL_DEVICE_ONLY__
415  }
416 
417 #if __SYCL_DEVICE_ONLY__
418 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
419 #define OP(opassign, op) \
420  wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \
421  M.spvm = __spirv_VectorInsertDynamic( \
422  M.spvm, \
423  __spirv_VectorExtractDynamic< \
424  sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
425  NumCols, spv_matrix_use_traits<Use>::value, \
426  spv_matrix_layout_traits<Layout>::value, \
427  spv_scope_traits<Group>::value>(M.spvm, idx) op rhs, \
428  idx); \
429  return *this; \
430  }
431 #else
432 #define OP(opassign, op) \
433  wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \
434  sycl::ext::oneapi::bfloat16 *ExtractP = \
435  __spirv_AccessChain<sycl::ext::oneapi::bfloat16, \
436  sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
437  spv_matrix_use_traits<Use>::value, \
438  spv_scope_traits<Group>::value>(&M.spvm, idx); \
439  sycl::ext::oneapi::bfloat16 *InsertP = \
440  __spirv_AccessChain<sycl::ext::oneapi::bfloat16, \
441  sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
442  spv_matrix_use_traits<Use>::value, \
443  spv_scope_traits<Group>::value>(&M.spvm, idx); \
444  *InsertP = *ExtractP op rhs; \
445  return *this; \
446  }
447 #endif // __SPIRV_USE_COOPERATIVE_MATRIX
448 #else // __SYCL_DEVICE_ONLY__
449 #define OP(opassign, op) \
450  wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \
451  (void)rhs; \
452  throw exception(make_error_code(errc::runtime), \
453  "joint matrix is not supported on host."); \
454  }
455 #endif // __SYCL_DEVICE_ONLY__
456  OP(+=, +)
457  OP(-=, -)
458  OP(*=, *)
459  OP(/=, /)
460 #undef OP
461 
462 #if __SYCL_DEVICE_ONLY__
463 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
464 #define OP(type, op) \
465  friend type operator op( \
466  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
467  Layout, Group> &lhs, \
468  const sycl::ext::oneapi::bfloat16 &rhs) { \
469  return __spirv_VectorExtractDynamic< \
470  sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
471  NumCols, spv_matrix_use_traits<Use>::value, \
472  spv_matrix_layout_traits<Layout>::value, \
473  spv_scope_traits<Group>::value>(lhs.M.spvm, lhs.idx) op rhs; \
474  } \
475  friend type operator op( \
476  const sycl::ext::oneapi::bfloat16 &lhs, \
477  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
478  Layout, Group> &rhs) { \
479  return __spirv_VectorExtractDynamic< \
480  sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
481  NumCols, spv_matrix_use_traits<Use>::value, \
482  spv_matrix_layout_traits<Layout>::value, \
483  spv_scope_traits<Group>::value>(rhs.M.spvm, rhs.idx) op lhs; \
484  }
485 #else
486 #define OP(type, op) \
487  friend type operator op( \
488  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
489  Layout, Group> &lhs, \
490  const sycl::ext::oneapi::bfloat16 &rhs) { \
491  sycl::ext::oneapi::bfloat16 *ExtractP = \
492  __spirv_AccessChain<sycl::ext::oneapi::bfloat16, \
493  sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
494  spv_matrix_use_traits<Use>::value, \
495  spv_scope_traits<Group>::value>(&lhs.M.spvm, \
496  lhs.idx); \
497  return *ExtractP op rhs; \
498  } \
499  friend type operator op( \
500  const sycl::ext::oneapi::bfloat16 &lhs, \
501  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
502  Layout, Group> &rhs) { \
503  sycl::ext::oneapi::bfloat16 *ExtractP = \
504  __spirv_AccessChain<sycl::ext::oneapi::bfloat16, \
505  sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
506  spv_matrix_use_traits<Use>::value, \
507  spv_scope_traits<Group>::value>(&rhs.M.spvm, \
508  rhs.idx); \
509  return *ExtractP op lhs; \
510  }
511 #endif // __SPIRV_USE_COOPERATIVE_MATRIX
513  OP(sycl::ext::oneapi::bfloat16, -)
514  OP(sycl::ext::oneapi::bfloat16, *)
515  OP(sycl::ext::oneapi::bfloat16, /)
516 #undef OP
517 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
518 #define OP(type, op) \
519  friend type operator op( \
520  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
521  Layout, Group> &lhs, \
522  const sycl::ext::oneapi::bfloat16 &rhs) { \
523  return type{static_cast<float>( \
524  __spirv_VectorExtractDynamic< \
525  sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
526  NumCols, spv_matrix_use_traits<Use>::value, \
527  spv_matrix_layout_traits<Layout>::value, \
528  spv_scope_traits<Group>::value>(lhs.M.spvm, lhs.idx)) \
529  op static_cast<float>(rhs)}; \
530  } \
531  friend type operator op( \
532  const sycl::ext::oneapi::bfloat16 &lhs, \
533  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
534  Layout, Group> &rhs) { \
535  return type{static_cast<float>( \
536  __spirv_VectorExtractDynamic< \
537  sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
538  NumCols, spv_matrix_use_traits<Use>::value, \
539  spv_matrix_layout_traits<Layout>::value, \
540  spv_scope_traits<Group>::value>(rhs.M.spvm, rhs.idx)) \
541  op static_cast<float>(lhs)}; \
542  }
543 #else
544 #define OP(type, op) \
545  friend type operator op( \
546  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
547  Layout, Group> &lhs, \
548  const sycl::ext::oneapi::bfloat16 &rhs) { \
549  sycl::ext::oneapi::bfloat16 *ExtractP = \
550  __spirv_AccessChain<sycl::ext::oneapi::bfloat16, \
551  sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
552  spv_matrix_use_traits<Use>::value, \
553  spv_scope_traits<Group>::value>(&lhs.M.spvm, \
554  lhs.idx); \
555  return type{static_cast<float>(*ExtractP) op static_cast<float>(rhs)}; \
556  } \
557  friend type operator op( \
558  const sycl::ext::oneapi::bfloat16 &lhs, \
559  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
560  Layout, Group> &rhs) { \
561  sycl::ext::oneapi::bfloat16 *ExtractP = \
562  __spirv_AccessChain<sycl::ext::oneapi::bfloat16, \
563  sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
564  spv_matrix_use_traits<Use>::value, \
565  spv_scope_traits<Group>::value>(&rhs.M.spvm, \
566  rhs.idx); \
567  return type{static_cast<float>(*ExtractP) op static_cast<float>(lhs)}; \
568  }
569 #endif // __SPIRV_USE_COOPERATIVE_MATRIX
570  OP(bool, ==)
571  OP(bool, !=)
572  OP(bool, <)
573  OP(bool, >)
574  OP(bool, <=)
575  OP(bool, >=)
576 #undef OP
577 #else // __SYCL_DEVICE_ONLY__
578 #define OP(type, op) \
579  friend type operator op( \
580  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
581  Layout, Group> &, \
582  const sycl::ext::oneapi::bfloat16 &) { \
583  throw exception(make_error_code(errc::runtime), \
584  "joint matrix is not supported on host."); \
585  } \
586  friend type operator op( \
587  const sycl::ext::oneapi::bfloat16 &, \
588  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
589  Layout, Group> &) { \
590  throw exception(make_error_code(errc::runtime), \
591  "joint matrix is not supported on host."); \
592  }
594  OP(sycl::ext::oneapi::bfloat16, -)
595  OP(sycl::ext::oneapi::bfloat16, *)
596  OP(sycl::ext::oneapi::bfloat16, /)
597  OP(bool, ==)
598  OP(bool, !=)
599  OP(bool, <)
600  OP(bool, >)
601  OP(bool, <=)
602  OP(bool, >=)
603 #undef OP
604 #endif // __SYCL_DEVICE_ONLY__
605 };
606 
607 // End wi_element definition
608 
609 // Begin wi_data definition
610 
611 template <typename Group, typename T,
614 class wi_data {
615 
617  Cols, Layout> &jm;
618 
620  Group, T, Use, Rows, Cols, Layout> &_jm)
621  : jm(_jm){};
622 
623  template <typename Grp, typename Type,
624  sycl::ext::oneapi::experimental::matrix::use UseJm, size_t NumRows,
625  size_t NumCols,
627  friend decltype(auto)
629  Grp, Type, UseJm, NumRows, NumCols, LayoutJm> &);
630 
631 public:
632  size_t length() {
633 #if __SYCL_DEVICE_ONLY__
634 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
635  return __spirv_JointMatrixWorkItemLengthINTEL(jm.spvm);
636 #else
637  return __spirv_CooperativeMatrixLengthKHR(jm.spvm);
638 #endif // __SPIRV_USE_COOPERATIVE_MATRIX
639 #else
641  "joint matrix is not supported on host.");
642 #endif
643  };
644 
645  decltype(auto) operator[](size_t i) {
647  };
648 };
649 
650 template <typename Group, typename T,
653 inline __SYCL_ALWAYS_INLINE decltype(auto)
654 get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix<
655  Group, T, Use, Rows, Cols, Layout> &jm) {
656  std::ignore = sg;
657  return wi_data(jm);
658 }
659 
660 // End wi_data definition
661 } // namespace detail
662 } // namespace oneapi
663 
664 namespace intel::experimental::matrix {
665 template <
666  typename Group, typename T, typename Tp,
669  access::address_space Space, access::decorated IsDecorated,
670  std::enable_if_t<Use == sycl::ext::oneapi::experimental::matrix::use::a ||
672  bool> = true>
673 inline __SYCL_ALWAYS_INLINE void
676  Group, Tp, Use, NumRows, NumCols, Layout> &src,
677  multi_ptr<T, Space, IsDecorated> dst, size_t stride) {
678 #if defined(__SYCL_DEVICE_ONLY__)
679  static_assert(Space != access::address_space::private_space,
680  "Joint Matrix doesn't support store to private memory!");
681 #if defined(__NVPTX__)
682  std::ignore = src;
683  std::ignore = dst;
684  std::ignore = stride;
685  throw exception(
687  "This version of the matrix extension is only currently supported on "
688  "intel devices");
689 #else
690  // intel's impl
691  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
692  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(dst);
693 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
694  __spirv_JointMatrixStoreINTEL<DecorT, Tp, NumRows, NumCols,
699  Ptr, src.spvm, stride,
701  Layout>::value,
703 #else
704  __spirv_CooperativeMatrixStoreKHR<
705  DecorT, Tp, NumRows, NumCols,
707  Use>::value,
709  Layout>::value>(
710  Ptr, src.spvm,
712  Layout>::value,
713  stride);
714 #endif // __SPIRV_USE_COOPERATIVE_MATRIX
715 #endif // defined(__NVPTX__)
716 #else
717  std::ignore = src;
718  std::ignore = dst;
719  std::ignore = stride;
721  "joint matrix is not supported on host.");
722 #endif // defined(__SYCL_DEVICE_ONLY__)
723 }
724 
725 template <
726  typename Group, typename T, typename Tp,
729  typename PropertyListT,
730  std::enable_if_t<Use == sycl::ext::oneapi::experimental::matrix::use::a ||
732  bool> = true>
734  Group,
736  Group, Tp, Use, NumRows, NumCols, Layout> &src,
738  size_t stride) {
739 #if defined(__SYCL_DEVICE_ONLY__)
740 #if defined(__NVPTX__)
741  std::ignore = src;
742  std::ignore = dst;
743  std::ignore = stride;
744  throw exception(
746  "This version of the matrix extension is only currently supported on "
747  "intel devices");
748 #else
749  // intel's impl
750  T *Ptr = dst.get();
751 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
752  __spirv_JointMatrixStoreINTEL<T, Tp, NumRows, NumCols,
757  Ptr, src.spvm, stride,
759  Layout>::value,
761 #else
762  __spirv_CooperativeMatrixStoreKHR<
763  T, Tp, NumRows, NumCols,
765  Use>::value,
767  Layout>::value>(
768  Ptr, src.spvm,
770  Layout>::value,
771  stride);
772 #endif // __SPIRV_USE_COOPERATIVE_MATRIX
773 #endif // defined(__NVPTX__)
774 #else
775  std::ignore = src;
776  std::ignore = dst;
777  std::ignore = stride;
779  "joint matrix is not supported on host.");
780 #endif // defined(__SYCL_DEVICE_ONLY__)
781 }
782 
783 template <typename Group, typename T,
786  typename F>
788  Group sg,
790  Cols, Layout> &jm,
791  F &&lambda) {
792 #if defined(__SYCL_DEVICE_ONLY__)
793 #if defined(__NVPTX__)
794  std::ignore = sg;
795  for (int i = 0; i < jm.matrix_impl.wi_marray.size(); i++) {
796  lambda(jm.matrix_impl.wi_marray[i]);
797  }
798 #else // NVPTX
799  using storage_element_type =
801  T>::storage_element_type;
802  auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jm);
803  for (int i = 0; i < wi_data_c.length(); i++) {
804  storage_element_type element = wi_data_c[i];
805  auto [row, col] = wi_data_c[i].get_coord();
806  lambda(element, row, col);
807  wi_data_c[i] = element;
808  }
809 #endif
810 #else
811  std::ignore = sg;
812  std::ignore = jm;
813  std::ignore = lambda;
815  "joint matrix is not supported on host.");
816 #endif
817 }
818 
819 using namespace sycl::ext::oneapi::experimental::matrix;
820 
821 // Begin out-of-bounds API
822 
823 template <typename Group, typename T, size_t NumRows, size_t NumCols, use Use,
824  layout Layout, typename T2>
826  Group, joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &Res,
827  const T2 &Value, size_t Height, size_t Width, size_t CoordX,
828  size_t CoordY) {
829 #if defined(__SYCL_DEVICE_ONLY__)
830  using storage_element_type =
832  T>::storage_element_type;
833  Res.spvm = __spirv_CooperativeMatrixConstructCheckedINTEL<
834  storage_element_type, T, NumRows, NumCols,
835  spv_matrix_use_traits<Use>::value,
836  spv_matrix_layout_traits<Layout>::value>(
837  CoordX, CoordY, Height, Width, static_cast<storage_element_type>(Value));
838 #else
839  std::ignore = Res;
840  std::ignore = Value;
841  std::ignore = Height;
842  std::ignore = Width;
843  std::ignore = CoordX;
844  std::ignore = CoordY;
846  "joint matrix is not supported on host.");
847 #endif // defined(__SYCL_DEVICE_ONLY__)
848 }
849 
850 template <
851  typename Group, typename S, typename T, size_t NumRows, size_t NumCols,
852  access::address_space Space, access::decorated IsDecorated,
853  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value, bool> =
854  true>
856  Group sg,
857  joint_matrix<Group, S, use::accumulator, NumRows, NumCols, layout::dynamic>
858  &Res,
859  multi_ptr<T, Space, IsDecorated> Src, size_t Stride, layout Layout,
860  size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
861 #if defined(__SYCL_DEVICE_ONLY__)
862  static_assert(Space != access::address_space::private_space,
863  "Joint Matrix doesn't support load from private memory!");
864  std::ignore = sg;
865  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
866  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
867  Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
868  DecorT, S, NumRows, NumCols,
869  spv_matrix_use_traits<use::accumulator>::value,
870  spv_matrix_layout_traits<layout::dynamic>::value>(
871  Ptr, CoordX, CoordY, sycl::detail::joint_matrix_layout_to_spv(Layout),
872  Height, Width, Stride);
873 #else
874  std::ignore = sg;
875  std::ignore = Res;
876  std::ignore = Src;
877  std::ignore = Stride;
878  std::ignore = Height;
879  std::ignore = Width;
880  std::ignore = Layout;
881  std::ignore = CoordX;
882  std::ignore = CoordY;
884  "joint matrix is not supported on host.");
885 #endif // defined(__SYCL_DEVICE_ONLY__)
886 }
887 
888 template <
889  typename Group, typename S, typename T, use Use, size_t NumRows,
890  size_t NumCols, layout Layout, access::address_space Space,
891  access::decorated IsDecorated,
892  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
893  (std::is_same<S, precision::tf32>::value &&
894  std::is_same<std::remove_const_t<T>, float>::value),
895  bool> = true>
897  Group sg, joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &Res,
898  multi_ptr<T, Space, IsDecorated> Src, size_t Stride, size_t Height,
899  size_t Width, size_t CoordX, size_t CoordY) {
900 #if defined(__SYCL_DEVICE_ONLY__)
901  static_assert(Space != access::address_space::private_space,
902  "Joint Matrix doesn't support load from private memory!");
903  std::ignore = sg;
904  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
905  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
906  Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
907  DecorT, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
908  spv_matrix_layout_traits<Layout>::value>(
909  Ptr, CoordX, CoordY, spv_matrix_layout_traits<Layout>::value, Height,
910  Width, Stride);
911 #else
912  std::ignore = sg;
913  std::ignore = Res;
914  std::ignore = Src;
915  std::ignore = Stride;
916  std::ignore = Height;
917  std::ignore = Width;
918  std::ignore = CoordX;
919  std::ignore = CoordY;
921  "joint matrix is not supported on host.");
922 #endif // defined(__SYCL_DEVICE_ONLY__)
923 }
924 
925 template <typename Group, typename T, size_t NumRows, size_t NumCols,
926  access::address_space Space, access::decorated IsDecorated>
928  Group sg,
929  joint_matrix<Group, T, use::accumulator, NumRows, NumCols, layout::dynamic>
930  &Src,
931  multi_ptr<T, Space, IsDecorated> Dst, size_t Stride, layout Layout,
932  size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
933 #if defined(__SYCL_DEVICE_ONLY__)
934  static_assert(Space != access::address_space::private_space,
935  "Joint Matrix doesn't support store to private memory!");
936  std::ignore = sg;
937  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
938  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
939  __spirv_CooperativeMatrixStoreCheckedINTEL<
940  DecorT, T, NumRows, NumCols,
941  spv_matrix_use_traits<use::accumulator>::value,
942  spv_matrix_layout_traits<layout::dynamic>::value>(
943  Ptr, CoordX, CoordY, Src.spvm,
944  sycl::detail::joint_matrix_layout_to_spv(Layout), Height, Width, Stride);
945 #else
946  std::ignore = sg;
947  std::ignore = Src;
948  std::ignore = Dst;
949  std::ignore = Stride;
950  std::ignore = Height;
951  std::ignore = Width;
952  std::ignore = Layout;
953  std::ignore = CoordX;
954  std::ignore = CoordY;
956  "joint matrix is not supported on host.");
957 #endif // defined(__SYCL_DEVICE_ONLY__)
958 }
959 
960 template <typename Group, typename T, typename Tp, use Use, size_t NumRows,
961  size_t NumCols, layout Layout, access::address_space Space,
962  access::decorated IsDecorated,
963  std::enable_if_t<Use == use::a || Use == use::b, bool> = true>
965  Group sg, const joint_matrix<Group, Tp, Use, NumRows, NumCols, Layout> &Src,
966  multi_ptr<T, Space, IsDecorated> Dst, size_t Stride, size_t Height,
967  size_t Width, size_t CoordX, size_t CoordY) {
968 #if defined(__SYCL_DEVICE_ONLY__)
969  static_assert(Space != access::address_space::private_space,
970  "Joint Matrix doesn't support store to private memory!");
971  std::ignore = sg;
972  using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
973  DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
974  __spirv_CooperativeMatrixStoreCheckedINTEL<
975  DecorT, Tp, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
976  spv_matrix_layout_traits<Layout>::value>(
977  Ptr, CoordX, CoordY, Src.spvm, spv_matrix_layout_traits<Layout>::value,
978  Height, Width, Stride);
979 #else
980  std::ignore = sg;
981  std::ignore = Src;
982  std::ignore = Dst;
983  std::ignore = Stride;
984  std::ignore = Height;
985  std::ignore = Width;
986  std::ignore = CoordX;
987  std::ignore = CoordY;
989  "joint matrix is not supported on host.");
990 #endif // defined(__SYCL_DEVICE_ONLY__)
991 }
992 
993 // Annotated pointer overloads:
994 template <typename Group, typename S, typename T, size_t NumRows,
995  size_t NumCols, typename PropertyListT,
996  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value,
997  bool> = true>
999  Group sg,
1000  joint_matrix<Group, S, use::accumulator, NumRows, NumCols, layout::dynamic>
1001  &Res,
1003  size_t Stride, layout Layout, size_t Height, size_t Width, size_t CoordX,
1004  size_t CoordY) {
1005 #if defined(__SYCL_DEVICE_ONLY__)
1006  std::ignore = sg;
1007  T *Ptr = Src.get();
1008  Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
1009  T, S, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
1010  spv_matrix_layout_traits<layout::dynamic>::value>(
1011  Ptr, CoordX, CoordY, sycl::detail::joint_matrix_layout_to_spv(Layout),
1012  Height, Width, Stride);
1013 #else
1014  std::ignore = sg;
1015  std::ignore = Res;
1016  std::ignore = Src;
1017  std::ignore = Stride;
1018  std::ignore = Height;
1019  std::ignore = Width;
1020  std::ignore = Layout;
1021  std::ignore = CoordX;
1022  std::ignore = CoordY;
1024  "joint matrix is not supported on host.");
1025 #endif // defined(__SYCL_DEVICE_ONLY__)
1026 }
1027 
1028 template <
1029  typename Group, typename S, typename T, use Use, size_t NumRows,
1030  size_t NumCols, layout Layout, typename PropertyListT,
1031  std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
1032  (std::is_same<S, precision::tf32>::value &&
1033  std::is_same<std::remove_const_t<T>, float>::value),
1034  bool> = true>
1036  Group sg, joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &Res,
1038  size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
1039 #if defined(__SYCL_DEVICE_ONLY__)
1040  std::ignore = sg;
1041  T *Ptr = Src.get();
1042  Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
1043  T, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
1044  spv_matrix_layout_traits<Layout>::value>(
1045  Ptr, CoordX, CoordY, spv_matrix_layout_traits<Layout>::value, Height,
1046  Width, Stride);
1047 #else
1048  std::ignore = sg;
1049  std::ignore = Res;
1050  std::ignore = Src;
1051  std::ignore = Stride;
1052  std::ignore = Height;
1053  std::ignore = Width;
1054  std::ignore = CoordX;
1055  std::ignore = CoordY;
1057  "joint matrix is not supported on host.");
1058 #endif // defined(__SYCL_DEVICE_ONLY__)
1059 }
1060 
1061 template <typename Group, typename T, size_t NumRows, size_t NumCols,
1062  typename PropertyListT>
1064  Group sg,
1065  joint_matrix<Group, T, use::accumulator, NumRows, NumCols, layout::dynamic>
1066  &Src,
1068  size_t Stride, layout Layout, size_t Height, size_t Width, size_t CoordX,
1069  size_t CoordY) {
1070 #if defined(__SYCL_DEVICE_ONLY__)
1071  std::ignore = sg;
1072  T *Ptr = Dst.get();
1073  __spirv_CooperativeMatrixStoreCheckedINTEL<
1074  T, T, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
1075  spv_matrix_layout_traits<layout::dynamic>::value>(
1076  Ptr, CoordX, CoordY, Src.spvm,
1077  sycl::detail::joint_matrix_layout_to_spv(Layout), Height, Width, Stride);
1078 #else
1079  std::ignore = sg;
1080  std::ignore = Src;
1081  std::ignore = Dst;
1082  std::ignore = Stride;
1083  std::ignore = Height;
1084  std::ignore = Width;
1085  std::ignore = Layout;
1086  std::ignore = CoordX;
1087  std::ignore = CoordY;
1089  "joint matrix is not supported on host.");
1090 #endif // defined(__SYCL_DEVICE_ONLY__)
1091 }
1092 
1093 template <typename Group, typename T, typename Tp, use Use, size_t NumRows,
1094  size_t NumCols, layout Layout, typename PropertyListT,
1095  std::enable_if_t<Use == use::a || Use == use::b, bool> = true>
1097  Group sg, const joint_matrix<Group, Tp, Use, NumRows, NumCols, Layout> &Src,
1099  size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
1100 #if defined(__SYCL_DEVICE_ONLY__)
1101  std::ignore = sg;
1102  T *Ptr = Dst.get();
1103  __spirv_CooperativeMatrixStoreCheckedINTEL<
1104  T, Tp, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
1105  spv_matrix_layout_traits<Layout>::value>(
1106  Ptr, CoordX, CoordY, Src.spvm, spv_matrix_layout_traits<Layout>::value,
1107  Height, Width, Stride);
1108 #else
1109  std::ignore = sg;
1110  std::ignore = Src;
1111  std::ignore = Dst;
1112  std::ignore = Stride;
1113  std::ignore = Height;
1114  std::ignore = Width;
1115  std::ignore = CoordX;
1116  std::ignore = CoordY;
1118  "joint matrix is not supported on host.");
1119 #endif // defined(__SYCL_DEVICE_ONLY__)
1120 }
1121 // End out-of-bounds API
1122 
1123 } // namespace intel::experimental::matrix
1124 
1125 } // namespace ext
1126 } // namespace _V1
1127 } // namespace sycl
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols, Layout > &Mat, std::size_t i)
wi_element & operator=(const wi_element< sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout, Group > &rhs)
wi_element & operator=(const wi_element< T, NumRows, NumCols, Use, Layout, Group > &rhs)
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, NumRows, NumCols, Layout > &Mat, std::size_t i)
wi_element & operator=(const T2 &rhs)
typename oneapi::detail::jm_type_interpretation_helper_trait< T >::storage_element_type storage_element_type
__SYCL_ALWAYS_INLINE std::tuple< size_t, size_t > get_coord()
#define __SYCL_ALWAYS_INLINE
#define SPV_MATRIX_USE_TRAITS(USE, SPV_USE)
#define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT)
#define OP(op)
__SYCL_ALWAYS_INLINE __spv::MatrixLayout joint_matrix_layout_to_spv(sycl::ext::oneapi::experimental::matrix::layout Layout)
constexpr tuple< Ts... > make_tuple(Ts... Args)
Definition: tuple.hpp:35
sycl::ext::oneapi::bfloat16 bfloat16
__SYCL_ALWAYS_INLINE void joint_matrix_store(Group, const sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, Tp, Use, NumRows, NumCols, Layout > &src, ext::oneapi::experimental::annotated_ptr< T, PropertyListT > dst, size_t stride)
__SYCL_ALWAYS_INLINE void joint_matrix_load_checked(Group sg, joint_matrix< Group, S, Use, NumRows, NumCols, Layout > &Res, ext::oneapi::experimental::annotated_ptr< T, PropertyListT > Src, size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY)
__SYCL_ALWAYS_INLINE void joint_matrix_apply(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, Rows, Cols, Layout > &jm, F &&lambda)
__SYCL_ALWAYS_INLINE void joint_matrix_store_checked(Group sg, const joint_matrix< Group, Tp, Use, NumRows, NumCols, Layout > &Src, ext::oneapi::experimental::annotated_ptr< T, PropertyListT > Dst, size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY)
__SYCL_ALWAYS_INLINE void joint_matrix_fill_checked(Group, joint_matrix< Group, T, Use, NumRows, NumCols, Layout > &Res, const T2 &Value, size_t Height, size_t Width, size_t CoordX, size_t CoordY)
decltype(auto) __SYCL_ALWAYS_INLINE get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, Rows, Cols, Layout > &jm)
std::enable_if_t< detail::is_bf16_storage_type< T >::value, T > fabs(T x)
std::error_code make_error_code(sycl::errc E) noexcept
Constructs an error code using e and sycl_category()
Definition: exception.cpp:65
Definition: access.hpp:18