13 #if defined(__SYCL_DEVICE_ONLY__)
14 #if defined(__NVPTX__)
16 #elif defined(__gfx90a__)
33 #include <type_traits>
36 inline namespace _V1 {
39 namespace experimental {
42 template <
typename Group,
typename T,
use Use,
size_t Rows,
size_t Cols,
46 #if defined(__SYCL_DEVICE_ONLY__)
47 #if defined(__NVPTX__)
50 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
53 #elif defined(__SPIR__) || defined(__SPIRV__)
58 static_assert(
false,
"The joint_matrix API is only supported by the Intel, "
59 "CUDA and HIP (GFX90A) backends");
63 #if defined(__SYCL_DEVICE_ONLY__)
64 [[__sycl_detail__::add_ir_attributes_function(
65 "sycl-joint-matrix-type",
"sycl-joint-matrix-use",
66 "sycl-joint-matrix-rows",
"sycl-joint-matrix-cols",
67 sycl::detail::convertTypeToMatrixTypeString<T>(),
71 #ifndef __SYCL_DEVICE_ONLY__
72 throw runtime_error(
"joint matrix is not supported on host device.",
73 PI_ERROR_INVALID_DEVICE);
76 #ifdef __SYCL_DEVICE_ONLY__
77 #if defined(__SPIR__) || defined(__SPIRV__)
84 template <
typename Group,
typename T,
use Use,
size_t M,
size_t N,
89 #if defined(__SYCL_DEVICE_ONLY__)
90 #if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__)
92 for (
int i = 0; i < jm.matrix_impl.wi_marray.size(); i++) {
93 lambda(jm.matrix_impl.wi_marray[i]);
96 using storage_element_type =
98 T>::storage_element_type;
100 for (
int i = 0; i < wi_data_c.length(); i++) {
101 storage_element_type element = wi_data_c[i];
103 wi_data_c[i] = element;
109 std::ignore = lambda;
110 throw runtime_error(
"joint matrix is not supported on host device.",
111 PI_ERROR_INVALID_DEVICE);
116 template <
typename Group,
typename T,
use Use,
size_t M,
size_t N,
117 layout Layout,
typename F>
122 #if defined(__SYCL_DEVICE_ONLY__)
123 #if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__)
125 for (
int i = 0; i < jmsrc.matrix_impl.wi_marray.size(); i++) {
126 lambda(jmsrc.matrix_impl.wi_marray[i], jmdest.matrix_impl.wi_marray[i]);
129 using storage_element_type =
131 T>::storage_element_type;
134 for (
int i = 0; i < wi_data_c.length(); i++) {
135 storage_element_type elementsrc = wi_data_c[i];
136 storage_element_type elementdest = wi_data_d[i];
137 lambda(elementsrc, elementdest);
138 wi_data_d[i] = elementdest;
144 std::ignore = jmdest;
145 std::ignore = lambda;
146 throw runtime_error(
"joint matrix is not supported on host device.",
147 PI_ERROR_INVALID_DEVICE);
152 template <
typename Group,
typename T,
size_t NumRows,
size_t NumCols,
use Use,
153 layout Layout,
typename T2>
158 #if defined(__SYCL_DEVICE_ONLY__)
159 #if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__)
160 res.matrix_impl.wi_marray = v;
162 using storage_element_type =
164 T>::storage_element_type;
166 __spirv_CompositeConstruct<storage_element_type, T, NumRows, NumCols,
169 static_cast<storage_element_type
>(v));
174 throw runtime_error(
"joint matrix is not supported on host device.",
175 PI_ERROR_INVALID_DEVICE);
180 typename Group,
typename S,
typename T,
size_t NumRows,
size_t NumCols,
182 std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value,
bool> =
187 sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res,
190 #if defined(__SYCL_DEVICE_ONLY__)
192 "Joint Matrix doesn't support load from private memory!");
193 #if defined(__NVPTX__)
195 sycl::ext::oneapi::detail::load_accumulator_cuda(res.matrix_impl, src, stride,
197 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
202 using DecorT =
typename sycl::detail::DecoratedType<T, Space>::type;
203 DecorT *Ptr = sycl::detail::getDecorated<DecorT>(src);
204 res.spvm = __spirv_JointMatrixLoadINTEL<
205 DecorT, S, NumRows, NumCols,
215 std::ignore = stride;
216 std::ignore = Layout;
217 throw runtime_error(
"joint matrix is not supported on host device.",
218 PI_ERROR_INVALID_DEVICE);
223 typename Group,
typename S,
typename T,
use Use,
size_t NumRows,
226 std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
227 (std::is_same<S, precision::tf32>::value &&
228 std::is_same<std::remove_const_t<T>,
float>::value),
234 #if defined(__SYCL_DEVICE_ONLY__)
236 "Joint Matrix doesn't support load from private memory!");
237 #if defined(__NVPTX__)
239 sycl::ext::oneapi::detail::load_multiplicand_cuda<S, T, NumRows, NumCols, Use,
241 res.matrix_impl, src, stride);
242 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
244 NumCols, Use, Layout, Space>(
245 res.matrix_impl, src, stride, sg);
248 using DecorT =
typename sycl::detail::DecoratedType<T, Space>::type;
249 DecorT *Ptr = sycl::detail::getDecorated<DecorT>(src);
251 __spirv_JointMatrixLoadINTEL<DecorT, S, NumRows, NumCols,
261 std::ignore = stride;
262 throw runtime_error(
"joint matrix is not supported on host device.",
263 PI_ERROR_INVALID_DEVICE);
267 template <
typename Group,
typename S,
typename T,
size_t NumRows,
268 size_t NumCols,
typename PropertyListT,
269 std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value,
274 sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res,
277 #if defined(__SYCL_DEVICE_ONLY__)
278 #if defined(__NVPTX__)
280 throw runtime_error(
"Use joint_matrix_load on multi_ptr on Nvidia device.",
281 PI_ERROR_INVALID_DEVICE);
282 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
283 throw runtime_error(
"Use joint_matrix_load on multi_ptr on AMD device.",
284 PI_ERROR_INVALID_DEVICE);
288 res.spvm = __spirv_JointMatrixLoadINTEL<
298 std::ignore = stride;
299 std::ignore = Layout;
300 throw runtime_error(
"joint matrix is not supported on host device.",
301 PI_ERROR_INVALID_DEVICE);
306 typename Group,
typename S,
typename T,
use Use,
size_t NumRows,
308 std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
309 (std::is_same<S, precision::tf32>::value &&
310 std::is_same<std::remove_const_t<T>,
float>::value),
316 #if defined(__SYCL_DEVICE_ONLY__)
317 #if defined(__NVPTX__)
319 throw runtime_error(
"Use joint_matrix_load on multi_ptr on Nvidia device.",
320 PI_ERROR_INVALID_DEVICE);
321 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
322 throw runtime_error(
"Use joint_matrix_load on multi_ptr on AMD device.",
323 PI_ERROR_INVALID_DEVICE);
328 __spirv_JointMatrixLoadINTEL<T, S, NumRows, NumCols,
338 std::ignore = stride;
339 throw runtime_error(
"joint matrix is not supported on host device.",
340 PI_ERROR_INVALID_DEVICE);
344 template <
typename Group,
typename T,
size_t NumRows,
size_t NumCols,
349 sycl::ext::oneapi::experimental::matrix::layout::dynamic>
353 #if defined(__SYCL_DEVICE_ONLY__)
355 "Joint Matrix doesn't support store to private memory!");
356 #if defined(__NVPTX__)
358 sycl::ext::oneapi::detail::joint_matrix_store_cuda<T, NumRows, NumCols,
360 src.matrix_impl, dst, stride, Layout);
361 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
363 Space>(src.matrix_impl, dst,
367 using DecorT =
typename sycl::detail::DecoratedType<T, Space>::type;
368 DecorT *Ptr = sycl::detail::getDecorated<DecorT>(dst);
369 __spirv_JointMatrixStoreINTEL<
370 DecorT, T, NumRows, NumCols,
380 std::ignore = stride;
381 std::ignore = Layout;
382 throw runtime_error(
"joint matrix is not supported on host device.",
383 PI_ERROR_INVALID_DEVICE);
387 template <
typename Group,
typename T,
size_t NumRows,
size_t NumCols,
388 typename PropertyListT>
392 sycl::ext::oneapi::experimental::matrix::layout::dynamic>
396 #if defined(__SYCL_DEVICE_ONLY__)
397 #if defined(__NVPTX__)
399 throw runtime_error(
"Use joint_matrix_store on multi_ptr on Nvidia device.",
400 PI_ERROR_INVALID_DEVICE);
401 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
402 throw runtime_error(
"Use joint_matrix_store on multi_ptr on AMD device.",
403 PI_ERROR_INVALID_DEVICE);
407 __spirv_JointMatrixStoreINTEL<
417 std::ignore = stride;
418 std::ignore = Layout;
419 throw runtime_error(
"joint matrix is not supported on host device.",
420 PI_ERROR_INVALID_DEVICE);
424 template <
typename Group,
typename Ta,
typename Tb,
typename Tc,
typename Td,
425 std::size_t M, std::size_t K, std::size_t N,
layout LayoutA,
427 #if defined(__SYCL_DEVICE_ONLY__)
428 [[__sycl_detail__::add_ir_attributes_function(
429 "sycl-joint-matrix-mad-type-A",
"sycl-joint-matrix-mad-type-B",
430 "sycl-joint-matrix-mad-type-C",
"sycl-joint-matrix-mad-type-D",
431 "sycl-joint-matrix-mad-size-M",
"sycl-joint-matrix-mad-size-K",
432 "sycl-joint-matrix-mad-size-N",
433 sycl::detail::convertTypeToMatrixTypeString<Ta>(),
434 sycl::detail::convertTypeToMatrixTypeString<Tb>(),
435 sycl::detail::convertTypeToMatrixTypeString<Tc>(),
436 sycl::detail::convertTypeToMatrixTypeString<Td>(), M, K, N)]]
442 sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D,
446 sycl::ext::oneapi::experimental::matrix::layout::dynamic>
448 #if defined(__SYCL_DEVICE_ONLY__)
449 #if defined(__NVPTX__)
450 if constexpr (std::is_same<Ta, Tb>::value) {
451 sycl::ext::oneapi::detail::joint_matrix_mad_cuda<Ta, Tc, Td, M, K, N,
453 D.matrix_impl, A.matrix_impl, B.matrix_impl, C.matrix_impl);
455 assert(
false &&
"Ta != Tb : In the CUDA backend joint_matrix_mad "
456 "requires that joint_matrix data types Ta and Tb match");
458 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
459 if constexpr (std::is_same<Ta, Tb>::value) {
462 D.matrix_impl, A.matrix_impl, B.matrix_impl, C.matrix_impl);
464 assert(
false &&
"Ta != Tb : In the HIP backend joint_matrix_mad "
465 "requires that joint_matrix data types Ta and Tb match");
468 if constexpr (std::is_same<Ta, uint16_t>::value &&
469 std::is_same<Tb, uint16_t>::value &&
470 std::is_same<Tc, float>::value)
471 D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
472 else if constexpr (std::is_unsigned<Ta>::value && std::is_unsigned<Tb>::value)
473 D.spvm = __spirv_JointMatrixUUMadINTEL(A.spvm, B.spvm, C.spvm);
474 else if constexpr (std::is_signed<Ta>::value && std::is_unsigned<Tb>::value)
475 D.spvm = __spirv_JointMatrixSUMadINTEL(A.spvm, B.spvm, C.spvm);
476 else if constexpr (std::is_unsigned<Ta>::value && std::is_signed<Tb>::value)
477 D.spvm = __spirv_JointMatrixUSMadINTEL(A.spvm, B.spvm, C.spvm);
479 D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
486 throw runtime_error(
"joint matrix is not supported on host device.",
487 PI_ERROR_INVALID_DEVICE);
491 template <
typename Group,
typename T1,
typename T2,
size_t Rows,
size_t Cols,
496 #if defined(__SYCL_DEVICE_ONLY__)
497 #if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__)
499 dst.matrix_impl.wi_marray = src.matrix_impl.wi_marray;
501 using storage_element_type =
503 T2>::storage_element_type;
506 for (
int i = 0; i < wi_data_c.length(); i++) {
507 wi_data_dst[i] =
static_cast<storage_element_type
>(wi_data_c[i]);
514 throw runtime_error(
"joint matrix is not supported on host device.",
515 PI_ERROR_INVALID_DEVICE);
522 #if defined(__SYCL_DEVICE_ONLY__)
523 #if defined(__NVPTX__)
524 int32_t tmp_int = __nvvm_f2tf32_rna(
a);
525 return __nvvm_bitcast_i2f(tmp_int);
527 return __spirv_RoundFToTF32INTEL(
a);
530 uint32_t tmp_uint =
reinterpret_cast<const uint32_t &
>(
a);
532 tmp_uint &= 0xFFFFE000u;
534 std::memcpy(&ret, &tmp_uint,
sizeof(
float));
539 template <
size_t NumRows,
size_t NumCols,
typename Group,
typename T,
545 #if defined(__SYCL_DEVICE_ONLY__)
546 #if defined(__NVPTX__)
550 "joint_matrix_prefetch is not supported on Nvidia device.",
551 PI_ERROR_INVALID_DEVICE);
552 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
555 throw runtime_error(
"joint_matrix_prefetch is not supported on AMD device.",
556 PI_ERROR_INVALID_DEVICE);
559 auto prop =
properties.template get_property<prefetch_hint_key>();
563 __spirv_CooperativeMatrixPrefetchINTEL<T>(
564 Ptr, coordX, coordY, NumRows, NumCols,
571 std::ignore = stride;
572 std::ignore = Layout;
574 throw runtime_error(
"joint matrix is not supported on host device.",
575 PI_ERROR_INVALID_DEVICE);
#define __SYCL_ALWAYS_INLINE
__SYCL_ALWAYS_INLINE __spv::MatrixLayout joint_matrix_layout_to_spv(sycl::ext::oneapi::experimental::matrix::layout Layout)
constexpr const char * convertMatrixUseEnumToString(ext::oneapi::experimental::matrix::use Use)
void joint_matrix_store_hip(const joint_matrix_hip< T, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &src, multi_ptr< T, Space, IsDecorated > dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout layout, Group &sg)
void joint_matrix_mad_hip(joint_matrix_hip< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &D, const joint_matrix_hip< Tm, sycl::ext::oneapi::experimental::matrix::use::a, M, K, LayoutA > &A, const joint_matrix_hip< Tm, sycl::ext::oneapi::experimental::matrix::use::b, K, N, LayoutB > &B, const joint_matrix_hip< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &C)
void load_accumulator_hip(joint_matrix_hip< S, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, sycl::ext::oneapi::experimental::matrix::layout layout, Group &sg)
void load_multiplicand_hip(joint_matrix_hip< S, Use, M, N, Layout > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, Group &sg)
decltype(auto) __SYCL_ALWAYS_INLINE get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, Rows, Cols, Layout > &jm)
__SYCL_ALWAYS_INLINE float round_to_tf32(const float &a)
void joint_matrix_copy(Group sg, joint_matrix< Group, T1, Use1, Rows, Cols, Layout1 > &src, joint_matrix< Group, T2, Use2, Rows, Cols, Layout2 > &dst)
__SYCL_ALWAYS_INLINE void joint_matrix_load(Group sg, joint_matrix< Group, S, use::accumulator, NumRows, NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout)
__SYCL_ALWAYS_INLINE void joint_matrix_fill(Group, joint_matrix< Group, T, Use, NumRows, NumCols, Layout > &res, const T2 &v)
__SYCL_ALWAYS_INLINE void joint_matrix_store(Group sg, const joint_matrix< Group, T, use::accumulator, NumRows, NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &src, multi_ptr< T, Space, IsDecorated > dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout)
__SYCL_ALWAYS_INLINE void joint_matrix_mad(Group, joint_matrix< Group, Td, use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &D, const joint_matrix< Group, Ta, use::a, M, K, LayoutA > &A, const joint_matrix< Group, Tb, use::b, K, N, LayoutB > &B, const joint_matrix< Group, Tc, use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &C)
__SYCL_ALWAYS_INLINE void joint_matrix_prefetch(Group sg, T *Ptr, size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout, Properties properties={})
__SYCL_ALWAYS_INLINE void joint_matrix_apply(Group sg, joint_matrix< Group, T, Use, M, N, Layout > &jm, F &&lambda)
annotated_arg & operator=(annotated_arg &)=default
decltype(properties{}) empty_properties_t