90 namespace ext::intel::esimd::detail {
114 enum class CmpOp { lt, lte, gte, gt, eq, ne };
116 enum class UnaryOp { minus,
plus, bit_not, log_not };
118 struct invalid_raw_element_type;
128 template <
class T,
class SFINAE =
void>
struct element_type_traits {
131 using RawT = invalid_raw_element_type;
135 using EnclosingCppT = void;
138 static inline constexpr
bool use_native_cpp_ops =
true;
142 static inline constexpr
bool is_floating_point =
false;
147 struct element_type_traits<T,
std::
enable_if_t<is_vectorizable_v<T>>> {
149 using EnclosingCppT = T;
150 static inline constexpr
bool use_native_cpp_ops =
true;
151 static inline constexpr
bool is_floating_point = std::is_floating_point_v<T>;
156 template <
class T>
using __raw_t =
typename element_type_traits<T>::RawT;
158 using __cpp_t =
typename element_type_traits<T>::EnclosingCppT;
160 template <
class T,
int N>
161 using __raw_vec_t = vector_type_t<typename element_type_traits<T>::RawT, N>;
165 template <
class T,
int N>
166 using __cmp_t = decltype(std::declval<__raw_vec_t<T, N>>() <
167 std::declval<__raw_vec_t<T, N>>());
171 static inline constexpr
bool is_wrapper_elem_type_v =
172 !std::is_same_v<__raw_t<T>, invalid_raw_element_type> &&
173 !std::is_same_v<__raw_t<T>, T>;
176 static inline constexpr
bool is_valid_simd_elem_type_v =
177 (is_vectorizable_v<T> || is_wrapper_elem_type_v<T>);
181 template <
class WrapperT,
int N>
struct vector_conversion_traits {
182 static_assert(is_wrapper_elem_type_v<WrapperT>,
"");
183 using StdT = __cpp_t<WrapperT>;
184 using RawT = __raw_t<WrapperT>;
186 static vector_type_t<RawT, N> convert_to_raw(vector_type_t<StdT, N>);
187 static vector_type_t<StdT, N> convert_to_cpp(vector_type_t<RawT, N>);
190 template <
class WrapperT>
struct scalar_conversion_traits {
191 static_assert(is_wrapper_elem_type_v<WrapperT>,
"");
192 using RawT = __raw_t<WrapperT>;
194 static RawT bitcast_to_raw(WrapperT);
195 static WrapperT bitcast_to_wrapper(RawT);
200 template <BinOp Op,
class WrapperT>
struct scalar_binary_op_traits {
201 static_assert(is_wrapper_elem_type_v<WrapperT>,
"");
203 static WrapperT impl(WrapperT X, WrapperT Y);
206 template <BinOp Op,
class WrapperT,
int N>
struct vector_binary_op_traits {
207 static_assert(is_wrapper_elem_type_v<WrapperT>,
"");
208 using RawVecT = __raw_vec_t<WrapperT, N>;
210 static RawVecT impl(RawVecT X, RawVecT Y);
215 template <CmpOp Op,
class WrapperT>
struct scalar_comparison_op_traits {
216 static_assert(is_wrapper_elem_type_v<WrapperT>,
"");
218 static bool impl(WrapperT X, WrapperT Y);
221 template <CmpOp Op,
class WrapperT,
int N>
struct vector_comparison_op_traits {
222 static_assert(is_wrapper_elem_type_v<WrapperT>,
"");
223 using RawVecT = __raw_vec_t<WrapperT, N>;
225 static __cmp_t<WrapperT, N> impl(RawVecT X, RawVecT Y);
230 template <UnaryOp Op,
class WrapperT>
struct scalar_unary_op_traits {
231 static_assert(is_wrapper_elem_type_v<WrapperT>,
"");
233 static WrapperT impl(WrapperT X);
236 template <UnaryOp Op,
class WrapperT,
int N>
struct vector_unary_op_traits {
237 static_assert(is_wrapper_elem_type_v<WrapperT>,
"");
238 using RawVecT = __raw_vec_t<WrapperT, N>;
240 static RawVecT impl(RawVecT X);
248 template <
class WrapperT>
struct wrapper_type_converter {
249 using StdT = __cpp_t<WrapperT>;
250 using RawT = __raw_t<WrapperT>;
253 ESIMD_INLINE
static vector_type_t<RawT, N>
254 to_vector(vector_type_t<StdT, N> Val) {
255 if constexpr (element_type_traits<WrapperT>::use_native_cpp_ops) {
256 return __builtin_convertvector(Val, vector_type_t<RawT, N>);
258 return vector_conversion_traits<WrapperT, N>::convert_to_raw(Val);
263 ESIMD_INLINE
static vector_type_t<StdT, N>
264 from_vector(vector_type_t<RawT, N> Val) {
265 if constexpr (element_type_traits<WrapperT>::use_native_cpp_ops) {
266 return __builtin_convertvector(Val, vector_type_t<StdT, N>);
268 return vector_conversion_traits<WrapperT, N>::convert_to_cpp(Val);
276 template <
class DstWrapperTy,
class SrcWrapperTy,
int N,
277 class DstRawVecTy = vector_type_t<__raw_t<DstWrapperTy>, N>,
278 class SrcRawVecTy = vector_type_t<__raw_t<SrcWrapperTy>, N>>
279 ESIMD_INLINE DstRawVecTy convert_vector(SrcRawVecTy Val) {
280 if constexpr (std::is_same_v<SrcWrapperTy, DstWrapperTy>) {
282 }
else if constexpr (!is_wrapper_elem_type_v<SrcWrapperTy> &&
283 !is_wrapper_elem_type_v<DstWrapperTy>) {
284 return __builtin_convertvector(Val, DstRawVecTy);
298 using SrcConv = wrapper_type_converter<SrcWrapperTy>;
299 using DstConv = wrapper_type_converter<DstWrapperTy>;
300 using SrcStdT =
typename SrcConv::StdT;
301 using DstStdT =
typename DstConv::StdT;
302 using DstStdVecT = vector_type_t<DstStdT, N>;
303 using SrcStdVecT = vector_type_t<SrcStdT, N>;
304 SrcStdVecT TmpSrcVal;
306 if constexpr (std::is_same_v<SrcStdT, SrcWrapperTy>) {
307 TmpSrcVal = std::move(Val);
309 TmpSrcVal = SrcConv::template from_vector<N>(Val);
311 if constexpr (std::is_same_v<SrcStdT, DstWrapperTy>) {
314 DstStdVecT TmpDstVal;
316 if constexpr (std::is_same_v<SrcStdT, DstStdVecT>) {
317 TmpDstVal = std::move(TmpSrcVal);
319 TmpDstVal = __builtin_convertvector(TmpSrcVal, DstStdVecT);
321 if constexpr (std::is_same_v<DstStdT, DstWrapperTy>) {
324 return DstConv::template to_vector<N>(TmpDstVal);
335 template <
class Ty> ESIMD_INLINE __raw_t<Ty> bitcast_to_raw_type(Ty Val) {
336 if constexpr (!is_wrapper_elem_type_v<Ty>) {
339 return scalar_conversion_traits<Ty>::bitcast_to_raw(Val);
343 template <
class Ty> ESIMD_INLINE Ty bitcast_to_wrapper_type(__raw_t<Ty> Val) {
344 if constexpr (!is_wrapper_elem_type_v<Ty>) {
347 return scalar_conversion_traits<Ty>::bitcast_to_wrapper(Val);
356 template <
class DstWrapperTy,
class SrcWrapperTy,
357 class DstRawTy = __raw_t<DstWrapperTy>,
358 class SrcRawTy = __raw_t<SrcWrapperTy>>
359 ESIMD_INLINE DstWrapperTy convert_scalar(SrcWrapperTy Val) {
360 if constexpr (std::is_same_v<SrcWrapperTy, DstWrapperTy>) {
362 }
else if constexpr (!is_wrapper_elem_type_v<SrcWrapperTy> &&
363 !is_wrapper_elem_type_v<DstWrapperTy>) {
364 return static_cast<DstRawTy
>(Val);
366 vector_type_t<SrcRawTy, 1> V0 = bitcast_to_raw_type<SrcWrapperTy>(Val);
367 vector_type_t<DstRawTy, 1> V1 =
368 convert_vector<DstWrapperTy, SrcWrapperTy, 1>(V0);
369 return bitcast_to_wrapper_type<DstWrapperTy>(V1[0]);
375 template <BinOp Op,
class T> T binary_op_default_impl(T X, T Y) {
379 else if constexpr (Op == BinOp::sub)
381 else if constexpr (Op == BinOp::mul)
385 else if constexpr (Op == BinOp::rem)
397 else if constexpr (Op == BinOp::log_or)
399 else if constexpr (Op == BinOp::log_and)
406 template <CmpOp Op,
class T>
auto comparison_op_default_impl(T X, T Y) {
407 decltype(X < Y) Res{};
408 if constexpr (Op == CmpOp::lt)
410 else if constexpr (Op == CmpOp::lte)
412 else if constexpr (Op == CmpOp::eq)
414 else if constexpr (Op == CmpOp::ne)
416 else if constexpr (Op == CmpOp::gte)
418 else if constexpr (Op == CmpOp::gt)
425 template <UnaryOp Op,
class T>
auto unary_op_default_impl(T X) {
426 if constexpr (Op == UnaryOp::minus)
430 else if constexpr (Op == UnaryOp::bit_not)
432 else if constexpr (Op == UnaryOp::log_not)
438 template <BinOp Op,
class T,
439 class = std::enable_if_t<is_valid_simd_elem_type_v<T>>>
440 ESIMD_INLINE T binary_op_default(T X, T Y) {
441 static_assert(element_type_traits<T>::use_native_cpp_ops);
442 using T1 = __raw_t<T>;
443 T1 X1 = bitcast_to_raw_type(X);
444 T1 Y1 = bitcast_to_raw_type(Y);
445 T1 Res = binary_op_default_impl<Op>(X1, Y1);
446 return bitcast_to_wrapper_type<T>(Res);
449 template <BinOp Op,
class T,
450 class = std::enable_if_t<is_valid_simd_elem_type_v<T>>>
451 ESIMD_INLINE T binary_op(T X, T Y) {
452 if constexpr (element_type_traits<T>::use_native_cpp_ops) {
453 return binary_op_default<Op>(X, Y);
455 return scalar_binary_op_traits<Op, T>::impl(X, Y);
461 template <BinOp Op,
class ElemT,
int N,
class RawVecT = __raw_vec_t<ElemT, N>>
462 ESIMD_INLINE RawVecT vector_binary_op_default(RawVecT X, RawVecT Y) {
463 static_assert(element_type_traits<ElemT>::use_native_cpp_ops);
464 return binary_op_default_impl<Op, RawVecT>(X, Y);
467 template <BinOp Op,
class ElemT,
int N,
class RawVecT = __raw_vec_t<ElemT, N>>
468 ESIMD_INLINE RawVecT vector_binary_op(RawVecT X, RawVecT Y) {
469 if constexpr (element_type_traits<ElemT>::use_native_cpp_ops) {
470 return vector_binary_op_default<Op, ElemT, N>(X, Y);
472 return vector_binary_op_traits<Op, ElemT, N>::impl(X, Y);
478 template <UnaryOp Op,
class T,
479 class = std::enable_if_t<is_valid_simd_elem_type_v<T>>>
480 ESIMD_INLINE T unary_op_default(T X) {
481 static_assert(element_type_traits<T>::use_native_cpp_ops);
482 using T1 = __raw_t<T>;
483 T1 X1 = bitcast_to_raw_type(X);
484 T1 Res = unary_op_default_impl<Op>(X1);
485 return bitcast_to_wrapper_type<T>(Res);
488 template <UnaryOp Op,
class T,
489 class = std::enable_if_t<is_valid_simd_elem_type_v<T>>>
490 ESIMD_INLINE T unary_op(T X) {
491 if constexpr (element_type_traits<T>::use_native_cpp_ops) {
492 return unary_op_default<Op>(X);
494 return scalar_unary_op_traits<Op, T>::impl(X);
500 template <UnaryOp Op,
class ElemT,
int N,
class RawVecT = __raw_vec_t<ElemT, N>>
501 ESIMD_INLINE RawVecT vector_unary_op_default(RawVecT X) {
502 static_assert(element_type_traits<ElemT>::use_native_cpp_ops);
503 return unary_op_default_impl<Op, RawVecT>(X);
506 template <UnaryOp Op,
class ElemT,
int N,
class RawVecT = __raw_vec_t<ElemT, N>>
507 ESIMD_INLINE RawVecT vector_unary_op(RawVecT X) {
508 if constexpr (element_type_traits<ElemT>::use_native_cpp_ops) {
509 return vector_unary_op_default<Op, ElemT, N>(X);
511 return vector_unary_op_traits<Op, ElemT, N>::impl(X);
517 template <CmpOp Op,
class ElemT,
int N,
class RetT = __cmp_t<ElemT, N>,
518 class RawVecT = __raw_vec_t<ElemT, N>>
519 ESIMD_INLINE RetT vector_comparison_op_default(RawVecT X, RawVecT Y) {
520 static_assert(element_type_traits<ElemT>::use_native_cpp_ops);
521 return comparison_op_default_impl<Op, RawVecT>(X, Y);
524 template <CmpOp Op,
class ElemT,
int N,
class RetT = __cmp_t<ElemT, N>,
525 class RawVecT = __raw_vec_t<ElemT, N>>
526 ESIMD_INLINE RetT vector_comparison_op(RawVecT X, RawVecT Y) {
527 if constexpr (element_type_traits<ElemT>::use_native_cpp_ops) {
528 return vector_comparison_op_default<Op, ElemT, N>(X, Y);
530 return vector_comparison_op_traits<Op, ElemT, N>::impl(X, Y);
541 template <BinOp Op,
class WrapperT>
542 ESIMD_INLINE WrapperT scalar_binary_op_traits<Op, WrapperT>::impl(WrapperT X,
544 using T1 = __cpp_t<WrapperT>;
545 T1 X1 = convert_scalar<T1, WrapperT>(X);
546 T1 Y1 = convert_scalar<T1, WrapperT>(Y);
547 return convert_scalar<WrapperT>(binary_op_default<Op, T1>(X1, Y1));
553 template <BinOp Op,
class WrapperT,
int N>
554 ESIMD_INLINE __raw_vec_t<WrapperT, N>
555 vector_binary_op_traits<Op, WrapperT, N>::impl(__raw_vec_t<WrapperT, N> X,
556 __raw_vec_t<WrapperT, N> Y) {
557 using T1 = __cpp_t<WrapperT>;
558 using VecT1 = vector_type_t<T1, N>;
559 VecT1 X1 = convert_vector<T1, WrapperT, N>(X);
560 VecT1 Y1 = convert_vector<T1, WrapperT, N>(Y);
561 return convert_vector<WrapperT, T1, N>(
562 vector_binary_op_default<Op, T1, N>(X1, Y1));
568 template <UnaryOp Op,
class WrapperT>
569 ESIMD_INLINE WrapperT scalar_unary_op_traits<Op, WrapperT>::impl(WrapperT X) {
570 using T1 = __cpp_t<WrapperT>;
571 T1 X1 = convert_scalar<T1, WrapperT>(X);
572 return convert_scalar<WrapperT>(unary_op_default<Op, T1>(X1));
578 template <UnaryOp Op,
class WrapperT,
int N>
579 ESIMD_INLINE __raw_vec_t<WrapperT, N>
580 vector_unary_op_traits<Op, WrapperT, N>::impl(__raw_vec_t<WrapperT, N> X) {
581 using T1 = __cpp_t<WrapperT>;
582 using VecT1 = vector_type_t<T1, N>;
583 VecT1 X1 = convert_vector<T1, WrapperT, N>(X);
584 return convert_vector<WrapperT, T1, N>(
585 vector_unary_op_default<Op, T1, N>(X1));
591 template <CmpOp Op,
class WrapperT,
int N>
592 ESIMD_INLINE __cmp_t<WrapperT, N>
593 vector_comparison_op_traits<Op, WrapperT, N>::impl(__raw_vec_t<WrapperT, N> X,
594 __raw_vec_t<WrapperT, N> Y) {
595 using T1 = __cpp_t<WrapperT>;
596 using VecT1 = vector_type_t<T1, N>;
597 VecT1 X1 = convert_vector<T1, WrapperT, N>(X);
598 VecT1 Y1 = convert_vector<T1, WrapperT, N>(Y);
599 return convert_vector<vector_element_type_t<__cmp_t<WrapperT, N>>, T1, N>(
600 vector_comparison_op_default<Op, T1, N>(X1, Y1));
605 template <
typename T>
606 static inline constexpr
bool is_generic_floating_point_v =
607 element_type_traits<T>::is_floating_point;