25#pragma clang diagnostic push
26#pragma clang diagnostic ignored "-Wunused-parameter"
38template <
typename T0,
typename T1,
int SZ>
41 "The internal types are not yet supported!");
42 return __ESIMD_NS::abs<T0, T1, SZ>(src0);
50template <
typename T0,
typename T1>
55 "The internal types are not yet supported!");
56 return __ESIMD_NS::abs<T0, T1>(src0);
66template <
typename T1,
int SZ>
69 "The internal types are not yet supported!");
70 return __ESIMD_NS::abs<T1, SZ>(src0);
82 "The internal types are not yet supported!");
83 return __ESIMD_NS::abs<T1>(src0);
96template <
typename T,
int SZ,
typename Sat = xetla_saturation_off_tag>
100 "The internal types are not yet supported!");
101 return __ESIMD_NS::max<T, SZ>(src0, src1, Sat::value);
114template <
typename T,
int SZ,
typename Sat = xetla_saturation_off_tag>
118 "The internal types are not yet supported!");
119 return __ESIMD_NS::max<T, SZ>(src0, src1, Sat::value);
132template <
typename T,
int SZ,
typename Sat = xetla_saturation_off_tag>
136 "The internal types are not yet supported!");
137 return __ESIMD_NS::max<T, SZ>(src0, src1, Sat::value);
148template <
typename T,
typename Sat = xetla_saturation_off_tag>
151 "The internal types are not yet supported!");
152 return __ESIMD_NS::max<T>(src0, src1, Sat::value);
165template <
typename T,
int SZ,
typename Sat = xetla_saturation_off_tag>
169 "The internal types are not yet supported!");
170 return __ESIMD_NS::min<T, SZ>(src0, src1, Sat::value);
183template <
typename T,
int SZ,
typename Sat = xetla_saturation_off_tag>
187 "The internal types are not yet supported!");
188 return __ESIMD_NS::min<T, SZ>(src0, src1, Sat::value);
201template <
typename T,
int SZ,
typename Sat = xetla_saturation_off_tag>
205 "The internal types are not yet supported!");
206 return __ESIMD_NS::min<T, SZ>(src0, src1, Sat::value);
217template <
typename T,
typename Sat = xetla_saturation_off_tag>
220 "The internal types are not yet supported!");
221 return __ESIMD_NS::min<T>(src0, src1, Sat::value);
231template <
class T,
int SZ,
typename Sat = xetla_saturation_off_tag>
234 static_assert((std::is_same<remove_const_t<T>,
float>::value)
236 "Only support fp32 and fp16");
237 return __ESIMD_NS::exp<T, SZ>(src, Sat::value);
246template <
class T,
typename Sat = xetla_saturation_off_tag>
248 static_assert((std::is_same<remove_const_t<T>,
float>::value)
250 "Only support fp32 and fp16");
251 return __ESIMD_NS::exp<T>(src, Sat::value);
261template <
class T,
int SZ,
typename Sat = xetla_saturation_off_tag>
264 static_assert((std::is_same<remove_const_t<T>,
float>::value)
266 "Only support fp32 and fp16");
267 return __ESIMD_NS::exp2<T, SZ>(src, Sat::value);
276template <
class T,
typename Sat = xetla_saturation_off_tag>
278 static_assert((std::is_same<remove_const_t<T>,
float>::value)
280 "Only support fp32 and fp16");
281 return __ESIMD_NS::exp2<T>(src, Sat::value);
290template <
typename T,
int SZ,
typename Sat = xetla_saturation_off_tag>
293 static_assert((std::is_same<remove_const_t<T>,
float>::value)
295 "Only support fp32 and fp16");
296 return __ESIMD_NS::inv(src, Sat::value);
304template <
typename T,
typename Sat = xetla_saturation_off_tag>
306 static_assert((std::is_same<remove_const_t<T>,
float>::value)
308 "Only support fp32 and fp16");
309 return __ESIMD_NS::inv(src, Sat::value);
318template <
typename T,
int SZ,
typename Sat = xetla_saturation_off_tag>
321 static_assert((std::is_same<remove_const_t<T>,
float>::value)
323 "Only support fp32 and fp16");
324 return __ESIMD_NS::sqrt(src, Sat::value);
332template <
typename T,
typename Sat = xetla_saturation_off_tag>
334 static_assert((std::is_same<remove_const_t<T>,
float>::value)
336 "Only support fp32 and fp16");
337 return __ESIMD_NS::sqrt(src, Sat::value);
346template <
typename T,
int SZ,
typename Sat = xetla_saturation_off_tag>
349 static_assert((std::is_same<remove_const_t<T>,
float>::value)
351 "Only support fp32 and fp16");
352 return __ESIMD_NS::sqrt_ieee(src, Sat::value);
360template <
typename T,
typename Sat = xetla_saturation_off_tag>
362 static_assert((std::is_same<remove_const_t<T>,
float>::value)
364 "Only support fp32 and fp16");
365 return __ESIMD_NS::sqrt_ieee(src, Sat::value);
374template <
typename T,
int SZ,
typename Sat = xetla_saturation_off_tag>
377 static_assert((std::is_same<remove_const_t<T>,
float>::value)
379 "Only support fp32 and fp16");
380 return __ESIMD_NS::rsqrt(src, Sat::value);
388template <
typename T,
typename Sat = xetla_saturation_off_tag>
390 static_assert((std::is_same<remove_const_t<T>,
float>::value)
392 "Only support fp32 and fp16");
393 return __ESIMD_NS::rsqrt(src, Sat::value);
401template <
typename T,
int SZ>
403 static_assert(std::is_same<remove_const_t<T>,
float>::value,
404 "Only support fp32! ");
405 constexpr uint32_t flag_elems = 8 * 16;
407 if constexpr (SZ / flag_elems > 0) {
409 for (uint32_t i = 0; i < SZ / flag_elems; i++) {
410 auto src_sub = src.xetla_select<flag_elems, 1>(i * flag_elems);
411 auto ret_sub = ret.xetla_select<flag_elems, 1>(i * flag_elems);
414 = xetla_exp<T, flag_elems>(src_sub * 2.f);
415 ret_sub = (exp2x - 1.f) / (exp2x + 1.f);
417 ret_sub.xetla_merge(ones, mask);
421 if constexpr (SZ % flag_elems != 0) {
422 constexpr uint32_t start_pos = SZ / flag_elems * flag_elems;
423 constexpr uint32_t remain_elems = SZ % flag_elems;
425 auto src_sub = src.xetla_select<remain_elems, 1>(start_pos);
426 auto ret_sub = ret.xetla_select<remain_elems, 1>(start_pos);
429 = xetla_exp<T, remain_elems>(src_sub * 2.f);
430 ret_sub = (exp2x - 1.f) / (exp2x + 1.f);
432 ret_sub.xetla_merge(ones, mask);
445 static_assert(std::is_same<remove_const_t<T>,
float>::value,
446 "Only support fp32! ");
447 T exp2x = xetla_exp<T>(src * 2.f);
448 T ret = (exp2x - 1.f) / (exp2x + 1.f);
449 return (src >= 10) ? 1 : ret;
459template <
typename T,
int SZ>
462 static_assert((std::is_same<remove_const_t<T>, uint32_t>::value),
463 "For addc, only uint32_t is supported");
477template <
typename T,
int SZ>
480 static_assert((std::is_same<remove_const_t<T>, uint32_t>::value),
481 "For addc, only uint32_t is supported");
497template <
typename T0,
typename T1,
typename T2,
int SZ>
502 = __ESIMD_ENS::imul<T0, T1, T2, SZ>(lo_tmp, src0, src1);
519template <
typename T0,
typename T1,
int SZ, reduce_op BinaryOperation>
522 return __ESIMD_NS::detail::sum<T0, T1, SZ>(v);
524 return __ESIMD_NS::detail::prod<T0, T1, SZ>(v);
526 return __ESIMD_NS::hmin<T0, T1, SZ>(v);
528 return __ESIMD_NS::hmax<T0, T1, SZ>(v);
537template <
typename T,
int SZ>
540 "The internal types are not yet supported!");
541 return __ESIMD_NS::rnde<T, SZ>(src0);
554template <
typename T1,
typename T0,
int SZ,
555 typename Sat = xetla_saturation_off_tag>
560 "The internal types are not yet supported!");
562 xetla_vector<T0, SZ> temp = src0 + src1;
563 if constexpr (std::is_same_v<Sat, xetla_saturation_off_tag>)
564 return xetla_vector<T1, SZ>(temp);
566 return __ESIMD_NS::saturate<T1, T0, SZ>(temp);
573template <
typename T1,
typename T0,
int SZ>
577 "The internal types are not yet supported!");
578 return __ESIMD_NS::saturate<T1, T0, SZ>(src);
584#pragma clang diagnostic pop
typename std::remove_const< T >::type remove_const_t
Definition common.hpp:26
#define __XETLA_API
Definition common.hpp:43
Workaround for ESIMD vector(1D) ref type.
Definition base_types.hpp:187
#define __REF__
Workaround for ESIMD reference usage.
Definition base_types.hpp:177
sycl::half fp16
xetla fp16 data type.
Definition base_types.hpp:43
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
__XETLA_API T0 xetla_reduce(xetla_vector< T1, SZ > v)
Performs reduction over elements of the input vector.
Definition math_general.hpp:520
__XETLA_API xetla_vector< T0, SZ > xetla_abs(xetla_vector< T1, SZ > src0)
Get absolute value (vector version)
Definition math_general.hpp:39
__XETLA_API xetla_vector< T, SZ > xetla_max(xetla_vector< T, SZ > src0, xetla_vector< T, SZ > src1, Sat sat={})
Selects component-wise the maximum of the two vectors.
Definition math_general.hpp:97
__XETLA_API xetla_vector< T, SZ > xetla_exp(xetla_vector< T, SZ > src, Sat sat={})
Calculate exponent value for each element of the input vector, the base is e.
Definition math_general.hpp:232
__XETLA_API xetla_vector< T, SZ > xetla_min(xetla_vector< T, SZ > src0, xetla_vector< T, SZ > src1, Sat sat={})
Selects component-wise the minimum of the two vectors.
Definition math_general.hpp:166
__XETLA_API xetla_vector< T1, SZ > xetla_sat(xetla_vector< T0, SZ > src)
Saturation function.
Definition math_general.hpp:574
__XETLA_API xetla_vector< T, SZ > xetla_sqrt(xetla_vector< T, SZ > src, Sat sat={})
Calculate the square root, i.e.
Definition math_general.hpp:319
__XETLA_API xetla_vector< T0, SZ > xetla_imul(xetla_vector_ref< T0, SZ > __REF__ lo, xetla_vector< T1, SZ > src0, T2 src1)
Multiply src0 with src1, return the hi part and in-place update the lo part.
Definition math_general.hpp:498
__XETLA_API xetla_vector< T, SZ > xetla_exp2(xetla_vector< T, SZ > src, Sat sat={})
Calculate exponent value for each element of the input vector, the base is 2.
Definition math_general.hpp:262
__XETLA_API xetla_vector< T, SZ > xetla_tanh(xetla_vector< T, SZ > src)
Calculate the tanh (vector version).
Definition math_general.hpp:402
__XETLA_API xetla_vector< T, SZ > xetla_rsqrt(xetla_vector< T, SZ > src, Sat sat={})
Calculate the inversion of square root, i.e.
Definition math_general.hpp:375
__XETLA_API xetla_vector< T, SZ > xetla_add_c(xetla_vector< T, SZ > src0, xetla_vector< T, SZ > src1, xetla_vector_ref< T, SZ > __REF__ carry)
Add two unsigned integer vectors, return the result and in-place update the carry.
Definition math_general.hpp:460
__XETLA_API xetla_vector< T, SZ > xetla_inv(xetla_vector< T, SZ > src, Sat sat={})
Calculate the inversion, i.e.
Definition math_general.hpp:291
__XETLA_API xetla_vector< T, SZ > xetla_sqrt_ieee(xetla_vector< T, SZ > src, Sat sat={})
Calculate the square root, i.e.
Definition math_general.hpp:347
__XETLA_API xetla_vector< T1, SZ > xetla_add(xetla_vector< T0, SZ > src0, xetla_vector< T0, SZ > src1, Sat sat={})
Adds two vectors with saturation The source operands must be both of floating-point type.
Definition math_general.hpp:556
__XETLA_API xetla_vector< T, SZ > xetla_rnde(xetla_vector< T, SZ > src0)
Get rounded value.
Definition math_general.hpp:538
Definition arch_config.hpp:24
Used to check if the type is xetla internal data type.
Definition base_types.hpp:67
static constexpr bool value
Definition base_types.hpp:68