DPC++ Runtime
Runtime libraries for oneAPI DPC++
sub_group_mask.hpp
Go to the documentation of this file.
1 //==------------ sub_group_mask.hpp --- SYCL sub-group mask ----------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #pragma once
9 
10 #include <CL/__spirv/spirv_ops.hpp>
12 #include <sycl/detail/helpers.hpp>
13 #include <sycl/exception.hpp>
14 #include <sycl/id.hpp>
15 #include <sycl/marray.hpp>
16 
17 namespace sycl {
19 namespace detail {
20 class Builder;
21 } // namespace detail
22 
23 namespace ext {
24 namespace oneapi {
25 
27  friend class detail::Builder;
28  static constexpr size_t max_bits = 32 /* implementation-defined */;
29  static constexpr size_t word_size = sizeof(uint32_t) * CHAR_BIT;
30 
31  // enable reference to individual bit
32  struct reference {
33  reference &operator=(bool x) {
34  if (x) {
35  Ref |= RefBit;
36  } else {
37  Ref &= ~RefBit;
38  }
39  return *this;
40  }
42  operator=((bool)x);
43  return *this;
44  }
45  bool operator~() const { return !(Ref & RefBit); }
46  operator bool() const { return Ref & RefBit; }
48  operator=(!(bool)*this);
49  return *this;
50  }
51 
52  reference(sub_group_mask &gmask, size_t pos) : Ref(gmask.Bits) {
53  RefBit = (pos < gmask.bits_num) ? (1UL << pos) : 0;
54  }
55 
56  private:
57  // Reference to the word containing the bit
58  uint32_t &Ref;
59  // Bit mask where only referenced bit is set
60  uint32_t RefBit;
61  };
62 
63  bool operator[](id<1> id) const {
64  return (Bits & ((id.get(0) < bits_num) ? (1UL << id.get(0)) : 0));
65  }
66 
67  reference operator[](id<1> id) { return {*this, id.get(0)}; }
68  bool test(id<1> id) const { return operator[](id); }
69  bool all() const { return count() == bits_num; }
70  bool any() const { return count() != 0; }
71  bool none() const { return count() == 0; }
72  uint32_t count() const {
73  unsigned int count = 0;
74  auto word = (Bits & valuable_bits(bits_num));
75  while (word) {
76  word &= (word - 1);
77  count++;
78  }
79  return count;
80  }
81  uint32_t size() const { return bits_num; }
82  id<1> find_low() const {
83  size_t i = 0;
84  while (i < size() && !operator[](i))
85  i++;
86  return {i};
87  }
88  id<1> find_high() const {
89  size_t i = size() - 1;
90  while (i > 0 && !operator[](i))
91  i--;
92  return {operator[](i) ? i : size()};
93  }
94 
95  template <typename Type,
96  typename = sycl::detail::enable_if_t<std::is_integral<Type>::value>>
97  void insert_bits(Type bits, id<1> pos = 0) {
98  size_t insert_size = sizeof(Type) * CHAR_BIT;
99  uint32_t insert_data = (uint32_t)bits;
100  insert_data <<= pos.get(0);
101  uint32_t mask = 0;
102  if (pos.get(0) + insert_size < size())
103  mask |= (valuable_bits(bits_num) << (pos.get(0) + insert_size));
104  if (pos.get(0) < size() && pos.get(0))
105  mask |= (valuable_bits(max_bits) >> (max_bits - pos.get(0)));
106  Bits &= mask;
107  Bits += insert_data;
108  }
109 
110  /* The bits are stored in the memory in the following way:
111  marray id | 0 | 1 | 2 | 3 |
112  bit id |7 .. 0|15 .. 8|23 .. 16|31 .. 24|
113  */
114  template <typename Type, size_t Size,
115  typename = sycl::detail::enable_if_t<std::is_integral<Type>::value>>
116  void insert_bits(const marray<Type, Size> &bits, id<1> pos = 0) {
117  size_t cur_pos = pos.get(0);
118  for (auto elem : bits) {
119  if (cur_pos < size()) {
120  this->insert_bits(elem, cur_pos);
121  cur_pos += sizeof(Type) * CHAR_BIT;
122  }
123  }
124  }
125 
126  template <typename Type,
127  typename = sycl::detail::enable_if_t<std::is_integral<Type>::value>>
128  void extract_bits(Type &bits, id<1> pos = 0) const {
129  auto Res = Bits;
130  Res &= valuable_bits(bits_num);
131  if (pos.get(0) < size()) {
132  if (pos.get(0) > 0) {
133  Res >>= pos.get(0);
134  }
135 
136  if (sizeof(Type) * CHAR_BIT < max_bits) {
137  Res &= valuable_bits(sizeof(Type) * CHAR_BIT);
138  }
139  bits = (Type)Res;
140  } else {
141  bits = 0;
142  }
143  }
144 
145  template <typename Type, size_t Size,
146  typename = sycl::detail::enable_if_t<std::is_integral<Type>::value>>
147  void extract_bits(marray<Type, Size> &bits, id<1> pos = 0) const {
148  size_t cur_pos = pos.get(0);
149  for (auto &elem : bits) {
150  if (cur_pos < size()) {
151  this->extract_bits(elem, cur_pos);
152  cur_pos += sizeof(Type) * CHAR_BIT;
153  } else {
154  elem = 0;
155  }
156  }
157  }
158 
159  void set() { Bits = valuable_bits(bits_num); }
160  void set(id<1> id, bool value = true) { operator[](id) = value; }
161  void reset() { Bits = uint32_t{0}; }
162  void reset(id<1> id) { operator[](id) = 0; }
163  void reset_low() { reset(find_low()); }
164  void reset_high() { reset(find_high()); }
165  void flip() { Bits = (~Bits & valuable_bits(bits_num)); }
166  void flip(id<1> id) { operator[](id).flip(); }
167 
168  bool operator==(const sub_group_mask &rhs) const { return Bits == rhs.Bits; }
169  bool operator!=(const sub_group_mask &rhs) const { return !(*this == rhs); }
170 
172  Bits &= rhs.Bits;
173  return *this;
174  }
176  Bits |= rhs.Bits;
177  return *this;
178  }
179 
181  Bits ^= rhs.Bits;
182  Bits &= valuable_bits(bits_num);
183  return *this;
184  }
185 
187  Bits <<= pos;
188  Bits &= valuable_bits(bits_num);
189  return *this;
190  }
191 
193  Bits >>= pos;
194  return *this;
195  }
196 
198  auto Tmp = *this;
199  Tmp.flip();
200  return Tmp;
201  }
202  sub_group_mask operator<<(size_t pos) const {
203  auto Tmp = *this;
204  Tmp <<= pos;
205  return Tmp;
206  }
207  sub_group_mask operator>>(size_t pos) const {
208  auto Tmp = *this;
209  Tmp >>= pos;
210  return Tmp;
211  }
212 
214  : Bits(rhs.Bits), bits_num(rhs.bits_num) {}
215 
216  template <typename Group>
217  friend detail::enable_if_t<
218  std::is_same<std::decay_t<Group>, sub_group>::value, sub_group_mask>
219  group_ballot(Group g, bool predicate);
220 
222  const sub_group_mask &rhs) {
223  auto Res = lhs;
224  Res &= rhs;
225  return Res;
226  }
227 
229  const sub_group_mask &rhs) {
230  auto Res = lhs;
231  Res |= rhs;
232  return Res;
233  }
234 
236  const sub_group_mask &rhs) {
237  auto Res = lhs;
238  Res ^= rhs;
239  return Res;
240  }
241 
242 private:
243  sub_group_mask(uint32_t rhs, size_t bn) : Bits(rhs), bits_num(bn) {
244  assert(bits_num <= max_bits);
245  }
246  inline uint32_t valuable_bits(size_t bn) const {
247  return static_cast<uint32_t>((1ULL << bn) - 1ULL);
248  }
249  uint32_t Bits;
250  // Number of valuable bits
251  size_t bits_num;
252 };
253 
254 template <typename Group>
255 detail::enable_if_t<std::is_same<std::decay_t<Group>, sub_group>::value,
256  sub_group_mask>
257 group_ballot(Group g, bool predicate) {
258  (void)g;
259 #ifdef __SYCL_DEVICE_ONLY__
260  auto res = __spirv_GroupNonUniformBallot(
261  detail::spirv::group_scope<Group>::value, predicate);
262  return detail::Builder::createSubGroupMask<sub_group_mask>(
263  res[0], g.get_max_local_range()[0]);
264 #else
265  (void)predicate;
266  throw exception{errc::feature_not_supported,
267  "Sub-group mask is not supported on host device"};
268 #endif
269 }
270 
271 } // namespace oneapi
272 } // namespace ext
273 } // __SYCL_INLINE_VER_NAMESPACE(_V1)
274 } // namespace sycl
Provides a cross-patform math array class template that works on SYCL devices as well as in host C++ ...
Definition: marray.hpp:24
#define __SYCL_INLINE_VER_NAMESPACE(X)
constexpr tuple_element< I, tuple< Types... > >::type & get(sycl::detail::tuple< Types... > &Arg) noexcept
Definition: tuple.hpp:199
typename std::enable_if< B, T >::type enable_if_t
detail::enable_if_t< std::is_same< std::decay_t< Group >, sub_group >::value, sub_group_mask > group_ballot(Group g, bool predicate)
---— Error handling, matching OpenCL plugin semantics.
Definition: access.hpp:14
reference(sub_group_mask &gmask, size_t pos)
sub_group_mask(const sub_group_mask &rhs)
sub_group_mask operator>>(size_t pos) const
sub_group_mask & operator^=(const sub_group_mask &rhs)
bool operator!=(const sub_group_mask &rhs) const
sub_group_mask & operator>>=(size_t pos)
sub_group_mask & operator|=(const sub_group_mask &rhs)
friend sub_group_mask operator^(const sub_group_mask &lhs, const sub_group_mask &rhs)
sub_group_mask & operator<<=(size_t pos)
void extract_bits(Type &bits, id< 1 > pos=0) const
bool operator==(const sub_group_mask &rhs) const
void insert_bits(const marray< Type, Size > &bits, id< 1 > pos=0)
void extract_bits(marray< Type, Size > &bits, id< 1 > pos=0) const
friend sub_group_mask operator&(const sub_group_mask &lhs, const sub_group_mask &rhs)
void set(id< 1 > id, bool value=true)
friend sub_group_mask operator|(const sub_group_mask &lhs, const sub_group_mask &rhs)
sub_group_mask & operator&=(const sub_group_mask &rhs)
void insert_bits(Type bits, id< 1 > pos=0)
sub_group_mask operator<<(size_t pos) const