DPC++ Runtime
Runtime libraries for oneAPI DPC++
sub_group_mask.hpp
Go to the documentation of this file.
1 //==------------ sub_group_mask.hpp --- SYCL sub-group mask ----------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #pragma once
9 
10 #include <sycl/builtins.hpp> // for assert
11 #include <sycl/detail/helpers.hpp> // for Builder
12 #include <sycl/detail/memcpy.hpp> // detail::memcpy
13 #include <sycl/exception.hpp> // for errc, exception
14 #include <sycl/feature_test.hpp> // for SYCL_EXT_ONEAPI_SUB_GROUP_MASK
15 #include <sycl/id.hpp> // for id
16 #include <sycl/marray.hpp> // for marray
17 #include <sycl/types.hpp> // for vec
18 
19 #include <assert.h> // for assert
20 #include <climits> // for CHAR_BIT
21 #include <stddef.h> // for size_t
22 #include <stdint.h> // for uint32_t
23 #include <system_error> // for error_code
24 #include <type_traits> // for enable_if_t, decay_t
25 
26 namespace sycl {
27 inline namespace _V1 {
28 namespace detail {
29 class Builder;
30 
31 namespace spirv {
32 
33 template <typename Group> struct group_scope;
34 
35 } // namespace spirv
36 
37 } // namespace detail
38 
39 // forward decalre sycl::sub_group
40 struct sub_group;
41 
42 namespace ext::oneapi {
43 
44 // forward decalre sycl::ext::oneapi::sub_group
45 struct sub_group;
46 
47 // defining `group_ballot` here to make predicate default `true`
48 // need to forward declare sub_group_mask first
49 struct sub_group_mask;
50 template <typename Group>
51 std::enable_if_t<std::is_same_v<std::decay_t<Group>, sub_group> ||
52  std::is_same_v<std::decay_t<Group>, sycl::sub_group>,
53  sub_group_mask>
54 group_ballot(Group g, bool predicate = true);
55 
57  friend class sycl::detail::Builder;
58  using BitsType = uint64_t;
59 
60  static constexpr size_t max_bits =
61  sizeof(BitsType) * CHAR_BIT /* implementation-defined */;
62  static constexpr size_t word_size = sizeof(uint32_t) * CHAR_BIT;
63 
64  // enable reference to individual bit
65  struct reference {
66  reference &operator=(bool x) {
67  if (x) {
68  Ref |= RefBit;
69  } else {
70  Ref &= ~RefBit;
71  }
72  return *this;
73  }
75  operator=((bool)x);
76  return *this;
77  }
78  bool operator~() const { return !(Ref & RefBit); }
79  operator bool() const { return Ref & RefBit; }
81  operator=(!(bool)*this);
82  return *this;
83  }
84 
85  reference(sub_group_mask &gmask, size_t pos) : Ref(gmask.Bits) {
86  BitsType one = 1;
87  RefBit = (pos < gmask.bits_num) ? (one << pos) : 0;
88  }
89 
90  private:
91  // Reference to the word containing the bit
92  BitsType &Ref;
93  // Bit mask where only referenced bit is set
94  BitsType RefBit;
95  };
96 
97 #if SYCL_EXT_ONEAPI_SUB_GROUP_MASK >= 2
98  sub_group_mask() : sub_group_mask(0, GetMaxLocalRangeSize()){};
99 
100  sub_group_mask(unsigned long long val)
101  : sub_group_mask(0, GetMaxLocalRangeSize()) {
102  Bits = val;
103  };
104 
105  template <typename T, std::size_t K,
106  typename = std::enable_if_t<std::is_integral_v<T>>>
107  sub_group_mask(const sycl::marray<T, K> &val)
108  : sub_group_mask(0, GetMaxLocalRangeSize()) {
109  for (size_t I = 0, BytesCopied = 0; I < K && BytesCopied < sizeof(Bits);
110  ++I) {
111  size_t RemainingBytes = sizeof(Bits) - BytesCopied;
112  size_t BytesToCopy =
113  RemainingBytes < sizeof(T) ? RemainingBytes : sizeof(T);
114  sycl::detail::memcpy(reinterpret_cast<char *>(&Bits) + BytesCopied,
115  &val[I], BytesToCopy);
116  BytesCopied += BytesToCopy;
117  }
118  }
119 
120  sub_group_mask(const sub_group_mask &other) = default;
121  sub_group_mask &operator=(const sub_group_mask &other) = default;
122 #endif // SYCL_EXT_ONEAPI_SUB_GROUP_MASK
123 
124  bool operator[](id<1> id) const {
125  BitsType one = 1;
126  return (Bits & ((id.get(0) < bits_num) ? (one << id.get(0)) : 0));
127  }
128 
129  reference operator[](id<1> id) { return {*this, id.get(0)}; }
130  bool test(id<1> id) const { return operator[](id); }
131  bool all() const { return count() == bits_num; }
132  bool any() const { return count() != 0; }
133  bool none() const { return count() == 0; }
134  uint32_t count() const {
135 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
136  sycl::marray<unsigned, 4> TmpMArray;
137  this->extract_bits(TmpMArray);
138  sycl::vec<unsigned, 4> MemberMask;
139  for (int i = 0; i < 4; ++i) {
140  MemberMask[i] = TmpMArray[i];
141  }
142  return __spirv_GroupNonUniformBallotBitCount(
145 #else
146  unsigned int count = 0;
147  auto word = (Bits & valuable_bits(bits_num));
148  while (word) {
149  word &= (word - 1);
150  count++;
151  }
152  return count;
153 #endif
154  }
155  uint32_t size() const { return bits_num; }
156  id<1> find_low() const {
157  size_t i = 0;
158  while (i < size() && !operator[](i))
159  i++;
160  return {i};
161  }
162  id<1> find_high() const {
163  size_t i = size() - 1;
164  while (i > 0 && !operator[](i))
165  i--;
166  return {operator[](i) ? i : size()};
167  }
168 
169  template <typename Type,
170  typename = std::enable_if_t<std::is_integral_v<Type>>>
171  void insert_bits(Type bits, id<1> pos = 0) {
172  size_t insert_size = sizeof(Type) * CHAR_BIT;
173  BitsType insert_data = (BitsType)bits;
174  insert_data <<= pos.get(0);
175  BitsType mask = 0;
176  if (pos.get(0) + insert_size < size())
177  mask |= (valuable_bits(bits_num) << (pos.get(0) + insert_size));
178  if (pos.get(0) < size() && pos.get(0))
179  mask |= (valuable_bits(max_bits) >> (max_bits - pos.get(0)));
180  Bits &= mask;
181  Bits += insert_data;
182  }
183 
184  /* The bits are stored in the memory in the following way:
185  marray id | 0 | 1 | 2 | 3 |...
186  bit id |7 .. 0|15 .. 8|23 .. 16|31 .. 24|...
187  */
188  template <typename Type, size_t Size,
189  typename = std::enable_if_t<std::is_integral_v<Type>>>
190  void insert_bits(const marray<Type, Size> &bits, id<1> pos = 0) {
191  size_t cur_pos = pos.get(0);
192  for (auto elem : bits) {
193  if (cur_pos < size()) {
194  this->insert_bits(elem, cur_pos);
195  cur_pos += sizeof(Type) * CHAR_BIT;
196  }
197  }
198  }
199 
200  template <typename Type,
201  typename = std::enable_if_t<std::is_integral_v<Type>>>
202  void extract_bits(Type &bits, id<1> pos = 0) const {
203  auto Res = Bits;
204  Res &= valuable_bits(bits_num);
205  if (pos.get(0) < size()) {
206  if (pos.get(0) > 0) {
207  Res >>= pos.get(0);
208  }
209 
210  if (sizeof(Type) * CHAR_BIT < max_bits) {
211  Res &= valuable_bits(sizeof(Type) * CHAR_BIT);
212  }
213  bits = (Type)Res;
214  } else {
215  bits = 0;
216  }
217  }
218 
219  template <typename Type, size_t Size,
220  typename = std::enable_if_t<std::is_integral_v<Type>>>
221  void extract_bits(marray<Type, Size> &bits, id<1> pos = 0) const {
222  size_t cur_pos = pos.get(0);
223  for (auto &elem : bits) {
224  if (cur_pos < size()) {
225  this->extract_bits(elem, cur_pos);
226  cur_pos += sizeof(Type) * CHAR_BIT;
227  } else {
228  elem = 0;
229  }
230  }
231  }
232 
233  void set() { Bits = valuable_bits(bits_num); }
234  void set(id<1> id, bool value = true) { operator[](id) = value; }
235  void reset() { Bits = BitsType{0}; }
236  void reset(id<1> id) { operator[](id) = 0; }
237  void reset_low() { reset(find_low()); }
238  void reset_high() { reset(find_high()); }
239  void flip() { Bits = (~Bits & valuable_bits(bits_num)); }
240  void flip(id<1> id) { operator[](id).flip(); }
241 
242  bool operator==(const sub_group_mask &rhs) const { return Bits == rhs.Bits; }
243  bool operator!=(const sub_group_mask &rhs) const { return !(*this == rhs); }
244 
246  Bits &= rhs.Bits;
247  return *this;
248  }
250  Bits |= rhs.Bits;
251  return *this;
252  }
253 
255  Bits ^= rhs.Bits;
256  Bits &= valuable_bits(bits_num);
257  return *this;
258  }
259 
261  Bits <<= pos;
262  Bits &= valuable_bits(bits_num);
263  return *this;
264  }
265 
267  Bits >>= pos;
268  return *this;
269  }
270 
272  auto Tmp = *this;
273  Tmp.flip();
274  return Tmp;
275  }
276  sub_group_mask operator<<(size_t pos) const {
277  auto Tmp = *this;
278  Tmp <<= pos;
279  return Tmp;
280  }
281  sub_group_mask operator>>(size_t pos) const {
282  auto Tmp = *this;
283  Tmp >>= pos;
284  return Tmp;
285  }
286 
287  template <typename Group>
288  friend std::enable_if_t<std::is_same_v<std::decay_t<Group>, sub_group>,
290  group_ballot(Group g, bool predicate);
291 
293  const sub_group_mask &rhs) {
294  auto Res = lhs;
295  Res &= rhs;
296  return Res;
297  }
298 
300  const sub_group_mask &rhs) {
301  auto Res = lhs;
302  Res |= rhs;
303  return Res;
304  }
305 
307  const sub_group_mask &rhs) {
308  auto Res = lhs;
309  Res ^= rhs;
310  return Res;
311  }
312 
313 private:
314  static size_t GetMaxLocalRangeSize() {
315 #ifdef __SYCL_DEVICE_ONLY__
316  return __spirv_SubgroupMaxSize();
317 #else
318  return max_bits;
319 #endif
320  }
321 
322  sub_group_mask(BitsType rhs, size_t bn)
323  : Bits(rhs & valuable_bits(bn)), bits_num(bn) {
324  assert(bits_num <= max_bits);
325  }
326  inline BitsType valuable_bits(size_t bn) const {
327  assert(bn <= max_bits);
328  BitsType one = 1;
329  if (bn == max_bits)
330  return -one;
331  return (one << bn) - one;
332  }
333  BitsType Bits;
334  // Number of valuable bits
335  size_t bits_num;
336 };
337 
338 template <typename Group>
339 std::enable_if_t<std::is_same_v<std::decay_t<Group>, sub_group> ||
340  std::is_same_v<std::decay_t<Group>, sycl::sub_group>,
341  sub_group_mask>
342 group_ballot(Group g, bool predicate) {
343  (void)g;
344 #ifdef __SYCL_DEVICE_ONLY__
345  auto res = __spirv_GroupNonUniformBallot(
347  sub_group_mask::BitsType val = res[0];
348  if constexpr (sizeof(sub_group_mask::BitsType) == 8)
349  val |= ((sub_group_mask::BitsType)res[1]) << 32;
350  return sycl::detail::Builder::createSubGroupMask<sub_group_mask>(
351  val, g.get_max_local_range()[0]);
352 #else
353  (void)predicate;
355  "Sub-group mask is not supported on host device"};
356 #endif
357 }
358 
359 } // namespace ext::oneapi
360 } // namespace _V1
361 } // namespace sycl
362 
363 // We have a cyclic dependency with
364 // sub_group_mask.hpp
365 // detail/spirv.hpp
366 // non_uniform_groups.hpp
367 // "Break" it by including this at the end (instead of beginning). Ideally, we
368 // should refactor this somehow...
369 #include <sycl/detail/spirv.hpp>
Provides a cross-platform math array class template that works on SYCL devices as well as in host C++...
Definition: marray.hpp:49
std::enable_if_t< std::is_same_v< std::decay_t< Group >, sub_group >||std::is_same_v< std::decay_t< Group >, sycl::sub_group >, sub_group_mask > group_ballot(Group g, bool predicate=true)
pointer get() const
Definition: multi_ptr.hpp:544
PropertyListT int access::address_space multi_ptr & operator=(multi_ptr &&)=default
autodecltype(x) x
Definition: access.hpp:18
reference(sub_group_mask &gmask, size_t pos)
sub_group_mask operator>>(size_t pos) const
sub_group_mask & operator^=(const sub_group_mask &rhs)
bool operator!=(const sub_group_mask &rhs) const
sub_group_mask & operator>>=(size_t pos)
sub_group_mask & operator|=(const sub_group_mask &rhs)
friend sub_group_mask operator^(const sub_group_mask &lhs, const sub_group_mask &rhs)
sub_group_mask & operator<<=(size_t pos)
friend class sycl::detail::Builder
void extract_bits(Type &bits, id< 1 > pos=0) const
bool operator==(const sub_group_mask &rhs) const
friend std::enable_if_t< std::is_same_v< std::decay_t< Group >, sub_group >, sub_group_mask > group_ballot(Group g, bool predicate)
void insert_bits(const marray< Type, Size > &bits, id< 1 > pos=0)
void extract_bits(marray< Type, Size > &bits, id< 1 > pos=0) const
friend sub_group_mask operator&(const sub_group_mask &lhs, const sub_group_mask &rhs)
void set(id< 1 > id, bool value=true)
friend sub_group_mask operator|(const sub_group_mask &lhs, const sub_group_mask &rhs)
sub_group_mask & operator&=(const sub_group_mask &rhs)
void insert_bits(Type bits, id< 1 > pos=0)
sub_group_mask operator<<(size_t pos) const