DPC++ Runtime
Runtime libraries for oneAPI DPC++
bfloat16_math.hpp
Go to the documentation of this file.
1 //==-------- bfloat16_math.hpp - SYCL bloat16 math functions ---------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #include <sycl/builtins.hpp> // for ceil, cos, exp, exp10, exp2
12 #include <sycl/builtins_utils_vec.hpp> // For simplify_if_swizzle, is_swizzle
13 #include <sycl/detail/memcpy.hpp> // sycl::detail::memcpy
14 #include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16, bfloat16ToBits
15 #include <sycl/marray.hpp> // for marray
16 
17 #include <cstring> // for size_t
18 #include <stdint.h> // for uint32_t
19 #include <type_traits> // for enable_if_t, is_same
20 
21 namespace sycl {
22 inline namespace _V1 {
23 namespace ext::oneapi::experimental {
24 
25 namespace detail {
26 template <size_t N>
27 uint32_t to_uint32_t(sycl::marray<bfloat16, N> x, size_t start) {
28  uint32_t res;
29  sycl::detail::memcpy(&res, &x[start], sizeof(uint32_t));
30  return res;
31 }
32 } // namespace detail
33 
34 // Trait to check if the type is a vector or swizzle of bfloat16.
35 template <typename T>
36 constexpr bool is_vec_or_swizzle_bf16_v =
37  sycl::detail::is_vec_or_swizzle_v<T> &&
38  sycl::detail::is_valid_elem_type_v<T, bfloat16>;
39 
40 template <typename T>
41 constexpr int num_elements_v = sycl::detail::num_elements<T>::value;
42 
43 /******************* isnan ********************/
44 
45 // According to bfloat16 format, NAN value's exponent field is 0xFF and
46 // significand has non-zero bits.
47 template <typename T>
48 std::enable_if_t<std::is_same_v<T, bfloat16>, bool> isnan(T x) {
50  return (((XBits & 0x7F80) == 0x7F80) && (XBits & 0x7F)) ? true : false;
51 }
52 
55  for (size_t i = 0; i < N; i++) {
56  res[i] = isnan(x[i]);
57  }
58  return res;
59 }
60 
61 // Overload for BF16 vec and swizzles.
62 template <typename T, int N = num_elements_v<T>>
63 std::enable_if_t<is_vec_or_swizzle_bf16_v<T>, sycl::vec<int16_t, N>>
64 isnan(T x) {
65 
66 #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
67  // Convert BFloat16 vector to float vec and call isnan().
68  sycl::vec<float, N> FVec =
69  x.template convert<float, sycl::rounding_mode::automatic>();
70  auto Res = isnan(FVec);
71 
72  // For vec<float>, the return type of isnan is vec<int32_t> so,
73  // an explicit conversion is required to vec<int16_t>.
74  return Res.template convert<int16_t>();
75 #else
76 
78  for (size_t i = 0; i < N; i++) {
79  // The result of isnan is 0 or 1 but SPEC requires
80  // isnan() of vec/swizzle to return -1 or 0.
81  res[i] = isnan(x[i]) ? -1 : 0;
82  }
83  return res;
84 #endif
85 }
86 
87 /******************* fabs ********************/
88 
89 template <typename T>
90 std::enable_if_t<std::is_same_v<T, bfloat16>, T> fabs(T x) {
91 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
92  (__SYCL_CUDA_ARCH__ >= 800)
94  return oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
95 #else
96  if (!isnan(x)) {
97  const static oneapi::detail::Bfloat16StorageT SignMask = 0x8000;
99  x = ((XBits & SignMask) == SignMask)
100  ? oneapi::detail::bitsToBfloat16(XBits & ~SignMask)
101  : x;
102  }
103  return x;
104 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
105  // (__SYCL_CUDA_ARCH__ >= 800)
106 }
107 
108 template <size_t N>
111 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
112  (__SYCL_CUDA_ARCH__ >= 800)
113  for (size_t i = 0; i < N / 2; i++) {
114  auto partial_res = __clc_fabs(detail::to_uint32_t(x, i * 2));
115  sycl::detail::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
116  }
117 
118  if (N % 2) {
121  res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
122  }
123 #else
124  for (size_t i = 0; i < N; i++) {
125  res[i] = fabs(x[i]);
126  }
127 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
128  // (__SYCL_CUDA_ARCH__ >= 800)
129  return res;
130 }
131 
132 // Overload for BF16 vec and swizzles.
133 template <typename T, int N = num_elements_v<T>>
134 std::enable_if_t<is_vec_or_swizzle_bf16_v<T>, sycl::vec<bfloat16, N>>
135 fabs(T x) {
136 #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
137  // Convert BFloat16 vector to float vec.
138  sycl::vec<float, N> FVec =
139  x.template convert<float, sycl::rounding_mode::automatic>();
140  auto Res = fabs(FVec);
141  return Res.template convert<bfloat16>();
142 #else
144  for (size_t i = 0; i < N; i++) {
145  res[i] = fabs(x[i]);
146  }
147  return res;
148 #endif
149 }
150 
151 /******************* fmin ********************/
152 
153 template <typename T>
154 std::enable_if_t<std::is_same_v<T, bfloat16>, T> fmin(T x, T y) {
155 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
156  (__SYCL_CUDA_ARCH__ >= 800)
159  return oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits));
160 #else
161  static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0;
162  if (isnan(x) && isnan(y))
163  return oneapi::detail::bitsToBfloat16(CanonicalNan);
164 
165  if (isnan(x))
166  return y;
167  if (isnan(y))
168  return x;
171  if (((XBits | YBits) ==
172  static_cast<oneapi::detail::Bfloat16StorageT>(0x8000)) &&
173  !(XBits & YBits))
175  static_cast<oneapi::detail::Bfloat16StorageT>(0x8000));
176 
177  return (x < y) ? x : y;
178 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
179  // (__SYCL_CUDA_ARCH__ >= 800)
180 }
181 
182 template <size_t N>
186 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
187  (__SYCL_CUDA_ARCH__ >= 800)
188  for (size_t i = 0; i < N / 2; i++) {
189  auto partial_res = __clc_fmin(detail::to_uint32_t(x, i * 2),
190  detail::to_uint32_t(y, i * 2));
191  sycl::detail::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
192  }
193 
194  if (N % 2) {
199  res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits));
200  }
201 #else
202  for (size_t i = 0; i < N; i++) {
203  res[i] = fmin(x[i], y[i]);
204  }
205 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
206  // (__SYCL_CUDA_ARCH__ >= 800)
207  return res;
208 }
209 
210 // Overload for different combination of BF16 vec and swizzles.
211 template <typename T1, typename T2, int N1 = num_elements_v<T1>,
212  int N2 = num_elements_v<T2>>
213 std::enable_if_t<is_vec_or_swizzle_bf16_v<T1> && is_vec_or_swizzle_bf16_v<T2> &&
214  N1 == N2,
216 fmin(T1 x, T2 y) {
217 #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
218  // Convert BFloat16 vectors to float vecs.
219  sycl::vec<float, N1> FVecX =
220  x.template convert<float, sycl::rounding_mode::automatic>();
221  sycl::vec<float, N1> FVecY =
222  y.template convert<float, sycl::rounding_mode::automatic>();
223  auto Res = fmin(FVecX, FVecY);
224  return Res.template convert<bfloat16>();
225 #else
227  for (size_t i = 0; i < N1; i++) {
228  res[i] = fmin(x[i], y[i]);
229  }
230  return res;
231 #endif
232 }
233 
234 /******************* fmax ********************/
235 
236 template <typename T>
237 std::enable_if_t<std::is_same_v<T, bfloat16>, T> fmax(T x, T y) {
238 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
239  (__SYCL_CUDA_ARCH__ >= 800)
242  return oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits));
243 #else
244  static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0;
245  if (isnan(x) && isnan(y))
246  return oneapi::detail::bitsToBfloat16(CanonicalNan);
247 
248  if (isnan(x))
249  return y;
250  if (isnan(y))
251  return x;
254  if (((XBits | YBits) ==
255  static_cast<oneapi::detail::Bfloat16StorageT>(0x8000)) &&
256  !(XBits & YBits))
258 
259  return (x > y) ? x : y;
260 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
261  // (__SYCL_CUDA_ARCH__ >= 800)
262 }
263 
264 template <size_t N>
268 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
269  (__SYCL_CUDA_ARCH__ >= 800)
270  for (size_t i = 0; i < N / 2; i++) {
271  auto partial_res = __clc_fmax(detail::to_uint32_t(x, i * 2),
272  detail::to_uint32_t(y, i * 2));
273  sycl::detail::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
274  }
275 
276  if (N % 2) {
281  res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits));
282  }
283 #else
284  for (size_t i = 0; i < N; i++) {
285  res[i] = fmax(x[i], y[i]);
286  }
287 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
288  // (__SYCL_CUDA_ARCH__ >= 800)
289  return res;
290 }
291 
292 // Overload for different combination of BF16 vec and swizzles.
293 template <typename T1, typename T2, int N1 = num_elements_v<T1>,
294  int N2 = num_elements_v<T2>>
295 std::enable_if_t<is_vec_or_swizzle_bf16_v<T1> && is_vec_or_swizzle_bf16_v<T2> &&
296  N1 == N2,
298 fmax(T1 x, T2 y) {
299 #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
300  // Convert BFloat16 vectors to float vecs.
301  sycl::vec<float, N1> FVecX =
302  x.template convert<float, sycl::rounding_mode::automatic>();
303  sycl::vec<float, N1> FVecY =
304  y.template convert<float, sycl::rounding_mode::automatic>();
305  auto Res = fmax(FVecX, FVecY);
306  return Res.template convert<bfloat16>();
307 #else
309  for (size_t i = 0; i < N1; i++) {
310  res[i] = fmax(x[i], y[i]);
311  }
312  return res;
313 #endif
314 }
315 
316 /******************* fma *********************/
317 
318 template <typename T>
319 std::enable_if_t<std::is_same_v<T, bfloat16>, T> fma(T x, T y, T z) {
320 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
321  (__SYCL_CUDA_ARCH__ >= 800)
325  return oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits));
326 #else
327  return sycl::ext::oneapi::bfloat16{sycl::fma(float{x}, float{y}, float{z})};
328 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
329  // (__SYCL_CUDA_ARCH__ >= 800)
330 }
331 
332 template <size_t N>
337 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
338  (__SYCL_CUDA_ARCH__ >= 800)
339  for (size_t i = 0; i < N / 2; i++) {
340  auto partial_res =
341  __clc_fma(detail::to_uint32_t(x, i * 2), detail::to_uint32_t(y, i * 2),
342  detail::to_uint32_t(z, i * 2));
343  sycl::detail::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
344  }
345 
346  if (N % 2) {
353  res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits));
354  }
355 #else
356  for (size_t i = 0; i < N; i++) {
357  res[i] = fma(x[i], y[i], z[i]);
358  }
359 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
360  // (__SYCL_CUDA_ARCH__ >= 800)
361  return res;
362 }
363 
364 // Overload for different combination of BF16 vec and swizzles.
365 template <typename T1, typename T2, typename T3, int N1 = num_elements_v<T1>,
366  int N2 = num_elements_v<T2>, int N3 = num_elements_v<T3>>
367 std::enable_if_t<is_vec_or_swizzle_bf16_v<T1> && is_vec_or_swizzle_bf16_v<T2> &&
368  is_vec_or_swizzle_bf16_v<T3> && N1 == N2 && N2 == N3,
370 fma(T1 x, T2 y, T3 z) {
371 #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
372  // Convert BFloat16 vectors to float vecs.
373  sycl::vec<float, N1> FVecX =
374  x.template convert<float, sycl::rounding_mode::automatic>();
375  sycl::vec<float, N1> FVecY =
376  y.template convert<float, sycl::rounding_mode::automatic>();
377  sycl::vec<float, N1> FVecZ =
378  z.template convert<float, sycl::rounding_mode::automatic>();
379 
380  auto Res = fma(FVecX, FVecY, FVecZ);
381  return Res.template convert<bfloat16>();
382 #else
384  for (size_t i = 0; i < N1; i++) {
385  res[i] = fma(x[i], y[i], z[i]);
386  }
387  return res;
388 #endif
389 }
390 
391 /******************* unary math operations ********************/
392 
393 #define BFLOAT16_MATH_FP32_WRAPPERS(op) \
394  template <typename T> \
395  std::enable_if_t<std::is_same<T, bfloat16>::value, T> op(T x) { \
396  return sycl::ext::oneapi::bfloat16{sycl::op(float{x})}; \
397  }
398 
399 #define BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(op) \
400  template <size_t N> \
401  sycl::marray<bfloat16, N> op(sycl::marray<bfloat16, N> x) { \
402  sycl::marray<bfloat16, N> res; \
403  for (size_t i = 0; i < N; i++) { \
404  res[i] = op(x[i]); \
405  } \
406  return res; \
407  }
408 
409 #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
410 #define BFLOAT16_MATH_FP32_WRAPPERS_VEC(op) \
411  /* Overload for BF16 vec and swizzles. */ \
412  template <typename T, int N = num_elements_v<T>> \
413  std::enable_if_t<is_vec_or_swizzle_bf16_v<T>, sycl::vec<bfloat16, N>> op( \
414  T x) { \
415  sycl::vec<float, N> FVec = \
416  x.template convert<float, sycl::rounding_mode::automatic>(); \
417  auto Res = op(FVec); \
418  return Res.template convert<bfloat16>(); \
419  }
420 #else
421 #define BFLOAT16_MATH_FP32_WRAPPERS_VEC(op) \
422  /* Overload for BF16 vec and swizzles. */ \
423  template <typename T, int N = num_elements_v<T>> \
424  std::enable_if_t<is_vec_or_swizzle_bf16_v<T>, sycl::vec<bfloat16, N>> op( \
425  T x) { \
426  sycl::vec<bfloat16, N> res; \
427  for (size_t i = 0; i < N; i++) { \
428  res[i] = op(x[i]); \
429  } \
430  return res; \
431  }
432 #endif
433 
437 
441 
445 
449 
453 
457 
461 
465 
469 
473 
477 
481 
485 
489 
490 #undef BFLOAT16_MATH_FP32_WRAPPERS
491 #undef BFLOAT16_MATH_FP32_WRAPPERS_MARRAY
492 #undef BFLOAT16_MATH_FP32_WRAPPERS_VEC
493 } // namespace ext::oneapi::experimental
494 } // namespace _V1
495 } // namespace sycl
#define BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(op)
#define BFLOAT16_MATH_FP32_WRAPPERS(op)
#define BFLOAT16_MATH_FP32_WRAPPERS_VEC(op)
Provides a cross-platform math array class template that works on SYCL devices as well as in host C++...
Definition: marray.hpp:49
__ESIMD_API simd< T, N > rsqrt(simd< T, N > src, Sat sat={})
Square root reciprocal - calculates 1/sqrt(x).
Definition: math.hpp:432
__ESIMD_API simd< T, N > log2(simd< T, N > src, Sat sat={})
Logarithm base 2.
Definition: math.hpp:398
__ESIMD_API simd< T, N > exp2(simd< T, N > src, Sat sat={})
Exponent base 2.
Definition: math.hpp:402
bfloat16 bitsToBfloat16(const Bfloat16StorageT Value)
Definition: bfloat16.hpp:304
Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value)
Definition: bfloat16.hpp:298
uint32_t to_uint32_t(sycl::marray< bfloat16, N > x, size_t start)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > sin(const complex< _Tp > &__x)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > cos(const complex< _Tp > &__x)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > sqrt(const complex< _Tp > &__x)
std::enable_if_t< std::is_same_v< T, bfloat16 >, bool > isnan(T x)
std::enable_if_t< std::is_same_v< T, bfloat16 >, T > fabs(T x)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > exp(const complex< _Tp > &__x)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > log(const complex< _Tp > &__x)
std::enable_if_t< std::is_same_v< T, bfloat16 >, T > fmin(T x, T y)
std::enable_if_t< std::is_same_v< T, bfloat16 >, T > fma(T x, T y, T z)
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY std::enable_if_t< is_genfloat< _Tp >::value, complex< _Tp > > log10(const complex< _Tp > &__x)
std::enable_if_t< std::is_same_v< T, bfloat16 >, T > fmax(T x, T y)
float ceil(float)
auto auto autodecltype(x) z
float floor(float)
float rint(float)
autodecltype(x) x
float trunc(float)
Definition: access.hpp:18