47 inline namespace _V1 {
48 namespace ext::oneapi {
61 #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
62 const uint16_t *src_i16 = sycl::bit_cast<const uint16_t *>(src);
65 else if constexpr (N == 2)
67 else if constexpr (N == 3)
69 else if constexpr (N == 4)
71 else if constexpr (N == 8)
73 else if constexpr (N == 16)
76 for (
int i = 0; i < N; ++i) {
77 dst[i] = (float)src[i];
113 uint32_t roundingBias = ((intStorage >> 16) & 0x1) + 0x00007FFF;
114 return static_cast<uint16_t
>((intStorage + roundingBias) >> 16);
119 #if defined(__SYCL_DEVICE_ONLY__)
120 #if defined(__NVPTX__)
121 #if (__SYCL_CUDA_ARCH__ >= 800)
123 asm(
"cvt.rn.bf16.f32 %0, %1;" :
"=h"(res) :
"f"(a));
126 return from_float_fallback(a);
128 #elif defined(__AMDGCN__)
129 return from_float_fallback(a);
134 return from_float_fallback(a);
138 #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
145 intStorage =
a << 16;
163 value = from_float(rhs);
171 value = from_float(rhs);
176 operator float()
const {
return to_float(
value); }
182 explicit operator bool() {
return to_float(
value) != 0.0f; }
186 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
187 (__SYCL_CUDA_ARCH__ >= 800)
189 asm(
"neg.bf16 %0, %1;" :
"=h"(res) :
"h"(lhs.
value));
191 #elif defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
200 friend bfloat16 &operator op(bfloat16 & lhs) { \
201 float f = to_float(lhs.value); \
202 lhs.value = from_float(op f); \
205 friend bfloat16 operator op(bfloat16 &lhs, int) { \
206 bfloat16 old = lhs; \
216 friend bfloat16 &operator op(bfloat16 & lhs, const bfloat16 & rhs) { \
217 float f = static_cast<float>(lhs); \
218 f op static_cast<float>(rhs); \
228 #define OP(type, op) \
229 friend type operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \
230 return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
232 template <typename T> \
233 friend std::enable_if_t<std::is_convertible_v<T, float>, type> operator op( \
234 const bfloat16 & lhs, const T & rhs) { \
235 return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
237 template <typename T> \
238 friend std::enable_if_t<std::is_convertible_v<T, float>, type> operator op( \
239 const T & lhs, const bfloat16 & rhs) { \
240 return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
259 O << static_cast<float>(rhs);
264 float ValFloat = 0.0f;
274 #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
275 uint16_t *dst_i16 = sycl::bit_cast<uint16_t *>(dst);
276 if constexpr (N == 1)
278 else if constexpr (N == 2)
280 else if constexpr (N == 3)
282 else if constexpr (N == 4)
284 else if constexpr (N == 8)
286 else if constexpr (N == 16)
289 for (
int i = 0; i < N; ++i) {
315 enum SYCLRoundingMode { automatic = 0, rte = 1, rtz = 2, rtp = 3, rtn = 4 };
318 template <
typename Ty>
static size_t get_msb_pos(
const Ty &
x) {
321 Ty mask = ((Ty)1 << (
sizeof(Ty) * 8 - 1));
322 for (idx = 0; idx < (
sizeof(Ty) * 8); ++idx) {
323 if ((
x & mask) == mask)
328 return (
sizeof(Ty) * 8 - 1 - idx);
335 getBFloat16FromFloatWithRoundingMode(
const float &f,
336 SYCLRoundingMode roundingMode) {
339 roundingMode == SYCLRoundingMode::rte) {
343 uint32_t u32_val = sycl::bit_cast<uint32_t>(f);
344 uint16_t bf16_sign =
static_cast<uint16_t
>((u32_val >> 31) & 0x1);
345 uint16_t bf16_exp =
static_cast<uint16_t
>((u32_val >> 23) & 0x7FF);
346 uint32_t f_mant = u32_val & 0x7F'FFFF;
347 uint16_t bf16_mant =
static_cast<uint16_t
>(f_mant >> 16);
349 if (bf16_exp == 0xFF) {
358 if (!bf16_exp && !f_mant) {
362 uint16_t mant_discard =
static_cast<uint16_t
>(f_mant & 0xFFFF);
363 switch (roundingMode) {
364 case SYCLRoundingMode::rtn:
365 if (bf16_sign && mant_discard)
368 case SYCLRoundingMode::rtz:
370 case SYCLRoundingMode::rtp:
371 if (!bf16_sign && mant_discard)
377 case SYCLRoundingMode::rte:
383 if (bf16_mant == 0x80) {
388 return bitsToBfloat16((bf16_sign << 15) | (bf16_exp << 7) | bf16_mant);
396 template <
typename T>
398 getBFloat16FromUIntegralWithRoundingMode(T &u,
399 SYCLRoundingMode roundingMode) {
401 size_t msb_pos = get_msb_pos(u);
406 T mant = u & ((
static_cast<T
>(1) << msb_pos) - 1);
411 uint16_t b_exp = msb_pos;
417 mant <<= (7 - msb_pos);
418 b_mant =
static_cast<uint16_t
>(mant);
420 b_mant =
static_cast<uint16_t
>(mant >> (msb_pos - 7));
421 T mant_discard = mant & ((
static_cast<T
>(1) << (msb_pos - 7)) - 1);
422 T mid =
static_cast<T
>(1) << (msb_pos - 8);
423 switch (roundingMode) {
425 case SYCLRoundingMode::rte:
426 if ((mant_discard > mid) ||
427 ((mant_discard == mid) && ((b_mant & 0x1) == 0x1)))
430 case SYCLRoundingMode::rtp:
434 case SYCLRoundingMode::rtn:
435 case SYCLRoundingMode::rtz:
439 if (b_mant == 0x80) {
451 template <
typename T>
453 getBFloat16FromSIntegralWithRoundingMode(T &i,
454 SYCLRoundingMode roundingMode) {
456 typedef typename std::make_unsigned_t<T> UTy;
458 uint16_t b_sign = (i >= 0) ? 0 : 0x8000;
459 UTy ui = (i > 0) ?
static_cast<UTy
>(i) :
static_cast<UTy
>(-i);
460 size_t msb_pos = get_msb_pos<UTy>(ui);
463 UTy mant = ui & ((
static_cast<UTy
>(1) << msb_pos) - 1);
465 uint16_t b_exp = msb_pos;
468 mant <<= (7 - msb_pos);
469 b_mant =
static_cast<uint16_t
>(mant);
471 b_mant =
static_cast<uint16_t
>(mant >> (msb_pos - 7));
472 T mant_discard = mant & ((
static_cast<T
>(1) << (msb_pos - 7)) - 1);
473 T mid =
static_cast<T
>(1) << (msb_pos - 8);
474 switch (roundingMode) {
476 case SYCLRoundingMode::rte:
477 if ((mant_discard > mid) ||
478 ((mant_discard == mid) && ((b_mant & 0x1) == 0x1)))
481 case SYCLRoundingMode::rtp:
482 if (mant_discard && !b_sign)
485 case SYCLRoundingMode::rtn:
486 if (mant_discard && b_sign)
488 case SYCLRoundingMode::rtz:
493 if (b_mant == 0x80) {
504 static bfloat16 getBFloat16FromDoubleWithRTE(
const double &d) {
506 uint64_t u64_val = sycl::bit_cast<uint64_t>(d);
507 int16_t bf16_sign = (u64_val >> 63) & 0x1;
508 uint16_t fp64_exp =
static_cast<uint16_t
>((u64_val >> 52) & 0x7FF);
509 uint64_t fp64_mant = (u64_val & 0xF'FFFF'FFFF'FFFF);
512 if (fp64_exp == 0x7FF) {
514 return bf16_sign ? 0xFF80 : 0x7F80;
523 return bf16_sign ? 0x8000 : 0x0;
528 if (
static_cast<int16_t
>(fp64_exp) > 127) {
529 return bf16_sign ? 0xFF80 : 0x7F80;
533 if (
static_cast<int16_t
>(fp64_exp) < -133) {
534 return bf16_sign ? 0x8000 : 0x0;
540 uint64_t discard_bits;
541 if (
static_cast<int16_t
>(fp64_exp) < -126) {
542 fp64_mant |= 0x10'0000'0000'0000;
543 fp64_mant >>= -126 -
static_cast<int16_t
>(fp64_exp) - 1;
544 discard_bits = fp64_mant & 0x3FFF'FFFF'FFFF;
545 bf16_mant =
static_cast<uint16_t
>(fp64_mant >> 46);
546 if (discard_bits > 0x2000'0000'0000 ||
547 ((discard_bits == 0x2000'0000'0000) && ((bf16_mant & 0x1) == 0x1)))
550 if (bf16_mant == 0x80) {
554 return (bf16_sign << 15) | (fp64_exp << 7) | bf16_mant;
558 discard_bits = fp64_mant & 0x1FFF'FFFF'FFFF;
559 bf16_mant =
static_cast<uint16_t
>(fp64_mant >> 45);
560 if (discard_bits > 0x1000'0000'0000 ||
561 ((discard_bits == 0x1000'0000'0000) && ((bf16_mant & 0x1) == 0x1)))
564 if (bf16_mant == 0x80) {
565 if (fp64_exp != 127) {
569 return bf16_sign ? 0xFF80 : 0x7F80;
574 return (bf16_sign << 15) | (fp64_exp << 7) | bf16_mant;
578 template <
typename Ty,
int rm>
584 constexpr SYCLRoundingMode roundingMode =
static_cast<SYCLRoundingMode
>(rm);
587 if constexpr (std::is_same_v<Ty, float>) {
588 return getBFloat16FromFloatWithRoundingMode(
a, roundingMode);
591 else if constexpr (std::is_same_v<Ty, double>) {
594 roundingMode == SYCLRoundingMode::rte,
595 "Only automatic/RTE rounding mode is supported for double type.");
596 return getBFloat16FromDoubleWithRTE(
a);
599 else if constexpr (std::is_same_v<Ty, sycl::half>) {
603 return getBFloat16FromFloatWithRoundingMode(
static_cast<float>(
a),
607 else if constexpr (std::is_integral_v<Ty> && std::is_unsigned_v<Ty>) {
608 return getBFloat16FromUIntegralWithRoundingMode<Ty>(
a, roundingMode);
611 else if constexpr (std::is_integral_v<Ty> && std::is_signed_v<Ty>) {
612 return getBFloat16FromSIntegralWithRoundingMode<Ty>(
a, roundingMode);
614 static_assert(std::is_integral_v<Ty> || std::is_floating_point_v<Ty>,
615 "Only integral and floating point types are supported.");
__DPCPP_SYCL_EXTERNAL void __devicelib_ConvertFToBF16INTELVec4(const float *, uint16_t *) noexcept
__DPCPP_SYCL_EXTERNAL void __devicelib_ConvertFToBF16INTELVec16(const float *, uint16_t *) noexcept
__DPCPP_SYCL_EXTERNAL void __devicelib_ConvertFToBF16INTELVec3(const float *, uint16_t *) noexcept
__DPCPP_SYCL_EXTERNAL void __devicelib_ConvertBF16ToFINTELVec16(const uint16_t *, float *) noexcept
__DPCPP_SYCL_EXTERNAL void __devicelib_ConvertFToBF16INTELVec8(const float *, uint16_t *) noexcept
__DPCPP_SYCL_EXTERNAL void __devicelib_ConvertBF16ToFINTELVec1(const uint16_t *, float *) noexcept
__DPCPP_SYCL_EXTERNAL void __devicelib_ConvertBF16ToFINTELVec2(const uint16_t *, float *) noexcept
__DPCPP_SYCL_EXTERNAL void __devicelib_ConvertFToBF16INTELVec1(const float *, uint16_t *) noexcept
__DPCPP_SYCL_EXTERNAL void __devicelib_ConvertBF16ToFINTELVec3(const uint16_t *, float *) noexcept
__DPCPP_SYCL_EXTERNAL uint16_t __devicelib_ConvertFToBF16INTEL(const float &) noexcept
__DPCPP_SYCL_EXTERNAL void __devicelib_ConvertBF16ToFINTELVec4(const uint16_t *, float *) noexcept
__DPCPP_SYCL_EXTERNAL void __devicelib_ConvertFToBF16INTELVec2(const float *, uint16_t *) noexcept
__DPCPP_SYCL_EXTERNAL float __devicelib_ConvertBF16ToFINTEL(const uint16_t &) noexcept
__DPCPP_SYCL_EXTERNAL void __devicelib_ConvertBF16ToFINTELVec8(const uint16_t *, float *) noexcept
bfloat16(const sycl::half &a)
constexpr bfloat16(const bfloat16 &)=default
friend bfloat16 operator-(bfloat16 &lhs)
constexpr bfloat16(bfloat16 &&)=default
friend std::istream & operator>>(std::istream &I, bfloat16 &rhs)
friend std::ostream & operator<<(std::ostream &O, bfloat16 const &rhs)
bfloat16 & operator=(const sycl::half &rhs)
constexpr bfloat16 & operator=(const bfloat16 &rhs)=default
bfloat16 & operator=(const float &rhs)
detail::Bfloat16StorageT value
static bfloat16 getBfloat16WithRoundingMode(const Ty &a)
#define __DPCPP_SYCL_EXTERNAL
sycl::ext::oneapi::bfloat16 bfloat16
void FloatVecToBF16Vec(float src[N], bfloat16 dst[N])
bfloat16 bitsToBfloat16(const Bfloat16StorageT Value)
void BF16VecToFloatVec(const bfloat16 src[N], float dst[N])
Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value)
uint16_t Bfloat16StorageT
sycl::detail::half_impl::half half
_Abi const simd< _Tp, _Abi > & noexcept