13 #include <sycl/feature_test.hpp>
22 #include <system_error>
23 #include <type_traits>
26 inline namespace _V1 {
41 namespace ext::oneapi {
48 struct sub_group_mask;
49 template <
typename Group>
50 std::enable_if_t<std::is_same_v<std::decay_t<Group>,
sub_group> ||
61 static constexpr
size_t word_size =
sizeof(uint32_t) * CHAR_BIT;
78 operator bool()
const {
return Ref & RefBit; }
86 RefBit = (pos < gmask.bits_num) ? (one << pos) : 0;
96 #if SYCL_EXT_ONEAPI_SUB_GROUP_MASK >= 2
99 sub_group_mask(
unsigned long long val)
100 : sub_group_mask(0, GetMaxLocalRangeSize()) {
104 template <
typename T, std::size_t K,
105 typename = std::enable_if_t<std::is_integral_v<T>>>
107 : sub_group_mask(0, GetMaxLocalRangeSize()) {
108 for (
size_t I = 0, BytesCopied = 0; I < K && BytesCopied <
sizeof(Bits);
110 size_t RemainingBytes =
sizeof(Bits) - BytesCopied;
112 RemainingBytes <
sizeof(T) ? RemainingBytes :
sizeof(T);
113 sycl::detail::memcpy(
reinterpret_cast<char *
>(&Bits) + BytesCopied,
114 &val[I], BytesToCopy);
115 BytesCopied += BytesToCopy;
119 sub_group_mask(
const sub_group_mask &other) =
default;
120 sub_group_mask &
operator=(
const sub_group_mask &other) =
default;
125 return (Bits & ((
id.
get(0) < bits_num) ? (one <<
id.
get(0)) : 0));
134 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
138 for (
int i = 0; i < 4; ++i) {
139 MemberMask[i] = TmpMArray[i];
141 return __spirv_GroupNonUniformBallotBitCount(
145 unsigned int count = 0;
146 auto word = (Bits & valuable_bits(bits_num));
154 uint32_t
size()
const {
return bits_num; }
157 while (i <
size() && !
operator[](i))
162 size_t i =
size() - 1;
163 while (i > 0 && !
operator[](i))
168 template <
typename Type,
169 typename = std::enable_if_t<std::is_integral_v<Type>>>
171 size_t insert_size =
sizeof(Type) * CHAR_BIT;
173 insert_data <<= pos.get(0);
175 if (pos.get(0) + insert_size <
size())
176 mask |= (valuable_bits(bits_num) << (pos.get(0) + insert_size));
177 if (pos.get(0) <
size() && pos.get(0))
187 template <
typename Type,
size_t Size,
188 typename = std::enable_if_t<std::is_integral_v<Type>>>
190 size_t cur_pos = pos.get(0);
191 for (
auto elem : bits) {
192 if (cur_pos <
size()) {
194 cur_pos +=
sizeof(Type) * CHAR_BIT;
199 template <
typename Type,
200 typename = std::enable_if_t<std::is_integral_v<Type>>>
203 Res &= valuable_bits(bits_num);
204 if (pos.get(0) <
size()) {
205 if (pos.get(0) > 0) {
209 if (
sizeof(Type) * CHAR_BIT <
max_bits) {
210 Res &= valuable_bits(
sizeof(Type) * CHAR_BIT);
218 template <
typename Type,
size_t Size,
219 typename = std::enable_if_t<std::is_integral_v<Type>>>
221 size_t cur_pos = pos.get(0);
222 for (
auto &
elem : bits) {
223 if (cur_pos <
size()) {
225 cur_pos +=
sizeof(Type) * CHAR_BIT;
232 void set() { Bits = valuable_bits(bits_num); }
238 void flip() { Bits = (~Bits & valuable_bits(bits_num)); }
255 Bits &= valuable_bits(bits_num);
261 Bits &= valuable_bits(bits_num);
286 template <
typename Group>
287 friend std::enable_if_t<std::is_same_v<std::decay_t<Group>,
sub_group>,
313 static size_t GetMaxLocalRangeSize() {
314 #ifdef __SYCL_DEVICE_ONLY__
315 return __spirv_SubgroupMaxSize();
321 sub_group_mask(
BitsType rhs,
size_t bn)
322 : Bits(rhs & valuable_bits(bn)), bits_num(bn) {
325 inline BitsType valuable_bits(
size_t bn)
const {
330 return (one << bn) - one;
337 template <
typename Group>
338 std::enable_if_t<std::is_same_v<std::decay_t<Group>, sub_group> ||
343 #ifdef __SYCL_DEVICE_ONLY__
344 auto res = __spirv_GroupNonUniformBallot(
349 return sycl::detail::Builder::createSubGroupMask<sub_group_mask>(
350 val, g.get_max_local_range()[0]);
354 "Sub-group mask is not supported on host device"};
Provides a cross-platform math array class template that works on SYCL devices as well as in host C++...
class sycl::vec ///////////////////////// Provides a cross-patform vector class template that works e...
auto convertToOpenCLType(T &&x)
std::enable_if_t< std::is_same_v< std::decay_t< Group >, sub_group >||std::is_same_v< std::decay_t< Group >, sycl::sub_group >, sub_group_mask > group_ballot(Group g, bool predicate=true)
PropertyListT int access::address_space multi_ptr & operator=(multi_ptr &&)=default
reference & operator=(bool x)
reference(sub_group_mask &gmask, size_t pos)
reference & operator=(const reference &x)
sub_group_mask operator>>(size_t pos) const
sub_group_mask & operator^=(const sub_group_mask &rhs)
bool operator!=(const sub_group_mask &rhs) const
sub_group_mask & operator>>=(size_t pos)
bool test(id< 1 > id) const
sub_group_mask & operator|=(const sub_group_mask &rhs)
friend sub_group_mask operator^(const sub_group_mask &lhs, const sub_group_mask &rhs)
sub_group_mask & operator<<=(size_t pos)
friend class sycl::detail::Builder
sub_group_mask operator~() const
void extract_bits(Type &bits, id< 1 > pos=0) const
static constexpr size_t max_bits
bool operator==(const sub_group_mask &rhs) const
friend std::enable_if_t< std::is_same_v< std::decay_t< Group >, sub_group >, sub_group_mask > group_ballot(Group g, bool predicate)
static constexpr size_t word_size
bool operator[](id< 1 > id) const
void insert_bits(const marray< Type, Size > &bits, id< 1 > pos=0)
id< 1 > find_high() const
void extract_bits(marray< Type, Size > &bits, id< 1 > pos=0) const
friend sub_group_mask operator&(const sub_group_mask &lhs, const sub_group_mask &rhs)
void set(id< 1 > id, bool value=true)
friend sub_group_mask operator|(const sub_group_mask &lhs, const sub_group_mask &rhs)
sub_group_mask & operator&=(const sub_group_mask &rhs)
void insert_bits(Type bits, id< 1 > pos=0)
sub_group_mask operator<<(size_t pos) const
reference operator[](id< 1 > id)