11 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
15 inline namespace _V1 {
18 template <
typename T,
class BinaryOperation>
19 using IsRedux = std::bool_constant<
20 std::is_integral<T>::value && IsBitAND<T, BinaryOperation>::value ||
21 IsBitOR<T, BinaryOperation>::value || IsBitXOR<T, BinaryOperation>::value ||
22 IsPlus<T, BinaryOperation>::value || IsMinimum<T, BinaryOperation>::value ||
23 IsMaximum<T, BinaryOperation>::value>;
27 template <
typename Group,
typename T,
class BinaryOperation>
28 std::enable_if_t<is_sugeninteger_v<T> && IsMinimum<T, BinaryOperation>::value,
30 masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op,
31 const uint32_t MemberMask) {
32 return __nvvm_redux_sync_umin(x, MemberMask);
35 template <
typename Group,
typename T,
class BinaryOperation>
36 std::enable_if_t<is_sigeninteger_v<T> && IsMinimum<T, BinaryOperation>::value,
38 masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op,
39 const uint32_t MemberMask) {
40 return __nvvm_redux_sync_min(x, MemberMask);
43 template <
typename Group,
typename T,
class BinaryOperation>
44 std::enable_if_t<is_sugeninteger_v<T> && IsMaximum<T, BinaryOperation>::value,
46 masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op,
47 const uint32_t MemberMask) {
48 return __nvvm_redux_sync_umax(x, MemberMask);
51 template <
typename Group,
typename T,
class BinaryOperation>
52 std::enable_if_t<is_sigeninteger_v<T> && IsMaximum<T, BinaryOperation>::value,
54 masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op,
55 const uint32_t MemberMask) {
56 return __nvvm_redux_sync_max(x, MemberMask);
59 template <
typename Group,
typename T,
class BinaryOperation>
60 std::enable_if_t<(is_sugeninteger_v<T> ||
61 is_sigeninteger_v<T>)&&IsPlus<T, BinaryOperation>::value,
63 masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op,
64 const uint32_t MemberMask) {
65 return __nvvm_redux_sync_add(x, MemberMask);
68 template <
typename Group,
typename T,
class BinaryOperation>
69 std::enable_if_t<(is_sugeninteger_v<T> ||
70 is_sigeninteger_v<T>)&&IsBitAND<T, BinaryOperation>::value,
72 masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op,
73 const uint32_t MemberMask) {
74 return __nvvm_redux_sync_and(x, MemberMask);
77 template <
typename Group,
typename T,
class BinaryOperation>
78 std::enable_if_t<(is_sugeninteger_v<T> ||
79 is_sigeninteger_v<T>)&&IsBitOR<T, BinaryOperation>::value,
81 masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op,
82 const uint32_t MemberMask) {
83 return __nvvm_redux_sync_or(x, MemberMask);
86 template <
typename Group,
typename T,
class BinaryOperation>
87 std::enable_if_t<(is_sugeninteger_v<T> ||
88 is_sigeninteger_v<T>)&&IsBitXOR<T, BinaryOperation>::value,
90 masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op,
91 const uint32_t MemberMask) {
92 return __nvvm_redux_sync_xor(x, MemberMask);
99 template <
typename Group,
typename T,
class BinaryOperation>
101 masked_reduction_cuda_shfls(Group g, T x, BinaryOperation binary_op,
102 const uint32_t MemberMask) {
103 for (
int i = g.get_local_range()[0] / 2; i > 0; i /= 2) {
104 T tmp = cuda_shfl_sync_bfly_i32(MemberMask, x, i, 0x1f);
105 x = binary_op(x, tmp);
111 template <
typename Group,
typename T,
class BinaryOperation>
113 ext::oneapi::experimental::is_user_constructed_group_v<Group> &&
114 !is_fixed_size_group_v<Group>,
116 masked_reduction_cuda_shfls(Group g, T x, BinaryOperation binary_op,
117 const uint32_t MemberMask) {
119 unsigned localSetBit = g.get_local_id()[0] + 1;
122 auto opRange = g.get_local_range()[0];
125 unsigned stride = opRange / 2;
126 while (stride >= 1) {
129 unsigned remainder = opRange % 2;
132 int unfoldedSrcSetBit = localSetBit + stride;
136 T tmp = cuda_shfl_sync_idx_i32(
137 MemberMask, x, __nvvm_fns(MemberMask, 0, unfoldedSrcSetBit), 31);
139 if (!(localSetBit == 1 && remainder != 0)) {
140 x = binary_op(x, tmp);
143 opRange = stride + remainder;
144 stride = opRange / 2;
147 asm volatile(
".reg .u32 rev;\n\t"
148 "brev.b32 rev, %1;\n\t"
153 x = cuda_shfl_sync_idx_i32(MemberMask, x, broadID, 31);
158 template <
typename Group,
typename T,
class BinaryOperation>
160 std::is_same<IsRedux<T, BinaryOperation>, std::false_type>::value &&
161 ext::oneapi::experimental::is_user_constructed_group_v<Group>,
163 masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op,
164 const uint32_t MemberMask) {
165 return masked_reduction_cuda_shfls(g, x, binary_op, MemberMask);
171 template <
typename T,
class BinaryOperation>
173 std::enable_if_t<IsPlus<T, BinaryOperation>::value ||
174 IsBitOR<T, BinaryOperation>::value ||
175 IsBitXOR<T, BinaryOperation>::value,
181 template <
typename T,
class BinaryOperation>
183 std::enable_if_t<IsMultiplies<T, BinaryOperation>::value, T>
188 template <
typename T,
class BinaryOperation>
190 std::enable_if_t<IsBitAND<T, BinaryOperation>::value, T>
195 template <
typename T,
class BinaryOperation>
197 std::enable_if_t<IsMinimum<T, BinaryOperation>::value, T>
202 template <
typename T,
class BinaryOperation>
204 std::enable_if_t<IsMaximum<T, BinaryOperation>::value, T>
213 class BinaryOperation>
215 masked_scan_cuda_shfls(Group g, T x, BinaryOperation binary_op,
216 const uint32_t MemberMask) {
217 unsigned localIdVal = g.get_local_id()[0];
218 for (
int i = 1; i < g.get_local_range()[0]; i *= 2) {
219 T tmp = cuda_shfl_sync_up_i32(MemberMask, x, i, 0);
221 x = binary_op(x, tmp);
224 x = cuda_shfl_sync_up_i32(MemberMask, x, 1, 0);
225 if (localIdVal == 0) {
226 return get_identity<T, BinaryOperation>();
233 class BinaryOperation>
235 ext::oneapi::experimental::is_user_constructed_group_v<Group> &&
236 !is_fixed_size_group_v<Group>,
238 masked_scan_cuda_shfls(Group g, T x, BinaryOperation binary_op,
239 const uint32_t MemberMask) {
240 unsigned localIdVal = g.get_local_id()[0];
241 unsigned localSetBit = localIdVal + 1;
243 for (
int i = 1; i < g.get_local_range()[0]; i *= 2) {
244 int unfoldedSrcSetBit = localSetBit - i;
246 T tmp = cuda_shfl_sync_idx_i32(
247 MemberMask, x, __nvvm_fns(MemberMask, 0, unfoldedSrcSetBit), 31);
250 x = binary_op(x, tmp);
253 x = cuda_shfl_sync_idx_i32(MemberMask, x,
254 __nvvm_fns(MemberMask, 0, localSetBit - 1), 31);
255 if (localIdVal == 0) {
256 return get_identity<T, BinaryOperation>();
#define __SYCL_ALWAYS_INLINE