21 #include <type_traits>
24 inline namespace _V1 {
27 namespace experimental {
30 template <
typename Group,
typename T,
use Use,
size_t Rows,
size_t Cols,
40 size_t Rows,
size_t Cols,
42 sycl::ext::oneapi::experimental::matrix::layout::dynamic,
46 #define __SYCL_JOINT_MATRIX_OVERLOAD_ARR(TYPE, USE, M, N, SIZE) \
47 template <sycl::ext::oneapi::experimental::matrix::layout Layout> \
48 struct joint_matrix_cuda< \
49 TYPE, sycl::ext::oneapi::experimental::matrix::use::USE, M, N, Layout, \
50 typename std::enable_if_t< \
52 sycl::ext::oneapi::experimental::matrix::layout::row_major || \
54 sycl::ext::oneapi::experimental::matrix::layout::col_major>> { \
55 marray<TYPE, SIZE> wi_marray; \
92 #undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR
94 #define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(TYPE, M, N, SIZE) \
96 struct joint_matrix_cuda< \
97 TYPE, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, \
98 sycl::ext::oneapi::experimental::matrix::layout::dynamic> { \
99 marray<TYPE, SIZE> wi_marray; \
113 #undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC
115 #define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(PRECISION, USE, M, N, TYPE, \
117 template <sycl::ext::oneapi::experimental::matrix::layout Layout> \
118 struct joint_matrix_cuda< \
119 PRECISION, sycl::ext::oneapi::experimental::matrix::use::USE, M, N, \
121 typename std::enable_if_t< \
123 sycl::ext::oneapi::experimental::matrix::layout::row_major || \
125 sycl::ext::oneapi::experimental::matrix::layout::col_major>> { \
126 marray<TYPE, SIZE> wi_marray; \
133 sycl::ext::oneapi::experimental::matrix::precision::tf32,
b, 8, 16,
float,
136 #undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION
137 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
138 template <sycl::ext::oneapi::experimental::matrix::layout Layout>
139 constexpr
int get_layout_id();
143 get_layout_id<sycl::ext::oneapi::experimental::matrix::layout::row_major>() {
149 get_layout_id<sycl::ext::oneapi::experimental::matrix::layout::col_major>() {
154 typename T,
size_t NumRows,
size_t NumCols,
158 S, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
159 NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res,
160 multi_ptr<T, Space, IsDecorated> src,
size_t stride) {
161 if constexpr (std::is_same_v<S, int32_t>) {
162 auto destptr =
reinterpret_cast<int32_t *
>(&res.wi_marray);
163 if constexpr (NumRows == 16 && NumCols == 16) {
164 __imma_m16n16k16_ld_c(destptr, src.get(), stride,
165 get_layout_id<Layout>());
166 }
else if constexpr (NumRows == 8 && NumCols == 32) {
167 __imma_m8n32k16_ld_c(destptr, src.get(), stride, get_layout_id<Layout>());
168 }
else if constexpr (NumRows == 32 && NumCols == 8) {
169 __imma_m32n8k16_ld_c(destptr, src.get(), stride, get_layout_id<Layout>());
171 }
else if constexpr (std::is_same_v<S, float>) {
172 auto dstptr =
reinterpret_cast<float *
>(&res.wi_marray);
173 if constexpr (NumRows == 16 && NumCols == 16) {
174 __hmma_m16n16k16_ld_c_f32(dstptr, src.get(), stride,
175 get_layout_id<Layout>());
176 }
else if constexpr (NumRows == 8 && NumCols == 32) {
177 __hmma_m8n32k16_ld_c_f32(dstptr, src.get(), stride,
178 get_layout_id<Layout>());
179 }
else if constexpr (NumRows == 32 && NumCols == 8) {
180 __hmma_m32n8k16_ld_c_f32(dstptr, src.get(), stride,
181 get_layout_id<Layout>());
183 }
else if constexpr (std::is_same_v<S, half>) {
184 auto tileptr =
reinterpret_cast<const int32_t *
>(src.get());
185 auto dstptr =
reinterpret_cast<int32_t *
>(&res.wi_marray);
186 if constexpr (NumRows == 32 && NumCols == 8) {
187 __hmma_m32n8k16_ld_c_f16(dstptr, tileptr, stride,
188 get_layout_id<Layout>());
189 }
else if constexpr (NumRows == 8 && NumCols == 32) {
190 __hmma_m8n32k16_ld_c_f16(dstptr, tileptr, stride,
191 get_layout_id<Layout>());
192 }
else if constexpr (NumRows == 16 && NumCols == 16) {
193 __hmma_m16n16k16_ld_c_f16(dstptr, tileptr, stride,
194 get_layout_id<Layout>());
196 }
else if constexpr (std::is_same_v<S, double>) {
197 __dmma_m8n8k4_ld_c(
reinterpret_cast<double *
>(&res.wi_marray), src.get(),
198 stride, get_layout_id<Layout>());
202 template <
typename S,
typename T,
size_t NumRows,
size_t NumCols,
204 void load_accumulator_cuda(
206 S, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
207 NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res,
208 multi_ptr<T, Space, IsDecorated> src,
size_t stride,
211 case sycl::ext::oneapi::experimental::matrix::layout::row_major:
213 sycl::ext::oneapi::experimental::matrix::layout::row_major>(res, src,
216 case sycl::ext::oneapi::experimental::matrix::layout::col_major:
218 sycl::ext::oneapi::experimental::matrix::layout::col_major>(res, src,
222 assert(
false &&
"Invalid layout specified!");
227 typename S,
typename T,
size_t NumRows,
size_t NumCols,
232 Layout == sycl::ext::oneapi::experimental::matrix::layout::row_major ||
234 sycl::ext::oneapi::experimental::matrix::layout::col_major,
236 void load_multiplicand_cuda(
237 joint_matrix_cuda<S, Use, NumRows, NumCols, Layout> &res,
238 multi_ptr<T, Space, IsDecorated> src,
size_t stride) {
239 if constexpr (std::is_same_v<S, sycl::ext::oneapi::bfloat16>) {
240 auto tileptr =
reinterpret_cast<const int32_t *
>(src.get());
241 auto destptr =
reinterpret_cast<int32_t *
>(&res.wi_marray);
242 if constexpr (NumRows == 16 && NumCols == 16) {
244 __mma_bf16_m16n16k16_ld_a(destptr, tileptr, stride,
245 get_layout_id<Layout>());
246 }
else if constexpr (Use ==
248 __mma_bf16_m16n16k16_ld_b(destptr, tileptr, stride,
249 get_layout_id<Layout>());
251 }
else if constexpr (NumRows == 8 && NumCols == 16) {
252 __mma_bf16_m8n32k16_ld_a(destptr, tileptr, stride,
253 get_layout_id<Layout>());
254 }
else if constexpr (NumRows == 16 && NumCols == 32) {
255 __mma_bf16_m8n32k16_ld_b(destptr, tileptr, stride,
256 get_layout_id<Layout>());
257 }
else if constexpr (NumRows == 32 && NumCols == 16) {
258 __mma_bf16_m32n8k16_ld_a(destptr, tileptr, stride,
259 get_layout_id<Layout>());
260 }
else if constexpr (NumRows == 16 && NumCols == 8) {
261 __mma_bf16_m32n8k16_ld_b(destptr, tileptr, stride,
262 get_layout_id<Layout>());
264 }
else if constexpr (std::is_same_v<S, uint8_t>) {
265 auto tileptr =
reinterpret_cast<const int32_t *
>(src.get());
266 auto destptr =
reinterpret_cast<int32_t *
>(&res.wi_marray);
267 if constexpr (NumRows == 16 && NumCols == 16) {
269 __imma_m16n16k16_ld_a_u8(destptr, tileptr, stride,
270 get_layout_id<Layout>());
271 }
else if constexpr (Use ==
273 __imma_m16n16k16_ld_b_u8(destptr, tileptr, stride,
274 get_layout_id<Layout>());
276 }
else if constexpr (NumRows == 8 && NumCols == 16) {
277 __imma_m8n32k16_ld_a_u8(destptr, tileptr, stride,
278 get_layout_id<Layout>());
279 }
else if constexpr (NumRows == 16 && NumCols == 32) {
280 __imma_m8n32k16_ld_b_u8(destptr, tileptr, stride,
281 get_layout_id<Layout>());
282 }
else if constexpr (NumRows == 32 && NumCols == 16) {
283 __imma_m32n8k16_ld_a_u8(destptr, tileptr, stride,
284 get_layout_id<Layout>());
285 }
else if constexpr (NumRows == 16 && NumCols == 8) {
286 __imma_m32n8k16_ld_b_u8(destptr, tileptr, stride,
287 get_layout_id<Layout>());
289 }
else if constexpr (std::is_same_v<S, int8_t>) {
290 auto tileptr =
reinterpret_cast<const int32_t *
>(src.get());
291 auto destptr =
reinterpret_cast<int32_t *
>(&res.wi_marray);
292 if constexpr (NumRows == 16 && NumCols == 16) {
294 __imma_m16n16k16_ld_a_s8(destptr, tileptr, stride,
295 get_layout_id<Layout>());
296 }
else if constexpr (Use ==
298 __imma_m16n16k16_ld_b_s8(destptr, tileptr, stride,
299 get_layout_id<Layout>());
301 }
else if constexpr (NumRows == 8 && NumCols == 16) {
302 __imma_m8n32k16_ld_a_s8(destptr, tileptr, stride,
303 get_layout_id<Layout>());
304 }
else if constexpr (NumRows == 16 && NumCols == 32) {
305 __imma_m8n32k16_ld_b_s8(destptr, tileptr, stride,
306 get_layout_id<Layout>());
307 }
else if constexpr (NumRows == 32 && NumCols == 16) {
308 __imma_m32n8k16_ld_a_s8(destptr, tileptr, stride,
309 get_layout_id<Layout>());
310 }
else if constexpr (NumRows == 16 && NumCols == 8) {
311 __imma_m32n8k16_ld_b_s8(destptr, tileptr, stride,
312 get_layout_id<Layout>());
314 }
else if constexpr (std::is_same_v<S, half>) {
315 auto tileptr =
reinterpret_cast<const int32_t *
>(src.get());
316 auto dstptr =
reinterpret_cast<int32_t *
>(&res.wi_marray);
317 if constexpr (NumRows == 16 && NumCols == 16) {
319 __hmma_m16n16k16_ld_a(dstptr, tileptr, stride, get_layout_id<Layout>());
320 }
else if constexpr (Use ==
322 __hmma_m16n16k16_ld_b(dstptr, tileptr, stride, get_layout_id<Layout>());
324 }
else if constexpr (NumRows == 8 && NumCols == 16) {
325 __hmma_m8n32k16_ld_a(dstptr, tileptr, stride, get_layout_id<Layout>());
326 }
else if constexpr (NumRows == 16 && NumCols == 32) {
327 __hmma_m8n32k16_ld_b(dstptr, tileptr, stride, get_layout_id<Layout>());
328 }
else if constexpr (NumRows == 32 && NumCols == 16) {
329 __hmma_m32n8k16_ld_a(dstptr, tileptr, stride, get_layout_id<Layout>());
330 }
else if constexpr (NumRows == 16 && NumCols == 8) {
331 __hmma_m32n8k16_ld_b(dstptr, tileptr, stride, get_layout_id<Layout>());
334 }
else if constexpr (std::is_same_v<S, sycl::ext::oneapi::experimental::
335 matrix::precision::tf32>) {
336 auto tileptr =
reinterpret_cast<const int32_t *
>(src.get());
337 auto dstptr =
reinterpret_cast<int32_t *
>(&res.wi_marray);
338 if constexpr (NumRows == 16 && NumCols == 8) {
339 __mma_tf32_m16n16k8_ld_a(dstptr, tileptr, stride,
340 get_layout_id<Layout>());
341 }
else if constexpr (NumRows == 8 && NumCols == 16) {
342 __mma_tf32_m16n16k8_ld_b(dstptr, tileptr, stride,
343 get_layout_id<Layout>());
345 }
else if constexpr (std::is_same_v<S, double>) {
346 auto dstptr =
reinterpret_cast<double *
>(&res.wi_marray);
348 __dmma_m8n8k4_ld_a(dstptr, src.get(), stride, get_layout_id<Layout>());
349 }
else if constexpr (Use ==
351 __dmma_m8n8k4_ld_b(dstptr, src.get(), stride, get_layout_id<Layout>());
360 const joint_matrix_cuda<
361 T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
362 NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
363 multi_ptr<T, Space, IsDecorated> dst,
size_t stride) {
364 if constexpr (NumRows == 16 && NumCols == 16) {
365 if constexpr (std::is_same_v<T, float>) {
366 __hmma_m16n16k16_st_c_f32(dst.get(), &src.wi_marray[0], stride,
367 get_layout_id<Layout>());
368 }
else if constexpr (std::is_same_v<T, int32_t>) {
369 __imma_m16n16k16_st_c_i32(dst.get(), &src.wi_marray[0], stride,
370 get_layout_id<Layout>());
371 }
else if constexpr (std::is_same_v<T, half>) {
372 __hmma_m16n16k16_st_c_f16(
373 reinterpret_cast<int32_t *
>(dst.get()),
374 reinterpret_cast<const int32_t *
>(&src.wi_marray[0]), stride,
375 get_layout_id<Layout>());
377 }
else if constexpr (NumRows == 8 && NumCols == 32) {
378 if constexpr (std::is_same_v<T, float>) {
379 __hmma_m8n32k16_st_c_f32(dst.get(), &src.wi_marray[0], stride,
380 get_layout_id<Layout>());
381 }
else if constexpr (std::is_same_v<T, int32_t>) {
382 __imma_m8n32k16_st_c_i32(dst.get(), &src.wi_marray[0], stride,
383 get_layout_id<Layout>());
384 }
else if constexpr (std::is_same_v<T, half>) {
385 __hmma_m8n32k16_st_c_f16(
386 reinterpret_cast<int32_t *
>(dst.get()),
387 reinterpret_cast<const int32_t *
>(&src.wi_marray[0]), stride,
388 get_layout_id<Layout>());
390 }
else if constexpr (NumRows == 32 && NumCols == 8) {
391 if constexpr (std::is_same_v<T, float>) {
392 __hmma_m32n8k16_st_c_f32(dst.get(), &src.wi_marray[0], stride,
393 get_layout_id<Layout>());
394 }
else if constexpr (std::is_same_v<T, int32_t>) {
395 __imma_m32n8k16_st_c_i32(dst.get(), &src.wi_marray[0], stride,
396 get_layout_id<Layout>());
397 }
else if constexpr (std::is_same_v<T, half>) {
398 __hmma_m32n8k16_st_c_f16(
399 reinterpret_cast<int32_t *
>(dst.get()),
400 reinterpret_cast<const int32_t *
>(&src.wi_marray[0]), stride,
401 get_layout_id<Layout>());
403 }
else if constexpr (std::is_same_v<T, double>) {
404 __dmma_m8n8k4_st_c_f64(dst.get(), &src.wi_marray[0], stride,
405 get_layout_id<Layout>());
409 template <
typename T,
size_t NumRows,
size_t NumCols,
411 void joint_matrix_store_cuda(
412 const joint_matrix_cuda<
413 T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
414 NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
415 multi_ptr<T, Space, IsDecorated> dst,
size_t stride,
418 case sycl::ext::oneapi::experimental::matrix::layout::row_major:
419 store_layoutT<sycl::ext::oneapi::experimental::matrix::layout::row_major>(
422 case sycl::ext::oneapi::experimental::matrix::layout::col_major:
423 store_layoutT<sycl::ext::oneapi::experimental::matrix::layout::col_major>(
427 assert(
false &&
"Invalid layout specified!");
433 constexpr
int get_layout_pair_id();
436 constexpr
int get_layout_pair_id<
437 sycl::ext::oneapi::experimental::matrix::layout::row_major,
438 sycl::ext::oneapi::experimental::matrix::layout::row_major>() {
443 constexpr
int get_layout_pair_id<
444 sycl::ext::oneapi::experimental::matrix::layout::row_major,
445 sycl::ext::oneapi::experimental::matrix::layout::col_major>() {
450 constexpr
int get_layout_pair_id<
451 sycl::ext::oneapi::experimental::matrix::layout::col_major,
452 sycl::ext::oneapi::experimental::matrix::layout::row_major>() {
457 constexpr
int get_layout_pair_id<
458 sycl::ext::oneapi::experimental::matrix::layout::col_major,
459 sycl::ext::oneapi::experimental::matrix::layout::col_major>() {
464 typename Tm,
typename Tc,
typename Td, std::size_t M, std::size_t K,
469 sycl::ext::oneapi::experimental::matrix::layout::row_major ||
471 sycl::ext::oneapi::experimental::matrix::layout::col_major) &&
473 sycl::ext::oneapi::experimental::matrix::layout::row_major ||
475 sycl::ext::oneapi::experimental::matrix::layout::col_major),
477 void joint_matrix_mad_cuda(
479 Td, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
480 sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D,
485 const joint_matrix_cuda<
486 Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
487 sycl::ext::oneapi::experimental::matrix::layout::dynamic> &C) {
488 if constexpr (M == 16 && N == 16 && K == 16) {
489 if constexpr (std::is_same_v<Tc, int32_t>) {
490 auto ptrA =
reinterpret_cast<const int32_t *
>(&
A.wi_marray);
491 auto ptrB =
reinterpret_cast<const int32_t *
>(&
B.wi_marray);
492 auto ptrC =
reinterpret_cast<const int32_t *
>(&C.wi_marray);
493 auto ptrD =
reinterpret_cast<int32_t *
>(&D.wi_marray);
494 if constexpr (std::is_same_v<Tm, int8_t>) {
495 __imma_m16n16k16_mma_s8(ptrD, ptrA, ptrB, ptrC,
496 get_layout_pair_id<LayoutA, LayoutB>(), 0);
497 }
else if constexpr (std::is_same_v<Tm, uint8_t>) {
498 __imma_m16n16k16_mma_u8(ptrD, ptrA, ptrB, ptrC,
499 get_layout_pair_id<LayoutA, LayoutB>(), 0);
501 }
else if constexpr (std::is_same_v<Tm, half>) {
502 auto ptrA =
reinterpret_cast<const int32_t *
>(&
A.wi_marray);
503 auto ptrB =
reinterpret_cast<const int32_t *
>(&
B.wi_marray);
504 if constexpr (std::is_same_v<Tc, float>) {
505 if constexpr (std::is_same<Td, float>::value) {
506 __hmma_m16n16k16_mma_f32f32(
507 reinterpret_cast<float *
>(&D.wi_marray), ptrA, ptrB,
508 reinterpret_cast<const float *
>(&C.wi_marray),
509 get_layout_pair_id<LayoutA, LayoutB>(), 0);
511 __hmma_m16n16k16_mma_f16f32(
512 reinterpret_cast<int32_t *
>(&D.wi_marray), ptrA, ptrB,
513 reinterpret_cast<const float *
>(&C.wi_marray),
514 get_layout_pair_id<LayoutA, LayoutB>(), 0);
516 }
else if constexpr (std::is_same_v<Tc, half>) {
517 if constexpr (std::is_same<Td, float>::value) {
518 __hmma_m16n16k16_mma_f32f16(
519 reinterpret_cast<float *
>(&D.wi_marray), ptrA, ptrB,
520 reinterpret_cast<const int32_t *
>(&C.wi_marray),
521 get_layout_pair_id<LayoutA, LayoutB>(), 0);
523 __hmma_m16n16k16_mma_f16f16(
524 reinterpret_cast<int32_t *
>(&D.wi_marray), ptrA, ptrB,
525 reinterpret_cast<const int32_t *
>(&C.wi_marray),
526 get_layout_pair_id<LayoutA, LayoutB>(), 0);
529 }
else if constexpr (std::is_same_v<Tm, sycl::ext::oneapi::bfloat16>) {
530 __mma_bf16_m16n16k16_mma_f32(
531 reinterpret_cast<float *
>(&D.wi_marray),
532 reinterpret_cast<const int32_t *
>(&
A.wi_marray),
533 reinterpret_cast<const int32_t *
>(&
B.wi_marray),
534 reinterpret_cast<const float *
>(&C.wi_marray),
535 get_layout_pair_id<LayoutA, LayoutB>(), 0);
537 }
else if constexpr (M == 8 && N == 32 && K == 16) {
538 if constexpr (std::is_same_v<Tc, int32_t>) {
539 auto ptrA =
reinterpret_cast<const int32_t *
>(&
A.wi_marray);
540 auto ptrB =
reinterpret_cast<const int32_t *
>(&
B.wi_marray);
541 auto ptrC =
reinterpret_cast<const int32_t *
>(&C.wi_marray);
542 auto ptrD =
reinterpret_cast<int32_t *
>(&D.wi_marray);
543 if constexpr (std::is_same_v<Tm, int8_t>) {
544 __imma_m8n32k16_mma_s8(ptrD, ptrA, ptrB, ptrC,
545 get_layout_pair_id<LayoutA, LayoutB>(), 0);
546 }
else if constexpr (std::is_same_v<Tm, uint8_t>) {
547 __imma_m8n32k16_mma_u8(ptrD, ptrA, ptrB, ptrC,
548 get_layout_pair_id<LayoutA, LayoutB>(), 0);
550 }
else if constexpr (std::is_same_v<Tm, half>) {
551 auto ptrA =
reinterpret_cast<const int32_t *
>(&
A.wi_marray);
552 auto ptrB =
reinterpret_cast<const int32_t *
>(&
B.wi_marray);
553 if constexpr (std::is_same_v<Tc, float>) {
554 if constexpr (std::is_same<Td, float>::value) {
555 __hmma_m8n32k16_mma_f32f32(
556 reinterpret_cast<float *
>(&D.wi_marray), ptrA, ptrB,
557 reinterpret_cast<const float *
>(&C.wi_marray),
558 get_layout_pair_id<LayoutA, LayoutB>(), 0);
560 __hmma_m8n32k16_mma_f16f32(
561 reinterpret_cast<int32_t *
>(&D.wi_marray), ptrA, ptrB,
562 reinterpret_cast<const float *
>(&C.wi_marray),
563 get_layout_pair_id<LayoutA, LayoutB>(), 0);
565 }
else if constexpr (std::is_same_v<Tc, half>) {
566 if constexpr (std::is_same<Td, float>::value) {
567 __hmma_m8n32k16_mma_f32f16(
568 reinterpret_cast<float *
>(&D.wi_marray), ptrA, ptrB,
569 reinterpret_cast<const int32_t *
>(&C.wi_marray),
570 get_layout_pair_id<LayoutA, LayoutB>(), 0);
572 __hmma_m8n32k16_mma_f16f16(
573 reinterpret_cast<int32_t *
>(&D.wi_marray), ptrA, ptrB,
574 reinterpret_cast<const int32_t *
>(&C.wi_marray),
575 get_layout_pair_id<LayoutA, LayoutB>(), 0);
578 }
else if constexpr (std::is_same_v<Tm, sycl::ext::oneapi::bfloat16>) {
579 __mma_bf16_m8n32k16_mma_f32(
580 reinterpret_cast<float *
>(&D.wi_marray),
581 reinterpret_cast<const int32_t *
>(&
A.wi_marray),
582 reinterpret_cast<const int32_t *
>(&
B.wi_marray),
583 reinterpret_cast<const float *
>(&C.wi_marray),
584 get_layout_pair_id<LayoutA, LayoutB>(), 0);
586 }
else if constexpr (M == 32 && N == 8 && K == 16) {
587 if constexpr (std::is_same_v<Tc, int32_t>) {
588 auto ptrA =
reinterpret_cast<const int32_t *
>(&
A.wi_marray);
589 auto ptrB =
reinterpret_cast<const int32_t *
>(&
B.wi_marray);
590 auto ptrC =
reinterpret_cast<const int32_t *
>(&C.wi_marray);
591 auto ptrD =
reinterpret_cast<int32_t *
>(&D.wi_marray);
592 if constexpr (std::is_same_v<Tm, int8_t>) {
593 __imma_m32n8k16_mma_s8(ptrD, ptrA, ptrB, ptrC,
594 get_layout_pair_id<LayoutA, LayoutB>(), 0);
595 }
else if constexpr (std::is_same_v<Tm, uint8_t>) {
596 __imma_m32n8k16_mma_u8(ptrD, ptrA, ptrB, ptrC,
597 get_layout_pair_id<LayoutA, LayoutB>(), 0);
599 }
else if constexpr (std::is_same_v<Tm, sycl::ext::oneapi::bfloat16>) {
600 __mma_bf16_m32n8k16_mma_f32(
601 reinterpret_cast<float *
>(&D.wi_marray),
602 reinterpret_cast<const int32_t *
>(&
A.wi_marray),
603 reinterpret_cast<const int32_t *
>(&
B.wi_marray),
604 reinterpret_cast<const float *
>(&C.wi_marray),
605 get_layout_pair_id<LayoutA, LayoutB>(), 0);
606 }
else if constexpr (std::is_same_v<Tm, half>) {
608 auto ptrA =
reinterpret_cast<const int32_t *
>(&
A.wi_marray);
609 auto ptrB =
reinterpret_cast<const int32_t *
>(&
B.wi_marray);
610 if constexpr (std::is_same_v<Tc, float>) {
611 if constexpr (std::is_same<Td, float>::value) {
612 __hmma_m32n8k16_mma_f32f32(
613 reinterpret_cast<float *
>(&D.wi_marray), ptrA, ptrB,
614 reinterpret_cast<const float *
>(&C.wi_marray),
615 get_layout_pair_id<LayoutA, LayoutB>(), 0);
617 __hmma_m32n8k16_mma_f16f32(
618 reinterpret_cast<int32_t *
>(&D.wi_marray), ptrA, ptrB,
619 reinterpret_cast<const float *
>(&C.wi_marray),
620 get_layout_pair_id<LayoutA, LayoutB>(), 0);
622 }
else if constexpr (std::is_same_v<Tc, half>) {
623 if constexpr (std::is_same<Td, float>::value) {
624 __hmma_m32n8k16_mma_f32f16(
625 reinterpret_cast<float *
>(&D.wi_marray), ptrA, ptrB,
626 reinterpret_cast<const int32_t *
>(&C.wi_marray),
627 get_layout_pair_id<LayoutA, LayoutB>(), 0);
629 __hmma_m32n8k16_mma_f16f16(
630 reinterpret_cast<int32_t *
>(&D.wi_marray), ptrA, ptrB,
631 reinterpret_cast<const int32_t *
>(&C.wi_marray),
632 get_layout_pair_id<LayoutA, LayoutB>(), 0);
636 }
else if constexpr (M == 16 && N == 16 && K == 8) {
637 __mma_tf32_m16n16k8_mma_f32(
reinterpret_cast<float *
>(&D.wi_marray),
638 reinterpret_cast<const int32_t *
>(&
A.wi_marray),
639 reinterpret_cast<const int32_t *
>(&
B.wi_marray),
640 reinterpret_cast<const float *
>(&C.wi_marray),
641 get_layout_pair_id<LayoutA, LayoutB>(), 0);
642 }
else if constexpr (std::is_same_v<Tm, double>) {
643 __dmma_m8n8k4_mma_f64(
reinterpret_cast<double *
>(&D.wi_marray),
644 reinterpret_cast<const double *
>(&
A.wi_marray),
645 reinterpret_cast<const double *
>(&
B.wi_marray),
646 reinterpret_cast<const double *
>(&C.wi_marray),
647 get_layout_pair_id<LayoutA, LayoutB>(), 0);
#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_PRECISION(PRECISION, USE, M, N, TYPE, SIZE)
#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(TYPE, M, N, SIZE)
#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR(TYPE, USE, M, N, SIZE)
void store_layoutT(const joint_matrix_hip< T, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &src, multi_ptr< T, Space, IsDecorated > dst, size_t stride, Group &sg)
void load_accumulator_layoutT(joint_matrix_hip< S, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, Group &sg)
sycl::detail::half_impl::half half