18 #include <type_traits>
22 namespace ext::oneapi::experimental {
26 uint32_t
to_uint32_t(sycl::marray<bfloat16, N> x,
size_t start) {
36 std::enable_if_t<std::is_same<T, bfloat16>::value,
bool>
isnan(T x) {
38 return (((XBits & 0x7F80) == 0x7F80) && (XBits & 0x7F)) ? true :
false;
41 template <
size_t N> sycl::marray<bool, N>
isnan(sycl::marray<bfloat16, N> x) {
42 sycl::marray<bool, N> res;
43 for (
size_t i = 0; i < N; i++) {
50 std::enable_if_t<std::is_same<T, bfloat16>::value, T>
fabs(T x) {
51 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
57 "bfloat16 math functions are not currently supported on the host device.",
58 PI_ERROR_INVALID_DEVICE);
63 sycl::marray<bfloat16, N>
fabs(sycl::marray<bfloat16, N> x) {
64 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
65 sycl::marray<bfloat16, N> res;
67 for (
size_t i = 0; i < N / 2; i++) {
69 std::memcpy(&res[i * 2], &partial_res,
sizeof(uint32_t));
81 "bfloat16 math functions are not currently supported on the host device.",
82 PI_ERROR_INVALID_DEVICE);
87 std::enable_if_t<std::is_same<T, bfloat16>::value, T>
fmin(T x, T y) {
88 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
103 if (((XBits | YBits) ==
109 return (x < y) ? x : y;
114 sycl::marray<bfloat16, N>
fmin(sycl::marray<bfloat16, N> x,
115 sycl::marray<bfloat16, N> y) {
116 sycl::marray<bfloat16, N> res;
117 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
118 for (
size_t i = 0; i < N / 2; i++) {
121 std::memcpy(&res[i * 2], &partial_res,
sizeof(uint32_t));
132 for (
size_t i = 0; i < N; i++) {
133 res[i] =
fmin(x[i], y[i]);
139 template <
typename T>
140 std::enable_if_t<std::is_same<T, bfloat16>::value, T>
fmax(T x, T y) {
141 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
156 if (((XBits | YBits) ==
161 return (x > y) ? x : y;
166 sycl::marray<bfloat16, N>
fmax(sycl::marray<bfloat16, N> x,
167 sycl::marray<bfloat16, N> y) {
168 sycl::marray<bfloat16, N> res;
169 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
170 for (
size_t i = 0; i < N / 2; i++) {
173 std::memcpy(&res[i * 2], &partial_res,
sizeof(uint32_t));
184 for (
size_t i = 0; i < N; i++) {
185 res[i] =
fmax(x[i], y[i]);
191 template <
typename T>
192 std::enable_if_t<std::is_same<T, bfloat16>::value, T>
fma(T x, T y, T z) {
193 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
199 return sycl::ext::oneapi::bfloat16{sycl::fma(
float{x},
float{y},
float{z})};
204 sycl::marray<bfloat16, N>
fma(sycl::marray<bfloat16, N> x,
205 sycl::marray<bfloat16, N> y,
206 sycl::marray<bfloat16, N> z) {
207 sycl::marray<bfloat16, N> res;
208 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
209 for (
size_t i = 0; i < N / 2; i++) {
213 std::memcpy(&res[i * 2], &partial_res,
sizeof(uint32_t));
226 for (
size_t i = 0; i < N; i++) {
227 res[i] =
fma(x[i], y[i], z[i]);
#define __SYCL_INLINE_VER_NAMESPACE(X)
void memcpy(void *Dst, const void *Src, std::size_t Size)
bfloat16 bitsToBfloat16(const Bfloat16StorageT Value)
Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value)
uint16_t Bfloat16StorageT
uint32_t to_uint32_t(sycl::marray< bfloat16, N > x, size_t start)
sycl::marray< bfloat16, N > fabs(sycl::marray< bfloat16, N > x)
sycl::marray< bool, N > isnan(sycl::marray< bfloat16, N > x)
std::enable_if_t< detail::is_bf16_storage_type< T >::value, T > fma(T x, T y, T z)
std::enable_if_t< detail::is_bf16_storage_type< T >::value, T > fmax(T x, T y)
std::enable_if_t< detail::is_bf16_storage_type< T >::value, T > fmin(T x, T y)
---— Error handling, matching OpenCL plugin semantics.