DPC++ Runtime
Runtime libraries for oneAPI DPC++
matrix-jit.hpp
Go to the documentation of this file.
1 //==---------------- matrix-jit.hpp - SYCL matrix --------------*- C++ -*---==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // ===--------------------------------------------------------------------=== //
8 
9 #pragma once
10 
11 #include <CL/__spirv/spirv_ops.hpp>
14 #include <sycl/feature_test.hpp>
15 
16 namespace sycl {
18 namespace ext::oneapi::experimental::matrix {
19 
21 
22 template <matrix_layout Layout> struct spv_matrix_layout_traits {
23  static constexpr __spv::MatrixLayout value = __spv::MatrixLayout::RowMajor;
24 };
25 
26 #define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT) \
27  template <> struct spv_matrix_layout_traits<LAYOUT> { \
28  static constexpr __spv::MatrixLayout value = SPV_LAYOUT; \
29  };
30 
31 SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::row_major,
33 SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::col_major,
35 SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::packed_a, __spv::MatrixLayout::PackedA)
36 SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::packed_b, __spv::MatrixLayout::PackedB)
37 
38 template <typename G> struct spv_scope_traits {};
39 template <> struct spv_scope_traits<sycl::sub_group> {
40  constexpr static auto value = __spv::Scope::Subgroup;
41 };
42 template <int D> struct spv_scope_traits<sycl::group<D>> {
43  constexpr static auto value = __spv::Scope::Workgroup;
44 };
45 
46 template <typename T, size_t NumRows, size_t NumCols,
47  matrix_layout Layout = matrix_layout::row_major,
48  typename Group = sycl::sub_group>
49 class wi_data;
50 
51 template <typename T, size_t NumRows, size_t NumCols,
52  matrix_layout Layout = matrix_layout::row_major,
53  typename Group = sycl::sub_group>
54 struct joint_matrix {
55 public:
57  T, NumRows, NumCols, spv_matrix_layout_traits<Layout>::value,
59  joint_matrix(Group sg) {
60 #ifndef __SYCL_DEVICE_ONLY__
61  (void)sg;
62  throw runtime_error("joint matrix is not supported on host device.",
63  PI_ERROR_INVALID_DEVICE);
64 #endif // __SYCL_DEVICE_ONLY__
65  }
66 
70  }
71 };
72 
73 template <typename Group, typename T, size_t NumRows, size_t NumCols,
74  matrix_layout Layout = matrix_layout::row_major,
78  multi_ptr<T, Space, IsDecorated> src, size_t stride, matrix_layout MemL) {
79 #ifdef __SYCL_DEVICE_ONLY__
80  T *Ptr = src.get();
81  switch (MemL) {
82  default:
83  assert(false && "Invalid Memory Layout!");
84  case matrix_layout::row_major:
85  res.spvm =
86  __spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
88  Ptr, stride, __spv::MatrixLayout::RowMajor,
90  break;
91  case matrix_layout::col_major:
92  res.spvm =
93  __spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
97  break;
98  case matrix_layout::packed_a:
99  res.spvm =
100  __spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
102  Ptr, stride, __spv::MatrixLayout::PackedA,
104  break;
105  case matrix_layout::packed_b:
106  res.spvm =
107  __spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
109  Ptr, stride, __spv::MatrixLayout::PackedB,
111  break;
112  }
113 #else
114  (void)sg;
115  (void)res;
116  (void)src;
117  (void)stride;
118  (void)MemL;
119  throw runtime_error("joint matrix is not supported on host device.",
120  PI_ERROR_INVALID_DEVICE);
121 #endif // __SYCL_DEVICE_ONLY__
122 }
123 
124 template <typename Group, typename T, size_t NumRows, size_t NumCols,
125  matrix_layout MatL = matrix_layout::row_major,
129  multi_ptr<T, Space, IsDecorated> res, size_t stride, matrix_layout MemL) {
130 #ifdef __SYCL_DEVICE_ONLY__
131  T *Ptr = res.get();
132  switch (MemL) {
133  default:
134  assert(false && "Invalid Memory Layout!");
135  case matrix_layout::row_major:
136  __spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
138  Ptr, src.spvm, stride, __spv::MatrixLayout::RowMajor,
140  break;
141  case matrix_layout::col_major:
142  __spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
144  Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor,
146  break;
147  case matrix_layout::packed_a:
148  __spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
150  Ptr, src.spvm, stride, __spv::MatrixLayout::PackedA,
152  break;
153  case matrix_layout::packed_b:
154  __spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
156  Ptr, src.spvm, stride, __spv::MatrixLayout::PackedB,
158  break;
159  }
160 #else
161  (void)sg;
162  (void)src;
163  (void)res;
164  (void)stride;
165  (void)MemL;
166  throw runtime_error("joint matrix is not supported on host device.",
167  PI_ERROR_INVALID_DEVICE);
168 #endif // __SYCL_DEVICE_ONLY__
169 }
170 
171 template <typename Group, typename T1, typename T2, typename T3, size_t M,
172  size_t K, size_t N, matrix_layout LayoutA, matrix_layout LayoutB,
173  matrix_layout LayoutC>
174 inline __SYCL_ALWAYS_INLINE joint_matrix<T3, M, N, LayoutC, Group>
178 #ifdef __SYCL_DEVICE_ONLY__
180  if constexpr (std::is_same<T1, uint16_t>::value &&
181  std::is_same<T2, uint16_t>::value &&
182  std::is_same<T3, float>::value)
183  res.spvm = __spirv_JointMatrixMadINTEL(mA.spvm, mB.spvm, mC.spvm);
184  else if constexpr (std::is_unsigned<T1>::value && std::is_unsigned<T2>::value)
185  res.spvm = __spirv_JointMatrixUUMadINTEL(mA.spvm, mB.spvm, mC.spvm);
186  else if constexpr (std::is_signed<T1>::value && std::is_unsigned<T2>::value)
187  res.spvm = __spirv_JointMatrixSUMadINTEL(mA.spvm, mB.spvm, mC.spvm);
188  else if constexpr (std::is_unsigned<T1>::value && std::is_signed<T2>::value)
189  res.spvm = __spirv_JointMatrixUSMadINTEL(mA.spvm, mB.spvm, mC.spvm);
190  else
191  res.spvm = __spirv_JointMatrixMadINTEL(mA.spvm, mB.spvm, mC.spvm);
192  return res;
193 #else
194  (void)sg;
195  (void)mA;
196  (void)mB;
197  (void)mC;
198  throw runtime_error("joint matrix is not supported on host device.",
199  PI_ERROR_INVALID_DEVICE);
200 #endif // __SYCL_DEVICE_ONLY__
201 }
202 
203 template <typename Group, typename T, size_t NumRows, size_t NumCols,
204  matrix_layout Layout, typename T2>
205 inline __SYCL_ALWAYS_INLINE void
208  const T2 v) {
209  // We kept the unused "sg" in joint_matrix_fill to match the other DPC++
210  // functions
211  (void)sg;
212 #ifdef __SYCL_DEVICE_ONLY__
213  res.spvm =
214  __spirv_CompositeConstruct<T, NumRows, NumCols,
216  static_cast<T>(v));
217 
218 #else
219  (void)res;
220  (void)v;
221 #endif // __SYCL_DEVICE_ONLY__
222 }
223 
224 template <typename T, size_t NumRows, size_t NumCols,
225  matrix_layout Layout = matrix_layout::row_major,
226  typename Group = sycl::sub_group>
227 class wi_element {
228  joint_matrix<T, NumRows, NumCols, Layout, Group> &M;
229  std::size_t idx;
230 
231 public:
233  std::size_t i)
234  : M(Mat), idx(i) {}
235  operator T() {
236 #ifdef __SYCL_DEVICE_ONLY__
237  return __spirv_VectorExtractDynamic(M.spvm, idx);
238 #else
239  throw runtime_error("joint matrix is not supported on host device.",
240  PI_ERROR_INVALID_DEVICE);
241 #endif // __SYCL_DEVICE_ONLY__
242  }
243 
244  explicit operator bool() {
245 #ifdef __SYCL_DEVICE_ONLY__
246  return __spirv_VectorExtractDynamic(M.spvm, idx) != static_cast<T>(0);
247 #else
248  throw runtime_error("joint matrix is not supported on host device.",
249  PI_ERROR_INVALID_DEVICE);
250 #endif // __SYCL_DEVICE_ONLY__
251  }
252 
253  template <typename T2> wi_element &operator=(const T2 &rhs) {
254 #ifdef __SYCL_DEVICE_ONLY__
255  M.spvm = __spirv_VectorInsertDynamic(M.spvm, static_cast<T>(rhs), idx);
256  return *this;
257 #else
258  (void)rhs;
259  throw runtime_error("joint matrix is not supported on host device.",
260  PI_ERROR_INVALID_DEVICE);
261 #endif // __SYCL_DEVICE_ONLY__
262  }
263 
264  wi_element &
266 #ifdef __SYCL_DEVICE_ONLY__
267  M.spvm = __spirv_VectorInsertDynamic(
268  M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
269  return *this;
270 #else
271  (void)rhs;
272  throw runtime_error("joint matrix is not supported on host device.",
273  PI_ERROR_INVALID_DEVICE);
274 #endif // __SYCL_DEVICE_ONLY__
275  }
276 
277 #if __SYCL_DEVICE_ONLY__
278 #define OP(op) \
279  template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
280  M.spvm = __spirv_VectorInsertDynamic( \
281  M.spvm, \
282  static_cast<T>(__spirv_VectorExtractDynamic(M.spvm, idx) \
283  op static_cast<T>(rhs)), \
284  idx); \
285  return *this; \
286  }
287 #else // __SYCL_DEVICE_ONLY__
288 #define OP(op) \
289  template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
290  (void)rhs; \
291  throw runtime_error("joint matrix is not supported on host device.", \
292  PI_ERROR_INVALID_DEVICE); \
293  }
294 #endif // __SYCL_DEVICE_ONLY__
295  OP(+)
296  OP(-)
297  OP(*)
298  OP(/)
299 #undef OP
300 };
301 
302 // Note that similarly to the other matrix functions, uint16_t is used here to
303 // represent bf16 type. Since the AMX and DPAS implementations don't support
304 // uint16_t, this interpretation is possible. This design choice was made before
305 // the introduction of SYCL experimental bfloat16 type. Our plan is to move
306 // towards using the SYCL bfloat16. But since it is still experimental, we will
307 // probably keep both uint16 interpretation and SYCL bfloat16.
308 template <size_t NumRows, size_t NumCols, matrix_layout Layout, typename Group>
309 class wi_element<uint16_t, NumRows, NumCols, Layout, Group> {
311  std::size_t idx;
312 
313 public:
315  std::size_t i)
316  : M(Mat), idx(i) {}
317  operator uint16_t() {
318 #ifdef __SYCL_DEVICE_ONLY__
319  return __spirv_VectorExtractDynamic(M.spvm, idx);
320 #else
321  throw runtime_error("joint matrix is not supported on host device.",
322  PI_ERROR_INVALID_DEVICE);
323 #endif // __SYCL_DEVICE_ONLY__
324  }
325 
326  explicit operator bool() {
327 #ifdef __SYCL_DEVICE_ONLY__
328  return std::fabs(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx))) >=
329  std::numeric_limits<float>::epsilon();
330 #else
331  throw runtime_error("joint matrix is not supported on host device.",
332  PI_ERROR_INVALID_DEVICE);
333 #endif // __SYCL_DEVICE_ONLY__
334  }
335 
336  wi_element &operator=(const uint16_t &rhs) {
337 #ifdef __SYCL_DEVICE_ONLY__
338  M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx);
339  return *this;
340 #else
341  (void)rhs;
342  throw runtime_error("joint matrix is not supported on host device.",
343  PI_ERROR_INVALID_DEVICE);
344 #endif // __SYCL_DEVICE_ONLY__
345  }
346 
347  wi_element &
349 #ifdef __SYCL_DEVICE_ONLY__
350  M.spvm = __spirv_VectorInsertDynamic(
351  M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
352  return *this;
353 #else
354  (void)rhs;
355  throw runtime_error("joint matrix is not supported on host device.",
356  PI_ERROR_INVALID_DEVICE);
357 #endif // __SYCL_DEVICE_ONLY__
358  }
359 
360  // We use here the following functions for conversion (bf16=>fp32 and
361  // fp32=>bf16). This is a workaround until we are able to use
362  // __spirv_ConvertFToBF16INTEL and __spirv_ConvertBF16ToFINTEL once these are
363  // supported in the CPU backend
364  static float make_fp32(uint16_t x) {
365  unsigned int y = x;
366  y = y << 16;
367  float *res = reinterpret_cast<float *>(&y);
368  return *res;
369  }
370 
371  static uint16_t make_bf16(float x) {
372  int *res = reinterpret_cast<int *>(&x);
373  *res = *res >> 16;
374  return (uint16_t)*res;
375  }
376 
377 #if __SYCL_DEVICE_ONLY__
378 #define OP(op) \
379  wi_element &operator op##=(const uint16_t &rhs) { \
380  M.spvm = __spirv_VectorInsertDynamic( \
381  M.spvm, \
382  make_bf16(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx) \
383  op make_fp32(rhs))), \
384  idx); \
385  return *this; \
386  }
387 #else // __SYCL_DEVICE_ONLY__
388 #define OP(op) \
389  wi_element &operator op##=(const uint16_t &rhs) { \
390  (void)rhs; \
391  throw runtime_error("joint matrix is not supported on host device.", \
392  PI_ERROR_INVALID_DEVICE); \
393  }
394 #endif // __SYCL_DEVICE_ONLY__
395  OP(+)
396  OP(-)
397  OP(*)
398  OP(/)
399 #undef OP
400 
401  template <typename T1, typename T2> struct Converter {
402  static T2 convert(const T1 &from) { return static_cast<T2>(from); }
403  };
404 
405  template <typename T> struct Converter<T, uint16_t> {
406  static uint16_t convert(const T &from) { return make_bf16(from); }
407  };
408 #if __SYCL_DEVICE_ONLY__
409 #define OP(input_type, type, op) \
410  friend type operator op( \
411  const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &lhs, \
412  const uint16_t &rhs) { \
413  return Converter<input_type, type>::convert(make_fp32( \
414  __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) op make_fp32(rhs)); \
415  } \
416  friend type operator op( \
417  const uint16_t &lhs, \
418  const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &rhs) { \
419  return Converter<input_type, type>::convert(make_fp32( \
420  __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx)) op make_fp32(lhs)); \
421  }
422 #else // __SYCL_DEVICE_ONLY__
423 #define OP(input_type, type, op) \
424  friend type operator op( \
425  const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &lhs, \
426  const uint16_t &rhs) { \
427  (void)lhs; \
428  (void)rhs; \
429  throw runtime_error("joint matrix is not supported on host device.", \
430  PI_ERROR_INVALID_DEVICE); \
431  } \
432  friend type operator op( \
433  const uint16_t &lhs, \
434  const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &rhs) { \
435  (void)lhs; \
436  (void)rhs; \
437  throw runtime_error("joint matrix is not supported on host device.", \
438  PI_ERROR_INVALID_DEVICE); \
439  }
440 #endif // __SYCL_DEVICE_ONLY__
441  OP(float, uint16_t, +)
442  OP(float, uint16_t, -)
443  OP(float, uint16_t, *)
444  OP(float, uint16_t, /)
445  OP(bool, bool, ==)
446  OP(bool, bool, !=)
447  OP(bool, bool, <)
448  OP(bool, bool, >)
449  OP(bool, bool, <=)
450  OP(bool, bool, >=)
451 #undef OP
452 };
453 
454 template <size_t NumRows, size_t NumCols, matrix_layout Layout, typename Group>
455 class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, Group> {
457  std::size_t idx;
458 
459 public:
460  wi_element(joint_matrix<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout,
461  Group> &Mat,
462  std::size_t i)
463  : M(Mat), idx(i) {}
464  operator sycl::ext::oneapi::bfloat16() {
465 #ifdef __SYCL_DEVICE_ONLY__
466  return __spirv_VectorExtractDynamic(M.spvm, idx);
467 #else
468  throw runtime_error("joint matrix is not supported on host device.",
469  PI_ERROR_INVALID_DEVICE);
470 #endif // __SYCL_DEVICE_ONLY__
471  }
472 
473  explicit operator bool() {
474 #ifdef __SYCL_DEVICE_ONLY__
475  return std::fabs(static_cast<float>(__spirv_VectorExtractDynamic(
476  M.spvm, idx))) >= std::numeric_limits<float>::epsilon();
477 #else
478  throw runtime_error("joint matrix is not supported on host device.",
479  PI_ERROR_INVALID_DEVICE);
480 #endif // __SYCL_DEVICE_ONLY__
481  }
482 
483  wi_element &operator=(const sycl::ext::oneapi::bfloat16 &rhs) {
484 #ifdef __SYCL_DEVICE_ONLY__
485  M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx);
486  return *this;
487 #else
488  (void)rhs;
489  throw runtime_error("joint matrix is not supported on host device.",
490  PI_ERROR_INVALID_DEVICE);
491 #endif // __SYCL_DEVICE_ONLY__
492  }
493 
494  wi_element &operator=(const wi_element<sycl::ext::oneapi::bfloat16, NumRows,
495  NumCols, Layout, Group> &rhs) {
496 #ifdef __SYCL_DEVICE_ONLY__
497  M.spvm = __spirv_VectorInsertDynamic(
498  M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
499  return *this;
500 #else
501  (void)rhs;
502  throw runtime_error("joint matrix is not supported on host device.",
503  PI_ERROR_INVALID_DEVICE);
504 #endif // __SYCL_DEVICE_ONLY__
505  }
506 
507 #if __SYCL_DEVICE_ONLY__
508 #define OP(opassign, op) \
509  wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \
510  M.spvm = __spirv_VectorInsertDynamic( \
511  M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) op rhs, idx); \
512  return *this; \
513  }
514 #else // __SYCL_DEVICE_ONLY__
515 #define OP(opassign, op) \
516  wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \
517  (void)rhs; \
518  throw runtime_error("joint matrix is not supported on host device.", \
519  PI_ERROR_INVALID_DEVICE); \
520  }
521 #endif // __SYCL_DEVICE_ONLY__
522  OP(+=, +)
523  OP(-=, -)
524  OP(*=, *)
525  OP(/=, /)
526 #undef OP
527 
528 #if __SYCL_DEVICE_ONLY__
529 #define OP(type, op) \
530  friend type operator op( \
531  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
532  Group> &lhs, \
533  const sycl::ext::oneapi::bfloat16 &rhs) { \
534  return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) op rhs; \
535  } \
536  friend type operator op( \
537  const sycl::ext::oneapi::bfloat16 &lhs, \
538  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
539  Group> &rhs) { \
540  return __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx) op lhs; \
541  }
542  OP(sycl::ext::oneapi::bfloat16, +)
543  OP(sycl::ext::oneapi::bfloat16, -)
544  OP(sycl::ext::oneapi::bfloat16, *)
545  OP(sycl::ext::oneapi::bfloat16, /)
546 #undef OP
547 #define OP(type, op) \
548  friend type operator op( \
549  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
550  Group> &lhs, \
551  const sycl::ext::oneapi::bfloat16 &rhs) { \
552  return type{static_cast<float>(__spirv_VectorExtractDynamic( \
553  lhs.M.spvm, lhs.idx)) op static_cast<float>(rhs)}; \
554  } \
555  friend type operator op( \
556  const sycl::ext::oneapi::bfloat16 &lhs, \
557  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
558  Group> &rhs) { \
559  return type{static_cast<float>(__spirv_VectorExtractDynamic( \
560  rhs.M.spvm, rhs.idx)) op static_cast<float>(lhs)}; \
561  }
562  OP(bool, ==)
563  OP(bool, !=)
564  OP(bool, <)
565  OP(bool, >)
566  OP(bool, <=)
567  OP(bool, >=)
568 #undef OP
569 #else // __SYCL_DEVICE_ONLY__
570 #define OP(type, op) \
571  friend type operator op(const wi_element<sycl::ext::oneapi::bfloat16, \
572  NumRows, NumCols, Layout, Group> &, \
573  const sycl::ext::oneapi::bfloat16 &) { \
574  throw runtime_error("joint matrix is not supported on host device.", \
575  PI_ERROR_INVALID_DEVICE); \
576  } \
577  friend type operator op( \
578  const sycl::ext::oneapi::bfloat16 &, \
579  const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
580  Group> &) { \
581  throw runtime_error("joint matrix is not supported on host device.", \
582  PI_ERROR_INVALID_DEVICE); \
583  }
584  OP(sycl::ext::oneapi::bfloat16, +)
585  OP(sycl::ext::oneapi::bfloat16, -)
586  OP(sycl::ext::oneapi::bfloat16, *)
587  OP(sycl::ext::oneapi::bfloat16, /)
588  OP(bool, ==)
589  OP(bool, !=)
590  OP(bool, <)
591  OP(bool, >)
592  OP(bool, <=)
593  OP(bool, >=)
594 #undef OP
595 #endif // __SYCL_DEVICE_ONLY__
596 };
597 
598 template <typename T, size_t NumRows, size_t NumCols, matrix_layout Layout,
599  typename Group>
600 class wi_data {
601  joint_matrix<T, NumRows, NumCols, Layout, Group> &M;
602 
603 public:
605  size_t length() {
606 #ifdef __SYCL_DEVICE_ONLY__
607  return __spirv_JointMatrixWorkItemLengthINTEL(M.spvm);
608 #else
609  throw runtime_error("joint matrix is not supported on host device.",
610  PI_ERROR_INVALID_DEVICE);
611 #endif // __SYCL_DEVICE_ONLY__
612  }
615  }
616 };
617 
618 #undef SPV_MATRIX_LAYOUT_TRAITS
619 
620 } // namespace ext::oneapi::experimental::matrix
621 } // __SYCL_INLINE_VER_NAMESPACE(_V1)
622 } // namespace sycl
wi_element< T, NumRows, NumCols, Layout, Group > operator[](size_t i)
Definition: matrix-jit.hpp:613
wi_data(joint_matrix< T, NumRows, NumCols, Layout, Group > &Mat)
Definition: matrix-jit.hpp:604
wi_element & operator=(const wi_element< sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, Group > &rhs)
Definition: matrix-jit.hpp:494
wi_element(joint_matrix< sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, Group > &Mat, std::size_t i)
Definition: matrix-jit.hpp:460
wi_element & operator=(const wi_element< uint16_t, NumRows, NumCols, Layout, Group > &rhs)
Definition: matrix-jit.hpp:348
wi_element(joint_matrix< uint16_t, NumRows, NumCols, Layout, Group > &Mat, std::size_t i)
Definition: matrix-jit.hpp:314
wi_element & operator=(const wi_element< T, NumRows, NumCols, Layout, Group > &rhs)
Definition: matrix-jit.hpp:265
wi_element(joint_matrix< T, NumRows, NumCols, Layout, Group > &Mat, std::size_t i)
Definition: matrix-jit.hpp:232
Provides constructors for address space qualified and non address space qualified pointers to allow i...
Definition: multi_ptr.hpp:78
pointer get() const
Definition: multi_ptr.hpp:247
#define __SYCL_INLINE_VER_NAMESPACE(X)
#define __SYCL_ALWAYS_INLINE
#define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT)
Definition: matrix-jit.hpp:26
#define OP(op)
Definition: matrix-jit.hpp:570
__SYCL_ALWAYS_INLINE void joint_matrix_store(Group sg, joint_matrix< T, NumRows, NumCols, MatL, Group > &src, multi_ptr< T, Space, IsDecorated > res, size_t stride, matrix_layout MemL)
Definition: matrix-jit.hpp:127
__SYCL_ALWAYS_INLINE void joint_matrix_load(Group sg, joint_matrix< T, NumRows, NumCols, Layout, Group > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, matrix_layout MemL)
Definition: matrix-jit.hpp:76
__SYCL_ALWAYS_INLINE void joint_matrix_fill(Group sg, joint_matrix< T, NumRows, NumCols, Layout, Group > &res, const T2 v)
Definition: matrix-jit.hpp:206
__SYCL_ALWAYS_INLINE joint_matrix< T3, M, N, LayoutC, Group > joint_matrix_mad(Group sg, joint_matrix< T1, M, K, LayoutA, Group > &mA, joint_matrix< T2, K, N, LayoutB, Group > &mB, joint_matrix< T3, M, N, LayoutC, Group > &mC)
Definition: matrix-jit.hpp:175
std::enable_if_t< detail::is_bf16_storage_type< T >::value, T > fabs(T x)
---— Error handling, matching OpenCL plugin semantics.
Definition: access.hpp:14
__spv::__spirv_JointMatrixINTEL< T, NumRows, NumCols, spv_matrix_layout_traits< Layout >::value, spv_scope_traits< Group >::value > * spvm
Definition: matrix-jit.hpp:58
__SYCL_ALWAYS_INLINE wi_data< T, NumRows, NumCols, Layout, Group > get_wi_data()
Definition: matrix-jit.hpp:68