DPC++ Runtime
Runtime libraries for oneAPI DPC++
matrix-aot-amx.hpp
Go to the documentation of this file.
1 //===------------ matrix-aot-amx.hpp - SYCL matrix ------------*- C++ -*---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // ===--------------------------------------------------------------------=== //
31 // ===--------------------------------------------------------------------=== //
32 
33 #pragma once
34 
36 #include <immintrin.h>
37 
39 namespace sycl {
40 namespace ext {
41 namespace intel {
42 namespace detail {
43 template <typename T> class submatrix {
44 public:
45  _tile1024i tile;
46  short rows, cols;
47 };
48 
49 // TODO: we are adding it this way until sycl::dynamic_extent gets implemented.
51 
52 template <typename T> struct elems_per_dword {
53  static constexpr size_t value = 1;
54 };
55 
56 #define ELEMS_PER_DWORD(TYPE, NUM) \
57  template <> struct elems_per_dword<TYPE> { \
58  static constexpr size_t value = NUM; \
59  };
60 
61 ELEMS_PER_DWORD(int8_t, 4)
62 ELEMS_PER_DWORD(unsigned short, 2)
63 
64 } // namespace detail
65 
66 namespace experimental::matrix {
67 #ifdef __SYCL_DEVICE_ONLY__
68 SYCL_EXTERNAL extern "C" _tile1024i
69 _tileloadd64_internal(short row, short col, char *buf, size_t stride);
70 SYCL_EXTERNAL extern "C" _tile1024i
71 _tdpbssd_internal(unsigned short m, unsigned short n, unsigned short k,
72  _tile1024i dst, _tile1024i src1, _tile1024i src2);
73 SYCL_EXTERNAL extern "C" _tile1024i
74 _tdpbf16ps_internal(unsigned short m, unsigned short n, unsigned short k,
75  _tile1024i dst, _tile1024i src1, _tile1024i src2);
76 SYCL_EXTERNAL extern "C" void _tilestored64_internal(short row, short col,
77  char *buf, size_t stride,
78  _tile1024i tile);
79 static _tile1024i tileloadd64_internal(short row, short col, char *buf,
80  size_t stride) {
81  return _tileloadd64_internal(row, col, buf, stride);
82 }
83 static _tile1024i tdpbssd_internal(unsigned short m, unsigned short n,
84  unsigned short k, _tile1024i dst,
85  _tile1024i src1, _tile1024i src2) {
86  return _tdpbssd_internal(m, n, k, dst, src1, src2);
87 }
88 static _tile1024i tdpbf16ps_internal(unsigned short m, unsigned short n,
89  unsigned short k, _tile1024i dst,
90  _tile1024i src1, _tile1024i src2) {
91  return _tdpbf16ps_internal(m, n, k, dst, src1, src2);
92 }
93 static void tilestored64_internal(short row, short col, char *buf,
94  size_t stride, _tile1024i tile) {
95  return _tilestored64_internal(row, col, buf, stride, tile);
96 }
97 #else
98 static _tile1024i tileloadd64_internal(short row, short col, char *buf,
99  size_t stride) {
100  return __builtin_ia32_tileloadd64_internal(row, col, buf, stride);
101 }
102 static _tile1024i tdpbssd_internal(unsigned short m, unsigned short n,
103  unsigned short k, _tile1024i dst,
104  _tile1024i src1, _tile1024i src2) {
105  return __builtin_ia32_tdpbssd_internal(m, n, k, dst, src1, src2);
106 }
107 static _tile1024i tdpbf16ps_internal(unsigned short m, unsigned short n,
108  unsigned short k, _tile1024i dst,
109  _tile1024i src1, _tile1024i src2) {
110  return __builtin_ia32_tdpbf16ps_internal(m, n, k, dst, src1, src2);
111 }
112 static void tilestored64_internal(short row, short col, char *buf,
113  size_t stride, _tile1024i tile) {
114  __builtin_ia32_tilestored64_internal(row, col, buf, stride, tile);
115 }
116 #endif
117 
119 
120 inline constexpr size_t tile_size = 16;
121 
122 template <typename Group, typename T, size_t NumRows = detail::dynamic_extent,
123  size_t NumCols = detail::dynamic_extent,
124  matrix_layout Layout = matrix_layout::row_major,
125  typename Enabled = void>
126 struct joint_matrix {
127  joint_matrix(Group sg) {}
128  joint_matrix(Group sg, size_t Size) {
129  static_assert((NumRows != detail::dynamic_extent &&
130  NumCols != detail::dynamic_extent),
131  "AMX implementation does not support dynamic allocation");
132  }
133  joint_matrix(Group sg, size_t Rows, size_t Cols) {
134  static_assert((NumRows != detail::dynamic_extent &&
135  NumCols != detail::dynamic_extent),
136  "AMX implementation does not support dynamic allocation");
137  }
138 };
139 
140 // This template specialization handles cases where matrix can't be accommodated
141 // by a tile. In this case, we create raw_storage for the matrix and the size
142 // is the multiply of (TILE*TILE*4).
143 template <typename Group, typename T, size_t NumRows, size_t NumCols,
144  matrix_layout Layout>
146  Group, T, NumRows, NumCols, Layout,
147  typename std::enable_if<!((NumRows <= tile_size) &&
148  (NumCols * sizeof(T) / 4 <= tile_size) &&
149  (Layout != matrix_layout::col_major))>::type> {
150 public:
151  // trows: Num of tiles in row.
152  // If T=int8, NumRows==33, trows should be 3=(33+15)/16
153  static constexpr size_t trows = (NumRows + tile_size - 1) / tile_size;
154  // tcols: Num of tiles in column.
155  static constexpr size_t tcols =
156  (NumCols * sizeof(T) / 4 + tile_size - 1) / tile_size;
157  // if T=int8, NumRows==33, NumCols==33*4, tile_size==16, then size of
158  // raw_storage should be 48*48*4.
159  // FIXME: Greedy Regalloc for tile seems has some limitation and currently we
160  // do tileload for (16,16*4) instead of varying shapes, so raw_storage's size
161  // is multiple of (16*16*4)
162  static constexpr size_t size = trows * tcols * tile_size * tile_size * 4;
163  // stride is aligned to T instead of int8
164  static constexpr size_t stride = tcols * tile_size * 4 / sizeof(T);
165  int8_t raw_storage[size];
166  static constexpr bool isSmall = false;
167 
168 public:
169  matrix_layout layout;
170  // We do zero-padding for matrix whose size is not fitted into tiles in ctor.
171  joint_matrix(Group sg) { memset(raw_storage, 0x00, size); }
172 };
173 
174 // This template specialization handles cases where matrix can be put into a
175 // tile and users specify layout is packed_a or packed_b
176 template <typename Group, typename T, size_t NumRows, size_t NumCols,
177  matrix_layout Layout>
178 struct joint_matrix<
179  Group, T, NumRows, NumCols, Layout,
180  typename std::enable_if<(NumRows <= tile_size) &&
181  (NumCols * sizeof(T) / 4 <= tile_size)>::type> {
182 public:
183  static constexpr size_t trows = (NumRows + tile_size - 1) / tile_size;
184  // tcols: Num of tiles in column.
185  static constexpr size_t tcols =
186  (NumCols * sizeof(T) / 4 + tile_size - 1) / tile_size;
187  static constexpr size_t size = trows * tcols * tile_size * tile_size * 4;
188  // stride is aligned to T instead of int8
189  static constexpr size_t stride = tcols * tile_size * 4 / sizeof(T);
190  _tile1024i tile;
191  static constexpr bool isSmall = true;
192  matrix_layout layout;
193  // We do zero-padding for matrix whose size is not fitted into tiles in ctor.
194  joint_matrix(Group sg) {}
195 };
196 
197 } // namespace experimental::matrix
198 
199 namespace detail {
200 
201 using namespace experimental;
202 
203 template <typename Group, typename T, size_t NumRows, size_t NumCols,
204  matrix::matrix_layout Layout>
205 inline __SYCL_ALWAYS_INLINE static
206  typename std::enable_if<(NumRows > matrix::tile_size) ||
207  (NumCols * sizeof(T) / 4 > matrix::tile_size),
208  void>::type
209  submatrix_load(detail::submatrix<T> &sub_m,
210  matrix::joint_matrix<Group, T, NumRows, NumCols, Layout> jm,
211  uint32_t row, uint32_t col, size_t stride,
212  matrix::matrix_layout layout, bool shouldreload) {
213  uint32_t offset = (row * stride + col);
214  T *ptr = reinterpret_cast<T *>(jm.raw_storage);
215  ptr += offset;
216  stride *= sizeof(T);
217  sub_m.rows = matrix::tile_size;
218  sub_m.cols = matrix::tile_size * 4;
219  sub_m.tile = matrix::tileloadd64_internal(
220  sub_m.rows, sub_m.cols, reinterpret_cast<char *>(ptr), stride);
221 }
222 
223 template <typename Group, typename T, size_t NumRows, size_t NumCols,
224  matrix::matrix_layout Layout>
225 inline __SYCL_ALWAYS_INLINE static
226  typename std::enable_if<(NumRows <= matrix::tile_size) &&
227  (NumCols * sizeof(T) / 4 <= matrix::tile_size),
228  void>::type
229  submatrix_load(detail::submatrix<T> &sub_m,
230  matrix::joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
231  uint32_t row, uint32_t col, size_t stride,
232  matrix::matrix_layout layout, bool shouldreload) {
233  if (shouldreload) {
234  // Force sub_m.tile's shape to be matrix::tile_size *
235  // matrix::tile_size * 4
236  int8_t NewjmC[matrix::tile_size * matrix::tile_size * 4];
237  matrix::tilestored64_internal(NumRows, NumCols * sizeof(T),
238  reinterpret_cast<char *>(NewjmC),
239  matrix::tile_size * 4, jm.tile);
240  sub_m.rows = matrix::tile_size;
241  sub_m.cols = matrix::tile_size * 4;
242  sub_m.tile = matrix::tileloadd64_internal(sub_m.rows, sub_m.cols,
243  reinterpret_cast<char *>(NewjmC),
244  matrix::tile_size * 4);
245  return;
246  }
247  sub_m.rows = NumRows;
248  sub_m.cols = NumCols * sizeof(T);
249  sub_m.tile = jm.tile;
250 }
251 
252 // This handles cases where T1 is int8, T2 is int32.
253 inline __SYCL_ALWAYS_INLINE static void
254 submatrix_mad(detail::submatrix<int8_t> &sub_ma,
255  detail::submatrix<int8_t> &sub_mb,
256  detail::submatrix<int32_t> &sub_mc) {
257  sub_mc.tile = matrix::tdpbssd_internal(sub_mc.rows, sub_mc.cols, sub_ma.cols,
258  sub_mc.tile, sub_ma.tile, sub_mb.tile);
259 }
260 
261 // This handles cases where T1 is int16(bfloat16), T2 is float.
262 inline __SYCL_ALWAYS_INLINE static void
263 submatrix_mad(detail::submatrix<unsigned short> &sub_ma,
264  detail::submatrix<unsigned short> &sub_mb,
265  detail::submatrix<float> &sub_mc) {
266  sub_mc.tile =
267  matrix::tdpbf16ps_internal(sub_mc.rows, sub_mc.cols, sub_ma.cols,
268  sub_mc.tile, sub_ma.tile, sub_mb.tile);
269 }
270 
271 template <typename Group, typename T, size_t NumRows, size_t NumCols>
272 inline __SYCL_ALWAYS_INLINE static
273  typename std::enable_if<(NumRows > matrix::tile_size) ||
274  (NumCols * sizeof(T) / 4 > matrix::tile_size),
275  void>::type
276  submatrix_store(detail::submatrix<T> &sub_m,
277  matrix::joint_matrix<Group, T, NumRows, NumCols> &jm,
278  uint32_t row, uint32_t col, size_t stride,
279  matrix::matrix_layout layout, bool shouldreload) {
280  uint32_t offset = (row * stride + col);
281  T *ptr = reinterpret_cast<T *>(jm.raw_storage);
282  ptr += offset;
283  stride *= sizeof(T);
284  matrix::tilestored64_internal(sub_m.rows, sub_m.cols,
285  reinterpret_cast<char *>(ptr), stride,
286  sub_m.tile);
287 }
288 
289 template <typename Group, typename T, size_t NumRows, size_t NumCols>
290 inline __SYCL_ALWAYS_INLINE static
291  typename std::enable_if<(NumRows <= matrix::tile_size) &&
292  (NumCols * sizeof(T) / 4 <= matrix::tile_size),
293  void>::type
294  submatrix_store(detail::submatrix<T> &sub_m,
295  matrix::joint_matrix<Group, T, NumRows, NumCols> &jm,
296  uint32_t row, uint32_t col, size_t stride,
297  matrix::matrix_layout layout, bool shouldreload) {
298  if (shouldreload) {
299  int8_t NewjmC[matrix::tile_size * matrix::tile_size * 4];
300  matrix::tilestored64_internal(matrix::tile_size, matrix::tile_size * 4,
301  reinterpret_cast<char *>(NewjmC),
302  matrix::tile_size * 4, sub_m.tile);
303  jm.tile = matrix::tileloadd64_internal(NumRows, NumCols * sizeof(T),
304  reinterpret_cast<char *>(NewjmC),
305  matrix::tile_size * 4);
306  return;
307  }
308  jm.tile = sub_m.tile;
309 }
310 
311 } // namespace detail
312 
313 namespace experimental::matrix {
314 
315 // This handles cases where matrix can't be accommodated by a tile
316 template <typename Group, typename T, size_t NumRows, size_t NumCols,
317  matrix_layout Layout, access::address_space Space>
318 inline __SYCL_ALWAYS_INLINE typename std::enable_if<
319  (NumRows > tile_size) || (NumCols * sizeof(T) / 4 > tile_size), void>::type
320 joint_matrix_load(Group sg,
321  joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
322  multi_ptr<T, Space> src, size_t stride,
323  matrix_layout layout) {
324  T *mem = src.get();
325  // memcpy from mem to jm.raw_storage
326  for (int i = 0; i < NumRows; ++i) {
327  char *srcptr = reinterpret_cast<char *>(mem) + i * stride * sizeof(T);
328  char *dstptr =
329  reinterpret_cast<char *>(jm.raw_storage) + i * jm.stride * sizeof(T);
330  // TODO: we may reformat layout.
331  memcpy(dstptr, srcptr, NumCols * sizeof(T));
332  }
333  jm.layout = layout;
334 }
335 
336 // This handles cases where matrix can be put into a tile
337 template <typename Group, typename T, size_t NumRows, size_t NumCols,
338  matrix_layout Layout, access::address_space Space>
339 inline __SYCL_ALWAYS_INLINE
340  typename std::enable_if<(NumRows <= tile_size) &&
341  (NumCols * sizeof(T) / 4 <= tile_size),
342  void>::type
343  joint_matrix_load(Group sg,
344  joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
345  multi_ptr<T, Space> src, size_t stride,
346  matrix_layout layout) {
347  T *mem = src.get();
348  // tileload happens!
349  jm.tile =
350  tileloadd64_internal(NumRows, NumCols * sizeof(T),
351  reinterpret_cast<char *>(mem), stride * sizeof(T));
352  jm.layout = layout;
353 }
354 
355 // This handles cases where matrix can't be accommodated by a tile
356 template <typename Group, typename T, size_t NumRows, size_t NumCols,
357  matrix_layout Layout, access::address_space Space>
358 inline __SYCL_ALWAYS_INLINE typename std::enable_if<
359  (NumRows > tile_size) || (NumCols * sizeof(T) / 4 > tile_size), void>::type
360 joint_matrix_store(Group sg,
361  joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
362  multi_ptr<T, Space> dst, size_t stride,
363  matrix_layout layout) {
364  T *mem = dst.get();
365  for (int i = 0; i < NumRows; ++i) {
366  char *dstptr = reinterpret_cast<char *>(mem) + i * stride * sizeof(T);
367  char *srcptr =
368  reinterpret_cast<char *>(jm.raw_storage) + i * jm.stride * sizeof(T);
369  // TODO: we may reformat layout.
370  memcpy(dstptr, srcptr, NumCols * sizeof(T));
371  }
372  return;
373 }
374 
375 // This handles cases where matrix can be put into a tile
376 template <typename Group, typename T, size_t NumRows, size_t NumCols,
377  matrix_layout Layout, access::address_space Space>
378 inline __SYCL_ALWAYS_INLINE
379  typename std::enable_if<(NumRows <= tile_size) &&
380  (NumCols * sizeof(T) / 4 <= tile_size),
381  void>::type
382  joint_matrix_store(Group sg,
383  joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
384  multi_ptr<T, Space> dst, size_t stride,
385  matrix_layout layout) {
386  T *mem = dst.get();
387  // tilestore happens!
388  tilestored64_internal(NumRows, NumCols * sizeof(T),
389  reinterpret_cast<char *>(mem), stride * sizeof(T),
390  jm.tile);
391  return;
392 }
393 
394 template <typename Group, typename T1, typename T2, size_t NumRowsA,
395  size_t NumColsA, size_t NumRowsB, size_t NumColsB, size_t NumRowsC,
396  size_t NumColsC, matrix_layout LayoutA, matrix_layout LayoutB,
397  matrix_layout LayoutC>
398 inline __SYCL_ALWAYS_INLINE typename std::enable_if<
399  ((std::is_same<T1, int8_t>::value && std::is_same<T2, int32_t>::value) ||
400  (std::is_same<T1, unsigned short>::value &&
401  std::is_same<T2, float>::value)) &&
402  (LayoutA == matrix_layout::row_major) &&
403  (LayoutB == matrix_layout::packed_b) &&
404  (LayoutC == matrix_layout::row_major),
405  joint_matrix<Group, T2, NumRowsC, NumColsC, LayoutC>>::type
406 joint_matrix_mad(Group sg,
407  joint_matrix<Group, T1, NumRowsA, NumColsA, LayoutA> &jmA,
408  joint_matrix<Group, T1, NumRowsB, NumColsB, LayoutB> &jmB,
409  joint_matrix<Group, T2, NumRowsC, NumColsC, LayoutC> &jmC) {
410  joint_matrix<Group, T2, NumRowsC, NumColsC, LayoutC> res(jmC);
411  constexpr size_t epd = detail::elems_per_dword<T1>::value;
412  // If A is large and C is small, in joint_matrix_load, we do memcpy for A, and
413  // we do tileload for C whose shape is not tile_size*tile_size*4. In
414  // joint_matrix_mad, we do tileload for A and shape is tile_size*tile_size*4.
415  // So we need to reshape C before we do dpbssd.
416  bool Cshouldreload = res.isSmall && !jmA.isSmall && !jmB.isSmall;
417  bool Ashouldreload = jmA.isSmall && !jmB.isSmall;
418  bool Bshouldreload = jmB.isSmall && !jmA.isSmall;
419 
420  for (int m = 0; m < res.trows; ++m) {
421  for (int n = 0; n < res.tcols; ++n) {
422  detail::submatrix<T2> sub_c;
423 
424  // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64
425  submatrix_load(sub_c, res, m * tile_size, n * tile_size, res.stride,
426  matrix_layout::row_major, Cshouldreload);
427  for (int k = 0; k < jmA.tcols; ++k) { // K->int8_t
428  detail::submatrix<T1> sub_a;
429  detail::submatrix<T1> sub_b;
430  submatrix_load(sub_a, jmA, m * tile_size, k * tile_size * epd,
431  jmA.stride, matrix_layout::packed_a, Ashouldreload);
432  // Assume we alreay in vnni format.
433  submatrix_load(sub_b, jmB, k * tile_size, n * tile_size * epd,
434  jmB.stride, matrix_layout::packed_b, Bshouldreload);
435  submatrix_mad(sub_a, sub_b, sub_c);
436  }
437  submatrix_store(sub_c, res, m * tile_size, n * tile_size, res.stride,
438  matrix_layout::row_major, Cshouldreload);
439  }
440  }
441  return res;
442 }
443 
444 } // namespace experimental::matrix
445 } // namespace intel
446 } // namespace ext
447 } // namespace sycl
448 } // __SYCL_INLINE_NAMESPACE(cl)
ELEMS_PER_DWORD
#define ELEMS_PER_DWORD(TYPE, NUM)
Definition: matrix-aot-amx.hpp:56
cl::sycl::ext::intel::detail::submatrix
Definition: matrix-aot-amx.hpp:43
cl::sycl::ext::intel::experimental::matrix::matrix_layout::col_major
@ col_major
T
defines_elementary.hpp
SYCL_EXTERNAL
#define SYCL_EXTERNAL
Definition: defines_elementary.hpp:34
cl::sycl::ext::intel::experimental::matrix::matrix_layout
matrix_layout
Definition: matrix-aot-amx.hpp:118
cl::sycl::ext::intel::experimental::matrix::tilestored64_internal
static void tilestored64_internal(short row, short col, char *buf, size_t stride, _tile1024i tile)
Definition: matrix-aot-amx.hpp:112
cl::sycl::ext::intel::detail::submatrix::tile
_tile1024i tile
Definition: matrix-aot-amx.hpp:45
sycl
Definition: invoke_simd.hpp:68
cl::sycl::ext::intel::detail::submatrix::rows
short rows
Definition: matrix-aot-amx.hpp:46
max
simd< _Tp, _Abi > max(const simd< _Tp, _Abi > &, const simd< _Tp, _Abi > &) noexcept
cl::sycl::ext::intel::experimental::matrix::tile_size
constexpr size_t tile_size
Definition: matrix-aot-amx.hpp:120
cl::sycl::ext::intel::experimental::matrix::tdpbf16ps_internal
static _tile1024i tdpbf16ps_internal(unsigned short m, unsigned short n, unsigned short k, _tile1024i dst, _tile1024i src1, _tile1024i src2)
Definition: matrix-aot-amx.hpp:107
cl::sycl::ext::intel::experimental::matrix::joint_matrix::joint_matrix
joint_matrix(Group sg, size_t Rows, size_t Cols)
Definition: matrix-aot-amx.hpp:133
cl::sycl::ext::intel::experimental::matrix::joint_matrix::joint_matrix
joint_matrix(Group sg)
Definition: matrix-aot-amx.hpp:127
cl::sycl::ext::intel::experimental::matrix::matrix_layout::packed_b
@ packed_b
cl::sycl::ext::intel::experimental::matrix::joint_matrix::joint_matrix
joint_matrix(Group sg, size_t Size)
Definition: matrix-aot-amx.hpp:128
cl
We provide new interfaces for matrix muliply in this patch:
Definition: access.hpp:13
cl::sycl::ext::intel::experimental::matrix::tileloadd64_internal
static _tile1024i tileloadd64_internal(short row, short col, char *buf, size_t stride)
Definition: matrix-aot-amx.hpp:98
std
Definition: accessor.hpp:2616
cl::sycl::ext::intel::experimental::matrix::tdpbssd_internal
static _tile1024i tdpbssd_internal(unsigned short m, unsigned short n, unsigned short k, _tile1024i dst, _tile1024i src1, _tile1024i src2)
Definition: matrix-aot-amx.hpp:102
cl::sycl::ext::intel::experimental::matrix::joint_matrix
Definition: matrix-aot-amx.hpp:126
cl::sycl::ext::intel::experimental::matrix::matrix_layout::packed_a
@ packed_a
cl::sycl::ext::intel::experimental::matrix::matrix_layout::row_major
@ row_major
cl::sycl::ext::intel::detail::elems_per_dword
Definition: matrix-aot-amx.hpp:52
cl::sycl::ext::intel::detail::dynamic_extent
constexpr size_t dynamic_extent
Definition: matrix-aot-amx.hpp:50
__SYCL_INLINE_NAMESPACE
#define __SYCL_INLINE_NAMESPACE(X)
Definition: defines_elementary.hpp:12