DPC++ Runtime
Runtime libraries for oneAPI DPC++
bfloat16_type_traits.hpp
Go to the documentation of this file.
1 //==-------------- bfloat16_type_traits.hpp - DPC++ Explicit SIMD API ------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // Implementation of SIMD element type traits for the bfloat16 type.
9 //===----------------------------------------------------------------------===//
10 
11 #pragma once
12 
15 
17 
19 
20 namespace sycl {
21 inline namespace _V1 {
22 namespace ext::intel::esimd::detail {
23 
24 using bfloat16 = sycl::ext::oneapi::bfloat16;
25 
26 template <> struct element_type_traits<bfloat16> {
27  // TODO map the raw type to __bf16 once SPIRV target supports it:
28  using RawT = uint_type_t<sizeof(bfloat16)>;
29  // Nearest standard enclosing C++ type to delegate natively unsupported
30  // operations to:
31  using EnclosingCppT = float;
32  // Can't map bfloat16 operations to opertations on RawT:
33  static constexpr bool use_native_cpp_ops = false;
34  static constexpr bool is_floating_point = true;
35 };
36 
37 #ifdef __SYCL_DEVICE_ONLY__
38 // VC BE-specific glitch
39 // @llvm.genx.bf.cvt uses half (_Float16) as bit representation for bfloat16
40 using vc_be_bfloat16_raw_t = _Float16;
41 #endif // __SYCL_DEVICE_ONLY__
42 
43 // ------------------- Type conversion traits
44 
45 template <int N> struct vector_conversion_traits<bfloat16, N> {
46  using StdT = __cpp_t<bfloat16>;
47  using StdVecT = vector_type_t<StdT, N>;
48  using RawT = __raw_t<bfloat16>;
49 
50  static ESIMD_INLINE vector_type_t<RawT, N>
51  convert_to_raw(vector_type_t<StdT, N> Val) {
52 #ifdef __SYCL_DEVICE_ONLY__
53  using RawVecT = vector_type_t<vc_be_bfloat16_raw_t, N>;
54  RawVecT ConvVal = __esimd_bf_cvt<vc_be_bfloat16_raw_t, StdT, N>(Val);
55  // cast from _Float16 to int16_t:
56  return sycl::bit_cast<vector_type_t<RawT, N>>(ConvVal);
57 #else
58  __ESIMD_UNSUPPORTED_ON_HOST;
59 #endif // __SYCL_DEVICE_ONLY__
60  }
61 
62  static ESIMD_INLINE vector_type_t<StdT, N>
63  convert_to_cpp(vector_type_t<RawT, N> Val) {
64 #ifdef __SYCL_DEVICE_ONLY__
65  using RawVecT = vector_type_t<vc_be_bfloat16_raw_t, N>;
66  RawVecT Bits = sycl::bit_cast<RawVecT>(Val);
67  return __esimd_bf_cvt<StdT, vc_be_bfloat16_raw_t, N>(Bits);
68 #else
69  __ESIMD_UNSUPPORTED_ON_HOST;
70 #endif // __SYCL_DEVICE_ONLY__
71  }
72 };
73 
74 // TODO: remove bitcasts from the scalar_conversion_traits, and replace with
75 // sycl::bit_cast directly
76 template <> struct scalar_conversion_traits<bfloat16> {
77  using RawT = __raw_t<bfloat16>;
78 
79  static ESIMD_INLINE RawT bitcast_to_raw(bfloat16 Val) {
80  return sycl::bit_cast<RawT>(Val);
81  }
82 
83  static ESIMD_INLINE bfloat16 bitcast_to_wrapper(RawT Val) {
84  return sycl::bit_cast<bfloat16>(Val);
85  }
86 };
87 
88 // bfloat16 uses default inefficient implementations of std C++ operations,
89 // hence no specializations of other traits.
90 
91 // Misc
92 inline std::ostream &operator<<(std::ostream &O, bfloat16 const &rhs) {
93  O << static_cast<float>(rhs);
94  return O;
95 }
96 
97 template <> struct is_esimd_arithmetic_type<bfloat16, void> : std::true_type {};
98 
99 } // namespace ext::intel::esimd::detail
100 } // namespace _V1
101 } // namespace sycl
102 
auto operator<<(const __ESIMD_DNS::simd_obj_impl< __raw_t< T1 >, N, SimdT< T1, N >> &LHS, const __ESIMD_DNS::simd_obj_impl< __raw_t< T2 >, N, SimdT< T2, N >> &RHS)
Definition: operators.hpp:178
Definition: access.hpp:18