DPC++ Runtime
Runtime libraries for oneAPI DPC++
bfloat16_math.hpp
Go to the documentation of this file.
1 //==-------- bfloat16_math.hpp - SYCL bloat16 math functions ---------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
12 #include <sycl/exception.hpp>
14 #include <sycl/marray.hpp>
15 
16 #include <cstring>
17 #include <tuple>
18 #include <type_traits>
19 
20 namespace sycl {
22 namespace ext::oneapi::experimental {
23 
24 namespace detail {
25 template <size_t N>
26 uint32_t to_uint32_t(sycl::marray<bfloat16, N> x, size_t start) {
27  uint32_t res;
28  std::memcpy(&res, &x[start], sizeof(uint32_t));
29  return res;
30 }
31 } // namespace detail
32 
33 // According to bfloat16 format, NAN value's exponent field is 0xFF and
34 // significand has non-zero bits.
35 template <typename T>
36 std::enable_if_t<std::is_same<T, bfloat16>::value, bool> isnan(T x) {
38  return (((XBits & 0x7F80) == 0x7F80) && (XBits & 0x7F)) ? true : false;
39 }
40 
41 template <size_t N> sycl::marray<bool, N> isnan(sycl::marray<bfloat16, N> x) {
42  sycl::marray<bool, N> res;
43  for (size_t i = 0; i < N; i++) {
44  res[i] = isnan(x[i]);
45  }
46  return res;
47 }
48 
49 template <typename T>
50 std::enable_if_t<std::is_same<T, bfloat16>::value, T> fabs(T x) {
51 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
53  return oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
54 #else
55  std::ignore = x;
56  throw runtime_error(
57  "bfloat16 math functions are not currently supported on the host device.",
58  PI_ERROR_INVALID_DEVICE);
59 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
60 }
61 
62 template <size_t N>
63 sycl::marray<bfloat16, N> fabs(sycl::marray<bfloat16, N> x) {
64 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
65  sycl::marray<bfloat16, N> res;
66 
67  for (size_t i = 0; i < N / 2; i++) {
68  auto partial_res = __clc_fabs(detail::to_uint32_t(x, i * 2));
69  std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
70  }
71 
72  if (N % 2) {
75  res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
76  }
77  return res;
78 #else
79  std::ignore = x;
80  throw runtime_error(
81  "bfloat16 math functions are not currently supported on the host device.",
82  PI_ERROR_INVALID_DEVICE);
83 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
84 }
85 
86 template <typename T>
87 std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmin(T x, T y) {
88 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
91  return oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits));
92 #else
93  static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0;
96  if (isnan(x) && isnan(y))
97  return oneapi::detail::bitsToBfloat16(CanonicalNan);
98 
99  if (isnan(x))
100  return y;
101  if (isnan(y))
102  return x;
103  if (((XBits | YBits) ==
104  static_cast<oneapi::detail::Bfloat16StorageT>(0x8000)) &&
105  !(XBits & YBits))
107  static_cast<oneapi::detail::Bfloat16StorageT>(0x8000));
108 
109  return (x < y) ? x : y;
110 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
111 }
112 
113 template <size_t N>
114 sycl::marray<bfloat16, N> fmin(sycl::marray<bfloat16, N> x,
115  sycl::marray<bfloat16, N> y) {
116  sycl::marray<bfloat16, N> res;
117 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
118  for (size_t i = 0; i < N / 2; i++) {
119  auto partial_res = __clc_fmin(detail::to_uint32_t(x, i * 2),
120  detail::to_uint32_t(y, i * 2));
121  std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
122  }
123 
124  if (N % 2) {
129  res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits));
130  }
131 #else
132  for (size_t i = 0; i < N; i++) {
133  res[i] = fmin(x[i], y[i]);
134  }
135 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
136  return res;
137 }
138 
139 template <typename T>
140 std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmax(T x, T y) {
141 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
144  return oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits));
145 #else
146  static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0;
149  if (isnan(x) && isnan(y))
150  return oneapi::detail::bitsToBfloat16(CanonicalNan);
151 
152  if (isnan(x))
153  return y;
154  if (isnan(y))
155  return x;
156  if (((XBits | YBits) ==
157  static_cast<oneapi::detail::Bfloat16StorageT>(0x8000)) &&
158  !(XBits & YBits))
160 
161  return (x > y) ? x : y;
162 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
163 }
164 
165 template <size_t N>
166 sycl::marray<bfloat16, N> fmax(sycl::marray<bfloat16, N> x,
167  sycl::marray<bfloat16, N> y) {
168  sycl::marray<bfloat16, N> res;
169 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
170  for (size_t i = 0; i < N / 2; i++) {
171  auto partial_res = __clc_fmax(detail::to_uint32_t(x, i * 2),
172  detail::to_uint32_t(y, i * 2));
173  std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
174  }
175 
176  if (N % 2) {
181  res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits));
182  }
183 #else
184  for (size_t i = 0; i < N; i++) {
185  res[i] = fmax(x[i], y[i]);
186  }
187 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
188  return res;
189 }
190 
191 template <typename T>
192 std::enable_if_t<std::is_same<T, bfloat16>::value, T> fma(T x, T y, T z) {
193 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
197  return oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits));
198 #else
199  return sycl::ext::oneapi::bfloat16{sycl::fma(float{x}, float{y}, float{z})};
200 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
201 }
202 
203 template <size_t N>
204 sycl::marray<bfloat16, N> fma(sycl::marray<bfloat16, N> x,
205  sycl::marray<bfloat16, N> y,
206  sycl::marray<bfloat16, N> z) {
207  sycl::marray<bfloat16, N> res;
208 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
209  for (size_t i = 0; i < N / 2; i++) {
210  auto partial_res =
211  __clc_fma(detail::to_uint32_t(x, i * 2), detail::to_uint32_t(y, i * 2),
212  detail::to_uint32_t(z, i * 2));
213  std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
214  }
215 
216  if (N % 2) {
223  res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits));
224  }
225 #else
226  for (size_t i = 0; i < N; i++) {
227  res[i] = fma(x[i], y[i], z[i]);
228  }
229 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
230  return res;
231 }
232 
233 } // namespace ext::oneapi::experimental
234 } // __SYCL_INLINE_VER_NAMESPACE(_V1)
235 } // namespace sycl
#define __SYCL_INLINE_VER_NAMESPACE(X)
void memcpy(void *Dst, const void *Src, std::size_t Size)
bfloat16 bitsToBfloat16(const Bfloat16StorageT Value)
Definition: bfloat16.hpp:220
Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value)
Definition: bfloat16.hpp:214
uint32_t to_uint32_t(sycl::marray< bfloat16, N > x, size_t start)
sycl::marray< bfloat16, N > fabs(sycl::marray< bfloat16, N > x)
sycl::marray< bool, N > isnan(sycl::marray< bfloat16, N > x)
std::enable_if_t< detail::is_bf16_storage_type< T >::value, T > fma(T x, T y, T z)
std::enable_if_t< detail::is_bf16_storage_type< T >::value, T > fmax(T x, T y)
std::enable_if_t< detail::is_bf16_storage_type< T >::value, T > fmin(T x, T y)
---— Error handling, matching OpenCL plugin semantics.
Definition: access.hpp:14