DPC++ Runtime
Runtime libraries for oneAPI DPC++
matrix-jit.hpp
Go to the documentation of this file.
1 //==---------------- matrix-jit.hpp - SYCL matrix --------------*- C++ -*---==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // ===--------------------------------------------------------------------=== //
8 
9 #pragma once
10 
11 #include <CL/__spirv/spirv_ops.hpp>
13 #include <CL/sycl/feature_test.hpp>
15 
17 namespace sycl {
18 namespace ext {
19 namespace oneapi {
20 namespace experimental::matrix {
21 
22 enum class matrix_layout { row_major, col_major, packed_a, packed_b };
23 
24 template <matrix_layout Layout> struct spv_matrix_layout_traits {
26 };
27 
28 #define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT) \
29  template <> struct spv_matrix_layout_traits<LAYOUT> { \
30  static constexpr __spv::MatrixLayout value = SPV_LAYOUT; \
31  };
32 
33 SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::row_major,
35 SPV_MATRIX_LAYOUT_TRAITS(matrix_layout::col_major,
39 
40 template <typename G> struct spv_scope_traits {};
41 template <> struct spv_scope_traits<sycl::sub_group> {
42  constexpr static auto value = __spv::Scope::Subgroup;
43 };
44 template <int D> struct spv_scope_traits<sycl::group<D>> {
45  constexpr static auto value = __spv::Scope::Workgroup;
46 };
47 
48 template <typename T, size_t NumRows, size_t NumCols,
49  matrix_layout Layout = matrix_layout::row_major,
50  typename Group = sycl::sub_group>
51 class wi_data;
52 
53 template <typename T, size_t NumRows, size_t NumCols,
54  matrix_layout Layout = matrix_layout::row_major,
55  typename Group = sycl::sub_group>
56 struct joint_matrix {
57 public:
60  joint_matrix(Group sg) {
61 #ifndef __SYCL_DEVICE_ONLY__
62  (void)sg;
63  throw runtime_error("joint matrix is not supported on host device.",
64  PI_ERROR_INVALID_DEVICE);
65 #endif // __SYCL_DEVICE_ONLY__
66  }
67 
71  }
72 };
73 
74 template <typename Group, typename T, size_t NumRows, size_t NumCols,
75  matrix_layout Layout = matrix_layout::row_major,
77 inline __SYCL_ALWAYS_INLINE void
80  multi_ptr<T, Space> src, size_t stride, matrix_layout MemL) {
81 #ifdef __SYCL_DEVICE_ONLY__
82  T *Ptr = src.get();
83  switch (MemL) {
84  default:
85  assert(false && "Invalid Memory Layout!");
86  case matrix_layout::row_major:
87  res.spvm =
88  __spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
90  Ptr, stride, __spv::MatrixLayout::RowMajor,
91  spv_scope_traits<Group>::value);
92  break;
93  case matrix_layout::col_major:
94  res.spvm =
95  __spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
98  spv_scope_traits<Group>::value);
99  break;
100  case matrix_layout::packed_a:
101  res.spvm =
102  __spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
104  Ptr, stride, __spv::MatrixLayout::PackedA,
105  spv_scope_traits<Group>::value);
106  break;
107  case matrix_layout::packed_b:
108  res.spvm =
109  __spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
111  Ptr, stride, __spv::MatrixLayout::PackedB,
112  spv_scope_traits<Group>::value);
113  break;
114  }
115 #else
116  (void)sg;
117  (void)res;
118  (void)src;
119  (void)stride;
120  (void)MemL;
121  throw runtime_error("joint matrix is not supported on host device.",
122  PI_ERROR_INVALID_DEVICE);
123 #endif // __SYCL_DEVICE_ONLY__
124 }
125 
126 template <typename Group, typename T, size_t NumRows, size_t NumCols,
127  matrix_layout MatL = matrix_layout::row_major,
128  access::address_space Space>
129 inline __SYCL_ALWAYS_INLINE void
132  multi_ptr<T, Space> res, size_t stride, matrix_layout MemL) {
133 #ifdef __SYCL_DEVICE_ONLY__
134  T *Ptr = res.get();
135  switch (MemL) {
136  default:
137  assert(false && "Invalid Memory Layout!");
138  case matrix_layout::row_major:
139  __spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
141  Ptr, src.spvm, stride, __spv::MatrixLayout::RowMajor,
142  spv_scope_traits<Group>::value);
143  break;
144  case matrix_layout::col_major:
145  __spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
147  Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor,
148  spv_scope_traits<Group>::value);
149  break;
150  case matrix_layout::packed_a:
151  __spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
153  Ptr, src.spvm, stride, __spv::MatrixLayout::PackedA,
154  spv_scope_traits<Group>::value);
155  break;
156  case matrix_layout::packed_b:
157  __spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
159  Ptr, src.spvm, stride, __spv::MatrixLayout::PackedB,
160  spv_scope_traits<Group>::value);
161  break;
162  }
163 #else
164  (void)sg;
165  (void)src;
166  (void)res;
167  (void)stride;
168  (void)MemL;
169  throw runtime_error("joint matrix is not supported on host device.",
170  PI_ERROR_INVALID_DEVICE);
171 #endif // __SYCL_DEVICE_ONLY__
172 }
173 
174 template <typename Group, typename T1, typename T2, typename T3, size_t M,
175  size_t K, size_t N, matrix_layout LayoutA, matrix_layout LayoutB,
176  matrix_layout LayoutC>
177 inline __SYCL_ALWAYS_INLINE joint_matrix<T3, M, N, LayoutC, Group>
181 #ifdef __SYCL_DEVICE_ONLY__
183  if constexpr (std::is_same<T1, uint16_t>::value &&
184  std::is_same<T2, uint16_t>::value &&
185  std::is_same<T3, float>::value)
186  res.spvm = __spirv_JointMatrixMadINTEL(mA.spvm, mB.spvm, mC.spvm);
187  else if constexpr (std::is_unsigned<T1>::value && std::is_unsigned<T2>::value)
188  res.spvm = __spirv_JointMatrixUUMadINTEL(mA.spvm, mB.spvm, mC.spvm);
189  else if constexpr (std::is_signed<T1>::value && std::is_unsigned<T2>::value)
190  res.spvm = __spirv_JointMatrixSUMadINTEL(mA.spvm, mB.spvm, mC.spvm);
191  else if constexpr (std::is_unsigned<T1>::value && std::is_signed<T2>::value)
192  res.spvm = __spirv_JointMatrixUSMadINTEL(mA.spvm, mB.spvm, mC.spvm);
193  else
194  res.spvm = __spirv_JointMatrixMadINTEL(mA.spvm, mB.spvm, mC.spvm);
195  return res;
196 #else
197  (void)sg;
198  (void)mA;
199  (void)mB;
200  (void)mC;
201  throw runtime_error("joint matrix is not supported on host device.",
202  PI_ERROR_INVALID_DEVICE);
203 #endif // __SYCL_DEVICE_ONLY__
204 }
205 
206 template <typename Group, typename T, size_t NumRows, size_t NumCols,
207  matrix_layout Layout, typename T2>
208 inline __SYCL_ALWAYS_INLINE void
211  const T2 v) {
212  // We kept the unused "sg" in joint_matrix_fill to match the other DPC++
213  // functions
214  (void)sg;
215 #ifdef __SYCL_DEVICE_ONLY__
216  res.spvm =
217  __spirv_CompositeConstruct<T, NumRows, NumCols,
219  static_cast<T>(v));
220 
221 #else
222  (void)res;
223  (void)v;
224 #endif // __SYCL_DEVICE_ONLY__
225 }
226 
227 template <typename T, size_t NumRows, size_t NumCols,
228  matrix_layout Layout = matrix_layout::row_major,
229  typename Group = sycl::sub_group>
230 class wi_element {
232  std::size_t idx;
233 
234 public:
236  std::size_t i)
237  : M(Mat), idx(i) {}
238  operator T() {
239 #ifdef __SYCL_DEVICE_ONLY__
240  return __spirv_VectorExtractDynamic(M.spvm, idx);
241 #else
242  throw runtime_error("joint matrix is not supported on host device.",
243  PI_ERROR_INVALID_DEVICE);
244 #endif // __SYCL_DEVICE_ONLY__
245  }
246 
247  explicit operator bool() {
248 #ifdef __SYCL_DEVICE_ONLY__
249  return __spirv_VectorExtractDynamic(M.spvm, idx) != static_cast<T>(0);
250 #else
251  throw runtime_error("joint matrix is not supported on host device.",
252  PI_ERROR_INVALID_DEVICE);
253 #endif // __SYCL_DEVICE_ONLY__
254  }
255 
256  template <typename T2> wi_element &operator=(const T2 &rhs) {
257 #ifdef __SYCL_DEVICE_ONLY__
258  M.spvm = __spirv_VectorInsertDynamic(M.spvm, static_cast<T>(rhs), idx);
259  return *this;
260 #else
261  (void)rhs;
262  throw runtime_error("joint matrix is not supported on host device.",
263  PI_ERROR_INVALID_DEVICE);
264 #endif // __SYCL_DEVICE_ONLY__
265  }
266 
267  wi_element &
269 #ifdef __SYCL_DEVICE_ONLY__
270  M.spvm = __spirv_VectorInsertDynamic(
271  M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
272  return *this;
273 #else
274  (void)rhs;
275  throw runtime_error("joint matrix is not supported on host device.",
276  PI_ERROR_INVALID_DEVICE);
277 #endif // __SYCL_DEVICE_ONLY__
278  }
279 
280 #if __SYCL_DEVICE_ONLY__
281 #define OP(op) \
282  template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
283  M.spvm = __spirv_VectorInsertDynamic( \
284  M.spvm, \
285  static_cast<T>(__spirv_VectorExtractDynamic(M.spvm, idx) \
286  op static_cast<T>(rhs)), \
287  idx); \
288  return *this; \
289  }
290 #else // __SYCL_DEVICE_ONLY__
291 #define OP(op) \
292  template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
293  (void)rhs; \
294  throw runtime_error("joint matrix is not supported on host device.", \
295  PI_ERROR_INVALID_DEVICE); \
296  }
297 #endif // __SYCL_DEVICE_ONLY__
298  OP(+)
299  OP(-)
300  OP(*)
301  OP(/)
302 #undef OP
303 };
304 
305 // Note that similarly to the other matrix functions, uint16_t is used here to
306 // represent bf16 type. Since the AMX and DPAS implementations don't support
307 // uint16_t, this interpretation is possible. This design choice was made before
308 // the introduction of SYCL experimental bfloat16 type. Our plan is to move
309 // towards using the SYCL bfloat16. But since it is still experimental, we will
310 // probably keep both uint16 interpretation and SYCL bfloat16.
311 template <size_t NumRows, size_t NumCols, matrix_layout Layout, typename Group>
312 class wi_element<uint16_t, NumRows, NumCols, Layout, Group> {
314  std::size_t idx;
315 
316 public:
318  std::size_t i)
319  : M(Mat), idx(i) {}
320  operator uint16_t() {
321 #ifdef __SYCL_DEVICE_ONLY__
322  return __spirv_VectorExtractDynamic(M.spvm, idx);
323 #else
324  throw runtime_error("joint matrix is not supported on host device.",
325  PI_ERROR_INVALID_DEVICE);
326 #endif // __SYCL_DEVICE_ONLY__
327  }
328 
329  explicit operator bool() {
330 #ifdef __SYCL_DEVICE_ONLY__
331  return std::fabs(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx))) >=
332  std::numeric_limits<float>::epsilon();
333 #else
334  throw runtime_error("joint matrix is not supported on host device.",
335  PI_ERROR_INVALID_DEVICE);
336 #endif // __SYCL_DEVICE_ONLY__
337  }
338 
339  wi_element &operator=(const uint16_t &rhs) {
340 #ifdef __SYCL_DEVICE_ONLY__
341  M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx);
342  return *this;
343 #else
344  (void)rhs;
345  throw runtime_error("joint matrix is not supported on host device.",
346  PI_ERROR_INVALID_DEVICE);
347 #endif // __SYCL_DEVICE_ONLY__
348  }
349 
350  wi_element &
352 #ifdef __SYCL_DEVICE_ONLY__
353  M.spvm = __spirv_VectorInsertDynamic(
354  M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
355  return *this;
356 #else
357  (void)rhs;
358  throw runtime_error("joint matrix is not supported on host device.",
359  PI_ERROR_INVALID_DEVICE);
360 #endif // __SYCL_DEVICE_ONLY__
361  }
362 
363  // We use here the following functions for conversion (bf16=>fp32 and
364  // fp32=>bf16). This is a workaround until we are able to use
365  // __spirv_ConvertFToBF16INTEL and __spirv_ConvertBF16ToFINTEL once these are
366  // supported in the CPU backend
367  static float make_fp32(uint16_t x) {
368  unsigned int y = x;
369  y = y << 16;
370  float *res = reinterpret_cast<float *>(&y);
371  return *res;
372  }
373 
374  static uint16_t make_bf16(float x) {
375  int *res = reinterpret_cast<int *>(&x);
376  *res = *res >> 16;
377  return (uint16_t)*res;
378  }
379 
380 #if __SYCL_DEVICE_ONLY__
381 #define OP(op) \
382  wi_element &operator op##=(const uint16_t &rhs) { \
383  M.spvm = __spirv_VectorInsertDynamic( \
384  M.spvm, \
385  make_bf16(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx) \
386  op make_fp32(rhs))), \
387  idx); \
388  return *this; \
389  }
390 #else // __SYCL_DEVICE_ONLY__
391 #define OP(op) \
392  wi_element &operator op##=(const uint16_t &rhs) { \
393  (void)rhs; \
394  throw runtime_error("joint matrix is not supported on host device.", \
395  PI_ERROR_INVALID_DEVICE); \
396  }
397 #endif // __SYCL_DEVICE_ONLY__
398  OP(+)
399  OP(-)
400  OP(*)
401  OP(/)
402 #undef OP
403 
404  template <typename T1, typename T2> struct Converter {
405  static T2 convert(const T1 &from) { return static_cast<T2>(from); }
406  };
407 
408  template <typename T> struct Converter<T, uint16_t> {
409  static uint16_t convert(const T &from) { return make_bf16(from); }
410  };
411 #if __SYCL_DEVICE_ONLY__
412 #define OP(input_type, type, op) \
413  friend type operator op( \
414  const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &lhs, \
415  const uint16_t &rhs) { \
416  return Converter<input_type, type>::convert(make_fp32( \
417  __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) op make_fp32(rhs)); \
418  } \
419  friend type operator op( \
420  const uint16_t &lhs, \
421  const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &rhs) { \
422  return Converter<input_type, type>::convert(make_fp32( \
423  __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx)) op make_fp32(lhs)); \
424  }
425 #else // __SYCL_DEVICE_ONLY__
426 #define OP(input_type, type, op) \
427  friend type operator op( \
428  const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &lhs, \
429  const uint16_t &rhs) { \
430  (void)lhs; \
431  (void)rhs; \
432  throw runtime_error("joint matrix is not supported on host device.", \
433  PI_ERROR_INVALID_DEVICE); \
434  } \
435  friend type operator op( \
436  const uint16_t &lhs, \
437  const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &rhs) { \
438  (void)lhs; \
439  (void)rhs; \
440  throw runtime_error("joint matrix is not supported on host device.", \
441  PI_ERROR_INVALID_DEVICE); \
442  }
443 #endif // __SYCL_DEVICE_ONLY__
444  OP(float, uint16_t, +)
445  OP(float, uint16_t, -)
446  OP(float, uint16_t, *)
447  OP(float, uint16_t, /)
448  OP(bool, bool, ==)
449  OP(bool, bool, !=)
450  OP(bool, bool, <)
451  OP(bool, bool, >)
452  OP(bool, bool, <=)
453  OP(bool, bool, >=)
454 #undef OP
455 };
456 
457 template <size_t NumRows, size_t NumCols, matrix_layout Layout, typename Group>
459  Layout, Group> {
461  Layout, Group> &M;
462  std::size_t idx;
463 
464 public:
466  NumCols, Layout, Group> &Mat,
467  std::size_t i)
468  : M(Mat), idx(i) {}
470 #ifdef __SYCL_DEVICE_ONLY__
471  return __spirv_VectorExtractDynamic(M.spvm, idx);
472 #else
473  throw runtime_error("joint matrix is not supported on host device.",
474  PI_ERROR_INVALID_DEVICE);
475 #endif // __SYCL_DEVICE_ONLY__
476  }
477 
478  explicit operator bool() {
479 #ifdef __SYCL_DEVICE_ONLY__
480  return std::fabs(static_cast<float>(__spirv_VectorExtractDynamic(
481  M.spvm, idx))) >= std::numeric_limits<float>::epsilon();
482 #else
483  throw runtime_error("joint matrix is not supported on host device.",
484  PI_ERROR_INVALID_DEVICE);
485 #endif // __SYCL_DEVICE_ONLY__
486  }
487 
489 #ifdef __SYCL_DEVICE_ONLY__
490  M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx);
491  return *this;
492 #else
493  (void)rhs;
494  throw runtime_error("joint matrix is not supported on host device.",
495  PI_ERROR_INVALID_DEVICE);
496 #endif // __SYCL_DEVICE_ONLY__
497  }
498 
499  wi_element &
501  NumCols, Layout, Group> &rhs) {
502 #ifdef __SYCL_DEVICE_ONLY__
503  M.spvm = __spirv_VectorInsertDynamic(
504  M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
505  return *this;
506 #else
507  (void)rhs;
508  throw runtime_error("joint matrix is not supported on host device.",
509  PI_ERROR_INVALID_DEVICE);
510 #endif // __SYCL_DEVICE_ONLY__
511  }
512 
513 #if __SYCL_DEVICE_ONLY__
514 #define OP(opassign, op) \
515  wi_element &operator opassign( \
516  const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \
517  M.spvm = __spirv_VectorInsertDynamic( \
518  M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) op rhs, idx); \
519  return *this; \
520  }
521 #else // __SYCL_DEVICE_ONLY__
522 #define OP(opassign, op) \
523  wi_element &operator opassign( \
524  const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \
525  (void)rhs; \
526  throw runtime_error("joint matrix is not supported on host device.", \
527  PI_ERROR_INVALID_DEVICE); \
528  }
529 #endif // __SYCL_DEVICE_ONLY__
530  OP(+=, +)
531  OP(-=, -)
532  OP(*=, *)
533  OP(/=, /)
534 #undef OP
535 
536 #if __SYCL_DEVICE_ONLY__
537 #define OP(type, op) \
538  friend type operator op( \
539  const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
540  NumCols, Layout, Group> &lhs, \
541  const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \
542  return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) op rhs; \
543  } \
544  friend type operator op( \
545  const sycl::ext::oneapi::experimental::bfloat16 &lhs, \
546  const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
547  NumCols, Layout, Group> &rhs) { \
548  return __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx) op lhs; \
549  }
551  OP(sycl::ext::oneapi::experimental::bfloat16, -)
552  OP(sycl::ext::oneapi::experimental::bfloat16, *)
553  OP(sycl::ext::oneapi::experimental::bfloat16, /)
554 #undef OP
555 #define OP(type, op) \
556  friend type operator op( \
557  const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
558  NumCols, Layout, Group> &lhs, \
559  const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \
560  return type{static_cast<float>(__spirv_VectorExtractDynamic( \
561  lhs.M.spvm, lhs.idx)) op static_cast<float>(rhs)}; \
562  } \
563  friend type operator op( \
564  const sycl::ext::oneapi::experimental::bfloat16 &lhs, \
565  const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
566  NumCols, Layout, Group> &rhs) { \
567  return type{static_cast<float>(__spirv_VectorExtractDynamic( \
568  rhs.M.spvm, rhs.idx)) op static_cast<float>(lhs)}; \
569  }
570  OP(bool, ==)
571  OP(bool, !=)
572  OP(bool, <)
573  OP(bool, >)
574  OP(bool, <=)
575  OP(bool, >=)
576 #undef OP
577 #else // __SYCL_DEVICE_ONLY__
578 #define OP(type, op) \
579  friend type operator op( \
580  const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
581  NumCols, Layout, Group> &, \
582  const sycl::ext::oneapi::experimental::bfloat16 &) { \
583  throw runtime_error("joint matrix is not supported on host device.", \
584  PI_ERROR_INVALID_DEVICE); \
585  } \
586  friend type operator op( \
587  const sycl::ext::oneapi::experimental::bfloat16 &, \
588  const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
589  NumCols, Layout, Group> &) { \
590  throw runtime_error("joint matrix is not supported on host device.", \
591  PI_ERROR_INVALID_DEVICE); \
592  }
593  OP(sycl::ext::oneapi::experimental::bfloat16, +)
594  OP(sycl::ext::oneapi::experimental::bfloat16, -)
595  OP(sycl::ext::oneapi::experimental::bfloat16, *)
596  OP(sycl::ext::oneapi::experimental::bfloat16, /)
597  OP(bool, ==)
598  OP(bool, !=)
599  OP(bool, <)
600  OP(bool, >)
601  OP(bool, <=)
602  OP(bool, >=)
603 #undef OP
604 #endif // __SYCL_DEVICE_ONLY__
605 };
606 
607 template <typename T, size_t NumRows, size_t NumCols, matrix_layout Layout,
608  typename Group>
609 class wi_data {
610  joint_matrix<T, NumRows, NumCols, Layout, Group> &M;
611 
612 public:
614  size_t length() {
615 #ifdef __SYCL_DEVICE_ONLY__
616  return __spirv_JointMatrixWorkItemLengthINTEL(M.spvm);
617 #else
618  throw runtime_error("joint matrix is not supported on host device.",
619  PI_ERROR_INVALID_DEVICE);
620 #endif // __SYCL_DEVICE_ONLY__
621  }
624  }
625 };
626 
627 } // namespace experimental::matrix
628 } // namespace oneapi
629 } // namespace ext
630 } // namespace sycl
631 } // __SYCL_INLINE_NAMESPACE(cl)
spirv_ops.hpp
cl::sycl::ext::oneapi::experimental::matrix::joint_matrix::spvm
__spv::__spirv_JointMatrixINTEL< T, NumRows, NumCols, spv_matrix_layout_traits< Layout >::value > * spvm
Definition: matrix-jit.hpp:59
__spv::Scope::Workgroup
@ Workgroup
Definition: spirv_types.hpp:30
T
__spv::MatrixLayout::RowMajor
@ RowMajor
cl::sycl::ext::oneapi::experimental::matrix::wi_element< uint16_t, NumRows, NumCols, Layout, Group >::make_fp32
static float make_fp32(uint16_t x)
Definition: matrix-jit.hpp:367
cl::sycl::ext::oneapi::experimental::matrix::joint_matrix_store
__SYCL_ALWAYS_INLINE void joint_matrix_store(Group sg, joint_matrix< T, NumRows, NumCols, MatL, Group > &src, multi_ptr< T, Space > res, size_t stride, matrix_layout MemL)
Definition: matrix-jit.hpp:130
defines_elementary.hpp
cl::sycl::multi_ptr::get
pointer_t get() const
Definition: multi_ptr.hpp:234
OP
#define OP(op)
Definition: matrix-jit.hpp:578
__spv::MatrixLayout::PackedA
@ PackedA
cl::sycl::group
Encapsulates all functionality required to represent a particular work-group within a parallel execut...
Definition: helpers.hpp:29
cl::sycl::ext::oneapi::experimental::matrix::wi_element< uint16_t, NumRows, NumCols, Layout, Group >::make_bf16
static uint16_t make_bf16(float x)
Definition: matrix-jit.hpp:374
cl::sycl::ext::oneapi::experimental::matrix::matrix_layout
matrix_layout
Definition: matrix-jit.hpp:22
cl::sycl::ext::oneapi::experimental::matrix::wi_element< uint16_t, NumRows, NumCols, Layout, Group >::Converter::convert
static T2 convert(const T1 &from)
Definition: matrix-jit.hpp:405
__spv::__spirv_JointMatrixINTEL
Definition: spirv_types.hpp:134
__spv::MatrixLayout::PackedB
@ PackedB
sycl
Definition: invoke_simd.hpp:68
SPV_MATRIX_LAYOUT_TRAITS
#define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT)
Definition: matrix-jit.hpp:28
cl::sycl::multi_ptr
Provides constructors for address space qualified and non address space qualified pointers to allow i...
Definition: atomic.hpp:33
cl::sycl::ext::oneapi::experimental::matrix::joint_matrix_load
__SYCL_ALWAYS_INLINE void joint_matrix_load(Group sg, joint_matrix< T, NumRows, NumCols, Layout, Group > &res, multi_ptr< T, Space > src, size_t stride, matrix_layout MemL)
Definition: matrix-jit.hpp:78
cl::sycl::ext::oneapi::experimental::matrix::joint_matrix::joint_matrix
joint_matrix(Group sg)
Definition: matrix-jit.hpp:60
cl::sycl::ext::oneapi::experimental::matrix::spv_matrix_layout_traits
Definition: matrix-jit.hpp:24
cl::sycl::ext::oneapi::experimental::matrix::wi_element::operator=
wi_element & operator=(const wi_element< T, NumRows, NumCols, Layout, Group > &rhs)
Definition: matrix-jit.hpp:268
cl::sycl::ext::oneapi::experimental::matrix::joint_matrix_mad
__SYCL_ALWAYS_INLINE joint_matrix< T3, M, N, LayoutC, Group > joint_matrix_mad(Group sg, joint_matrix< T1, M, K, LayoutA, Group > &mA, joint_matrix< T2, K, N, LayoutB, Group > &mB, joint_matrix< T3, M, N, LayoutC, Group > &mC)
Definition: matrix-jit.hpp:178
cl::sycl::ext::oneapi::experimental::matrix::joint_matrix::get_wi_data
__SYCL_ALWAYS_INLINE wi_data< T, NumRows, NumCols, Layout, Group > get_wi_data()
Definition: matrix-jit.hpp:69
cl::sycl::ext::oneapi::experimental::matrix::wi_element< uint16_t, NumRows, NumCols, Layout, Group >::operator=
wi_element & operator=(const wi_element< uint16_t, NumRows, NumCols, Layout, Group > &rhs)
Definition: matrix-jit.hpp:351
cl::sycl::ext::oneapi::experimental::matrix::wi_element::operator=
wi_element & operator=(const T2 &rhs)
Definition: matrix-jit.hpp:256
cl::sycl::ext::oneapi::experimental::matrix::wi_element< uint16_t, NumRows, NumCols, Layout, Group >
Definition: matrix-jit.hpp:312
__SYCL_ALWAYS_INLINE
#define __SYCL_ALWAYS_INLINE
Definition: defines_elementary.hpp:29
__spv::Scope::Subgroup
@ Subgroup
Definition: spirv_types.hpp:31
cl::sycl::fabs
detail::enable_if_t< detail::is_genfloat< T >::value, T > fabs(T x) __NOEXC
Definition: builtins.hpp:178
cl
We provide new interfaces for matrix muliply in this patch:
Definition: access.hpp:13
bfloat16.hpp
cl::sycl::ext::oneapi::experimental::matrix::wi_data::operator[]
wi_element< T, NumRows, NumCols, Layout, Group > operator[](size_t i)
Definition: matrix-jit.hpp:622
cl::sycl::ext::oneapi::experimental::matrix::joint_matrix_fill
__SYCL_ALWAYS_INLINE void joint_matrix_fill(Group sg, joint_matrix< T, NumRows, NumCols, Layout, Group > &res, const T2 v)
Definition: matrix-jit.hpp:209
cl::sycl::ext::oneapi::experimental::bfloat16
Definition: bfloat16.hpp:20
cl::sycl::ext::oneapi::experimental::matrix::wi_element
Definition: matrix-jit.hpp:230
cl::sycl::access::address_space
address_space
Definition: access.hpp:45
cl::sycl::ext::oneapi::experimental::matrix::wi_data::length
size_t length()
Definition: matrix-jit.hpp:614
cl::sycl::ext::oneapi::experimental::matrix::wi_element< sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols, Layout, Group >::operator=
wi_element & operator=(const wi_element< sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols, Layout, Group > &rhs)
Definition: matrix-jit.hpp:500
cl::sycl::ext::oneapi::sub_group
Definition: sub_group.hpp:108
__spv::MatrixLayout::ColumnMajor
@ ColumnMajor
cl::sycl::ext::oneapi::experimental::matrix::wi_element< sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols, Layout, Group >::wi_element
wi_element(joint_matrix< sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols, Layout, Group > &Mat, std::size_t i)
Definition: matrix-jit.hpp:465
cl::sycl::ext::oneapi::experimental::matrix::wi_element< sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols, Layout, Group >::operator=
wi_element & operator=(const sycl::ext::oneapi::experimental::bfloat16 &rhs)
Definition: matrix-jit.hpp:488
cl::sycl::ext::oneapi::experimental::matrix::wi_element::wi_element
wi_element(joint_matrix< T, NumRows, NumCols, Layout, Group > &Mat, std::size_t i)
Definition: matrix-jit.hpp:235
__spv::MatrixLayout
MatrixLayout
Definition: spirv_types.hpp:111
cl::sycl::ext::oneapi::experimental::matrix::joint_matrix
Definition: matrix-jit.hpp:56
feature_test.hpp
cl::sycl::ext::oneapi::experimental::matrix::wi_data::wi_data
wi_data(joint_matrix< T, NumRows, NumCols, Layout, Group > &Mat)
Definition: matrix-jit.hpp:613
cl::sycl::ext::oneapi::experimental::matrix::wi_element< uint16_t, NumRows, NumCols, Layout, Group >::Converter< T, uint16_t >::convert
static uint16_t convert(const T &from)
Definition: matrix-jit.hpp:409
cl::sycl::ext::oneapi::experimental::matrix::wi_data
Definition: matrix-jit.hpp:51
cl::sycl::ext::oneapi::experimental::matrix::wi_element< uint16_t, NumRows, NumCols, Layout, Group >::wi_element
wi_element(joint_matrix< uint16_t, NumRows, NumCols, Layout, Group > &Mat, std::size_t i)
Definition: matrix-jit.hpp:317
cl::sycl::ext::oneapi::experimental::matrix::wi_element< uint16_t, NumRows, NumCols, Layout, Group >::operator=
wi_element & operator=(const uint16_t &rhs)
Definition: matrix-jit.hpp:339
__SYCL_INLINE_NAMESPACE
#define __SYCL_INLINE_NAMESPACE(X)
Definition: defines_elementary.hpp:12