DPC++ Runtime
Runtime libraries for oneAPI DPC++
matrix-hip.hpp
Go to the documentation of this file.
1 
2 //===-------- matrix-hip.hpp - matrix ext impl ---*- C++ -*-------===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 // ===-------------------------------------------------------------------=== //
9 
10 #pragma once
11 
12 #include "matrix-unified-utils.hpp"
13 
14 #include <sycl/access/access.hpp>
16 #include <sycl/marray.hpp>
17 #include <sycl/multi_ptr.hpp>
18 
19 #include <cstring>
20 
21 #define __HIP_PLATFORM_AMD_MFMA__
22 
23 namespace sycl {
24 inline namespace _V1 {
25 namespace ext {
26 namespace oneapi {
27 namespace detail {
28 
29 constexpr int WAVEFRONT_SIZE = 64;
30 
31 template <typename T, sycl::ext::oneapi::experimental::matrix::use Use,
32  size_t Rows, size_t Cols,
34  sycl::ext::oneapi::experimental::matrix::layout::dynamic,
35  typename Cond = void>
37 
38 using bfloat16x4 = __attribute__((__vector_size__(4 * sizeof(__bf16)))) __fp16;
39 using float16x4 = __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16;
40 using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
41 using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
42 using int32x4 = __attribute__((__vector_size__(4 * sizeof(int32_t)))) int;
43 using int32x16 = __attribute__((__vector_size__(16 * sizeof(int32_t)))) int;
44 using doublex4 = __attribute__((__vector_size__(4 * sizeof(double)))) double;
45 
46 template <typename T> struct to_hip_type {
47  using type = T;
48 };
49 
50 template <> struct to_hip_type<bfloat16> {
51  using type = __bf16;
52 };
53 
54 template <> struct to_hip_type<half> {
55  using type = __fp16;
56 };
57 
58 template <> struct to_hip_type<int8_t> {
59  using type = int32_t;
60 };
61 
62 #undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR
63 
64 #define __SYCL_JOINT_MATRIX_OVERLOAD_ARR(TYPE, USE, M, N, SIZE) \
65  template <sycl::ext::oneapi::experimental::matrix::layout Layout> \
66  struct joint_matrix_hip< \
67  TYPE, sycl::ext::oneapi::experimental::matrix::use::USE, M, N, Layout, \
68  typename std::enable_if_t< \
69  Layout == \
70  sycl::ext::oneapi::experimental::matrix::layout::row_major || \
71  Layout == \
72  sycl::ext::oneapi::experimental::matrix::layout::col_major>> { \
73  sycl::marray<TYPE, SIZE> wi_marray; \
74  };
75 
76 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 16, 16, 4)
77 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 16, 16, 4)
78 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 32, 8, 4)
79 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 8, 32, 4)
80 
85 
86 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, a, 16, 4, 1)
87 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, b, 4, 16, 1)
88 
89 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 32, 8, 4)
90 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 8, 32, 4)
91 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 16, 16, 4)
92 __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 16, 4)
93 
94 #undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR
95 
96 #define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(TYPE, M, N) \
97  template <> \
98  struct joint_matrix_hip< \
99  TYPE, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, \
100  sycl::ext::oneapi::experimental::matrix::layout::dynamic> { \
101  sycl::marray<TYPE, (M * N) / WAVEFRONT_SIZE> wi_marray; \
102  };
103 
109 
110 #undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC
111 
112 template <sycl::ext::oneapi::experimental::matrix::layout Layout, typename S,
113  typename T, size_t M, size_t N, access::address_space Space,
114  access::decorated IsDecorated, typename Group>
117  S, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
118  sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res,
119  multi_ptr<T, Space, IsDecorated> src, size_t stride, Group &sg) {
120  const auto idx = sg.get_group_linear_id() * sg.get_local_range()[0] +
121  sg.get_local_linear_id();
122 
123  if constexpr (std::is_same_v<S, double>) {
124  const auto thread_x = idx % N;
125  const auto thread_y = idx / N;
126 
127  if constexpr (Layout ==
128  sycl::ext::oneapi::experimental::matrix::layout::row_major) {
129  for (int i = 0; i < 4; ++i) {
130  const int s_idx = thread_x + i * 4 * stride + thread_y * stride;
131  res.wi_marray[i] = src[s_idx];
132  }
133  } else {
134  for (int i = 0; i < 4; ++i) {
135  const int s_idx = i * 4 + thread_x * stride + thread_y;
136  res.wi_marray[i] = src[s_idx];
137  }
138  }
139  } else if constexpr (std::is_same_v<S, float> || std::is_same_v<S, int32_t>) {
140  if constexpr (M == 16 && N == 16) {
141  const auto thread_x = idx % N;
142  const auto thread_y = idx / N;
143 
144  if constexpr (Layout == sycl::ext::oneapi::experimental::matrix::layout::
145  row_major) {
146  for (int i = 0; i < 4; ++i) {
147  const int s_idx = thread_x + i * stride + thread_y * 4 * stride;
148  res.wi_marray[i] = src[s_idx];
149  }
150  } else {
151  for (int i = 0; i < 4; ++i) {
152  const int s_idx = i + thread_x * stride + thread_y * 4;
153  res.wi_marray[i] = src[s_idx];
154  }
155  }
156  } else if constexpr (M == 32 && N == 32) {
157  const auto thread_x = idx % N;
158  const auto thread_y = idx / N;
159 
160  if constexpr (Layout == sycl::ext::oneapi::experimental::matrix::layout::
161  row_major) {
162  for (int j = 0; j < 4; ++j) {
163  for (int i = 0; i < 4; ++i) {
164  const int s_idx =
165  thread_x + i * stride + thread_y * 4 * stride + j * 8 * N;
166  res.wi_marray[i + 4 * j] = src[s_idx];
167  }
168  }
169  } else {
170  for (int j = 0; j < 4; ++j) {
171  for (int i = 0; i < 4; ++i) {
172  const int s_idx = i + thread_x * stride + thread_y * 4 + j * 8;
173  res.wi_marray[i + 4 * j] = src[s_idx];
174  }
175  }
176  }
177  }
178  }
179 }
180 
181 template <
182  typename Group, typename S, typename T, size_t M, size_t N,
183  access::address_space Space, access::decorated IsDecorated,
184  typename = std::enable_if_t<std::is_same_v<S, std::remove_const_t<T>>>>
187  S, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
188  sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res,
189  multi_ptr<T, Space, IsDecorated> src, size_t stride,
191  if (layout == sycl::ext::oneapi::experimental::matrix::layout::row_major)
193  sycl::ext::oneapi::experimental::matrix::layout::row_major>(res, src,
194  stride, sg);
195  else
197  sycl::ext::oneapi::experimental::matrix::layout::col_major>(res, src,
198  stride, sg);
199 }
200 
201 template <
202  typename Group, typename S, typename T, size_t M, size_t N,
205  access::address_space Space, access::decorated IsDecorated,
206  typename = typename std::enable_if_t<
207  (Layout == sycl::ext::oneapi::experimental::matrix::layout::row_major ||
208  Layout ==
209  sycl::ext::oneapi::experimental::matrix::layout::col_major) &&
210  std::is_same_v<S, std::remove_const_t<T>>>>
212  multi_ptr<T, Space, IsDecorated> src, size_t stride,
213  Group &sg) {
214  const auto idx = sg.get_group_linear_id() * sg.get_local_range()[0] +
215  sg.get_local_linear_id();
216 
217  if constexpr (std::is_same_v<S, double>) {
218  if constexpr (Layout ==
219  sycl::ext::oneapi::experimental::matrix::layout::row_major) {
220  res.wi_marray[0] = src[idx];
221  } else {
222  res.wi_marray[0] = src[(idx % M) * stride + idx / M];
223  }
224  } else {
225  constexpr int Dim = (M == 16) ? 16 : 32;
226 
227  const auto thread_x = idx % Dim;
228  const auto thread_y = idx / Dim;
229 
230  if constexpr (Layout ==
231  sycl::ext::oneapi::experimental::matrix::layout::col_major) {
232  for (int i = 0; i < 4; ++i) {
233  const int c_idx = thread_x * stride + i + thread_y * 4;
234  res.wi_marray[i] = src[c_idx];
235  }
236  } else {
237  for (int i = 0; i < 4; ++i) {
238  const int r_idx = thread_x + i * stride + thread_y * stride * 4;
239  res.wi_marray[i] = src[r_idx];
240  }
241  }
242  }
243 }
244 
245 template <typename Group,
247  size_t M, size_t N, access::address_space Space,
248  access::decorated IsDecorated>
250  const joint_matrix_hip<
251  T, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
252  sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
253  multi_ptr<T, Space, IsDecorated> dst, size_t stride, Group &sg) {
254  const auto idx = sg.get_group_linear_id() * sg.get_local_range()[0] +
255  sg.get_local_linear_id();
256 
257  if constexpr (std::is_same_v<T, double>) {
258  const auto thread_x = idx % N;
259  const auto thread_y = idx / N;
260 
261  if constexpr (Layout ==
262  sycl::ext::oneapi::experimental::matrix::layout::row_major) {
263  for (int i = 0; i < 4; ++i) {
264  const int d_idx = thread_x + i * 4 * stride + thread_y * stride;
265  dst[d_idx] = src.wi_marray[i];
266  }
267  } else {
268  for (int i = 0; i < 4; ++i) {
269  const int d_idx = i * 4 + thread_x * stride + thread_y;
270  dst[d_idx] = src.wi_marray[i];
271  }
272  }
273  } else if constexpr (std::is_same_v<T, float> || std::is_same_v<T, int32_t>) {
274  if constexpr (M == 16 && N == 16) {
275  const auto thread_x = idx % N;
276  const auto thread_y = idx / N;
277 
278  if constexpr (Layout == sycl::ext::oneapi::experimental::matrix::layout::
279  row_major) {
280  for (int i = 0; i < 4; ++i) {
281  const int d_idx = thread_x + i * stride + thread_y * 4 * stride;
282  dst[d_idx] = src.wi_marray[i];
283  }
284  } else {
285  for (int i = 0; i < 4; ++i) {
286  const int d_idx = i + thread_x * stride + thread_y * 4;
287  dst[d_idx] = src.wi_marray[i];
288  }
289  }
290  } else if constexpr (M == 32 && N == 32) {
291  const auto thread_x = idx % N;
292  const auto thread_y = idx / N;
293 
294  if constexpr (Layout == sycl::ext::oneapi::experimental::matrix::layout::
295  row_major) {
296  for (int j = 0; j < 4; ++j) {
297  for (int i = 0; i < 4; ++i) {
298  const int d_idx =
299  thread_x + i * stride + thread_y * 4 * stride + j * 8 * stride;
300  dst[d_idx] = src.wi_marray[i + 4 * j];
301  }
302  }
303  } else {
304  for (int j = 0; j < 4; ++j) {
305  for (int i = 0; i < 4; ++i) {
306  const int d_idx = i + thread_x * stride + thread_y * 4 + j * 8;
307  dst[d_idx] = src.wi_marray[i + 4 * j];
308  }
309  }
310  }
311  }
312  }
313 }
314 
315 template <typename Group, typename T, size_t M, size_t N,
316  access::address_space Space, access::decorated IsDecorated>
318  const joint_matrix_hip<
319  T, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
320  sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
321  multi_ptr<T, Space, IsDecorated> dst, size_t stride,
323  if (sycl::ext::oneapi::experimental::matrix::layout::row_major == layout) {
324  store_layoutT<Group,
325  sycl::ext::oneapi::experimental::matrix::layout::row_major>(
326  src, dst, stride, sg);
327  } else {
328  store_layoutT<Group,
329  sycl::ext::oneapi::experimental::matrix::layout::col_major>(
330  src, dst, stride, sg);
331  }
332 }
333 
334 template <typename Tm, typename Tc, std::size_t M, std::size_t K, std::size_t N,
339  Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
340  sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D,
342  M, K, LayoutA> &A,
344  K, N, LayoutB> &B,
345  const joint_matrix_hip<
346  Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
347  sycl::ext::oneapi::experimental::matrix::layout::dynamic> &C) {
348 #ifdef __gfx90a__
349  if constexpr (std::is_same_v<Tm, sycl::half>) {
350  if constexpr (M == 16 && N == 16) {
351  auto result = __builtin_amdgcn_mfma_f32_16x16x16f16(
352  *reinterpret_cast<const float16x4 *>(&A.wi_marray),
353  *reinterpret_cast<const float16x4 *>(&B.wi_marray),
354  *reinterpret_cast<const floatx4 *>(&C.wi_marray), 0, 0, 0);
355  std::memcpy(&D.wi_marray, &result, 4 * sizeof(float));
356  } else if constexpr (M == 32 && N == 32) {
357  auto result = __builtin_amdgcn_mfma_f32_32x32x8f16(
358  *reinterpret_cast<const float16x4 *>(&A.wi_marray),
359  *reinterpret_cast<const float16x4 *>(&B.wi_marray),
360  *reinterpret_cast<const floatx16 *>(&C.wi_marray), 0, 0, 0);
361  std::memcpy(&D.wi_marray, &result, 16 * sizeof(float));
362  }
363  } else if constexpr (std::is_same_v<Tm, bfloat16>) {
364  if constexpr (M == 16 && N == 16) {
365  auto result = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
366  *reinterpret_cast<const bfloat16x4 *>(&A.wi_marray),
367  *reinterpret_cast<const bfloat16x4 *>(&B.wi_marray),
368  *reinterpret_cast<const floatx4 *>(&C.wi_marray), 0, 0, 0);
369  std::memcpy(&D.wi_marray, &result, 4 * sizeof(float));
370  } else if constexpr (M == 32 && N == 32) {
371  auto result = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
372  *reinterpret_cast<const bfloat16x4 *>(&A.wi_marray),
373  *reinterpret_cast<const bfloat16x4 *>(&B.wi_marray),
374  *reinterpret_cast<const floatx16 *>(&C.wi_marray), 0, 0, 0);
375  std::memcpy(&D.wi_marray, &result, 16 * sizeof(float));
376  }
377  } else if constexpr (std::is_same_v<Tm, double>) {
378  if constexpr (M == 16 && N == 16) {
379  auto result = __builtin_amdgcn_mfma_f64_16x16x4f64(
380  A.wi_marray[0], B.wi_marray[0],
381  *reinterpret_cast<const doublex4 *>(&C.wi_marray), 0, 0, 0);
382  std::memcpy(&D.wi_marray, &result, 4 * sizeof(double));
383  }
384  } else if constexpr (std::is_same_v<Tm, int8_t>) {
385  if constexpr (M == 16 && N == 16) {
386  auto result = __builtin_amdgcn_mfma_i32_16x16x16i8(
387  *reinterpret_cast<const Tc *>(&A.wi_marray),
388  *reinterpret_cast<const Tc *>(&B.wi_marray),
389  *reinterpret_cast<const int32x4 *>(&C.wi_marray), 0, 0, 0);
390  std::memcpy(&D.wi_marray, &result, 4 * sizeof(int32_t));
391  } else if constexpr (M == 32 && N == 32) {
392  auto result = __builtin_amdgcn_mfma_i32_32x32x8i8(
393  *reinterpret_cast<const Tc *>(&A.wi_marray),
394  *reinterpret_cast<const Tc *>(&B.wi_marray),
395  *reinterpret_cast<const int32x16 *>(&C.wi_marray), 0, 0, 0);
396  std::memcpy(&D.wi_marray, &result, 16 * sizeof(int32_t));
397  }
398  }
399 #endif // __gfx90a__
400 }
401 
402 } // namespace detail
403 } // namespace oneapi
404 } // namespace ext
405 } // namespace _V1
406 } // namespace sycl
#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR(TYPE, USE, M, N, SIZE)
Definition: matrix-hip.hpp:64
#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(TYPE, M, N)
Definition: matrix-hip.hpp:96
void joint_matrix_store_hip(const joint_matrix_hip< T, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &src, multi_ptr< T, Space, IsDecorated > dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout layout, Group &sg)
Definition: matrix-hip.hpp:317
void store_layoutT(const joint_matrix_hip< T, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &src, multi_ptr< T, Space, IsDecorated > dst, size_t stride, Group &sg)
Definition: matrix-hip.hpp:249
__attribute__((__vector_size__(4 *sizeof(int32_t)))) int int32x4
Definition: matrix-hip.hpp:42
void joint_matrix_mad_hip(joint_matrix_hip< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &D, const joint_matrix_hip< Tm, sycl::ext::oneapi::experimental::matrix::use::a, M, K, LayoutA > &A, const joint_matrix_hip< Tm, sycl::ext::oneapi::experimental::matrix::use::b, K, N, LayoutB > &B, const joint_matrix_hip< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &C)
Definition: matrix-hip.hpp:337
__attribute__((__vector_size__(4 *sizeof(__bf16)))) __fp16 bfloat16x4
Definition: matrix-hip.hpp:38
void load_accumulator_hip(joint_matrix_hip< S, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, sycl::ext::oneapi::experimental::matrix::layout layout, Group &sg)
Definition: matrix-hip.hpp:185
void load_multiplicand_hip(joint_matrix_hip< S, Use, M, N, Layout > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, Group &sg)
Definition: matrix-hip.hpp:211
__attribute__((__vector_size__(4 *sizeof(float)))) float floatx4
Definition: matrix-hip.hpp:40
__attribute__((__vector_size__(4 *sizeof(double)))) double doublex4
Definition: matrix-hip.hpp:44
void load_accumulator_layoutT(joint_matrix_hip< S, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, Group &sg)
Definition: matrix-hip.hpp:115
__attribute__((__vector_size__(16 *sizeof(float)))) float floatx16
Definition: matrix-hip.hpp:41
__attribute__((__vector_size__(4 *sizeof(__fp16)))) __fp16 float16x4
Definition: matrix-hip.hpp:39
__attribute__((__vector_size__(16 *sizeof(int32_t)))) int int32x16
Definition: matrix-hip.hpp:43
__attribute__((always_inline)) auto invoke_simd(sycl
The invoke_simd free function invokes a SIMD function using all work-items in a sub_group.
sycl::detail::half_impl::half half
Definition: aliases.hpp:101
Definition: access.hpp:18