DPC++ Runtime
Runtime libraries for oneAPI DPC++
vector_arith.hpp
Go to the documentation of this file.
1 //=== vector_arith.hpp --- Implementation of arithmetic ops on sycl::vec ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #include <sycl/aliases.hpp> // for half, cl_char, cl_int
12 #include <sycl/detail/generic_type_traits.hpp> // for is_sigeninteger, is_s...
13 #include <sycl/detail/type_list.hpp> // for is_contained
14 #include <sycl/detail/type_traits.hpp> // for is_floating_point
15 
16 #include <sycl/ext/oneapi/bfloat16.hpp> // bfloat16
17 
18 #include <cstddef>
19 #include <type_traits> // for enable_if_t, is_same
20 
21 namespace sycl {
22 inline namespace _V1 {
23 
24 template <typename DataT, int NumElem> class __SYCL_EBO vec;
25 
26 namespace detail {
27 
28 template <typename VecT> class VecAccess;
29 
30 // Macros to populate binary operation on sycl::vec.
31 #if defined(__SYCL_BINOP) || defined(BINOP_BASE)
32 #error "Undefine __SYCL_BINOP and BINOP_BASE macro"
33 #endif
34 
35 #ifdef __SYCL_DEVICE_ONLY__
36 #define BINOP_BASE(BINOP, OPASSIGN, CONVERT, COND) \
37  template <typename T = DataT> \
38  friend std::enable_if_t<(COND), vec_t> operator BINOP(const vec_t & Lhs, \
39  const vec_t & Rhs) { \
40  vec_t Ret; \
41  if constexpr (vec_t::IsBfloat16) { \
42  for (size_t I = 0; I < NumElements; ++I) { \
43  Ret[I] = Lhs[I] BINOP Rhs[I]; \
44  } \
45  } else { \
46  auto ExtVecLhs = sycl::bit_cast<typename vec_t::vector_t>(Lhs); \
47  auto ExtVecRhs = sycl::bit_cast<typename vec_t::vector_t>(Rhs); \
48  Ret = vec<DataT, NumElements>(ExtVecLhs BINOP ExtVecRhs); \
49  if constexpr (std::is_same_v<DataT, bool> && CONVERT) { \
50  vec_arith_common<bool, NumElements>::ConvertToDataT(Ret); \
51  } \
52  } \
53  return Ret; \
54  }
55 #else // __SYCL_DEVICE_ONLY__
56 
57 #define BINOP_BASE(BINOP, OPASSIGN, CONVERT, COND) \
58  template <typename T = DataT> \
59  friend std::enable_if_t<(COND), vec_t> operator BINOP(const vec_t & Lhs, \
60  const vec_t & Rhs) { \
61  vec_t Ret{}; \
62  for (size_t I = 0; I < NumElements; ++I) { \
63  Ret[I] = Lhs[I] BINOP Rhs[I]; \
64  } \
65  return Ret; \
66  }
67 #endif // __SYCL_DEVICE_ONLY__
68 
69 #define __SYCL_BINOP(BINOP, OPASSIGN, CONVERT, COND) \
70  BINOP_BASE(BINOP, OPASSIGN, CONVERT, COND) \
71  \
72  template <typename T = DataT> \
73  friend std::enable_if_t<(COND), vec_t> operator BINOP(const vec_t & Lhs, \
74  const DataT & Rhs) { \
75  return Lhs BINOP vec_t(Rhs); \
76  } \
77  template <typename T = DataT> \
78  friend std::enable_if_t<(COND), vec_t> operator BINOP(const DataT & Lhs, \
79  const vec_t & Rhs) { \
80  return vec_t(Lhs) BINOP Rhs; \
81  } \
82  template <typename T = DataT> \
83  friend std::enable_if_t<(COND), vec_t> &operator OPASSIGN( \
84  vec_t & Lhs, const vec_t & Rhs) { \
85  Lhs = Lhs BINOP Rhs; \
86  return Lhs; \
87  } \
88  template <int Num = NumElements, typename T = DataT> \
89  friend std::enable_if_t<(Num != 1) && (COND), vec_t &> operator OPASSIGN( \
90  vec_t & Lhs, const DataT & Rhs) { \
91  Lhs = Lhs BINOP vec_t(Rhs); \
92  return Lhs; \
93  }
94 
95 /****************************************************************
96  * vec_arith_common
97  * / | \
98  * / | \
99  * vec_arith<int> vec_arith<float> ... vec_arith<byte>
100  * \ | /
101  * \ | /
102  * sycl::vec<T>
103  *
104  * vec_arith_common is the base class for vec_arith. It contains
105  * the common math operators of sycl::vec for all types.
106  * vec_arith is the derived class that contains the math operators
107  * specialized for certain types. sycl::vec inherits from vec_arith.
108  * *************************************************************/
109 template <typename DataT, int NumElements> class vec_arith_common;
110 template <typename DataT> struct vec_helper;
111 
112 template <typename DataT, int NumElements>
113 class vec_arith : public vec_arith_common<DataT, NumElements> {
114 protected:
117  template <typename T> using vec_data = vec_helper<T>;
118 
119  // operator!.
121 #ifdef __SYCL_DEVICE_ONLY__
122  if constexpr (!vec_t::IsBfloat16) {
123  auto extVec = sycl::bit_cast<typename vec_t::vector_t>(Rhs);
125  (typename vec<ocl_t, NumElements>::vector_t) !extVec};
126  return Ret;
127  } else
128 #endif // __SYCL_DEVICE_ONLY__
129  {
131  for (size_t I = 0; I < NumElements; ++I) {
132  // static_cast will work here as the output of ! operator is either 0 or
133  // -1.
134  Ret[I] = static_cast<ocl_t>(-1 * (!Rhs[I]));
135  }
136  return Ret;
137  }
138  }
139 
140  // operator +.
141  friend vec_t operator+(const vec_t &Lhs) {
142 #ifdef __SYCL_DEVICE_ONLY__
143  auto extVec = sycl::bit_cast<typename vec_t::vector_t>(Lhs);
144  return vec_t{+extVec};
145 #else
146  vec_t Ret{};
147  for (size_t I = 0; I < NumElements; ++I)
148  Ret[I] = +Lhs[I];
149  return Ret;
150 #endif
151  }
152 
153  // operator -.
154  friend vec_t operator-(const vec_t &Lhs) {
155  vec_t Ret{};
156  if constexpr (vec_t::IsBfloat16) {
157  for (size_t I = 0; I < NumElements; I++)
158  Ret[I] = -Lhs[I];
159  } else {
160 #ifndef __SYCL_DEVICE_ONLY__
161  for (size_t I = 0; I < NumElements; ++I)
162  Ret[I] = -Lhs[I];
163 #else
164  auto extVec = sycl::bit_cast<typename vec_t::vector_t>(Lhs);
165  Ret = vec_t{-extVec};
166  if constexpr (std::is_same_v<DataT, bool>) {
168  }
169 #endif
170  }
171  return Ret;
172  }
173 
174 // Unary operations on sycl::vec
175 // FIXME: Don't allow Unary operators on vec<bool> after
176 // https://github.com/KhronosGroup/SYCL-CTS/issues/896 gets fixed.
177 #ifdef __SYCL_UOP
178 #error "Undefine __SYCL_UOP macro"
179 #endif
180 #define __SYCL_UOP(UOP, OPASSIGN) \
181  friend vec_t &operator UOP(vec_t & Rhs) { \
182  Rhs OPASSIGN DataT{1}; \
183  return Rhs; \
184  } \
185  friend vec_t operator UOP(vec_t &Lhs, int) { \
186  vec_t Ret(Lhs); \
187  Lhs OPASSIGN DataT{1}; \
188  return Ret; \
189  }
190 
191  __SYCL_UOP(++, +=)
192  __SYCL_UOP(--, -=)
193 #undef __SYCL_UOP
194 
195  // The logical operations on scalar types results in 0/1, while for vec<>,
196  // logical operations should result in 0 and -1 (similar to OpenCL vectors).
197  // That's why, for vec<DataT, 1>, we need to invert the result of the logical
198  // operations since we store vec<DataT, 1> as scalar type on the device.
199 #if defined(__SYCL_RELLOGOP) || defined(RELLOGOP_BASE)
200 #error "Undefine __SYCL_RELLOGOP and RELLOGOP_BASE macro."
201 #endif
202 
203 #ifdef __SYCL_DEVICE_ONLY__
204 #define RELLOGOP_BASE(RELLOGOP, COND) \
205  template <typename T = DataT> \
206  friend std::enable_if_t<(COND), vec<ocl_t, NumElements>> operator RELLOGOP( \
207  const vec_t & Lhs, const vec_t & Rhs) { \
208  vec<ocl_t, NumElements> Ret{}; \
209  /* ext_vector_type does not support bfloat16, so for these */ \
210  /* we do element-by-element operation on the underlying std::array. */ \
211  if constexpr (vec_t::IsBfloat16) { \
212  for (size_t I = 0; I < NumElements; ++I) { \
213  Ret[I] = static_cast<ocl_t>(-(Lhs[I] RELLOGOP Rhs[I])); \
214  } \
215  } else { \
216  auto ExtVecLhs = sycl::bit_cast<typename vec_t::vector_t>(Lhs); \
217  auto ExtVecRhs = sycl::bit_cast<typename vec_t::vector_t>(Rhs); \
218  /* Cast required to convert unsigned char ext_vec_type to */ \
219  /* char ext_vec_type. */ \
220  Ret = vec<ocl_t, NumElements>( \
221  (typename vec<ocl_t, NumElements>::vector_t)( \
222  ExtVecLhs RELLOGOP ExtVecRhs)); \
223  /* For NumElements == 1, we use scalar instead of ext_vector_type. */ \
224  if constexpr (NumElements == 1) { \
225  Ret *= -1; \
226  } \
227  } \
228  return Ret; \
229  }
230 #else // __SYCL_DEVICE_ONLY__
231 #define RELLOGOP_BASE(RELLOGOP, COND) \
232  template <typename T = DataT> \
233  friend std::enable_if_t<(COND), vec<ocl_t, NumElements>> operator RELLOGOP( \
234  const vec_t & Lhs, const vec_t & Rhs) { \
235  vec<ocl_t, NumElements> Ret{}; \
236  for (size_t I = 0; I < NumElements; ++I) { \
237  Ret[I] = static_cast<ocl_t>(-(Lhs[I] RELLOGOP Rhs[I])); \
238  } \
239  return Ret; \
240  }
241 #endif
242 
243 #define __SYCL_RELLOGOP(RELLOGOP, COND) \
244  RELLOGOP_BASE(RELLOGOP, COND) \
245  \
246  template <typename T = DataT> \
247  friend std::enable_if_t<(COND), vec<ocl_t, NumElements>> operator RELLOGOP( \
248  const vec_t & Lhs, const DataT & Rhs) { \
249  return Lhs RELLOGOP vec_t(Rhs); \
250  } \
251  template <typename T = DataT> \
252  friend std::enable_if_t<(COND), vec<ocl_t, NumElements>> operator RELLOGOP( \
253  const DataT & Lhs, const vec_t & Rhs) { \
254  return vec_t(Lhs) RELLOGOP Rhs; \
255  }
256 
257  // OP is: ==, !=, <, >, <=, >=, &&, ||
258  // vec<RET, NumElements> operatorOP(const vec<DataT, NumElements> &Rhs) const;
259  // vec<RET, NumElements> operatorOP(const DataT &Rhs) const;
260  __SYCL_RELLOGOP(==, true)
261  __SYCL_RELLOGOP(!=, true)
262  __SYCL_RELLOGOP(>, true)
263  __SYCL_RELLOGOP(<, true)
264  __SYCL_RELLOGOP(>=, true)
265  __SYCL_RELLOGOP(<=, true)
266 
267  // Only available to integral types.
268  __SYCL_RELLOGOP(&&, (!detail::is_vgenfloat_v<T>))
269  __SYCL_RELLOGOP(||, (!detail::is_vgenfloat_v<T>))
270 #undef __SYCL_RELLOGOP
271 #undef RELLOGOP_BASE
272 
273  // Binary operations on sycl::vec<> for all types except std::byte.
274  __SYCL_BINOP(+, +=, true, true)
275  __SYCL_BINOP(-, -=, true, true)
276  __SYCL_BINOP(*, *=, false, true)
277  __SYCL_BINOP(/, /=, false, true)
278 
279  // The following OPs are available only when: DataT != cl_float &&
280  // DataT != cl_double && DataT != cl_half && DataT != BF16.
281  __SYCL_BINOP(%, %=, false, (!detail::is_vgenfloat_v<T>))
282  // Bitwise operations are allowed for std::byte.
283  __SYCL_BINOP(|, |=, false, (!detail::is_vgenfloat_v<DataT>))
284  __SYCL_BINOP(&, &=, false, (!detail::is_vgenfloat_v<DataT>))
285  __SYCL_BINOP(^, ^=, false, (!detail::is_vgenfloat_v<DataT>))
286  __SYCL_BINOP(>>, >>=, false, (!detail::is_vgenfloat_v<DataT>))
287  __SYCL_BINOP(<<, <<=, true, (!detail::is_vgenfloat_v<DataT>))
288 
289  // friends
290  template <typename T1, int T2> friend class __SYCL_EBO vec;
291 }; // class vec_arith<>
292 
293 #if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
294 template <int NumElements>
295 class vec_arith<std::byte, NumElements>
296  : public vec_arith_common<std::byte, NumElements> {
297 protected:
298  // NumElements can never be zero. Still using the redundant check to avoid
299  // incomplete type errors.
300  using DataT = typename std::conditional_t<NumElements == 0, int, std::byte>;
302  template <typename T> using vec_data = vec_helper<T>;
303 
304  // Special <<, >> operators for std::byte.
305  // std::byte is not an arithmetic type and it only supports the following
306  // overloads of >> and << operators.
307  //
308  // 1 template <class IntegerType>
309  // constexpr std::byte operator<<( std::byte b, IntegerType shift )
310  // noexcept;
311  friend vec_t operator<<(const vec_t &Lhs, int shift) {
312  vec_t Ret;
313  for (size_t I = 0; I < NumElements; ++I) {
314  Ret[I] = Lhs[I] << shift;
315  }
316  return Ret;
317  }
318  friend vec_t &operator<<=(vec_t &Lhs, int shift) {
319  Lhs = Lhs << shift;
320  return Lhs;
321  }
322 
323  // 2 template <class IntegerType>
324  // constexpr std::byte operator>>( std::byte b, IntegerType shift )
325  // noexcept;
326  friend vec_t operator>>(const vec_t &Lhs, int shift) {
327  vec_t Ret;
328  for (size_t I = 0; I < NumElements; ++I) {
329  Ret[I] = Lhs[I] >> shift;
330  }
331  return Ret;
332  }
333  friend vec_t &operator>>=(vec_t &Lhs, int shift) {
334  Lhs = Lhs >> shift;
335  return Lhs;
336  }
337 
338  __SYCL_BINOP(|, |=, false, true)
339  __SYCL_BINOP(&, &=, false, true)
340  __SYCL_BINOP(^, ^=, false, true)
341 
342  // friends
343  template <typename T1, int T2> friend class __SYCL_EBO vec;
344 };
345 #endif // (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
346 
347 template <typename DataT, int NumElements> class vec_arith_common {
348 protected:
350 
351  static constexpr bool IsBfloat16 =
352  std::is_same_v<DataT, sycl::ext::oneapi::bfloat16>;
353 
354  // operator~() available only when: dataT != float && dataT != double
355  // && dataT != half
356  template <typename T = DataT>
357  friend std::enable_if_t<!detail::is_vgenfloat_v<T>, vec_t>
358  operator~(const vec_t &Rhs) {
359 #ifdef __SYCL_DEVICE_ONLY__
360  auto extVec = sycl::bit_cast<typename vec_t::vector_t>(Rhs);
361  vec_t Ret{~extVec};
362  if constexpr (std::is_same_v<DataT, bool>) {
363  ConvertToDataT(Ret);
364  }
365  return Ret;
366 #else
367  vec_t Ret{};
368  for (size_t I = 0; I < NumElements; ++I) {
369  Ret[I] = ~Rhs[I];
370  }
371  return Ret;
372 #endif
373  }
374 
375 #ifdef __SYCL_DEVICE_ONLY__
376  using vec_bool_t = vec<bool, NumElements>;
377  // Required only for std::bool.
378  static void ConvertToDataT(vec_bool_t &Ret) {
379  for (size_t I = 0; I < NumElements; ++I) {
380  Ret[I] = bit_cast<int8_t>(Ret[I]) != 0;
381  }
382  }
383 #endif
384 
385  // friends
386  template <typename T1, int T2> friend class __SYCL_EBO vec;
387 };
388 
389 #undef __SYCL_BINOP
390 #undef BINOP_BASE
391 
392 } // namespace detail
393 } // namespace _V1
394 } // namespace sycl
friend vec_t & operator>>=(vec_t &Lhs, int shift)
typename std::conditional_t< NumElements==0, int, std::byte > DataT
friend vec_t operator>>(const vec_t &Lhs, int shift)
friend vec_t operator<<(const vec_t &Lhs, int shift)
friend vec_t & operator<<=(vec_t &Lhs, int shift)
friend std::enable_if_t<!detail::is_vgenfloat_v< T >, vec_t > operator~(const vec_t &Rhs)
friend vec< ocl_t, NumElements > operator!(const vec_t &Rhs)
detail::select_cl_scalar_integral_signed_t< DataT > ocl_t
friend vec_t operator+(const vec_t &Lhs)
friend vec_t operator-(const vec_t &Lhs)
#define __SYCL_EBO
constexpr bool is_vgenfloat_v
select_apply_cl_scalar_t< T, sycl::opencl::cl_char, sycl::opencl::cl_short, sycl::opencl::cl_int, sycl::opencl::cl_long > select_cl_scalar_integral_signed_t
unsigned char byte
Definition: image.hpp:107
class __SYCL_EBO vec
Definition: aliases.hpp:18
Definition: access.hpp:18
#define __SYCL_UOP(UOP, OPASSIGN)
#define __SYCL_BINOP(BINOP, OPASSIGN, CONVERT, COND)
#define __SYCL_RELLOGOP(RELLOGOP, COND)