DPC++ Runtime
Runtime libraries for oneAPI DPC++
sub_group_mask.hpp
Go to the documentation of this file.
1 //==------------ sub_group_mask.hpp --- SYCL sub-group mask ----------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #pragma once
9 
10 #include <CL/__spirv/spirv_ops.hpp>
12 #include <sycl/detail/helpers.hpp>
13 #include <sycl/exception.hpp>
14 #include <sycl/id.hpp>
15 #include <sycl/marray.hpp>
16 
17 namespace sycl {
19 namespace detail {
20 class Builder;
21 } // namespace detail
22 
23 namespace ext::oneapi {
24 
25 #if defined(__SYCL_DEVICE_ONLY__) && defined(__AMDGCN__) && \
26  (__AMDGCN_WAVEFRONT_SIZE == 64)
27 #define BITS_TYPE uint64_t
28 #else
29 #define BITS_TYPE uint32_t
30 #endif
31 
32 // defining `group_ballot` here to make predicate default `true`
33 // need to forward declare sub_group_mask first
34 struct sub_group_mask;
35 template <typename Group>
36 detail::enable_if_t<std::is_same<std::decay_t<Group>, sub_group>::value,
37  sub_group_mask>
38 group_ballot(Group g, bool predicate = true);
39 
41  friend class detail::Builder;
43 
44  static constexpr size_t max_bits =
45  sizeof(BitsType) * CHAR_BIT /* implementation-defined */;
46  static constexpr size_t word_size = sizeof(uint32_t) * CHAR_BIT;
47 
48  // enable reference to individual bit
49  struct reference {
50  reference &operator=(bool x) {
51  if (x) {
52  Ref |= RefBit;
53  } else {
54  Ref &= ~RefBit;
55  }
56  return *this;
57  }
59  operator=((bool)x);
60  return *this;
61  }
62  bool operator~() const { return !(Ref & RefBit); }
63  operator bool() const { return Ref & RefBit; }
65  operator=(!(bool)*this);
66  return *this;
67  }
68 
69  reference(sub_group_mask &gmask, size_t pos) : Ref(gmask.Bits) {
70  RefBit = (pos < gmask.bits_num) ? (1UL << pos) : 0;
71  }
72 
73  private:
74  // Reference to the word containing the bit
75  BitsType &Ref;
76  // Bit mask where only referenced bit is set
77  BitsType RefBit;
78  };
79 
80  bool operator[](id<1> id) const {
81  return (Bits & ((id.get(0) < bits_num) ? (1UL << id.get(0)) : 0));
82  }
83 
84  reference operator[](id<1> id) { return {*this, id.get(0)}; }
85  bool test(id<1> id) const { return operator[](id); }
86  bool all() const { return count() == bits_num; }
87  bool any() const { return count() != 0; }
88  bool none() const { return count() == 0; }
89  uint32_t count() const {
90  unsigned int count = 0;
91  auto word = (Bits & valuable_bits(bits_num));
92  while (word) {
93  word &= (word - 1);
94  count++;
95  }
96  return count;
97  }
98  uint32_t size() const { return bits_num; }
99  id<1> find_low() const {
100  size_t i = 0;
101  while (i < size() && !operator[](i))
102  i++;
103  return {i};
104  }
105  id<1> find_high() const {
106  size_t i = size() - 1;
107  while (i > 0 && !operator[](i))
108  i--;
109  return {operator[](i) ? i : size()};
110  }
111 
112  template <typename Type,
113  typename = sycl::detail::enable_if_t<std::is_integral<Type>::value>>
114  void insert_bits(Type bits, id<1> pos = 0) {
115  size_t insert_size = sizeof(Type) * CHAR_BIT;
116  BitsType insert_data = (BitsType)bits;
117  insert_data <<= pos.get(0);
118  BitsType mask = 0;
119  if (pos.get(0) + insert_size < size())
120  mask |= (valuable_bits(bits_num) << (pos.get(0) + insert_size));
121  if (pos.get(0) < size() && pos.get(0))
122  mask |= (valuable_bits(max_bits) >> (max_bits - pos.get(0)));
123  Bits &= mask;
124  Bits += insert_data;
125  }
126 
127  /* The bits are stored in the memory in the following way:
128  marray id | 0 | 1 | 2 | 3 |...
129  bit id |7 .. 0|15 .. 8|23 .. 16|31 .. 24|...
130  */
131  template <typename Type, size_t Size,
132  typename = sycl::detail::enable_if_t<std::is_integral<Type>::value>>
133  void insert_bits(const marray<Type, Size> &bits, id<1> pos = 0) {
134  size_t cur_pos = pos.get(0);
135  for (auto elem : bits) {
136  if (cur_pos < size()) {
137  this->insert_bits(elem, cur_pos);
138  cur_pos += sizeof(Type) * CHAR_BIT;
139  }
140  }
141  }
142 
143  template <typename Type,
144  typename = sycl::detail::enable_if_t<std::is_integral<Type>::value>>
145  void extract_bits(Type &bits, id<1> pos = 0) const {
146  auto Res = Bits;
147  Res &= valuable_bits(bits_num);
148  if (pos.get(0) < size()) {
149  if (pos.get(0) > 0) {
150  Res >>= pos.get(0);
151  }
152 
153  if (sizeof(Type) * CHAR_BIT < max_bits) {
154  Res &= valuable_bits(sizeof(Type) * CHAR_BIT);
155  }
156  bits = (Type)Res;
157  } else {
158  bits = 0;
159  }
160  }
161 
162  template <typename Type, size_t Size,
163  typename = sycl::detail::enable_if_t<std::is_integral<Type>::value>>
164  void extract_bits(marray<Type, Size> &bits, id<1> pos = 0) const {
165  size_t cur_pos = pos.get(0);
166  for (auto &elem : bits) {
167  if (cur_pos < size()) {
168  this->extract_bits(elem, cur_pos);
169  cur_pos += sizeof(Type) * CHAR_BIT;
170  } else {
171  elem = 0;
172  }
173  }
174  }
175 
176  void set() { Bits = valuable_bits(bits_num); }
177  void set(id<1> id, bool value = true) { operator[](id) = value; }
178  void reset() { Bits = BitsType{0}; }
179  void reset(id<1> id) { operator[](id) = 0; }
180  void reset_low() { reset(find_low()); }
181  void reset_high() { reset(find_high()); }
182  void flip() { Bits = (~Bits & valuable_bits(bits_num)); }
183  void flip(id<1> id) { operator[](id).flip(); }
184 
185  bool operator==(const sub_group_mask &rhs) const { return Bits == rhs.Bits; }
186  bool operator!=(const sub_group_mask &rhs) const { return !(*this == rhs); }
187 
189  Bits &= rhs.Bits;
190  return *this;
191  }
193  Bits |= rhs.Bits;
194  return *this;
195  }
196 
198  Bits ^= rhs.Bits;
199  Bits &= valuable_bits(bits_num);
200  return *this;
201  }
202 
204  Bits <<= pos;
205  Bits &= valuable_bits(bits_num);
206  return *this;
207  }
208 
210  Bits >>= pos;
211  return *this;
212  }
213 
215  auto Tmp = *this;
216  Tmp.flip();
217  return Tmp;
218  }
219  sub_group_mask operator<<(size_t pos) const {
220  auto Tmp = *this;
221  Tmp <<= pos;
222  return Tmp;
223  }
224  sub_group_mask operator>>(size_t pos) const {
225  auto Tmp = *this;
226  Tmp >>= pos;
227  return Tmp;
228  }
229 
231  : Bits(rhs.Bits), bits_num(rhs.bits_num) {}
232 
233  template <typename Group>
234  friend detail::enable_if_t<
235  std::is_same<std::decay_t<Group>, sub_group>::value, sub_group_mask>
236  group_ballot(Group g, bool predicate);
237 
239  const sub_group_mask &rhs) {
240  auto Res = lhs;
241  Res &= rhs;
242  return Res;
243  }
244 
246  const sub_group_mask &rhs) {
247  auto Res = lhs;
248  Res |= rhs;
249  return Res;
250  }
251 
253  const sub_group_mask &rhs) {
254  auto Res = lhs;
255  Res ^= rhs;
256  return Res;
257  }
258 
259 private:
260  sub_group_mask(BitsType rhs, size_t bn)
261  : Bits(rhs & valuable_bits(bn)), bits_num(bn) {
262  assert(bits_num <= max_bits);
263  }
264  inline BitsType valuable_bits(size_t bn) const {
265  assert(bn <= max_bits);
266  BitsType one = 1;
267  if (bn == max_bits)
268  return -one;
269  return (one << bn) - one;
270  }
271  BitsType Bits;
272  // Number of valuable bits
273  size_t bits_num;
274 };
275 
276 template <typename Group>
277 detail::enable_if_t<std::is_same<std::decay_t<Group>, sub_group>::value,
278  sub_group_mask>
279 group_ballot(Group g, bool predicate) {
280  (void)g;
281 #ifdef __SYCL_DEVICE_ONLY__
282  auto res = __spirv_GroupNonUniformBallot(
283  detail::spirv::group_scope<Group>::value, predicate);
284  BITS_TYPE val = res[0];
285  if constexpr (sizeof(BITS_TYPE) == 8)
286  val |= ((BITS_TYPE)res[1]) << 32;
287  return detail::Builder::createSubGroupMask<sub_group_mask>(
288  val, g.get_max_local_range()[0]);
289 #else
290  (void)predicate;
291  throw exception{errc::feature_not_supported,
292  "Sub-group mask is not supported on host device"};
293 #endif
294 }
295 
296 #undef BITS_TYPE
297 
298 } // namespace ext::oneapi
299 } // __SYCL_INLINE_VER_NAMESPACE(_V1)
300 } // namespace sycl
spirv_ops.hpp
sycl::_V1::ext::oneapi::sub_group_mask::insert_bits
void insert_bits(const marray< Type, Size > &bits, id< 1 > pos=0)
Definition: sub_group_mask.hpp:133
sycl::_V1::ext::oneapi::sub_group_mask::insert_bits
void insert_bits(Type bits, id< 1 > pos=0)
Definition: sub_group_mask.hpp:114
sycl::_V1::ext::oneapi::sub_group_mask::operator==
bool operator==(const sub_group_mask &rhs) const
Definition: sub_group_mask.hpp:185
sycl::_V1::detail::Builder
Definition: helpers.hpp:61
sycl::_V1::ext::oneapi::sub_group_mask::operator&=
sub_group_mask & operator&=(const sub_group_mask &rhs)
Definition: sub_group_mask.hpp:188
sycl::_V1::ext::oneapi::sub_group_mask::reset_low
void reset_low()
Definition: sub_group_mask.hpp:180
sycl::_V1::ext::oneapi::sub_group_mask::extract_bits
void extract_bits(marray< Type, Size > &bits, id< 1 > pos=0) const
Definition: sub_group_mask.hpp:164
sycl::_V1::elem
Definition: types.hpp:78
sycl::_V1::ext::oneapi::sub_group_mask::sub_group_mask
sub_group_mask(const sub_group_mask &rhs)
Definition: sub_group_mask.hpp:230
__SYCL_INLINE_VER_NAMESPACE
#define __SYCL_INLINE_VER_NAMESPACE(X)
Definition: defines_elementary.hpp:11
sycl::_V1::ext::oneapi::sub_group_mask::reference::operator=
reference & operator=(bool x)
Definition: sub_group_mask.hpp:50
sycl::_V1::ext::oneapi::sub_group_mask::count
uint32_t count() const
Definition: sub_group_mask.hpp:89
sycl::_V1::ext::oneapi::experimental::operator[]
T & operator[](std::ptrdiff_t idx) const noexcept
Definition: annotated_arg.hpp:160
sycl::_V1::ext::oneapi::sub_group_mask
Definition: sub_group_mask.hpp:40
sycl::_V1::ext::oneapi::sub_group_mask::reference::reference
reference(sub_group_mask &gmask, size_t pos)
Definition: sub_group_mask.hpp:69
helpers.hpp
spirv_vars.hpp
sycl
---— Error handling, matching OpenCL plugin semantics.
Definition: access.hpp:14
sycl::_V1::ext::oneapi::sub_group_mask::operator~
sub_group_mask operator~() const
Definition: sub_group_mask.hpp:214
sycl::_V1::id< 1 >
sycl::_V1::ext::oneapi::sub_group_mask::reset
void reset(id< 1 > id)
Definition: sub_group_mask.hpp:179
id.hpp
sycl::_V1::detail::enable_if_t
typename std::enable_if< B, T >::type enable_if_t
Definition: stl_type_traits.hpp:24
sycl::_V1::ext::oneapi::sub_group_mask::operator^=
sub_group_mask & operator^=(const sub_group_mask &rhs)
Definition: sub_group_mask.hpp:197
sycl::_V1::ext::oneapi::sub_group_mask::none
bool none() const
Definition: sub_group_mask.hpp:88
sycl::_V1::ext::oneapi::sub_group_mask::all
bool all() const
Definition: sub_group_mask.hpp:86
BITS_TYPE
#define BITS_TYPE
Definition: sub_group_mask.hpp:29
sycl::_V1::ext::oneapi::sub_group_mask::operator!=
bool operator!=(const sub_group_mask &rhs) const
Definition: sub_group_mask.hpp:186
sycl::_V1::ext::oneapi::sub_group_mask::BitsType
BITS_TYPE BitsType
Definition: sub_group_mask.hpp:42
sycl::_V1::ext::oneapi::sub_group_mask::operator<<=
sub_group_mask & operator<<=(size_t pos)
Definition: sub_group_mask.hpp:203
std::get
constexpr tuple_element< I, tuple< Types... > >::type & get(sycl::detail::tuple< Types... > &Arg) noexcept
Definition: tuple.hpp:199
sycl::_V1::ext::oneapi::sub_group_mask::size
uint32_t size() const
Definition: sub_group_mask.hpp:98
sycl::_V1::ext::oneapi::sub_group_mask::operator^
friend sub_group_mask operator^(const sub_group_mask &lhs, const sub_group_mask &rhs)
Definition: sub_group_mask.hpp:252
sycl::_V1::ext::oneapi::group_ballot
detail::enable_if_t< std::is_same< std::decay_t< Group >, sub_group >::value, sub_group_mask > group_ballot(Group g, bool predicate=true)
Definition: sub_group_mask.hpp:279
sycl::_V1::ext::oneapi::experimental::operator=
annotated_arg & operator=(annotated_arg &)=default
sycl::_V1::marray
Provides a cross-platform math array class template that works on SYCL devices as well as in host C++...
Definition: generic_type_lists.hpp:25
sycl::_V1::exception
Definition: exception.hpp:64
sycl::_V1::ext::oneapi::sub_group_mask::operator&
friend sub_group_mask operator&(const sub_group_mask &lhs, const sub_group_mask &rhs)
Definition: sub_group_mask.hpp:238
sycl::_V1::ext::oneapi::sub_group_mask::operator<<
sub_group_mask operator<<(size_t pos) const
Definition: sub_group_mask.hpp:219
sycl::_V1::ext::oneapi::sub_group_mask::set
void set()
Definition: sub_group_mask.hpp:176
sycl::_V1::ext::oneapi::sub_group_mask::set
void set(id< 1 > id, bool value=true)
Definition: sub_group_mask.hpp:177
sycl::_V1::ext::oneapi::sub_group_mask::operator>>
sub_group_mask operator>>(size_t pos) const
Definition: sub_group_mask.hpp:224
sycl::_V1::ext::oneapi::sub_group_mask::reset_high
void reset_high()
Definition: sub_group_mask.hpp:181
exception.hpp
sycl::_V1::ext::oneapi::sub_group_mask::reference::operator=
reference & operator=(const reference &x)
Definition: sub_group_mask.hpp:58
sycl::_V1::ext::oneapi::sub_group_mask::flip
void flip(id< 1 > id)
Definition: sub_group_mask.hpp:183
sycl::_V1::ext::oneapi::sub_group_mask::test
bool test(id< 1 > id) const
Definition: sub_group_mask.hpp:85
marray.hpp
sycl::_V1::ext::oneapi::sub_group_mask::flip
void flip()
Definition: sub_group_mask.hpp:182
sycl::_V1::ext::oneapi::sub_group
Definition: sub_group.hpp:131
sycl::_V1::ext::oneapi::sub_group_mask::reference::flip
reference & flip()
Definition: sub_group_mask.hpp:64
sycl::_V1::ext::oneapi::sub_group_mask::find_high
id< 1 > find_high() const
Definition: sub_group_mask.hpp:105
sycl::_V1::ext::oneapi::sub_group_mask::operator[]
bool operator[](id< 1 > id) const
Definition: sub_group_mask.hpp:80
sycl::_V1::ext::oneapi::sub_group_mask::reference::operator~
bool operator~() const
Definition: sub_group_mask.hpp:62
sycl::_V1::ext::oneapi::sub_group_mask::find_low
id< 1 > find_low() const
Definition: sub_group_mask.hpp:99
sycl::_V1::ext::oneapi::sub_group_mask::operator|
friend sub_group_mask operator|(const sub_group_mask &lhs, const sub_group_mask &rhs)
Definition: sub_group_mask.hpp:245
sycl::_V1::ext::oneapi::sub_group_mask::reference
Definition: sub_group_mask.hpp:49
sycl::_V1::ext::oneapi::sub_group_mask::operator[]
reference operator[](id< 1 > id)
Definition: sub_group_mask.hpp:84
sycl::_V1::ext::oneapi::sub_group_mask::operator>>=
sub_group_mask & operator>>=(size_t pos)
Definition: sub_group_mask.hpp:209
sycl::_V1::ext::oneapi::sub_group_mask::any
bool any() const
Definition: sub_group_mask.hpp:87
sycl::_V1::ext::oneapi::sub_group_mask::reset
void reset()
Definition: sub_group_mask.hpp:178
sycl::_V1::ext::oneapi::sub_group_mask::extract_bits
void extract_bits(Type &bits, id< 1 > pos=0) const
Definition: sub_group_mask.hpp:145
sycl::_V1::ext::oneapi::sub_group_mask::operator|=
sub_group_mask & operator|=(const sub_group_mask &rhs)
Definition: sub_group_mask.hpp:192