21 #define __HIP_PLATFORM_AMD_MFMA__
24 inline namespace _V1 {
32 size_t Rows,
size_t Cols,
34 sycl::ext::oneapi::experimental::matrix::layout::dynamic,
62 #undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR
64 #define __SYCL_JOINT_MATRIX_OVERLOAD_ARR(TYPE, USE, M, N, SIZE) \
65 template <sycl::ext::oneapi::experimental::matrix::layout Layout> \
66 struct joint_matrix_hip< \
67 TYPE, sycl::ext::oneapi::experimental::matrix::use::USE, M, N, Layout, \
68 typename std::enable_if_t< \
70 sycl::ext::oneapi::experimental::matrix::layout::row_major || \
72 sycl::ext::oneapi::experimental::matrix::layout::col_major>> { \
73 sycl::marray<TYPE, SIZE> wi_marray; \
94 #undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR
96 #define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(TYPE, M, N) \
98 struct joint_matrix_hip< \
99 TYPE, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, \
100 sycl::ext::oneapi::experimental::matrix::layout::dynamic> { \
101 sycl::marray<TYPE, (M * N) / WAVEFRONT_SIZE> wi_marray; \
110 #undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC
117 S, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
118 sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res,
120 const auto idx = sg.get_group_linear_id() * sg.get_local_range()[0] +
121 sg.get_local_linear_id();
123 if constexpr (std::is_same_v<S, double>) {
124 const auto thread_x = idx % N;
125 const auto thread_y = idx / N;
127 if constexpr (Layout ==
128 sycl::ext::oneapi::experimental::matrix::layout::row_major) {
129 for (
int i = 0; i < 4; ++i) {
130 const int s_idx = thread_x + i * 4 * stride + thread_y * stride;
131 res.wi_marray[i] = src[s_idx];
134 for (
int i = 0; i < 4; ++i) {
135 const int s_idx = i * 4 + thread_x * stride + thread_y;
136 res.wi_marray[i] = src[s_idx];
139 }
else if constexpr (std::is_same_v<S, float> || std::is_same_v<S, int32_t>) {
140 if constexpr (M == 16 && N == 16) {
141 const auto thread_x = idx % N;
142 const auto thread_y = idx / N;
146 for (
int i = 0; i < 4; ++i) {
147 const int s_idx = thread_x + i * stride + thread_y * 4 * stride;
148 res.wi_marray[i] = src[s_idx];
151 for (
int i = 0; i < 4; ++i) {
152 const int s_idx = i + thread_x * stride + thread_y * 4;
153 res.wi_marray[i] = src[s_idx];
156 }
else if constexpr (M == 32 && N == 32) {
157 const auto thread_x = idx % N;
158 const auto thread_y = idx / N;
162 for (
int j = 0; j < 4; ++j) {
163 for (
int i = 0; i < 4; ++i) {
165 thread_x + i * stride + thread_y * 4 * stride + j * 8 * N;
166 res.wi_marray[i + 4 * j] = src[s_idx];
170 for (
int j = 0; j < 4; ++j) {
171 for (
int i = 0; i < 4; ++i) {
172 const int s_idx = i + thread_x * stride + thread_y * 4 + j * 8;
173 res.wi_marray[i + 4 * j] = src[s_idx];
182 typename Group,
typename S,
typename T,
size_t M,
size_t N,
184 typename = std::enable_if_t<std::is_same_v<S, std::remove_const_t<T>>>>
187 S, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
188 sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res,
191 if (
layout == sycl::ext::oneapi::experimental::matrix::layout::row_major)
193 sycl::ext::oneapi::experimental::matrix::layout::row_major>(res, src,
197 sycl::ext::oneapi::experimental::matrix::layout::col_major>(res, src,
202 typename Group,
typename S,
typename T,
size_t M,
size_t N,
206 typename =
typename std::enable_if_t<
207 (Layout == sycl::ext::oneapi::experimental::matrix::layout::row_major ||
209 sycl::ext::oneapi::experimental::matrix::layout::col_major) &&
210 std::is_same_v<S, std::remove_const_t<T>>>>
214 const auto idx = sg.get_group_linear_id() * sg.get_local_range()[0] +
215 sg.get_local_linear_id();
217 if constexpr (std::is_same_v<S, double>) {
218 if constexpr (Layout ==
219 sycl::ext::oneapi::experimental::matrix::layout::row_major) {
220 res.wi_marray[0] = src[idx];
222 res.wi_marray[0] = src[(idx % M) * stride + idx / M];
225 constexpr
int Dim = (M == 16) ? 16 : 32;
227 const auto thread_x = idx % Dim;
228 const auto thread_y = idx / Dim;
230 if constexpr (Layout ==
231 sycl::ext::oneapi::experimental::matrix::layout::col_major) {
232 for (
int i = 0; i < 4; ++i) {
233 const int c_idx = thread_x * stride + i + thread_y * 4;
234 res.wi_marray[i] = src[c_idx];
237 for (
int i = 0; i < 4; ++i) {
238 const int r_idx = thread_x + i * stride + thread_y * stride * 4;
239 res.wi_marray[i] = src[r_idx];
245 template <
typename Group,
251 T, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
252 sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
254 const auto idx = sg.get_group_linear_id() * sg.get_local_range()[0] +
255 sg.get_local_linear_id();
257 if constexpr (std::is_same_v<T, double>) {
258 const auto thread_x = idx % N;
259 const auto thread_y = idx / N;
261 if constexpr (Layout ==
262 sycl::ext::oneapi::experimental::matrix::layout::row_major) {
263 for (
int i = 0; i < 4; ++i) {
264 const int d_idx = thread_x + i * 4 * stride + thread_y * stride;
265 dst[d_idx] = src.wi_marray[i];
268 for (
int i = 0; i < 4; ++i) {
269 const int d_idx = i * 4 + thread_x * stride + thread_y;
270 dst[d_idx] = src.wi_marray[i];
273 }
else if constexpr (std::is_same_v<T, float> || std::is_same_v<T, int32_t>) {
274 if constexpr (M == 16 && N == 16) {
275 const auto thread_x = idx % N;
276 const auto thread_y = idx / N;
280 for (
int i = 0; i < 4; ++i) {
281 const int d_idx = thread_x + i * stride + thread_y * 4 * stride;
282 dst[d_idx] = src.wi_marray[i];
285 for (
int i = 0; i < 4; ++i) {
286 const int d_idx = i + thread_x * stride + thread_y * 4;
287 dst[d_idx] = src.wi_marray[i];
290 }
else if constexpr (M == 32 && N == 32) {
291 const auto thread_x = idx % N;
292 const auto thread_y = idx / N;
296 for (
int j = 0; j < 4; ++j) {
297 for (
int i = 0; i < 4; ++i) {
299 thread_x + i * stride + thread_y * 4 * stride + j * 8 * stride;
300 dst[d_idx] = src.wi_marray[i + 4 * j];
304 for (
int j = 0; j < 4; ++j) {
305 for (
int i = 0; i < 4; ++i) {
306 const int d_idx = i + thread_x * stride + thread_y * 4 + j * 8;
307 dst[d_idx] = src.wi_marray[i + 4 * j];
315 template <
typename Group,
typename T,
size_t M,
size_t N,
319 T, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
320 sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
323 if (sycl::ext::oneapi::experimental::matrix::layout::row_major ==
layout) {
325 sycl::ext::oneapi::experimental::matrix::layout::row_major>(
326 src, dst, stride, sg);
329 sycl::ext::oneapi::experimental::matrix::layout::col_major>(
330 src, dst, stride, sg);
334 template <
typename Tm,
typename Tc, std::size_t M, std::size_t K, std::size_t N,
339 Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
340 sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D,
346 Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
347 sycl::ext::oneapi::experimental::matrix::layout::dynamic> &C) {
349 if constexpr (std::is_same_v<Tm, sycl::half>) {
350 if constexpr (M == 16 && N == 16) {
351 auto result = __builtin_amdgcn_mfma_f32_16x16x16f16(
352 *
reinterpret_cast<const float16x4 *
>(&A.wi_marray),
353 *
reinterpret_cast<const float16x4 *
>(&B.wi_marray),
354 *
reinterpret_cast<const floatx4 *
>(&C.wi_marray), 0, 0, 0);
355 std::memcpy(&D.wi_marray, &result, 4 *
sizeof(
float));
356 }
else if constexpr (M == 32 && N == 32) {
357 auto result = __builtin_amdgcn_mfma_f32_32x32x8f16(
358 *
reinterpret_cast<const float16x4 *
>(&A.wi_marray),
359 *
reinterpret_cast<const float16x4 *
>(&B.wi_marray),
360 *
reinterpret_cast<const floatx16 *
>(&C.wi_marray), 0, 0, 0);
361 std::memcpy(&D.wi_marray, &result, 16 *
sizeof(
float));
363 }
else if constexpr (std::is_same_v<Tm, bfloat16>) {
364 if constexpr (M == 16 && N == 16) {
365 auto result = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
366 *
reinterpret_cast<const bfloat16x4 *
>(&A.wi_marray),
367 *
reinterpret_cast<const bfloat16x4 *
>(&B.wi_marray),
368 *
reinterpret_cast<const floatx4 *
>(&C.wi_marray), 0, 0, 0);
369 std::memcpy(&D.wi_marray, &result, 4 *
sizeof(
float));
370 }
else if constexpr (M == 32 && N == 32) {
371 auto result = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
372 *
reinterpret_cast<const bfloat16x4 *
>(&A.wi_marray),
373 *
reinterpret_cast<const bfloat16x4 *
>(&B.wi_marray),
374 *
reinterpret_cast<const floatx16 *
>(&C.wi_marray), 0, 0, 0);
375 std::memcpy(&D.wi_marray, &result, 16 *
sizeof(
float));
377 }
else if constexpr (std::is_same_v<Tm, double>) {
378 if constexpr (M == 16 && N == 16) {
379 auto result = __builtin_amdgcn_mfma_f64_16x16x4f64(
380 A.wi_marray[0], B.wi_marray[0],
381 *
reinterpret_cast<const doublex4 *
>(&C.wi_marray), 0, 0, 0);
382 std::memcpy(&D.wi_marray, &result, 4 *
sizeof(
double));
384 }
else if constexpr (std::is_same_v<Tm, int8_t>) {
385 if constexpr (M == 16 && N == 16) {
386 auto result = __builtin_amdgcn_mfma_i32_16x16x16i8(
387 *
reinterpret_cast<const Tc *
>(&A.wi_marray),
388 *
reinterpret_cast<const Tc *
>(&B.wi_marray),
389 *
reinterpret_cast<const int32x4 *
>(&C.wi_marray), 0, 0, 0);
390 std::memcpy(&D.wi_marray, &result, 4 *
sizeof(int32_t));
391 }
else if constexpr (M == 32 && N == 32) {
392 auto result = __builtin_amdgcn_mfma_i32_32x32x8i8(
393 *
reinterpret_cast<const Tc *
>(&A.wi_marray),
394 *
reinterpret_cast<const Tc *
>(&B.wi_marray),
395 *
reinterpret_cast<const int32x16 *
>(&C.wi_marray), 0, 0, 0);
396 std::memcpy(&D.wi_marray, &result, 16 *
sizeof(int32_t));
#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR(TYPE, USE, M, N, SIZE)
#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(TYPE, M, N)
sycl::ext::oneapi::bfloat16 bfloat16
void joint_matrix_store_hip(const joint_matrix_hip< T, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &src, multi_ptr< T, Space, IsDecorated > dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout layout, Group &sg)
void store_layoutT(const joint_matrix_hip< T, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &src, multi_ptr< T, Space, IsDecorated > dst, size_t stride, Group &sg)
__attribute__((__vector_size__(4 *sizeof(int32_t)))) int int32x4
void joint_matrix_mad_hip(joint_matrix_hip< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &D, const joint_matrix_hip< Tm, sycl::ext::oneapi::experimental::matrix::use::a, M, K, LayoutA > &A, const joint_matrix_hip< Tm, sycl::ext::oneapi::experimental::matrix::use::b, K, N, LayoutB > &B, const joint_matrix_hip< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &C)
__attribute__((__vector_size__(4 *sizeof(__bf16)))) __fp16 bfloat16x4
void load_accumulator_hip(joint_matrix_hip< S, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, sycl::ext::oneapi::experimental::matrix::layout layout, Group &sg)
constexpr int WAVEFRONT_SIZE
void load_multiplicand_hip(joint_matrix_hip< S, Use, M, N, Layout > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, Group &sg)
__attribute__((__vector_size__(4 *sizeof(float)))) float floatx4
__attribute__((__vector_size__(4 *sizeof(double)))) double doublex4
void load_accumulator_layoutT(joint_matrix_hip< S, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, Group &sg)
__attribute__((__vector_size__(16 *sizeof(float)))) float floatx16
__attribute__((__vector_size__(4 *sizeof(__fp16)))) __fp16 float16x4
__attribute__((__vector_size__(16 *sizeof(int32_t)))) int int32x16
__attribute__((always_inline)) auto invoke_simd(sycl
The invoke_simd free function invokes a SIMD function using all work-items in a sub_group.
sycl::detail::half_impl::half half