23 namespace ext::oneapi {
25 #if defined(__SYCL_DEVICE_ONLY__) && defined(__AMDGCN__) && \
26 (__AMDGCN_WAVEFRONT_SIZE == 64)
27 #define BITS_TYPE uint64_t
29 #define BITS_TYPE uint32_t
34 struct sub_group_mask;
35 template <
typename Group>
36 detail::enable_if_t<std::is_same<std::decay_t<Group>, sub_group>::value,
44 static constexpr
size_t max_bits =
46 static constexpr
size_t word_size =
sizeof(uint32_t) * CHAR_BIT;
63 operator bool()
const {
return Ref & RefBit; }
70 RefBit = (pos < gmask.bits_num) ? (1UL << pos) : 0;
81 return (Bits & ((
id.
get(0) < bits_num) ? (1UL <<
id.
get(0)) : 0));
86 bool all()
const {
return count() == bits_num; }
87 bool any()
const {
return count() != 0; }
88 bool none()
const {
return count() == 0; }
90 unsigned int count = 0;
91 auto word = (Bits & valuable_bits(bits_num));
98 uint32_t
size()
const {
return bits_num; }
101 while (i < size() && !
operator[](i))
106 size_t i = size() - 1;
107 while (i > 0 && !
operator[](i))
112 template <
typename Type,
113 typename = sycl::detail::enable_if_t<std::is_integral<Type>::value>>
115 size_t insert_size =
sizeof(Type) * CHAR_BIT;
117 insert_data <<= pos.get(0);
119 if (pos.get(0) + insert_size < size())
120 mask |= (valuable_bits(bits_num) << (pos.get(0) + insert_size));
121 if (pos.get(0) < size() && pos.get(0))
122 mask |= (valuable_bits(max_bits) >> (max_bits - pos.get(0)));
131 template <
typename Type,
size_t Size,
132 typename = sycl::detail::enable_if_t<std::is_integral<Type>::value>>
134 size_t cur_pos = pos.get(0);
135 for (
auto elem : bits) {
136 if (cur_pos < size()) {
137 this->insert_bits(
elem, cur_pos);
138 cur_pos +=
sizeof(Type) * CHAR_BIT;
143 template <
typename Type,
144 typename = sycl::detail::enable_if_t<std::is_integral<Type>::value>>
147 Res &= valuable_bits(bits_num);
148 if (pos.get(0) < size()) {
149 if (pos.get(0) > 0) {
153 if (
sizeof(Type) * CHAR_BIT < max_bits) {
154 Res &= valuable_bits(
sizeof(Type) * CHAR_BIT);
162 template <
typename Type,
size_t Size,
163 typename = sycl::detail::enable_if_t<std::is_integral<Type>::value>>
165 size_t cur_pos = pos.get(0);
166 for (
auto &
elem : bits) {
167 if (cur_pos < size()) {
168 this->extract_bits(
elem, cur_pos);
169 cur_pos +=
sizeof(Type) * CHAR_BIT;
176 void set() { Bits = valuable_bits(bits_num); }
182 void flip() { Bits = (~Bits & valuable_bits(bits_num)); }
199 Bits &= valuable_bits(bits_num);
205 Bits &= valuable_bits(bits_num);
231 : Bits(rhs.Bits), bits_num(rhs.bits_num) {}
233 template <
typename Group>
261 : Bits(rhs & valuable_bits(bn)), bits_num(bn) {
262 assert(bits_num <= max_bits);
264 inline BitsType valuable_bits(
size_t bn)
const {
265 assert(bn <= max_bits);
269 return (one << bn) - one;
276 template <
typename Group>
277 detail::enable_if_t<std::is_same<std::decay_t<Group>, sub_group>::value,
281 #ifdef __SYCL_DEVICE_ONLY__
282 auto res = __spirv_GroupNonUniformBallot(
283 detail::spirv::group_scope<Group>::value, predicate);
287 return detail::Builder::createSubGroupMask<sub_group_mask>(
288 val, g.get_max_local_range()[0]);
291 throw exception{errc::feature_not_supported,
292 "Sub-group mask is not supported on host device"};