DPC++ Runtime
Runtime libraries for oneAPI DPC++
spirv.hpp
Go to the documentation of this file.
1 //===-- spirv.hpp - Helpers to generate SPIR-V instructions ----*- C++ -*--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 #include <CL/__spirv/spirv_ops.hpp>
13 #include <cstring>
15 #include <sycl/detail/helpers.hpp>
17 #include <sycl/id.hpp>
18 #include <sycl/memory_enums.hpp>
19 
20 #ifdef __SYCL_DEVICE_ONLY__
21 namespace sycl {
23 namespace ext {
24 namespace oneapi {
25 struct sub_group;
26 } // namespace oneapi
27 } // namespace ext
28 
29 namespace detail {
30 
31 // Helper for reinterpret casting the decorated pointer inside a multi_ptr
32 // without losing the decorations.
33 template <typename ToT, typename FromT, access::address_space Space,
34  access::decorated IsDecorated>
35 inline typename multi_ptr<ToT, Space, access::decorated::yes>::pointer
36 GetMultiPtrDecoratedAs(multi_ptr<FromT, Space, IsDecorated> MPtr) {
37  if constexpr (IsDecorated == access::decorated::legacy)
38  return reinterpret_cast<
39  typename multi_ptr<ToT, Space, access::decorated::yes>::pointer>(
40  MPtr.get());
41  else
42  return reinterpret_cast<
43  typename multi_ptr<ToT, Space, access::decorated::yes>::pointer>(
44  MPtr.get_decorated());
45 }
46 
47 namespace spirv {
48 
49 template <typename Group> struct group_scope {};
50 
51 template <int Dimensions> struct group_scope<group<Dimensions>> {
52  static constexpr __spv::Scope::Flag value = __spv::Scope::Flag::Workgroup;
53 };
54 
55 template <> struct group_scope<::sycl::ext::oneapi::sub_group> {
56  static constexpr __spv::Scope::Flag value = __spv::Scope::Flag::Subgroup;
57 };
58 
59 // Generic shuffles and broadcasts may require multiple calls to
60 // intrinsics, and should use the fewest broadcasts possible
61 // - Loop over chunks until remaining bytes < chunk size
62 // - At most one 32-bit, 16-bit and 8-bit chunk left over
63 #ifndef __NVPTX__
64 using ShuffleChunkT = uint64_t;
65 #else
66 using ShuffleChunkT = uint32_t;
67 #endif
68 template <typename T, typename Functor>
69 void GenericCall(const Functor &ApplyToBytes) {
70  if (sizeof(T) >= sizeof(ShuffleChunkT)) {
71 #pragma unroll
72  for (size_t Offset = 0; Offset + sizeof(ShuffleChunkT) <= sizeof(T);
73  Offset += sizeof(ShuffleChunkT)) {
74  ApplyToBytes(Offset, sizeof(ShuffleChunkT));
75  }
76  }
77  if (sizeof(ShuffleChunkT) >= sizeof(uint64_t)) {
78  if (sizeof(T) % sizeof(uint64_t) >= sizeof(uint32_t)) {
79  size_t Offset = sizeof(T) / sizeof(uint64_t) * sizeof(uint64_t);
80  ApplyToBytes(Offset, sizeof(uint32_t));
81  }
82  }
83  if (sizeof(ShuffleChunkT) >= sizeof(uint32_t)) {
84  if (sizeof(T) % sizeof(uint32_t) >= sizeof(uint16_t)) {
85  size_t Offset = sizeof(T) / sizeof(uint32_t) * sizeof(uint32_t);
86  ApplyToBytes(Offset, sizeof(uint16_t));
87  }
88  }
89  if (sizeof(ShuffleChunkT) >= sizeof(uint16_t)) {
90  if (sizeof(T) % sizeof(uint16_t) >= sizeof(uint8_t)) {
91  size_t Offset = sizeof(T) / sizeof(uint16_t) * sizeof(uint16_t);
92  ApplyToBytes(Offset, sizeof(uint8_t));
93  }
94  }
95 }
96 
97 template <typename Group> bool GroupAll(bool pred) {
98  return __spirv_GroupAll(group_scope<Group>::value, pred);
99 }
100 
101 template <typename Group> bool GroupAny(bool pred) {
102  return __spirv_GroupAny(group_scope<Group>::value, pred);
103 }
104 
105 // Native broadcasts map directly to a SPIR-V GroupBroadcast intrinsic
106 // FIXME: Do not special-case for half once all backends support all data types.
107 template <typename T>
108 using is_native_broadcast = bool_constant<detail::is_arithmetic<T>::value &&
109  !std::is_same<T, half>::value>;
110 
111 template <typename T, typename IdT = size_t>
112 using EnableIfNativeBroadcast = detail::enable_if_t<
113  is_native_broadcast<T>::value && std::is_integral<IdT>::value, T>;
114 
115 // Bitcast broadcasts can be implemented using a single SPIR-V GroupBroadcast
116 // intrinsic, but require type-punning via an appropriate integer type
117 template <typename T>
118 using is_bitcast_broadcast = bool_constant<
119  !is_native_broadcast<T>::value && std::is_trivially_copyable<T>::value &&
120  (sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8)>;
121 
122 template <typename T, typename IdT = size_t>
123 using EnableIfBitcastBroadcast = detail::enable_if_t<
124  is_bitcast_broadcast<T>::value && std::is_integral<IdT>::value, T>;
125 
126 template <typename T>
127 using ConvertToNativeBroadcastType_t = select_cl_scalar_integral_unsigned_t<T>;
128 
129 // Generic broadcasts may require multiple calls to SPIR-V GroupBroadcast
130 // intrinsics, and should use the fewest broadcasts possible
131 // - Loop over 64-bit chunks until remaining bytes < 64-bit
132 // - At most one 32-bit, 16-bit and 8-bit chunk left over
133 template <typename T>
134 using is_generic_broadcast =
135  bool_constant<!is_native_broadcast<T>::value &&
136  !is_bitcast_broadcast<T>::value &&
137  std::is_trivially_copyable<T>::value>;
138 
139 template <typename T, typename IdT = size_t>
140 using EnableIfGenericBroadcast = detail::enable_if_t<
141  is_generic_broadcast<T>::value && std::is_integral<IdT>::value, T>;
142 
143 // FIXME: Disable widening once all backends support all data types.
144 template <typename T>
145 using WidenOpenCLTypeTo32_t = conditional_t<
146  std::is_same<T, cl_char>() || std::is_same<T, cl_short>(), cl_int,
147  conditional_t<std::is_same<T, cl_uchar>() || std::is_same<T, cl_ushort>(),
148  cl_uint, T>>;
149 
150 // Broadcast with scalar local index
151 // Work-group supports any integral type
152 // Sub-group currently supports only uint32_t
153 template <typename Group> struct GroupId {
154  using type = size_t;
155 };
156 template <> struct GroupId<::sycl::ext::oneapi::sub_group> {
157  using type = uint32_t;
158 };
159 template <typename Group, typename T, typename IdT>
160 EnableIfNativeBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
161  using GroupIdT = typename GroupId<Group>::type;
162  GroupIdT GroupLocalId = static_cast<GroupIdT>(local_id);
163  using OCLT = detail::ConvertToOpenCLType_t<T>;
164  using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
165  using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
166  WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
167  OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
168  return __spirv_GroupBroadcast(group_scope<Group>::value, OCLX, OCLId);
169 }
170 template <typename Group, typename T, typename IdT>
171 EnableIfBitcastBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
172  using BroadcastT = ConvertToNativeBroadcastType_t<T>;
173  auto BroadcastX = bit_cast<BroadcastT>(x);
174  BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
175  return bit_cast<T>(Result);
176 }
177 template <typename Group, typename T, typename IdT>
178 EnableIfGenericBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
179  // Initialize with x to support type T without default constructor
180  T Result = x;
181  char *XBytes = reinterpret_cast<char *>(&x);
182  char *ResultBytes = reinterpret_cast<char *>(&Result);
183  auto BroadcastBytes = [=](size_t Offset, size_t Size) {
184  uint64_t BroadcastX, BroadcastResult;
185  std::memcpy(&BroadcastX, XBytes + Offset, Size);
186  BroadcastResult = GroupBroadcast<Group>(BroadcastX, local_id);
187  std::memcpy(ResultBytes + Offset, &BroadcastResult, Size);
188  };
189  GenericCall<T>(BroadcastBytes);
190  return Result;
191 }
192 
193 // Broadcast with vector local index
194 template <typename Group, typename T, int Dimensions>
195 EnableIfNativeBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
196  if (Dimensions == 1) {
197  return GroupBroadcast<Group>(x, local_id[0]);
198  }
199  using IdT = vec<size_t, Dimensions>;
200  using OCLT = detail::ConvertToOpenCLType_t<T>;
201  using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
202  using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
203  IdT VecId;
204  for (int i = 0; i < Dimensions; ++i) {
205  VecId[i] = local_id[Dimensions - i - 1];
206  }
207  WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
208  OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
209  return __spirv_GroupBroadcast(group_scope<Group>::value, OCLX, OCLId);
210 }
211 template <typename Group, typename T, int Dimensions>
212 EnableIfBitcastBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
213  using BroadcastT = ConvertToNativeBroadcastType_t<T>;
214  auto BroadcastX = bit_cast<BroadcastT>(x);
215  BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
216  return bit_cast<T>(Result);
217 }
218 template <typename Group, typename T, int Dimensions>
219 EnableIfGenericBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
220  if (Dimensions == 1) {
221  return GroupBroadcast<Group>(x, local_id[0]);
222  }
223  // Initialize with x to support type T without default constructor
224  T Result = x;
225  char *XBytes = reinterpret_cast<char *>(&x);
226  char *ResultBytes = reinterpret_cast<char *>(&Result);
227  auto BroadcastBytes = [=](size_t Offset, size_t Size) {
228  uint64_t BroadcastX, BroadcastResult;
229  std::memcpy(&BroadcastX, XBytes + Offset, Size);
230  BroadcastResult = GroupBroadcast<Group>(BroadcastX, local_id);
231  std::memcpy(ResultBytes + Offset, &BroadcastResult, Size);
232  };
233  GenericCall<T>(BroadcastBytes);
234  return Result;
235 }
236 
237 // Single happens-before means semantics should always apply to all spaces
238 // Although consume is unsupported, forwarding to acquire is valid
239 template <typename T>
240 static inline constexpr
241  typename std::enable_if<std::is_same<T, sycl::memory_order>::value,
243  getMemorySemanticsMask(T Order) {
245  switch (Order) {
246  case T::relaxed:
248  break;
249  case T::__consume_unsupported:
250  case T::acquire:
252  break;
253  case T::release:
255  break;
256  case T::acq_rel:
258  break;
259  case T::seq_cst:
261  break;
262  }
263  return static_cast<__spv::MemorySemanticsMask::Flag>(
267 }
268 
269 static inline constexpr __spv::Scope::Flag getScope(memory_scope Scope) {
270  switch (Scope) {
271  case memory_scope::work_item:
273  case memory_scope::sub_group:
274  return __spv::Scope::Subgroup;
275  case memory_scope::work_group:
277  case memory_scope::device:
278  return __spv::Scope::Device;
279  case memory_scope::system:
281  }
282 }
283 
284 template <typename T, access::address_space AddressSpace,
285  access::decorated IsDecorated>
286 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
287 AtomicCompareExchange(multi_ptr<T, AddressSpace, IsDecorated> MPtr,
288  memory_scope Scope, memory_order Success,
289  memory_order Failure, T Desired, T Expected) {
290  auto SPIRVSuccess = getMemorySemanticsMask(Success);
291  auto SPIRVFailure = getMemorySemanticsMask(Failure);
292  auto SPIRVScope = getScope(Scope);
293  auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
294  return __spirv_AtomicCompareExchange(Ptr, SPIRVScope, SPIRVSuccess,
295  SPIRVFailure, Desired, Expected);
296 }
297 
298 template <typename T, access::address_space AddressSpace,
299  access::decorated IsDecorated>
300 inline typename detail::enable_if_t<std::is_floating_point<T>::value, T>
301 AtomicCompareExchange(multi_ptr<T, AddressSpace, IsDecorated> MPtr,
302  memory_scope Scope, memory_order Success,
303  memory_order Failure, T Desired, T Expected) {
304  using I = detail::make_unsinged_integer_t<T>;
305  auto SPIRVSuccess = getMemorySemanticsMask(Success);
306  auto SPIRVFailure = getMemorySemanticsMask(Failure);
307  auto SPIRVScope = getScope(Scope);
308  auto *PtrInt = GetMultiPtrDecoratedAs<I>(MPtr);
309  I DesiredInt = bit_cast<I>(Desired);
310  I ExpectedInt = bit_cast<I>(Expected);
311  I ResultInt = __spirv_AtomicCompareExchange(
312  PtrInt, SPIRVScope, SPIRVSuccess, SPIRVFailure, DesiredInt, ExpectedInt);
313  return bit_cast<T>(ResultInt);
314 }
315 
316 template <typename T, access::address_space AddressSpace,
317  access::decorated IsDecorated>
318 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
319 AtomicLoad(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
320  memory_order Order) {
321  auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
322  auto SPIRVOrder = getMemorySemanticsMask(Order);
323  auto SPIRVScope = getScope(Scope);
324  return __spirv_AtomicLoad(Ptr, SPIRVScope, SPIRVOrder);
325 }
326 
327 template <typename T, access::address_space AddressSpace,
328  access::decorated IsDecorated>
329 inline typename detail::enable_if_t<std::is_floating_point<T>::value, T>
330 AtomicLoad(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
331  memory_order Order) {
332  using I = detail::make_unsinged_integer_t<T>;
333  auto *PtrInt = GetMultiPtrDecoratedAs<I>(MPtr);
334  auto SPIRVOrder = getMemorySemanticsMask(Order);
335  auto SPIRVScope = getScope(Scope);
336  I ResultInt = __spirv_AtomicLoad(PtrInt, SPIRVScope, SPIRVOrder);
337  return bit_cast<T>(ResultInt);
338 }
339 
340 template <typename T, access::address_space AddressSpace,
341  access::decorated IsDecorated>
342 inline typename detail::enable_if_t<std::is_integral<T>::value>
343 AtomicStore(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
344  memory_order Order, T Value) {
345  auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
346  auto SPIRVOrder = getMemorySemanticsMask(Order);
347  auto SPIRVScope = getScope(Scope);
348  __spirv_AtomicStore(Ptr, SPIRVScope, SPIRVOrder, Value);
349 }
350 
351 template <typename T, access::address_space AddressSpace,
352  access::decorated IsDecorated>
353 inline typename detail::enable_if_t<std::is_floating_point<T>::value>
354 AtomicStore(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
355  memory_order Order, T Value) {
356  using I = detail::make_unsinged_integer_t<T>;
357  auto *PtrInt = GetMultiPtrDecoratedAs<I>(MPtr);
358  auto SPIRVOrder = getMemorySemanticsMask(Order);
359  auto SPIRVScope = getScope(Scope);
360  I ValueInt = bit_cast<I>(Value);
361  __spirv_AtomicStore(PtrInt, SPIRVScope, SPIRVOrder, ValueInt);
362 }
363 
364 template <typename T, access::address_space AddressSpace,
365  access::decorated IsDecorated>
366 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
367 AtomicExchange(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
368  memory_order Order, T Value) {
369  auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
370  auto SPIRVOrder = getMemorySemanticsMask(Order);
371  auto SPIRVScope = getScope(Scope);
372  return __spirv_AtomicExchange(Ptr, SPIRVScope, SPIRVOrder, Value);
373 }
374 
375 template <typename T, access::address_space AddressSpace,
376  access::decorated IsDecorated>
377 inline typename detail::enable_if_t<std::is_floating_point<T>::value, T>
378 AtomicExchange(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
379  memory_order Order, T Value) {
380  using I = detail::make_unsinged_integer_t<T>;
381  auto *PtrInt = GetMultiPtrDecoratedAs<I>(MPtr);
382  auto SPIRVOrder = getMemorySemanticsMask(Order);
383  auto SPIRVScope = getScope(Scope);
384  I ValueInt = bit_cast<I>(Value);
385  I ResultInt =
386  __spirv_AtomicExchange(PtrInt, SPIRVScope, SPIRVOrder, ValueInt);
387  return bit_cast<T>(ResultInt);
388 }
389 
390 template <typename T, access::address_space AddressSpace,
391  access::decorated IsDecorated>
392 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
393 AtomicIAdd(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
394  memory_order Order, T Value) {
395  auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
396  auto SPIRVOrder = getMemorySemanticsMask(Order);
397  auto SPIRVScope = getScope(Scope);
398  return __spirv_AtomicIAdd(Ptr, SPIRVScope, SPIRVOrder, Value);
399 }
400 
401 template <typename T, access::address_space AddressSpace,
402  access::decorated IsDecorated>
403 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
404 AtomicISub(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
405  memory_order Order, T Value) {
406  auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
407  auto SPIRVOrder = getMemorySemanticsMask(Order);
408  auto SPIRVScope = getScope(Scope);
409  return __spirv_AtomicISub(Ptr, SPIRVScope, SPIRVOrder, Value);
410 }
411 
412 template <typename T, access::address_space AddressSpace,
413  access::decorated IsDecorated>
414 inline typename detail::enable_if_t<std::is_floating_point<T>::value, T>
415 AtomicFAdd(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
416  memory_order Order, T Value) {
417  auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
418  auto SPIRVOrder = getMemorySemanticsMask(Order);
419  auto SPIRVScope = getScope(Scope);
420  return __spirv_AtomicFAddEXT(Ptr, SPIRVScope, SPIRVOrder, Value);
421 }
422 
423 template <typename T, access::address_space AddressSpace,
424  access::decorated IsDecorated>
425 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
426 AtomicAnd(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
427  memory_order Order, T Value) {
428  auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
429  auto SPIRVOrder = getMemorySemanticsMask(Order);
430  auto SPIRVScope = getScope(Scope);
431  return __spirv_AtomicAnd(Ptr, SPIRVScope, SPIRVOrder, Value);
432 }
433 
434 template <typename T, access::address_space AddressSpace,
435  access::decorated IsDecorated>
436 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
437 AtomicOr(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
438  memory_order Order, T Value) {
439  auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
440  auto SPIRVOrder = getMemorySemanticsMask(Order);
441  auto SPIRVScope = getScope(Scope);
442  return __spirv_AtomicOr(Ptr, SPIRVScope, SPIRVOrder, Value);
443 }
444 
445 template <typename T, access::address_space AddressSpace,
446  access::decorated IsDecorated>
447 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
448 AtomicXor(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
449  memory_order Order, T Value) {
450  auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
451  auto SPIRVOrder = getMemorySemanticsMask(Order);
452  auto SPIRVScope = getScope(Scope);
453  return __spirv_AtomicXor(Ptr, SPIRVScope, SPIRVOrder, Value);
454 }
455 
456 template <typename T, access::address_space AddressSpace,
457  access::decorated IsDecorated>
458 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
459 AtomicMin(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
460  memory_order Order, T Value) {
461  auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
462  auto SPIRVOrder = getMemorySemanticsMask(Order);
463  auto SPIRVScope = getScope(Scope);
464  return __spirv_AtomicMin(Ptr, SPIRVScope, SPIRVOrder, Value);
465 }
466 
467 template <typename T, access::address_space AddressSpace,
468  access::decorated IsDecorated>
469 inline typename detail::enable_if_t<std::is_floating_point<T>::value, T>
470 AtomicMin(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
471  memory_order Order, T Value) {
472  auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
473  auto SPIRVOrder = getMemorySemanticsMask(Order);
474  auto SPIRVScope = getScope(Scope);
475  return __spirv_AtomicMin(Ptr, SPIRVScope, SPIRVOrder, Value);
476 }
477 
478 template <typename T, access::address_space AddressSpace,
479  access::decorated IsDecorated>
480 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
481 AtomicMax(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
482  memory_order Order, T Value) {
483  auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
484  auto SPIRVOrder = getMemorySemanticsMask(Order);
485  auto SPIRVScope = getScope(Scope);
486  return __spirv_AtomicMax(Ptr, SPIRVScope, SPIRVOrder, Value);
487 }
488 
489 template <typename T, access::address_space AddressSpace,
490  access::decorated IsDecorated>
491 inline typename detail::enable_if_t<std::is_floating_point<T>::value, T>
492 AtomicMax(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
493  memory_order Order, T Value) {
494  auto *Ptr = GetMultiPtrDecoratedAs<T>(MPtr);
495  auto SPIRVOrder = getMemorySemanticsMask(Order);
496  auto SPIRVScope = getScope(Scope);
497  return __spirv_AtomicMax(Ptr, SPIRVScope, SPIRVOrder, Value);
498 }
499 
500 // Native shuffles map directly to a shuffle intrinsic:
501 // - The Intel SPIR-V extension natively supports all arithmetic types.
502 // However, OpenCL extension natively supports float vectors,
503 // integer vectors, half scalar and double scalar.
504 // For double vectors we perform emulation with scalar version.
505 // - The CUDA shfl intrinsics do not support vectors, and we use the _i32
506 // variants for all scalar types
507 #ifndef __NVPTX__
508 
509 template <typename T>
510 struct TypeIsProhibitedForShuffleEmulation
511  : bool_constant<std::is_same_v<vector_element_t<T>, double>> {};
512 
513 template <typename T>
514 struct VecTypeIsProhibitedForShuffleEmulation
515  : bool_constant<
516  (detail::get_vec_size<T>::size > 1) &&
517  TypeIsProhibitedForShuffleEmulation<vector_element_t<T>>::value> {};
518 
519 template <typename T>
520 using EnableIfNativeShuffle =
521  std::enable_if_t<detail::is_arithmetic<T>::value &&
522  !VecTypeIsProhibitedForShuffleEmulation<T>::value,
523  T>;
524 
525 template <typename T>
526 using EnableIfVectorShuffle =
527  std::enable_if_t<VecTypeIsProhibitedForShuffleEmulation<T>::value, T>;
528 
529 #else // ifndef __NVPTX__
530 
531 template <typename T>
532 using EnableIfNativeShuffle = std::enable_if_t<
533  std::is_integral<T>::value && (sizeof(T) <= sizeof(int32_t)), T>;
534 
535 template <typename T>
536 using EnableIfVectorShuffle =
537  std::enable_if_t<detail::is_vector_arithmetic<T>::value, T>;
538 #endif // ifndef __NVPTX__
539 
540 // Bitcast shuffles can be implemented using a single SubgroupShuffle
541 // intrinsic, but require type-punning via an appropriate integer type
542 #ifndef __NVPTX__
543 template <typename T>
544 using EnableIfBitcastShuffle =
545  std::enable_if_t<!detail::is_arithmetic<T>::value &&
546  (std::is_trivially_copyable_v<T> &&
547  (sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4 ||
548  sizeof(T) == 8)),
549  T>;
550 #else
551 template <typename T>
552 using EnableIfBitcastShuffle =
553  std::enable_if_t<!(std::is_integral_v<T> &&
554  (sizeof(T) <= sizeof(int32_t))) &&
555  !detail::is_vector_arithmetic<T>::value &&
556  (std::is_trivially_copyable_v<T> &&
557  (sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4)),
558  T>;
559 #endif // ifndef __NVPTX__
560 
561 // Generic shuffles may require multiple calls to SubgroupShuffle
562 // intrinsics, and should use the fewest shuffles possible:
563 // - Loop over 64-bit chunks until remaining bytes < 64-bit
564 // - At most one 32-bit, 16-bit and 8-bit chunk left over
565 #ifndef __NVPTX__
566 template <typename T>
567 using EnableIfGenericShuffle =
568  std::enable_if_t<!detail::is_arithmetic<T>::value &&
569  !(std::is_trivially_copyable_v<T> &&
570  (sizeof(T) == 1 || sizeof(T) == 2 ||
571  sizeof(T) == 4 || sizeof(T) == 8)),
572  T>;
573 #else
574 template <typename T>
575 using EnableIfGenericShuffle = std::enable_if_t<
576  !(std::is_integral<T>::value && (sizeof(T) <= sizeof(int32_t))) &&
577  !detail::is_vector_arithmetic<T>::value &&
578  !(std::is_trivially_copyable_v<T> &&
579  (sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4)),
580  T>;
581 #endif
582 
583 #ifdef __NVPTX__
584 inline uint32_t membermask() {
585  // use a full mask as sync operations are required to be convergent and exited
586  // threads can safely be in the mask
587  return 0xFFFFFFFF;
588 }
589 #endif
590 
591 // Forward declarations for template overloadings
592 template <typename T>
593 EnableIfBitcastShuffle<T> SubgroupShuffle(T x, id<1> local_id);
594 
595 template <typename T>
596 EnableIfBitcastShuffle<T> SubgroupShuffleXor(T x, id<1> local_id);
597 
598 template <typename T>
599 EnableIfBitcastShuffle<T> SubgroupShuffleDown(T x, id<1> local_id);
600 
601 template <typename T>
602 EnableIfBitcastShuffle<T> SubgroupShuffleUp(T x, id<1> local_id);
603 
604 template <typename T>
605 EnableIfGenericShuffle<T> SubgroupShuffle(T x, id<1> local_id);
606 
607 template <typename T>
608 EnableIfGenericShuffle<T> SubgroupShuffleXor(T x, id<1> local_id);
609 
610 template <typename T>
611 EnableIfGenericShuffle<T> SubgroupShuffleDown(T x, id<1> local_id);
612 
613 template <typename T>
614 EnableIfGenericShuffle<T> SubgroupShuffleUp(T x, id<1> local_id);
615 
616 template <typename T>
617 EnableIfNativeShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
618 #ifndef __NVPTX__
619  using OCLT = detail::ConvertToOpenCLType_t<T>;
620  return __spirv_SubgroupShuffleINTEL(OCLT(x),
621  static_cast<uint32_t>(local_id.get(0)));
622 #else
623  return __nvvm_shfl_sync_idx_i32(membermask(), x, local_id.get(0), 0x1f);
624 #endif
625 }
626 
627 template <typename T>
628 EnableIfNativeShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
629 #ifndef __NVPTX__
630  using OCLT = detail::ConvertToOpenCLType_t<T>;
631  return __spirv_SubgroupShuffleXorINTEL(
632  OCLT(x), static_cast<uint32_t>(local_id.get(0)));
633 #else
634  return __nvvm_shfl_sync_bfly_i32(membermask(), x, local_id.get(0), 0x1f);
635 #endif
636 }
637 
638 template <typename T>
639 EnableIfNativeShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
640 #ifndef __NVPTX__
641  using OCLT = detail::ConvertToOpenCLType_t<T>;
642  return __spirv_SubgroupShuffleDownINTEL(OCLT(x), OCLT(x), delta);
643 #else
644  return __nvvm_shfl_sync_down_i32(membermask(), x, delta, 0x1f);
645 #endif
646 }
647 
648 template <typename T>
649 EnableIfNativeShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
650 #ifndef __NVPTX__
651  using OCLT = detail::ConvertToOpenCLType_t<T>;
652  return __spirv_SubgroupShuffleUpINTEL(OCLT(x), OCLT(x), delta);
653 #else
654  return __nvvm_shfl_sync_up_i32(membermask(), x, delta, 0);
655 #endif
656 }
657 
658 template <typename T>
659 EnableIfVectorShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
660  T result;
661  for (int s = 0; s < x.size(); ++s) {
662  result[s] = SubgroupShuffle(x[s], local_id);
663  }
664  return result;
665 }
666 
667 template <typename T>
668 EnableIfVectorShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
669  T result;
670  for (int s = 0; s < x.size(); ++s) {
671  result[s] = SubgroupShuffleXor(x[s], local_id);
672  }
673  return result;
674 }
675 
676 template <typename T>
677 EnableIfVectorShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
678  T result;
679  for (int s = 0; s < x.size(); ++s) {
680  result[s] = SubgroupShuffleDown(x[s], delta);
681  }
682  return result;
683 }
684 
685 template <typename T>
686 EnableIfVectorShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
687  T result;
688  for (int s = 0; s < x.size(); ++s) {
689  result[s] = SubgroupShuffleUp(x[s], delta);
690  }
691  return result;
692 }
693 
694 template <typename T>
695 using ConvertToNativeShuffleType_t = select_cl_scalar_integral_unsigned_t<T>;
696 
697 template <typename T>
698 EnableIfBitcastShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
699  using ShuffleT = ConvertToNativeShuffleType_t<T>;
700  auto ShuffleX = bit_cast<ShuffleT>(x);
701 #ifndef __NVPTX__
702  ShuffleT Result = __spirv_SubgroupShuffleINTEL(
703  ShuffleX, static_cast<uint32_t>(local_id.get(0)));
704 #else
705  ShuffleT Result =
706  __nvvm_shfl_sync_idx_i32(membermask(), ShuffleX, local_id.get(0), 0x1f);
707 #endif
708  return bit_cast<T>(Result);
709 }
710 
711 template <typename T>
712 EnableIfBitcastShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
713  using ShuffleT = ConvertToNativeShuffleType_t<T>;
714  auto ShuffleX = bit_cast<ShuffleT>(x);
715 #ifndef __NVPTX__
716  ShuffleT Result = __spirv_SubgroupShuffleXorINTEL(
717  ShuffleX, static_cast<uint32_t>(local_id.get(0)));
718 #else
719  ShuffleT Result =
720  __nvvm_shfl_sync_bfly_i32(membermask(), ShuffleX, local_id.get(0), 0x1f);
721 #endif
722  return bit_cast<T>(Result);
723 }
724 
725 template <typename T>
726 EnableIfBitcastShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
727  using ShuffleT = ConvertToNativeShuffleType_t<T>;
728  auto ShuffleX = bit_cast<ShuffleT>(x);
729 #ifndef __NVPTX__
730  ShuffleT Result = __spirv_SubgroupShuffleDownINTEL(ShuffleX, ShuffleX, delta);
731 #else
732  ShuffleT Result =
733  __nvvm_shfl_sync_down_i32(membermask(), ShuffleX, delta, 0x1f);
734 #endif
735  return bit_cast<T>(Result);
736 }
737 
738 template <typename T>
739 EnableIfBitcastShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
740  using ShuffleT = ConvertToNativeShuffleType_t<T>;
741  auto ShuffleX = bit_cast<ShuffleT>(x);
742 #ifndef __NVPTX__
743  ShuffleT Result = __spirv_SubgroupShuffleUpINTEL(ShuffleX, ShuffleX, delta);
744 #else
745  ShuffleT Result = __nvvm_shfl_sync_up_i32(membermask(), ShuffleX, delta, 0);
746 #endif
747  return bit_cast<T>(Result);
748 }
749 
750 template <typename T>
751 EnableIfGenericShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
752  T Result;
753  char *XBytes = reinterpret_cast<char *>(&x);
754  char *ResultBytes = reinterpret_cast<char *>(&Result);
755  auto ShuffleBytes = [=](size_t Offset, size_t Size) {
756  ShuffleChunkT ShuffleX, ShuffleResult;
757  std::memcpy(&ShuffleX, XBytes + Offset, Size);
758  ShuffleResult = SubgroupShuffle(ShuffleX, local_id);
759  std::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
760  };
761  GenericCall<T>(ShuffleBytes);
762  return Result;
763 }
764 
765 template <typename T>
766 EnableIfGenericShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
767  T Result;
768  char *XBytes = reinterpret_cast<char *>(&x);
769  char *ResultBytes = reinterpret_cast<char *>(&Result);
770  auto ShuffleBytes = [=](size_t Offset, size_t Size) {
771  ShuffleChunkT ShuffleX, ShuffleResult;
772  std::memcpy(&ShuffleX, XBytes + Offset, Size);
773  ShuffleResult = SubgroupShuffleXor(ShuffleX, local_id);
774  std::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
775  };
776  GenericCall<T>(ShuffleBytes);
777  return Result;
778 }
779 
780 template <typename T>
781 EnableIfGenericShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
782  T Result;
783  char *XBytes = reinterpret_cast<char *>(&x);
784  char *ResultBytes = reinterpret_cast<char *>(&Result);
785  auto ShuffleBytes = [=](size_t Offset, size_t Size) {
786  ShuffleChunkT ShuffleX, ShuffleResult;
787  std::memcpy(&ShuffleX, XBytes + Offset, Size);
788  ShuffleResult = SubgroupShuffleDown(ShuffleX, delta);
789  std::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
790  };
791  GenericCall<T>(ShuffleBytes);
792  return Result;
793 }
794 
795 template <typename T>
796 EnableIfGenericShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
797  T Result;
798  char *XBytes = reinterpret_cast<char *>(&x);
799  char *ResultBytes = reinterpret_cast<char *>(&Result);
800  auto ShuffleBytes = [=](size_t Offset, size_t Size) {
801  ShuffleChunkT ShuffleX, ShuffleResult;
802  std::memcpy(&ShuffleX, XBytes + Offset, Size);
803  ShuffleResult = SubgroupShuffleUp(ShuffleX, delta);
804  std::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
805  };
806  GenericCall<T>(ShuffleBytes);
807  return Result;
808 }
809 
810 } // namespace spirv
811 } // namespace detail
812 } // __SYCL_INLINE_VER_NAMESPACE(_V1)
813 } // namespace sycl
814 #endif // __SYCL_DEVICE_ONLY__
spirv_ops.hpp
__spv::MemorySemanticsMask::SubgroupMemory
@ SubgroupMemory
Definition: spirv_types.hpp:92
__spv::Scope::Workgroup
@ Workgroup
Definition: spirv_types.hpp:31
type_traits.hpp
__spirv_AtomicOr
T __spirv_AtomicOr(std::atomic< T > *Ptr, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag MS, T V)
Definition: atomic.hpp:130
sycl::_V1::detail::bool_constant
std::integral_constant< bool, V > bool_constant
Definition: stl_type_traits.hpp:40
sycl::_V1::ext::intel::experimental::esimd::lsc_scope::group
@ group
__SYCL_INLINE_VER_NAMESPACE
#define __SYCL_INLINE_VER_NAMESPACE(X)
Definition: defines_elementary.hpp:11
__spirv_AtomicAnd
T __spirv_AtomicAnd(std::atomic< T > *Ptr, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag MS, T V)
Definition: atomic.hpp:124
__spirv_AtomicXor
T __spirv_AtomicXor(std::atomic< T > *Ptr, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag MS, T V)
Definition: atomic.hpp:136
__spv::MemorySemanticsMask::Acquire
@ Acquire
Definition: spirv_types.hpp:87
sycl::_V1::Dimensions
class __SYCL_EBO __SYCL_SPECIAL_CLASS __SYCL_TYPE(local_accessor) local_accessor class __SYCL_EBO __SYCL_SPECIAL_CLASS Dimensions
Definition: accessor.hpp:2825
helpers.hpp
spirv_vars.hpp
sycl
---— Error handling, matching OpenCL plugin semantics.
Definition: access.hpp:14
__spv::MemorySemanticsMask::AcquireRelease
@ AcquireRelease
Definition: spirv_types.hpp:89
__spv::MemorySemanticsMask::CrossWorkgroupMemory
@ CrossWorkgroupMemory
Definition: spirv_types.hpp:94
sycl::_V1::detail::memcpy
void memcpy(void *Dst, const void *Src, std::size_t Size)
sycl::_V1::cl_int
std::int32_t cl_int
Definition: aliases.hpp:108
id.hpp
sycl::_V1::detail::enable_if_t
typename std::enable_if< B, T >::type enable_if_t
Definition: stl_type_traits.hpp:24
sycl::_V1::cl_uint
std::uint32_t cl_uint
Definition: aliases.hpp:109
__spv::Scope::Flag
Flag
Definition: spirv_types.hpp:28
__spv::Scope::Subgroup
@ Subgroup
Definition: spirv_types.hpp:32
__spirv_AtomicMax
T __spirv_AtomicMax(std::atomic< T > *Ptr, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag MS, T V)
Definition: atomic.hpp:155
generic_type_traits.hpp
__spirv_AtomicExchange
T __spirv_AtomicExchange(std::atomic< T > *Ptr, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag MS, T V)
Definition: atomic.hpp:106
__spirv_AtomicStore
void __spirv_AtomicStore(std::atomic< T > *Ptr, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag MS, T V)
Definition: atomic.hpp:94
__spv::Scope::Invocation
@ Invocation
Definition: spirv_types.hpp:33
sycl::_V1::access::decorated
decorated
Definition: access.hpp:59
__spv::MemorySemanticsMask::SequentiallyConsistent
@ SequentiallyConsistent
Definition: spirv_types.hpp:90
__spv::MemorySemanticsMask::None
@ None
Definition: spirv_types.hpp:86
__spirv_AtomicISub
T __spirv_AtomicISub(std::atomic< T > *Ptr, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag MS, T V)
Definition: atomic.hpp:118
__spv::MemorySemanticsMask::WorkgroupMemory
@ WorkgroupMemory
Definition: spirv_types.hpp:93
__spirv_AtomicIAdd
T __spirv_AtomicIAdd(std::atomic< T > *Ptr, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag MS, T V)
Definition: atomic.hpp:112
sycl::_V1::memory_scope
memory_scope
Definition: memory_enums.hpp:26
__spv::MemorySemanticsMask::Release
@ Release
Definition: spirv_types.hpp:88
__spv::Scope::CrossDevice
@ CrossDevice
Definition: spirv_types.hpp:29
__spv::MemorySemanticsMask::Flag
Flag
Definition: spirv_types.hpp:85
memory_enums.hpp
__spirv_AtomicLoad
T __spirv_AtomicLoad(const std::atomic< T > *Ptr, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag MS)
Definition: atomic.hpp:100
sycl::_V1::memory_order
memory_order
Definition: memory_enums.hpp:16
sycl::_V1::detail::conditional_t
typename std::conditional< B, T, F >::type conditional_t
Definition: stl_type_traits.hpp:27
__spirv_AtomicMin
T __spirv_AtomicMin(std::atomic< T > *Ptr, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag MS, T V)
Definition: atomic.hpp:142
sycl::_V1::Space
Space
Definition: multi_ptr.hpp:1316
__spv::Scope::Device
@ Device
Definition: spirv_types.hpp:30
sycl::_V1::ext::oneapi::experimental::matrix::scope_t::sub_group
@ sub_group
spirv_types.hpp
sycl::_V1::access::address_space
address_space
Definition: access.hpp:47