20 #ifdef __SYCL_DEVICE_ONLY__
32 template <
typename Group>
struct group_scope {};
47 using ShuffleChunkT = uint64_t;
49 using ShuffleChunkT = uint32_t;
51 template <
typename T,
typename Functor>
52 void GenericCall(
const Functor &ApplyToBytes) {
53 if (
sizeof(
T) >=
sizeof(ShuffleChunkT)) {
55 for (
size_t Offset = 0; Offset +
sizeof(ShuffleChunkT) <=
sizeof(
T);
56 Offset +=
sizeof(ShuffleChunkT)) {
57 ApplyToBytes(Offset,
sizeof(ShuffleChunkT));
60 if (
sizeof(ShuffleChunkT) >=
sizeof(uint64_t)) {
61 if (
sizeof(
T) %
sizeof(uint64_t) >=
sizeof(uint32_t)) {
62 size_t Offset =
sizeof(
T) /
sizeof(uint64_t) *
sizeof(uint64_t);
63 ApplyToBytes(Offset,
sizeof(uint32_t));
66 if (
sizeof(ShuffleChunkT) >=
sizeof(uint32_t)) {
67 if (
sizeof(
T) %
sizeof(uint32_t) >=
sizeof(uint16_t)) {
68 size_t Offset =
sizeof(
T) /
sizeof(uint32_t) *
sizeof(uint32_t);
69 ApplyToBytes(Offset,
sizeof(uint16_t));
72 if (
sizeof(ShuffleChunkT) >=
sizeof(uint16_t)) {
73 if (
sizeof(
T) %
sizeof(uint16_t) >=
sizeof(uint8_t)) {
74 size_t Offset =
sizeof(
T) /
sizeof(uint16_t) *
sizeof(uint16_t);
75 ApplyToBytes(Offset,
sizeof(uint8_t));
80 template <
typename Group>
bool GroupAll(
bool pred) {
81 return __spirv_GroupAll(group_scope<Group>::value, pred);
84 template <
typename Group>
bool GroupAny(
bool pred) {
85 return __spirv_GroupAny(group_scope<Group>::value, pred);
91 using is_native_broadcast = bool_constant<detail::is_arithmetic<T>::value &&
92 !std::is_same<T, half>::value>;
94 template <
typename T,
typename IdT =
size_t>
96 is_native_broadcast<T>::value && std::is_integral<IdT>::value,
T>;
100 template <
typename T>
102 !is_native_broadcast<T>::value && std::is_trivially_copyable<T>::value &&
103 (
sizeof(
T) == 1 ||
sizeof(
T) == 2 ||
sizeof(
T) == 4 ||
sizeof(
T) == 8)>;
105 template <
typename T,
typename IdT =
size_t>
107 is_bitcast_broadcast<T>::value && std::is_integral<IdT>::value,
T>;
109 template <
typename T>
110 using ConvertToNativeBroadcastType_t = select_cl_scalar_integral_unsigned_t<T>;
116 template <
typename T>
117 using is_generic_broadcast =
118 bool_constant<!is_native_broadcast<T>::value &&
119 !is_bitcast_broadcast<T>::value &&
120 std::is_trivially_copyable<T>::value>;
122 template <
typename T,
typename IdT =
size_t>
124 is_generic_broadcast<T>::value && std::is_integral<IdT>::value,
T>;
127 template <
typename T>
129 std::is_same<T, cl_char>() || std::is_same<T, cl_short>(),
cl_int,
130 conditional_t<std::is_same<T, cl_uchar>() || std::is_same<T, cl_ushort>(),
136 template <
typename Group>
struct GroupId {
using type = size_t; };
138 using type = uint32_t;
140 template <
typename Group,
typename T,
typename IdT>
141 EnableIfNativeBroadcast<T, IdT> GroupBroadcast(
T x, IdT local_id) {
142 using GroupIdT =
typename GroupId<Group>::type;
143 GroupIdT GroupLocalId =
static_cast<GroupIdT
>(local_id);
144 using OCLT = detail::ConvertToOpenCLType_t<T>;
145 using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
146 using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
147 WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
148 OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
149 return __spirv_GroupBroadcast(group_scope<Group>::value, OCLX, OCLId);
151 template <
typename Group,
typename T,
typename IdT>
152 EnableIfBitcastBroadcast<T, IdT> GroupBroadcast(
T x, IdT local_id) {
153 using BroadcastT = ConvertToNativeBroadcastType_t<T>;
154 auto BroadcastX = bit_cast<BroadcastT>(x);
155 BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
156 return bit_cast<T>(Result);
158 template <
typename Group,
typename T,
typename IdT>
159 EnableIfGenericBroadcast<T, IdT> GroupBroadcast(
T x, IdT local_id) {
161 char *XBytes =
reinterpret_cast<char *
>(&x);
162 char *ResultBytes =
reinterpret_cast<char *
>(&Result);
163 auto BroadcastBytes = [=](
size_t Offset,
size_t Size) {
164 uint64_t BroadcastX, BroadcastResult;
166 BroadcastResult = GroupBroadcast<Group>(BroadcastX, local_id);
167 std::memcpy(ResultBytes + Offset, &BroadcastResult, Size);
169 GenericCall<T>(BroadcastBytes);
174 template <
typename Group,
typename T,
int Dimensions>
175 EnableIfNativeBroadcast<T> GroupBroadcast(
T x, id<Dimensions> local_id) {
177 return GroupBroadcast<Group>(x, local_id[0]);
179 using IdT = vec<size_t, Dimensions>;
180 using OCLT = detail::ConvertToOpenCLType_t<T>;
181 using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
182 using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
187 WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
188 OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
189 return __spirv_GroupBroadcast(group_scope<Group>::value, OCLX, OCLId);
191 template <
typename Group,
typename T,
int Dimensions>
192 EnableIfBitcastBroadcast<T> GroupBroadcast(
T x, id<Dimensions> local_id) {
193 using BroadcastT = ConvertToNativeBroadcastType_t<T>;
194 auto BroadcastX = bit_cast<BroadcastT>(x);
195 BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
196 return bit_cast<T>(Result);
198 template <
typename Group,
typename T,
int Dimensions>
199 EnableIfGenericBroadcast<T> GroupBroadcast(
T x, id<Dimensions> local_id) {
201 return GroupBroadcast<Group>(x, local_id[0]);
204 char *XBytes =
reinterpret_cast<char *
>(&x);
205 char *ResultBytes =
reinterpret_cast<char *
>(&Result);
206 auto BroadcastBytes = [=](
size_t Offset,
size_t Size) {
207 uint64_t BroadcastX, BroadcastResult;
209 BroadcastResult = GroupBroadcast<Group>(BroadcastX, local_id);
210 std::memcpy(ResultBytes + Offset, &BroadcastResult, Size);
212 GenericCall<T>(BroadcastBytes);
218 template <
typename T>
219 static inline constexpr
220 typename std::enable_if<std::is_same<T, sycl::memory_order>::value,
222 getMemorySemanticsMask(
T Order) {
228 case T::__consume_unsupported:
250 case memory_scope::work_item:
252 case memory_scope::sub_group:
254 case memory_scope::work_group:
258 case memory_scope::system:
263 template <
typename T, access::address_space AddressSpace>
264 inline typename detail::enable_if_t<std::is_integral<T>::value,
T>
265 AtomicCompareExchange(multi_ptr<T, AddressSpace> MPtr,
memory_scope Scope,
268 auto SPIRVSuccess = getMemorySemanticsMask(Success);
269 auto SPIRVFailure = getMemorySemanticsMask(Failure);
270 auto SPIRVScope = getScope(Scope);
271 auto *Ptr = MPtr.get();
272 return __spirv_AtomicCompareExchange(Ptr, SPIRVScope, SPIRVSuccess,
273 SPIRVFailure, Desired, Expected);
276 template <
typename T, access::address_space AddressSpace>
277 inline typename detail::enable_if_t<std::is_floating_point<T>::value,
T>
278 AtomicCompareExchange(multi_ptr<T, AddressSpace> MPtr,
memory_scope Scope,
281 using I = detail::make_unsinged_integer_t<T>;
282 auto SPIRVSuccess = getMemorySemanticsMask(Success);
283 auto SPIRVFailure = getMemorySemanticsMask(Failure);
284 auto SPIRVScope = getScope(Scope);
286 reinterpret_cast<typename multi_ptr<I, AddressSpace>::pointer_t
>(
288 I DesiredInt = bit_cast<I>(Desired);
289 I ExpectedInt = bit_cast<I>(Expected);
290 I ResultInt = __spirv_AtomicCompareExchange(
291 PtrInt, SPIRVScope, SPIRVSuccess, SPIRVFailure, DesiredInt, ExpectedInt);
292 return bit_cast<T>(ResultInt);
295 template <
typename T, access::address_space AddressSpace>
296 inline typename detail::enable_if_t<std::is_integral<T>::value,
T>
297 AtomicLoad(multi_ptr<T, AddressSpace> MPtr,
memory_scope Scope,
299 auto *Ptr = MPtr.get();
300 auto SPIRVOrder = getMemorySemanticsMask(Order);
301 auto SPIRVScope = getScope(Scope);
305 template <
typename T, access::address_space AddressSpace>
306 inline typename detail::enable_if_t<std::is_floating_point<T>::value,
T>
307 AtomicLoad(multi_ptr<T, AddressSpace> MPtr,
memory_scope Scope,
309 using I = detail::make_unsinged_integer_t<T>;
311 reinterpret_cast<typename multi_ptr<I, AddressSpace>::pointer_t
>(
313 auto SPIRVOrder = getMemorySemanticsMask(Order);
314 auto SPIRVScope = getScope(Scope);
316 return bit_cast<T>(ResultInt);
319 template <
typename T, access::address_space AddressSpace>
320 inline typename detail::enable_if_t<std::is_integral<T>::value>
321 AtomicStore(multi_ptr<T, AddressSpace> MPtr,
memory_scope Scope,
323 auto *Ptr = MPtr.get();
324 auto SPIRVOrder = getMemorySemanticsMask(Order);
325 auto SPIRVScope = getScope(Scope);
329 template <
typename T, access::address_space AddressSpace>
330 inline typename detail::enable_if_t<std::is_floating_point<T>::value>
331 AtomicStore(multi_ptr<T, AddressSpace> MPtr,
memory_scope Scope,
333 using I = detail::make_unsinged_integer_t<T>;
335 reinterpret_cast<typename multi_ptr<I, AddressSpace>::pointer_t
>(
337 auto SPIRVOrder = getMemorySemanticsMask(Order);
338 auto SPIRVScope = getScope(Scope);
339 I ValueInt = bit_cast<I>(Value);
343 template <
typename T, access::address_space AddressSpace>
344 inline typename detail::enable_if_t<std::is_integral<T>::value,
T>
345 AtomicExchange(multi_ptr<T, AddressSpace> MPtr,
memory_scope Scope,
347 auto *Ptr = MPtr.get();
348 auto SPIRVOrder = getMemorySemanticsMask(Order);
349 auto SPIRVScope = getScope(Scope);
353 template <
typename T, access::address_space AddressSpace>
354 inline typename detail::enable_if_t<std::is_floating_point<T>::value,
T>
355 AtomicExchange(multi_ptr<T, AddressSpace> MPtr,
memory_scope Scope,
357 using I = detail::make_unsinged_integer_t<T>;
359 reinterpret_cast<typename multi_ptr<I, AddressSpace>::pointer_t
>(
361 auto SPIRVOrder = getMemorySemanticsMask(Order);
362 auto SPIRVScope = getScope(Scope);
363 I ValueInt = bit_cast<I>(Value);
366 return bit_cast<T>(ResultInt);
369 template <
typename T, access::address_space AddressSpace>
370 inline typename detail::enable_if_t<std::is_integral<T>::value,
T>
371 AtomicIAdd(multi_ptr<T, AddressSpace> MPtr,
memory_scope Scope,
373 auto *Ptr = MPtr.get();
374 auto SPIRVOrder = getMemorySemanticsMask(Order);
375 auto SPIRVScope = getScope(Scope);
379 template <
typename T, access::address_space AddressSpace>
380 inline typename detail::enable_if_t<std::is_integral<T>::value,
T>
381 AtomicISub(multi_ptr<T, AddressSpace> MPtr,
memory_scope Scope,
383 auto *Ptr = MPtr.get();
384 auto SPIRVOrder = getMemorySemanticsMask(Order);
385 auto SPIRVScope = getScope(Scope);
389 template <
typename T, access::address_space AddressSpace>
390 inline typename detail::enable_if_t<std::is_floating_point<T>::value,
T>
391 AtomicFAdd(multi_ptr<T, AddressSpace> MPtr,
memory_scope Scope,
393 auto *Ptr = MPtr.get();
394 auto SPIRVOrder = getMemorySemanticsMask(Order);
395 auto SPIRVScope = getScope(Scope);
396 return __spirv_AtomicFAddEXT(Ptr, SPIRVScope, SPIRVOrder, Value);
399 template <
typename T, access::address_space AddressSpace>
400 inline typename detail::enable_if_t<std::is_integral<T>::value,
T>
401 AtomicAnd(multi_ptr<T, AddressSpace> MPtr,
memory_scope Scope,
403 auto *Ptr = MPtr.get();
404 auto SPIRVOrder = getMemorySemanticsMask(Order);
405 auto SPIRVScope = getScope(Scope);
409 template <
typename T, access::address_space AddressSpace>
410 inline typename detail::enable_if_t<std::is_integral<T>::value,
T>
411 AtomicOr(multi_ptr<T, AddressSpace> MPtr,
memory_scope Scope,
413 auto *Ptr = MPtr.get();
414 auto SPIRVOrder = getMemorySemanticsMask(Order);
415 auto SPIRVScope = getScope(Scope);
419 template <
typename T, access::address_space AddressSpace>
420 inline typename detail::enable_if_t<std::is_integral<T>::value,
T>
421 AtomicXor(multi_ptr<T, AddressSpace> MPtr,
memory_scope Scope,
423 auto *Ptr = MPtr.get();
424 auto SPIRVOrder = getMemorySemanticsMask(Order);
425 auto SPIRVScope = getScope(Scope);
429 template <
typename T, access::address_space AddressSpace>
430 inline typename detail::enable_if_t<std::is_integral<T>::value,
T>
431 AtomicMin(multi_ptr<T, AddressSpace> MPtr,
memory_scope Scope,
433 auto *Ptr = MPtr.get();
434 auto SPIRVOrder = getMemorySemanticsMask(Order);
435 auto SPIRVScope = getScope(Scope);
439 template <
typename T, access::address_space AddressSpace>
440 inline typename detail::enable_if_t<std::is_floating_point<T>::value,
T>
441 AtomicMin(multi_ptr<T, AddressSpace> MPtr,
memory_scope Scope,
443 auto *Ptr = MPtr.get();
444 auto SPIRVOrder = getMemorySemanticsMask(Order);
445 auto SPIRVScope = getScope(Scope);
449 template <
typename T, access::address_space AddressSpace>
450 inline typename detail::enable_if_t<std::is_integral<T>::value,
T>
451 AtomicMax(multi_ptr<T, AddressSpace> MPtr,
memory_scope Scope,
453 auto *Ptr = MPtr.get();
454 auto SPIRVOrder = getMemorySemanticsMask(Order);
455 auto SPIRVScope = getScope(Scope);
459 template <
typename T, access::address_space AddressSpace>
460 inline typename detail::enable_if_t<std::is_floating_point<T>::value,
T>
461 AtomicMax(multi_ptr<T, AddressSpace> MPtr,
memory_scope Scope,
463 auto *Ptr = MPtr.get();
464 auto SPIRVOrder = getMemorySemanticsMask(Order);
465 auto SPIRVScope = getScope(Scope);
474 template <
typename T>
475 using EnableIfNativeShuffle =
476 detail::enable_if_t<detail::is_arithmetic<T>::value,
T>;
478 template <
typename T>
480 std::is_integral<T>::value && (
sizeof(
T) <=
sizeof(int32_t)),
T>;
482 template <
typename T>
483 using EnableIfVectorShuffle =
484 detail::enable_if_t<detail::is_vector_arithmetic<T>::value,
T>;
488 inline uint32_t membermask() {
495 template <
typename T>
496 EnableIfNativeShuffle<T> SubgroupShuffle(
T x, id<1> local_id) {
498 using OCLT = detail::ConvertToOpenCLType_t<T>;
499 return __spirv_SubgroupShuffleINTEL(OCLT(x),
500 static_cast<uint32_t
>(local_id.get(0)));
502 return __nvvm_shfl_sync_idx_i32(membermask(), x, local_id.get(0), 0x1f);
506 template <
typename T>
507 EnableIfNativeShuffle<T> SubgroupShuffleXor(
T x, id<1> local_id) {
509 using OCLT = detail::ConvertToOpenCLType_t<T>;
510 return __spirv_SubgroupShuffleXorINTEL(
511 OCLT(x),
static_cast<uint32_t
>(local_id.get(0)));
513 return __nvvm_shfl_sync_bfly_i32(membermask(), x, local_id.get(0), 0x1f);
517 template <
typename T>
518 EnableIfNativeShuffle<T> SubgroupShuffleDown(
T x, uint32_t delta) {
520 using OCLT = detail::ConvertToOpenCLType_t<T>;
521 return __spirv_SubgroupShuffleDownINTEL(OCLT(x), OCLT(x), delta);
523 return __nvvm_shfl_sync_down_i32(membermask(), x, delta, 0x1f);
527 template <
typename T>
528 EnableIfNativeShuffle<T> SubgroupShuffleUp(
T x, uint32_t delta) {
530 using OCLT = detail::ConvertToOpenCLType_t<T>;
531 return __spirv_SubgroupShuffleUpINTEL(OCLT(x), OCLT(x), delta);
533 return __nvvm_shfl_sync_up_i32(membermask(), x, delta, 0);
538 template <
typename T>
539 EnableIfVectorShuffle<T> SubgroupShuffle(
T x, id<1> local_id) {
541 for (
int s = 0;
s < x.get_size(); ++
s) {
542 result[
s] = SubgroupShuffle(x[
s], local_id);
547 template <
typename T>
548 EnableIfVectorShuffle<T> SubgroupShuffleXor(
T x, id<1> local_id) {
550 for (
int s = 0;
s < x.get_size(); ++
s) {
551 result[
s] = SubgroupShuffleXor(x[
s], local_id);
556 template <
typename T>
557 EnableIfVectorShuffle<T> SubgroupShuffleDown(
T x, uint32_t delta) {
559 for (
int s = 0;
s < x.get_size(); ++
s) {
560 result[
s] = SubgroupShuffleDown(x[
s], delta);
565 template <
typename T>
566 EnableIfVectorShuffle<T> SubgroupShuffleUp(
T x, uint32_t delta) {
568 for (
int s = 0;
s < x.get_size(); ++
s) {
569 result[
s] = SubgroupShuffleUp(x[
s], delta);
578 template <
typename T>
579 using EnableIfBitcastShuffle =
580 detail::enable_if_t<!detail::is_arithmetic<T>::value &&
581 (std::is_trivially_copyable<T>::value &&
582 (
sizeof(
T) == 1 ||
sizeof(
T) == 2 ||
583 sizeof(
T) == 4 ||
sizeof(
T) == 8)),
586 template <
typename T>
588 !(std::is_integral<T>::value && (
sizeof(
T) <=
sizeof(int32_t))) &&
589 !detail::is_vector_arithmetic<T>::value &&
590 (std::is_trivially_copyable<T>::value &&
591 (
sizeof(
T) == 1 ||
sizeof(
T) == 2 ||
sizeof(
T) == 4)),
595 template <
typename T>
596 using ConvertToNativeShuffleType_t = select_cl_scalar_integral_unsigned_t<T>;
598 template <
typename T>
599 EnableIfBitcastShuffle<T> SubgroupShuffle(
T x, id<1> local_id) {
600 using ShuffleT = ConvertToNativeShuffleType_t<T>;
601 auto ShuffleX = bit_cast<ShuffleT>(x);
603 ShuffleT Result = __spirv_SubgroupShuffleINTEL(
604 ShuffleX,
static_cast<uint32_t
>(local_id.get(0)));
607 __nvvm_shfl_sync_idx_i32(membermask(), ShuffleX, local_id.get(0), 0x1f);
609 return bit_cast<T>(Result);
612 template <
typename T>
613 EnableIfBitcastShuffle<T> SubgroupShuffleXor(
T x, id<1> local_id) {
614 using ShuffleT = ConvertToNativeShuffleType_t<T>;
615 auto ShuffleX = bit_cast<ShuffleT>(x);
617 ShuffleT Result = __spirv_SubgroupShuffleXorINTEL(
618 ShuffleX,
static_cast<uint32_t
>(local_id.get(0)));
621 __nvvm_shfl_sync_bfly_i32(membermask(), ShuffleX, local_id.get(0), 0x1f);
623 return bit_cast<T>(Result);
626 template <
typename T>
627 EnableIfBitcastShuffle<T> SubgroupShuffleDown(
T x, uint32_t delta) {
628 using ShuffleT = ConvertToNativeShuffleType_t<T>;
629 auto ShuffleX = bit_cast<ShuffleT>(x);
631 ShuffleT Result = __spirv_SubgroupShuffleDownINTEL(ShuffleX, ShuffleX, delta);
634 __nvvm_shfl_sync_down_i32(membermask(), ShuffleX, delta, 0x1f);
636 return bit_cast<T>(Result);
639 template <
typename T>
640 EnableIfBitcastShuffle<T> SubgroupShuffleUp(
T x, uint32_t delta) {
641 using ShuffleT = ConvertToNativeShuffleType_t<T>;
642 auto ShuffleX = bit_cast<ShuffleT>(x);
644 ShuffleT Result = __spirv_SubgroupShuffleUpINTEL(ShuffleX, ShuffleX, delta);
646 ShuffleT Result = __nvvm_shfl_sync_up_i32(membermask(), ShuffleX, delta, 0);
648 return bit_cast<T>(Result);
656 template <
typename T>
657 using EnableIfGenericShuffle =
658 detail::enable_if_t<!detail::is_arithmetic<T>::value &&
659 !(std::is_trivially_copyable<T>::value &&
660 (
sizeof(
T) == 1 ||
sizeof(
T) == 2 ||
661 sizeof(
T) == 4 ||
sizeof(
T) == 8)),
664 template <
typename T>
666 !(std::is_integral<T>::value && (
sizeof(
T) <=
sizeof(int32_t))) &&
667 !detail::is_vector_arithmetic<T>::value &&
668 !(std::is_trivially_copyable<T>::value &&
669 (
sizeof(
T) == 1 ||
sizeof(
T) == 2 ||
sizeof(
T) == 4)),
673 template <
typename T>
674 EnableIfGenericShuffle<T> SubgroupShuffle(
T x, id<1> local_id) {
676 char *XBytes =
reinterpret_cast<char *
>(&x);
677 char *ResultBytes =
reinterpret_cast<char *
>(&Result);
678 auto ShuffleBytes = [=](
size_t Offset,
size_t Size) {
679 ShuffleChunkT ShuffleX, ShuffleResult;
681 ShuffleResult = SubgroupShuffle(ShuffleX, local_id);
682 std::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
684 GenericCall<T>(ShuffleBytes);
688 template <
typename T>
689 EnableIfGenericShuffle<T> SubgroupShuffleXor(
T x, id<1> local_id) {
691 char *XBytes =
reinterpret_cast<char *
>(&x);
692 char *ResultBytes =
reinterpret_cast<char *
>(&Result);
693 auto ShuffleBytes = [=](
size_t Offset,
size_t Size) {
694 ShuffleChunkT ShuffleX, ShuffleResult;
696 ShuffleResult = SubgroupShuffleXor(ShuffleX, local_id);
697 std::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
699 GenericCall<T>(ShuffleBytes);
703 template <
typename T>
704 EnableIfGenericShuffle<T> SubgroupShuffleDown(
T x, uint32_t delta) {
706 char *XBytes =
reinterpret_cast<char *
>(&x);
707 char *ResultBytes =
reinterpret_cast<char *
>(&Result);
708 auto ShuffleBytes = [=](
size_t Offset,
size_t Size) {
709 ShuffleChunkT ShuffleX, ShuffleResult;
711 ShuffleResult = SubgroupShuffleDown(ShuffleX, delta);
712 std::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
714 GenericCall<T>(ShuffleBytes);
718 template <
typename T>
719 EnableIfGenericShuffle<T> SubgroupShuffleUp(
T x, uint32_t delta) {
721 char *XBytes =
reinterpret_cast<char *
>(&x);
722 char *ResultBytes =
reinterpret_cast<char *
>(&Result);
723 auto ShuffleBytes = [=](
size_t Offset,
size_t Size) {
724 ShuffleChunkT ShuffleX, ShuffleResult;
726 ShuffleResult = SubgroupShuffleUp(ShuffleX, delta);
727 std::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
729 GenericCall<T>(ShuffleBytes);
737 #endif // __SYCL_DEVICE_ONLY__