28 #include <type_traits>
31 inline namespace _V1 {
34 namespace experimental {
41 #define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT) \
42 template <> struct spv_matrix_layout_traits<LAYOUT> { \
43 static constexpr __spv::MatrixLayout value = SPV_LAYOUT; \
55 #define SPV_MATRIX_USE_TRAITS(USE, SPV_USE) \
56 template <> struct spv_matrix_use_traits<USE> { \
57 static constexpr __spv::MatrixUse value = SPV_USE; \
73 template <
typename Group,
typename T,
use Use,
size_t Rows,
size_t Cols,
94 using namespace sycl::ext::oneapi::experimental::matrix;
97 template <
typename T,
size_t NumRows,
size_t NumCols,
100 sycl::ext::oneapi::experimental::matrix::layout::dynamic,
112 Group, T, Use, NumRows, NumCols, Layout> &Mat,
117 #if defined(__SYCL_DEVICE_ONLY__)
118 __ocl_vec_t<uint32_t, 2> coord =
119 __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
120 const size_t row = coord[0];
121 const size_t col = coord[1];
125 "joint matrix is not supported on host.");
130 #ifdef __SYCL_DEVICE_ONLY__
133 spv_matrix_use_traits<Use>::value,
134 spv_matrix_layout_traits<Layout>::value,
135 spv_scope_traits<Group>::value>(M.spvm,
140 "joint matrix is not supported on host.");
144 explicit operator bool() {
145 #ifdef __SYCL_DEVICE_ONLY__
148 spv_matrix_use_traits<Use>::value,
149 spv_matrix_layout_traits<Layout>::value,
150 spv_scope_traits<Group>::value>(
154 "joint matrix is not supported on host.");
159 #ifdef __SYCL_DEVICE_ONLY__
160 M.spvm = __spirv_VectorInsertDynamic(
166 "joint matrix is not supported on host.");
172 #ifdef __SYCL_DEVICE_ONLY__
173 M.spvm = __spirv_VectorInsertDynamic(
176 spv_matrix_use_traits<Use>::value,
177 spv_matrix_layout_traits<Layout>::value,
178 spv_scope_traits<Group>::value>(rhs.M.spvm,
185 "joint matrix is not supported on host.");
189 #if __SYCL_DEVICE_ONLY__
191 template <typename T2> wi_element &operator op##=(const T2 & rhs) { \
192 M.spvm = __spirv_VectorInsertDynamic( \
194 static_cast<storage_element_type>( \
195 __spirv_VectorExtractDynamic< \
196 storage_element_type, T, NumRows, NumCols, \
197 spv_matrix_use_traits<Use>::value, \
198 spv_matrix_layout_traits<Layout>::value, \
199 spv_scope_traits<Group>::value>(M.spvm, idx) \
200 op static_cast<storage_element_type>(rhs)), \
206 template <typename T2> wi_element &operator op##=(const T2 & rhs) { \
208 throw exception(make_error_code(errc::runtime), \
209 "joint matrix is not supported on host."); \
219 template <
size_t NumRows,
size_t NumCols,
237 #if defined(__SYCL_DEVICE_ONLY__)
238 __ocl_vec_t<uint32_t, 2> coord =
239 __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
240 const uint32_t row = coord[0];
241 const uint32_t col = coord[1];
245 "joint matrix is not supported on host.");
250 #ifdef __SYCL_DEVICE_ONLY__
251 return __spirv_VectorExtractDynamic<
253 NumCols, spv_matrix_use_traits<Use>::value,
254 spv_matrix_layout_traits<Layout>::value,
255 spv_scope_traits<Group>::value>(M.spvm, idx);
258 "joint matrix is not supported on host.");
262 explicit operator bool() {
263 #ifdef __SYCL_DEVICE_ONLY__
265 __spirv_VectorExtractDynamic<
267 NumRows, NumCols, spv_matrix_use_traits<Use>::value,
268 spv_matrix_layout_traits<Layout>::value,
269 spv_scope_traits<Group>::value>(M.spvm, idx))) >=
270 std::numeric_limits<float>::epsilon();
273 "joint matrix is not supported on host.");
278 #ifdef __SYCL_DEVICE_ONLY__
279 M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx);
284 "joint matrix is not supported on host.");
289 NumCols, Use, Layout, Group> &rhs) {
290 #ifdef __SYCL_DEVICE_ONLY__
291 M.spvm = __spirv_VectorInsertDynamic(
295 NumCols, spv_matrix_use_traits<Use>::value,
296 spv_matrix_layout_traits<Layout>::value,
297 spv_scope_traits<Group>::value>(rhs.M.spvm,
304 "joint matrix is not supported on host.");
308 #if __SYCL_DEVICE_ONLY__
309 #define OP(opassign, op) \
310 wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \
311 M.spvm = __spirv_VectorInsertDynamic( \
313 __spirv_VectorExtractDynamic< \
314 sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
315 NumCols, spv_matrix_use_traits<Use>::value, \
316 spv_matrix_layout_traits<Layout>::value, \
317 spv_scope_traits<Group>::value>(M.spvm, idx) op rhs, \
322 #define OP(opassign, op) \
323 wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \
325 throw exception(make_error_code(errc::runtime), \
326 "joint matrix is not supported on host."); \
335 #if __SYCL_DEVICE_ONLY__
336 #define OP(type, op) \
337 friend type operator op( \
338 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
339 Layout, Group> &lhs, \
340 const sycl::ext::oneapi::bfloat16 &rhs) { \
341 return __spirv_VectorExtractDynamic< \
342 sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
343 NumCols, spv_matrix_use_traits<Use>::value, \
344 spv_matrix_layout_traits<Layout>::value, \
345 spv_scope_traits<Group>::value>(lhs.M.spvm, lhs.idx) op rhs; \
347 friend type operator op( \
348 const sycl::ext::oneapi::bfloat16 &lhs, \
349 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
350 Layout, Group> &rhs) { \
351 return __spirv_VectorExtractDynamic< \
352 sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
353 NumCols, spv_matrix_use_traits<Use>::value, \
354 spv_matrix_layout_traits<Layout>::value, \
355 spv_scope_traits<Group>::value>(rhs.M.spvm, rhs.idx) op lhs; \
362 #define OP(type, op) \
363 friend type operator op( \
364 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
365 Layout, Group> &lhs, \
366 const sycl::ext::oneapi::bfloat16 &rhs) { \
367 return type{static_cast<float>( \
368 __spirv_VectorExtractDynamic< \
369 sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
370 NumCols, spv_matrix_use_traits<Use>::value, \
371 spv_matrix_layout_traits<Layout>::value, \
372 spv_scope_traits<Group>::value>(lhs.M.spvm, lhs.idx)) \
373 op static_cast<float>(rhs)}; \
375 friend type operator op( \
376 const sycl::ext::oneapi::bfloat16 &lhs, \
377 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
378 Layout, Group> &rhs) { \
379 return type{static_cast<float>( \
380 __spirv_VectorExtractDynamic< \
381 sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
382 NumCols, spv_matrix_use_traits<Use>::value, \
383 spv_matrix_layout_traits<Layout>::value, \
384 spv_scope_traits<Group>::value>(rhs.M.spvm, rhs.idx)) \
385 op static_cast<float>(lhs)}; \
395 #define OP(type, op) \
396 friend type operator op( \
397 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
399 const sycl::ext::oneapi::bfloat16 &) { \
400 throw exception(make_error_code(errc::runtime), \
401 "joint matrix is not supported on host."); \
403 friend type operator op( \
404 const sycl::ext::oneapi::bfloat16 &, \
405 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
406 Layout, Group> &) { \
407 throw exception(make_error_code(errc::runtime), \
408 "joint matrix is not supported on host."); \
428 template <
typename Group,
typename T,
437 Group, T, Use, Rows, Cols, Layout> &_jm)
440 template <
typename Grp,
typename Type,
444 friend decltype(
auto)
446 Grp, Type, UseJm, NumRows, NumCols, LayoutJm> &);
450 #if __SYCL_DEVICE_ONLY__
451 return __spirv_JointMatrixWorkItemLengthINTEL(jm.spvm);
454 "joint matrix is not supported on host.");
458 decltype(
auto) operator[](
size_t i) {
463 template <
typename Group,
typename T,
468 Group, T, Use, Rows, Cols, Layout> &jm) {
477 namespace intel::experimental::matrix {
479 typename Group,
typename T,
typename Tp,
489 Group, Tp, Use, NumRows, NumCols, Layout> &src,
491 #if defined(__SYCL_DEVICE_ONLY__)
493 "Joint Matrix doesn't support store to private memory!");
494 #if defined(__NVPTX__)
497 std::ignore = stride;
500 "This version of the matrix extension is only currently supported on "
504 using DecorT =
typename sycl::detail::DecoratedType<T, Space>::type;
505 DecorT *Ptr = sycl::detail::getDecorated<DecorT>(dst);
506 __spirv_JointMatrixStoreINTEL<DecorT, Tp, NumRows, NumCols,
511 Ptr, src.spvm, stride,
519 std::ignore = stride;
521 "joint matrix is not supported on host.");
526 typename Group,
typename T,
typename Tp,
529 typename PropertyListT,
536 Group, Tp, Use, NumRows, NumCols, Layout> &src,
539 #if defined(__SYCL_DEVICE_ONLY__)
540 #if defined(__NVPTX__)
543 std::ignore = stride;
546 "This version of the matrix extension is only currently supported on "
551 __spirv_JointMatrixStoreINTEL<T, Tp, NumRows, NumCols,
556 Ptr, src.spvm, stride,
564 std::ignore = stride;
566 "joint matrix is not supported on host.");
570 template <
typename Group,
typename T,
579 #if defined(__SYCL_DEVICE_ONLY__)
580 #if defined(__NVPTX__)
582 for (
int i = 0; i < jm.matrix_impl.wi_marray.size(); i++) {
583 lambda(jm.matrix_impl.wi_marray[i]);
586 using storage_element_type =
588 T>::storage_element_type;
590 for (
int i = 0; i < wi_data_c.length(); i++) {
591 storage_element_type element = wi_data_c[i];
592 auto [row, col] = wi_data_c[i].get_coord();
593 lambda(element, row, col);
594 wi_data_c[i] = element;
600 std::ignore = lambda;
602 "joint matrix is not supported on host.");
606 using namespace sycl::ext::oneapi::experimental::matrix;
610 template <
typename Group,
typename T,
size_t NumRows,
size_t NumCols,
use Use,
611 layout Layout,
typename T2>
613 Group, joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &Res,
614 const T2 &Value,
size_t Height,
size_t Width,
size_t CoordX,
616 #if defined(__SYCL_DEVICE_ONLY__)
617 using storage_element_type =
619 T>::storage_element_type;
620 Res.spvm = __spirv_CooperativeMatrixConstructCheckedINTEL<
621 storage_element_type, T, NumRows, NumCols,
622 spv_matrix_use_traits<Use>::value,
623 spv_matrix_layout_traits<Layout>::value>(
624 CoordX, CoordY, Height, Width,
static_cast<storage_element_type
>(Value));
628 std::ignore = Height;
630 std::ignore = CoordX;
631 std::ignore = CoordY;
633 "joint matrix is not supported on host.");
638 typename Group,
typename S,
typename T,
size_t NumRows,
size_t NumCols,
640 std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value,
bool> =
644 joint_matrix<Group, S, use::accumulator, NumRows, NumCols, layout::dynamic>
647 size_t Height,
size_t Width,
size_t CoordX,
size_t CoordY) {
648 #if defined(__SYCL_DEVICE_ONLY__)
650 "Joint Matrix doesn't support load from private memory!");
652 using DecorT =
typename sycl::detail::DecoratedType<T, Space>::type;
653 DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
654 Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
655 DecorT, S, NumRows, NumCols,
656 spv_matrix_use_traits<use::accumulator>::value,
657 spv_matrix_layout_traits<layout::dynamic>::value>(
659 Height, Width, Stride);
664 std::ignore = Stride;
665 std::ignore = Height;
667 std::ignore = Layout;
668 std::ignore = CoordX;
669 std::ignore = CoordY;
671 "joint matrix is not supported on host.");
676 typename Group,
typename S,
typename T,
use Use,
size_t NumRows,
679 std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
680 (std::is_same<S, precision::tf32>::value &&
681 std::is_same<std::remove_const_t<T>,
float>::value),
684 Group sg, joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &Res,
686 size_t Width,
size_t CoordX,
size_t CoordY) {
687 #if defined(__SYCL_DEVICE_ONLY__)
689 "Joint Matrix doesn't support load from private memory!");
691 using DecorT =
typename sycl::detail::DecoratedType<T, Space>::type;
692 DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
693 Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
694 DecorT, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
695 spv_matrix_layout_traits<Layout>::value>(
696 Ptr, CoordX, CoordY, spv_matrix_layout_traits<Layout>::value, Height,
702 std::ignore = Stride;
703 std::ignore = Height;
705 std::ignore = CoordX;
706 std::ignore = CoordY;
708 "joint matrix is not supported on host.");
712 template <
typename Group,
typename T,
size_t NumRows,
size_t NumCols,
716 joint_matrix<Group, T, use::accumulator, NumRows, NumCols, layout::dynamic>
719 size_t Height,
size_t Width,
size_t CoordX,
size_t CoordY) {
720 #if defined(__SYCL_DEVICE_ONLY__)
722 "Joint Matrix doesn't support store to private memory!");
724 using DecorT =
typename sycl::detail::DecoratedType<T, Space>::type;
725 DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
726 __spirv_CooperativeMatrixStoreCheckedINTEL<
727 DecorT, T, NumRows, NumCols,
728 spv_matrix_use_traits<use::accumulator>::value,
729 spv_matrix_layout_traits<layout::dynamic>::value>(
730 Ptr, CoordX, CoordY, Src.spvm,
736 std::ignore = Stride;
737 std::ignore = Height;
739 std::ignore = Layout;
740 std::ignore = CoordX;
741 std::ignore = CoordY;
743 "joint matrix is not supported on host.");
747 template <
typename Group,
typename T,
typename Tp,
use Use,
size_t NumRows,
750 std::enable_if_t<Use == use::a || Use == use::b, bool> =
true>
752 Group sg,
const joint_matrix<Group, Tp, Use, NumRows, NumCols, Layout> &Src,
754 size_t Width,
size_t CoordX,
size_t CoordY) {
755 #if defined(__SYCL_DEVICE_ONLY__)
757 "Joint Matrix doesn't support store to private memory!");
759 using DecorT =
typename sycl::detail::DecoratedType<T, Space>::type;
760 DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
761 __spirv_CooperativeMatrixStoreCheckedINTEL<
762 DecorT, Tp, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
763 spv_matrix_layout_traits<Layout>::value>(
764 Ptr, CoordX, CoordY, Src.spvm, spv_matrix_layout_traits<Layout>::value,
765 Height, Width, Stride);
770 std::ignore = Stride;
771 std::ignore = Height;
773 std::ignore = CoordX;
774 std::ignore = CoordY;
776 "joint matrix is not supported on host.");
781 template <
typename Group,
typename S,
typename T,
size_t NumRows,
782 size_t NumCols,
typename PropertyListT,
783 std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value,
787 joint_matrix<Group, S, use::accumulator, NumRows, NumCols, layout::dynamic>
790 size_t Stride,
layout Layout,
size_t Height,
size_t Width,
size_t CoordX,
792 #if defined(__SYCL_DEVICE_ONLY__)
795 Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
796 T, S, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
797 spv_matrix_layout_traits<layout::dynamic>::value>(
799 Height, Width, Stride);
804 std::ignore = Stride;
805 std::ignore = Height;
807 std::ignore = Layout;
808 std::ignore = CoordX;
809 std::ignore = CoordY;
811 "joint matrix is not supported on host.");
816 typename Group,
typename S,
typename T,
use Use,
size_t NumRows,
817 size_t NumCols,
layout Layout,
typename PropertyListT,
818 std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
819 (std::is_same<S, precision::tf32>::value &&
820 std::is_same<std::remove_const_t<T>,
float>::value),
823 Group sg, joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &Res,
825 size_t Stride,
size_t Height,
size_t Width,
size_t CoordX,
size_t CoordY) {
826 #if defined(__SYCL_DEVICE_ONLY__)
829 Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
830 T, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
831 spv_matrix_layout_traits<Layout>::value>(
832 Ptr, CoordX, CoordY, spv_matrix_layout_traits<Layout>::value, Height,
838 std::ignore = Stride;
839 std::ignore = Height;
841 std::ignore = CoordX;
842 std::ignore = CoordY;
844 "joint matrix is not supported on host.");
848 template <
typename Group,
typename T,
size_t NumRows,
size_t NumCols,
849 typename PropertyListT>
852 joint_matrix<Group, T, use::accumulator, NumRows, NumCols, layout::dynamic>
855 size_t Stride,
layout Layout,
size_t Height,
size_t Width,
size_t CoordX,
857 #if defined(__SYCL_DEVICE_ONLY__)
860 __spirv_CooperativeMatrixStoreCheckedINTEL<
861 T, T, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
862 spv_matrix_layout_traits<layout::dynamic>::value>(
863 Ptr, CoordX, CoordY, Src.spvm,
869 std::ignore = Stride;
870 std::ignore = Height;
872 std::ignore = Layout;
873 std::ignore = CoordX;
874 std::ignore = CoordY;
876 "joint matrix is not supported on host.");
880 template <
typename Group,
typename T,
typename Tp,
use Use,
size_t NumRows,
881 size_t NumCols,
layout Layout,
typename PropertyListT,
882 std::enable_if_t<Use == use::a || Use == use::b, bool> =
true>
884 Group sg,
const joint_matrix<Group, Tp, Use, NumRows, NumCols, Layout> &Src,
886 size_t Stride,
size_t Height,
size_t Width,
size_t CoordX,
size_t CoordY) {
887 #if defined(__SYCL_DEVICE_ONLY__)
890 __spirv_CooperativeMatrixStoreCheckedINTEL<
891 T, Tp, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
892 spv_matrix_layout_traits<Layout>::value>(
893 Ptr, CoordX, CoordY, Src.spvm, spv_matrix_layout_traits<Layout>::value,
894 Height, Width, Stride);
899 std::ignore = Stride;
900 std::ignore = Height;
902 std::ignore = CoordX;
903 std::ignore = CoordY;
905 "joint matrix is not supported on host.");
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols, Layout > &Mat, std::size_t i)
wi_element & operator=(const wi_element< sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout, Group > &rhs)
__SYCL_ALWAYS_INLINE std::tuple< uint32_t, uint32_t > get_coord()
wi_element & operator=(const sycl::ext::oneapi::bfloat16 &rhs)
wi_element & operator=(const wi_element< T, NumRows, NumCols, Use, Layout, Group > &rhs)
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, NumRows, NumCols, Layout > &Mat, std::size_t i)
wi_element & operator=(const T2 &rhs)
typename oneapi::detail::jm_type_interpretation_helper_trait< T >::storage_element_type storage_element_type
__SYCL_ALWAYS_INLINE std::tuple< size_t, size_t > get_coord()
#define __SYCL_ALWAYS_INLINE
#define SPV_MATRIX_USE_TRAITS(USE, SPV_USE)
#define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT)
__SYCL_ALWAYS_INLINE __spv::MatrixLayout joint_matrix_layout_to_spv(sycl::ext::oneapi::experimental::matrix::layout Layout)
constexpr tuple< Ts... > make_tuple(Ts... Args)
sycl::ext::oneapi::bfloat16 bfloat16
__SYCL_ALWAYS_INLINE void joint_matrix_store(Group, const sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, Tp, Use, NumRows, NumCols, Layout > &src, ext::oneapi::experimental::annotated_ptr< T, PropertyListT > dst, size_t stride)
__SYCL_ALWAYS_INLINE void joint_matrix_load_checked(Group sg, joint_matrix< Group, S, Use, NumRows, NumCols, Layout > &Res, ext::oneapi::experimental::annotated_ptr< T, PropertyListT > Src, size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY)
__SYCL_ALWAYS_INLINE void joint_matrix_apply(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, Rows, Cols, Layout > &jm, F &&lambda)
__SYCL_ALWAYS_INLINE void joint_matrix_store_checked(Group sg, const joint_matrix< Group, Tp, Use, NumRows, NumCols, Layout > &Src, ext::oneapi::experimental::annotated_ptr< T, PropertyListT > Dst, size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY)
__SYCL_ALWAYS_INLINE void joint_matrix_fill_checked(Group, joint_matrix< Group, T, Use, NumRows, NumCols, Layout > &Res, const T2 &Value, size_t Height, size_t Width, size_t CoordX, size_t CoordY)
decltype(auto) __SYCL_ALWAYS_INLINE get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, Rows, Cols, Layout > &jm)
std::enable_if_t< detail::is_bf16_storage_type< T >::value, T > fabs(T x)
std::error_code make_error_code(sycl::errc E) noexcept
Constructs an error code using e and sycl_category()
float storage_element_type
static constexpr __spv::MatrixLayout value
static constexpr __spv::MatrixUse value