DPC++ Runtime
Runtime libraries for oneAPI DPC++
bfloat16_math.hpp
Go to the documentation of this file.
1 //==-------- bfloat16_math.hpp - SYCL bloat16 math functions ---------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #include <sycl/builtins.hpp> // for ceil, cos, exp, exp10, exp2
12 #include <sycl/detail/memcpy.hpp> // sycl::detail::memcpy
13 #include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16, bfloat16ToBits
14 #include <sycl/marray.hpp> // for marray
15 
16 #include <cstring> // for size_t
17 #include <stdint.h> // for uint32_t
18 #include <type_traits> // for enable_if_t, is_same
19 
20 namespace sycl {
21 inline namespace _V1 {
22 namespace ext::oneapi::experimental {
23 
24 namespace detail {
25 template <size_t N>
26 uint32_t to_uint32_t(sycl::marray<bfloat16, N> x, size_t start) {
27  uint32_t res;
28  sycl::detail::memcpy(&res, &x[start], sizeof(uint32_t));
29  return res;
30 }
31 } // namespace detail
32 
33 // According to bfloat16 format, NAN value's exponent field is 0xFF and
34 // significand has non-zero bits.
35 template <typename T>
36 std::enable_if_t<std::is_same_v<T, bfloat16>, bool> isnan(T x) {
38  return (((XBits & 0x7F80) == 0x7F80) && (XBits & 0x7F)) ? true : false;
39 }
40 
43  for (size_t i = 0; i < N; i++) {
44  res[i] = isnan(x[i]);
45  }
46  return res;
47 }
48 
49 template <typename T>
50 std::enable_if_t<std::is_same_v<T, bfloat16>, T> fabs(T x) {
51 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
52  (__SYCL_CUDA_ARCH__ >= 800)
54  return oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
55 #else
56  if (!isnan(x)) {
57  const static oneapi::detail::Bfloat16StorageT SignMask = 0x8000;
59  x = ((XBits & SignMask) == SignMask)
60  ? oneapi::detail::bitsToBfloat16(XBits & ~SignMask)
61  : x;
62  }
63  return x;
64 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
65  // (__SYCL_CUDA_ARCH__ >= 800)
66 }
67 
68 template <size_t N>
71 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
72  (__SYCL_CUDA_ARCH__ >= 800)
73  for (size_t i = 0; i < N / 2; i++) {
74  auto partial_res = __clc_fabs(detail::to_uint32_t(x, i * 2));
75  sycl::detail::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
76  }
77 
78  if (N % 2) {
81  res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
82  }
83 #else
84  for (size_t i = 0; i < N; i++) {
85  res[i] = fabs(x[i]);
86  }
87 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
88  // (__SYCL_CUDA_ARCH__ >= 800)
89  return res;
90 }
91 
92 template <typename T>
93 std::enable_if_t<std::is_same_v<T, bfloat16>, T> fmin(T x, T y) {
94 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
95  (__SYCL_CUDA_ARCH__ >= 800)
98  return oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits));
99 #else
100  static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0;
101  if (isnan(x) && isnan(y))
102  return oneapi::detail::bitsToBfloat16(CanonicalNan);
103 
104  if (isnan(x))
105  return y;
106  if (isnan(y))
107  return x;
110  if (((XBits | YBits) ==
111  static_cast<oneapi::detail::Bfloat16StorageT>(0x8000)) &&
112  !(XBits & YBits))
114  static_cast<oneapi::detail::Bfloat16StorageT>(0x8000));
115 
116  return (x < y) ? x : y;
117 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
118  // (__SYCL_CUDA_ARCH__ >= 800)
119 }
120 
121 template <size_t N>
125 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
126  (__SYCL_CUDA_ARCH__ >= 800)
127  for (size_t i = 0; i < N / 2; i++) {
128  auto partial_res = __clc_fmin(detail::to_uint32_t(x, i * 2),
129  detail::to_uint32_t(y, i * 2));
130  sycl::detail::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
131  }
132 
133  if (N % 2) {
138  res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits));
139  }
140 #else
141  for (size_t i = 0; i < N; i++) {
142  res[i] = fmin(x[i], y[i]);
143  }
144 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
145  // (__SYCL_CUDA_ARCH__ >= 800)
146  return res;
147 }
148 
149 template <typename T>
150 std::enable_if_t<std::is_same_v<T, bfloat16>, T> fmax(T x, T y) {
151 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
152  (__SYCL_CUDA_ARCH__ >= 800)
155  return oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits));
156 #else
157  static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0;
158  if (isnan(x) && isnan(y))
159  return oneapi::detail::bitsToBfloat16(CanonicalNan);
160 
161  if (isnan(x))
162  return y;
163  if (isnan(y))
164  return x;
167  if (((XBits | YBits) ==
168  static_cast<oneapi::detail::Bfloat16StorageT>(0x8000)) &&
169  !(XBits & YBits))
171 
172  return (x > y) ? x : y;
173 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
174  // (__SYCL_CUDA_ARCH__ >= 800)
175 }
176 
177 template <size_t N>
181 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
182  (__SYCL_CUDA_ARCH__ >= 800)
183  for (size_t i = 0; i < N / 2; i++) {
184  auto partial_res = __clc_fmax(detail::to_uint32_t(x, i * 2),
185  detail::to_uint32_t(y, i * 2));
186  sycl::detail::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
187  }
188 
189  if (N % 2) {
194  res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits));
195  }
196 #else
197  for (size_t i = 0; i < N; i++) {
198  res[i] = fmax(x[i], y[i]);
199  }
200 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
201  // (__SYCL_CUDA_ARCH__ >= 800)
202  return res;
203 }
204 
205 template <typename T>
206 std::enable_if_t<std::is_same_v<T, bfloat16>, T> fma(T x, T y, T z) {
207 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
208  (__SYCL_CUDA_ARCH__ >= 800)
212  return oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits));
213 #else
214  return sycl::ext::oneapi::bfloat16{sycl::fma(float{x}, float{y}, float{z})};
215 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
216  // (__SYCL_CUDA_ARCH__ >= 800)
217 }
218 
219 template <size_t N>
224 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
225  (__SYCL_CUDA_ARCH__ >= 800)
226  for (size_t i = 0; i < N / 2; i++) {
227  auto partial_res =
228  __clc_fma(detail::to_uint32_t(x, i * 2), detail::to_uint32_t(y, i * 2),
229  detail::to_uint32_t(z, i * 2));
230  sycl::detail::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
231  }
232 
233  if (N % 2) {
240  res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits));
241  }
242 #else
243  for (size_t i = 0; i < N; i++) {
244  res[i] = fma(x[i], y[i], z[i]);
245  }
246 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
247  // (__SYCL_CUDA_ARCH__ >= 800)
248  return res;
249 }
250 
251 #define BFLOAT16_MATH_FP32_WRAPPERS(op) \
252  template <typename T> \
253  std::enable_if_t<std::is_same<T, bfloat16>::value, T> op(T x) { \
254  return sycl::ext::oneapi::bfloat16{sycl::op(float{x})}; \
255  }
256 
257 #define BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(op) \
258  template <size_t N> \
259  sycl::marray<bfloat16, N> op(sycl::marray<bfloat16, N> x) { \
260  sycl::marray<bfloat16, N> res; \
261  for (size_t i = 0; i < N; i++) { \
262  res[i] = op(x[i]); \
263  } \
264  return res; \
265  }
266 
295 
296 #undef BFLOAT16_MATH_FP32_WRAPPERS
297 #undef BFLOAT16_MATH_FP32_WRAPPERS_MARRAY
298 } // namespace ext::oneapi::experimental
299 } // namespace _V1
300 } // namespace sycl
#define BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(op)
#define BFLOAT16_MATH_FP32_WRAPPERS(op)
Provides a cross-platform math array class template that works on SYCL devices as well as in host C++...
Definition: marray.hpp:49
__ESIMD_API simd< T, N > rsqrt(simd< T, N > src, Sat sat={})
Square root reciprocal - calculates 1/sqrt(x).
Definition: math.hpp:396
__ESIMD_API simd< T, N > log2(simd< T, N > src, Sat sat={})
Logarithm base 2.
Definition: math.hpp:380
__ESIMD_API simd< T, N > exp2(simd< T, N > src, Sat sat={})
Exponent base 2.
Definition: math.hpp:384
bfloat16 bitsToBfloat16(const Bfloat16StorageT Value)
Definition: bfloat16.hpp:247
Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value)
Definition: bfloat16.hpp:241
uint32_t to_uint32_t(sycl::marray< bfloat16, N > x, size_t start)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > sin(const complex< _Tp > &__x)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > cos(const complex< _Tp > &__x)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > sqrt(const complex< _Tp > &__x)
std::enable_if_t< std::is_same_v< T, bfloat16 >, bool > isnan(T x)
std::enable_if_t< std::is_same_v< T, bfloat16 >, T > fabs(T x)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > exp(const complex< _Tp > &__x)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > log(const complex< _Tp > &__x)
std::enable_if_t< std::is_same_v< T, bfloat16 >, T > fmin(T x, T y)
std::enable_if_t< std::is_same_v< T, bfloat16 >, T > fma(T x, T y, T z)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > log10(const complex< _Tp > &__x)
std::enable_if_t< std::is_same_v< T, bfloat16 >, T > fmax(T x, T y)
float ceil(float)
auto auto autodecltype(x) z
float floor(float)
float rint(float)
autodecltype(x) x
float trunc(float)
Definition: access.hpp:18