18 #include <type_traits>
22 namespace ext::oneapi::experimental {
26 uint32_t
to_uint32_t(sycl::marray<bfloat16, N> x,
size_t start) {
36 std::enable_if_t<std::is_same_v<T, bfloat16>,
bool>
isnan(T x) {
38 return (((XBits & 0x7F80) == 0x7F80) && (XBits & 0x7F)) ? true :
false;
41 template <
size_t N> sycl::marray<bool, N>
isnan(sycl::marray<bfloat16, N> x) {
42 sycl::marray<bool, N> res;
43 for (
size_t i = 0; i < N; i++) {
50 std::enable_if_t<std::is_same_v<T, bfloat16>, T>
fabs(T x) {
51 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
58 x = ((XBits & SignMask) == SignMask)
63 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
67 sycl::marray<bfloat16, N>
fabs(sycl::marray<bfloat16, N> x) {
68 sycl::marray<bfloat16, N> res;
69 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
70 for (
size_t i = 0; i < N / 2; i++) {
72 std::memcpy(&res[i * 2], &partial_res,
sizeof(uint32_t));
81 for (
size_t i = 0; i < N; i++) {
84 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
89 std::enable_if_t<std::is_same_v<T, bfloat16>, T>
fmin(T x, T y) {
90 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
105 if (((XBits | YBits) ==
111 return (x < y) ? x : y;
112 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
116 sycl::marray<bfloat16, N>
fmin(sycl::marray<bfloat16, N> x,
117 sycl::marray<bfloat16, N> y) {
118 sycl::marray<bfloat16, N> res;
119 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
120 for (
size_t i = 0; i < N / 2; i++) {
123 std::memcpy(&res[i * 2], &partial_res,
sizeof(uint32_t));
134 for (
size_t i = 0; i < N; i++) {
135 res[i] =
fmin(x[i], y[i]);
137 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
141 template <
typename T>
142 std::enable_if_t<std::is_same_v<T, bfloat16>, T>
fmax(T x, T y) {
143 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
158 if (((XBits | YBits) ==
163 return (x > y) ? x : y;
164 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
168 sycl::marray<bfloat16, N>
fmax(sycl::marray<bfloat16, N> x,
169 sycl::marray<bfloat16, N> y) {
170 sycl::marray<bfloat16, N> res;
171 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
172 for (
size_t i = 0; i < N / 2; i++) {
175 std::memcpy(&res[i * 2], &partial_res,
sizeof(uint32_t));
186 for (
size_t i = 0; i < N; i++) {
187 res[i] =
fmax(x[i], y[i]);
189 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
193 template <
typename T>
194 std::enable_if_t<std::is_same_v<T, bfloat16>, T>
fma(T x, T y, T z) {
195 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
201 return sycl::ext::oneapi::bfloat16{sycl::fma(
float{x},
float{y},
float{z})};
202 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
206 sycl::marray<bfloat16, N>
fma(sycl::marray<bfloat16, N> x,
207 sycl::marray<bfloat16, N> y,
208 sycl::marray<bfloat16, N> z) {
209 sycl::marray<bfloat16, N> res;
210 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
211 for (
size_t i = 0; i < N / 2; i++) {
215 std::memcpy(&res[i * 2], &partial_res,
sizeof(uint32_t));
228 for (
size_t i = 0; i < N; i++) {
229 res[i] =
fma(x[i], y[i], z[i]);
231 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
235 #define BFLOAT16_MATH_FP32_WRAPPERS(op) \
236 template <typename T> \
237 std::enable_if_t<std::is_same<T, bfloat16>::value, T> op(T x) { \
238 return sycl::ext::oneapi::bfloat16{sycl::op(float{x})}; \
241 #define BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(op) \
242 template <size_t N> \
243 sycl::marray<bfloat16, N> op(sycl::marray<bfloat16, N> x) { \
244 sycl::marray<bfloat16, N> res; \
245 for (size_t i = 0; i < N; i++) { \
280 #undef BFLOAT16_MATH_FP32_WRAPPERS
281 #undef BFLOAT16_MATH_FP32_WRAPPERS_MARRAY