DPC++ Runtime
Runtime libraries for oneAPI DPC++
ballot_group.hpp
Go to the documentation of this file.
1 //==------ ballot_group.hpp --- SYCL extension for non-uniform groups ------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #include <sycl/aspects.hpp>
12 #include <sycl/detail/pi.h> // for PI_ERROR_INVALID_DEVICE
13 #include <sycl/detail/type_traits.hpp> // for is_group, is_user_cons...
14 #include <sycl/exception.hpp> // for runtime_error
16 #include <sycl/ext/oneapi/sub_group_mask.hpp> // for sub_group_mask
17 #include <sycl/id.hpp> // for id
18 #include <sycl/memory_enums.hpp> // for memory_scope
19 #include <sycl/range.hpp> // for range
20 #include <sycl/sub_group.hpp> // for sub_group
21 
22 #include <type_traits> // for enable_if_t, decay_t
23 
24 namespace sycl {
25 inline namespace _V1 {
26 namespace ext::oneapi::experimental {
27 
28 template <typename ParentGroup> class ballot_group;
29 
30 template <typename Group>
31 #ifdef __SYCL_DEVICE_ONLY__
32 [[__sycl_detail__::__uses_aspects__(sycl::aspect::ext_oneapi_ballot_group)]]
33 #endif
34 inline std::enable_if_t<sycl::is_group_v<std::decay_t<Group>> &&
35  std::is_same_v<Group, sycl::sub_group>,
36  ballot_group<Group>> get_ballot_group(Group group,
37  bool predicate);
38 
39 template <typename ParentGroup> class ballot_group {
40 public:
41  using id_type = id<1>;
43  using linear_id_type = typename ParentGroup::linear_id_type;
44  static constexpr int dimensions = 1;
46 
48 #ifdef __SYCL_DEVICE_ONLY__
49  return (Predicate) ? 1 : 0;
50 #else
51  throw runtime_error("Non-uniform groups are not supported on host device.",
52  PI_ERROR_INVALID_DEVICE);
53 #endif
54  }
55 
57 #ifdef __SYCL_DEVICE_ONLY__
58  return sycl::detail::CallerPositionInMask(Mask);
59 #else
60  throw runtime_error("Non-uniform groups are not supported on host device.",
61  PI_ERROR_INVALID_DEVICE);
62 #endif
63  }
64 
66 #ifdef __SYCL_DEVICE_ONLY__
67  return 2;
68 #else
69  throw runtime_error("Non-uniform groups are not supported on host device.",
70  PI_ERROR_INVALID_DEVICE);
71 #endif
72  }
73 
75 #ifdef __SYCL_DEVICE_ONLY__
76  return Mask.count();
77 #else
78  throw runtime_error("Non-uniform groups are not supported on host device.",
79  PI_ERROR_INVALID_DEVICE);
80 #endif
81  }
82 
84 #ifdef __SYCL_DEVICE_ONLY__
85  return static_cast<linear_id_type>(get_group_id()[0]);
86 #else
87  throw runtime_error("Non-uniform groups are not supported on host device.",
88  PI_ERROR_INVALID_DEVICE);
89 #endif
90  }
91 
93 #ifdef __SYCL_DEVICE_ONLY__
94  return static_cast<linear_id_type>(get_local_id()[0]);
95 #else
96  throw runtime_error("Non-uniform groups are not supported on host device.",
97  PI_ERROR_INVALID_DEVICE);
98 #endif
99  }
100 
102 #ifdef __SYCL_DEVICE_ONLY__
103  return static_cast<linear_id_type>(get_group_range()[0]);
104 #else
105  throw runtime_error("Non-uniform groups are not supported on host device.",
106  PI_ERROR_INVALID_DEVICE);
107 #endif
108  }
109 
111 #ifdef __SYCL_DEVICE_ONLY__
112  return static_cast<linear_id_type>(get_local_range()[0]);
113 #else
114  throw runtime_error("Non-uniform groups are not supported on host device.",
115  PI_ERROR_INVALID_DEVICE);
116 #endif
117  }
118 
119  bool leader() const {
120 #ifdef __SYCL_DEVICE_ONLY__
121  uint32_t Lowest = static_cast<uint32_t>(Mask.find_low()[0]);
122  return __spirv_SubgroupLocalInvocationId() == Lowest;
123 #else
124  throw runtime_error("Non-uniform groups are not supported on host device.",
125  PI_ERROR_INVALID_DEVICE);
126 #endif
127  }
128 
129 protected:
131  const bool Predicate;
132 
133  ballot_group(sub_group_mask m, bool p) : Mask(m), Predicate(p) {}
134 
136  get_ballot_group<ParentGroup>(ParentGroup g, bool predicate);
137 
138  friend sub_group_mask sycl::detail::GetMask<ballot_group<ParentGroup>>(
140 };
141 
142 template <typename Group>
143 inline std::enable_if_t<sycl::is_group_v<std::decay_t<Group>> &&
144  std::is_same_v<Group, sycl::sub_group>,
146 get_ballot_group(Group group, bool predicate) {
147  (void)group;
148 #ifdef __SYCL_DEVICE_ONLY__
149 #if defined(__SPIR__) || defined(__NVPTX__)
150  // ballot_group partitions into two groups using the predicate
151  // Membership mask for one group is negation of the other
153  if (predicate) {
154  return ballot_group<sycl::sub_group>(mask, predicate);
155  } else {
156  // To negate the mask for the false-predicate group, we also need to exclude
157  // all parts of the mask that is not part of the group.
158  sub_group_mask::BitsType participant_filter =
159  (~sub_group_mask::BitsType{0}) >>
160  (sub_group_mask::max_bits - group.get_local_linear_range());
161  return ballot_group<sycl::sub_group>((~mask) & participant_filter,
162  predicate);
163  }
164 #endif
165 #else
166  (void)predicate;
167  throw runtime_error("Non-uniform groups are not supported on host device.",
168  PI_ERROR_INVALID_DEVICE);
169 #endif
170 }
171 
172 template <typename ParentGroup>
173 struct is_user_constructed_group<ballot_group<ParentGroup>> : std::true_type {};
174 
175 } // namespace ext::oneapi::experimental
176 
177 template <typename ParentGroup>
178 struct is_group<ext::oneapi::experimental::ballot_group<ParentGroup>>
179  : std::true_type {};
180 
181 } // namespace _V1
182 } // namespace sycl
static constexpr sycl::memory_scope fence_scope
typename ParentGroup::linear_id_type linear_id_type
fence_scope
The scope that fence() operation should apply to.
Definition: common.hpp:350
std::enable_if_t< sycl::is_group_v< std::decay_t< Group > > &&std::is_same_v< Group, sycl::sub_group >, ballot_group< Group > > get_ballot_group(Group group, bool predicate)
std::enable_if_t< std::is_same_v< std::decay_t< Group >, sub_group >||std::is_same_v< std::decay_t< Group >, sycl::sub_group >, sub_group_mask > group_ballot(Group g, bool predicate=true)
Definition: access.hpp:18