18 namespace ext::oneapi::experimental::matrix {
22 template <matrix_layout Layout>
struct spv_matrix_layout_traits {
26 #define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT) \
27 template <> struct spv_matrix_layout_traits<LAYOUT> { \
28 static constexpr __spv::MatrixLayout value = SPV_LAYOUT; \
38 template <
typename G>
struct spv_scope_traits {};
42 template <
int D>
struct spv_scope_traits<
sycl::
group<D>> {
46 template <
typename T,
size_t NumRows,
size_t NumCols,
48 typename Group = sycl::sub_group>
51 template <
typename T,
size_t NumRows,
size_t NumCols,
53 typename Group = sycl::sub_group>
57 T, NumRows, NumCols, spv_matrix_layout_traits<Layout>::value,
60 #ifndef __SYCL_DEVICE_ONLY__
62 throw runtime_error(
"joint matrix is not supported on host device.",
63 PI_ERROR_INVALID_DEVICE);
73 template <
typename Group,
typename T,
size_t NumRows,
size_t NumCols,
79 #ifdef __SYCL_DEVICE_ONLY__
83 assert(
false &&
"Invalid Memory Layout!");
84 case matrix_layout::row_major:
86 __spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
91 case matrix_layout::col_major:
93 __spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
98 case matrix_layout::packed_a:
100 __spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
102 Ptr, stride, __spv::MatrixLayout::PackedA,
105 case matrix_layout::packed_b:
107 __spirv_JointMatrixLoadINTEL<T, NumRows, NumCols,
109 Ptr, stride, __spv::MatrixLayout::PackedB,
119 throw runtime_error(
"joint matrix is not supported on host device.",
120 PI_ERROR_INVALID_DEVICE);
124 template <
typename Group,
typename T,
size_t NumRows,
size_t NumCols,
130 #ifdef __SYCL_DEVICE_ONLY__
134 assert(
false &&
"Invalid Memory Layout!");
135 case matrix_layout::row_major:
136 __spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
141 case matrix_layout::col_major:
142 __spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
147 case matrix_layout::packed_a:
148 __spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
150 Ptr, src.spvm, stride, __spv::MatrixLayout::PackedA,
153 case matrix_layout::packed_b:
154 __spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
156 Ptr, src.spvm, stride, __spv::MatrixLayout::PackedB,
166 throw runtime_error(
"joint matrix is not supported on host device.",
167 PI_ERROR_INVALID_DEVICE);
171 template <
typename Group,
typename T1,
typename T2,
typename T3,
size_t M,
178 #ifdef __SYCL_DEVICE_ONLY__
180 if constexpr (std::is_same<T1, uint16_t>::value &&
181 std::is_same<T2, uint16_t>::value &&
182 std::is_same<T3, float>::value)
183 res.spvm = __spirv_JointMatrixMadINTEL(mA.spvm, mB.spvm, mC.spvm);
184 else if constexpr (std::is_unsigned<T1>::value && std::is_unsigned<T2>::value)
185 res.spvm = __spirv_JointMatrixUUMadINTEL(mA.spvm, mB.spvm, mC.spvm);
186 else if constexpr (std::is_signed<T1>::value && std::is_unsigned<T2>::value)
187 res.spvm = __spirv_JointMatrixSUMadINTEL(mA.spvm, mB.spvm, mC.spvm);
188 else if constexpr (std::is_unsigned<T1>::value && std::is_signed<T2>::value)
189 res.spvm = __spirv_JointMatrixUSMadINTEL(mA.spvm, mB.spvm, mC.spvm);
191 res.spvm = __spirv_JointMatrixMadINTEL(mA.spvm, mB.spvm, mC.spvm);
198 throw runtime_error(
"joint matrix is not supported on host device.",
199 PI_ERROR_INVALID_DEVICE);
203 template <
typename Group,
typename T,
size_t NumRows,
size_t NumCols,
212 #ifdef __SYCL_DEVICE_ONLY__
214 __spirv_CompositeConstruct<T, NumRows, NumCols,
224 template <
typename T,
size_t NumRows,
size_t NumCols,
226 typename Group = sycl::sub_group>
228 joint_matrix<T, NumRows, NumCols, Layout, Group> &M;
236 #ifdef __SYCL_DEVICE_ONLY__
237 return __spirv_VectorExtractDynamic(M.spvm, idx);
239 throw runtime_error(
"joint matrix is not supported on host device.",
240 PI_ERROR_INVALID_DEVICE);
244 explicit operator bool() {
245 #ifdef __SYCL_DEVICE_ONLY__
246 return __spirv_VectorExtractDynamic(M.spvm, idx) !=
static_cast<T
>(0);
248 throw runtime_error(
"joint matrix is not supported on host device.",
249 PI_ERROR_INVALID_DEVICE);
254 #ifdef __SYCL_DEVICE_ONLY__
255 M.spvm = __spirv_VectorInsertDynamic(M.spvm,
static_cast<T
>(rhs), idx);
259 throw runtime_error(
"joint matrix is not supported on host device.",
260 PI_ERROR_INVALID_DEVICE);
266 #ifdef __SYCL_DEVICE_ONLY__
267 M.spvm = __spirv_VectorInsertDynamic(
268 M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
272 throw runtime_error(
"joint matrix is not supported on host device.",
273 PI_ERROR_INVALID_DEVICE);
277 #if __SYCL_DEVICE_ONLY__
279 template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
280 M.spvm = __spirv_VectorInsertDynamic( \
282 static_cast<T>(__spirv_VectorExtractDynamic(M.spvm, idx) \
283 op static_cast<T>(rhs)), \
289 template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
291 throw runtime_error("joint matrix is not supported on host device.", \
292 PI_ERROR_INVALID_DEVICE); \
308 template <
size_t NumRows,
size_t NumCols, matrix_layout Layout,
typename Group>
309 class wi_element<uint16_t, NumRows, NumCols, Layout, Group> {
317 operator uint16_t() {
318 #ifdef __SYCL_DEVICE_ONLY__
319 return __spirv_VectorExtractDynamic(M.spvm, idx);
321 throw runtime_error(
"joint matrix is not supported on host device.",
322 PI_ERROR_INVALID_DEVICE);
326 explicit operator bool() {
327 #ifdef __SYCL_DEVICE_ONLY__
328 return std::fabs(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx))) >=
329 std::numeric_limits<float>::epsilon();
331 throw runtime_error(
"joint matrix is not supported on host device.",
332 PI_ERROR_INVALID_DEVICE);
337 #ifdef __SYCL_DEVICE_ONLY__
338 M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx);
342 throw runtime_error(
"joint matrix is not supported on host device.",
343 PI_ERROR_INVALID_DEVICE);
349 #ifdef __SYCL_DEVICE_ONLY__
350 M.spvm = __spirv_VectorInsertDynamic(
351 M.spvm, __spirv_VectorExtractDynamic(rhs.M.
spvm, rhs.idx), idx);
355 throw runtime_error(
"joint matrix is not supported on host device.",
356 PI_ERROR_INVALID_DEVICE);
367 float *res =
reinterpret_cast<float *
>(&y);
372 int *res =
reinterpret_cast<int *
>(&x);
374 return (uint16_t)*res;
377 #if __SYCL_DEVICE_ONLY__
379 wi_element &operator op##=(const uint16_t &rhs) { \
380 M.spvm = __spirv_VectorInsertDynamic( \
382 make_bf16(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx) \
383 op make_fp32(rhs))), \
389 wi_element &operator op##=(const uint16_t &rhs) { \
391 throw runtime_error("joint matrix is not supported on host device.", \
392 PI_ERROR_INVALID_DEVICE); \
401 template <
typename T1,
typename T2>
struct Converter {
402 static T2
convert(
const T1 &from) {
return static_cast<T2
>(from); }
405 template <
typename T>
struct Converter<T, uint16_t> {
406 static uint16_t
convert(
const T &from) {
return make_bf16(from); }
408 #if __SYCL_DEVICE_ONLY__
409 #define OP(input_type, type, op) \
410 friend type operator op( \
411 const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &lhs, \
412 const uint16_t &rhs) { \
413 return Converter<input_type, type>::convert(make_fp32( \
414 __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) op make_fp32(rhs)); \
416 friend type operator op( \
417 const uint16_t &lhs, \
418 const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &rhs) { \
419 return Converter<input_type, type>::convert(make_fp32( \
420 __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx)) op make_fp32(lhs)); \
423 #define OP(input_type, type, op) \
424 friend type operator op( \
425 const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &lhs, \
426 const uint16_t &rhs) { \
429 throw runtime_error("joint matrix is not supported on host device.", \
430 PI_ERROR_INVALID_DEVICE); \
432 friend type operator op( \
433 const uint16_t &lhs, \
434 const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &rhs) { \
437 throw runtime_error("joint matrix is not supported on host device.", \
438 PI_ERROR_INVALID_DEVICE); \
441 OP(
float, uint16_t, +)
442 OP(
float, uint16_t, -)
443 OP(
float, uint16_t, *)
444 OP(
float, uint16_t, /)
454 template <
size_t NumRows,
size_t NumCols, matrix_layout Layout,
typename Group>
464 operator sycl::ext::oneapi::bfloat16() {
465 #ifdef __SYCL_DEVICE_ONLY__
466 return __spirv_VectorExtractDynamic(M.spvm, idx);
468 throw runtime_error(
"joint matrix is not supported on host device.",
469 PI_ERROR_INVALID_DEVICE);
473 explicit operator bool() {
474 #ifdef __SYCL_DEVICE_ONLY__
475 return std::fabs(
static_cast<float>(__spirv_VectorExtractDynamic(
476 M.spvm, idx))) >= std::numeric_limits<float>::epsilon();
478 throw runtime_error(
"joint matrix is not supported on host device.",
479 PI_ERROR_INVALID_DEVICE);
484 #ifdef __SYCL_DEVICE_ONLY__
485 M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx);
489 throw runtime_error(
"joint matrix is not supported on host device.",
490 PI_ERROR_INVALID_DEVICE);
495 NumCols, Layout, Group> &rhs) {
496 #ifdef __SYCL_DEVICE_ONLY__
497 M.spvm = __spirv_VectorInsertDynamic(
498 M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
502 throw runtime_error(
"joint matrix is not supported on host device.",
503 PI_ERROR_INVALID_DEVICE);
507 #if __SYCL_DEVICE_ONLY__
508 #define OP(opassign, op) \
509 wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \
510 M.spvm = __spirv_VectorInsertDynamic( \
511 M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) op rhs, idx); \
515 #define OP(opassign, op) \
516 wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \
518 throw runtime_error("joint matrix is not supported on host device.", \
519 PI_ERROR_INVALID_DEVICE); \
528 #if __SYCL_DEVICE_ONLY__
529 #define OP(type, op) \
530 friend type operator op( \
531 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
533 const sycl::ext::oneapi::bfloat16 &rhs) { \
534 return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) op rhs; \
536 friend type operator op( \
537 const sycl::ext::oneapi::bfloat16 &lhs, \
538 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
540 return __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx) op lhs; \
542 OP(sycl::ext::oneapi::bfloat16, +)
547 #define OP(type, op) \
548 friend type operator op( \
549 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
551 const sycl::ext::oneapi::bfloat16 &rhs) { \
552 return type{static_cast<float>(__spirv_VectorExtractDynamic( \
553 lhs.M.spvm, lhs.idx)) op static_cast<float>(rhs)}; \
555 friend type operator op( \
556 const sycl::ext::oneapi::bfloat16 &lhs, \
557 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
559 return type{static_cast<float>(__spirv_VectorExtractDynamic( \
560 rhs.M.spvm, rhs.idx)) op static_cast<float>(lhs)}; \
570 #define OP(type, op) \
571 friend type operator op(const wi_element<sycl::ext::oneapi::bfloat16, \
572 NumRows, NumCols, Layout, Group> &, \
573 const sycl::ext::oneapi::bfloat16 &) { \
574 throw runtime_error("joint matrix is not supported on host device.", \
575 PI_ERROR_INVALID_DEVICE); \
577 friend type operator op( \
578 const sycl::ext::oneapi::bfloat16 &, \
579 const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
581 throw runtime_error("joint matrix is not supported on host device.", \
582 PI_ERROR_INVALID_DEVICE); \
584 OP(sycl::ext::oneapi::bfloat16, +)
585 OP(
sycl::ext::oneapi::bfloat16, -)
586 OP(
sycl::ext::oneapi::bfloat16, *)
587 OP(
sycl::ext::oneapi::bfloat16, /)
598 template <
typename T,
size_t NumRows,
size_t NumCols,
matrix_layout Layout,
601 joint_matrix<T, NumRows, NumCols, Layout, Group> &M;
606 #ifdef __SYCL_DEVICE_ONLY__
607 return __spirv_JointMatrixWorkItemLengthINTEL(M.spvm);
609 throw runtime_error(
"joint matrix is not supported on host device.",
610 PI_ERROR_INVALID_DEVICE);
618 #undef SPV_MATRIX_LAYOUT_TRAITS
wi_element< T, NumRows, NumCols, Layout, Group > operator[](size_t i)
wi_data(joint_matrix< T, NumRows, NumCols, Layout, Group > &Mat)
wi_element & operator=(const sycl::ext::oneapi::bfloat16 &rhs)
wi_element & operator=(const wi_element< sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, Group > &rhs)
wi_element(joint_matrix< sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, Group > &Mat, std::size_t i)
wi_element & operator=(const wi_element< uint16_t, NumRows, NumCols, Layout, Group > &rhs)
static float make_fp32(uint16_t x)
wi_element(joint_matrix< uint16_t, NumRows, NumCols, Layout, Group > &Mat, std::size_t i)
static uint16_t make_bf16(float x)
wi_element & operator=(const uint16_t &rhs)
wi_element & operator=(const wi_element< T, NumRows, NumCols, Layout, Group > &rhs)
wi_element & operator=(const T2 &rhs)
wi_element(joint_matrix< T, NumRows, NumCols, Layout, Group > &Mat, std::size_t i)
Provides constructors for address space qualified and non address space qualified pointers to allow i...
#define __SYCL_INLINE_VER_NAMESPACE(X)
#define __SYCL_ALWAYS_INLINE
#define SPV_MATRIX_LAYOUT_TRAITS(LAYOUT, SPV_LAYOUT)
__SYCL_ALWAYS_INLINE void joint_matrix_store(Group sg, joint_matrix< T, NumRows, NumCols, MatL, Group > &src, multi_ptr< T, Space, IsDecorated > res, size_t stride, matrix_layout MemL)
__SYCL_ALWAYS_INLINE void joint_matrix_load(Group sg, joint_matrix< T, NumRows, NumCols, Layout, Group > &res, multi_ptr< T, Space, IsDecorated > src, size_t stride, matrix_layout MemL)
__SYCL_ALWAYS_INLINE void joint_matrix_fill(Group sg, joint_matrix< T, NumRows, NumCols, Layout, Group > &res, const T2 v)
__SYCL_ALWAYS_INLINE joint_matrix< T3, M, N, LayoutC, Group > joint_matrix_mad(Group sg, joint_matrix< T1, M, K, LayoutA, Group > &mA, joint_matrix< T2, K, N, LayoutB, Group > &mB, joint_matrix< T3, M, N, LayoutC, Group > &mC)
std::enable_if_t< detail::is_bf16_storage_type< T >::value, T > fabs(T x)
---— Error handling, matching OpenCL plugin semantics.
__spv::__spirv_JointMatrixINTEL< T, NumRows, NumCols, spv_matrix_layout_traits< Layout >::value, spv_scope_traits< Group >::value > * spvm
__SYCL_ALWAYS_INLINE wi_data< T, NumRows, NumCols, Layout, Group > get_wi_data()
static uint16_t convert(const T &from)
static T2 convert(const T1 &from)