DPC++ Runtime
Runtime libraries for oneAPI DPC++
bfloat16_math.hpp
Go to the documentation of this file.
1 //==-------- bfloat16_math.hpp - SYCL bloat16 math functions ---------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
12 #include <sycl/exception.hpp>
14 #include <sycl/marray.hpp>
15 
16 #include <cstring>
17 #include <tuple>
18 #include <type_traits>
19 
20 namespace sycl {
22 namespace ext::oneapi::experimental {
23 
24 namespace detail {
25 template <size_t N>
26 uint32_t to_uint32_t(sycl::marray<bfloat16, N> x, size_t start) {
27  uint32_t res;
28  std::memcpy(&res, &x[start], sizeof(uint32_t));
29  return res;
30 }
31 } // namespace detail
32 
33 // According to bfloat16 format, NAN value's exponent field is 0xFF and
34 // significand has non-zero bits.
35 template <typename T>
36 std::enable_if_t<std::is_same_v<T, bfloat16>, bool> isnan(T x) {
38  return (((XBits & 0x7F80) == 0x7F80) && (XBits & 0x7F)) ? true : false;
39 }
40 
41 template <size_t N> sycl::marray<bool, N> isnan(sycl::marray<bfloat16, N> x) {
42  sycl::marray<bool, N> res;
43  for (size_t i = 0; i < N; i++) {
44  res[i] = isnan(x[i]);
45  }
46  return res;
47 }
48 
49 template <typename T>
50 std::enable_if_t<std::is_same_v<T, bfloat16>, T> fabs(T x) {
51 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
53  return oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
54 #else
55  if (!isnan(x)) {
56  const static oneapi::detail::Bfloat16StorageT SignMask = 0x8000;
58  x = ((XBits & SignMask) == SignMask)
59  ? oneapi::detail::bitsToBfloat16(XBits & ~SignMask)
60  : x;
61  }
62  return x;
63 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
64 }
65 
66 template <size_t N>
67 sycl::marray<bfloat16, N> fabs(sycl::marray<bfloat16, N> x) {
68  sycl::marray<bfloat16, N> res;
69 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
70  for (size_t i = 0; i < N / 2; i++) {
71  auto partial_res = __clc_fabs(detail::to_uint32_t(x, i * 2));
72  std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
73  }
74 
75  if (N % 2) {
78  res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
79  }
80 #else
81  for (size_t i = 0; i < N; i++) {
82  res[i] = fabs(x[i]);
83  }
84 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
85  return res;
86 }
87 
88 template <typename T>
89 std::enable_if_t<std::is_same_v<T, bfloat16>, T> fmin(T x, T y) {
90 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
93  return oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits));
94 #else
95  static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0;
98  if (isnan(x) && isnan(y))
99  return oneapi::detail::bitsToBfloat16(CanonicalNan);
100 
101  if (isnan(x))
102  return y;
103  if (isnan(y))
104  return x;
105  if (((XBits | YBits) ==
106  static_cast<oneapi::detail::Bfloat16StorageT>(0x8000)) &&
107  !(XBits & YBits))
109  static_cast<oneapi::detail::Bfloat16StorageT>(0x8000));
110 
111  return (x < y) ? x : y;
112 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
113 }
114 
115 template <size_t N>
116 sycl::marray<bfloat16, N> fmin(sycl::marray<bfloat16, N> x,
117  sycl::marray<bfloat16, N> y) {
118  sycl::marray<bfloat16, N> res;
119 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
120  for (size_t i = 0; i < N / 2; i++) {
121  auto partial_res = __clc_fmin(detail::to_uint32_t(x, i * 2),
122  detail::to_uint32_t(y, i * 2));
123  std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
124  }
125 
126  if (N % 2) {
131  res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits));
132  }
133 #else
134  for (size_t i = 0; i < N; i++) {
135  res[i] = fmin(x[i], y[i]);
136  }
137 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
138  return res;
139 }
140 
141 template <typename T>
142 std::enable_if_t<std::is_same_v<T, bfloat16>, T> fmax(T x, T y) {
143 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
146  return oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits));
147 #else
148  static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0;
151  if (isnan(x) && isnan(y))
152  return oneapi::detail::bitsToBfloat16(CanonicalNan);
153 
154  if (isnan(x))
155  return y;
156  if (isnan(y))
157  return x;
158  if (((XBits | YBits) ==
159  static_cast<oneapi::detail::Bfloat16StorageT>(0x8000)) &&
160  !(XBits & YBits))
162 
163  return (x > y) ? x : y;
164 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
165 }
166 
167 template <size_t N>
168 sycl::marray<bfloat16, N> fmax(sycl::marray<bfloat16, N> x,
169  sycl::marray<bfloat16, N> y) {
170  sycl::marray<bfloat16, N> res;
171 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
172  for (size_t i = 0; i < N / 2; i++) {
173  auto partial_res = __clc_fmax(detail::to_uint32_t(x, i * 2),
174  detail::to_uint32_t(y, i * 2));
175  std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
176  }
177 
178  if (N % 2) {
183  res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits));
184  }
185 #else
186  for (size_t i = 0; i < N; i++) {
187  res[i] = fmax(x[i], y[i]);
188  }
189 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
190  return res;
191 }
192 
193 template <typename T>
194 std::enable_if_t<std::is_same_v<T, bfloat16>, T> fma(T x, T y, T z) {
195 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
199  return oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits));
200 #else
201  return sycl::ext::oneapi::bfloat16{sycl::fma(float{x}, float{y}, float{z})};
202 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
203 }
204 
205 template <size_t N>
206 sycl::marray<bfloat16, N> fma(sycl::marray<bfloat16, N> x,
207  sycl::marray<bfloat16, N> y,
208  sycl::marray<bfloat16, N> z) {
209  sycl::marray<bfloat16, N> res;
210 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
211  for (size_t i = 0; i < N / 2; i++) {
212  auto partial_res =
213  __clc_fma(detail::to_uint32_t(x, i * 2), detail::to_uint32_t(y, i * 2),
214  detail::to_uint32_t(z, i * 2));
215  std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
216  }
217 
218  if (N % 2) {
225  res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits));
226  }
227 #else
228  for (size_t i = 0; i < N; i++) {
229  res[i] = fma(x[i], y[i], z[i]);
230  }
231 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
232  return res;
233 }
234 
235 #define BFLOAT16_MATH_FP32_WRAPPERS(op) \
236  template <typename T> \
237  std::enable_if_t<std::is_same<T, bfloat16>::value, T> op(T x) { \
238  return sycl::ext::oneapi::bfloat16{sycl::op(float{x})}; \
239  }
240 
241 #define BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(op) \
242  template <size_t N> \
243  sycl::marray<bfloat16, N> op(sycl::marray<bfloat16, N> x) { \
244  sycl::marray<bfloat16, N> res; \
245  for (size_t i = 0; i < N; i++) { \
246  res[i] = op(x[i]); \
247  } \
248  return res; \
249  }
250 
279 
280 #undef BFLOAT16_MATH_FP32_WRAPPERS
281 #undef BFLOAT16_MATH_FP32_WRAPPERS_MARRAY
282 } // namespace ext::oneapi::experimental
283 } // __SYCL_INLINE_VER_NAMESPACE(_V1)
284 } // namespace sycl
sycl::_V1::ext::intel::esimd::exp2
__ESIMD_API simd< T, N > exp2(simd< T, N > src, Sat sat={})
Exponent base 2.
Definition: math.hpp:385
__SYCL_INLINE_VER_NAMESPACE
#define __SYCL_INLINE_VER_NAMESPACE(X)
Definition: defines_elementary.hpp:11
sycl::_V1::ext::oneapi::fma
std::enable_if_t< detail::is_bf16_storage_type< T >::value, T > fma(T x, T y, T z)
Definition: bf16_storage_builtins.hpp:71
sycl::_V1::detail::memcpy
void memcpy(void *Dst, const void *Src, size_t Size)
Definition: memcpy.hpp:16
sycl::_V1::cos
ESIMD_NODEBUG ESIMD_INLINE sycl::ext::intel::esimd::simd< float, SZ > cos(sycl::ext::intel::esimd::simd< float, SZ > x) __NOEXC
Definition: builtins_esimd.hpp:27
sycl::_V1::exp
ESIMD_NODEBUG ESIMD_INLINE sycl::ext::intel::esimd::simd< float, SZ > exp(sycl::ext::intel::esimd::simd< float, SZ > x) __NOEXC
Definition: builtins_esimd.hpp:49
sycl
---— Error handling, matching OpenCL plugin semantics.
Definition: access.hpp:14
BFLOAT16_MATH_FP32_WRAPPERS_MARRAY
#define BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(op)
Definition: bfloat16_math.hpp:241
sycl::_V1::ext::oneapi::experimental::detail::to_uint32_t
uint32_t to_uint32_t(sycl::marray< bfloat16, N > x, size_t start)
Definition: bfloat16_math.hpp:26
sycl::_V1::ext::oneapi::fmax
std::enable_if_t< detail::is_bf16_storage_type< T >::value, T > fmax(T x, T y)
Definition: bf16_storage_builtins.hpp:60
sycl::_V1::ext::oneapi::experimental::fabs
sycl::marray< bfloat16, N > fabs(sycl::marray< bfloat16, N > x)
Definition: bfloat16_math.hpp:67
sycl::_V1::ext::intel::esimd::sqrt
__ESIMD_API simd< T, N > sqrt(simd< T, N > src, Sat sat={})
Square root.
Definition: math.hpp:389
sycl::_V1::sin
ESIMD_NODEBUG ESIMD_INLINE sycl::ext::intel::esimd::simd< float, SZ > sin(sycl::ext::intel::esimd::simd< float, SZ > x) __NOEXC
Definition: builtins_esimd.hpp:38
sycl::_V1::ext::oneapi::detail::Bfloat16StorageT
uint16_t Bfloat16StorageT
Definition: bfloat16.hpp:27
BFLOAT16_MATH_FP32_WRAPPERS
#define BFLOAT16_MATH_FP32_WRAPPERS(op)
Definition: bfloat16_math.hpp:235
defines_elementary.hpp
sycl::_V1::ext::oneapi::fmin
std::enable_if_t< detail::is_bf16_storage_type< T >::value, T > fmin(T x, T y)
Definition: bf16_storage_builtins.hpp:49
sycl::_V1::ext::oneapi::detail::bitsToBfloat16
bfloat16 bitsToBfloat16(const Bfloat16StorageT Value)
Definition: bfloat16.hpp:211
bfloat16.hpp
sycl::_V1::ext::intel::esimd::rsqrt
__ESIMD_API simd< T, N > rsqrt(simd< T, N > src, Sat sat={})
Square root reciprocal - calculates 1/sqrt(x).
Definition: math.hpp:397
sycl::_V1::ext::intel::esimd::ceil
ESIMD_INLINE sycl::ext::intel::esimd::simd< RT, SZ > ceil(const sycl::ext::intel::esimd::simd< float, SZ > src0, Sat sat={})
"Ceiling" operation, vector version - alias of rndu.
Definition: math.hpp:594
sycl::_V1::ext::intel::esimd::log2
__ESIMD_API simd< T, N > log2(simd< T, N > src, Sat sat={})
Logarithm base 2.
Definition: math.hpp:381
exception.hpp
marray.hpp
sycl::_V1::log
ESIMD_NODEBUG ESIMD_INLINE sycl::ext::intel::esimd::simd< float, SZ > log(sycl::ext::intel::esimd::simd< float, SZ > x) __NOEXC
Definition: builtins_esimd.hpp:60
sycl::_V1::ext::oneapi::experimental::isnan
sycl::marray< bool, N > isnan(sycl::marray< bfloat16, N > x)
Definition: bfloat16_math.hpp:41
sycl::_V1::ext::intel::math::rint
std::enable_if_t< std::is_same_v< Tp, float >, float > rint(Tp x)
Definition: math.hpp:149
sycl::_V1::ext::oneapi::detail::bfloat16ToBits
Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value)
Definition: bfloat16.hpp:205
sycl::_V1::ext::intel::esimd::floor
ESIMD_INLINE sycl::ext::intel::esimd::simd< RT, SZ > floor(const sycl::ext::intel::esimd::simd< float, SZ > src0, Sat sat={})
"Floor" operation, vector version - alias of rndd.
Definition: math.hpp:581
sycl::_V1::ext::intel::esimd::trunc
__ESIMD_API sycl::ext::intel::esimd::simd< RT, SZ > trunc(const sycl::ext::intel::esimd::simd< float, SZ > &src0, Sat sat={})
Round to integral value using the round to zero rounding mode (vector version).
Definition: math.hpp:614