DPC++ Runtime
Runtime libraries for oneAPI DPC++
bfloat16.hpp
Go to the documentation of this file.
1 //==--------- bfloat16.hpp ------- SYCL bfloat16 conversion ----------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #include <sycl/aliases.hpp> // for half
12 #include <sycl/detail/defines_elementary.hpp> // for __DPCPP_SYCL_EXTERNAL
13 #include <sycl/half_type.hpp> // for half
14 
15 #include <stdint.h> // for uint16_t, uint32_t
16 
17 extern "C" __DPCPP_SYCL_EXTERNAL uint16_t
19 extern "C" __DPCPP_SYCL_EXTERNAL float
21 
22 namespace sycl {
23 inline namespace _V1 {
24 namespace ext::oneapi {
25 
26 class bfloat16;
27 
28 namespace detail {
29 using Bfloat16StorageT = uint16_t;
32 
33 // sycl::vec support
34 namespace bf16 {
35 #ifdef __SYCL_DEVICE_ONLY__
36 using Vec2StorageT = Bfloat16StorageT __attribute__((ext_vector_type(2)));
37 using Vec3StorageT = Bfloat16StorageT __attribute__((ext_vector_type(3)));
38 using Vec4StorageT = Bfloat16StorageT __attribute__((ext_vector_type(4)));
39 using Vec8StorageT = Bfloat16StorageT __attribute__((ext_vector_type(8)));
40 using Vec16StorageT = Bfloat16StorageT __attribute__((ext_vector_type(16)));
41 #else
42 using Vec2StorageT = std::array<Bfloat16StorageT, 2>;
43 using Vec3StorageT = std::array<Bfloat16StorageT, 3>;
44 using Vec4StorageT = std::array<Bfloat16StorageT, 4>;
45 using Vec8StorageT = std::array<Bfloat16StorageT, 8>;
46 using Vec16StorageT = std::array<Bfloat16StorageT, 16>;
47 #endif
48 } // namespace bf16
49 } // namespace detail
50 
51 class bfloat16 {
52 protected:
54 
55  friend inline detail::Bfloat16StorageT
57  friend inline bfloat16
59 
60 public:
61  bfloat16() = default;
62  constexpr bfloat16(const bfloat16 &) = default;
63  constexpr bfloat16(bfloat16 &&) = default;
64  constexpr bfloat16 &operator=(const bfloat16 &rhs) = default;
65  ~bfloat16() = default;
66 
67 private:
68  static detail::Bfloat16StorageT from_float_fallback(const float &a) {
69  // We don't call sycl::isnan because we don't want a data type to depend on
70  // builtins.
71  if (a != a)
72  return 0xffc1;
73 
74  union {
75  uint32_t intStorage;
76  float floatValue;
77  };
78  floatValue = a;
79  // Do RNE and truncate
80  uint32_t roundingBias = ((intStorage >> 16) & 0x1) + 0x00007FFF;
81  return static_cast<uint16_t>((intStorage + roundingBias) >> 16);
82  }
83 
84  // Explicit conversion functions
85  static detail::Bfloat16StorageT from_float(const float &a) {
86 #if defined(__SYCL_DEVICE_ONLY__)
87 #if defined(__NVPTX__)
88 #if (__SYCL_CUDA_ARCH__ >= 800)
90  asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(res) : "f"(a));
91  return res;
92 #else
93  return from_float_fallback(a);
94 #endif
95 #elif defined(__AMDGCN__)
96  return from_float_fallback(a);
97 #else
99 #endif
100 #endif
101  return from_float_fallback(a);
102  }
103 
104  static float to_float(const detail::Bfloat16StorageT &a) {
105 #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
107 #else
108  union {
109  uint32_t intStorage;
110  float floatValue;
111  };
112  intStorage = a << 16;
113  return floatValue;
114 #endif
115  }
116 
117 protected:
118  friend class sycl::vec<bfloat16, 1>;
119  friend class sycl::vec<bfloat16, 2>;
120  friend class sycl::vec<bfloat16, 3>;
121  friend class sycl::vec<bfloat16, 4>;
122  friend class sycl::vec<bfloat16, 8>;
123  friend class sycl::vec<bfloat16, 16>;
124 
125 public:
126  // Implicit conversion from float to bfloat16
127  bfloat16(const float &a) { value = from_float(a); }
128 
129  bfloat16 &operator=(const float &rhs) {
130  value = from_float(rhs);
131  return *this;
132  }
133 
134  // Implicit conversion from sycl::half to bfloat16
135  bfloat16(const sycl::half &a) { value = from_float(a); }
136 
138  value = from_float(rhs);
139  return *this;
140  }
141 
142  // Implicit conversion from bfloat16 to float
143  operator float() const { return to_float(value); }
144 
145  // Implicit conversion from bfloat16 to sycl::half
146  operator sycl::half() const { return to_float(value); }
147 
148  // Logical operators (!,||,&&) are covered if we can cast to bool
149  explicit operator bool() { return to_float(value) != 0.0f; }
150 
151  // Unary minus operator overloading
152  friend bfloat16 operator-(bfloat16 &lhs) {
153 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
154  (__SYCL_CUDA_ARCH__ >= 800)
156  asm("neg.bf16 %0, %1;" : "=h"(res) : "h"(lhs.value));
157  return detail::bitsToBfloat16(res);
158 #elif defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
160 #else
161  return bfloat16{-to_float(lhs.value)};
162 #endif
163  }
164 
165 // Increment and decrement operators overloading
166 #define OP(op) \
167  friend bfloat16 &operator op(bfloat16 & lhs) { \
168  float f = to_float(lhs.value); \
169  lhs.value = from_float(op f); \
170  return lhs; \
171  } \
172  friend bfloat16 operator op(bfloat16 &lhs, int) { \
173  bfloat16 old = lhs; \
174  operator op(lhs); \
175  return old; \
176  }
177  OP(++)
178  OP(--)
179 #undef OP
180 
181  // Assignment operators overloading
182 #define OP(op) \
183  friend bfloat16 &operator op(bfloat16 & lhs, const bfloat16 & rhs) { \
184  float f = static_cast<float>(lhs); \
185  f op static_cast<float>(rhs); \
186  return lhs = f; \
187  }
188  OP(+=)
189  OP(-=)
190  OP(*=)
191  OP(/=)
192 #undef OP
193 
194 // Binary operators overloading
195 #define OP(type, op) \
196  friend type operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \
197  return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
198  } \
199  template <typename T> \
200  friend std::enable_if_t<std::is_convertible_v<T, float>, type> operator op( \
201  const bfloat16 & lhs, const T & rhs) { \
202  return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
203  } \
204  template <typename T> \
205  friend std::enable_if_t<std::is_convertible_v<T, float>, type> operator op( \
206  const T & lhs, const bfloat16 & rhs) { \
207  return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
208  }
209  OP(bfloat16, +)
210  OP(bfloat16, -)
211  OP(bfloat16, *)
212  OP(bfloat16, /)
213  OP(bool, ==)
214  OP(bool, !=)
215  OP(bool, <)
216  OP(bool, >)
217  OP(bool, <=)
218  OP(bool, >=)
219 #undef OP
220 
221  // Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported
222  // for floating-point types.
223 
224  // Stream Operator << and >>
225  inline friend std::ostream &operator<<(std::ostream &O, bfloat16 const &rhs) {
226  O << static_cast<float>(rhs);
227  return O;
228  }
229 
230  inline friend std::istream &operator>>(std::istream &I, bfloat16 &rhs) {
231  float ValFloat = 0.0f;
232  I >> ValFloat;
233  rhs = ValFloat;
234  return I;
235  }
236 };
237 
238 namespace detail {
239 
240 // Helper function for getting the internal representation of a bfloat16.
242  return Value.value;
243 }
244 
245 // Helper function for creating a float16 from a value with the same type as the
246 // internal representation.
248  bfloat16 res;
249  res.value = Value;
250  return res;
251 }
252 
253 } // namespace detail
254 
255 } // namespace ext::oneapi
256 
257 } // namespace _V1
258 } // namespace sycl
__DPCPP_SYCL_EXTERNAL uint16_t __devicelib_ConvertFToBF16INTEL(const float &) noexcept
#define OP(op)
Definition: bfloat16.hpp:195
__DPCPP_SYCL_EXTERNAL float __devicelib_ConvertBF16ToFINTEL(const uint16_t &) noexcept
bfloat16(const sycl::half &a)
Definition: bfloat16.hpp:135
constexpr bfloat16(const bfloat16 &)=default
friend bfloat16 operator-(bfloat16 &lhs)
Definition: bfloat16.hpp:152
constexpr bfloat16(bfloat16 &&)=default
friend std::istream & operator>>(std::istream &I, bfloat16 &rhs)
Definition: bfloat16.hpp:230
friend std::ostream & operator<<(std::ostream &O, bfloat16 const &rhs)
Definition: bfloat16.hpp:225
bfloat16 & operator=(const sycl::half &rhs)
Definition: bfloat16.hpp:137
constexpr bfloat16 & operator=(const bfloat16 &rhs)=default
bfloat16 & operator=(const float &rhs)
Definition: bfloat16.hpp:129
detail::Bfloat16StorageT value
Definition: bfloat16.hpp:53
Provides a cross-patform vector class template that works efficiently on SYCL devices as well as in h...
Definition: types.hpp:284
#define __DPCPP_SYCL_EXTERNAL
std::array< Bfloat16StorageT, 4 > Vec4StorageT
Definition: bfloat16.hpp:44
std::array< Bfloat16StorageT, 3 > Vec3StorageT
Definition: bfloat16.hpp:43
std::array< Bfloat16StorageT, 16 > Vec16StorageT
Definition: bfloat16.hpp:46
std::array< Bfloat16StorageT, 8 > Vec8StorageT
Definition: bfloat16.hpp:45
std::array< Bfloat16StorageT, 2 > Vec2StorageT
Definition: bfloat16.hpp:42
bfloat16 bitsToBfloat16(const Bfloat16StorageT Value)
Definition: bfloat16.hpp:247
Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value)
Definition: bfloat16.hpp:241
__attribute__((always_inline)) auto invoke_simd(sycl
The invoke_simd free function invokes a SIMD function using all work-items in a sub_group.
sycl::detail::half_impl::half half
Definition: aliases.hpp:101
Definition: access.hpp:18
_Abi const simd< _Tp, _Abi > & noexcept
Definition: simd.hpp:1324