28 #include <type_traits>
31 inline namespace _V1 {
34 namespace experimental {
41 #define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT) \
42 template <> struct spv_matrix_layout_traits<LAYOUT> { \
43 static constexpr __spv::MatrixLayout value = SPV_LAYOUT; \
55 #define SPV_MATRIX_USE_TRAITS(USE, SPV_USE) \
56 template <> struct spv_matrix_use_traits<USE> { \
57 static constexpr __spv::MatrixUse value = SPV_USE; \
73 template <
typename Group,
typename T,
use Use,
size_t Rows,
size_t Cols,
94 using namespace sycl::ext::oneapi::experimental::matrix;
97 template <
typename T,
size_t NumRows,
size_t NumCols,
100 sycl::ext::oneapi::experimental::matrix::layout::dynamic,
112 Group, T, Use, NumRows, NumCols, Layout> &Mat,
117 #if defined(__SYCL_DEVICE_ONLY__)
118 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
119 __ocl_vec_t<uint32_t, 2> coord =
120 __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
122 __ocl_vec_t<uint32_t, 2> coord =
123 __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
125 const size_t row = coord[0];
126 const size_t col = coord[1];
130 "joint matrix is not supported on host.");
135 #ifdef __SYCL_DEVICE_ONLY__
136 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
139 spv_matrix_use_traits<Use>::value,
140 spv_matrix_layout_traits<Layout>::value,
141 spv_scope_traits<Group>::value>(M.spvm,
146 spv_matrix_use_traits<Use>::value,
147 spv_scope_traits<Group>::value>(&M.spvm, idx);
153 "joint matrix is not supported on host.");
157 explicit operator bool() {
158 #ifdef __SYCL_DEVICE_ONLY__
159 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
162 spv_matrix_use_traits<Use>::value,
163 spv_matrix_layout_traits<Layout>::value,
164 spv_scope_traits<Group>::value>(
169 spv_matrix_use_traits<Use>::value,
170 spv_scope_traits<Group>::value>(&M.spvm, idx);
175 "joint matrix is not supported on host.");
180 #ifdef __SYCL_DEVICE_ONLY__
181 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
182 M.spvm = __spirv_VectorInsertDynamic(
187 spv_matrix_use_traits<Use>::value,
188 spv_scope_traits<Group>::value>(&M.spvm, idx);
195 "joint matrix is not supported on host.");
201 #ifdef __SYCL_DEVICE_ONLY__
202 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
203 M.spvm = __spirv_VectorInsertDynamic(
206 spv_matrix_use_traits<Use>::value,
207 spv_matrix_layout_traits<Layout>::value,
208 spv_scope_traits<Group>::value>(rhs.M.spvm,
214 spv_matrix_use_traits<Use>::value,
215 spv_scope_traits<Group>::value>(&rhs.M.spvm,
219 spv_matrix_use_traits<Use>::value,
220 spv_scope_traits<Group>::value>(&M.spvm, idx);
221 *InsertP = *ExtractP;
227 "joint matrix is not supported on host.");
231 #if __SYCL_DEVICE_ONLY__
232 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
234 template <typename T2> wi_element &operator op##=(const T2 & rhs) { \
235 M.spvm = __spirv_VectorInsertDynamic( \
237 static_cast<storage_element_type>( \
238 __spirv_VectorExtractDynamic< \
239 storage_element_type, T, NumRows, NumCols, \
240 spv_matrix_use_traits<Use>::value, \
241 spv_matrix_layout_traits<Layout>::value, \
242 spv_scope_traits<Group>::value>(M.spvm, idx) \
243 op static_cast<storage_element_type>(rhs)), \
249 template <typename T2> wi_element &operator op##=(const T2 & rhs) { \
250 storage_element_type *ExtractP = \
251 __spirv_AccessChain<storage_element_type, T, NumRows, NumCols, \
252 spv_matrix_use_traits<Use>::value, \
253 spv_scope_traits<Group>::value>(&rhs.M.spvm, \
255 storage_element_type *InsertP = \
256 __spirv_AccessChain<storage_element_type, T, NumRows, NumCols, \
257 spv_matrix_use_traits<Use>::value, \
258 spv_scope_traits<Group>::value>(&M.spvm, idx); \
259 *InsertP = *ExtractP op static_cast<storage_element_type>(rhs); \
265 template <typename T2> wi_element &operator op##=(const T2 & rhs) { \
267 throw exception(make_error_code(errc::runtime), \
268 "joint matrix is not supported on host."); \
278 template <
size_t NumRows,
size_t NumCols,
296 #if defined(__SYCL_DEVICE_ONLY__)
297 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
298 __ocl_vec_t<uint32_t, 2> coord =
299 __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
301 __ocl_vec_t<uint32_t, 2> coord =
302 __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
304 const uint32_t row = coord[0];
305 const uint32_t col = coord[1];
309 "joint matrix is not supported on host.");
314 #ifdef __SYCL_DEVICE_ONLY__
315 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
316 return __spirv_VectorExtractDynamic<
318 NumCols, spv_matrix_use_traits<Use>::value,
319 spv_matrix_layout_traits<Layout>::value,
320 spv_scope_traits<Group>::value>(M.spvm, idx);
325 spv_matrix_use_traits<Use>::value,
326 spv_scope_traits<Group>::value>(&M.spvm, idx);
331 "joint matrix is not supported on host.");
335 explicit operator bool() {
336 #ifdef __SYCL_DEVICE_ONLY__
337 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
339 __spirv_VectorExtractDynamic<
341 NumRows, NumCols, spv_matrix_use_traits<Use>::value,
342 spv_matrix_layout_traits<Layout>::value,
343 spv_scope_traits<Group>::value>(M.spvm, idx))) >=
344 std::numeric_limits<float>::epsilon();
349 spv_matrix_use_traits<Use>::value,
350 spv_scope_traits<Group>::value>(&M.spvm, idx);
352 return sycl::fabs(
static_cast<float>(Elem)) >=
353 std::numeric_limits<float>::epsilon();
357 "joint matrix is not supported on host.");
362 #ifdef __SYCL_DEVICE_ONLY__
363 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
364 M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx);
369 spv_matrix_use_traits<Use>::value,
370 spv_scope_traits<Group>::value>(&M.spvm, idx);
377 "joint matrix is not supported on host.");
382 NumCols, Use, Layout, Group> &rhs) {
383 #ifdef __SYCL_DEVICE_ONLY__
384 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
385 M.spvm = __spirv_VectorInsertDynamic(
389 NumCols, spv_matrix_use_traits<Use>::value,
390 spv_matrix_layout_traits<Layout>::value,
391 spv_scope_traits<Group>::value>(rhs.M.spvm,
399 spv_matrix_use_traits<Use>::value,
400 spv_scope_traits<Group>::value>(&rhs.M.spvm,
405 spv_matrix_use_traits<Use>::value,
406 spv_scope_traits<Group>::value>(&M.spvm, idx);
407 *InsertP = *ExtractP;
413 "joint matrix is not supported on host.");
417 #if __SYCL_DEVICE_ONLY__
418 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
419 #define OP(opassign, op) \
420 wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \
421 M.spvm = __spirv_VectorInsertDynamic( \
423 __spirv_VectorExtractDynamic< \
424 sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
425 NumCols, spv_matrix_use_traits<Use>::value, \
426 spv_matrix_layout_traits<Layout>::value, \
427 spv_scope_traits<Group>::value>(M.spvm, idx) op rhs, \
432 #define OP(opassign, op) \
433 wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \
434 sycl::ext::oneapi::bfloat16 *ExtractP = \
435 __spirv_AccessChain<sycl::ext::oneapi::bfloat16, \
436 sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
437 spv_matrix_use_traits<Use>::value, \
438 spv_scope_traits<Group>::value>(&M.spvm, idx); \
439 sycl::ext::oneapi::bfloat16 *InsertP = \
440 __spirv_AccessChain<sycl::ext::oneapi::bfloat16, \
441 sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
442 spv_matrix_use_traits<Use>::value, \
443 spv_scope_traits<Group>::value>(&M.spvm, idx); \
444 *InsertP = *ExtractP op rhs; \
449 #define OP(opassign, op) \
450 wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \
452 throw exception(make_error_code(errc::runtime), \
453 "joint matrix is not supported on host."); \
462 #if __SYCL_DEVICE_ONLY__
463 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
464 #define OP(type, op) \
465 friend type operator op( \
466 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
467 Layout, Group> &lhs, \
468 const sycl::ext::oneapi::bfloat16 &rhs) { \
469 return __spirv_VectorExtractDynamic< \
470 sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
471 NumCols, spv_matrix_use_traits<Use>::value, \
472 spv_matrix_layout_traits<Layout>::value, \
473 spv_scope_traits<Group>::value>(lhs.M.spvm, lhs.idx) op rhs; \
475 friend type operator op( \
476 const sycl::ext::oneapi::bfloat16 &lhs, \
477 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
478 Layout, Group> &rhs) { \
479 return __spirv_VectorExtractDynamic< \
480 sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
481 NumCols, spv_matrix_use_traits<Use>::value, \
482 spv_matrix_layout_traits<Layout>::value, \
483 spv_scope_traits<Group>::value>(rhs.M.spvm, rhs.idx) op lhs; \
486 #define OP(type, op) \
487 friend type operator op( \
488 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
489 Layout, Group> &lhs, \
490 const sycl::ext::oneapi::bfloat16 &rhs) { \
491 sycl::ext::oneapi::bfloat16 *ExtractP = \
492 __spirv_AccessChain<sycl::ext::oneapi::bfloat16, \
493 sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
494 spv_matrix_use_traits<Use>::value, \
495 spv_scope_traits<Group>::value>(&lhs.M.spvm, \
497 return *ExtractP op rhs; \
499 friend type operator op( \
500 const sycl::ext::oneapi::bfloat16 &lhs, \
501 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
502 Layout, Group> &rhs) { \
503 sycl::ext::oneapi::bfloat16 *ExtractP = \
504 __spirv_AccessChain<sycl::ext::oneapi::bfloat16, \
505 sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
506 spv_matrix_use_traits<Use>::value, \
507 spv_scope_traits<Group>::value>(&rhs.M.spvm, \
509 return *ExtractP op lhs; \
517 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
518 #define OP(type, op) \
519 friend type operator op( \
520 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
521 Layout, Group> &lhs, \
522 const sycl::ext::oneapi::bfloat16 &rhs) { \
523 return type{static_cast<float>( \
524 __spirv_VectorExtractDynamic< \
525 sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
526 NumCols, spv_matrix_use_traits<Use>::value, \
527 spv_matrix_layout_traits<Layout>::value, \
528 spv_scope_traits<Group>::value>(lhs.M.spvm, lhs.idx)) \
529 op static_cast<float>(rhs)}; \
531 friend type operator op( \
532 const sycl::ext::oneapi::bfloat16 &lhs, \
533 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
534 Layout, Group> &rhs) { \
535 return type{static_cast<float>( \
536 __spirv_VectorExtractDynamic< \
537 sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
538 NumCols, spv_matrix_use_traits<Use>::value, \
539 spv_matrix_layout_traits<Layout>::value, \
540 spv_scope_traits<Group>::value>(rhs.M.spvm, rhs.idx)) \
541 op static_cast<float>(lhs)}; \
544 #define OP(type, op) \
545 friend type operator op( \
546 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
547 Layout, Group> &lhs, \
548 const sycl::ext::oneapi::bfloat16 &rhs) { \
549 sycl::ext::oneapi::bfloat16 *ExtractP = \
550 __spirv_AccessChain<sycl::ext::oneapi::bfloat16, \
551 sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
552 spv_matrix_use_traits<Use>::value, \
553 spv_scope_traits<Group>::value>(&lhs.M.spvm, \
555 return type{static_cast<float>(*ExtractP) op static_cast<float>(rhs)}; \
557 friend type operator op( \
558 const sycl::ext::oneapi::bfloat16 &lhs, \
559 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
560 Layout, Group> &rhs) { \
561 sycl::ext::oneapi::bfloat16 *ExtractP = \
562 __spirv_AccessChain<sycl::ext::oneapi::bfloat16, \
563 sycl::ext::oneapi::bfloat16, NumRows, NumCols, \
564 spv_matrix_use_traits<Use>::value, \
565 spv_scope_traits<Group>::value>(&rhs.M.spvm, \
567 return type{static_cast<float>(*ExtractP) op static_cast<float>(lhs)}; \
578 #define OP(type, op) \
579 friend type operator op( \
580 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
582 const sycl::ext::oneapi::bfloat16 &) { \
583 throw exception(make_error_code(errc::runtime), \
584 "joint matrix is not supported on host."); \
586 friend type operator op( \
587 const sycl::ext::oneapi::bfloat16 &, \
588 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
589 Layout, Group> &) { \
590 throw exception(make_error_code(errc::runtime), \
591 "joint matrix is not supported on host."); \
611 template <
typename Group,
typename T,
620 Group, T, Use, Rows, Cols, Layout> &_jm)
623 template <
typename Grp,
typename Type,
627 friend decltype(
auto)
629 Grp, Type, UseJm, NumRows, NumCols, LayoutJm> &);
633 #if __SYCL_DEVICE_ONLY__
634 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
635 return __spirv_JointMatrixWorkItemLengthINTEL(jm.spvm);
637 return __spirv_CooperativeMatrixLengthKHR(jm.spvm);
641 "joint matrix is not supported on host.");
645 decltype(
auto) operator[](
size_t i) {
650 template <
typename Group,
typename T,
655 Group, T, Use, Rows, Cols, Layout> &jm) {
664 namespace intel::experimental::matrix {
666 typename Group,
typename T,
typename Tp,
676 Group, Tp, Use, NumRows, NumCols, Layout> &src,
678 #if defined(__SYCL_DEVICE_ONLY__)
680 "Joint Matrix doesn't support store to private memory!");
681 #if defined(__NVPTX__)
684 std::ignore = stride;
687 "This version of the matrix extension is only currently supported on "
691 using DecorT =
typename sycl::detail::DecoratedType<T, Space>::type;
692 DecorT *Ptr = sycl::detail::getDecorated<DecorT>(dst);
693 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
694 __spirv_JointMatrixStoreINTEL<DecorT, Tp, NumRows, NumCols,
699 Ptr, src.spvm, stride,
704 __spirv_CooperativeMatrixStoreKHR<
705 DecorT, Tp, NumRows, NumCols,
719 std::ignore = stride;
721 "joint matrix is not supported on host.");
726 typename Group,
typename T,
typename Tp,
729 typename PropertyListT,
736 Group, Tp, Use, NumRows, NumCols, Layout> &src,
739 #if defined(__SYCL_DEVICE_ONLY__)
740 #if defined(__NVPTX__)
743 std::ignore = stride;
746 "This version of the matrix extension is only currently supported on "
751 #ifndef __SPIRV_USE_COOPERATIVE_MATRIX
752 __spirv_JointMatrixStoreINTEL<T, Tp, NumRows, NumCols,
757 Ptr, src.spvm, stride,
762 __spirv_CooperativeMatrixStoreKHR<
763 T, Tp, NumRows, NumCols,
777 std::ignore = stride;
779 "joint matrix is not supported on host.");
783 template <
typename Group,
typename T,
792 #if defined(__SYCL_DEVICE_ONLY__)
793 #if defined(__NVPTX__)
795 for (
int i = 0; i < jm.matrix_impl.wi_marray.size(); i++) {
796 lambda(jm.matrix_impl.wi_marray[i]);
799 using storage_element_type =
801 T>::storage_element_type;
803 for (
int i = 0; i < wi_data_c.length(); i++) {
804 storage_element_type element = wi_data_c[i];
805 auto [row, col] = wi_data_c[i].get_coord();
806 lambda(element, row, col);
807 wi_data_c[i] = element;
813 std::ignore = lambda;
815 "joint matrix is not supported on host.");
819 using namespace sycl::ext::oneapi::experimental::matrix;
823 template <
typename Group,
typename T,
size_t NumRows,
size_t NumCols,
use Use,
824 layout Layout,
typename T2>
826 Group, joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &Res,
827 const T2 &Value,
size_t Height,
size_t Width,
size_t CoordX,
829 #if defined(__SYCL_DEVICE_ONLY__)
830 using storage_element_type =
832 T>::storage_element_type;
833 Res.spvm = __spirv_CooperativeMatrixConstructCheckedINTEL<
834 storage_element_type, T, NumRows, NumCols,
835 spv_matrix_use_traits<Use>::value,
836 spv_matrix_layout_traits<Layout>::value>(
837 CoordX, CoordY, Height, Width,
static_cast<storage_element_type
>(Value));
841 std::ignore = Height;
843 std::ignore = CoordX;
844 std::ignore = CoordY;
846 "joint matrix is not supported on host.");
851 typename Group,
typename S,
typename T,
size_t NumRows,
size_t NumCols,
853 std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value,
bool> =
857 joint_matrix<Group, S, use::accumulator, NumRows, NumCols, layout::dynamic>
860 size_t Height,
size_t Width,
size_t CoordX,
size_t CoordY) {
861 #if defined(__SYCL_DEVICE_ONLY__)
863 "Joint Matrix doesn't support load from private memory!");
865 using DecorT =
typename sycl::detail::DecoratedType<T, Space>::type;
866 DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
867 Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
868 DecorT, S, NumRows, NumCols,
869 spv_matrix_use_traits<use::accumulator>::value,
870 spv_matrix_layout_traits<layout::dynamic>::value>(
872 Height, Width, Stride);
877 std::ignore = Stride;
878 std::ignore = Height;
880 std::ignore = Layout;
881 std::ignore = CoordX;
882 std::ignore = CoordY;
884 "joint matrix is not supported on host.");
889 typename Group,
typename S,
typename T,
use Use,
size_t NumRows,
892 std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
893 (std::is_same<S, precision::tf32>::value &&
894 std::is_same<std::remove_const_t<T>,
float>::value),
897 Group sg, joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &Res,
899 size_t Width,
size_t CoordX,
size_t CoordY) {
900 #if defined(__SYCL_DEVICE_ONLY__)
902 "Joint Matrix doesn't support load from private memory!");
904 using DecorT =
typename sycl::detail::DecoratedType<T, Space>::type;
905 DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
906 Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
907 DecorT, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
908 spv_matrix_layout_traits<Layout>::value>(
909 Ptr, CoordX, CoordY, spv_matrix_layout_traits<Layout>::value, Height,
915 std::ignore = Stride;
916 std::ignore = Height;
918 std::ignore = CoordX;
919 std::ignore = CoordY;
921 "joint matrix is not supported on host.");
925 template <
typename Group,
typename T,
size_t NumRows,
size_t NumCols,
929 joint_matrix<Group, T, use::accumulator, NumRows, NumCols, layout::dynamic>
932 size_t Height,
size_t Width,
size_t CoordX,
size_t CoordY) {
933 #if defined(__SYCL_DEVICE_ONLY__)
935 "Joint Matrix doesn't support store to private memory!");
937 using DecorT =
typename sycl::detail::DecoratedType<T, Space>::type;
938 DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
939 __spirv_CooperativeMatrixStoreCheckedINTEL<
940 DecorT, T, NumRows, NumCols,
941 spv_matrix_use_traits<use::accumulator>::value,
942 spv_matrix_layout_traits<layout::dynamic>::value>(
943 Ptr, CoordX, CoordY, Src.spvm,
949 std::ignore = Stride;
950 std::ignore = Height;
952 std::ignore = Layout;
953 std::ignore = CoordX;
954 std::ignore = CoordY;
956 "joint matrix is not supported on host.");
960 template <
typename Group,
typename T,
typename Tp,
use Use,
size_t NumRows,
963 std::enable_if_t<Use == use::a || Use == use::b, bool> =
true>
965 Group sg,
const joint_matrix<Group, Tp, Use, NumRows, NumCols, Layout> &Src,
967 size_t Width,
size_t CoordX,
size_t CoordY) {
968 #if defined(__SYCL_DEVICE_ONLY__)
970 "Joint Matrix doesn't support store to private memory!");
972 using DecorT =
typename sycl::detail::DecoratedType<T, Space>::type;
973 DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
974 __spirv_CooperativeMatrixStoreCheckedINTEL<
975 DecorT, Tp, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
976 spv_matrix_layout_traits<Layout>::value>(
977 Ptr, CoordX, CoordY, Src.spvm, spv_matrix_layout_traits<Layout>::value,
978 Height, Width, Stride);
983 std::ignore = Stride;
984 std::ignore = Height;
986 std::ignore = CoordX;
987 std::ignore = CoordY;
989 "joint matrix is not supported on host.");
994 template <
typename Group,
typename S,
typename T,
size_t NumRows,
995 size_t NumCols,
typename PropertyListT,
996 std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value,
1000 joint_matrix<Group, S, use::accumulator, NumRows, NumCols, layout::dynamic>
1003 size_t Stride,
layout Layout,
size_t Height,
size_t Width,
size_t CoordX,
1005 #if defined(__SYCL_DEVICE_ONLY__)
1008 Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
1009 T, S, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
1010 spv_matrix_layout_traits<layout::dynamic>::value>(
1012 Height, Width, Stride);
1017 std::ignore = Stride;
1018 std::ignore = Height;
1019 std::ignore = Width;
1020 std::ignore = Layout;
1021 std::ignore = CoordX;
1022 std::ignore = CoordY;
1024 "joint matrix is not supported on host.");
1029 typename Group,
typename S,
typename T,
use Use,
size_t NumRows,
1030 size_t NumCols,
layout Layout,
typename PropertyListT,
1031 std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
1032 (std::is_same<S, precision::tf32>::value &&
1033 std::is_same<std::remove_const_t<T>,
float>::value),
1036 Group sg, joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &Res,
1038 size_t Stride,
size_t Height,
size_t Width,
size_t CoordX,
size_t CoordY) {
1039 #if defined(__SYCL_DEVICE_ONLY__)
1042 Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
1043 T, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
1044 spv_matrix_layout_traits<Layout>::value>(
1045 Ptr, CoordX, CoordY, spv_matrix_layout_traits<Layout>::value, Height,
1051 std::ignore = Stride;
1052 std::ignore = Height;
1053 std::ignore = Width;
1054 std::ignore = CoordX;
1055 std::ignore = CoordY;
1057 "joint matrix is not supported on host.");
1061 template <
typename Group,
typename T,
size_t NumRows,
size_t NumCols,
1062 typename PropertyListT>
1065 joint_matrix<Group, T, use::accumulator, NumRows, NumCols, layout::dynamic>
1068 size_t Stride,
layout Layout,
size_t Height,
size_t Width,
size_t CoordX,
1070 #if defined(__SYCL_DEVICE_ONLY__)
1073 __spirv_CooperativeMatrixStoreCheckedINTEL<
1074 T, T, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
1075 spv_matrix_layout_traits<layout::dynamic>::value>(
1076 Ptr, CoordX, CoordY, Src.spvm,
1082 std::ignore = Stride;
1083 std::ignore = Height;
1084 std::ignore = Width;
1085 std::ignore = Layout;
1086 std::ignore = CoordX;
1087 std::ignore = CoordY;
1089 "joint matrix is not supported on host.");
1093 template <
typename Group,
typename T,
typename Tp,
use Use,
size_t NumRows,
1094 size_t NumCols,
layout Layout,
typename PropertyListT,
1095 std::enable_if_t<Use == use::a || Use == use::b, bool> =
true>
1097 Group sg,
const joint_matrix<Group, Tp, Use, NumRows, NumCols, Layout> &Src,
1099 size_t Stride,
size_t Height,
size_t Width,
size_t CoordX,
size_t CoordY) {
1100 #if defined(__SYCL_DEVICE_ONLY__)
1103 __spirv_CooperativeMatrixStoreCheckedINTEL<
1104 T, Tp, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
1105 spv_matrix_layout_traits<Layout>::value>(
1106 Ptr, CoordX, CoordY, Src.spvm, spv_matrix_layout_traits<Layout>::value,
1107 Height, Width, Stride);
1112 std::ignore = Stride;
1113 std::ignore = Height;
1114 std::ignore = Width;
1115 std::ignore = CoordX;
1116 std::ignore = CoordY;
1118 "joint matrix is not supported on host.");
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols, Layout > &Mat, std::size_t i)
wi_element & operator=(const wi_element< sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout, Group > &rhs)
__SYCL_ALWAYS_INLINE std::tuple< uint32_t, uint32_t > get_coord()
wi_element & operator=(const sycl::ext::oneapi::bfloat16 &rhs)
wi_element & operator=(const wi_element< T, NumRows, NumCols, Use, Layout, Group > &rhs)
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, NumRows, NumCols, Layout > &Mat, std::size_t i)
wi_element & operator=(const T2 &rhs)
typename oneapi::detail::jm_type_interpretation_helper_trait< T >::storage_element_type storage_element_type
__SYCL_ALWAYS_INLINE std::tuple< size_t, size_t > get_coord()
#define __SYCL_ALWAYS_INLINE
#define SPV_MATRIX_USE_TRAITS(USE, SPV_USE)
#define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT)
__SYCL_ALWAYS_INLINE __spv::MatrixLayout joint_matrix_layout_to_spv(sycl::ext::oneapi::experimental::matrix::layout Layout)
constexpr tuple< Ts... > make_tuple(Ts... Args)
sycl::ext::oneapi::bfloat16 bfloat16
__SYCL_ALWAYS_INLINE void joint_matrix_store(Group, const sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, Tp, Use, NumRows, NumCols, Layout > &src, ext::oneapi::experimental::annotated_ptr< T, PropertyListT > dst, size_t stride)
__SYCL_ALWAYS_INLINE void joint_matrix_load_checked(Group sg, joint_matrix< Group, S, Use, NumRows, NumCols, Layout > &Res, ext::oneapi::experimental::annotated_ptr< T, PropertyListT > Src, size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY)
__SYCL_ALWAYS_INLINE void joint_matrix_apply(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, Rows, Cols, Layout > &jm, F &&lambda)
__SYCL_ALWAYS_INLINE void joint_matrix_store_checked(Group sg, const joint_matrix< Group, Tp, Use, NumRows, NumCols, Layout > &Src, ext::oneapi::experimental::annotated_ptr< T, PropertyListT > Dst, size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY)
__SYCL_ALWAYS_INLINE void joint_matrix_fill_checked(Group, joint_matrix< Group, T, Use, NumRows, NumCols, Layout > &Res, const T2 &Value, size_t Height, size_t Width, size_t CoordX, size_t CoordY)
decltype(auto) __SYCL_ALWAYS_INLINE get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, Rows, Cols, Layout > &jm)
std::enable_if_t< detail::is_bf16_storage_type< T >::value, T > fabs(T x)
std::error_code make_error_code(sycl::errc E) noexcept
Constructs an error code using e and sycl_category()
float storage_element_type
static constexpr __spv::MatrixLayout value
static constexpr __spv::MatrixUse value