DPC++ Runtime
Runtime libraries for oneAPI DPC++
bfloat16.hpp
Go to the documentation of this file.
1 //==--------- bfloat16.hpp ------- SYCL bfloat16 conversion ----------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #include <CL/__spirv/spirv_ops.hpp>
12 #include <sycl/half_type.hpp>
13 
14 #if !defined(__SYCL_DEVICE_ONLY__)
15 #include <cmath>
16 #endif
17 
18 extern "C" __DPCPP_SYCL_EXTERNAL uint16_t
19 __devicelib_ConvertFToBF16INTEL(const float &) noexcept;
20 extern "C" __DPCPP_SYCL_EXTERNAL float
21 __devicelib_ConvertBF16ToFINTEL(const uint16_t &) noexcept;
22 
23 namespace sycl {
25 namespace ext::oneapi {
26 
27 class bfloat16;
28 
29 namespace detail {
30 using Bfloat16StorageT = uint16_t;
33 } // namespace detail
34 
35 class bfloat16 {
37 
38  friend inline detail::Bfloat16StorageT
39  detail::bfloat16ToBits(const bfloat16 &Value);
40  friend inline bfloat16
42 
43 public:
44  bfloat16() = default;
45  bfloat16(const bfloat16 &) = default;
46  ~bfloat16() = default;
47 
48 private:
49  // Explicit conversion functions
50  static detail::Bfloat16StorageT from_float(const float &a) {
51 #if defined(__SYCL_DEVICE_ONLY__)
52 #if defined(__NVPTX__)
53 #if (__SYCL_CUDA_ARCH__ >= 800)
54  return __nvvm_f2bf16_rn(a);
55 #else
56  // TODO find a better way to check for NaN
57  if (a != a)
58  return 0xffc1;
59  union {
60  uint32_t intStorage;
61  float floatValue;
62  };
63  floatValue = a;
64  // Do RNE and truncate
65  uint32_t roundingBias = ((intStorage >> 16) & 0x1) + 0x00007FFF;
66  return static_cast<uint16_t>((intStorage + roundingBias) >> 16);
67 #endif
68 #else
70 #endif
71 #else
72  // In case float value is nan - propagate bfloat16's qnan
73  if (std::isnan(a))
74  return 0xffc1;
75  union {
76  uint32_t intStorage;
77  float floatValue;
78  };
79  floatValue = a;
80  // Do RNE and truncate
81  uint32_t roundingBias = ((intStorage >> 16) & 0x1) + 0x00007FFF;
82  return static_cast<uint16_t>((intStorage + roundingBias) >> 16);
83 #endif
84  }
85 
86  static float to_float(const detail::Bfloat16StorageT &a) {
87 #if defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
89 #else
90  union {
91  uint32_t intStorage;
92  float floatValue;
93  };
94  intStorage = a << 16;
95  return floatValue;
96 #endif
97  }
98 
99 public:
100  // Implicit conversion from float to bfloat16
101  bfloat16(const float &a) { value = from_float(a); }
102 
103  bfloat16 &operator=(const float &rhs) {
104  value = from_float(rhs);
105  return *this;
106  }
107 
108  // Implicit conversion from sycl::half to bfloat16
109  bfloat16(const sycl::half &a) { value = from_float(a); }
110 
112  value = from_float(rhs);
113  return *this;
114  }
115 
116  // Implicit conversion from bfloat16 to float
117  operator float() const { return to_float(value); }
118 
119  // Implicit conversion from bfloat16 to sycl::half
120  operator sycl::half() const { return to_float(value); }
121 
122  // Logical operators (!,||,&&) are covered if we can cast to bool
123  explicit operator bool() { return to_float(value) != 0.0f; }
124 
125  // Unary minus operator overloading
126  friend bfloat16 operator-(bfloat16 &lhs) {
127 #if defined(__SYCL_DEVICE_ONLY__)
128 #if defined(__NVPTX__)
129 #if (__SYCL_CUDA_ARCH__ >= 800)
130  return detail::bitsToBfloat16(__nvvm_neg_bf16(lhs.value));
131 #else
132  return -to_float(lhs.value);
133 #endif
134 #else
135  return bfloat16{-__devicelib_ConvertBF16ToFINTEL(lhs.value)};
136 #endif
137 #else
138  return -to_float(lhs.value);
139 #endif
140  }
141 
142 // Increment and decrement operators overloading
143 #define OP(op) \
144  friend bfloat16 &operator op(bfloat16 &lhs) { \
145  float f = to_float(lhs.value); \
146  lhs.value = from_float(op f); \
147  return lhs; \
148  } \
149  friend bfloat16 operator op(bfloat16 &lhs, int) { \
150  bfloat16 old = lhs; \
151  operator op(lhs); \
152  return old; \
153  }
154  OP(++)
155  OP(--)
156 #undef OP
157 
158  // Assignment operators overloading
159 #define OP(op) \
160  friend bfloat16 &operator op(bfloat16 &lhs, const bfloat16 &rhs) { \
161  float f = static_cast<float>(lhs); \
162  f op static_cast<float>(rhs); \
163  return lhs = f; \
164  } \
165  template <typename T> \
166  friend bfloat16 &operator op(bfloat16 &lhs, const T &rhs) { \
167  float f = static_cast<float>(lhs); \
168  f op static_cast<float>(rhs); \
169  return lhs = f; \
170  } \
171  template <typename T> friend T &operator op(T &lhs, const bfloat16 &rhs) { \
172  float f = static_cast<float>(lhs); \
173  f op static_cast<float>(rhs); \
174  return lhs = f; \
175  }
176  OP(+=)
177  OP(-=)
178  OP(*=)
179  OP(/=)
180 #undef OP
181 
182 // Binary operators overloading
183 #define OP(type, op) \
184  friend type operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \
185  return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
186  } \
187  template <typename T> \
188  friend type operator op(const bfloat16 &lhs, const T &rhs) { \
189  return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
190  } \
191  template <typename T> \
192  friend type operator op(const T &lhs, const bfloat16 &rhs) { \
193  return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
194  }
195  OP(bfloat16, +)
196  OP(bfloat16, -)
197  OP(bfloat16, *)
198  OP(bfloat16, /)
199  OP(bool, ==)
200  OP(bool, !=)
201  OP(bool, <)
202  OP(bool, >)
203  OP(bool, <=)
204  OP(bool, >=)
205 #undef OP
206 
207  // Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported
208  // for floating-point types.
209 };
210 
211 namespace detail {
212 
213 // Helper function for getting the internal representation of a bfloat16.
215  return Value.value;
216 }
217 
218 // Helper function for creating a float16 from a value with the same type as the
219 // internal representation.
221  bfloat16 res;
222  res.value = Value;
223  return res;
224 }
225 
226 } // namespace detail
227 
228 } // namespace ext::oneapi
229 
230 } // __SYCL_INLINE_VER_NAMESPACE(_V1)
231 } // namespace sycl
spirv_ops.hpp
sycl::_V1::ext::oneapi::bfloat16
Definition: bfloat16.hpp:35
sycl::_V1::ext::oneapi::bfloat16::operator=
bfloat16 & operator=(const float &rhs)
Definition: bfloat16.hpp:103
__SYCL_INLINE_VER_NAMESPACE
#define __SYCL_INLINE_VER_NAMESPACE(X)
Definition: defines_elementary.hpp:11
sycl
---— Error handling, matching OpenCL plugin semantics.
Definition: access.hpp:14
__devicelib_ConvertBF16ToFINTEL
__DPCPP_SYCL_EXTERNAL float __devicelib_ConvertBF16ToFINTEL(const uint16_t &) noexcept
sycl::_V1::ext::oneapi::detail::Bfloat16StorageT
uint16_t Bfloat16StorageT
Definition: bfloat16.hpp:30
sycl::_V1::ext::oneapi::experimental::isnan
std::enable_if_t< std::is_same< T, bfloat16 >::value, bool > isnan(T x)
Definition: bfloat16_math.hpp:36
sycl::_V1::ext::oneapi::detail::bitsToBfloat16
bfloat16 bitsToBfloat16(const Bfloat16StorageT Value)
Definition: bfloat16.hpp:220
sycl::_V1::ext::oneapi::bfloat16::bfloat16
bfloat16(const sycl::half &a)
Definition: bfloat16.hpp:109
sycl::_V1::half
sycl::detail::half_impl::half half
Definition: aliases.hpp:103
__devicelib_ConvertFToBF16INTEL
__DPCPP_SYCL_EXTERNAL uint16_t __devicelib_ConvertFToBF16INTEL(const float &) noexcept
half_type.hpp
sycl::_V1::ext::oneapi::bfloat16::operator=
bfloat16 & operator=(const sycl::half &rhs)
Definition: bfloat16.hpp:111
OP
#define OP(op)
Definition: bfloat16.hpp:183
sycl::_V1::ext::oneapi::bfloat16::operator-
friend bfloat16 operator-(bfloat16 &lhs)
Definition: bfloat16.hpp:126
sycl::_V1::ext::oneapi::detail::bfloat16ToBits
Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value)
Definition: bfloat16.hpp:214
sycl::_V1::ext::oneapi::bfloat16::bfloat16
bfloat16(const float &a)
Definition: bfloat16.hpp:101
__DPCPP_SYCL_EXTERNAL
#define __DPCPP_SYCL_EXTERNAL
Definition: defines_elementary.hpp:35