DPC++ Runtime
Runtime libraries for oneAPI DPC++
dpas.hpp
Go to the documentation of this file.
1 //==----------------- xmx/dpas.hpp - DPC++ Explicit SIMD API ---------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // Explicit SIMD API for DPAS Intel Xe Matrix eXtension.
9 //===----------------------------------------------------------------------===//
10 
11 #pragma once
12 
18 
19 namespace sycl {
20 inline namespace _V1 {
21 
22 namespace ext::intel::esimd::xmx {
23 
24 namespace detail {
25 
26 template <typename T> constexpr dpas_argument_type dpas_precision_from_type() {
27  if constexpr (std::is_same_v<T,
30  else if constexpr (std::is_same_v<T, sycl::half>)
32  else if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>)
34  else if constexpr (std::is_same_v<T, unsigned char>)
36  else if constexpr (__ESIMD_DNS::is_type<T, char, signed char>())
38  else
40 }
41 
42 template <dpas_argument_type T> constexpr int dpas_bitsize_from_precision() {
43  if constexpr (T == dpas_argument_type::u2 || T == dpas_argument_type::s2)
44  return 2;
45  else if constexpr (T == dpas_argument_type::u4 || T == dpas_argument_type::s4)
46  return 4;
47  else if constexpr (T == dpas_argument_type::u8 || T == dpas_argument_type::s8)
48  return 8;
49  else if constexpr (T == dpas_argument_type::bf16 ||
51  return 16;
52  else if constexpr (T == dpas_argument_type::tf32)
53  return 32;
54  else
55  return -1;
56 }
57 
58 template <int RepeatCount, int AElemBitSize, int BElemBitSize, bool IsDPASW>
59 constexpr void verify_repeat_count() {
60  static_assert(RepeatCount >= 1 && RepeatCount <= 8,
61  "Repeat count must be within 1 to 8 range");
62 
63  if constexpr (IsDPASW && RepeatCount != 8) {
64  static_assert(!(AElemBitSize == 2 && BElemBitSize > 4),
65  "Unsupported repeat count for DPASW operation");
66 
67  static_assert(
68  RepeatCount == 4 ||
69  (AElemBitSize != 2 && (AElemBitSize != 4 || BElemBitSize <= 4)),
70  "Unsupported repeat count for DPASW operation");
71  }
72 }
73 
74 template <int SystolicDepth, int RepeatCount, typename T, typename CT,
75  typename BT, typename AT, dpas_argument_type BPrecision,
76  dpas_argument_type APrecision, int BN, int AN, bool IsDPASW = false>
78 
79  static_assert(SystolicDepth == 8, "Systolic depth must be equal to 8");
80  static_assert(
81  APrecision != dpas_argument_type::Invalid &&
82  BPrecision != dpas_argument_type::Invalid,
83  "The types of dpas arguments are either incorrect or cannot be deduced."
84  "Fix the types and/or explicitly specify them.");
85 
86  constexpr int AElemBitSize = dpas_bitsize_from_precision<APrecision>();
87  constexpr int BElemBitSize = dpas_bitsize_from_precision<BPrecision>();
88  static_assert(AElemBitSize != -1 && BElemBitSize != -1,
89  "Cannot deduce element size of input arguments");
90  verify_repeat_count<RepeatCount, AElemBitSize, BElemBitSize, IsDPASW>();
91 
92  constexpr int MaxElemBitSize =
93  AElemBitSize > BElemBitSize ? AElemBitSize : BElemBitSize;
94  constexpr int MaxElemsInDword = 32 / MaxElemBitSize;
95  constexpr int OpsPerChannel =
96  MaxElemsInDword > 8 ? 8 : (MaxElemsInDword < 1 ? 1 : MaxElemsInDword);
97 
98  // A(_Mx_K) * B(_Kx_N) + C(_Mx_N)
99  // where:
100  // _M = RepeatCount;
101  // _K = SystolicDepth * OpsPerChannel;
102  // _N = ExecutionSize (unknown, but deducible), must be 8 or 16.
103  constexpr int _M = RepeatCount;
104  constexpr int _K = SystolicDepth * OpsPerChannel;
105 
106  // Compute _N (aka ExecutionSize) from the matrix B.
107  // It has _K*_N elements of BPrecision type, and BN elements of BT type
108  // hold those _K*_N*BPrecision bits, which let's us compute _N.
109  constexpr int BMatrixBitSize = sizeof(BT) * BN * 8;
110  constexpr int BNumElems = BMatrixBitSize / BElemBitSize;
111  constexpr int _N = BNumElems / _K;
112  static_assert(_K * _N == BNumElems, "Cannot deduce the execution size.");
113 
114  // Now verify that AN elements of AT type hold exactly _M*_K elements
115  // of APrecision type/size. Similarly for B: BN elements of BT type must
116  // hold _K*_N elements of BPrecision type/size.
117  // DPASW accepts 2x less expected AN elements than regular DPAS.
118  constexpr int AFactorForDPASW = IsDPASW ? 2 : 1;
119  static_assert(_M * _K * AElemBitSize == AN * sizeof(AT) * 8 * AFactorForDPASW,
120  "The first matrix multiplier has wrong size.");
121  static_assert(_K * _N * BElemBitSize == BN * sizeof(BT) * 8,
122  "The second matrix multiplier has wrong size.");
123 
124  // Execution size may be 8 or 16 depending on the target device.
125  // User must check if used execution size is supported before calling DPAS.
126  constexpr int ExecutionSize = _N;
127 
128  static_assert(ExecutionSize == 8 || (!IsDPASW && ExecutionSize == 16),
129  "Execution size must be 8 or 16 for DPAS and 8 for DPASW.");
130 
131  if constexpr (APrecision == dpas_argument_type::fp16 ||
132  BPrecision == dpas_argument_type::fp16) {
133  if constexpr (ExecutionSize == 8) {
134  static_assert(APrecision == BPrecision &&
135  __ESIMD_DNS::is_type<T, float>() &&
136  __ESIMD_DNS::is_type<CT, float>(),
137  "Unsupported DPAS types! The supported types are:\n"
138  " Result | C | B | A \n"
139  " f | f | hf | hf \n");
140  } else {
141  static_assert(APrecision == BPrecision &&
142  __ESIMD_DNS::is_type<T, float, sycl::half>() &&
143  __ESIMD_DNS::is_type<CT, float, sycl::half>(),
144  "Unsupported DPAS types! The supported types are:\n"
145  " Result | C | B | A \n"
146  " f, hf | f, hf | hf | hf \n");
147  }
148  } else if constexpr (APrecision == dpas_argument_type::bf16 ||
149  BPrecision == dpas_argument_type::bf16) {
150  using bfloat16 = sycl::ext::oneapi::bfloat16;
151  if constexpr (ExecutionSize == 8) {
152  static_assert(APrecision == BPrecision &&
153  __ESIMD_DNS::is_type<T, float, bfloat16>() &&
154  __ESIMD_DNS::is_type<CT, float, bfloat16>(),
155  "Unsupported DPAS types! The supported types are:\n"
156  " Result | C | B | A \n"
157  " f | f | bf | bf \n");
158  } else {
159  static_assert(APrecision == BPrecision &&
160  __ESIMD_DNS::is_type<T, float, bfloat16>() &&
161  __ESIMD_DNS::is_type<CT, float, bfloat16>(),
162  "Unsupported DPAS types! The supported types are:\n"
163  " Result | C | B | A \n"
164  " f, bf | f, bf | bf | bf \n");
165  }
166  } else if constexpr (APrecision == dpas_argument_type::tf32 ||
167  BPrecision == dpas_argument_type::tf32) {
168  static_assert(ExecutionSize == 16,
169  "tf32 type can be used only with ExecutionSize=16");
170  static_assert(APrecision == BPrecision && std::is_same_v<T, float> &&
171  std::is_same_v<CT, float>,
172  "Unsupported DPAS types! The supported types are:\n"
173  " Result | C | B | A \n"
174  " f | f | tf32 | tf32 \n");
175  } else {
176  static_assert((APrecision == dpas_argument_type::u2 ||
177  APrecision == dpas_argument_type::s2 ||
178  APrecision == dpas_argument_type::u4 ||
179  APrecision == dpas_argument_type::s4 ||
180  APrecision == dpas_argument_type::u8 ||
181  APrecision == dpas_argument_type::s8) &&
182  (BPrecision == dpas_argument_type::u2 ||
183  BPrecision == dpas_argument_type::s2 ||
184  BPrecision == dpas_argument_type::u4 ||
185  BPrecision == dpas_argument_type::s4 ||
186  BPrecision == dpas_argument_type::u8 ||
187  BPrecision == dpas_argument_type::s8),
188  "Unsupported DPAS types! The supported types are:\n"
189  " Result | C | B | A \n"
190  " ud, d | ud, d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 \n");
191  }
192  return ExecutionSize;
193 }
194 
195 } // namespace detail
196 
200 
210 template <
211  int SystolicDepth, int RepeatCount, typename T, typename CT, typename BT,
212  typename AT,
213  dpas_argument_type BPrecision = detail::dpas_precision_from_type<BT>(),
214  dpas_argument_type APrecision = detail::dpas_precision_from_type<AT>(),
215  int N, int BN, int AN>
216 __ESIMD_NS::simd<T, N> dpas(__ESIMD_NS::simd<CT, N> C,
217  __ESIMD_NS::simd<BT, BN> B,
218  __ESIMD_NS::simd<AT, AN> A) {
220  SystolicDepth, RepeatCount, T, CT, BT, AT, BPrecision, APrecision, BN,
221  AN>();
222 
223  using MsgT = int;
224  constexpr int ANCasted = AN * sizeof(AT) / sizeof(MsgT);
225  constexpr int BNCasted = BN * sizeof(BT) / sizeof(MsgT);
226  __ESIMD_NS::simd<MsgT, ANCasted> ACasted = A.template bit_cast_view<MsgT>();
227  __ESIMD_NS::simd<MsgT, BNCasted> BCasted = B.template bit_cast_view<MsgT>();
228  using CRawT = typename __ESIMD_NS::simd<CT, N>::raw_element_type;
229  using RawT = typename __ESIMD_NS::simd<T, N>::raw_element_type;
230  return __esimd_dpas2<BPrecision, APrecision, SystolicDepth, RepeatCount, RawT,
231  CRawT, MsgT, MsgT, N, BNCasted, ANCasted>(
232  C.data(), BCasted.data(), ACasted.data());
233 }
234 
241 template <
242  int SystolicDepth, int RepeatCount, typename T, typename BT, typename AT,
243  dpas_argument_type BPrecision = detail::dpas_precision_from_type<BT>(),
244  dpas_argument_type APrecision = detail::dpas_precision_from_type<AT>(),
245  int BN, int AN>
246 auto dpas(__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) {
247 
248  constexpr int ExecutionSize =
249  detail::verify_parameters_and_deduce_exec_size<SystolicDepth, RepeatCount,
250  T, T, BT, AT, BPrecision,
251  APrecision, BN, AN>();
252  // Result(_Mx_N) = A(_Mx_K) * B(_Kx_N)
253  // where:
254  // _M = RepeatCount;
255  // _K = SystolicDepth * OpsPerChannel;
256  // _N = ExecutionSize (unknown, but deducible), must be 8 or 16.
257  constexpr int ResultN = RepeatCount * ExecutionSize;
258 
259  using MsgT = int;
260  constexpr int ANCasted = AN * sizeof(AT) / sizeof(MsgT);
261  constexpr int BNCasted = BN * sizeof(BT) / sizeof(MsgT);
262  __ESIMD_NS::simd<MsgT, ANCasted> ACasted = A.template bit_cast_view<MsgT>();
263  __ESIMD_NS::simd<MsgT, BNCasted> BCasted = B.template bit_cast_view<MsgT>();
264 
265  constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) +
266  ((int)APrecision << 8) + (int)BPrecision;
267  using RawT = typename __ESIMD_NS::simd<T, ResultN>::raw_element_type;
268  __ESIMD_NS::simd<T, ResultN> Result =
269  __esimd_dpas_nosrc0<Info, RawT, MsgT, MsgT, ResultN, BNCasted, ANCasted>(
270  BCasted.data(), ACasted.data());
271  return Result;
272 }
273 
281 template <
282  int SystolicDepth, int RepeatCount, typename T, typename BT, typename AT,
283  dpas_argument_type BPrecision = detail::dpas_precision_from_type<BT>(),
284  dpas_argument_type APrecision = detail::dpas_precision_from_type<AT>(),
285  int N, int BN, int AN>
286 __ESIMD_NS::simd<T, N> dpasw(__ESIMD_NS::simd<T, N> C,
287  __ESIMD_NS::simd<BT, BN> B,
288  __ESIMD_NS::simd<AT, AN> A) {
289 
290  constexpr bool IsDPASW = true;
292  SystolicDepth, RepeatCount, T, T, BT, AT, BPrecision, APrecision, BN, AN,
293  IsDPASW>();
294 
295  constexpr int ANCasted = AN * sizeof(AT) / sizeof(int);
296  constexpr int BNCasted = BN * sizeof(BT) / sizeof(int);
297  __ESIMD_NS::simd<int, ANCasted> ACasted = A.template bit_cast_view<int>();
298  __ESIMD_NS::simd<int, BNCasted> BCasted = B.template bit_cast_view<int>();
299 
300  using RawT = typename __ESIMD_NS::simd<T, N>::raw_element_type;
301  constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) +
302  ((int)APrecision << 8) + (int)BPrecision;
303  return __esimd_dpasw<Info, RawT, int, int, N, BNCasted, ANCasted>(
304  C.data(), BCasted.data(), ACasted.data());
305 }
306 
313 template <
314  int SystolicDepth, int RepeatCount, typename T, typename BT, typename AT,
315  dpas_argument_type BPrecision = detail::dpas_precision_from_type<BT>(),
316  dpas_argument_type APrecision = detail::dpas_precision_from_type<AT>(),
317  int BN, int AN>
318 auto dpasw(__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) {
319 
320  constexpr bool IsDPASW = true;
321  constexpr int ExecutionSize = detail::verify_parameters_and_deduce_exec_size<
322  SystolicDepth, RepeatCount, T, T, BT, AT, BPrecision, APrecision, BN, AN,
323  IsDPASW>();
324 
325  // Result(_Mx_N) = A(_Mx_K) * B(_Kx_N)
326  // where:
327  // _M = RepeatCount;
328  // _K = SystolicDepth * OpsPerChannel;
329  // _N = ExecutionSize (unknown, but deducible), must be 8 or 16.
330  constexpr int ResultN = RepeatCount * ExecutionSize;
331 
332  constexpr int ANCasted = AN * sizeof(AT) / sizeof(int);
333  constexpr int BNCasted = BN * sizeof(BT) / sizeof(int);
334  __ESIMD_NS::simd<int, ANCasted> ACasted = A.template bit_cast_view<int>();
335  __ESIMD_NS::simd<int, BNCasted> BCasted = B.template bit_cast_view<int>();
336 
337  using RawT = typename __ESIMD_NS::simd<T, ResultN>::raw_element_type;
338  constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) +
339  ((int)APrecision << 8) + (int)BPrecision;
340  __ESIMD_NS::simd<T, ResultN> Result =
341  __esimd_dpasw_nosrc0<Info, RawT, int, int, ResultN, BNCasted, ANCasted>(
342  BCasted.data(), ACasted.data());
343  return Result;
344 }
345 
347 
348 } // namespace ext::intel::esimd::xmx
349 } // namespace _V1
350 } // namespace sycl
sycl::ext::intel::esimd::simd< T, N > dpasw(sycl::ext::intel::esimd::simd< T, N > C, sycl::ext::intel::esimd::simd< BT, BN > B, sycl::ext::intel::esimd::simd< AT, AN > A)
DPAS (Dot Product Accumulate Systolic) Computes the result of matrix operations: Result = C + A x B;.
Definition: dpas.hpp:286
sycl::ext::intel::esimd::simd< T, N > dpas(sycl::ext::intel::esimd::simd< CT, N > C, sycl::ext::intel::esimd::simd< BT, BN > B, sycl::ext::intel::esimd::simd< AT, AN > A)
Definition: dpas.hpp:216
constexpr int verify_parameters_and_deduce_exec_size()
Definition: dpas.hpp:77
constexpr int dpas_bitsize_from_precision()
Definition: dpas.hpp:42
constexpr dpas_argument_type dpas_precision_from_type()
Definition: dpas.hpp:26
dpas_argument_type
Describes the element types in the input matrices.
Definition: common.hpp:22
Definition: access.hpp:18