14 #include <sycl/feature_test.hpp>
23 #include <system_error>
24 #include <type_traits>
27 inline namespace _V1 {
42 namespace ext::oneapi {
49 struct sub_group_mask;
50 template <
typename Group>
51 std::enable_if_t<std::is_same_v<std::decay_t<Group>,
sub_group> ||
62 static constexpr
size_t word_size =
sizeof(uint32_t) * CHAR_BIT;
79 operator bool()
const {
return Ref & RefBit; }
87 RefBit = (pos < gmask.bits_num) ? (one << pos) : 0;
97 #if SYCL_EXT_ONEAPI_SUB_GROUP_MASK >= 2
100 sub_group_mask(
unsigned long long val)
101 : sub_group_mask(0, GetMaxLocalRangeSize()) {
105 template <
typename T, std::size_t K,
106 typename = std::enable_if_t<std::is_integral_v<T>>>
108 : sub_group_mask(0, GetMaxLocalRangeSize()) {
109 for (
size_t I = 0, BytesCopied = 0; I < K && BytesCopied <
sizeof(Bits);
111 size_t RemainingBytes =
sizeof(Bits) - BytesCopied;
113 RemainingBytes <
sizeof(T) ? RemainingBytes :
sizeof(T);
114 sycl::detail::memcpy(
reinterpret_cast<char *
>(&Bits) + BytesCopied,
115 &val[I], BytesToCopy);
116 BytesCopied += BytesToCopy;
120 sub_group_mask(
const sub_group_mask &other) =
default;
121 sub_group_mask &
operator=(
const sub_group_mask &other) =
default;
126 return (Bits & ((
id.
get(0) < bits_num) ? (one <<
id.
get(0)) : 0));
135 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
139 for (
int i = 0; i < 4; ++i) {
140 MemberMask[i] = TmpMArray[i];
142 return __spirv_GroupNonUniformBallotBitCount(
146 unsigned int count = 0;
147 auto word = (Bits & valuable_bits(bits_num));
155 uint32_t
size()
const {
return bits_num; }
158 while (i <
size() && !
operator[](i))
163 size_t i =
size() - 1;
164 while (i > 0 && !
operator[](i))
169 template <
typename Type,
170 typename = std::enable_if_t<std::is_integral_v<Type>>>
172 size_t insert_size =
sizeof(Type) * CHAR_BIT;
174 insert_data <<= pos.get(0);
176 if (pos.get(0) + insert_size <
size())
177 mask |= (valuable_bits(bits_num) << (pos.get(0) + insert_size));
178 if (pos.get(0) <
size() && pos.get(0))
188 template <
typename Type,
size_t Size,
189 typename = std::enable_if_t<std::is_integral_v<Type>>>
191 size_t cur_pos = pos.get(0);
192 for (
auto elem : bits) {
193 if (cur_pos <
size()) {
195 cur_pos +=
sizeof(Type) * CHAR_BIT;
200 template <
typename Type,
201 typename = std::enable_if_t<std::is_integral_v<Type>>>
204 Res &= valuable_bits(bits_num);
205 if (pos.get(0) <
size()) {
206 if (pos.get(0) > 0) {
210 if (
sizeof(Type) * CHAR_BIT <
max_bits) {
211 Res &= valuable_bits(
sizeof(Type) * CHAR_BIT);
219 template <
typename Type,
size_t Size,
220 typename = std::enable_if_t<std::is_integral_v<Type>>>
222 size_t cur_pos = pos.get(0);
223 for (
auto &
elem : bits) {
224 if (cur_pos <
size()) {
226 cur_pos +=
sizeof(Type) * CHAR_BIT;
233 void set() { Bits = valuable_bits(bits_num); }
239 void flip() { Bits = (~Bits & valuable_bits(bits_num)); }
256 Bits &= valuable_bits(bits_num);
262 Bits &= valuable_bits(bits_num);
287 template <
typename Group>
288 friend std::enable_if_t<std::is_same_v<std::decay_t<Group>,
sub_group>,
314 static size_t GetMaxLocalRangeSize() {
315 #ifdef __SYCL_DEVICE_ONLY__
316 return __spirv_SubgroupMaxSize();
322 sub_group_mask(
BitsType rhs,
size_t bn)
323 : Bits(rhs & valuable_bits(bn)), bits_num(bn) {
326 inline BitsType valuable_bits(
size_t bn)
const {
331 return (one << bn) - one;
338 template <
typename Group>
339 std::enable_if_t<std::is_same_v<std::decay_t<Group>, sub_group> ||
344 #ifdef __SYCL_DEVICE_ONLY__
345 auto res = __spirv_GroupNonUniformBallot(
350 return sycl::detail::Builder::createSubGroupMask<sub_group_mask>(
351 val, g.get_max_local_range()[0]);
355 "Sub-group mask is not supported on host device"};
Provides a cross-platform math array class template that works on SYCL devices as well as in host C++...
auto convertToOpenCLType(T &&x)
std::enable_if_t< std::is_same_v< std::decay_t< Group >, sub_group >||std::is_same_v< std::decay_t< Group >, sycl::sub_group >, sub_group_mask > group_ballot(Group g, bool predicate=true)
PropertyListT int access::address_space multi_ptr & operator=(multi_ptr &&)=default
reference & operator=(bool x)
reference(sub_group_mask &gmask, size_t pos)
reference & operator=(const reference &x)
sub_group_mask operator>>(size_t pos) const
sub_group_mask & operator^=(const sub_group_mask &rhs)
bool operator!=(const sub_group_mask &rhs) const
sub_group_mask & operator>>=(size_t pos)
bool test(id< 1 > id) const
sub_group_mask & operator|=(const sub_group_mask &rhs)
friend sub_group_mask operator^(const sub_group_mask &lhs, const sub_group_mask &rhs)
sub_group_mask & operator<<=(size_t pos)
friend class sycl::detail::Builder
sub_group_mask operator~() const
void extract_bits(Type &bits, id< 1 > pos=0) const
static constexpr size_t max_bits
bool operator==(const sub_group_mask &rhs) const
friend std::enable_if_t< std::is_same_v< std::decay_t< Group >, sub_group >, sub_group_mask > group_ballot(Group g, bool predicate)
static constexpr size_t word_size
bool operator[](id< 1 > id) const
void insert_bits(const marray< Type, Size > &bits, id< 1 > pos=0)
id< 1 > find_high() const
void extract_bits(marray< Type, Size > &bits, id< 1 > pos=0) const
friend sub_group_mask operator&(const sub_group_mask &lhs, const sub_group_mask &rhs)
void set(id< 1 > id, bool value=true)
friend sub_group_mask operator|(const sub_group_mask &lhs, const sub_group_mask &rhs)
sub_group_mask & operator&=(const sub_group_mask &rhs)
void insert_bits(Type bits, id< 1 > pos=0)
sub_group_mask operator<<(size_t pos) const
reference operator[](id< 1 > id)