36 #ifndef SYCL_EXT_ONEAPI_COMPLEX
37 #define SYCL_EXT_ONEAPI_COMPLEX
45 namespace complex_namespace = sycl::ext::oneapi::experimental;
47 template <
typename ValueT>
50 template <
typename ValueT>
51 inline ValueT
clamp(ValueT val, ValueT min_val, ValueT max_val) {
55 #ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS
70 template <
typename VecT,
class BinaryOperation,
class =
void>
73 inline VecT
operator()(VecT a, VecT b,
const BinaryOperation binary_op) {
75 for (
size_t i = 0; i < v4.size(); ++i) {
76 v4[i] = binary_op(a[i], b[i]);
84 template <
typename ValueT>
86 if constexpr (std::is_signed_v<ValueT>) {
87 return int64_t(val) << (64 - bit) >> (64 - bit);
92 template <
typename RetT,
bool needSat,
typename AT,
typename BT,
93 typename BinaryOperation>
94 inline constexpr RetT
extend_binary(AT a, BT b, BinaryOperation binary_op) {
97 const int64_t ret = binary_op(extend_a, extend_b);
98 if constexpr (needSat)
104 template <
typename RetT,
bool needSat,
typename AT,
typename BT,
typename CT,
105 typename BinaryOperation1,
typename BinaryOperation2>
107 BinaryOperation1 binary_op,
108 BinaryOperation2 second_op) {
113 if constexpr (needSat)
118 return second_op(extend_temp, extend_c);
121 template <
typename ValueT>
inline bool isnan(
const ValueT a) {
124 #ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS
150 for (
int i = 0; i < len; ++i)
160 template <
typename ValueT>
161 inline ValueT
length(
const ValueT *a,
const int len) {
173 for (
int i = 0; i < len; ++i)
184 template <
typename ValueT,
class BinaryOperation>
185 inline std::enable_if_t<
186 std::is_same_v<std::invoke_result_t<BinaryOperation, ValueT, ValueT>,
bool>,
188 compare(
const ValueT a,
const ValueT b,
const BinaryOperation binary_op) {
189 return binary_op(a, b);
191 template <
typename ValueT>
192 inline std::enable_if_t<
193 std::is_same_v<std::invoke_result_t<std::not_equal_to<>, ValueT, ValueT>,
196 compare(
const ValueT a,
const ValueT b,
const std::not_equal_to<> binary_op) {
205 template <
typename ValueT,
class BinaryOperation>
206 inline std::enable_if_t<ValueT::size() == 2, ValueT>
207 compare(
const ValueT a,
const ValueT b,
const BinaryOperation binary_op) {
208 return {
compare(a[0], b[0], binary_op),
compare(a[1], b[1], binary_op)};
216 template <
typename ValueT,
class BinaryOperation>
217 inline std::enable_if_t<
218 std::is_same_v<std::invoke_result_t<BinaryOperation, ValueT, ValueT>,
bool>,
221 const BinaryOperation binary_op) {
230 template <
typename ValueT,
class BinaryOperation>
231 inline std::enable_if_t<ValueT::size() == 2, ValueT>
233 const BinaryOperation binary_op) {
243 template <
typename ValueT,
class BinaryOperation>
244 inline std::enable_if_t<ValueT::size() == 2,
bool>
245 compare_both(
const ValueT a,
const ValueT b,
const BinaryOperation binary_op) {
246 return compare(a[0], b[0], binary_op) &&
compare(a[1], b[1], binary_op);
255 template <
typename ValueT,
class BinaryOperation>
256 inline std::enable_if_t<ValueT::size() == 2,
bool>
258 const BinaryOperation binary_op) {
270 template <
typename ValueT,
class BinaryOperation>
273 const BinaryOperation binary_op) {
275 return ((-
compare(a[0], b[0], binary_op)) << 16) |
276 ((-
compare(a[1], b[1], binary_op)) & 0xFFFF);
286 template <
typename ValueT,
class BinaryOperation>
289 const BinaryOperation binary_op) {
303 auto v2 = v0.template as<S>();
304 auto v3 = v1.template as<S>();
306 v0 = v2.template as<sycl::vec<T, 1>>();
319 auto v2 = v0.template as<S>();
320 auto v3 = v1.template as<S>();
322 v0 = v2.template as<sycl::vec<T, 1>>();
332 template <
typename VecT,
class UnaryOperation>
335 auto v1 = v0.as<VecT>();
336 auto v2 = unary_op(v1);
337 v0 = v2.template as<sycl::vec<unsigned, 1>>();
347 template <
typename VecT>
350 auto v2 = v0.as<VecT>();
351 auto v3 = v1.as<VecT>();
354 for (
size_t i = 0; i < v4.size(); ++i) {
369 auto v2 = v0.template as<S>();
370 auto v3 = v1.template as<S>();
371 auto v4 = sycl::isgreater(v2, v3);
372 v0 = v4.template as<sycl::vec<T, 1>>();
382 inline unsigned vectorized_isgreater<sycl::ushort2, unsigned>(
unsigned a,
385 auto v2 = v0.template as<sycl::ushort2>();
386 auto v3 = v1.template as<sycl::ushort2>();
388 v4[0] = v2[0] > v3[0];
389 v4[1] = v2[1] > v3[1];
390 v0 = v4.template as<sycl::vec<unsigned, 1>>();
399 template <
typename ValueT>
400 inline ValueT
clamp(ValueT val, ValueT min_val, ValueT max_val) {
407 template <
typename ValueT>
408 inline std::enable_if_t<ValueT::size() == 2, ValueT>
isnan(
const ValueT a) {
413 template <
typename ValueT>
414 inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
415 std::is_same_v<sycl::half, ValueT>,
426 template <
typename ValueT,
typename ValueU>
427 std::enable_if_t<std::is_integral_v<ValueT> && std::is_integral_v<ValueU>,
428 std::common_type_t<ValueT, ValueU>>
430 return sycl::min(
static_cast<std::common_type_t<ValueT, ValueU>
>(a),
431 static_cast<std::common_type_t<ValueT, ValueU>
>(b));
433 template <
typename ValueT,
typename ValueU>
434 std::enable_if_t<std::is_floating_point_v<ValueT> &&
435 std::is_floating_point_v<ValueU>,
436 std::common_type_t<ValueT, ValueU>>
438 return sycl::fmin(
static_cast<std::common_type_t<ValueT, ValueU>
>(a),
439 static_cast<std::common_type_t<ValueT, ValueU>
>(b));
443 template <
typename ValueT,
typename ValueU>
444 std::enable_if_t<std::is_integral_v<ValueT> && std::is_integral_v<ValueU>,
445 std::common_type_t<ValueT, ValueU>>
447 return sycl::max(
static_cast<std::common_type_t<ValueT, ValueU>
>(a),
448 static_cast<std::common_type_t<ValueT, ValueU>
>(b));
450 template <
typename ValueT,
typename ValueU>
451 std::enable_if_t<std::is_floating_point_v<ValueT> &&
452 std::is_floating_point_v<ValueU>,
453 std::common_type_t<ValueT, ValueU>>
455 return sycl::fmax(
static_cast<std::common_type_t<ValueT, ValueU>
>(a),
456 static_cast<std::common_type_t<ValueT, ValueU>
>(b));
465 template <
typename ValueT,
typename ValueU>
466 inline std::common_type_t<ValueT, ValueU>
fmax_nan(
const ValueT a,
470 return sycl::fmax(
static_cast<std::common_type_t<ValueT, ValueU>
>(a),
471 static_cast<std::common_type_t<ValueT, ValueU>
>(b));
473 template <
typename ValueT,
typename ValueU>
484 template <
typename ValueT,
typename ValueU>
485 inline std::common_type_t<ValueT, ValueU>
fmin_nan(
const ValueT a,
489 return sycl::fmin(
static_cast<std::common_type_t<ValueT, ValueU>
>(a),
490 static_cast<std::common_type_t<ValueT, ValueU>
>(b));
492 template <
typename ValueT,
typename ValueU>
499 inline float pow(
const float a,
const int b) {
return sycl::pown(a, b); }
500 inline double pow(
const double a,
const int b) {
return sycl::pown(a, b); }
502 template <
typename ValueT,
typename ValueU>
503 inline typename std::enable_if_t<std::is_floating_point_v<ValueT>, ValueT>
504 pow(
const ValueT a,
const ValueU b) {
505 return sycl::pow(a,
static_cast<ValueT
>(b));
511 template <
typename ValueT,
typename ValueU>
512 inline typename std::enable_if_t<!std::is_floating_point_v<ValueT>,
double>
513 pow(
const ValueT a,
const ValueU b) {
514 return sycl::pow(
static_cast<double>(a),
static_cast<double>(b));
520 template <
typename ValueT>
521 inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
522 std::is_same_v<sycl::half, ValueT>,
529 template <
class ValueT>
530 inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
531 std::is_same_v<sycl::half, ValueT>,
536 template <
class ValueT>
537 inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
538 std::is_same_v<sycl::half, ValueT>,
549 template <
typename T>
561 template <
typename T>
592 template <
typename ValueT>
602 template <
typename ValueT>
615 template <
typename ValueT>
auto operator()(
const ValueT x)
const {
622 template <
typename ValueT>
630 template <
typename ValueT>
632 return sycl::add_sat(x, y);
638 template <
typename ValueT>
640 return sycl::rhadd(x, y);
646 template <
typename ValueT>
654 template <
typename ValueT>
662 template <
typename ValueT>
670 template <
typename ValueT>
672 return sycl::sub_sat(x, y);
683 template <
typename VecT,
class BinaryOperation>
685 const BinaryOperation binary_op) {
687 auto v2 = v0.as<VecT>();
688 auto v3 = v1.as<VecT>();
691 v0 = v4.template as<sycl::vec<unsigned, 1>>();
702 template <
typename RetT,
typename AT,
typename BT>
704 return detail::extend_binary<RetT, false>(a, b,
std::plus());
718 template <
typename RetT,
typename AT,
typename BT,
typename CT,
719 typename BinaryOperation>
720 inline constexpr RetT
extend_add(AT a, BT b, CT c, BinaryOperation second_op) {
721 return detail::extend_binary<RetT, false>(a, b, c,
std::plus(), second_op);
731 template <
typename RetT,
typename AT,
typename BT>
733 return detail::extend_binary<RetT, true>(a, b,
std::plus());
749 template <
typename RetT,
typename AT,
typename BT,
typename CT,
750 typename BinaryOperation>
752 BinaryOperation second_op) {
753 return detail::extend_binary<RetT, true>(a, b, c,
std::plus(), second_op);
763 template <
typename RetT,
typename AT,
typename BT>
765 return detail::extend_binary<RetT, false>(a, b, std::minus());
779 template <
typename RetT,
typename AT,
typename BT,
typename CT,
780 typename BinaryOperation>
781 inline constexpr RetT
extend_sub(AT a, BT b, CT c, BinaryOperation second_op) {
782 return detail::extend_binary<RetT, false>(a, b, c, std::minus(), second_op);
792 template <
typename RetT,
typename AT,
typename BT>
794 return detail::extend_binary<RetT, true>(a, b, std::minus());
810 template <
typename RetT,
typename AT,
typename BT,
typename CT,
811 typename BinaryOperation>
813 BinaryOperation second_op) {
814 return detail::extend_binary<RetT, true>(a, b, c, std::minus(), second_op);
824 template <
typename RetT,
typename AT,
typename BT>
826 return detail::extend_binary<RetT, false>(a, b,
abs_diff());
841 template <
typename RetT,
typename AT,
typename BT,
typename CT,
842 typename BinaryOperation>
844 BinaryOperation second_op) {
845 return detail::extend_binary<RetT, false>(a, b, c,
abs_diff(), second_op);
855 template <
typename RetT,
typename AT,
typename BT>
857 return detail::extend_binary<RetT, true>(a, b,
abs_diff());
873 template <
typename RetT,
typename AT,
typename BT,
typename CT,
874 typename BinaryOperation>
876 BinaryOperation second_op) {
877 return detail::extend_binary<RetT, true>(a, b, c,
abs_diff(), second_op);
887 template <
typename RetT,
typename AT,
typename BT>
889 return detail::extend_binary<RetT, false>(a, b,
minimum());
904 template <
typename RetT,
typename AT,
typename BT,
typename CT,
905 typename BinaryOperation>
906 inline constexpr RetT
extend_min(AT a, BT b, CT c, BinaryOperation second_op) {
907 return detail::extend_binary<RetT, false>(a, b, c,
minimum(), second_op);
917 template <
typename RetT,
typename AT,
typename BT>
919 return detail::extend_binary<RetT, true>(a, b,
minimum());
935 template <
typename RetT,
typename AT,
typename BT,
typename CT,
936 typename BinaryOperation>
938 BinaryOperation second_op) {
939 return detail::extend_binary<RetT, true>(a, b, c,
minimum(), second_op);
949 template <
typename RetT,
typename AT,
typename BT>
951 return detail::extend_binary<RetT, false>(a, b,
maximum());
966 template <
typename RetT,
typename AT,
typename BT,
typename CT,
967 typename BinaryOperation>
968 inline constexpr RetT
extend_max(AT a, BT b, CT c, BinaryOperation second_op) {
969 return detail::extend_binary<RetT, false>(a, b, c,
maximum(), second_op);
979 template <
typename RetT,
typename AT,
typename BT>
981 return detail::extend_binary<RetT, true>(a, b,
maximum());
997 template <
typename RetT,
typename AT,
typename BT,
typename CT,
998 typename BinaryOperation>
1000 BinaryOperation second_op) {
1001 return detail::extend_binary<RetT, true>(a, b, c,
maximum(), second_op);
Provides a cross-platform math array class template that works on SYCL devices as well as in host C++...
class sycl::vec ///////////////////////// Provides a cross-patform vector class template that works e...
VecT operator()(VecT a, VecT b, const BinaryOperation binary_op)
__ESIMD_API std::enable_if_t< detail::is_esimd_scalar< T1 >::value, std::remove_const_t< T1 > > abs(T1 src0)
Get absolute value (scalar version).
conditional< sizeof(long)==8, long, long long >::type int64_t
ESIMD_INLINE ESIMD_NODEBUG T0 sum(simd< T1, SZ > v)
std::enable_if_t< std::is_same_v< Tp, sycl::half2 >, sycl::half2 > sqrt(Tp x)
sycl::minimum< T > minimum
sycl::maximum< T > maximum
return std::max(x, y) - std hadd
bool isnan(const ValueT a)
ValueT clamp(ValueT val, ValueT min_val, ValueT max_val)
detail::complex_namespace::complex< ValueT > complex_type
int64_t zero_or_signed_extent(ValueT val, unsigned bit)
Extend the 'val' to 'bit' size, zero extend for unsigned int and signed extend for signed int.
constexpr RetT extend_binary(AT a, BT b, BinaryOperation binary_op)
float pow(const float a, const int b)
constexpr RetT extend_absdiff_sat(AT a, BT b)
Extend a and b to 33 bit and do abs_diff with saturation.
sycl::half min(sycl::half a, sycl::half b)
constexpr RetT extend_add(AT a, BT b)
Extend a and b to 33 bit and add them.
T vectorized_max(T a, T b)
Compute vectorized max for two values, with each value treated as a vector type S.
std::enable_if_t< std::is_same_v< std::invoke_result_t< BinaryOperation, ValueT, ValueT >, bool >, bool > compare(const ValueT a, const ValueT b, const BinaryOperation binary_op)
Performs comparison.
constexpr RetT extend_sub(AT a, BT b)
Extend a and b to 33 bit and minus them.
constexpr RetT extend_max_sat(AT a, BT b)
Extend a and b to 33 bit and return bigger one with saturation.
float fast_length(const float *a, int len)
Compute fast_length for variable-length array.
constexpr RetT extend_add_sat(AT a, BT b)
Extend a and b to 33 bit and add them with saturation.
T cabs(sycl::vec< T, 2 > x)
Computes the magnitude of a complex number.
sycl::half max(sycl::half a, sycl::half b)
constexpr RetT extend_min(AT a, BT b)
Extend a and b to 33 bit and return smaller one.
std::enable_if_t< std::is_integral_v< ValueT > &&std::is_integral_v< ValueU >, std::common_type_t< ValueT, ValueU > > max(ValueT a, ValueU b)
constexpr RetT extend_max(AT a, BT b)
Extend a and b to 33 bit and return bigger one.
std::common_type_t< ValueT, ValueU > fmax_nan(const ValueT a, const ValueU b)
Performs 2 elements comparison and returns the bigger one.
constexpr RetT extend_absdiff(AT a, BT b)
Extend a and b to 33 bit and do abs_diff.
constexpr RetT extend_min_sat(AT a, BT b)
Extend a and b to 33 bit and return smaller one with saturation.
sycl::vec< T, 2 > conj(sycl::vec< T, 2 > x)
Computes the complex conjugate of a complex number.
std::enable_if_t< std::is_integral_v< ValueT > &&std::is_integral_v< ValueU >, std::common_type_t< ValueT, ValueU > > min(ValueT a, ValueU b)
std::enable_if_t< std::is_floating_point_v< ValueT >||std::is_same_v< sycl::half, ValueT >, ValueT > relu(const ValueT a)
Performs relu saturation.
std::enable_if_t< ValueT::size()==2, bool > unordered_compare_both(const ValueT a, const ValueT b, const BinaryOperation binary_op)
Performs 2 element unordered comparison and return true if both results are true.
std::enable_if_t< std::is_floating_point_v< ValueT >||std::is_same_v< sycl::half, ValueT >, ValueT > cbrt(ValueT val)
cbrt function wrapper.
sycl::vec< T, 2 > cdiv(sycl::vec< T, 2 > x, sycl::vec< T, 2 > y)
Computes the division of two complex numbers.
T vectorized_min(T a, T b)
Compute vectorized min for two values, with each value treated as a vector type S.
std::common_type_t< ValueT, ValueU > fmin_nan(const ValueT a, const ValueU b)
Performs 2 elements comparison and returns the smaller one.
sycl::vec< ValueT, 2 > cmul_add(const sycl::vec< ValueT, 2 > a, const sycl::vec< ValueT, 2 > b, const sycl::vec< ValueT, 2 > c)
Performs complex number multiply addition.
constexpr RetT extend_sub_sat(AT a, BT b)
Extend a and b to 33 bit and minus them with saturation.
ValueT clamp(ValueT val, ValueT min_val, ValueT max_val)
Returns min(max(val, min_val), max_val)
unsigned compare_mask(const sycl::vec< ValueT, 2 > a, const sycl::vec< ValueT, 2 > b, const BinaryOperation binary_op)
Performs 2 elements comparison, compare result of each element is 0 (false) or 0xffff (true),...
ValueT length(const ValueT *a, const int len)
Calculate the square root of the input array.
unsigned unordered_compare_mask(const sycl::vec< ValueT, 2 > a, const sycl::vec< ValueT, 2 > b, const BinaryOperation binary_op)
Performs 2 elements unordered comparison, compare result of each element is 0 (false) or 0xffff (true...
std::enable_if_t< ValueT::size()==2, bool > compare_both(const ValueT a, const ValueT b, const BinaryOperation binary_op)
Performs 2 element comparison and return true if both results are true.
unsigned vectorized_unary(unsigned a, const UnaryOperation unary_op)
Compute vectorized unary operation for a value, with the value treated as a vector type VecT.
unsigned vectorized_binary(unsigned a, unsigned b, const BinaryOperation binary_op)
Compute vectorized binary operation value for two values, with each value treated as a vector type Ve...
std::enable_if_t< ValueT::size()==2, ValueT > isnan(const ValueT a)
Determine whether 2 element value is NaN.
std::enable_if_t< std::is_same_v< std::invoke_result_t< BinaryOperation, ValueT, ValueT >, bool >, bool > unordered_compare(const ValueT a, const ValueT b, const BinaryOperation binary_op)
Performs unordered comparison.
std::enable_if_t<!std::is_floating_point_v< ValueT >, double > pow(const ValueT a, const ValueU b)
unsigned vectorized_sum_abs_diff(unsigned a, unsigned b)
Compute vectorized absolute difference for two values without modulo overflow, with each value treate...
T vectorized_isgreater(T a, T b)
Compute vectorized isgreater for two values, with each value treated as a vector type S.
sycl::vec< T, 2 > cmul(sycl::vec< T, 2 > x, sycl::vec< T, 2 > y)
Computes the multiplication of two complex numbers.
A sycl::abs_diff wrapper functors.
auto operator()(const ValueT x, const ValueT y) const
A sycl::abs wrapper functors.
auto operator()(const ValueT x) const
A sycl::add_sat wrapper functors.
auto operator()(const ValueT x, const ValueT y) const
A sycl::hadd wrapper functors.
auto operator()(const ValueT x, const ValueT y) const
A sycl::max wrapper functors.
auto operator()(const ValueT x, const ValueT y) const
A sycl::min wrapper functors.
auto operator()(const ValueT x, const ValueT y) const
A sycl::rhadd wrapper functors.
auto operator()(const ValueT x, const ValueT y) const
A sycl::sub_sat wrapper functors.
auto operator()(const ValueT x, const ValueT y) const