DPC++ Runtime
Runtime libraries for oneAPI DPC++
sub_group_mask.hpp
Go to the documentation of this file.
1 //==------------ sub_group_mask.hpp --- SYCL sub-group mask ----------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #pragma once
9 
10 #include <sycl/detail/helpers.hpp> // for Builder
11 #include <sycl/detail/memcpy.hpp> // detail::memcpy
12 #include <sycl/exception.hpp> // for errc, exception
13 #include <sycl/feature_test.hpp> // for SYCL_EXT_ONEAPI_SUB_GROUP_MASK
14 #include <sycl/id.hpp> // for id
15 #include <sycl/marray.hpp> // for marray
16 #include <sycl/types.hpp> // for vec
17 
18 #include <assert.h> // for assert
19 #include <climits> // for CHAR_BIT
20 #include <stddef.h> // for size_t
21 #include <stdint.h> // for uint32_t
22 #include <system_error> // for error_code
23 #include <type_traits> // for enable_if_t, decay_t
24 
25 namespace sycl {
26 inline namespace _V1 {
27 namespace detail {
28 class Builder;
29 
30 namespace spirv {
31 
32 template <typename Group> struct group_scope;
33 
34 } // namespace spirv
35 
36 } // namespace detail
37 
38 // forward decalre sycl::sub_group
39 struct sub_group;
40 
41 namespace ext::oneapi {
42 
43 // forward decalre sycl::ext::oneapi::sub_group
44 struct sub_group;
45 
46 // defining `group_ballot` here to make predicate default `true`
47 // need to forward declare sub_group_mask first
48 struct sub_group_mask;
49 template <typename Group>
50 std::enable_if_t<std::is_same_v<std::decay_t<Group>, sub_group> ||
51  std::is_same_v<std::decay_t<Group>, sycl::sub_group>,
52  sub_group_mask>
53 group_ballot(Group g, bool predicate = true);
54 
56  friend class sycl::detail::Builder;
57  using BitsType = uint64_t;
58 
59  static constexpr size_t max_bits =
60  sizeof(BitsType) * CHAR_BIT /* implementation-defined */;
61  static constexpr size_t word_size = sizeof(uint32_t) * CHAR_BIT;
62 
63  // enable reference to individual bit
64  struct reference {
65  reference &operator=(bool x) {
66  if (x) {
67  Ref |= RefBit;
68  } else {
69  Ref &= ~RefBit;
70  }
71  return *this;
72  }
74  operator=((bool)x);
75  return *this;
76  }
77  bool operator~() const { return !(Ref & RefBit); }
78  operator bool() const { return Ref & RefBit; }
80  operator=(!(bool)*this);
81  return *this;
82  }
83 
84  reference(sub_group_mask &gmask, size_t pos) : Ref(gmask.Bits) {
85  BitsType one = 1;
86  RefBit = (pos < gmask.bits_num) ? (one << pos) : 0;
87  }
88 
89  private:
90  // Reference to the word containing the bit
91  BitsType &Ref;
92  // Bit mask where only referenced bit is set
93  BitsType RefBit;
94  };
95 
96 #if SYCL_EXT_ONEAPI_SUB_GROUP_MASK >= 2
97  sub_group_mask() : sub_group_mask(0, GetMaxLocalRangeSize()){};
98 
99  sub_group_mask(unsigned long long val)
100  : sub_group_mask(0, GetMaxLocalRangeSize()) {
101  Bits = val;
102  };
103 
104  template <typename T, std::size_t K,
105  typename = std::enable_if_t<std::is_integral_v<T>>>
106  sub_group_mask(const sycl::marray<T, K> &val)
107  : sub_group_mask(0, GetMaxLocalRangeSize()) {
108  for (size_t I = 0, BytesCopied = 0; I < K && BytesCopied < sizeof(Bits);
109  ++I) {
110  size_t RemainingBytes = sizeof(Bits) - BytesCopied;
111  size_t BytesToCopy =
112  RemainingBytes < sizeof(T) ? RemainingBytes : sizeof(T);
113  sycl::detail::memcpy(reinterpret_cast<char *>(&Bits) + BytesCopied,
114  &val[I], BytesToCopy);
115  BytesCopied += BytesToCopy;
116  }
117  }
118 
119  sub_group_mask(const sub_group_mask &other) = default;
120  sub_group_mask &operator=(const sub_group_mask &other) = default;
121 #endif // SYCL_EXT_ONEAPI_SUB_GROUP_MASK
122 
123  bool operator[](id<1> id) const {
124  BitsType one = 1;
125  return (Bits & ((id.get(0) < bits_num) ? (one << id.get(0)) : 0));
126  }
127 
128  reference operator[](id<1> id) { return {*this, id.get(0)}; }
129  bool test(id<1> id) const { return operator[](id); }
130  bool all() const { return count() == bits_num; }
131  bool any() const { return count() != 0; }
132  bool none() const { return count() == 0; }
133  uint32_t count() const {
134 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
135  sycl::marray<unsigned, 4> TmpMArray;
136  this->extract_bits(TmpMArray);
137  sycl::vec<unsigned, 4> MemberMask;
138  for (int i = 0; i < 4; ++i) {
139  MemberMask[i] = TmpMArray[i];
140  }
141  return __spirv_GroupNonUniformBallotBitCount(
144 #else
145  unsigned int count = 0;
146  auto word = (Bits & valuable_bits(bits_num));
147  while (word) {
148  word &= (word - 1);
149  count++;
150  }
151  return count;
152 #endif
153  }
154  uint32_t size() const { return bits_num; }
155  id<1> find_low() const {
156  size_t i = 0;
157  while (i < size() && !operator[](i))
158  i++;
159  return {i};
160  }
161  id<1> find_high() const {
162  size_t i = size() - 1;
163  while (i > 0 && !operator[](i))
164  i--;
165  return {operator[](i) ? i : size()};
166  }
167 
168  template <typename Type,
169  typename = std::enable_if_t<std::is_integral_v<Type>>>
170  void insert_bits(Type bits, id<1> pos = 0) {
171  size_t insert_size = sizeof(Type) * CHAR_BIT;
172  BitsType insert_data = (BitsType)bits;
173  insert_data <<= pos.get(0);
174  BitsType mask = 0;
175  if (pos.get(0) + insert_size < size())
176  mask |= (valuable_bits(bits_num) << (pos.get(0) + insert_size));
177  if (pos.get(0) < size() && pos.get(0))
178  mask |= (valuable_bits(max_bits) >> (max_bits - pos.get(0)));
179  Bits &= mask;
180  Bits += insert_data;
181  }
182 
183  /* The bits are stored in the memory in the following way:
184  marray id | 0 | 1 | 2 | 3 |...
185  bit id |7 .. 0|15 .. 8|23 .. 16|31 .. 24|...
186  */
187  template <typename Type, size_t Size,
188  typename = std::enable_if_t<std::is_integral_v<Type>>>
189  void insert_bits(const marray<Type, Size> &bits, id<1> pos = 0) {
190  size_t cur_pos = pos.get(0);
191  for (auto elem : bits) {
192  if (cur_pos < size()) {
193  this->insert_bits(elem, cur_pos);
194  cur_pos += sizeof(Type) * CHAR_BIT;
195  }
196  }
197  }
198 
199  template <typename Type,
200  typename = std::enable_if_t<std::is_integral_v<Type>>>
201  void extract_bits(Type &bits, id<1> pos = 0) const {
202  auto Res = Bits;
203  Res &= valuable_bits(bits_num);
204  if (pos.get(0) < size()) {
205  if (pos.get(0) > 0) {
206  Res >>= pos.get(0);
207  }
208 
209  if (sizeof(Type) * CHAR_BIT < max_bits) {
210  Res &= valuable_bits(sizeof(Type) * CHAR_BIT);
211  }
212  bits = (Type)Res;
213  } else {
214  bits = 0;
215  }
216  }
217 
218  template <typename Type, size_t Size,
219  typename = std::enable_if_t<std::is_integral_v<Type>>>
220  void extract_bits(marray<Type, Size> &bits, id<1> pos = 0) const {
221  size_t cur_pos = pos.get(0);
222  for (auto &elem : bits) {
223  if (cur_pos < size()) {
224  this->extract_bits(elem, cur_pos);
225  cur_pos += sizeof(Type) * CHAR_BIT;
226  } else {
227  elem = 0;
228  }
229  }
230  }
231 
232  void set() { Bits = valuable_bits(bits_num); }
233  void set(id<1> id, bool value = true) { operator[](id) = value; }
234  void reset() { Bits = BitsType{0}; }
235  void reset(id<1> id) { operator[](id) = 0; }
236  void reset_low() { reset(find_low()); }
237  void reset_high() { reset(find_high()); }
238  void flip() { Bits = (~Bits & valuable_bits(bits_num)); }
239  void flip(id<1> id) { operator[](id).flip(); }
240 
241  bool operator==(const sub_group_mask &rhs) const { return Bits == rhs.Bits; }
242  bool operator!=(const sub_group_mask &rhs) const { return !(*this == rhs); }
243 
245  Bits &= rhs.Bits;
246  return *this;
247  }
249  Bits |= rhs.Bits;
250  return *this;
251  }
252 
254  Bits ^= rhs.Bits;
255  Bits &= valuable_bits(bits_num);
256  return *this;
257  }
258 
260  Bits <<= pos;
261  Bits &= valuable_bits(bits_num);
262  return *this;
263  }
264 
266  Bits >>= pos;
267  return *this;
268  }
269 
271  auto Tmp = *this;
272  Tmp.flip();
273  return Tmp;
274  }
275  sub_group_mask operator<<(size_t pos) const {
276  auto Tmp = *this;
277  Tmp <<= pos;
278  return Tmp;
279  }
280  sub_group_mask operator>>(size_t pos) const {
281  auto Tmp = *this;
282  Tmp >>= pos;
283  return Tmp;
284  }
285 
286  template <typename Group>
287  friend std::enable_if_t<std::is_same_v<std::decay_t<Group>, sub_group>,
289  group_ballot(Group g, bool predicate);
290 
292  const sub_group_mask &rhs) {
293  auto Res = lhs;
294  Res &= rhs;
295  return Res;
296  }
297 
299  const sub_group_mask &rhs) {
300  auto Res = lhs;
301  Res |= rhs;
302  return Res;
303  }
304 
306  const sub_group_mask &rhs) {
307  auto Res = lhs;
308  Res ^= rhs;
309  return Res;
310  }
311 
312 private:
313  static size_t GetMaxLocalRangeSize() {
314 #ifdef __SYCL_DEVICE_ONLY__
315  return __spirv_SubgroupMaxSize();
316 #else
317  return max_bits;
318 #endif
319  }
320 
321  sub_group_mask(BitsType rhs, size_t bn)
322  : Bits(rhs & valuable_bits(bn)), bits_num(bn) {
323  assert(bits_num <= max_bits);
324  }
325  inline BitsType valuable_bits(size_t bn) const {
326  assert(bn <= max_bits);
327  BitsType one = 1;
328  if (bn == max_bits)
329  return -one;
330  return (one << bn) - one;
331  }
332  BitsType Bits;
333  // Number of valuable bits
334  size_t bits_num;
335 };
336 
337 template <typename Group>
338 std::enable_if_t<std::is_same_v<std::decay_t<Group>, sub_group> ||
339  std::is_same_v<std::decay_t<Group>, sycl::sub_group>,
340  sub_group_mask>
341 group_ballot(Group g, bool predicate) {
342  (void)g;
343 #ifdef __SYCL_DEVICE_ONLY__
344  auto res = __spirv_GroupNonUniformBallot(
346  sub_group_mask::BitsType val = res[0];
347  if constexpr (sizeof(sub_group_mask::BitsType) == 8)
348  val |= ((sub_group_mask::BitsType)res[1]) << 32;
349  return sycl::detail::Builder::createSubGroupMask<sub_group_mask>(
350  val, g.get_max_local_range()[0]);
351 #else
352  (void)predicate;
354  "Sub-group mask is not supported on host device"};
355 #endif
356 }
357 
358 } // namespace ext::oneapi
359 } // namespace _V1
360 } // namespace sycl
Provides a cross-platform math array class template that works on SYCL devices as well as in host C++...
Definition: marray.hpp:49
class sycl::vec ///////////////////////// Provides a cross-patform vector class template that works e...
Definition: vector.hpp:361
std::enable_if_t< std::is_same_v< std::decay_t< Group >, sub_group >||std::is_same_v< std::decay_t< Group >, sycl::sub_group >, sub_group_mask > group_ballot(Group g, bool predicate=true)
pointer get() const
Definition: multi_ptr.hpp:544
PropertyListT int access::address_space multi_ptr & operator=(multi_ptr &&)=default
autodecltype(x) x
Definition: access.hpp:18
reference(sub_group_mask &gmask, size_t pos)
sub_group_mask operator>>(size_t pos) const
sub_group_mask & operator^=(const sub_group_mask &rhs)
bool operator!=(const sub_group_mask &rhs) const
sub_group_mask & operator>>=(size_t pos)
sub_group_mask & operator|=(const sub_group_mask &rhs)
friend sub_group_mask operator^(const sub_group_mask &lhs, const sub_group_mask &rhs)
sub_group_mask & operator<<=(size_t pos)
friend class sycl::detail::Builder
void extract_bits(Type &bits, id< 1 > pos=0) const
bool operator==(const sub_group_mask &rhs) const
friend std::enable_if_t< std::is_same_v< std::decay_t< Group >, sub_group >, sub_group_mask > group_ballot(Group g, bool predicate)
void insert_bits(const marray< Type, Size > &bits, id< 1 > pos=0)
void extract_bits(marray< Type, Size > &bits, id< 1 > pos=0) const
friend sub_group_mask operator&(const sub_group_mask &lhs, const sub_group_mask &rhs)
void set(id< 1 > id, bool value=true)
friend sub_group_mask operator|(const sub_group_mask &lhs, const sub_group_mask &rhs)
sub_group_mask & operator&=(const sub_group_mask &rhs)
void insert_bits(Type bits, id< 1 > pos=0)
sub_group_mask operator<<(size_t pos) const