20 inline namespace _V1 {
22 namespace ext::intel::esimd::xmx {
27 if constexpr (std::is_same_v<T,
30 else if constexpr (std::is_same_v<T, sycl::half>)
32 else if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>)
34 else if constexpr (std::is_same_v<T, unsigned char>)
36 else if constexpr (__ESIMD_DNS::is_type<T, char, signed char>())
58 template <
int RepeatCount,
int AElemBitSize,
int BElemBitSize,
bool IsDPASW>
60 static_assert(RepeatCount >= 1 && RepeatCount <= 8,
61 "Repeat count must be within 1 to 8 range");
63 if constexpr (IsDPASW && RepeatCount != 8) {
64 static_assert(!(AElemBitSize == 2 && BElemBitSize > 4),
65 "Unsupported repeat count for DPASW operation");
69 (AElemBitSize != 2 && (AElemBitSize != 4 || BElemBitSize <= 4)),
70 "Unsupported repeat count for DPASW operation");
74 template <
int SystolicDepth,
int RepeatCount,
typename T,
typename CT,
79 static_assert(SystolicDepth == 8,
"Systolic depth must be equal to 8");
83 "The types of dpas arguments are either incorrect or cannot be deduced."
84 "Fix the types and/or explicitly specify them.");
86 constexpr
int AElemBitSize = dpas_bitsize_from_precision<APrecision>();
87 constexpr
int BElemBitSize = dpas_bitsize_from_precision<BPrecision>();
88 static_assert(AElemBitSize != -1 && BElemBitSize != -1,
89 "Cannot deduce element size of input arguments");
90 verify_repeat_count<RepeatCount, AElemBitSize, BElemBitSize, IsDPASW>();
92 constexpr
int MaxElemBitSize =
93 AElemBitSize > BElemBitSize ? AElemBitSize : BElemBitSize;
94 constexpr
int MaxElemsInDword = 32 / MaxElemBitSize;
95 constexpr
int OpsPerChannel =
96 MaxElemsInDword > 8 ? 8 : (MaxElemsInDword < 1 ? 1 : MaxElemsInDword);
103 constexpr
int _M = RepeatCount;
104 constexpr
int _K = SystolicDepth * OpsPerChannel;
109 constexpr
int BMatrixBitSize =
sizeof(BT) * BN * 8;
110 constexpr
int BNumElems = BMatrixBitSize / BElemBitSize;
111 constexpr
int _N = BNumElems / _K;
112 static_assert(_K * _N == BNumElems,
"Cannot deduce the execution size.");
118 constexpr
int AFactorForDPASW = IsDPASW ? 2 : 1;
119 static_assert(_M * _K * AElemBitSize == AN *
sizeof(AT) * 8 * AFactorForDPASW,
120 "The first matrix multiplier has wrong size.");
121 static_assert(_K * _N * BElemBitSize == BN *
sizeof(BT) * 8,
122 "The second matrix multiplier has wrong size.");
126 constexpr
int ExecutionSize = _N;
128 static_assert(ExecutionSize == 8 || (!IsDPASW && ExecutionSize == 16),
129 "Execution size must be 8 or 16 for DPAS and 8 for DPASW.");
133 if constexpr (ExecutionSize == 8) {
134 static_assert(APrecision == BPrecision &&
135 __ESIMD_DNS::is_type<T, float>() &&
136 __ESIMD_DNS::is_type<CT, float>(),
137 "Unsupported DPAS types! The supported types are:\n"
138 " Result | C | B | A \n"
139 " f | f | hf | hf \n");
141 static_assert(APrecision == BPrecision &&
142 __ESIMD_DNS::is_type<T, float, sycl::half>() &&
143 __ESIMD_DNS::is_type<CT, float, sycl::half>(),
144 "Unsupported DPAS types! The supported types are:\n"
145 " Result | C | B | A \n"
146 " f, hf | f, hf | hf | hf \n");
151 if constexpr (ExecutionSize == 8) {
152 static_assert(APrecision == BPrecision &&
153 __ESIMD_DNS::is_type<T, float, bfloat16>() &&
154 __ESIMD_DNS::is_type<CT, float, bfloat16>(),
155 "Unsupported DPAS types! The supported types are:\n"
156 " Result | C | B | A \n"
157 " f | f | bf | bf \n");
159 static_assert(APrecision == BPrecision &&
160 __ESIMD_DNS::is_type<T, float, bfloat16>() &&
161 __ESIMD_DNS::is_type<CT, float, bfloat16>(),
162 "Unsupported DPAS types! The supported types are:\n"
163 " Result | C | B | A \n"
164 " f, bf | f, bf | bf | bf \n");
168 static_assert(ExecutionSize == 16,
169 "tf32 type can be used only with ExecutionSize=16");
170 static_assert(APrecision == BPrecision && std::is_same_v<T, float> &&
171 std::is_same_v<CT, float>,
172 "Unsupported DPAS types! The supported types are:\n"
173 " Result | C | B | A \n"
174 " f | f | tf32 | tf32 \n");
188 "Unsupported DPAS types! The supported types are:\n"
189 " Result | C | B | A \n"
190 " ud, d | ud, d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 \n");
192 return ExecutionSize;
211 int SystolicDepth,
int RepeatCount,
typename T,
typename CT,
typename BT,
215 int N,
int BN,
int AN>
216 __ESIMD_NS::simd<T, N>
dpas(__ESIMD_NS::simd<CT, N> C,
217 __ESIMD_NS::simd<BT, BN> B,
218 __ESIMD_NS::simd<AT, AN> A) {
220 SystolicDepth, RepeatCount, T, CT, BT, AT, BPrecision, APrecision, BN,
224 constexpr
int ANCasted = AN *
sizeof(AT) /
sizeof(MsgT);
225 constexpr
int BNCasted = BN *
sizeof(BT) /
sizeof(MsgT);
226 __ESIMD_NS::simd<MsgT, ANCasted> ACasted =
A.template bit_cast_view<MsgT>();
227 __ESIMD_NS::simd<MsgT, BNCasted> BCasted =
B.template bit_cast_view<MsgT>();
228 using CRawT =
typename __ESIMD_NS::simd<CT, N>::raw_element_type;
229 using RawT =
typename __ESIMD_NS::simd<T, N>::raw_element_type;
230 return __esimd_dpas2<BPrecision, APrecision, SystolicDepth, RepeatCount, RawT,
231 CRawT, MsgT, MsgT, N, BNCasted, ANCasted>(
232 C.data(), BCasted.data(), ACasted.data());
242 int SystolicDepth,
int RepeatCount,
typename T,
typename BT,
typename AT,
246 auto dpas(__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) {
248 constexpr
int ExecutionSize =
250 T, T, BT, AT, BPrecision,
251 APrecision, BN, AN>();
257 constexpr
int ResultN = RepeatCount * ExecutionSize;
260 constexpr
int ANCasted = AN *
sizeof(AT) /
sizeof(MsgT);
261 constexpr
int BNCasted = BN *
sizeof(BT) /
sizeof(MsgT);
262 __ESIMD_NS::simd<MsgT, ANCasted> ACasted =
A.template bit_cast_view<MsgT>();
263 __ESIMD_NS::simd<MsgT, BNCasted> BCasted =
B.template bit_cast_view<MsgT>();
265 constexpr
int Info = (RepeatCount << 24) + (SystolicDepth << 16) +
266 ((int)APrecision << 8) + (int)BPrecision;
267 using RawT =
typename __ESIMD_NS::simd<T, ResultN>::raw_element_type;
268 __ESIMD_NS::simd<T, ResultN> Result =
269 __esimd_dpas_nosrc0<Info, RawT, MsgT, MsgT, ResultN, BNCasted, ANCasted>(
270 BCasted.data(), ACasted.data());
282 int SystolicDepth,
int RepeatCount,
typename T,
typename BT,
typename AT,
285 int N,
int BN,
int AN>
286 __ESIMD_NS::simd<T, N>
dpasw(__ESIMD_NS::simd<T, N> C,
287 __ESIMD_NS::simd<BT, BN> B,
288 __ESIMD_NS::simd<AT, AN> A) {
290 constexpr
bool IsDPASW =
true;
292 SystolicDepth, RepeatCount, T, T, BT, AT, BPrecision, APrecision, BN, AN,
295 constexpr
int ANCasted = AN *
sizeof(AT) /
sizeof(
int);
296 constexpr
int BNCasted = BN *
sizeof(BT) /
sizeof(
int);
297 __ESIMD_NS::simd<int, ANCasted> ACasted =
A.template bit_cast_view<int>();
298 __ESIMD_NS::simd<int, BNCasted> BCasted =
B.template bit_cast_view<int>();
300 using RawT =
typename __ESIMD_NS::simd<T, N>::raw_element_type;
301 constexpr
int Info = (RepeatCount << 24) + (SystolicDepth << 16) +
302 ((int)APrecision << 8) + (int)BPrecision;
303 return __esimd_dpasw<Info, RawT, int, int, N, BNCasted, ANCasted>(
304 C.data(), BCasted.data(), ACasted.data());
314 int SystolicDepth,
int RepeatCount,
typename T,
typename BT,
typename AT,
318 auto dpasw(__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) {
320 constexpr
bool IsDPASW =
true;
322 SystolicDepth, RepeatCount, T, T, BT, AT, BPrecision, APrecision, BN, AN,
330 constexpr
int ResultN = RepeatCount * ExecutionSize;
332 constexpr
int ANCasted = AN *
sizeof(AT) /
sizeof(
int);
333 constexpr
int BNCasted = BN *
sizeof(BT) /
sizeof(
int);
334 __ESIMD_NS::simd<int, ANCasted> ACasted =
A.template bit_cast_view<int>();
335 __ESIMD_NS::simd<int, BNCasted> BCasted =
B.template bit_cast_view<int>();
337 using RawT =
typename __ESIMD_NS::simd<T, ResultN>::raw_element_type;
338 constexpr
int Info = (RepeatCount << 24) + (SystolicDepth << 16) +
339 ((int)APrecision << 8) + (int)BPrecision;
340 __ESIMD_NS::simd<T, ResultN> Result =
341 __esimd_dpasw_nosrc0<Info, RawT, int, int, ResultN, BNCasted, ANCasted>(
342 BCasted.data(), ACasted.data());
sycl::ext::intel::esimd::simd< T, N > dpasw(sycl::ext::intel::esimd::simd< T, N > C, sycl::ext::intel::esimd::simd< BT, BN > B, sycl::ext::intel::esimd::simd< AT, AN > A)
DPAS (Dot Product Accumulate Systolic) Computes the result of matrix operations: Result = C + A x B;.
sycl::ext::intel::esimd::simd< T, N > dpas(sycl::ext::intel::esimd::simd< CT, N > C, sycl::ext::intel::esimd::simd< BT, BN > B, sycl::ext::intel::esimd::simd< AT, AN > A)
constexpr int verify_parameters_and_deduce_exec_size()
constexpr void verify_repeat_count()
constexpr int dpas_bitsize_from_precision()
constexpr dpas_argument_type dpas_precision_from_type()
dpas_argument_type
Describes the element types in the input matrices.