DPC++ Runtime
Runtime libraries for oneAPI DPC++
bfloat16_math.hpp
Go to the documentation of this file.
1 //==-------- bfloat16_math.hpp - SYCL bloat16 math functions ---------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #include <sycl/builtins.hpp> // for ceil, cos, exp, exp10, exp2
12 #include <sycl/builtins_utils_vec.hpp> // For simplify_if_swizzle, is_swizzle
13 #include <sycl/detail/memcpy.hpp> // sycl::detail::memcpy
14 #include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16, bfloat16ToBits
15 #include <sycl/marray.hpp> // for marray
16 
17 #include <cstring> // for size_t
18 #include <stdint.h> // for uint32_t
19 #include <type_traits> // for enable_if_t, is_same
20 
21 namespace sycl {
22 inline namespace _V1 {
23 namespace ext::oneapi::experimental {
24 
25 namespace detail {
26 template <size_t N>
27 uint32_t to_uint32_t(sycl::marray<bfloat16, N> x, size_t start) {
28  uint32_t res;
29  sycl::detail::memcpy(&res, &x[start], sizeof(uint32_t));
30  return res;
31 }
32 } // namespace detail
33 
34 // Trait to check if the type is a vector or swizzle of bfloat16.
35 template <typename T>
36 constexpr bool is_vec_or_swizzle_bf16_v =
37  sycl::detail::is_vec_or_swizzle_v<T> &&
38  sycl::detail::is_valid_elem_type_v<T, bfloat16>;
39 
40 template <typename T>
41 constexpr int num_elements_v = sycl::detail::num_elements<T>::value;
42 
43 /******************* isnan ********************/
44 
45 // According to bfloat16 format, NAN value's exponent field is 0xFF and
46 // significand has non-zero bits.
47 template <typename T>
48 std::enable_if_t<std::is_same_v<T, bfloat16>, bool> isnan(T x) {
50  return (((XBits & 0x7F80) == 0x7F80) && (XBits & 0x7F)) ? true : false;
51 }
52 
55  for (size_t i = 0; i < N; i++) {
56  res[i] = isnan(x[i]);
57  }
58  return res;
59 }
60 
61 // Overload for BF16 vec and swizzles.
62 template <typename T, int N = num_elements_v<T>>
63 std::enable_if_t<is_vec_or_swizzle_bf16_v<T>, sycl::vec<int16_t, N>>
64 isnan(T x) {
66  for (size_t i = 0; i < N; i++) {
67  // The result of isnan is 0 or 1 but SPEC requires
68  // isnan() of vec/swizzle to return -1 or 0.
69  res[i] = isnan(x[i]) ? -1 : 0;
70  }
71  return res;
72 }
73 
74 /******************* fabs ********************/
75 
76 template <typename T>
77 std::enable_if_t<std::is_same_v<T, bfloat16>, T> fabs(T x) {
78 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
79  (__SYCL_CUDA_ARCH__ >= 800)
81  return oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
82 #else
83  if (!isnan(x)) {
84  const static oneapi::detail::Bfloat16StorageT SignMask = 0x8000;
86  x = ((XBits & SignMask) == SignMask)
87  ? oneapi::detail::bitsToBfloat16(XBits & ~SignMask)
88  : x;
89  }
90  return x;
91 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
92  // (__SYCL_CUDA_ARCH__ >= 800)
93 }
94 
95 template <size_t N>
98 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
99  (__SYCL_CUDA_ARCH__ >= 800)
100  for (size_t i = 0; i < N / 2; i++) {
101  auto partial_res = __clc_fabs(detail::to_uint32_t(x, i * 2));
102  sycl::detail::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
103  }
104 
105  if (N % 2) {
108  res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
109  }
110 #else
111  for (size_t i = 0; i < N; i++) {
112  res[i] = fabs(x[i]);
113  }
114 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
115  // (__SYCL_CUDA_ARCH__ >= 800)
116  return res;
117 }
118 
119 // Overload for BF16 vec and swizzles.
120 template <typename T, int N = num_elements_v<T>>
121 std::enable_if_t<is_vec_or_swizzle_bf16_v<T>, sycl::vec<bfloat16, N>>
122 fabs(T x) {
124  for (size_t i = 0; i < N; i++) {
125  res[i] = fabs(x[i]);
126  }
127  return res;
128 }
129 
130 /******************* fmin ********************/
131 
132 template <typename T>
133 std::enable_if_t<std::is_same_v<T, bfloat16>, T> fmin(T x, T y) {
134 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
135  (__SYCL_CUDA_ARCH__ >= 800)
138  return oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits));
139 #else
140  static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0;
141  if (isnan(x) && isnan(y))
142  return oneapi::detail::bitsToBfloat16(CanonicalNan);
143 
144  if (isnan(x))
145  return y;
146  if (isnan(y))
147  return x;
150  if (((XBits | YBits) ==
151  static_cast<oneapi::detail::Bfloat16StorageT>(0x8000)) &&
152  !(XBits & YBits))
154  static_cast<oneapi::detail::Bfloat16StorageT>(0x8000));
155 
156  return (x < y) ? x : y;
157 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
158  // (__SYCL_CUDA_ARCH__ >= 800)
159 }
160 
161 template <size_t N>
165 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
166  (__SYCL_CUDA_ARCH__ >= 800)
167  for (size_t i = 0; i < N / 2; i++) {
168  auto partial_res = __clc_fmin(detail::to_uint32_t(x, i * 2),
169  detail::to_uint32_t(y, i * 2));
170  sycl::detail::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
171  }
172 
173  if (N % 2) {
178  res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits));
179  }
180 #else
181  for (size_t i = 0; i < N; i++) {
182  res[i] = fmin(x[i], y[i]);
183  }
184 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
185  // (__SYCL_CUDA_ARCH__ >= 800)
186  return res;
187 }
188 
189 // Overload for different combination of BF16 vec and swizzles.
190 template <typename T1, typename T2, int N1 = num_elements_v<T1>,
191  int N2 = num_elements_v<T2>>
192 std::enable_if_t<is_vec_or_swizzle_bf16_v<T1> && is_vec_or_swizzle_bf16_v<T2> &&
193  N1 == N2,
195 fmin(T1 x, T2 y) {
197  for (size_t i = 0; i < N1; i++) {
198  res[i] = fmin(x[i], y[i]);
199  }
200  return res;
201 }
202 
203 /******************* fmax ********************/
204 
205 template <typename T>
206 std::enable_if_t<std::is_same_v<T, bfloat16>, T> fmax(T x, T y) {
207 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
208  (__SYCL_CUDA_ARCH__ >= 800)
211  return oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits));
212 #else
213  static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0;
214  if (isnan(x) && isnan(y))
215  return oneapi::detail::bitsToBfloat16(CanonicalNan);
216 
217  if (isnan(x))
218  return y;
219  if (isnan(y))
220  return x;
223  if (((XBits | YBits) ==
224  static_cast<oneapi::detail::Bfloat16StorageT>(0x8000)) &&
225  !(XBits & YBits))
227 
228  return (x > y) ? x : y;
229 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
230  // (__SYCL_CUDA_ARCH__ >= 800)
231 }
232 
233 template <size_t N>
237 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
238  (__SYCL_CUDA_ARCH__ >= 800)
239  for (size_t i = 0; i < N / 2; i++) {
240  auto partial_res = __clc_fmax(detail::to_uint32_t(x, i * 2),
241  detail::to_uint32_t(y, i * 2));
242  sycl::detail::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
243  }
244 
245  if (N % 2) {
250  res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits));
251  }
252 #else
253  for (size_t i = 0; i < N; i++) {
254  res[i] = fmax(x[i], y[i]);
255  }
256 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
257  // (__SYCL_CUDA_ARCH__ >= 800)
258  return res;
259 }
260 
261 // Overload for different combination of BF16 vec and swizzles.
262 template <typename T1, typename T2, int N1 = num_elements_v<T1>,
263  int N2 = num_elements_v<T2>>
264 std::enable_if_t<is_vec_or_swizzle_bf16_v<T1> && is_vec_or_swizzle_bf16_v<T2> &&
265  N1 == N2,
267 fmax(T1 x, T2 y) {
269  for (size_t i = 0; i < N1; i++) {
270  res[i] = fmax(x[i], y[i]);
271  }
272  return res;
273 }
274 
275 /******************* fma *********************/
276 
277 template <typename T>
278 std::enable_if_t<std::is_same_v<T, bfloat16>, T> fma(T x, T y, T z) {
279 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
280  (__SYCL_CUDA_ARCH__ >= 800)
284  return oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits));
285 #else
286  return sycl::ext::oneapi::bfloat16{sycl::fma(float{x}, float{y}, float{z})};
287 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
288  // (__SYCL_CUDA_ARCH__ >= 800)
289 }
290 
291 template <size_t N>
296 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
297  (__SYCL_CUDA_ARCH__ >= 800)
298  for (size_t i = 0; i < N / 2; i++) {
299  auto partial_res =
300  __clc_fma(detail::to_uint32_t(x, i * 2), detail::to_uint32_t(y, i * 2),
301  detail::to_uint32_t(z, i * 2));
302  sycl::detail::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
303  }
304 
305  if (N % 2) {
312  res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits));
313  }
314 #else
315  for (size_t i = 0; i < N; i++) {
316  res[i] = fma(x[i], y[i], z[i]);
317  }
318 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
319  // (__SYCL_CUDA_ARCH__ >= 800)
320  return res;
321 }
322 
323 // Overload for different combination of BF16 vec and swizzles.
324 template <typename T1, typename T2, typename T3, int N1 = num_elements_v<T1>,
325  int N2 = num_elements_v<T2>, int N3 = num_elements_v<T3>>
326 std::enable_if_t<is_vec_or_swizzle_bf16_v<T1> && is_vec_or_swizzle_bf16_v<T2> &&
327  is_vec_or_swizzle_bf16_v<T3> && N1 == N2 && N2 == N3,
329 fma(T1 x, T2 y, T3 z) {
331  for (size_t i = 0; i < N1; i++) {
332  res[i] = fma(x[i], y[i], z[i]);
333  }
334  return res;
335 }
336 
337 /******************* unary math operations ********************/
338 
339 #define BFLOAT16_MATH_FP32_WRAPPERS(op) \
340  template <typename T> \
341  std::enable_if_t<std::is_same<T, bfloat16>::value, T> op(T x) { \
342  return sycl::ext::oneapi::bfloat16{sycl::op(float{x})}; \
343  }
344 
345 #define BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(op) \
346  template <size_t N> \
347  sycl::marray<bfloat16, N> op(sycl::marray<bfloat16, N> x) { \
348  sycl::marray<bfloat16, N> res; \
349  for (size_t i = 0; i < N; i++) { \
350  res[i] = op(x[i]); \
351  } \
352  return res; \
353  }
354 
355 #define BFLOAT16_MATH_FP32_WRAPPERS_VEC(op) \
356  /* Overload for BF16 vec and swizzles. */ \
357  template <typename T, int N = num_elements_v<T>> \
358  std::enable_if_t<is_vec_or_swizzle_bf16_v<T>, sycl::vec<bfloat16, N>> op( \
359  T x) { \
360  sycl::vec<bfloat16, N> res; \
361  for (size_t i = 0; i < N; i++) { \
362  res[i] = op(x[i]); \
363  } \
364  return res; \
365  }
366 
370 
374 
378 
382 
386 
390 
394 
398 
402 
406 
410 
414 
418 
422 
423 #undef BFLOAT16_MATH_FP32_WRAPPERS
424 #undef BFLOAT16_MATH_FP32_WRAPPERS_MARRAY
425 #undef BFLOAT16_MATH_FP32_WRAPPERS_VEC
426 } // namespace ext::oneapi::experimental
427 } // namespace _V1
428 } // namespace sycl
#define BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(op)
#define BFLOAT16_MATH_FP32_WRAPPERS(op)
#define BFLOAT16_MATH_FP32_WRAPPERS_VEC(op)
Provides a cross-platform math array class template that works on SYCL devices as well as in host C++...
Definition: marray.hpp:49
class sycl::vec ///////////////////////// Provides a cross-patform vector class template that works e...
__ESIMD_API simd< T, N > rsqrt(simd< T, N > src, Sat sat={})
Square root reciprocal - calculates 1/sqrt(x).
Definition: math.hpp:432
__ESIMD_API simd< T, N > log2(simd< T, N > src, Sat sat={})
Logarithm base 2.
Definition: math.hpp:398
__ESIMD_API simd< T, N > exp2(simd< T, N > src, Sat sat={})
Exponent base 2.
Definition: math.hpp:402
bfloat16 bitsToBfloat16(const Bfloat16StorageT Value)
Definition: bfloat16.hpp:323
Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value)
Definition: bfloat16.hpp:317
uint32_t to_uint32_t(sycl::marray< bfloat16, N > x, size_t start)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > sin(const complex< _Tp > &__x)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > cos(const complex< _Tp > &__x)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > sqrt(const complex< _Tp > &__x)
std::enable_if_t< std::is_same_v< T, bfloat16 >, bool > isnan(T x)
std::enable_if_t< std::is_same_v< T, bfloat16 >, T > fabs(T x)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > exp(const complex< _Tp > &__x)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > log(const complex< _Tp > &__x)
std::enable_if_t< std::is_same_v< T, bfloat16 >, T > fmin(T x, T y)
std::enable_if_t< std::is_same_v< T, bfloat16 >, T > fma(T x, T y, T z)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > log10(const complex< _Tp > &__x)
std::enable_if_t< std::is_same_v< T, bfloat16 >, T > fmax(T x, T y)
float ceil(float)
auto auto autodecltype(x) z
float floor(float)
float rint(float)
autodecltype(x) x
float trunc(float)
Definition: access.hpp:18