20 #ifdef __SYCL_DEVICE_ONLY__
35 inline typename multi_ptr<ToT, Space, access::decorated::yes>::pointer
36 GetMultiPtrDecoratedAs(multi_ptr<FromT, Space, IsDecorated> MPtr) {
37 if constexpr (IsDecorated == access::decorated::legacy)
38 return reinterpret_cast<
39 typename multi_ptr<ToT, Space, access::decorated::yes>::pointer
>(
42 return reinterpret_cast<
43 typename multi_ptr<ToT, Space, access::decorated::yes>::pointer
>(
44 MPtr.get_decorated());
49 template <
typename Group>
struct group_scope {};
55 template <>
struct group_scope<::sycl::ext::oneapi::sub_group> {
64 using ShuffleChunkT = uint64_t;
66 using ShuffleChunkT = uint32_t;
68 template <
typename T,
typename Functor>
69 void GenericCall(
const Functor &ApplyToBytes) {
70 if (
sizeof(T) >=
sizeof(ShuffleChunkT)) {
72 for (
size_t Offset = 0; Offset +
sizeof(ShuffleChunkT) <=
sizeof(T);
73 Offset +=
sizeof(ShuffleChunkT)) {
74 ApplyToBytes(Offset,
sizeof(ShuffleChunkT));
77 if (
sizeof(ShuffleChunkT) >=
sizeof(uint64_t)) {
78 if (
sizeof(T) %
sizeof(uint64_t) >=
sizeof(uint32_t)) {
79 size_t Offset =
sizeof(T) /
sizeof(uint64_t) *
sizeof(uint64_t);
80 ApplyToBytes(Offset,
sizeof(uint32_t));
83 if (
sizeof(ShuffleChunkT) >=
sizeof(uint32_t)) {
84 if (
sizeof(T) %
sizeof(uint32_t) >=
sizeof(uint16_t)) {
85 size_t Offset =
sizeof(T) /
sizeof(uint32_t) *
sizeof(uint32_t);
86 ApplyToBytes(Offset,
sizeof(uint16_t));
89 if (
sizeof(ShuffleChunkT) >=
sizeof(uint16_t)) {
90 if (
sizeof(T) %
sizeof(uint16_t) >=
sizeof(uint8_t)) {
91 size_t Offset =
sizeof(T) /
sizeof(uint16_t) *
sizeof(uint16_t);
92 ApplyToBytes(Offset,
sizeof(uint8_t));
97 template <
typename Group>
bool GroupAll(
bool pred) {
98 return __spirv_GroupAll(group_scope<Group>::value, pred);
101 template <
typename Group>
bool GroupAny(
bool pred) {
102 return __spirv_GroupAny(group_scope<Group>::value, pred);
107 template <
typename T>
108 using is_native_broadcast = bool_constant<detail::is_arithmetic<T>::value &&
109 !std::is_same<T, half>::value>;
111 template <
typename T,
typename IdT =
size_t>
113 is_native_broadcast<T>::value && std::is_integral<IdT>::value, T>;
117 template <
typename T>
119 !is_native_broadcast<T>::value && std::is_trivially_copyable<T>::value &&
120 (
sizeof(T) == 1 ||
sizeof(T) == 2 ||
sizeof(T) == 4 ||
sizeof(T) == 8)>;
122 template <
typename T,
typename IdT =
size_t>
124 is_bitcast_broadcast<T>::value && std::is_integral<IdT>::value, T>;
126 template <
typename T>
127 using ConvertToNativeBroadcastType_t = select_cl_scalar_integral_unsigned_t<T>;
133 template <
typename T>
134 using is_generic_broadcast =
135 bool_constant<!is_native_broadcast<T>::value &&
136 !is_bitcast_broadcast<T>::value &&
137 std::is_trivially_copyable<T>::value>;
139 template <
typename T,
typename IdT =
size_t>
141 is_generic_broadcast<T>::value && std::is_integral<IdT>::value, T>;
144 template <
typename T>
146 std::is_same<T, cl_char>() || std::is_same<T, cl_short>(),
cl_int,
147 conditional_t<std::is_same<T, cl_uchar>() || std::is_same<T, cl_ushort>(),
153 template <
typename Group>
struct GroupId {
156 template <>
struct GroupId<::sycl::ext::oneapi::sub_group> {
157 using type = uint32_t;
159 template <
typename Group,
typename T,
typename IdT>
160 EnableIfNativeBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
161 using GroupIdT =
typename GroupId<Group>::type;
162 GroupIdT GroupLocalId =
static_cast<GroupIdT
>(local_id);
163 using OCLT = detail::ConvertToOpenCLType_t<T>;
164 using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
165 using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
166 WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
167 OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
168 return __spirv_GroupBroadcast(group_scope<Group>::value, OCLX, OCLId);
170 template <
typename Group,
typename T,
typename IdT>
171 EnableIfBitcastBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
172 using BroadcastT = ConvertToNativeBroadcastType_t<T>;
173 auto BroadcastX = bit_cast<BroadcastT>(x);
174 BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
175 return bit_cast<T>(Result);
177 template <
typename Group,
typename T,
typename IdT>
178 EnableIfGenericBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
181 char *XBytes =
reinterpret_cast<char *
>(&x);
182 char *ResultBytes =
reinterpret_cast<char *
>(&Result);
183 auto BroadcastBytes = [=](
size_t Offset,
size_t Size) {
184 uint64_t BroadcastX, BroadcastResult;
186 BroadcastResult = GroupBroadcast<Group>(BroadcastX, local_id);
187 std::memcpy(ResultBytes + Offset, &BroadcastResult, Size);
189 GenericCall<T>(BroadcastBytes);
194 template <
typename Group,
typename T,
int Dimensions>
195 EnableIfNativeBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
197 return GroupBroadcast<Group>(x, local_id[0]);
199 using IdT = vec<size_t, Dimensions>;
200 using OCLT = detail::ConvertToOpenCLType_t<T>;
201 using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
202 using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
207 WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
208 OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
209 return __spirv_GroupBroadcast(group_scope<Group>::value, OCLX, OCLId);
211 template <
typename Group,
typename T,
int Dimensions>
212 EnableIfBitcastBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
213 using BroadcastT = ConvertToNativeBroadcastType_t<T>;
214 auto BroadcastX = bit_cast<BroadcastT>(x);
215 BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
216 return bit_cast<T>(Result);
218 template <
typename Group,
typename T,
int Dimensions>
219 EnableIfGenericBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
221 return GroupBroadcast<Group>(x, local_id[0]);
225 char *XBytes =
reinterpret_cast<char *
>(&x);
226 char *ResultBytes =
reinterpret_cast<char *
>(&Result);
227 auto BroadcastBytes = [=](
size_t Offset,
size_t Size) {
228 uint64_t BroadcastX, BroadcastResult;
230 BroadcastResult = GroupBroadcast<Group>(BroadcastX, local_id);
231 std::memcpy(ResultBytes + Offset, &BroadcastResult, Size);
233 GenericCall<T>(BroadcastBytes);
239 template <
typename T>
240 static inline constexpr
241 typename std::enable_if<std::is_same<T, sycl::memory_order>::value,
243 getMemorySemanticsMask(T Order) {
249 case T::__consume_unsupported:
271 case memory_scope::work_item:
273 case memory_scope::sub_group:
275 case memory_scope::work_group:
277 case memory_scope::device:
279 case memory_scope::system:
286 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
287 AtomicCompareExchange(multi_ptr<T, AddressSpace, IsDecorated> MPtr,
290 auto SPIRVSuccess = getMemorySemanticsMask(Success);
291 auto SPIRVFailure = getMemorySemanticsMask(Failure);
292 auto SPIRVScope = getScope(Scope);
293 auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
294 return __spirv_AtomicCompareExchange(Ptr, SPIRVScope, SPIRVSuccess,
295 SPIRVFailure, Desired, Expected);
300 inline typename detail::enable_if_t<std::is_floating_point<T>::value, T>
301 AtomicCompareExchange(multi_ptr<T, AddressSpace, IsDecorated> MPtr,
304 using I = detail::make_unsinged_integer_t<T>;
305 auto SPIRVSuccess = getMemorySemanticsMask(Success);
306 auto SPIRVFailure = getMemorySemanticsMask(Failure);
307 auto SPIRVScope = getScope(Scope);
308 auto *PtrInt = GetMultiPtrDecoratedAs<I>(MPtr);
309 I DesiredInt = bit_cast<I>(Desired);
310 I ExpectedInt = bit_cast<I>(Expected);
311 I ResultInt = __spirv_AtomicCompareExchange(
312 PtrInt, SPIRVScope, SPIRVSuccess, SPIRVFailure, DesiredInt, ExpectedInt);
313 return bit_cast<T>(ResultInt);
318 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
319 AtomicLoad(multi_ptr<T, AddressSpace, IsDecorated> MPtr,
memory_scope Scope,
321 auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
322 auto SPIRVOrder = getMemorySemanticsMask(Order);
323 auto SPIRVScope = getScope(Scope);
329 inline typename detail::enable_if_t<std::is_floating_point<T>::value, T>
330 AtomicLoad(multi_ptr<T, AddressSpace, IsDecorated> MPtr,
memory_scope Scope,
332 using I = detail::make_unsinged_integer_t<T>;
333 auto *PtrInt = GetMultiPtrDecoratedAs<I>(MPtr);
334 auto SPIRVOrder = getMemorySemanticsMask(Order);
335 auto SPIRVScope = getScope(Scope);
337 return bit_cast<T>(ResultInt);
342 inline typename detail::enable_if_t<std::is_integral<T>::value>
343 AtomicStore(multi_ptr<T, AddressSpace, IsDecorated> MPtr,
memory_scope Scope,
345 auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
346 auto SPIRVOrder = getMemorySemanticsMask(Order);
347 auto SPIRVScope = getScope(Scope);
353 inline typename detail::enable_if_t<std::is_floating_point<T>::value>
354 AtomicStore(multi_ptr<T, AddressSpace, IsDecorated> MPtr,
memory_scope Scope,
356 using I = detail::make_unsinged_integer_t<T>;
357 auto *PtrInt = GetMultiPtrDecoratedAs<I>(MPtr);
358 auto SPIRVOrder = getMemorySemanticsMask(Order);
359 auto SPIRVScope = getScope(Scope);
360 I ValueInt = bit_cast<I>(Value);
366 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
367 AtomicExchange(multi_ptr<T, AddressSpace, IsDecorated> MPtr,
memory_scope Scope,
369 auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
370 auto SPIRVOrder = getMemorySemanticsMask(Order);
371 auto SPIRVScope = getScope(Scope);
377 inline typename detail::enable_if_t<std::is_floating_point<T>::value, T>
378 AtomicExchange(multi_ptr<T, AddressSpace, IsDecorated> MPtr,
memory_scope Scope,
380 using I = detail::make_unsinged_integer_t<T>;
381 auto *PtrInt = GetMultiPtrDecoratedAs<I>(MPtr);
382 auto SPIRVOrder = getMemorySemanticsMask(Order);
383 auto SPIRVScope = getScope(Scope);
384 I ValueInt = bit_cast<I>(Value);
387 return bit_cast<T>(ResultInt);
392 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
393 AtomicIAdd(multi_ptr<T, AddressSpace, IsDecorated> MPtr,
memory_scope Scope,
395 auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
396 auto SPIRVOrder = getMemorySemanticsMask(Order);
397 auto SPIRVScope = getScope(Scope);
403 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
404 AtomicISub(multi_ptr<T, AddressSpace, IsDecorated> MPtr,
memory_scope Scope,
406 auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
407 auto SPIRVOrder = getMemorySemanticsMask(Order);
408 auto SPIRVScope = getScope(Scope);
414 inline typename detail::enable_if_t<std::is_floating_point<T>::value, T>
415 AtomicFAdd(multi_ptr<T, AddressSpace, IsDecorated> MPtr,
memory_scope Scope,
417 auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
418 auto SPIRVOrder = getMemorySemanticsMask(Order);
419 auto SPIRVScope = getScope(Scope);
420 return __spirv_AtomicFAddEXT(Ptr, SPIRVScope, SPIRVOrder, Value);
425 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
426 AtomicAnd(multi_ptr<T, AddressSpace, IsDecorated> MPtr,
memory_scope Scope,
428 auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
429 auto SPIRVOrder = getMemorySemanticsMask(Order);
430 auto SPIRVScope = getScope(Scope);
436 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
437 AtomicOr(multi_ptr<T, AddressSpace, IsDecorated> MPtr,
memory_scope Scope,
439 auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
440 auto SPIRVOrder = getMemorySemanticsMask(Order);
441 auto SPIRVScope = getScope(Scope);
447 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
448 AtomicXor(multi_ptr<T, AddressSpace, IsDecorated> MPtr,
memory_scope Scope,
450 auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
451 auto SPIRVOrder = getMemorySemanticsMask(Order);
452 auto SPIRVScope = getScope(Scope);
458 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
459 AtomicMin(multi_ptr<T, AddressSpace, IsDecorated> MPtr,
memory_scope Scope,
461 auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
462 auto SPIRVOrder = getMemorySemanticsMask(Order);
463 auto SPIRVScope = getScope(Scope);
469 inline typename detail::enable_if_t<std::is_floating_point<T>::value, T>
470 AtomicMin(multi_ptr<T, AddressSpace, IsDecorated> MPtr,
memory_scope Scope,
472 auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
473 auto SPIRVOrder = getMemorySemanticsMask(Order);
474 auto SPIRVScope = getScope(Scope);
480 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
481 AtomicMax(multi_ptr<T, AddressSpace, IsDecorated> MPtr,
memory_scope Scope,
483 auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
484 auto SPIRVOrder = getMemorySemanticsMask(Order);
485 auto SPIRVScope = getScope(Scope);
491 inline typename detail::enable_if_t<std::is_floating_point<T>::value, T>
492 AtomicMax(multi_ptr<T, AddressSpace, IsDecorated> MPtr,
memory_scope Scope,
494 auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
495 auto SPIRVOrder = getMemorySemanticsMask(Order);
496 auto SPIRVScope = getScope(Scope);
509 template <
typename T>
510 struct TypeIsProhibitedForShuffleEmulation
511 :
bool_constant<std::is_same_v<vector_element_t<T>, double>> {};
513 template <
typename T>
514 struct VecTypeIsProhibitedForShuffleEmulation
516 (detail::get_vec_size<T>::size > 1) &&
517 TypeIsProhibitedForShuffleEmulation<vector_element_t<T>>::value> {};
519 template <
typename T>
520 using EnableIfNativeShuffle =
521 std::enable_if_t<detail::is_arithmetic<T>::value &&
522 !VecTypeIsProhibitedForShuffleEmulation<T>::value,
525 template <
typename T>
526 using EnableIfVectorShuffle =
527 std::enable_if_t<VecTypeIsProhibitedForShuffleEmulation<T>::value, T>;
529 #else // ifndef __NVPTX__
531 template <
typename T>
533 std::is_integral<T>::value && (
sizeof(T) <=
sizeof(int32_t)), T>;
535 template <
typename T>
536 using EnableIfVectorShuffle =
537 std::enable_if_t<detail::is_vector_arithmetic<T>::value, T>;
538 #endif // ifndef __NVPTX__
543 template <
typename T>
544 using EnableIfBitcastShuffle =
545 std::enable_if_t<!detail::is_arithmetic<T>::value &&
546 (std::is_trivially_copyable_v<T> &&
547 (
sizeof(T) == 1 ||
sizeof(T) == 2 ||
sizeof(T) == 4 ||
551 template <
typename T>
552 using EnableIfBitcastShuffle =
554 (
sizeof(T) <=
sizeof(int32_t))) &&
555 !detail::is_vector_arithmetic<T>::value &&
556 (std::is_trivially_copyable_v<T> &&
557 (
sizeof(T) == 1 ||
sizeof(T) == 2 ||
sizeof(T) == 4)),
559 #endif // ifndef __NVPTX__
566 template <
typename T>
567 using EnableIfGenericShuffle =
568 std::enable_if_t<!detail::is_arithmetic<T>::value &&
569 !(std::is_trivially_copyable_v<T> &&
570 (
sizeof(T) == 1 ||
sizeof(T) == 2 ||
571 sizeof(T) == 4 ||
sizeof(T) == 8)),
574 template <
typename T>
576 !(std::is_integral<T>::value && (
sizeof(T) <=
sizeof(int32_t))) &&
577 !detail::is_vector_arithmetic<T>::value &&
578 !(std::is_trivially_copyable_v<T> &&
579 (
sizeof(T) == 1 ||
sizeof(T) == 2 ||
sizeof(T) == 4)),
584 inline uint32_t membermask() {
592 template <
typename T>
593 EnableIfBitcastShuffle<T> SubgroupShuffle(T x, id<1> local_id);
595 template <
typename T>
596 EnableIfBitcastShuffle<T> SubgroupShuffleXor(T x, id<1> local_id);
598 template <
typename T>
599 EnableIfBitcastShuffle<T> SubgroupShuffleDown(T x, id<1> local_id);
601 template <
typename T>
602 EnableIfBitcastShuffle<T> SubgroupShuffleUp(T x, id<1> local_id);
604 template <
typename T>
605 EnableIfGenericShuffle<T> SubgroupShuffle(T x, id<1> local_id);
607 template <
typename T>
608 EnableIfGenericShuffle<T> SubgroupShuffleXor(T x, id<1> local_id);
610 template <
typename T>
611 EnableIfGenericShuffle<T> SubgroupShuffleDown(T x, id<1> local_id);
613 template <
typename T>
614 EnableIfGenericShuffle<T> SubgroupShuffleUp(T x, id<1> local_id);
616 template <
typename T>
617 EnableIfNativeShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
619 using OCLT = detail::ConvertToOpenCLType_t<T>;
620 return __spirv_SubgroupShuffleINTEL(OCLT(x),
621 static_cast<uint32_t
>(local_id.get(0)));
623 return __nvvm_shfl_sync_idx_i32(membermask(), x, local_id.get(0), 0x1f);
627 template <
typename T>
628 EnableIfNativeShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
630 using OCLT = detail::ConvertToOpenCLType_t<T>;
631 return __spirv_SubgroupShuffleXorINTEL(
632 OCLT(x),
static_cast<uint32_t
>(local_id.get(0)));
634 return __nvvm_shfl_sync_bfly_i32(membermask(), x, local_id.get(0), 0x1f);
638 template <
typename T>
639 EnableIfNativeShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
641 using OCLT = detail::ConvertToOpenCLType_t<T>;
642 return __spirv_SubgroupShuffleDownINTEL(OCLT(x), OCLT(x), delta);
644 return __nvvm_shfl_sync_down_i32(membermask(), x, delta, 0x1f);
648 template <
typename T>
649 EnableIfNativeShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
651 using OCLT = detail::ConvertToOpenCLType_t<T>;
652 return __spirv_SubgroupShuffleUpINTEL(OCLT(x), OCLT(x), delta);
654 return __nvvm_shfl_sync_up_i32(membermask(), x, delta, 0);
658 template <
typename T>
659 EnableIfVectorShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
661 for (
int s = 0;
s < x.size(); ++
s) {
662 result[
s] = SubgroupShuffle(x[
s], local_id);
667 template <
typename T>
668 EnableIfVectorShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
670 for (
int s = 0;
s < x.size(); ++
s) {
671 result[
s] = SubgroupShuffleXor(x[
s], local_id);
676 template <
typename T>
677 EnableIfVectorShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
679 for (
int s = 0;
s < x.size(); ++
s) {
680 result[
s] = SubgroupShuffleDown(x[
s], delta);
685 template <
typename T>
686 EnableIfVectorShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
688 for (
int s = 0;
s < x.size(); ++
s) {
689 result[
s] = SubgroupShuffleUp(x[
s], delta);
694 template <
typename T>
695 using ConvertToNativeShuffleType_t = select_cl_scalar_integral_unsigned_t<T>;
697 template <
typename T>
698 EnableIfBitcastShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
699 using ShuffleT = ConvertToNativeShuffleType_t<T>;
700 auto ShuffleX = bit_cast<ShuffleT>(x);
702 ShuffleT Result = __spirv_SubgroupShuffleINTEL(
703 ShuffleX,
static_cast<uint32_t
>(local_id.get(0)));
706 __nvvm_shfl_sync_idx_i32(membermask(), ShuffleX, local_id.get(0), 0x1f);
708 return bit_cast<T>(Result);
711 template <
typename T>
712 EnableIfBitcastShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
713 using ShuffleT = ConvertToNativeShuffleType_t<T>;
714 auto ShuffleX = bit_cast<ShuffleT>(x);
716 ShuffleT Result = __spirv_SubgroupShuffleXorINTEL(
717 ShuffleX,
static_cast<uint32_t
>(local_id.get(0)));
720 __nvvm_shfl_sync_bfly_i32(membermask(), ShuffleX, local_id.get(0), 0x1f);
722 return bit_cast<T>(Result);
725 template <
typename T>
726 EnableIfBitcastShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
727 using ShuffleT = ConvertToNativeShuffleType_t<T>;
728 auto ShuffleX = bit_cast<ShuffleT>(x);
730 ShuffleT Result = __spirv_SubgroupShuffleDownINTEL(ShuffleX, ShuffleX, delta);
733 __nvvm_shfl_sync_down_i32(membermask(), ShuffleX, delta, 0x1f);
735 return bit_cast<T>(Result);
738 template <
typename T>
739 EnableIfBitcastShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
740 using ShuffleT = ConvertToNativeShuffleType_t<T>;
741 auto ShuffleX = bit_cast<ShuffleT>(x);
743 ShuffleT Result = __spirv_SubgroupShuffleUpINTEL(ShuffleX, ShuffleX, delta);
745 ShuffleT Result = __nvvm_shfl_sync_up_i32(membermask(), ShuffleX, delta, 0);
747 return bit_cast<T>(Result);
750 template <
typename T>
751 EnableIfGenericShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
753 char *XBytes =
reinterpret_cast<char *
>(&x);
754 char *ResultBytes =
reinterpret_cast<char *
>(&Result);
755 auto ShuffleBytes = [=](
size_t Offset,
size_t Size) {
756 ShuffleChunkT ShuffleX, ShuffleResult;
758 ShuffleResult = SubgroupShuffle(ShuffleX, local_id);
759 std::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
761 GenericCall<T>(ShuffleBytes);
765 template <
typename T>
766 EnableIfGenericShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
768 char *XBytes =
reinterpret_cast<char *
>(&x);
769 char *ResultBytes =
reinterpret_cast<char *
>(&Result);
770 auto ShuffleBytes = [=](
size_t Offset,
size_t Size) {
771 ShuffleChunkT ShuffleX, ShuffleResult;
773 ShuffleResult = SubgroupShuffleXor(ShuffleX, local_id);
774 std::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
776 GenericCall<T>(ShuffleBytes);
780 template <
typename T>
781 EnableIfGenericShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
783 char *XBytes =
reinterpret_cast<char *
>(&x);
784 char *ResultBytes =
reinterpret_cast<char *
>(&Result);
785 auto ShuffleBytes = [=](
size_t Offset,
size_t Size) {
786 ShuffleChunkT ShuffleX, ShuffleResult;
788 ShuffleResult = SubgroupShuffleDown(ShuffleX, delta);
789 std::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
791 GenericCall<T>(ShuffleBytes);
795 template <
typename T>
796 EnableIfGenericShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
798 char *XBytes =
reinterpret_cast<char *
>(&x);
799 char *ResultBytes =
reinterpret_cast<char *
>(&Result);
800 auto ShuffleBytes = [=](
size_t Offset,
size_t Size) {
801 ShuffleChunkT ShuffleX, ShuffleResult;
803 ShuffleResult = SubgroupShuffleUp(ShuffleX, delta);
804 std::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
806 GenericCall<T>(ShuffleBytes);
814 #endif // __SYCL_DEVICE_ONLY__