19 #include <type_traits>
22 inline namespace _V1 {
23 namespace ext::oneapi::experimental {
29 sycl::detail::memcpy(&res, &
x[start],
sizeof(uint32_t));
37 sycl::detail::is_vec_or_swizzle_v<T> &&
38 sycl::detail::is_valid_elem_type_v<T, bfloat16>;
48 std::enable_if_t<std::is_same_v<T, bfloat16>,
bool>
isnan(T x) {
50 return (((XBits & 0x7F80) == 0x7F80) && (XBits & 0x7F)) ? true :
false;
55 for (
size_t i = 0; i < N; i++) {
62 template <
typename T,
int N = num_elements_v<T>>
66 #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
69 x.template convert<float, sycl::rounding_mode::automatic>();
70 auto Res =
isnan(FVec);
74 return Res.template convert<int16_t>();
78 for (
size_t i = 0; i < N; i++) {
81 res[i] =
isnan(
x[i]) ? -1 : 0;
90 std::enable_if_t<std::is_same_v<T, bfloat16>, T>
fabs(T x) {
91 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
92 (__SYCL_CUDA_ARCH__ >= 800)
99 x = ((XBits & SignMask) == SignMask)
111 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
112 (__SYCL_CUDA_ARCH__ >= 800)
113 for (
size_t i = 0; i < N / 2; i++) {
115 sycl::detail::memcpy(&res[i * 2], &partial_res,
sizeof(uint32_t));
124 for (
size_t i = 0; i < N; i++) {
133 template <
typename T,
int N = num_elements_v<T>>
136 #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
139 x.template convert<float, sycl::rounding_mode::automatic>();
140 auto Res =
fabs(FVec);
141 return Res.template convert<bfloat16>();
144 for (
size_t i = 0; i < N; i++) {
153 template <
typename T>
154 std::enable_if_t<std::is_same_v<T, bfloat16>, T>
fmin(T x, T y) {
155 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
156 (__SYCL_CUDA_ARCH__ >= 800)
171 if (((XBits | YBits) ==
177 return (
x <
y) ?
x :
y;
186 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
187 (__SYCL_CUDA_ARCH__ >= 800)
188 for (
size_t i = 0; i < N / 2; i++) {
191 sycl::detail::memcpy(&res[i * 2], &partial_res,
sizeof(uint32_t));
202 for (
size_t i = 0; i < N; i++) {
203 res[i] =
fmin(
x[i],
y[i]);
211 template <
typename T1,
typename T2,
int N1 = num_elements_v<T1>,
212 int N2 = num_elements_v<T2>>
213 std::enable_if_t<is_vec_or_swizzle_bf16_v<T1> && is_vec_or_swizzle_bf16_v<T2> &&
217 #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
220 x.template convert<float, sycl::rounding_mode::automatic>();
222 y.template convert<float, sycl::rounding_mode::automatic>();
223 auto Res =
fmin(FVecX, FVecY);
224 return Res.template convert<bfloat16>();
227 for (
size_t i = 0; i < N1; i++) {
228 res[i] =
fmin(
x[i],
y[i]);
236 template <
typename T>
237 std::enable_if_t<std::is_same_v<T, bfloat16>, T>
fmax(T x, T y) {
238 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
239 (__SYCL_CUDA_ARCH__ >= 800)
254 if (((XBits | YBits) ==
259 return (
x >
y) ?
x :
y;
268 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
269 (__SYCL_CUDA_ARCH__ >= 800)
270 for (
size_t i = 0; i < N / 2; i++) {
273 sycl::detail::memcpy(&res[i * 2], &partial_res,
sizeof(uint32_t));
284 for (
size_t i = 0; i < N; i++) {
285 res[i] =
fmax(
x[i],
y[i]);
293 template <
typename T1,
typename T2,
int N1 = num_elements_v<T1>,
294 int N2 = num_elements_v<T2>>
295 std::enable_if_t<is_vec_or_swizzle_bf16_v<T1> && is_vec_or_swizzle_bf16_v<T2> &&
299 #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
302 x.template convert<float, sycl::rounding_mode::automatic>();
304 y.template convert<float, sycl::rounding_mode::automatic>();
305 auto Res =
fmax(FVecX, FVecY);
306 return Res.template convert<bfloat16>();
309 for (
size_t i = 0; i < N1; i++) {
310 res[i] =
fmax(
x[i],
y[i]);
318 template <
typename T>
319 std::enable_if_t<std::is_same_v<T, bfloat16>, T>
fma(T x, T y, T z) {
320 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
321 (__SYCL_CUDA_ARCH__ >= 800)
337 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
338 (__SYCL_CUDA_ARCH__ >= 800)
339 for (
size_t i = 0; i < N / 2; i++) {
343 sycl::detail::memcpy(&res[i * 2], &partial_res,
sizeof(uint32_t));
356 for (
size_t i = 0; i < N; i++) {
357 res[i] =
fma(
x[i],
y[i],
z[i]);
365 template <
typename T1,
typename T2,
typename T3,
int N1 = num_elements_v<T1>,
366 int N2 = num_elements_v<T2>,
int N3 = num_elements_v<T3>>
367 std::enable_if_t<is_vec_or_swizzle_bf16_v<T1> && is_vec_or_swizzle_bf16_v<T2> &&
368 is_vec_or_swizzle_bf16_v<T3> && N1 == N2 && N2 == N3,
371 #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
374 x.template convert<float, sycl::rounding_mode::automatic>();
376 y.template convert<float, sycl::rounding_mode::automatic>();
378 z.template convert<float, sycl::rounding_mode::automatic>();
380 auto Res =
fma(FVecX, FVecY, FVecZ);
381 return Res.template convert<bfloat16>();
384 for (
size_t i = 0; i < N1; i++) {
385 res[i] =
fma(
x[i],
y[i],
z[i]);
393 #define BFLOAT16_MATH_FP32_WRAPPERS(op) \
394 template <typename T> \
395 std::enable_if_t<std::is_same<T, bfloat16>::value, T> op(T x) { \
396 return sycl::ext::oneapi::bfloat16{sycl::op(float{x})}; \
399 #define BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(op) \
400 template <size_t N> \
401 sycl::marray<bfloat16, N> op(sycl::marray<bfloat16, N> x) { \
402 sycl::marray<bfloat16, N> res; \
403 for (size_t i = 0; i < N; i++) { \
409 #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
410 #define BFLOAT16_MATH_FP32_WRAPPERS_VEC(op) \
412 template <typename T, int N = num_elements_v<T>> \
413 std::enable_if_t<is_vec_or_swizzle_bf16_v<T>, sycl::vec<bfloat16, N>> op( \
415 sycl::vec<float, N> FVec = \
416 x.template convert<float, sycl::rounding_mode::automatic>(); \
417 auto Res = op(FVec); \
418 return Res.template convert<bfloat16>(); \
421 #define BFLOAT16_MATH_FP32_WRAPPERS_VEC(op) \
423 template <typename T, int N = num_elements_v<T>> \
424 std::enable_if_t<is_vec_or_swizzle_bf16_v<T>, sycl::vec<bfloat16, N>> op( \
426 sycl::vec<bfloat16, N> res; \
427 for (size_t i = 0; i < N; i++) { \
490 #undef BFLOAT16_MATH_FP32_WRAPPERS
491 #undef BFLOAT16_MATH_FP32_WRAPPERS_MARRAY
492 #undef BFLOAT16_MATH_FP32_WRAPPERS_VEC
#define BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(op)
#define BFLOAT16_MATH_FP32_WRAPPERS(op)
#define BFLOAT16_MATH_FP32_WRAPPERS_VEC(op)
Provides a cross-platform math array class template that works on SYCL devices as well as in host C++...
__ESIMD_API simd< T, N > rsqrt(simd< T, N > src, Sat sat={})
Square root reciprocal - calculates 1/sqrt(x).
__ESIMD_API simd< T, N > log2(simd< T, N > src, Sat sat={})
Logarithm base 2.
__ESIMD_API simd< T, N > exp2(simd< T, N > src, Sat sat={})
Exponent base 2.
bfloat16 bitsToBfloat16(const Bfloat16StorageT Value)
Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value)
uint16_t Bfloat16StorageT
uint32_t to_uint32_t(sycl::marray< bfloat16, N > x, size_t start)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > sin(const complex< _Tp > &__x)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > cos(const complex< _Tp > &__x)
constexpr int num_elements_v
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > sqrt(const complex< _Tp > &__x)
std::enable_if_t< std::is_same_v< T, bfloat16 >, bool > isnan(T x)
std::enable_if_t< std::is_same_v< T, bfloat16 >, T > fabs(T x)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > exp(const complex< _Tp > &__x)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > log(const complex< _Tp > &__x)
std::enable_if_t< std::is_same_v< T, bfloat16 >, T > fmin(T x, T y)
constexpr bool is_vec_or_swizzle_bf16_v
std::enable_if_t< std::is_same_v< T, bfloat16 >, T > fma(T x, T y, T z)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > log10(const complex< _Tp > &__x)
std::enable_if_t< std::is_same_v< T, bfloat16 >, T > fmax(T x, T y)
auto auto autodecltype(x) z