36 #include <immintrin.h>
53 static constexpr
size_t value = 1;
56 #define ELEMS_PER_DWORD(TYPE, NUM) \
57 template <> struct elems_per_dword<TYPE> { \
58 static constexpr size_t value = NUM; \
66 namespace experimental::matrix {
67 #ifdef __SYCL_DEVICE_ONLY__
69 _tileloadd64_internal(
short row,
short col,
char *buf,
size_t stride);
71 _tdpbssd_internal(
unsigned short m,
unsigned short n,
unsigned short k,
72 _tile1024i dst, _tile1024i src1, _tile1024i src2);
74 _tdpbf16ps_internal(
unsigned short m,
unsigned short n,
unsigned short k,
75 _tile1024i dst, _tile1024i src1, _tile1024i src2);
76 SYCL_EXTERNAL extern "C" void _tilestored64_internal(
short row,
short col,
77 char *buf,
size_t stride,
81 return _tileloadd64_internal(row, col, buf, stride);
84 unsigned short k, _tile1024i dst,
85 _tile1024i src1, _tile1024i src2) {
86 return _tdpbssd_internal(m, n, k, dst, src1, src2);
89 unsigned short k, _tile1024i dst,
90 _tile1024i src1, _tile1024i src2) {
91 return _tdpbf16ps_internal(m, n, k, dst, src1, src2);
94 size_t stride, _tile1024i tile) {
95 return _tilestored64_internal(row, col, buf, stride, tile);
100 return __builtin_ia32_tileloadd64_internal(row, col, buf, stride);
103 unsigned short k, _tile1024i dst,
104 _tile1024i src1, _tile1024i src2) {
105 return __builtin_ia32_tdpbssd_internal(m, n, k, dst, src1, src2);
108 unsigned short k, _tile1024i dst,
109 _tile1024i src1, _tile1024i src2) {
110 return __builtin_ia32_tdpbf16ps_internal(m, n, k, dst, src1, src2);
113 size_t stride, _tile1024i tile) {
114 __builtin_ia32_tilestored64_internal(row, col, buf, stride, tile);
125 typename Enabled =
void>
131 "AMX implementation does not support dynamic allocation");
136 "AMX implementation does not support dynamic allocation");
143 template <
typename Group,
typename T,
size_t NumRows,
size_t NumCols,
146 Group, T, NumRows, NumCols, Layout,
147 typename
std::enable_if<!((NumRows <= tile_size) &&
148 (NumCols * sizeof(T) / 4 <= tile_size) &&
149 (Layout != matrix_layout::col_major))>::type> {
153 static constexpr size_t trows = (NumRows + tile_size - 1) / tile_size;
155 static constexpr size_t tcols =
156 (NumCols * sizeof(T) / 4 + tile_size - 1) / tile_size;
162 static constexpr size_t size = trows * tcols * tile_size * tile_size * 4;
164 static constexpr size_t stride = tcols * tile_size * 4 / sizeof(T);
165 int8_t raw_storage[size];
166 static constexpr bool isSmall = false;
169 matrix_layout layout;
171 joint_matrix(Group sg) { memset(raw_storage, 0x00, size); }
176 template <typename Group, typename T, size_t NumRows, size_t NumCols,
177 matrix_layout Layout>
179 Group, T, NumRows, NumCols, Layout,
180 typename std::enable_if<(NumRows <= tile_size) &&
181 (NumCols * sizeof(T) / 4 <= tile_size)>::type> {
183 static constexpr size_t trows = (NumRows + tile_size - 1) / tile_size;
185 static constexpr size_t tcols =
186 (NumCols * sizeof(T) / 4 + tile_size - 1) / tile_size;
187 static constexpr size_t size = trows * tcols * tile_size * tile_size * 4;
189 static constexpr size_t stride = tcols * tile_size * 4 / sizeof(T);
191 static constexpr bool isSmall = true;
192 matrix_layout layout;
194 joint_matrix(Group sg) {}
201 using namespace experimental;
203 template <typename Group, typename T, size_t NumRows, size_t NumCols,
204 matrix::matrix_layout Layout>
205 inline __SYCL_ALWAYS_INLINE static
206 typename std::enable_if<(NumRows > matrix::tile_size) ||
207 (NumCols * sizeof(T) / 4 > matrix::tile_size),
209 submatrix_load(detail::submatrix<T> &sub_m,
210 matrix::joint_matrix<Group, T, NumRows, NumCols, Layout> jm,
211 uint32_t row, uint32_t col, size_t stride,
212 matrix::matrix_layout layout, bool shouldreload) {
213 uint32_t offset = (row * stride + col);
214 T *ptr = reinterpret_cast<T *>(jm.raw_storage);
217 sub_m.rows = matrix::tile_size;
218 sub_m.cols = matrix::tile_size * 4;
219 sub_m.tile = matrix::tileloadd64_internal(
220 sub_m.rows, sub_m.cols, reinterpret_cast<char *>(ptr), stride);
223 template <typename Group, typename T, size_t NumRows, size_t NumCols,
224 matrix::matrix_layout Layout>
225 inline __SYCL_ALWAYS_INLINE static
226 typename std::enable_if<(NumRows <= matrix::tile_size) &&
227 (NumCols * sizeof(T) / 4 <= matrix::tile_size),
229 submatrix_load(detail::submatrix<T> &sub_m,
230 matrix::joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
231 uint32_t row, uint32_t col, size_t stride,
232 matrix::matrix_layout layout, bool shouldreload) {
236 int8_t NewjmC[matrix::tile_size * matrix::tile_size * 4];
237 matrix::tilestored64_internal(NumRows, NumCols * sizeof(T),
238 reinterpret_cast<char *>(NewjmC),
239 matrix::tile_size * 4, jm.tile);
240 sub_m.rows = matrix::tile_size;
241 sub_m.cols = matrix::tile_size * 4;
242 sub_m.tile = matrix::tileloadd64_internal(sub_m.rows, sub_m.cols,
243 reinterpret_cast<char *>(NewjmC),
244 matrix::tile_size * 4);
247 sub_m.rows = NumRows;
248 sub_m.cols = NumCols * sizeof(T);
249 sub_m.tile = jm.tile;
253 inline __SYCL_ALWAYS_INLINE static void
254 submatrix_mad(detail::submatrix<int8_t> &sub_ma,
255 detail::submatrix<int8_t> &sub_mb,
256 detail::submatrix<int32_t> &sub_mc) {
257 sub_mc.tile = matrix::tdpbssd_internal(sub_mc.rows, sub_mc.cols, sub_ma.cols,
258 sub_mc.tile, sub_ma.tile, sub_mb.tile);
262 inline __SYCL_ALWAYS_INLINE static void
263 submatrix_mad(detail::submatrix<unsigned short> &sub_ma,
264 detail::submatrix<unsigned short> &sub_mb,
265 detail::submatrix<float> &sub_mc) {
267 matrix::tdpbf16ps_internal(sub_mc.rows, sub_mc.cols, sub_ma.cols,
268 sub_mc.tile, sub_ma.tile, sub_mb.tile);
271 template <typename Group, typename T, size_t NumRows, size_t NumCols>
272 inline __SYCL_ALWAYS_INLINE static
273 typename std::enable_if<(NumRows > matrix::tile_size) ||
274 (NumCols * sizeof(T) / 4 > matrix::tile_size),
276 submatrix_store(detail::submatrix<T> &sub_m,
277 matrix::joint_matrix<Group, T, NumRows, NumCols> &jm,
278 uint32_t row, uint32_t col, size_t stride,
279 matrix::matrix_layout layout, bool shouldreload) {
280 uint32_t offset = (row * stride + col);
281 T *ptr = reinterpret_cast<T *>(jm.raw_storage);
284 matrix::tilestored64_internal(sub_m.rows, sub_m.cols,
285 reinterpret_cast<char *>(ptr), stride,
289 template <typename Group, typename T, size_t NumRows, size_t NumCols>
290 inline __SYCL_ALWAYS_INLINE static
291 typename std::enable_if<(NumRows <= matrix::tile_size) &&
292 (NumCols * sizeof(T) / 4 <= matrix::tile_size),
294 submatrix_store(detail::submatrix<T> &sub_m,
295 matrix::joint_matrix<Group, T, NumRows, NumCols> &jm,
296 uint32_t row, uint32_t col, size_t stride,
297 matrix::matrix_layout layout, bool shouldreload) {
299 int8_t NewjmC[matrix::tile_size * matrix::tile_size * 4];
300 matrix::tilestored64_internal(matrix::tile_size, matrix::tile_size * 4,
301 reinterpret_cast<char *>(NewjmC),
302 matrix::tile_size * 4, sub_m.tile);
303 jm.tile = matrix::tileloadd64_internal(NumRows, NumCols * sizeof(T),
304 reinterpret_cast<char *>(NewjmC),
305 matrix::tile_size * 4);
308 jm.tile = sub_m.tile;
313 namespace experimental::matrix {
316 template <typename Group, typename T, size_t NumRows, size_t NumCols,
317 matrix_layout Layout, access::address_space Space>
318 inline __SYCL_ALWAYS_INLINE typename std::enable_if<
319 (NumRows > tile_size) || (NumCols * sizeof(T) / 4 > tile_size), void>::type
320 joint_matrix_load(Group sg,
321 joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
322 multi_ptr<T, Space> src, size_t stride,
323 matrix_layout layout) {
326 for (int i = 0; i < NumRows; ++i) {
327 char *srcptr = reinterpret_cast<char *>(mem) + i * stride * sizeof(T);
329 reinterpret_cast<char *>(jm.raw_storage) + i * jm.stride * sizeof(T);
331 memcpy(dstptr, srcptr, NumCols * sizeof(T));
337 template <typename Group, typename T, size_t NumRows, size_t NumCols,
338 matrix_layout Layout, access::address_space Space>
339 inline __SYCL_ALWAYS_INLINE
340 typename std::enable_if<(NumRows <= tile_size) &&
341 (NumCols * sizeof(T) / 4 <= tile_size),
343 joint_matrix_load(Group sg,
344 joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
345 multi_ptr<T, Space> src, size_t stride,
346 matrix_layout layout) {
350 tileloadd64_internal(NumRows, NumCols * sizeof(T),
351 reinterpret_cast<char *>(mem), stride * sizeof(T));
356 template <typename Group, typename T, size_t NumRows, size_t NumCols,
357 matrix_layout Layout, access::address_space Space>
358 inline __SYCL_ALWAYS_INLINE typename std::enable_if<
359 (NumRows > tile_size) || (NumCols * sizeof(T) / 4 > tile_size), void>::type
360 joint_matrix_store(Group sg,
361 joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
362 multi_ptr<T, Space> dst, size_t stride,
363 matrix_layout layout) {
365 for (int i = 0; i < NumRows; ++i) {
366 char *dstptr = reinterpret_cast<char *>(mem) + i * stride * sizeof(T);
368 reinterpret_cast<char *>(jm.raw_storage) + i * jm.stride * sizeof(T);
370 memcpy(dstptr, srcptr, NumCols * sizeof(T));
376 template <typename Group, typename T, size_t NumRows, size_t NumCols,
377 matrix_layout Layout, access::address_space Space>
378 inline __SYCL_ALWAYS_INLINE
379 typename std::enable_if<(NumRows <= tile_size) &&
380 (NumCols * sizeof(T) / 4 <= tile_size),
382 joint_matrix_store(Group sg,
383 joint_matrix<Group, T, NumRows, NumCols, Layout> &jm,
384 multi_ptr<T, Space> dst, size_t stride,
385 matrix_layout layout) {
388 tilestored64_internal(NumRows, NumCols * sizeof(T),
389 reinterpret_cast<char *>(mem), stride * sizeof(T),
394 template <typename Group, typename T1, typename T2, size_t NumRowsA,
395 size_t NumColsA, size_t NumRowsB, size_t NumColsB, size_t NumRowsC,
396 size_t NumColsC, matrix_layout LayoutA, matrix_layout LayoutB,
397 matrix_layout LayoutC>
398 inline __SYCL_ALWAYS_INLINE typename std::enable_if<
399 ((std::is_same<T1, int8_t>::value && std::is_same<T2, int32_t>::value) ||
400 (std::is_same<T1, unsigned short>::value &&
401 std::is_same<T2, float>::value)) &&
402 (LayoutA == matrix_layout::row_major) &&
403 (LayoutB == matrix_layout::packed_b) &&
404 (LayoutC == matrix_layout::row_major),
405 joint_matrix<Group, T2, NumRowsC, NumColsC, LayoutC>>::type
406 joint_matrix_mad(Group sg,
407 joint_matrix<Group, T1, NumRowsA, NumColsA, LayoutA> &jmA,
408 joint_matrix<Group, T1, NumRowsB, NumColsB, LayoutB> &jmB,
409 joint_matrix<Group, T2, NumRowsC, NumColsC, LayoutC> &jmC) {
410 joint_matrix<Group, T2, NumRowsC, NumColsC, LayoutC> res(jmC);
411 constexpr size_t epd = detail::elems_per_dword<T1>::value;
416 bool Cshouldreload = res.isSmall && !jmA.isSmall && !jmB.isSmall;
417 bool Ashouldreload = jmA.isSmall && !jmB.isSmall;
418 bool Bshouldreload = jmB.isSmall && !jmA.isSmall;
420 for (int m = 0; m < res.trows; ++m) {
421 for (int n = 0; n < res.tcols; ++n) {
422 detail::submatrix<T2> sub_c;
425 submatrix_load(sub_c, res, m * tile_size, n * tile_size, res.stride,
426 matrix_layout::row_major, Cshouldreload);
427 for (int k = 0; k < jmA.tcols; ++k) {
428 detail::submatrix<T1> sub_a;
429 detail::submatrix<T1> sub_b;
430 submatrix_load(sub_a, jmA, m * tile_size, k * tile_size * epd,
431 jmA.stride, matrix_layout::packed_a, Ashouldreload);
433 submatrix_load(sub_b, jmB, k * tile_size, n * tile_size * epd,
434 jmB.stride, matrix_layout::packed_b, Bshouldreload);
435 submatrix_mad(sub_a, sub_b, sub_c);
437 submatrix_store(sub_c, res, m * tile_size, n * tile_size, res.stride,
438 matrix_layout::row_major, Cshouldreload);