13 #if defined(__SYCL_DEVICE_ONLY__)
14 #if defined(__NVPTX__)
16 #elif defined(__gfx90a__)
32 #include <type_traits>
35 inline namespace _V1 {
38 namespace experimental {
41 template <
typename Group,
typename T,
use Use,
size_t Rows,
size_t Cols,
45 #if defined(__SYCL_DEVICE_ONLY__)
46 #if defined(__NVPTX__)
49 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
52 #elif defined(__SPIR__) || defined(__SPIRV__)
57 static_assert(
false,
"The joint_matrix API is only supported by the Intel, "
58 "CUDA and HIP (GFX90A) backends");
62 #if defined(__SYCL_DEVICE_ONLY__)
63 [[__sycl_detail__::add_ir_attributes_function(
64 "sycl-joint-matrix-type",
"sycl-joint-matrix-use",
65 "sycl-joint-matrix-rows",
"sycl-joint-matrix-cols",
66 sycl::detail::convertTypeToMatrixTypeString<T>(),
70 #ifndef __SYCL_DEVICE_ONLY__
72 "joint matrix is not supported on host.");
75 #ifdef __SYCL_DEVICE_ONLY__
76 #if defined(__SPIR__) || defined(__SPIRV__)
83 template <
typename Group,
typename T,
use Use,
size_t M,
size_t N,
88 #if defined(__SYCL_DEVICE_ONLY__)
89 #if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__)
91 for (
int i = 0; i < jm.matrix_impl.wi_marray.size(); i++) {
92 lambda(jm.matrix_impl.wi_marray[i]);
95 using storage_element_type =
97 T>::storage_element_type;
99 for (
int i = 0; i < wi_data_c.length(); i++) {
100 storage_element_type element = wi_data_c[i];
102 wi_data_c[i] = element;
108 std::ignore = lambda;
110 "joint matrix is not supported on host.");
115 template <
typename Group,
typename T,
use Use,
size_t M,
size_t N,
116 layout Layout,
typename F>
121 #if defined(__SYCL_DEVICE_ONLY__)
122 #if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__)
124 for (
int i = 0; i < jmsrc.matrix_impl.wi_marray.size(); i++) {
125 lambda(jmsrc.matrix_impl.wi_marray[i], jmdest.matrix_impl.wi_marray[i]);
128 using storage_element_type =
130 T>::storage_element_type;
133 for (
int i = 0; i < wi_data_c.length(); i++) {
134 storage_element_type elementsrc = wi_data_c[i];
135 storage_element_type elementdest = wi_data_d[i];
136 lambda(elementsrc, elementdest);
137 wi_data_d[i] = elementdest;
143 std::ignore = jmdest;
144 std::ignore = lambda;
146 "joint matrix is not supported on host.");
151 template <
typename Group,
typename T,
size_t NumRows,
size_t NumCols,
use Use,
152 layout Layout,
typename T2>
157 #if defined(__SYCL_DEVICE_ONLY__)
158 #if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__)
159 res.matrix_impl.wi_marray = v;
161 using storage_element_type =
163 T>::storage_element_type;
165 __spirv_CompositeConstruct<storage_element_type, T, NumRows, NumCols,
168 static_cast<storage_element_type
>(v));
174 "joint matrix is not supported on host.");
179 typename Group,
typename S,
typename T,
size_t NumRows,
size_t NumCols,
181 std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value,
bool> =
186 sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res,
189 #if defined(__SYCL_DEVICE_ONLY__)
191 "Joint Matrix doesn't support load from private memory!");
192 #if defined(__NVPTX__)
194 sycl::ext::oneapi::detail::load_accumulator_cuda(res.matrix_impl, src, stride,
196 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
201 using DecorT =
typename sycl::detail::DecoratedType<T, Space>::type;
202 DecorT *Ptr = sycl::detail::getDecorated<DecorT>(src);
203 res.spvm = __spirv_JointMatrixLoadINTEL<
204 DecorT, S, NumRows, NumCols,
214 std::ignore = stride;
215 std::ignore = Layout;
217 "joint matrix is not supported on host.");
222 typename Group,
typename S,
typename T,
use Use,
size_t NumRows,
225 std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
226 (std::is_same<S, precision::tf32>::value &&
227 std::is_same<std::remove_const_t<T>,
float>::value),
233 #if defined(__SYCL_DEVICE_ONLY__)
235 "Joint Matrix doesn't support load from private memory!");
236 #if defined(__NVPTX__)
238 sycl::ext::oneapi::detail::load_multiplicand_cuda<S, T, NumRows, NumCols, Use,
240 res.matrix_impl, src, stride);
241 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
243 NumCols, Use, Layout, Space>(
244 res.matrix_impl, src, stride, sg);
247 using DecorT =
typename sycl::detail::DecoratedType<T, Space>::type;
248 DecorT *Ptr = sycl::detail::getDecorated<DecorT>(src);
250 __spirv_JointMatrixLoadINTEL<DecorT, S, NumRows, NumCols,
260 std::ignore = stride;
262 "joint matrix is not supported on host.");
266 template <
typename Group,
typename S,
typename T,
size_t NumRows,
267 size_t NumCols,
typename PropertyListT,
268 std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value,
273 sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res,
276 #if defined(__SYCL_DEVICE_ONLY__)
277 #if defined(__NVPTX__)
280 "Use joint_matrix_load on multi_ptr on Nvidia device.");
281 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
283 "Use joint_matrix_load on multi_ptr on AMD device.");
287 res.spvm = __spirv_JointMatrixLoadINTEL<
297 std::ignore = stride;
298 std::ignore = Layout;
300 "joint matrix is not supported on host.");
305 typename Group,
typename S,
typename T,
use Use,
size_t NumRows,
307 std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
308 (std::is_same<S, precision::tf32>::value &&
309 std::is_same<std::remove_const_t<T>,
float>::value),
315 #if defined(__SYCL_DEVICE_ONLY__)
316 #if defined(__NVPTX__)
319 "Use joint_matrix_load on multi_ptr on Nvidia device.");
320 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
322 "Use joint_matrix_load on multi_ptr on AMD device.");
327 __spirv_JointMatrixLoadINTEL<T, S, NumRows, NumCols,
337 std::ignore = stride;
339 "joint matrix is not supported on host.");
343 template <
typename Group,
typename T,
size_t NumRows,
size_t NumCols,
348 sycl::ext::oneapi::experimental::matrix::layout::dynamic>
352 #if defined(__SYCL_DEVICE_ONLY__)
354 "Joint Matrix doesn't support store to private memory!");
355 #if defined(__NVPTX__)
357 sycl::ext::oneapi::detail::joint_matrix_store_cuda<T, NumRows, NumCols,
359 src.matrix_impl, dst, stride, Layout);
360 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
362 Space>(src.matrix_impl, dst,
366 using DecorT =
typename sycl::detail::DecoratedType<T, Space>::type;
367 DecorT *Ptr = sycl::detail::getDecorated<DecorT>(dst);
368 __spirv_JointMatrixStoreINTEL<
369 DecorT, T, NumRows, NumCols,
379 std::ignore = stride;
380 std::ignore = Layout;
382 "joint matrix is not supported on host.");
386 template <
typename Group,
typename T,
size_t NumRows,
size_t NumCols,
387 typename PropertyListT>
391 sycl::ext::oneapi::experimental::matrix::layout::dynamic>
395 #if defined(__SYCL_DEVICE_ONLY__)
396 #if defined(__NVPTX__)
399 "Use joint_matrix_store on multi_ptr on Nvidia device.");
400 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
402 "Use joint_matrix_store on multi_ptr on AMD device.");
406 __spirv_JointMatrixStoreINTEL<
416 std::ignore = stride;
417 std::ignore = Layout;
419 "joint matrix is not supported on host.");
423 template <
typename Group,
typename Ta,
typename Tb,
typename Tc,
typename Td,
424 std::size_t M, std::size_t K, std::size_t N,
layout LayoutA,
426 #if defined(__SYCL_DEVICE_ONLY__)
427 [[__sycl_detail__::add_ir_attributes_function(
428 "sycl-joint-matrix-mad-type-A",
"sycl-joint-matrix-mad-type-B",
429 "sycl-joint-matrix-mad-type-C",
"sycl-joint-matrix-mad-type-D",
430 "sycl-joint-matrix-mad-size-M",
"sycl-joint-matrix-mad-size-K",
431 "sycl-joint-matrix-mad-size-N",
432 sycl::detail::convertTypeToMatrixTypeString<Ta>(),
433 sycl::detail::convertTypeToMatrixTypeString<Tb>(),
434 sycl::detail::convertTypeToMatrixTypeString<Tc>(),
435 sycl::detail::convertTypeToMatrixTypeString<Td>(), M, K, N)]]
441 sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D,
445 sycl::ext::oneapi::experimental::matrix::layout::dynamic>
447 #if defined(__SYCL_DEVICE_ONLY__)
448 #if defined(__NVPTX__)
449 if constexpr (std::is_same<Ta, Tb>::value) {
450 sycl::ext::oneapi::detail::joint_matrix_mad_cuda<Ta, Tc, Td, M, K, N,
452 D.matrix_impl, A.matrix_impl, B.matrix_impl, C.matrix_impl);
454 assert(
false &&
"Ta != Tb : In the CUDA backend joint_matrix_mad "
455 "requires that joint_matrix data types Ta and Tb match");
457 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
458 if constexpr (std::is_same<Ta, Tb>::value) {
461 D.matrix_impl, A.matrix_impl, B.matrix_impl, C.matrix_impl);
463 assert(
false &&
"Ta != Tb : In the HIP backend joint_matrix_mad "
464 "requires that joint_matrix data types Ta and Tb match");
467 if constexpr (std::is_same<Ta, uint16_t>::value &&
468 std::is_same<Tb, uint16_t>::value &&
469 std::is_same<Tc, float>::value)
470 D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
471 else if constexpr (std::is_unsigned<Ta>::value && std::is_unsigned<Tb>::value)
472 D.spvm = __spirv_JointMatrixUUMadINTEL(A.spvm, B.spvm, C.spvm);
473 else if constexpr (std::is_signed<Ta>::value && std::is_unsigned<Tb>::value)
474 D.spvm = __spirv_JointMatrixSUMadINTEL(A.spvm, B.spvm, C.spvm);
475 else if constexpr (std::is_unsigned<Ta>::value && std::is_signed<Tb>::value)
476 D.spvm = __spirv_JointMatrixUSMadINTEL(A.spvm, B.spvm, C.spvm);
478 D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
486 "joint matrix is not supported on host.");
490 template <
typename Group,
typename T1,
typename T2,
size_t Rows,
size_t Cols,
495 #if defined(__SYCL_DEVICE_ONLY__)
496 #if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__)
498 dst.matrix_impl.wi_marray = src.matrix_impl.wi_marray;
500 using storage_element_type =
502 T2>::storage_element_type;
505 for (
int i = 0; i < wi_data_c.length(); i++) {
506 wi_data_dst[i] =
static_cast<storage_element_type
>(wi_data_c[i]);
514 "joint matrix is not supported on host.");
521 #if defined(__SYCL_DEVICE_ONLY__)
522 #if defined(__NVPTX__)
523 int32_t tmp_int = __nvvm_f2tf32_rna(
a);
524 return __nvvm_bitcast_i2f(tmp_int);
526 return __spirv_RoundFToTF32INTEL(
a);
529 uint32_t tmp_uint =
reinterpret_cast<const uint32_t &
>(
a);
531 tmp_uint &= 0xFFFFE000u;
533 std::memcpy(&ret, &tmp_uint,
sizeof(
float));
538 template <
size_t NumRows,
size_t NumCols,
typename Group,
typename T,
544 #if defined(__SYCL_DEVICE_ONLY__)
545 #if defined(__NVPTX__)
549 "joint_matrix_prefetch is not supported on Nvidia device.");
550 #elif defined(__HIP_PLATFORM_AMD_MFMA__)
554 "joint_matrix_prefetch is not supported on AMD device.");
557 auto prop =
properties.template get_property<prefetch_hint_key>();
558 __spirv_CooperativeMatrixPrefetchINTEL<T>(
565 std::ignore = stride;
566 std::ignore = Layout;
569 "joint matrix is not supported on host.");
#define __SYCL_ALWAYS_INLINE
__SYCL_ALWAYS_INLINE __spv::MatrixLayout joint_matrix_layout_to_spv(sycl::ext::oneapi::experimental::matrix::layout Layout)
constexpr const char * convertMatrixUseEnumToString(ext::oneapi::experimental::matrix::use Use)
void joint_matrix_store_hip(const joint_matrix_hip< T, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &src, multi_ptr< T, Space, IsDecorated > dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout layout, Group &sg)
void joint_matrix_mad_hip(joint_matrix_hip< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &D, const joint_matrix_hip< Tm, sycl::ext::oneapi::experimental::matrix::use::a, M, K, LayoutA > &A, const joint_matrix_hip< Tm, sycl::ext::oneapi::experimental::matrix::use::b, K, N, LayoutB > &B, const joint_matrix_hip< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &C)
void load_accumulator_hip(joint_matrix_hip< S, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, sycl::ext::oneapi::experimental::matrix::layout layout, Group &sg)
void load_multiplicand_hip(joint_matrix_hip< S, Use, M, N, Layout > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, Group &sg)
decltype(auto) __SYCL_ALWAYS_INLINE get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, T, Use, Rows, Cols, Layout > &jm)
__SYCL_ALWAYS_INLINE float round_to_tf32(const float &a)
void joint_matrix_copy(Group sg, joint_matrix< Group, T1, Use1, Rows, Cols, Layout1 > &src, joint_matrix< Group, T2, Use2, Rows, Cols, Layout2 > &dst)
__SYCL_ALWAYS_INLINE void joint_matrix_load(Group sg, joint_matrix< Group, S, use::accumulator, NumRows, NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout)
__SYCL_ALWAYS_INLINE void joint_matrix_fill(Group, joint_matrix< Group, T, Use, NumRows, NumCols, Layout > &res, const T2 &v)
__SYCL_ALWAYS_INLINE void joint_matrix_store(Group sg, const joint_matrix< Group, T, use::accumulator, NumRows, NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &src, multi_ptr< T, Space, IsDecorated > dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout)
__SYCL_ALWAYS_INLINE void joint_matrix_mad(Group, joint_matrix< Group, Td, use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &D, const joint_matrix< Group, Ta, use::a, M, K, LayoutA > &A, const joint_matrix< Group, Tb, use::b, K, N, LayoutB > &B, const joint_matrix< Group, Tc, use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &C)
__SYCL_ALWAYS_INLINE void joint_matrix_prefetch(Group sg, T *Ptr, size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout, Properties properties={})
__SYCL_ALWAYS_INLINE void joint_matrix_apply(Group sg, joint_matrix< Group, T, Use, M, N, Layout > &jm, F &&lambda)
annotated_arg & operator=(annotated_arg &)=default
std::error_code make_error_code(sycl::errc E) noexcept
Constructs an error code using e and sycl_category()