DPC++ Runtime
Runtime libraries for oneAPI DPC++
bfloat16_type_traits.hpp
Go to the documentation of this file.
1 //==-------------- bfloat16_type_traits.hpp - DPC++ Explicit SIMD API ------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // Implementation of SIMD element type traits for the bfloat16 type.
9 //===----------------------------------------------------------------------===//
10 
11 #pragma once
12 
15 
17 
19 
20 namespace sycl {
22 namespace ext::intel::esimd::detail {
23 
24 using bfloat16 = sycl::ext::oneapi::bfloat16;
25 
26 template <> struct element_type_traits<bfloat16> {
27  // TODO map the raw type to __bf16 once SPIRV target supports it:
28  using RawT = uint_type_t<sizeof(bfloat16)>;
29  // Nearest standard enclosing C++ type to delegate natively unsupported
30  // operations to:
31  using EnclosingCppT = float;
32  // Can't map bfloat16 operations to opertations on RawT:
33  static inline constexpr bool use_native_cpp_ops = false;
34  static inline constexpr bool is_floating_point = true;
35 };
36 
37 #ifdef __SYCL_DEVICE_ONLY__
38 // VC BE-specific glitch
39 // @llvm.genx.bf.cvt uses half (_Float16) as bit representation for bfloat16
40 using vc_be_bfloat16_raw_t = _Float16;
41 #endif // __SYCL_DEVICE_ONLY__
42 
43 // ------------------- Type conversion traits
44 
45 template <int N> struct vector_conversion_traits<bfloat16, N> {
46  using StdT = __cpp_t<bfloat16>;
47  using StdVecT = vector_type_t<StdT, N>;
48  using RawT = __raw_t<bfloat16>;
49 
50  static ESIMD_INLINE vector_type_t<RawT, N>
51  convert_to_raw(vector_type_t<StdT, N> Val) {
52 #ifdef __SYCL_DEVICE_ONLY__
53  using RawVecT = vector_type_t<vc_be_bfloat16_raw_t, N>;
54  RawVecT ConvVal = __esimd_bf_cvt<vc_be_bfloat16_raw_t, StdT, N>(Val);
55  // cast from _Float16 to int16_t:
56  return sycl::bit_cast<vector_type_t<RawT, N>>(ConvVal);
57 #else
58  vector_type_t<RawT, N> Output = 0;
59 
60  for (int i = 0; i < N; i++) {
61  Output[i] = sycl::bit_cast<RawT>(static_cast<bfloat16>(Val[i]));
62  }
63  return Output;
64 #endif // __SYCL_DEVICE_ONLY__
65  }
66 
67  static ESIMD_INLINE vector_type_t<StdT, N>
68  convert_to_cpp(vector_type_t<RawT, N> Val) {
69 #ifdef __SYCL_DEVICE_ONLY__
70  using RawVecT = vector_type_t<vc_be_bfloat16_raw_t, N>;
71  RawVecT Bits = sycl::bit_cast<RawVecT>(Val);
72  return __esimd_bf_cvt<StdT, vc_be_bfloat16_raw_t, N>(Bits);
73 #else
74  vector_type_t<StdT, N> Output;
75 
76  for (int i = 0; i < N; i++) {
77  Output[i] = sycl::bit_cast<bfloat16>(Val[i]);
78  }
79  return Output;
80 #endif // __SYCL_DEVICE_ONLY__
81  }
82 };
83 
84 // TODO: remove bitcasts from the scalar_conversion_traits, and replace with
85 // sycl::bit_cast directly
86 template <> struct scalar_conversion_traits<bfloat16> {
87  using RawT = __raw_t<bfloat16>;
88 
89  static ESIMD_INLINE RawT bitcast_to_raw(bfloat16 Val) {
90  return sycl::bit_cast<RawT>(Val);
91  }
92 
93  static ESIMD_INLINE bfloat16 bitcast_to_wrapper(RawT Val) {
94  return sycl::bit_cast<bfloat16>(Val);
95  }
96 };
97 
98 // bfloat16 uses default inefficient implementations of std C++ operations,
99 // hence no specializations of other traits.
100 
101 // Misc
102 inline std::ostream &operator<<(std::ostream &O, bfloat16 const &rhs) {
103  O << static_cast<float>(rhs);
104  return O;
105 }
106 
107 } // namespace ext::intel::esimd::detail
108 } // __SYCL_INLINE_VER_NAMESPACE(_V1)
109 } // namespace sycl
110 
__SYCL_INLINE_VER_NAMESPACE
#define __SYCL_INLINE_VER_NAMESPACE(X)
Definition: defines_elementary.hpp:11
sycl
---— Error handling, matching OpenCL plugin semantics.
Definition: access.hpp:14
intrin.hpp
elem_type_traits.hpp
bfloat16.hpp
sycl::_V1::operator<<
std::ostream & operator<<(std::ostream &Out, backend be)
Definition: backend_types.hpp:46