DPC++ Runtime
Runtime libraries for oneAPI DPC++
bfloat16.hpp
Go to the documentation of this file.
1 //==--------- bfloat16.hpp ------- SYCL bfloat16 conversion ----------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #include <sycl/aliases.hpp> // for half
12 #include <sycl/detail/defines_elementary.hpp> // for __DPCPP_SYCL_EXTERNAL
13 #include <sycl/half_type.hpp> // for half
14 
15 #include <stdint.h> // for uint16_t, uint32_t
16 
17 extern "C" __DPCPP_SYCL_EXTERNAL uint16_t
19 extern "C" __DPCPP_SYCL_EXTERNAL float
21 
22 namespace sycl {
23 inline namespace _V1 {
24 
25 #ifdef __INTEL_PREVIEW_BREAKING_CHANGES
26 // forward declaration of sycl::isnan built-in.
27 // extern __DPCPP_SYCL_EXTERNAL bool isnan(float a);
28 bool isnan(float a);
29 #endif
30 
31 namespace ext::oneapi {
32 
33 class bfloat16;
34 
35 namespace detail {
36 using Bfloat16StorageT = uint16_t;
39 
40 // sycl::vec support
41 namespace bf16 {
42 #ifdef __SYCL_DEVICE_ONLY__
43 using Vec2StorageT = Bfloat16StorageT __attribute__((ext_vector_type(2)));
44 using Vec3StorageT = Bfloat16StorageT __attribute__((ext_vector_type(3)));
45 using Vec4StorageT = Bfloat16StorageT __attribute__((ext_vector_type(4)));
46 using Vec8StorageT = Bfloat16StorageT __attribute__((ext_vector_type(8)));
47 using Vec16StorageT = Bfloat16StorageT __attribute__((ext_vector_type(16)));
48 #else
49 using Vec2StorageT = std::array<Bfloat16StorageT, 2>;
50 using Vec3StorageT = std::array<Bfloat16StorageT, 3>;
51 using Vec4StorageT = std::array<Bfloat16StorageT, 4>;
52 using Vec8StorageT = std::array<Bfloat16StorageT, 8>;
53 using Vec16StorageT = std::array<Bfloat16StorageT, 16>;
54 #endif
55 } // namespace bf16
56 
57 #ifndef __INTEL_PREVIEW_BREAKING_CHANGES
58 inline bool float_is_nan(float x) { return x != x; }
59 #endif
60 } // namespace detail
61 
62 class bfloat16 {
63 protected:
65 
66  friend inline detail::Bfloat16StorageT
68  friend inline bfloat16
70 
71 public:
72  bfloat16() = default;
73  constexpr bfloat16(const bfloat16 &) = default;
74  constexpr bfloat16(bfloat16 &&) = default;
75  constexpr bfloat16 &operator=(const bfloat16 &rhs) = default;
76  ~bfloat16() = default;
77 
78 private:
79  static detail::Bfloat16StorageT from_float_fallback(const float &a) {
80 #ifdef __INTEL_PREVIEW_BREAKING_CHANGES
81  if (sycl::isnan(a))
82  return 0xffc1;
83 #else
85  return 0xffc1;
86 #endif
87 
88  union {
89  uint32_t intStorage;
90  float floatValue;
91  };
92  floatValue = a;
93  // Do RNE and truncate
94  uint32_t roundingBias = ((intStorage >> 16) & 0x1) + 0x00007FFF;
95  return static_cast<uint16_t>((intStorage + roundingBias) >> 16);
96  }
97 
98  // Explicit conversion functions
99  static detail::Bfloat16StorageT from_float(const float &a) {
100 #if defined(__SYCL_DEVICE_ONLY__)
101 #if defined(__NVPTX__)
102 #if (__SYCL_CUDA_ARCH__ >= 800)
104  asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(res) : "f"(a));
105  return res;
106 #else
107  return from_float_fallback(a);
108 #endif
109 #elif defined(__AMDGCN__)
110  return from_float_fallback(a);
111 #else
113 #endif
114 #endif
115  return from_float_fallback(a);
116  }
117 
118  static float to_float(const detail::Bfloat16StorageT &a) {
119 #if defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
121 #else
122  union {
123  uint32_t intStorage;
124  float floatValue;
125  };
126  intStorage = a << 16;
127  return floatValue;
128 #endif
129  }
130 
131 protected:
132  friend class sycl::vec<bfloat16, 1>;
133  friend class sycl::vec<bfloat16, 2>;
134  friend class sycl::vec<bfloat16, 3>;
135  friend class sycl::vec<bfloat16, 4>;
136  friend class sycl::vec<bfloat16, 8>;
137  friend class sycl::vec<bfloat16, 16>;
138 
139 public:
140  // Implicit conversion from float to bfloat16
141  bfloat16(const float &a) { value = from_float(a); }
142 
143  bfloat16 &operator=(const float &rhs) {
144  value = from_float(rhs);
145  return *this;
146  }
147 
148  // Implicit conversion from sycl::half to bfloat16
149  bfloat16(const sycl::half &a) { value = from_float(a); }
150 
152  value = from_float(rhs);
153  return *this;
154  }
155 
156  // Implicit conversion from bfloat16 to float
157  operator float() const { return to_float(value); }
158 
159  // Implicit conversion from bfloat16 to sycl::half
160  operator sycl::half() const { return to_float(value); }
161 
162  // Logical operators (!,||,&&) are covered if we can cast to bool
163  explicit operator bool() { return to_float(value) != 0.0f; }
164 
165  // Unary minus operator overloading
166  friend bfloat16 operator-(bfloat16 &lhs) {
167 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
168  (__SYCL_CUDA_ARCH__ >= 800)
170  asm("neg.bf16 %0, %1;" : "=h"(res) : "h"(lhs.value));
171  return detail::bitsToBfloat16(res);
172 #elif defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
174 #else
175  return bfloat16{-to_float(lhs.value)};
176 #endif
177  }
178 
179 // Increment and decrement operators overloading
180 #define OP(op) \
181  friend bfloat16 &operator op(bfloat16 & lhs) { \
182  float f = to_float(lhs.value); \
183  lhs.value = from_float(op f); \
184  return lhs; \
185  } \
186  friend bfloat16 operator op(bfloat16 &lhs, int) { \
187  bfloat16 old = lhs; \
188  operator op(lhs); \
189  return old; \
190  }
191  OP(++)
192  OP(--)
193 #undef OP
194 
195  // Assignment operators overloading
196 #define OP(op) \
197  friend bfloat16 &operator op(bfloat16 & lhs, const bfloat16 & rhs) { \
198  float f = static_cast<float>(lhs); \
199  f op static_cast<float>(rhs); \
200  return lhs = f; \
201  }
202  OP(+=)
203  OP(-=)
204  OP(*=)
205  OP(/=)
206 #undef OP
207 
208 // Binary operators overloading
209 #define OP(type, op) \
210  friend type operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \
211  return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
212  } \
213  template <typename T> \
214  friend std::enable_if_t<std::is_convertible_v<T, float>, type> operator op( \
215  const bfloat16 & lhs, const T & rhs) { \
216  return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
217  } \
218  template <typename T> \
219  friend std::enable_if_t<std::is_convertible_v<T, float>, type> operator op( \
220  const T & lhs, const bfloat16 & rhs) { \
221  return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
222  }
223  OP(bfloat16, +)
224  OP(bfloat16, -)
225  OP(bfloat16, *)
226  OP(bfloat16, /)
227  OP(bool, ==)
228  OP(bool, !=)
229  OP(bool, <)
230  OP(bool, >)
231  OP(bool, <=)
232  OP(bool, >=)
233 #undef OP
234 
235  // Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported
236  // for floating-point types.
237 
238  // Stream Operator << and >>
239  inline friend std::ostream &operator<<(std::ostream &O, bfloat16 const &rhs) {
240  O << static_cast<float>(rhs);
241  return O;
242  }
243 
244  inline friend std::istream &operator>>(std::istream &I, bfloat16 &rhs) {
245  float ValFloat = 0.0f;
246  I >> ValFloat;
247  rhs = ValFloat;
248  return I;
249  }
250 };
251 
252 namespace detail {
253 
254 // Helper function for getting the internal representation of a bfloat16.
256  return Value.value;
257 }
258 
259 // Helper function for creating a float16 from a value with the same type as the
260 // internal representation.
262  bfloat16 res;
263  res.value = Value;
264  return res;
265 }
266 
267 } // namespace detail
268 
269 } // namespace ext::oneapi
270 
271 } // namespace _V1
272 } // namespace sycl
__DPCPP_SYCL_EXTERNAL uint16_t __devicelib_ConvertFToBF16INTEL(const float &) noexcept
#define OP(op)
Definition: bfloat16.hpp:209
__DPCPP_SYCL_EXTERNAL float __devicelib_ConvertBF16ToFINTEL(const uint16_t &) noexcept
bfloat16(const sycl::half &a)
Definition: bfloat16.hpp:149
constexpr bfloat16(const bfloat16 &)=default
friend bfloat16 operator-(bfloat16 &lhs)
Definition: bfloat16.hpp:166
constexpr bfloat16(bfloat16 &&)=default
friend std::istream & operator>>(std::istream &I, bfloat16 &rhs)
Definition: bfloat16.hpp:244
friend std::ostream & operator<<(std::ostream &O, bfloat16 const &rhs)
Definition: bfloat16.hpp:239
bfloat16 & operator=(const sycl::half &rhs)
Definition: bfloat16.hpp:151
constexpr bfloat16 & operator=(const bfloat16 &rhs)=default
bfloat16 & operator=(const float &rhs)
Definition: bfloat16.hpp:143
detail::Bfloat16StorageT value
Definition: bfloat16.hpp:64
defined(__INTEL_PREVIEW_BREAKING_CHANGES)
Definition: types.hpp:346
#define __DPCPP_SYCL_EXTERNAL
std::array< Bfloat16StorageT, 4 > Vec4StorageT
Definition: bfloat16.hpp:51
std::array< Bfloat16StorageT, 3 > Vec3StorageT
Definition: bfloat16.hpp:50
std::array< Bfloat16StorageT, 16 > Vec16StorageT
Definition: bfloat16.hpp:53
std::array< Bfloat16StorageT, 8 > Vec8StorageT
Definition: bfloat16.hpp:52
std::array< Bfloat16StorageT, 2 > Vec2StorageT
Definition: bfloat16.hpp:49
bfloat16 bitsToBfloat16(const Bfloat16StorageT Value)
Definition: bfloat16.hpp:261
Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value)
Definition: bfloat16.hpp:255
__attribute__((always_inline)) auto invoke_simd(sycl
The invoke_simd free function invokes a SIMD function using all work-items in a sub_group.
T detail::marray_element_t< T > y T T T maxval[i] T T T a
detail::common_rel_ret_t< T > isnan(T x)
sycl::detail::half_impl::half half
Definition: aliases.hpp:101
Definition: access.hpp:18
_Abi const simd< _Tp, _Abi > & noexcept
Definition: simd.hpp:1324