DPC++ Runtime
Runtime libraries for oneAPI DPC++
math.hpp
Go to the documentation of this file.
1 /***************************************************************************
2  *
3  * Copyright (C) Codeplay Software Ltd.
4  *
5  * Part of the LLVM Project, under the Apache License v2.0 with LLVM
6  * Exceptions. See https://llvm.org/LICENSE.txt for license information.
7  * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  *
15  * SYCL compatibility extension
16  *
17  * math.hpp
18  *
19  * Description:
20  * math utilities for the SYCL compatibility extension.
21  **************************************************************************/
22 
23 // The original source was under the license below:
24 //==---- math.hpp ---------------------------------*- C++ -*----------------==//
25 //
26 // Copyright (C) Intel Corporation
27 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
28 // See https://llvm.org/LICENSE.txt for license information.
29 //
30 //===----------------------------------------------------------------------===//
31 
32 #pragma once
33 
34 #include <sycl/sycl.hpp>
35 
36 #ifndef SYCL_EXT_ONEAPI_COMPLEX
37 #define SYCL_EXT_ONEAPI_COMPLEX
38 #endif
39 
41 
42 namespace syclcompat {
43 namespace detail {
44 
45 namespace complex_namespace = sycl::ext::oneapi::experimental;
46 
47 template <typename ValueT>
48 using complex_type = detail::complex_namespace::complex<ValueT>;
49 
50 template <typename ValueT>
51 inline ValueT clamp(ValueT val, ValueT min_val, ValueT max_val) {
52  return sycl::clamp(val, min_val, max_val);
53 }
54 
55 #ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS
56 // TODO: Follow the process to add this to the extension. If added,
57 // remove this functionality from the header.
58 template <>
62  if (val < min_val)
63  return min_val;
64  if (val > max_val)
65  return max_val;
66  return val;
67 }
68 #endif
69 
70 template <typename VecT, class BinaryOperation, class = void>
72 public:
73  inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) {
74  VecT v4;
75  for (size_t i = 0; i < v4.size(); ++i) {
76  v4[i] = binary_op(a[i], b[i]);
77  }
78  return v4;
79  }
80 };
81 
84 template <typename ValueT>
85 inline int64_t zero_or_signed_extent(ValueT val, unsigned bit) {
86  if constexpr (std::is_signed_v<ValueT>) {
87  return int64_t(val) << (64 - bit) >> (64 - bit);
88  }
89  return val;
90 }
91 
92 template <typename RetT, bool needSat, typename AT, typename BT,
93  typename BinaryOperation>
94 inline constexpr RetT extend_binary(AT a, BT b, BinaryOperation binary_op) {
95  const int64_t extend_a = zero_or_signed_extent(a, 33);
96  const int64_t extend_b = zero_or_signed_extent(b, 33);
97  const int64_t ret = binary_op(extend_a, extend_b);
98  if constexpr (needSat)
99  return detail::clamp<int64_t>(ret, std::numeric_limits<RetT>::min(),
101  return ret;
102 }
103 
104 template <typename RetT, bool needSat, typename AT, typename BT, typename CT,
105  typename BinaryOperation1, typename BinaryOperation2>
106 inline constexpr RetT extend_binary(AT a, BT b, CT c,
107  BinaryOperation1 binary_op,
108  BinaryOperation2 second_op) {
109  const int64_t extend_a = zero_or_signed_extent(a, 33);
110  const int64_t extend_b = zero_or_signed_extent(b, 33);
111  int64_t extend_temp =
112  zero_or_signed_extent(binary_op(extend_a, extend_b), 34);
113  if constexpr (needSat)
114  extend_temp =
115  detail::clamp<int64_t>(extend_temp, std::numeric_limits<RetT>::min(),
117  const int64_t extend_c = zero_or_signed_extent(c, 33);
118  return second_op(extend_temp, extend_c);
119 }
120 
121 template <typename ValueT> inline bool isnan(const ValueT a) {
122  return sycl::isnan(a);
123 }
124 #ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS
125 inline bool isnan(const sycl::ext::oneapi::bfloat16 a) {
127 }
128 #endif
129 
130 } // namespace detail
131 
136 inline float fast_length(const float *a, int len) {
137  switch (len) {
138  case 1:
139  return sycl::fast_length(a[0]);
140  case 2:
141  return sycl::fast_length(sycl::float2(a[0], a[1]));
142  case 3:
143  return sycl::fast_length(sycl::float3(a[0], a[1], a[2]));
144  case 4:
145  return sycl::fast_length(sycl::float4(a[0], a[1], a[2], a[3]));
146  case 0:
147  return 0;
148  default:
149  float f = 0;
150  for (int i = 0; i < len; ++i)
151  f += a[i] * a[i];
152  return sycl::sqrt(f);
153  }
154 }
155 
160 template <typename ValueT>
161 inline ValueT length(const ValueT *a, const int len) {
162  switch (len) {
163  case 1:
164  return a[0];
165  case 2:
166  return sycl::length(sycl::vec<ValueT, 2>(a[0], a[1]));
167  case 3:
168  return sycl::length(sycl::vec<ValueT, 3>(a[0], a[1], a[2]));
169  case 4:
170  return sycl::length(sycl::vec<ValueT, 4>(a[0], a[1], a[2], a[3]));
171  default:
172  ValueT ret = 0;
173  for (int i = 0; i < len; ++i)
174  ret += a[i] * a[i];
175  return sycl::sqrt(ret);
176  }
177 }
178 
184 template <typename ValueT, class BinaryOperation>
185 inline std::enable_if_t<
186  std::is_same_v<std::invoke_result_t<BinaryOperation, ValueT, ValueT>, bool>,
187  bool>
188 compare(const ValueT a, const ValueT b, const BinaryOperation binary_op) {
189  return binary_op(a, b);
190 }
191 template <typename ValueT>
192 inline std::enable_if_t<
193  std::is_same_v<std::invoke_result_t<std::not_equal_to<>, ValueT, ValueT>,
194  bool>,
195  bool>
196 compare(const ValueT a, const ValueT b, const std::not_equal_to<> binary_op) {
197  return !detail::isnan(a) && !detail::isnan(b) && binary_op(a, b);
198 }
199 
205 template <typename ValueT, class BinaryOperation>
206 inline std::enable_if_t<ValueT::size() == 2, ValueT>
207 compare(const ValueT a, const ValueT b, const BinaryOperation binary_op) {
208  return {compare(a[0], b[0], binary_op), compare(a[1], b[1], binary_op)};
209 }
210 
216 template <typename ValueT, class BinaryOperation>
217 inline std::enable_if_t<
218  std::is_same_v<std::invoke_result_t<BinaryOperation, ValueT, ValueT>, bool>,
219  bool>
220 unordered_compare(const ValueT a, const ValueT b,
221  const BinaryOperation binary_op) {
222  return detail::isnan(a) || detail::isnan(b) || binary_op(a, b);
223 }
224 
230 template <typename ValueT, class BinaryOperation>
231 inline std::enable_if_t<ValueT::size() == 2, ValueT>
232 unordered_compare(const ValueT a, const ValueT b,
233  const BinaryOperation binary_op) {
234  return {unordered_compare(a[0], b[0], binary_op),
235  unordered_compare(a[1], b[1], binary_op)};
236 }
237 
243 template <typename ValueT, class BinaryOperation>
244 inline std::enable_if_t<ValueT::size() == 2, bool>
245 compare_both(const ValueT a, const ValueT b, const BinaryOperation binary_op) {
246  return compare(a[0], b[0], binary_op) && compare(a[1], b[1], binary_op);
247 }
248 
255 template <typename ValueT, class BinaryOperation>
256 inline std::enable_if_t<ValueT::size() == 2, bool>
257 unordered_compare_both(const ValueT a, const ValueT b,
258  const BinaryOperation binary_op) {
259  return unordered_compare(a[0], b[0], binary_op) &&
260  unordered_compare(a[1], b[1], binary_op);
261 }
262 
270 template <typename ValueT, class BinaryOperation>
271 inline unsigned compare_mask(const sycl::vec<ValueT, 2> a,
272  const sycl::vec<ValueT, 2> b,
273  const BinaryOperation binary_op) {
274  // Since compare returns 0 or 1, -compare will be 0x00000000 or 0xFFFFFFFF
275  return ((-compare(a[0], b[0], binary_op)) << 16) |
276  ((-compare(a[1], b[1], binary_op)) & 0xFFFF);
277 }
278 
286 template <typename ValueT, class BinaryOperation>
288  const sycl::vec<ValueT, 2> b,
289  const BinaryOperation binary_op) {
290  return ((-unordered_compare(a[0], b[0], binary_op)) << 16) |
291  ((-unordered_compare(a[1], b[1], binary_op)) & 0xFFFF);
292 }
293 
301 template <typename S, typename T> inline T vectorized_max(T a, T b) {
302  sycl::vec<T, 1> v0{a}, v1{b};
303  auto v2 = v0.template as<S>();
304  auto v3 = v1.template as<S>();
305  v2 = sycl::max(v2, v3);
306  v0 = v2.template as<sycl::vec<T, 1>>();
307  return v0;
308 }
309 
317 template <typename S, typename T> inline T vectorized_min(T a, T b) {
318  sycl::vec<T, 1> v0{a}, v1{b};
319  auto v2 = v0.template as<S>();
320  auto v3 = v1.template as<S>();
321  v2 = sycl::min(v2, v3);
322  v0 = v2.template as<sycl::vec<T, 1>>();
323  return v0;
324 }
325 
332 template <typename VecT, class UnaryOperation>
333 inline unsigned vectorized_unary(unsigned a, const UnaryOperation unary_op) {
335  auto v1 = v0.as<VecT>();
336  auto v2 = unary_op(v1);
337  v0 = v2.template as<sycl::vec<unsigned, 1>>();
338  return v0;
339 }
340 
347 template <typename VecT>
348 inline unsigned vectorized_sum_abs_diff(unsigned a, unsigned b) {
349  sycl::vec<unsigned, 1> v0{a}, v1{b};
350  auto v2 = v0.as<VecT>();
351  auto v3 = v1.as<VecT>();
352  auto v4 = sycl::abs_diff(v2, v3);
353  unsigned sum = 0;
354  for (size_t i = 0; i < v4.size(); ++i) {
355  sum += v4[i];
356  }
357  return sum;
358 }
359 
367 template <typename S, typename T> inline T vectorized_isgreater(T a, T b) {
368  sycl::vec<T, 1> v0{a}, v1{b};
369  auto v2 = v0.template as<S>();
370  auto v3 = v1.template as<S>();
371  auto v4 = sycl::isgreater(v2, v3);
372  v0 = v4.template as<sycl::vec<T, 1>>();
373  return v0;
374 }
375 
381 template <>
382 inline unsigned vectorized_isgreater<sycl::ushort2, unsigned>(unsigned a,
383  unsigned b) {
384  sycl::vec<unsigned, 1> v0{a}, v1{b};
385  auto v2 = v0.template as<sycl::ushort2>();
386  auto v3 = v1.template as<sycl::ushort2>();
387  sycl::ushort2 v4;
388  v4[0] = v2[0] > v3[0];
389  v4[1] = v2[1] > v3[1];
390  v0 = v4.template as<sycl::vec<unsigned, 1>>();
391  return v0;
392 }
393 
399 template <typename ValueT>
400 inline ValueT clamp(ValueT val, ValueT min_val, ValueT max_val) {
401  return detail::clamp(val, min_val, max_val);
402 }
403 
407 template <typename ValueT>
408 inline std::enable_if_t<ValueT::size() == 2, ValueT> isnan(const ValueT a) {
409  return {detail::isnan(a[0]), detail::isnan(a[1])};
410 }
411 
413 template <typename ValueT>
414 inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
415  std::is_same_v<sycl::half, ValueT>,
416  ValueT>
417 cbrt(ValueT val) {
418  return sycl::cbrt(static_cast<ValueT>(val));
419 }
420 
421 // min/max function overloads.
422 // For floating-point types, `float` or `double` arguments are acceptable.
423 // For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or
424 // `std::int64_t` type arguments are acceptable.
425 // sycl::half supported as well.
426 template <typename ValueT, typename ValueU>
427 std::enable_if_t<std::is_integral_v<ValueT> && std::is_integral_v<ValueU>,
428  std::common_type_t<ValueT, ValueU>>
429 min(ValueT a, ValueU b) {
430  return sycl::min(static_cast<std::common_type_t<ValueT, ValueU>>(a),
431  static_cast<std::common_type_t<ValueT, ValueU>>(b));
432 }
433 template <typename ValueT, typename ValueU>
434 std::enable_if_t<std::is_floating_point_v<ValueT> &&
435  std::is_floating_point_v<ValueU>,
436  std::common_type_t<ValueT, ValueU>>
437 min(ValueT a, ValueU b) {
438  return sycl::fmin(static_cast<std::common_type_t<ValueT, ValueU>>(a),
439  static_cast<std::common_type_t<ValueT, ValueU>>(b));
440 }
441 sycl::half min(sycl::half a, sycl::half b) { return sycl::fmin(a, b); }
442 
443 template <typename ValueT, typename ValueU>
444 std::enable_if_t<std::is_integral_v<ValueT> && std::is_integral_v<ValueU>,
445  std::common_type_t<ValueT, ValueU>>
446 max(ValueT a, ValueU b) {
447  return sycl::max(static_cast<std::common_type_t<ValueT, ValueU>>(a),
448  static_cast<std::common_type_t<ValueT, ValueU>>(b));
449 }
450 template <typename ValueT, typename ValueU>
451 std::enable_if_t<std::is_floating_point_v<ValueT> &&
452  std::is_floating_point_v<ValueU>,
453  std::common_type_t<ValueT, ValueU>>
454 max(ValueT a, ValueU b) {
455  return sycl::fmax(static_cast<std::common_type_t<ValueT, ValueU>>(a),
456  static_cast<std::common_type_t<ValueT, ValueU>>(b));
457 }
458 sycl::half max(sycl::half a, sycl::half b) { return sycl::fmax(a, b); }
459 
465 template <typename ValueT, typename ValueU>
466 inline std::common_type_t<ValueT, ValueU> fmax_nan(const ValueT a,
467  const ValueU b) {
468  if (detail::isnan(a) || detail::isnan(b))
469  return NAN;
470  return sycl::fmax(static_cast<std::common_type_t<ValueT, ValueU>>(a),
471  static_cast<std::common_type_t<ValueT, ValueU>>(b));
472 }
473 template <typename ValueT, typename ValueU>
476  return {fmax_nan(a[0], b[0]), fmax_nan(a[1], b[1])};
477 }
478 
484 template <typename ValueT, typename ValueU>
485 inline std::common_type_t<ValueT, ValueU> fmin_nan(const ValueT a,
486  const ValueU b) {
487  if (detail::isnan(a) || detail::isnan(b))
488  return NAN;
489  return sycl::fmin(static_cast<std::common_type_t<ValueT, ValueU>>(a),
490  static_cast<std::common_type_t<ValueT, ValueU>>(b));
491 }
492 template <typename ValueT, typename ValueU>
495  return {fmin_nan(a[0], b[0]), fmin_nan(a[1], b[1])};
496 }
497 
498 // pow functions overload.
499 inline float pow(const float a, const int b) { return sycl::pown(a, b); }
500 inline double pow(const double a, const int b) { return sycl::pown(a, b); }
501 
502 template <typename ValueT, typename ValueU>
503 inline typename std::enable_if_t<std::is_floating_point_v<ValueT>, ValueT>
504 pow(const ValueT a, const ValueU b) {
505  return sycl::pow(a, static_cast<ValueT>(b));
506 }
507 
508 // TODO: calling pow with non-floating point values is currently defaulting to
509 // double, which fails on devices without aspect::fp64. This has to be properly
510 // documented, and maybe changed to support all devices.
511 template <typename ValueT, typename ValueU>
512 inline typename std::enable_if_t<!std::is_floating_point_v<ValueT>, double>
513 pow(const ValueT a, const ValueU b) {
514  return sycl::pow(static_cast<double>(a), static_cast<double>(b));
515 }
516 
520 template <typename ValueT>
521 inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
522  std::is_same_v<sycl::half, ValueT>,
523  ValueT>
524 relu(const ValueT a) {
525  if (!detail::isnan(a) && a < ValueT(0))
526  return ValueT(0);
527  return a;
528 }
529 template <class ValueT>
530 inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
531  std::is_same_v<sycl::half, ValueT>,
534  return {relu(a[0]), relu(a[1])};
535 }
536 template <class ValueT>
537 inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
538  std::is_same_v<sycl::half, ValueT>,
541  return {relu(a[0]), relu(a[1])};
542 }
543 
549 template <typename T>
551  sycl::ext::oneapi::experimental::complex<T> t1(x[0], x[1]), t2(y[0], y[1]);
552  t1 = t1 * t2;
553  return sycl::vec<T, 2>(t1.real(), t1.imag());
554 }
555 
561 template <typename T>
563  sycl::ext::oneapi::experimental::complex<T> t1(x[0], x[1]), t2(y[0], y[1]);
564  t1 = t1 / t2;
565  return sycl::vec<T, 2>(t1.real(), t1.imag());
566 }
567 
572 template <typename T> T cabs(sycl::vec<T, 2> x) {
575 }
576 
581 template <typename T> sycl::vec<T, 2> conj(sycl::vec<T, 2> x) {
583  t = conj(t);
584  return sycl::vec<T, 2>(t.real(), t.imag());
585 }
586 
592 template <typename ValueT>
594  const sycl::vec<ValueT, 2> b,
595  const sycl::vec<ValueT, 2> c) {
599  t = t * u + v;
600  return sycl::vec<ValueT, 2>{t.real(), t.imag()};
601 }
602 template <typename ValueT>
604  const sycl::marray<ValueT, 2> b,
605  const sycl::marray<ValueT, 2> c) {
609  t = t * u + v;
610  return sycl::marray<ValueT, 2>{t.real(), t.imag()};
611 }
612 
614 struct abs {
615  template <typename ValueT> auto operator()(const ValueT x) const {
616  return sycl::abs(x);
617  }
618 };
619 
621 struct abs_diff {
622  template <typename ValueT>
623  auto operator()(const ValueT x, const ValueT y) const {
624  return sycl::abs_diff(x, y);
625  }
626 };
627 
629 struct add_sat {
630  template <typename ValueT>
631  auto operator()(const ValueT x, const ValueT y) const {
632  return sycl::add_sat(x, y);
633  }
634 };
635 
637 struct rhadd {
638  template <typename ValueT>
639  auto operator()(const ValueT x, const ValueT y) const {
640  return sycl::rhadd(x, y);
641  }
642 };
643 
645 struct hadd {
646  template <typename ValueT>
647  auto operator()(const ValueT x, const ValueT y) const {
648  return sycl::hadd(x, y);
649  }
650 };
651 
653 struct maximum {
654  template <typename ValueT>
655  auto operator()(const ValueT x, const ValueT y) const {
656  return sycl::max(x, y);
657  }
658 };
659 
661 struct minimum {
662  template <typename ValueT>
663  auto operator()(const ValueT x, const ValueT y) const {
664  return sycl::min(x, y);
665  }
666 };
667 
669 struct sub_sat {
670  template <typename ValueT>
671  auto operator()(const ValueT x, const ValueT y) const {
672  return sycl::sub_sat(x, y);
673  }
674 };
675 
683 template <typename VecT, class BinaryOperation>
684 inline unsigned vectorized_binary(unsigned a, unsigned b,
685  const BinaryOperation binary_op) {
686  sycl::vec<unsigned, 1> v0{a}, v1{b};
687  auto v2 = v0.as<VecT>();
688  auto v3 = v1.as<VecT>();
689  auto v4 =
691  v0 = v4.template as<sycl::vec<unsigned, 1>>();
692  return v0;
693 }
694 
702 template <typename RetT, typename AT, typename BT>
703 inline constexpr RetT extend_add(AT a, BT b) {
704  return detail::extend_binary<RetT, false>(a, b, std::plus());
705 }
706 
718 template <typename RetT, typename AT, typename BT, typename CT,
719  typename BinaryOperation>
720 inline constexpr RetT extend_add(AT a, BT b, CT c, BinaryOperation second_op) {
721  return detail::extend_binary<RetT, false>(a, b, c, std::plus(), second_op);
722 }
723 
731 template <typename RetT, typename AT, typename BT>
732 inline constexpr RetT extend_add_sat(AT a, BT b) {
733  return detail::extend_binary<RetT, true>(a, b, std::plus());
734 }
735 
749 template <typename RetT, typename AT, typename BT, typename CT,
750  typename BinaryOperation>
751 inline constexpr RetT extend_add_sat(AT a, BT b, CT c,
752  BinaryOperation second_op) {
753  return detail::extend_binary<RetT, true>(a, b, c, std::plus(), second_op);
754 }
755 
763 template <typename RetT, typename AT, typename BT>
764 inline constexpr RetT extend_sub(AT a, BT b) {
765  return detail::extend_binary<RetT, false>(a, b, std::minus());
766 }
767 
779 template <typename RetT, typename AT, typename BT, typename CT,
780  typename BinaryOperation>
781 inline constexpr RetT extend_sub(AT a, BT b, CT c, BinaryOperation second_op) {
782  return detail::extend_binary<RetT, false>(a, b, c, std::minus(), second_op);
783 }
784 
792 template <typename RetT, typename AT, typename BT>
793 inline constexpr RetT extend_sub_sat(AT a, BT b) {
794  return detail::extend_binary<RetT, true>(a, b, std::minus());
795 }
796 
810 template <typename RetT, typename AT, typename BT, typename CT,
811  typename BinaryOperation>
812 inline constexpr RetT extend_sub_sat(AT a, BT b, CT c,
813  BinaryOperation second_op) {
814  return detail::extend_binary<RetT, true>(a, b, c, std::minus(), second_op);
815 }
816 
824 template <typename RetT, typename AT, typename BT>
825 inline constexpr RetT extend_absdiff(AT a, BT b) {
826  return detail::extend_binary<RetT, false>(a, b, abs_diff());
827 }
828 
841 template <typename RetT, typename AT, typename BT, typename CT,
842  typename BinaryOperation>
843 inline constexpr RetT extend_absdiff(AT a, BT b, CT c,
844  BinaryOperation second_op) {
845  return detail::extend_binary<RetT, false>(a, b, c, abs_diff(), second_op);
846 }
847 
855 template <typename RetT, typename AT, typename BT>
856 inline constexpr RetT extend_absdiff_sat(AT a, BT b) {
857  return detail::extend_binary<RetT, true>(a, b, abs_diff());
858 }
859 
873 template <typename RetT, typename AT, typename BT, typename CT,
874  typename BinaryOperation>
875 inline constexpr RetT extend_absdiff_sat(AT a, BT b, CT c,
876  BinaryOperation second_op) {
877  return detail::extend_binary<RetT, true>(a, b, c, abs_diff(), second_op);
878 }
879 
887 template <typename RetT, typename AT, typename BT>
888 inline constexpr RetT extend_min(AT a, BT b) {
889  return detail::extend_binary<RetT, false>(a, b, minimum());
890 }
891 
904 template <typename RetT, typename AT, typename BT, typename CT,
905  typename BinaryOperation>
906 inline constexpr RetT extend_min(AT a, BT b, CT c, BinaryOperation second_op) {
907  return detail::extend_binary<RetT, false>(a, b, c, minimum(), second_op);
908 }
909 
917 template <typename RetT, typename AT, typename BT>
918 inline constexpr RetT extend_min_sat(AT a, BT b) {
919  return detail::extend_binary<RetT, true>(a, b, minimum());
920 }
921 
935 template <typename RetT, typename AT, typename BT, typename CT,
936  typename BinaryOperation>
937 inline constexpr RetT extend_min_sat(AT a, BT b, CT c,
938  BinaryOperation second_op) {
939  return detail::extend_binary<RetT, true>(a, b, c, minimum(), second_op);
940 }
941 
949 template <typename RetT, typename AT, typename BT>
950 inline constexpr RetT extend_max(AT a, BT b) {
951  return detail::extend_binary<RetT, false>(a, b, maximum());
952 }
953 
966 template <typename RetT, typename AT, typename BT, typename CT,
967  typename BinaryOperation>
968 inline constexpr RetT extend_max(AT a, BT b, CT c, BinaryOperation second_op) {
969  return detail::extend_binary<RetT, false>(a, b, c, maximum(), second_op);
970 }
971 
979 template <typename RetT, typename AT, typename BT>
980 inline constexpr RetT extend_max_sat(AT a, BT b) {
981  return detail::extend_binary<RetT, true>(a, b, maximum());
982 }
983 
997 template <typename RetT, typename AT, typename BT, typename CT,
998  typename BinaryOperation>
999 inline constexpr RetT extend_max_sat(AT a, BT b, CT c,
1000  BinaryOperation second_op) {
1001  return detail::extend_binary<RetT, true>(a, b, c, maximum(), second_op);
1002 }
1003 
1004 } // namespace syclcompat
Provides a cross-platform math array class template that works on SYCL devices as well as in host C++...
Definition: marray.hpp:49
class sycl::vec ///////////////////////// Provides a cross-patform vector class template that works e...
VecT operator()(VecT a, VecT b, const BinaryOperation binary_op)
Definition: math.hpp:73
__ESIMD_API std::enable_if_t< detail::is_esimd_scalar< T1 >::value, std::remove_const_t< T1 > > abs(T1 src0)
Get absolute value (scalar version).
Definition: math.hpp:166
conditional< sizeof(long)==8, long, long long >::type int64_t
Definition: kernel_desc.hpp:35
ESIMD_INLINE ESIMD_NODEBUG T0 sum(simd< T1, SZ > v)
Definition: math.hpp:1009
std::enable_if_t< std::is_same_v< Tp, sycl::half2 >, sycl::half2 > sqrt(Tp x)
Definition: math.hpp:194
std::plus< T > plus
Definition: functional.hpp:20
sycl::minimum< T > minimum
Definition: functional.hpp:26
sycl::maximum< T > maximum
Definition: functional.hpp:25
return std::max(x, y) - std hadd
bool isnan(const ValueT a)
Definition: math.hpp:121
ValueT clamp(ValueT val, ValueT min_val, ValueT max_val)
Definition: math.hpp:51
detail::complex_namespace::complex< ValueT > complex_type
Definition: math.hpp:48
int64_t zero_or_signed_extent(ValueT val, unsigned bit)
Extend the 'val' to 'bit' size, zero extend for unsigned int and signed extend for signed int.
Definition: math.hpp:85
constexpr RetT extend_binary(AT a, BT b, BinaryOperation binary_op)
Definition: math.hpp:94
float pow(const float a, const int b)
Definition: math.hpp:499
constexpr RetT extend_absdiff_sat(AT a, BT b)
Extend a and b to 33 bit and do abs_diff with saturation.
Definition: math.hpp:856
sycl::half min(sycl::half a, sycl::half b)
Definition: math.hpp:441
constexpr RetT extend_add(AT a, BT b)
Extend a and b to 33 bit and add them.
Definition: math.hpp:703
T vectorized_max(T a, T b)
Compute vectorized max for two values, with each value treated as a vector type S.
Definition: math.hpp:301
std::enable_if_t< std::is_same_v< std::invoke_result_t< BinaryOperation, ValueT, ValueT >, bool >, bool > compare(const ValueT a, const ValueT b, const BinaryOperation binary_op)
Performs comparison.
Definition: math.hpp:188
constexpr RetT extend_sub(AT a, BT b)
Extend a and b to 33 bit and minus them.
Definition: math.hpp:764
constexpr RetT extend_max_sat(AT a, BT b)
Extend a and b to 33 bit and return bigger one with saturation.
Definition: math.hpp:980
float fast_length(const float *a, int len)
Compute fast_length for variable-length array.
Definition: math.hpp:136
constexpr RetT extend_add_sat(AT a, BT b)
Extend a and b to 33 bit and add them with saturation.
Definition: math.hpp:732
T cabs(sycl::vec< T, 2 > x)
Computes the magnitude of a complex number.
Definition: math.hpp:572
sycl::half max(sycl::half a, sycl::half b)
Definition: math.hpp:458
constexpr RetT extend_min(AT a, BT b)
Extend a and b to 33 bit and return smaller one.
Definition: math.hpp:888
std::enable_if_t< std::is_integral_v< ValueT > &&std::is_integral_v< ValueU >, std::common_type_t< ValueT, ValueU > > max(ValueT a, ValueU b)
Definition: math.hpp:446
constexpr RetT extend_max(AT a, BT b)
Extend a and b to 33 bit and return bigger one.
Definition: math.hpp:950
std::common_type_t< ValueT, ValueU > fmax_nan(const ValueT a, const ValueU b)
Performs 2 elements comparison and returns the bigger one.
Definition: math.hpp:466
constexpr RetT extend_absdiff(AT a, BT b)
Extend a and b to 33 bit and do abs_diff.
Definition: math.hpp:825
constexpr RetT extend_min_sat(AT a, BT b)
Extend a and b to 33 bit and return smaller one with saturation.
Definition: math.hpp:918
sycl::vec< T, 2 > conj(sycl::vec< T, 2 > x)
Computes the complex conjugate of a complex number.
Definition: math.hpp:581
std::enable_if_t< std::is_integral_v< ValueT > &&std::is_integral_v< ValueU >, std::common_type_t< ValueT, ValueU > > min(ValueT a, ValueU b)
Definition: math.hpp:429
std::enable_if_t< std::is_floating_point_v< ValueT >||std::is_same_v< sycl::half, ValueT >, ValueT > relu(const ValueT a)
Performs relu saturation.
Definition: math.hpp:524
std::enable_if_t< ValueT::size()==2, bool > unordered_compare_both(const ValueT a, const ValueT b, const BinaryOperation binary_op)
Performs 2 element unordered comparison and return true if both results are true.
Definition: math.hpp:257
std::enable_if_t< std::is_floating_point_v< ValueT >||std::is_same_v< sycl::half, ValueT >, ValueT > cbrt(ValueT val)
cbrt function wrapper.
Definition: math.hpp:417
sycl::vec< T, 2 > cdiv(sycl::vec< T, 2 > x, sycl::vec< T, 2 > y)
Computes the division of two complex numbers.
Definition: math.hpp:562
T vectorized_min(T a, T b)
Compute vectorized min for two values, with each value treated as a vector type S.
Definition: math.hpp:317
std::common_type_t< ValueT, ValueU > fmin_nan(const ValueT a, const ValueU b)
Performs 2 elements comparison and returns the smaller one.
Definition: math.hpp:485
sycl::vec< ValueT, 2 > cmul_add(const sycl::vec< ValueT, 2 > a, const sycl::vec< ValueT, 2 > b, const sycl::vec< ValueT, 2 > c)
Performs complex number multiply addition.
Definition: math.hpp:593
constexpr RetT extend_sub_sat(AT a, BT b)
Extend a and b to 33 bit and minus them with saturation.
Definition: math.hpp:793
ValueT clamp(ValueT val, ValueT min_val, ValueT max_val)
Returns min(max(val, min_val), max_val)
Definition: math.hpp:400
unsigned compare_mask(const sycl::vec< ValueT, 2 > a, const sycl::vec< ValueT, 2 > b, const BinaryOperation binary_op)
Performs 2 elements comparison, compare result of each element is 0 (false) or 0xffff (true),...
Definition: math.hpp:271
ValueT length(const ValueT *a, const int len)
Calculate the square root of the input array.
Definition: math.hpp:161
unsigned unordered_compare_mask(const sycl::vec< ValueT, 2 > a, const sycl::vec< ValueT, 2 > b, const BinaryOperation binary_op)
Performs 2 elements unordered comparison, compare result of each element is 0 (false) or 0xffff (true...
Definition: math.hpp:287
std::enable_if_t< ValueT::size()==2, bool > compare_both(const ValueT a, const ValueT b, const BinaryOperation binary_op)
Performs 2 element comparison and return true if both results are true.
Definition: math.hpp:245
unsigned vectorized_unary(unsigned a, const UnaryOperation unary_op)
Compute vectorized unary operation for a value, with the value treated as a vector type VecT.
Definition: math.hpp:333
unsigned vectorized_binary(unsigned a, unsigned b, const BinaryOperation binary_op)
Compute vectorized binary operation value for two values, with each value treated as a vector type Ve...
Definition: math.hpp:684
std::enable_if_t< ValueT::size()==2, ValueT > isnan(const ValueT a)
Determine whether 2 element value is NaN.
Definition: math.hpp:408
std::enable_if_t< std::is_same_v< std::invoke_result_t< BinaryOperation, ValueT, ValueT >, bool >, bool > unordered_compare(const ValueT a, const ValueT b, const BinaryOperation binary_op)
Performs unordered comparison.
Definition: math.hpp:220
std::enable_if_t<!std::is_floating_point_v< ValueT >, double > pow(const ValueT a, const ValueU b)
Definition: math.hpp:513
unsigned vectorized_sum_abs_diff(unsigned a, unsigned b)
Compute vectorized absolute difference for two values without modulo overflow, with each value treate...
Definition: math.hpp:348
T vectorized_isgreater(T a, T b)
Compute vectorized isgreater for two values, with each value treated as a vector type S.
Definition: math.hpp:367
sycl::vec< T, 2 > cmul(sycl::vec< T, 2 > x, sycl::vec< T, 2 > y)
Computes the multiplication of two complex numbers.
Definition: math.hpp:550
A sycl::abs_diff wrapper functors.
Definition: math.hpp:621
auto operator()(const ValueT x, const ValueT y) const
Definition: math.hpp:623
A sycl::abs wrapper functors.
Definition: math.hpp:614
auto operator()(const ValueT x) const
Definition: math.hpp:615
A sycl::add_sat wrapper functors.
Definition: math.hpp:629
auto operator()(const ValueT x, const ValueT y) const
Definition: math.hpp:631
A sycl::hadd wrapper functors.
Definition: math.hpp:645
auto operator()(const ValueT x, const ValueT y) const
Definition: math.hpp:647
A sycl::max wrapper functors.
Definition: math.hpp:653
auto operator()(const ValueT x, const ValueT y) const
Definition: math.hpp:655
A sycl::min wrapper functors.
Definition: math.hpp:661
auto operator()(const ValueT x, const ValueT y) const
Definition: math.hpp:663
A sycl::rhadd wrapper functors.
Definition: math.hpp:637
auto operator()(const ValueT x, const ValueT y) const
Definition: math.hpp:639
A sycl::sub_sat wrapper functors.
Definition: math.hpp:669
auto operator()(const ValueT x, const ValueT y) const
Definition: math.hpp:671