DPC++ Runtime
Runtime libraries for oneAPI DPC++
math_intrin.hpp
Go to the documentation of this file.
1 //==------------ math_intrin.hpp - DPC++ Explicit SIMD API -----------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // Declares experimental Explicit SIMD math intrinsics.
9 //===----------------------------------------------------------------------===//
10 
11 #pragma once
12 
14 
16 
17 #define __ESIMD_raw_vec_t(T, SZ) \
18  __ESIMD_DNS::vector_type_t<__ESIMD_DNS::__raw_t<T>, SZ>
19 #define __ESIMD_cpp_vec_t(T, SZ) \
20  __ESIMD_DNS::vector_type_t<__ESIMD_DNS::__cpp_t<T>, SZ>
21 
22 template <typename T0, typename T1, int SZ>
23 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
24  __esimd_ssshl(__ESIMD_raw_vec_t(T1, SZ) src0,
25  __ESIMD_raw_vec_t(T1, SZ) src1);
26 template <typename T0, typename T1, int SZ>
27 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
28  __esimd_sushl(__ESIMD_raw_vec_t(T1, SZ) src0,
29  __ESIMD_raw_vec_t(T1, SZ) src1);
30 template <typename T0, typename T1, int SZ>
31 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
32  __esimd_usshl(__ESIMD_raw_vec_t(T1, SZ) src0,
33  __ESIMD_raw_vec_t(T1, SZ) src1);
34 template <typename T0, typename T1, int SZ>
35 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
36  __esimd_uushl(__ESIMD_raw_vec_t(T1, SZ) src0,
37  __ESIMD_raw_vec_t(T1, SZ) src1);
38 template <typename T0, typename T1, int SZ>
39 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
40  __esimd_ssshl_sat(__ESIMD_raw_vec_t(T1, SZ) src0,
41  __ESIMD_raw_vec_t(T1, SZ) src1);
42 template <typename T0, typename T1, int SZ>
43 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
44  __esimd_sushl_sat(__ESIMD_raw_vec_t(T1, SZ) src0,
45  __ESIMD_raw_vec_t(T1, SZ) src1);
46 template <typename T0, typename T1, int SZ>
47 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
48  __esimd_usshl_sat(__ESIMD_raw_vec_t(T1, SZ) src0,
49  __ESIMD_raw_vec_t(T1, SZ) src1);
50 template <typename T0, typename T1, int SZ>
51 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
52  __esimd_uushl_sat(__ESIMD_raw_vec_t(T1, SZ) src0,
53  __ESIMD_raw_vec_t(T1, SZ) src1);
54 
55 template <typename T0, typename T1, int SZ>
56 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
57  __esimd_rol(__ESIMD_raw_vec_t(T1, SZ) src0, __ESIMD_raw_vec_t(T1, SZ) src1);
58 template <typename T0, typename T1, int SZ>
59 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
60  __esimd_ror(__ESIMD_raw_vec_t(T1, SZ) src0, __ESIMD_raw_vec_t(T1, SZ) src1);
61 
62 template <typename T, int SZ>
63 __ESIMD_INTRIN __ESIMD_raw_vec_t(T, SZ)
64  __esimd_umulh(__ESIMD_raw_vec_t(T, SZ) src0, __ESIMD_raw_vec_t(T, SZ) src1);
65 template <typename T, int SZ>
66 __ESIMD_INTRIN __ESIMD_raw_vec_t(T, SZ)
67  __esimd_smulh(__ESIMD_raw_vec_t(T, SZ) src0, __ESIMD_raw_vec_t(T, SZ) src1);
68 
69 template <int SZ>
70 __ESIMD_INTRIN __ESIMD_DNS::vector_type_t<float, SZ>
71 __esimd_frc(__ESIMD_DNS::vector_type_t<float, SZ> src0);
72 
73 template <typename T, int SZ>
74 __ESIMD_INTRIN __ESIMD_raw_vec_t(T, SZ)
75  __esimd_lzd(__ESIMD_raw_vec_t(T, SZ) src0);
76 
77 template <typename T0, typename T1, int SZ>
78 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
79  __esimd_bfrev(__ESIMD_raw_vec_t(T1, SZ) src0);
80 
81 template <typename T0, int SZ>
82 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
83  __esimd_bfi(__ESIMD_raw_vec_t(T0, SZ) src0, __ESIMD_raw_vec_t(T0, SZ) src1,
84  __ESIMD_raw_vec_t(T0, SZ) src2, __ESIMD_raw_vec_t(T0, SZ) src3);
85 
86 template <typename T0, int SZ>
87 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
88  __esimd_sbfe(__ESIMD_raw_vec_t(T0, SZ) src0, __ESIMD_raw_vec_t(T0, SZ) src1,
89  __ESIMD_raw_vec_t(T0, SZ) src2);
90 
91 template <typename T, int N>
92 __ESIMD_INTRIN __ESIMD_raw_vec_t(T, N)
93  __esimd_dp4(__ESIMD_raw_vec_t(T, N) v1, __ESIMD_raw_vec_t(T, N) v2)
94 #ifdef __SYCL_DEVICE_ONLY__
95  ;
96 #else
97 {
98  if constexpr (__ESIMD_DNS::is_wrapper_elem_type_v<T>)
99  __ESIMD_UNSUPPORTED_ON_HOST;
100  __ESIMD_raw_vec_t(T, N) retv;
101  for (auto i = 0; i != N; i += 4) {
102  T dp = (v1[i] * v2[i]) + (v1[i + 1] * v2[i + 1]) + (v1[i + 2] * v2[i + 2]) +
103  (v1[i + 3] * v2[i + 3]);
104  retv[i] = dp;
105  retv[i + 1] = dp;
106  retv[i + 2] = dp;
107  retv[i + 3] = dp;
108  }
109  return retv.data();
110 }
111 #endif // __SYCL_DEVICE_ONLY__
112 
113 template <typename T, typename T0, typename T1, typename T2, int N, int N1,
114  int N2>
115 SYCL_EXTERNAL SYCL_ESIMD_FUNCTION __ESIMD_DNS::vector_type_t<T, N>
116 __esimd_dpas(__ESIMD_DNS::vector_type_t<T0, N> src0,
117  __ESIMD_DNS::vector_type_t<T1, N1> src1,
118  __ESIMD_DNS::vector_type_t<T2, N2> src2, int src1_precision,
119  int src2_precision, int depth, int repeat, int sign_res,
120  int sign_acc);
121 
122 template <typename T, typename T1, typename T2, int N, int N1, int N2>
123 SYCL_EXTERNAL SYCL_ESIMD_FUNCTION __ESIMD_DNS::vector_type_t<T, N>
124 __esimd_dpas2(__ESIMD_DNS::vector_type_t<T1, N1> src1,
125  __ESIMD_DNS::vector_type_t<T2, N2> src2, int dpas_info);
126 
127 template <typename T, typename T1, typename T2, int N, int N1, int N2>
128 SYCL_EXTERNAL SYCL_ESIMD_FUNCTION __ESIMD_DNS::vector_type_t<T, N>
129 __esimd_dpasw(__ESIMD_DNS::vector_type_t<T, N> src0,
130  __ESIMD_DNS::vector_type_t<T1, N1> src1,
131  __ESIMD_DNS::vector_type_t<T2, N2> src2, int dpas_info);
132 
133 template <typename T, typename T1, typename T2, int N, int N1, int N2>
134 SYCL_EXTERNAL SYCL_ESIMD_FUNCTION __ESIMD_DNS::vector_type_t<T, N>
135 __esimd_dpasw2(__ESIMD_DNS::vector_type_t<T1, N1> src1,
136  __ESIMD_DNS::vector_type_t<T2, N2> src2, int dpas_info);
137 
138 #ifndef __SYCL_DEVICE_ONLY__
139 
140 template <typename T0, typename T1, int SZ>
141 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
142  __esimd_ssshl(__ESIMD_raw_vec_t(T1, SZ) src0,
143  __ESIMD_raw_vec_t(T1, SZ) src1) {
144  if (__ESIMD_DNS::is_wrapper_elem_type_v<T1>)
145  __ESIMD_UNSUPPORTED_ON_HOST;
146  int i;
147  typename __ESIMD_EMU_DNS::maxtype<T1>::type ret;
148  __ESIMD_raw_vec_t(T0, SZ) retv;
149 
150  for (i = 0; i < SZ; i++) {
151  SIMDCF_ELEMENT_SKIP(i);
152  ret = src0[i] << src1[i];
153  retv[i] = ret;
154  }
155  return retv;
156 }
157 
158 template <typename T0, typename T1, int SZ>
159 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
160  __esimd_sushl(__ESIMD_raw_vec_t(T1, SZ) src0,
161  __ESIMD_raw_vec_t(T1, SZ) src1) {
162  if (__ESIMD_DNS::is_wrapper_elem_type_v<T1>)
163  __ESIMD_UNSUPPORTED_ON_HOST;
164  int i;
165  typename __ESIMD_EMU_DNS::maxtype<T1>::type ret;
166  __ESIMD_raw_vec_t(T0, SZ) retv;
167 
168  for (i = 0; i < SZ; i++) {
169  SIMDCF_ELEMENT_SKIP(i);
170  ret = src0[i] << src1[i];
171  retv[i] = ret;
172  }
173  return retv;
174 }
175 
176 template <typename T0, typename T1, int SZ>
177 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
178  __esimd_usshl(__ESIMD_raw_vec_t(T1, SZ) src0,
179  __ESIMD_raw_vec_t(T1, SZ) src1) {
180  if (__ESIMD_DNS::is_wrapper_elem_type_v<T1>)
181  __ESIMD_UNSUPPORTED_ON_HOST;
182  int i;
183  typename __ESIMD_EMU_DNS::maxtype<T1>::type ret;
184  __ESIMD_raw_vec_t(T0, SZ) retv;
185 
186  for (i = 0; i < SZ; i++) {
187  SIMDCF_ELEMENT_SKIP(i);
188  ret = src0[i] << src1[i];
189  retv[i] = ret;
190  }
191  return retv;
192 }
193 
194 template <typename T0, typename T1, int SZ>
195 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
196  __esimd_uushl(__ESIMD_raw_vec_t(T1, SZ) src0,
197  __ESIMD_raw_vec_t(T1, SZ) src1) {
198  if (__ESIMD_DNS::is_wrapper_elem_type_v<T1>)
199  __ESIMD_UNSUPPORTED_ON_HOST;
200  int i;
201  typename __ESIMD_EMU_DNS::maxtype<T1>::type ret;
202  __ESIMD_raw_vec_t(T0, SZ) retv;
203 
204  for (i = 0; i < SZ; i++) {
205  SIMDCF_ELEMENT_SKIP(i);
206  ret = src0[i] << src1[i];
207  retv[i] = ret;
208  }
209  return retv;
210 }
211 
212 template <typename T0, typename T1, int SZ>
213 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
214  __esimd_ssshl_sat(__ESIMD_raw_vec_t(T1, SZ) src0,
215  __ESIMD_raw_vec_t(T1, SZ) src1) {
216  if (__ESIMD_DNS::is_wrapper_elem_type_v<T1>)
217  __ESIMD_UNSUPPORTED_ON_HOST;
218  int i;
219  typename __ESIMD_EMU_DNS::maxtype<T1>::type ret;
220  __ESIMD_raw_vec_t(T0, SZ) retv;
221 
222  for (i = 0; i < SZ; i++) {
223  SIMDCF_ELEMENT_SKIP(i);
224  ret = src0[i] << src1[i];
225  retv[i] = __ESIMD_EMU_DNS::satur<T0>::template saturate<T1>(ret, 1);
226  }
227  return retv;
228 }
229 
230 template <typename T0, typename T1, int SZ>
231 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
232  __esimd_sushl_sat(__ESIMD_raw_vec_t(T1, SZ) src0,
233  __ESIMD_raw_vec_t(T1, SZ) src1) {
234  if (__ESIMD_DNS::is_wrapper_elem_type_v<T1>)
235  __ESIMD_UNSUPPORTED_ON_HOST;
236  int i;
237  typename __ESIMD_EMU_DNS::maxtype<T1>::type ret;
238  __ESIMD_raw_vec_t(T0, SZ) retv;
239 
240  for (i = 0; i < SZ; i++) {
241  SIMDCF_ELEMENT_SKIP(i);
242  ret = src0[i] << src1[i];
243  retv[i] = __ESIMD_EMU_DNS::satur<T0>::template saturate<T1>(ret, 1);
244  }
245  return retv;
246 }
247 
248 template <typename T0, typename T1, int SZ>
249 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
250  __esimd_usshl_sat(__ESIMD_raw_vec_t(T1, SZ) src0,
251  __ESIMD_raw_vec_t(T1, SZ) src1) {
252  if (__ESIMD_DNS::is_wrapper_elem_type_v<T1>)
253  __ESIMD_UNSUPPORTED_ON_HOST;
254  int i;
255  typename __ESIMD_EMU_DNS::maxtype<T1>::type ret;
256  __ESIMD_raw_vec_t(T0, SZ) retv;
257 
258  for (i = 0; i < SZ; i++) {
259  SIMDCF_ELEMENT_SKIP(i);
260  ret = src0[i] << src1[i];
261  retv[i] = __ESIMD_EMU_DNS::satur<T0>::template saturate<T1>(ret, 1);
262  }
263  return retv;
264 }
265 
266 template <typename T0, typename T1, int SZ>
267 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
268  __esimd_uushl_sat(__ESIMD_raw_vec_t(T1, SZ) src0,
269  __ESIMD_raw_vec_t(T1, SZ) src1) {
270  if (__ESIMD_DNS::is_wrapper_elem_type_v<T1>)
271  __ESIMD_UNSUPPORTED_ON_HOST;
272  int i;
273  typename __ESIMD_EMU_DNS::maxtype<T1>::type ret;
274  __ESIMD_raw_vec_t(T0, SZ) retv;
275 
276  for (i = 0; i < SZ; i++) {
277  SIMDCF_ELEMENT_SKIP(i);
278  ret = src0[i] << src1[i];
279  retv[i] = __ESIMD_EMU_DNS::satur<T0>::template saturate<T1>(ret, 1);
280  }
281  return retv;
282 }
283 
284 template <typename T0, typename T1, int SZ>
285 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
286  __esimd_rol(__ESIMD_raw_vec_t(T1, SZ) src0,
287  __ESIMD_raw_vec_t(T1, SZ) src1) {
288  __ESIMD_UNSUPPORTED_ON_HOST;
289 }
290 
291 template <typename T0, typename T1, int SZ>
292 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
293  __esimd_ror(__ESIMD_raw_vec_t(T1, SZ) src0,
294  __ESIMD_raw_vec_t(T1, SZ) src1) {
295  __ESIMD_UNSUPPORTED_ON_HOST;
296 }
297 
298 template <typename T, int SZ>
299 __ESIMD_INTRIN __ESIMD_raw_vec_t(T, SZ)
300  __esimd_umulh(__ESIMD_raw_vec_t(T, SZ) src0,
301  __ESIMD_raw_vec_t(T, SZ) src1) {
302  if (__ESIMD_DNS::is_wrapper_elem_type_v<T>)
303  __ESIMD_UNSUPPORTED_ON_HOST;
304  int i;
305  __ESIMD_raw_vec_t(T, SZ) retv;
306 
307  for (i = 0; i < SZ; i++) {
308  unsigned long long temp;
309  SIMDCF_ELEMENT_SKIP(i);
310  temp = (long long)src0[i] * (long long)src1[i];
311  retv[i] = temp >> 32;
312  }
313  return retv;
314 }
315 
316 template <typename T, int SZ>
317 __ESIMD_INTRIN __ESIMD_raw_vec_t(T, SZ)
318  __esimd_smulh(__ESIMD_raw_vec_t(T, SZ) src0,
319  __ESIMD_raw_vec_t(T, SZ) src1) {
320  if (__ESIMD_DNS::is_wrapper_elem_type_v<T>)
321  __ESIMD_UNSUPPORTED_ON_HOST;
322  int i;
323  __ESIMD_raw_vec_t(T, SZ) retv;
324 
325  for (i = 0; i < SZ; i++) {
326  long long temp;
327  SIMDCF_ELEMENT_SKIP(i);
328  temp = (long long)src0[i] * (long long)src1[i];
329  retv[i] = temp >> 32;
330  }
331  return retv;
332 }
333 
334 template <int SZ>
335 __ESIMD_INTRIN __ESIMD_DNS::vector_type_t<float, SZ>
336 __esimd_frc(__ESIMD_DNS::vector_type_t<float, SZ> src0) {
337  __ESIMD_DNS::vector_type_t<float, SZ> retv;
338  for (int i = 0; i < SZ; i++) {
339  SIMDCF_ELEMENT_SKIP(i);
340  retv[i] = src0[i] - floor(src0[i]);
341  }
342  return retv;
343 }
344 
345 template <typename T, int SZ>
346 __ESIMD_INTRIN __ESIMD_raw_vec_t(T, SZ)
347  __esimd_lzd(__ESIMD_raw_vec_t(T, SZ) src0) {
348  if (__ESIMD_DNS::is_wrapper_elem_type_v<T>)
349  __ESIMD_UNSUPPORTED_ON_HOST;
350  int i;
351  T ret;
352  __ESIMD_raw_vec_t(T, SZ) retv;
353 
354  for (i = 0; i < SZ; i++) {
355  SIMDCF_ELEMENT_SKIP(i);
356  ret = src0[i];
357  uint32_t cnt = 0;
358  while ((ret & 1u << 31u) == 0 && cnt != 32) {
359  cnt++;
360  ret = ret << 1;
361  }
362  retv[i] = cnt;
363  }
364 
365  return retv;
366 }
367 
368 template <typename T0, typename T1, int SZ>
369 __ESIMD_INTRIN __ESIMD_raw_vec_t(T0, SZ)
370  __esimd_bfrev(__ESIMD_raw_vec_t(T1, SZ) src0) {
371  int i, j;
372  if (__ESIMD_DNS::is_wrapper_elem_type_v<T1>)
373  __ESIMD_UNSUPPORTED_ON_HOST;
374  __ESIMD_raw_vec_t(T0, SZ) retv;
375 
376  for (i = 0; i < SZ; i++) {
377  SIMDCF_ELEMENT_SKIP(i);
378  T0 input = src0[i];
379  T0 output = 0;
380  for (j = 0; j < sizeof(T0) * 8; j++) {
381  output |= input & 0x1;
382 
383  // Don't shift if this was the last one
384  if ((j + 1) < (sizeof(T0) * 8)) {
385  output <<= 1;
386  input >>= 1;
387  }
388  }
389  retv[i] = output;
390  }
391 
392  return retv;
393 }
394 
395 template <typename T, int SZ>
396 __ESIMD_INTRIN __ESIMD_raw_vec_t(T, SZ)
397  __esimd_bfi(__ESIMD_raw_vec_t(T, SZ) width, __ESIMD_raw_vec_t(T, SZ) offset,
398  __ESIMD_raw_vec_t(T, SZ) val, __ESIMD_raw_vec_t(T, SZ) src) {
399  if (__ESIMD_DNS::is_wrapper_elem_type_v<T>)
400  __ESIMD_UNSUPPORTED_ON_HOST;
401  int i;
402  typename __ESIMD_EMU_DNS::maxtype<T>::type ret;
403  __ESIMD_raw_vec_t(T, SZ) retv;
404 
405  for (i = 0; i < SZ; i++) {
406  SIMDCF_ELEMENT_SKIP(i);
407  const uint32_t mask = ((1 << width[i]) - 1) << offset[i];
408  const uint32_t imask = ~mask;
409  ret = (src[i] & imask) | ((val[i] << offset[i] & mask));
410  // Sign extend if signed type
411  if constexpr (std::is_signed<T>::value) {
412  int m = 1U << (width[i] - 1);
413  ret = (ret ^ m) - m;
414  }
415  retv[i] = ret;
416  }
417 
418  return retv;
419 }
420 
421 template <typename T, int SZ>
422 __ESIMD_INTRIN __ESIMD_raw_vec_t(T, SZ)
423  __esimd_sbfe(__ESIMD_raw_vec_t(T, SZ) width,
424  __ESIMD_raw_vec_t(T, SZ) offset,
425  __ESIMD_raw_vec_t(T, SZ) src) {
426  if (__ESIMD_DNS::is_wrapper_elem_type_v<T>)
427  __ESIMD_UNSUPPORTED_ON_HOST;
428  int i;
429  typename __ESIMD_EMU_DNS::maxtype<T>::type ret;
430  __ESIMD_raw_vec_t(T, SZ) retv;
431 
432  for (i = 0; i < SZ; i++) {
433  SIMDCF_ELEMENT_SKIP(i);
434  const uint32_t mask = ((1 << width[i]) - 1) << offset[i];
435  ret = (src[i] & mask) >> offset[i];
436  retv[i] = ret;
437  }
438 
439  return retv;
440 }
441 
442 inline constexpr __ESIMD_NS::uint
443 __esimd_dpas_bits_precision(__ESIMD_ENS::argument_type precisionType) {
444  return precisionType == __ESIMD_ENS::argument_type::TF32 ? 32
445  : precisionType == __ESIMD_ENS::argument_type::BF16 ||
446  precisionType == __ESIMD_ENS::argument_type::FP16
447  ? 16
448  : precisionType == __ESIMD_ENS::argument_type::S8 ||
449  precisionType == __ESIMD_ENS::argument_type::U8
450  ? 8
451  : precisionType == __ESIMD_ENS::argument_type::S4 ||
452  precisionType == __ESIMD_ENS::argument_type::U4
453  ? 4
454  : precisionType == __ESIMD_ENS::argument_type::S2 ||
455  precisionType == __ESIMD_ENS::argument_type::U2
456  ? 2
457  : 1;
458 }
459 
460 template <__ESIMD_ENS::argument_type src1_precision,
461  __ESIMD_ENS::argument_type src2_precision, int systolic_depth,
462  int repeat_count, typename RT, typename T1, typename T2,
464 inline __ESIMD_DNS::vector_type_t<RT, SZ>
465 __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t<RT, SZ> *src0,
466  const __ESIMD_DNS::vector_type_t<T1, N1> &src1,
467  const __ESIMD_DNS::vector_type_t<T2, N2> &src2) {
468  __ESIMD_DNS::vector_type_t<RT, SZ> retv;
469 
470  __ESIMD_NS::uint sat1 =
471  __ESIMD_EMU_DNS::SetSatur<
472  T1, __ESIMD_EMU_DNS::is_inttype<RT>::value>::set() ||
473  __ESIMD_EMU_DNS::SetSatur<T2,
474  __ESIMD_EMU_DNS::is_inttype<RT>::value>::set();
475 
476  constexpr __ESIMD_NS::uint ops_per_chan =
477  src1_precision == __ESIMD_ENS::argument_type::BF16 ||
478  src1_precision == __ESIMD_ENS::argument_type::FP16 ||
479  src2_precision == __ESIMD_ENS::argument_type::BF16 ||
480  src2_precision == __ESIMD_ENS::argument_type::FP16
481  ? 2
482  : src1_precision == __ESIMD_ENS::argument_type::S8 ||
483  src1_precision == __ESIMD_ENS::argument_type::U8 ||
484  src2_precision == __ESIMD_ENS::argument_type::S8 ||
485  src2_precision == __ESIMD_ENS::argument_type::U8
486  ? 4
487  : 8;
488 
489  __ESIMD_NS::uint V = 0, U = 0, k = 0, temp = 0, src1_ops_per_dword = 0, p = 0;
490 
491  constexpr auto src1_el_bits = __esimd_dpas_bits_precision(src1_precision);
492  constexpr auto src2_el_bits = __esimd_dpas_bits_precision(src2_precision);
493 
494  uint32_t src1_signed =
495  src1_precision == __ESIMD_ENS::argument_type::S2 ||
496  src1_precision == __ESIMD_ENS::argument_type::S4 ||
497  src1_precision == __ESIMD_ENS::argument_type::S8
498  ? 1
499  : 0;
500 
501  uint32_t src2_signed =
502  src2_precision == __ESIMD_ENS::argument_type::S2 ||
503  src2_precision == __ESIMD_ENS::argument_type::S4 ||
504  src2_precision == __ESIMD_ENS::argument_type::S8
505  ? 1
506  : 0;
507 
508 #if defined(ESIMD_XE_HPC) || defined(ESIMD_XE_HPG)
509  constexpr bool isPvc = true;
510  constexpr size_t SIMDSize = 16;
511 #else
512  constexpr bool isPvc = false;
513  constexpr size_t SIMDSize = 8;
514 #endif
515 
516  constexpr bool
517  pvcHfDest = isPvc && std::is_same<RT, __ESIMD_EMU_DNS::half>::value,
518  pvcBfDest = isPvc && std::is_same<RT, short>::value,
519  pvcBfOrHfDest = pvcBfDest || pvcHfDest,
520 
521  pvcBfDestChecks = pvcBfDest &&
522  src1_precision == __ESIMD_ENS::argument_type::BF16 &&
523  src2_precision == __ESIMD_ENS::argument_type::BF16,
524 
525  pvcHfDestChecks =
526  pvcHfDest && ((src1_precision == __ESIMD_ENS::argument_type::FP16 &&
527  src2_precision == __ESIMD_ENS::argument_type::FP16) ||
528  (src1_precision == __ESIMD_ENS::argument_type::BF16 &&
529  src2_precision == __ESIMD_ENS::argument_type::BF16)),
530 
531  destTypeChk =
532  (!pvcBfOrHfDest && __ESIMD_EMU_DNS::is_fp_or_dword_type<RT>::value) ||
533  (pvcBfOrHfDest && (pvcBfDestChecks || pvcHfDestChecks)),
534 
535  srcTypeChk = __ESIMD_EMU_DNS::is_dword_type<T1>::value &&
536  __ESIMD_EMU_DNS::is_dword_type<T2>::value,
537 
538  destSizeChk = SZ >= /*TODO: ==*/SIMDSize * repeat_count,
539 
540  systolicDepthAndRepeatCountChk =
541  systolic_depth == 8 && repeat_count >= 1 && repeat_count <= 8,
542 
543  src1CountChk =
544  N1 == ((src1_el_bits * systolic_depth * ops_per_chan * SZ) /
545  (repeat_count * sizeof(T1) * 8)),
546  src2CountChk =
547  N2 >= ((src2_el_bits * systolic_depth * ops_per_chan * repeat_count) /
548  (sizeof(T2) * 8))
549  /*TODO: ==; fix PVCIGEMM24*/
550  ;
551 
552  if constexpr (!isPvc)
553  static_assert(!pvcBfOrHfDest, "dpas: hfloat and bfloat16 destination "
554  "element type is only supported on PVC.");
555  static_assert(destTypeChk, "dpas: unsupported dest and accumulator type.");
556  static_assert(srcTypeChk, "dpas: unsupported src element type.");
557  static_assert(destSizeChk,
558  "dpas: destination size must be SIMDSize x repeat_count.");
559  static_assert(systolicDepthAndRepeatCountChk,
560  "dpas: only systolic_depth = 8 and repeat_count of 1 to 8 are "
561  "supported.");
562  static_assert(src1CountChk, "dpas: invalid size for src1.");
563  static_assert(src2CountChk, "dpas: invalid size for src2.");
564 
565  using TmpAccEl = typename std::conditional<
566  pvcBfOrHfDest, float,
567  typename __ESIMD_EMU_DNS::restype_ex<
568  RT, typename __ESIMD_EMU_DNS::restype_ex<T1, T2>::type>::type>::type;
569 
570  __ESIMD_DNS::vector_type_t<TmpAccEl, SIMDSize> simdAcc;
571 
572  for (uint r = 0; r < repeat_count; r++) {
573  V = r;
574  k = 0;
575 
576  for (uint n = 0; n < SIMDSize; n++) {
577  if (src0 != nullptr) {
578  auto src0El = src0[0][r * SIMDSize + n];
579 
580  if (pvcBfDest) {
581  const auto tmp = (uint32_t)(src0El) << 16;
582  simdAcc[n] = reinterpret_cast<const TmpAccEl &>(tmp);
583  } else
584  simdAcc[n] = src0El;
585  } else
586  simdAcc[n] = 0;
587  }
588 
589  for (uint s = 0; s < systolic_depth; s++) {
590  src1_ops_per_dword = 32 / (ops_per_chan * src1_el_bits);
591  // U = s / src1_ops_per_dword;
592  U = s >> uint(log2(src1_ops_per_dword));
593 
594  for (uint n = 0; n < SIMDSize; n++) {
595  for (uint d = 0; d < ops_per_chan; d++) {
596  p = d + (s % src1_ops_per_dword) * ops_per_chan;
597  uint32_t extension_temp = false;
598 
599  if (src2_precision == __ESIMD_ENS::argument_type::BF16) {
600  const auto s1 =
601  extract<uint32_t>(src1_el_bits, p * src1_el_bits,
602  src1[U * SIMDSize + n], extension_temp)
603  << 16;
604  const auto s2 =
605  extract<uint32_t>(src2_el_bits, d * src2_el_bits,
606  src2[V * 8 + k / ops_per_chan], src2_signed)
607  << 16;
608  simdAcc[n] += reinterpret_cast<const float &>(s2) *
609  reinterpret_cast<const float &>(s1);
610  } else if (src2_precision == __ESIMD_ENS::argument_type::FP16) {
611  const auto s1 =
612  extract<short>(src1_el_bits, p * src1_el_bits,
613  src1[U * SIMDSize + n], extension_temp);
614  const auto s2 =
615  extract<short>(src2_el_bits, d * src2_el_bits,
616  src2[V * 8 + k / ops_per_chan], src2_signed);
617  simdAcc[n] += reinterpret_cast<const __ESIMD_EMU_DNS::half &>(s1) *
618  reinterpret_cast<const __ESIMD_EMU_DNS::half &>(s2);
619  } else {
620  int src = (sizeof(T2) * 8) / (ops_per_chan * src2_el_bits);
621  int off = s % src * (ops_per_chan * src2_el_bits);
622  int src1_tmp = extract<T1>(src1_el_bits, p * src1_el_bits,
623  src1[U * SIMDSize + n], src1_signed);
624  int src2_tmp = extract<T2>(src2_el_bits, d * src2_el_bits + off,
625  src2[(V * 8 + k / ops_per_chan) / src],
626  src2_signed);
627  simdAcc[n] += src1_tmp * src2_tmp;
628  }
629  }
630  }
631 
632  k += ops_per_chan;
633 
634  } // Systolic phase.
635 
636  for (uint n = 0; n < SIMDSize; n++) {
637  if constexpr (pvcBfDest) {
638  // TODO: make abstraction, support saturation, review rounding algo for
639  // corner cases.
640  auto tmpFloat = simdAcc[n];
641  auto tmpUint = reinterpret_cast<uint32_t &>(tmpFloat);
642  if (std::isnormal(tmpFloat) && tmpUint & 1ull << 15 &&
643  (tmpUint & 0x7fff || tmpUint & 1ull << 16)) {
644  tmpUint += 1ull << 16;
645  }
646  retv[r * SIMDSize + n] =
647  static_cast<short>(reinterpret_cast<uint32_t &>(tmpUint) >> 16);
648  } else
649  retv[r * SIMDSize + n] =
650  __ESIMD_EMU_DNS::satur<RT>::saturate(simdAcc[n], sat1);
651  }
652 
653  } // Repeat.
654 
655  return retv;
656 }
657 
658 template <__ESIMD_ENS::argument_type src1_precision,
659  __ESIMD_ENS::argument_type src2_precision, int systolic_depth,
660  int repeat_count, typename T, typename T0, typename T1, typename T2,
661  int N, int N1, int N2>
662 inline __ESIMD_DNS::vector_type_t<T, N>
663 __esimd_dpas(__ESIMD_DNS::vector_type_t<T0, N> src0,
664  __ESIMD_DNS::vector_type_t<T1, N1> src1,
665  __ESIMD_DNS::vector_type_t<T2, N2> src2) {
666 #ifdef __SYCL_EXPLICIT_SIMD_PLUGIN__
667  return __esimd_dpas_inner<src1_precision, src2_precision, systolic_depth,
668  repeat_count, T, T1, T2, N, N1, N2>(
669  std::addressof(src0), src1, src2);
670 #else // __SYCL_EXPLICIT_SIMD_PLUGIN__
671  __ESIMD_UNSUPPORTED_ON_HOST;
672  return __ESIMD_DNS::vector_type_t<T, N>();
673 #endif // __SYCL_EXPLICIT_SIMD_PLUGIN__
674 }
675 
676 template <__ESIMD_ENS::argument_type src1_precision,
677  __ESIMD_ENS::argument_type src2_precision, int systolic_depth,
678  int repeat_count, typename T, typename T1, typename T2, int N, int N1,
679  int N2>
680 inline __ESIMD_DNS::vector_type_t<T, N>
681 __esimd_dpas2(__ESIMD_DNS::vector_type_t<T1, N1> src1,
682  __ESIMD_DNS::vector_type_t<T2, N2> src2) {
683 #ifdef __SYCL_EXPLICIT_SIMD_PLUGIN__
684  return __esimd_dpas_inner<src1_precision, src2_precision, systolic_depth,
685  repeat_count, T, T1, T2, N, N1, N2>(nullptr, src1,
686  src2);
687 #else // __SYCL_EXPLICIT_SIMD_PLUGIN__
688  __ESIMD_UNSUPPORTED_ON_HOST;
689  return __ESIMD_DNS::vector_type_t<T, N>();
690 #endif // __SYCL_EXPLICIT_SIMD_PLUGIN__
691 }
692 
693 template <__ESIMD_ENS::argument_type src1_precision,
694  __ESIMD_ENS::argument_type src2_precision, int systolic_depth,
695  int repeat_count, typename T, typename T1, typename T2, int N, int N1,
696  int N2>
697 inline __ESIMD_DNS::vector_type_t<T, N>
698 __esimd_dpasw(__ESIMD_DNS::vector_type_t<T, N> src0,
699  __ESIMD_DNS::vector_type_t<T1, N1> src1,
700  __ESIMD_DNS::vector_type_t<T2, N2> src2) {
701  __ESIMD_UNSUPPORTED_ON_HOST;
702  return __ESIMD_DNS::vector_type_t<T, N>();
703 }
704 
705 template <__ESIMD_ENS::argument_type src1_precision,
706  __ESIMD_ENS::argument_type src2_precision, int systolic_depth,
707  int repeat_count, typename T, typename T1, typename T2, int N, int N1,
708  int N2>
709 inline __ESIMD_DNS::vector_type_t<T, N>
710 __esimd_dpasw2(__ESIMD_DNS::vector_type_t<T1, N1> src1,
711  __ESIMD_DNS::vector_type_t<T2, N2> src2) {
712  __ESIMD_UNSUPPORTED_ON_HOST;
713  return __ESIMD_DNS::vector_type_t<T, N>();
714 }
715 
716 #endif // #ifdef __SYCL_DEVICE_ONLY__
717 
718 #undef __ESIMD_raw_vec_t
719 #undef __ESIMD_cpp_vec_t
720 
cl::sycl::ext::intel::esimd::saturate
__ESIMD_API std::enable_if_t<!detail::is_generic_floating_point_v< T0 >||std::is_same_v< T1, T0 >, simd< T0, SZ > > saturate(simd< T1, SZ > src)
Conversion of input vector elements of type T1 into vector of elements of type T0 with saturation.
Definition: math.hpp:71
T
SYCL_EXTERNAL
#define SYCL_EXTERNAL
Definition: defines_elementary.hpp:30
cl::sycl::floor
detail::enable_if_t< detail::is_genfloat< T >::value, T > floor(T x) __NOEXC
Definition: builtins.hpp:190
sycl
Definition: invoke_simd.hpp:68
cl::sycl::bundle_state::input
@ input
cl::sycl::ext::intel::experimental::esimd::argument_type
argument_type
Definition: common.hpp:29
cl::sycl::half
sycl::detail::half_impl::half half
Definition: aliases.hpp:77
math_intrin.hpp
cl::sycl::image_channel_order::r
@ r
cl::sycl::log2
detail::enable_if_t< __FAST_MATH_GENFLOAT(T), T > log2(T x) __NOEXC
Definition: builtins.hpp:312
cl::sycl::isnormal
detail::common_rel_ret_t< T > isnormal(T x) __NOEXC
Definition: builtins.hpp:1240
cl::sycl::uint
unsigned int uint
Definition: aliases.hpp:73