DPC++ Runtime
Runtime libraries for oneAPI DPC++
spirv.hpp
Go to the documentation of this file.
1 //===-- spirv.hpp - Helpers to generate SPIR-V instructions ----*- C++ -*--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 #include <CL/__spirv/spirv_ops.hpp>
16 #include <CL/sycl/id.hpp>
17 #include <CL/sycl/memory_enums.hpp>
18 #include <cstring>
19 
20 #ifdef __SYCL_DEVICE_ONLY__
22 namespace sycl {
23 namespace ext {
24 namespace oneapi {
25 struct sub_group;
26 } // namespace oneapi
27 } // namespace ext
28 
29 namespace detail {
30 namespace spirv {
31 
32 template <typename Group> struct group_scope {};
33 
34 template <int Dimensions> struct group_scope<group<Dimensions>> {
35  static constexpr __spv::Scope::Flag value = __spv::Scope::Flag::Workgroup;
36 };
37 
38 template <> struct group_scope<::cl::sycl::ext::oneapi::sub_group> {
39  static constexpr __spv::Scope::Flag value = __spv::Scope::Flag::Subgroup;
40 };
41 
42 // Generic shuffles and broadcasts may require multiple calls to
43 // intrinsics, and should use the fewest broadcasts possible
44 // - Loop over chunks until remaining bytes < chunk size
45 // - At most one 32-bit, 16-bit and 8-bit chunk left over
46 #ifndef __NVPTX__
47 using ShuffleChunkT = uint64_t;
48 #else
49 using ShuffleChunkT = uint32_t;
50 #endif
51 template <typename T, typename Functor>
52 void GenericCall(const Functor &ApplyToBytes) {
53  if (sizeof(T) >= sizeof(ShuffleChunkT)) {
54 #pragma unroll
55  for (size_t Offset = 0; Offset + sizeof(ShuffleChunkT) <= sizeof(T);
56  Offset += sizeof(ShuffleChunkT)) {
57  ApplyToBytes(Offset, sizeof(ShuffleChunkT));
58  }
59  }
60  if (sizeof(ShuffleChunkT) >= sizeof(uint64_t)) {
61  if (sizeof(T) % sizeof(uint64_t) >= sizeof(uint32_t)) {
62  size_t Offset = sizeof(T) / sizeof(uint64_t) * sizeof(uint64_t);
63  ApplyToBytes(Offset, sizeof(uint32_t));
64  }
65  }
66  if (sizeof(ShuffleChunkT) >= sizeof(uint32_t)) {
67  if (sizeof(T) % sizeof(uint32_t) >= sizeof(uint16_t)) {
68  size_t Offset = sizeof(T) / sizeof(uint32_t) * sizeof(uint32_t);
69  ApplyToBytes(Offset, sizeof(uint16_t));
70  }
71  }
72  if (sizeof(ShuffleChunkT) >= sizeof(uint16_t)) {
73  if (sizeof(T) % sizeof(uint16_t) >= sizeof(uint8_t)) {
74  size_t Offset = sizeof(T) / sizeof(uint16_t) * sizeof(uint16_t);
75  ApplyToBytes(Offset, sizeof(uint8_t));
76  }
77  }
78 }
79 
80 template <typename Group> bool GroupAll(bool pred) {
81  return __spirv_GroupAll(group_scope<Group>::value, pred);
82 }
83 
84 template <typename Group> bool GroupAny(bool pred) {
85  return __spirv_GroupAny(group_scope<Group>::value, pred);
86 }
87 
88 // Native broadcasts map directly to a SPIR-V GroupBroadcast intrinsic
89 // FIXME: Do not special-case for half once all backends support all data types.
90 template <typename T>
91 using is_native_broadcast = bool_constant<detail::is_arithmetic<T>::value &&
92  !std::is_same<T, half>::value>;
93 
94 template <typename T, typename IdT = size_t>
95 using EnableIfNativeBroadcast = detail::enable_if_t<
96  is_native_broadcast<T>::value && std::is_integral<IdT>::value, T>;
97 
98 // Bitcast broadcasts can be implemented using a single SPIR-V GroupBroadcast
99 // intrinsic, but require type-punning via an appropriate integer type
100 template <typename T>
101 using is_bitcast_broadcast = bool_constant<
102  !is_native_broadcast<T>::value && std::is_trivially_copyable<T>::value &&
103  (sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8)>;
104 
105 template <typename T, typename IdT = size_t>
106 using EnableIfBitcastBroadcast = detail::enable_if_t<
107  is_bitcast_broadcast<T>::value && std::is_integral<IdT>::value, T>;
108 
109 template <typename T>
110 using ConvertToNativeBroadcastType_t = select_cl_scalar_integral_unsigned_t<T>;
111 
112 // Generic broadcasts may require multiple calls to SPIR-V GroupBroadcast
113 // intrinsics, and should use the fewest broadcasts possible
114 // - Loop over 64-bit chunks until remaining bytes < 64-bit
115 // - At most one 32-bit, 16-bit and 8-bit chunk left over
116 template <typename T>
117 using is_generic_broadcast =
118  bool_constant<!is_native_broadcast<T>::value &&
119  !is_bitcast_broadcast<T>::value &&
120  std::is_trivially_copyable<T>::value>;
121 
122 template <typename T, typename IdT = size_t>
123 using EnableIfGenericBroadcast = detail::enable_if_t<
124  is_generic_broadcast<T>::value && std::is_integral<IdT>::value, T>;
125 
126 // FIXME: Disable widening once all backends support all data types.
127 template <typename T>
128 using WidenOpenCLTypeTo32_t = conditional_t<
129  std::is_same<T, cl_char>() || std::is_same<T, cl_short>(), cl_int,
130  conditional_t<std::is_same<T, cl_uchar>() || std::is_same<T, cl_ushort>(),
131  cl_uint, T>>;
132 
133 // Broadcast with scalar local index
134 // Work-group supports any integral type
135 // Sub-group currently supports only uint32_t
136 template <typename Group> struct GroupId { using type = size_t; };
137 template <> struct GroupId<::cl::sycl::ext::oneapi::sub_group> {
138  using type = uint32_t;
139 };
140 template <typename Group, typename T, typename IdT>
141 EnableIfNativeBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
142  using GroupIdT = typename GroupId<Group>::type;
143  GroupIdT GroupLocalId = static_cast<GroupIdT>(local_id);
144  using OCLT = detail::ConvertToOpenCLType_t<T>;
145  using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
146  using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
147  WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
148  OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
149  return __spirv_GroupBroadcast(group_scope<Group>::value, OCLX, OCLId);
150 }
151 template <typename Group, typename T, typename IdT>
152 EnableIfBitcastBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
153  using BroadcastT = ConvertToNativeBroadcastType_t<T>;
154  auto BroadcastX = bit_cast<BroadcastT>(x);
155  BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
156  return bit_cast<T>(Result);
157 }
158 template <typename Group, typename T, typename IdT>
159 EnableIfGenericBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
160  T Result;
161  char *XBytes = reinterpret_cast<char *>(&x);
162  char *ResultBytes = reinterpret_cast<char *>(&Result);
163  auto BroadcastBytes = [=](size_t Offset, size_t Size) {
164  uint64_t BroadcastX, BroadcastResult;
165  std::memcpy(&BroadcastX, XBytes + Offset, Size);
166  BroadcastResult = GroupBroadcast<Group>(BroadcastX, local_id);
167  std::memcpy(ResultBytes + Offset, &BroadcastResult, Size);
168  };
169  GenericCall<T>(BroadcastBytes);
170  return Result;
171 }
172 
173 // Broadcast with vector local index
174 template <typename Group, typename T, int Dimensions>
175 EnableIfNativeBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
176  if (Dimensions == 1) {
177  return GroupBroadcast<Group>(x, local_id[0]);
178  }
179  using IdT = vec<size_t, Dimensions>;
180  using OCLT = detail::ConvertToOpenCLType_t<T>;
181  using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
182  using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
183  IdT VecId;
184  for (int i = 0; i < Dimensions; ++i) {
185  VecId[i] = local_id[Dimensions - i - 1];
186  }
187  WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
188  OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
189  return __spirv_GroupBroadcast(group_scope<Group>::value, OCLX, OCLId);
190 }
191 template <typename Group, typename T, int Dimensions>
192 EnableIfBitcastBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
193  using BroadcastT = ConvertToNativeBroadcastType_t<T>;
194  auto BroadcastX = bit_cast<BroadcastT>(x);
195  BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
196  return bit_cast<T>(Result);
197 }
198 template <typename Group, typename T, int Dimensions>
199 EnableIfGenericBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
200  if (Dimensions == 1) {
201  return GroupBroadcast<Group>(x, local_id[0]);
202  }
203  T Result;
204  char *XBytes = reinterpret_cast<char *>(&x);
205  char *ResultBytes = reinterpret_cast<char *>(&Result);
206  auto BroadcastBytes = [=](size_t Offset, size_t Size) {
207  uint64_t BroadcastX, BroadcastResult;
208  std::memcpy(&BroadcastX, XBytes + Offset, Size);
209  BroadcastResult = GroupBroadcast<Group>(BroadcastX, local_id);
210  std::memcpy(ResultBytes + Offset, &BroadcastResult, Size);
211  };
212  GenericCall<T>(BroadcastBytes);
213  return Result;
214 }
215 
216 // Single happens-before means semantics should always apply to all spaces
217 // Although consume is unsupported, forwarding to acquire is valid
218 template <typename T>
219 static inline constexpr
220  typename std::enable_if<std::is_same<T, sycl::memory_order>::value,
222  getMemorySemanticsMask(T Order) {
224  switch (Order) {
225  case T::relaxed:
227  break;
228  case T::__consume_unsupported:
229  case T::acquire:
231  break;
232  case T::release:
234  break;
235  case T::acq_rel:
237  break;
238  case T::seq_cst:
240  break;
241  }
242  return static_cast<__spv::MemorySemanticsMask::Flag>(
246 }
247 
248 static inline constexpr __spv::Scope::Flag getScope(memory_scope Scope) {
249  switch (Scope) {
250  case memory_scope::work_item:
252  case memory_scope::sub_group:
253  return __spv::Scope::Subgroup;
254  case memory_scope::work_group:
257  return __spv::Scope::Device;
258  case memory_scope::system:
260  }
261 }
262 
263 template <typename T, access::address_space AddressSpace>
264 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
265 AtomicCompareExchange(multi_ptr<T, AddressSpace> MPtr, memory_scope Scope,
266  memory_order Success, memory_order Failure, T Desired,
267  T Expected) {
268  auto SPIRVSuccess = getMemorySemanticsMask(Success);
269  auto SPIRVFailure = getMemorySemanticsMask(Failure);
270  auto SPIRVScope = getScope(Scope);
271  auto *Ptr = MPtr.get();
272  return __spirv_AtomicCompareExchange(Ptr, SPIRVScope, SPIRVSuccess,
273  SPIRVFailure, Desired, Expected);
274 }
275 
276 template <typename T, access::address_space AddressSpace>
277 inline typename detail::enable_if_t<std::is_floating_point<T>::value, T>
278 AtomicCompareExchange(multi_ptr<T, AddressSpace> MPtr, memory_scope Scope,
279  memory_order Success, memory_order Failure, T Desired,
280  T Expected) {
281  using I = detail::make_unsinged_integer_t<T>;
282  auto SPIRVSuccess = getMemorySemanticsMask(Success);
283  auto SPIRVFailure = getMemorySemanticsMask(Failure);
284  auto SPIRVScope = getScope(Scope);
285  auto *PtrInt =
286  reinterpret_cast<typename multi_ptr<I, AddressSpace>::pointer_t>(
287  MPtr.get());
288  I DesiredInt = bit_cast<I>(Desired);
289  I ExpectedInt = bit_cast<I>(Expected);
290  I ResultInt = __spirv_AtomicCompareExchange(
291  PtrInt, SPIRVScope, SPIRVSuccess, SPIRVFailure, DesiredInt, ExpectedInt);
292  return bit_cast<T>(ResultInt);
293 }
294 
295 template <typename T, access::address_space AddressSpace>
296 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
297 AtomicLoad(multi_ptr<T, AddressSpace> MPtr, memory_scope Scope,
298  memory_order Order) {
299  auto *Ptr = MPtr.get();
300  auto SPIRVOrder = getMemorySemanticsMask(Order);
301  auto SPIRVScope = getScope(Scope);
302  return __spirv_AtomicLoad(Ptr, SPIRVScope, SPIRVOrder);
303 }
304 
305 template <typename T, access::address_space AddressSpace>
306 inline typename detail::enable_if_t<std::is_floating_point<T>::value, T>
307 AtomicLoad(multi_ptr<T, AddressSpace> MPtr, memory_scope Scope,
308  memory_order Order) {
309  using I = detail::make_unsinged_integer_t<T>;
310  auto *PtrInt =
311  reinterpret_cast<typename multi_ptr<I, AddressSpace>::pointer_t>(
312  MPtr.get());
313  auto SPIRVOrder = getMemorySemanticsMask(Order);
314  auto SPIRVScope = getScope(Scope);
315  I ResultInt = __spirv_AtomicLoad(PtrInt, SPIRVScope, SPIRVOrder);
316  return bit_cast<T>(ResultInt);
317 }
318 
319 template <typename T, access::address_space AddressSpace>
320 inline typename detail::enable_if_t<std::is_integral<T>::value>
321 AtomicStore(multi_ptr<T, AddressSpace> MPtr, memory_scope Scope,
322  memory_order Order, T Value) {
323  auto *Ptr = MPtr.get();
324  auto SPIRVOrder = getMemorySemanticsMask(Order);
325  auto SPIRVScope = getScope(Scope);
326  __spirv_AtomicStore(Ptr, SPIRVScope, SPIRVOrder, Value);
327 }
328 
329 template <typename T, access::address_space AddressSpace>
330 inline typename detail::enable_if_t<std::is_floating_point<T>::value>
331 AtomicStore(multi_ptr<T, AddressSpace> MPtr, memory_scope Scope,
332  memory_order Order, T Value) {
333  using I = detail::make_unsinged_integer_t<T>;
334  auto *PtrInt =
335  reinterpret_cast<typename multi_ptr<I, AddressSpace>::pointer_t>(
336  MPtr.get());
337  auto SPIRVOrder = getMemorySemanticsMask(Order);
338  auto SPIRVScope = getScope(Scope);
339  I ValueInt = bit_cast<I>(Value);
340  __spirv_AtomicStore(PtrInt, SPIRVScope, SPIRVOrder, ValueInt);
341 }
342 
343 template <typename T, access::address_space AddressSpace>
344 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
345 AtomicExchange(multi_ptr<T, AddressSpace> MPtr, memory_scope Scope,
346  memory_order Order, T Value) {
347  auto *Ptr = MPtr.get();
348  auto SPIRVOrder = getMemorySemanticsMask(Order);
349  auto SPIRVScope = getScope(Scope);
350  return __spirv_AtomicExchange(Ptr, SPIRVScope, SPIRVOrder, Value);
351 }
352 
353 template <typename T, access::address_space AddressSpace>
354 inline typename detail::enable_if_t<std::is_floating_point<T>::value, T>
355 AtomicExchange(multi_ptr<T, AddressSpace> MPtr, memory_scope Scope,
356  memory_order Order, T Value) {
357  using I = detail::make_unsinged_integer_t<T>;
358  auto *PtrInt =
359  reinterpret_cast<typename multi_ptr<I, AddressSpace>::pointer_t>(
360  MPtr.get());
361  auto SPIRVOrder = getMemorySemanticsMask(Order);
362  auto SPIRVScope = getScope(Scope);
363  I ValueInt = bit_cast<I>(Value);
364  I ResultInt =
365  __spirv_AtomicExchange(PtrInt, SPIRVScope, SPIRVOrder, ValueInt);
366  return bit_cast<T>(ResultInt);
367 }
368 
369 template <typename T, access::address_space AddressSpace>
370 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
371 AtomicIAdd(multi_ptr<T, AddressSpace> MPtr, memory_scope Scope,
372  memory_order Order, T Value) {
373  auto *Ptr = MPtr.get();
374  auto SPIRVOrder = getMemorySemanticsMask(Order);
375  auto SPIRVScope = getScope(Scope);
376  return __spirv_AtomicIAdd(Ptr, SPIRVScope, SPIRVOrder, Value);
377 }
378 
379 template <typename T, access::address_space AddressSpace>
380 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
381 AtomicISub(multi_ptr<T, AddressSpace> MPtr, memory_scope Scope,
382  memory_order Order, T Value) {
383  auto *Ptr = MPtr.get();
384  auto SPIRVOrder = getMemorySemanticsMask(Order);
385  auto SPIRVScope = getScope(Scope);
386  return __spirv_AtomicISub(Ptr, SPIRVScope, SPIRVOrder, Value);
387 }
388 
389 template <typename T, access::address_space AddressSpace>
390 inline typename detail::enable_if_t<std::is_floating_point<T>::value, T>
391 AtomicFAdd(multi_ptr<T, AddressSpace> MPtr, memory_scope Scope,
392  memory_order Order, T Value) {
393  auto *Ptr = MPtr.get();
394  auto SPIRVOrder = getMemorySemanticsMask(Order);
395  auto SPIRVScope = getScope(Scope);
396  return __spirv_AtomicFAddEXT(Ptr, SPIRVScope, SPIRVOrder, Value);
397 }
398 
399 template <typename T, access::address_space AddressSpace>
400 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
401 AtomicAnd(multi_ptr<T, AddressSpace> MPtr, memory_scope Scope,
402  memory_order Order, T Value) {
403  auto *Ptr = MPtr.get();
404  auto SPIRVOrder = getMemorySemanticsMask(Order);
405  auto SPIRVScope = getScope(Scope);
406  return __spirv_AtomicAnd(Ptr, SPIRVScope, SPIRVOrder, Value);
407 }
408 
409 template <typename T, access::address_space AddressSpace>
410 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
411 AtomicOr(multi_ptr<T, AddressSpace> MPtr, memory_scope Scope,
412  memory_order Order, T Value) {
413  auto *Ptr = MPtr.get();
414  auto SPIRVOrder = getMemorySemanticsMask(Order);
415  auto SPIRVScope = getScope(Scope);
416  return __spirv_AtomicOr(Ptr, SPIRVScope, SPIRVOrder, Value);
417 }
418 
419 template <typename T, access::address_space AddressSpace>
420 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
421 AtomicXor(multi_ptr<T, AddressSpace> MPtr, memory_scope Scope,
422  memory_order Order, T Value) {
423  auto *Ptr = MPtr.get();
424  auto SPIRVOrder = getMemorySemanticsMask(Order);
425  auto SPIRVScope = getScope(Scope);
426  return __spirv_AtomicXor(Ptr, SPIRVScope, SPIRVOrder, Value);
427 }
428 
429 template <typename T, access::address_space AddressSpace>
430 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
431 AtomicMin(multi_ptr<T, AddressSpace> MPtr, memory_scope Scope,
432  memory_order Order, T Value) {
433  auto *Ptr = MPtr.get();
434  auto SPIRVOrder = getMemorySemanticsMask(Order);
435  auto SPIRVScope = getScope(Scope);
436  return __spirv_AtomicMin(Ptr, SPIRVScope, SPIRVOrder, Value);
437 }
438 
439 template <typename T, access::address_space AddressSpace>
440 inline typename detail::enable_if_t<std::is_floating_point<T>::value, T>
441 AtomicMin(multi_ptr<T, AddressSpace> MPtr, memory_scope Scope,
442  memory_order Order, T Value) {
443  auto *Ptr = MPtr.get();
444  auto SPIRVOrder = getMemorySemanticsMask(Order);
445  auto SPIRVScope = getScope(Scope);
446  return __spirv_AtomicMin(Ptr, SPIRVScope, SPIRVOrder, Value);
447 }
448 
449 template <typename T, access::address_space AddressSpace>
450 inline typename detail::enable_if_t<std::is_integral<T>::value, T>
451 AtomicMax(multi_ptr<T, AddressSpace> MPtr, memory_scope Scope,
452  memory_order Order, T Value) {
453  auto *Ptr = MPtr.get();
454  auto SPIRVOrder = getMemorySemanticsMask(Order);
455  auto SPIRVScope = getScope(Scope);
456  return __spirv_AtomicMax(Ptr, SPIRVScope, SPIRVOrder, Value);
457 }
458 
459 template <typename T, access::address_space AddressSpace>
460 inline typename detail::enable_if_t<std::is_floating_point<T>::value, T>
461 AtomicMax(multi_ptr<T, AddressSpace> MPtr, memory_scope Scope,
462  memory_order Order, T Value) {
463  auto *Ptr = MPtr.get();
464  auto SPIRVOrder = getMemorySemanticsMask(Order);
465  auto SPIRVScope = getScope(Scope);
466  return __spirv_AtomicMax(Ptr, SPIRVScope, SPIRVOrder, Value);
467 }
468 
469 // Native shuffles map directly to a shuffle intrinsic:
470 // - The Intel SPIR-V extension natively supports all arithmetic types
471 // - The CUDA shfl intrinsics do not support vectors, and we use the _i32
472 // variants for all scalar types
473 #ifndef __NVPTX__
474 template <typename T>
475 using EnableIfNativeShuffle =
476  detail::enable_if_t<detail::is_arithmetic<T>::value, T>;
477 #else
478 template <typename T>
479 using EnableIfNativeShuffle = detail::enable_if_t<
480  std::is_integral<T>::value && (sizeof(T) <= sizeof(int32_t)), T>;
481 
482 template <typename T>
483 using EnableIfVectorShuffle =
484  detail::enable_if_t<detail::is_vector_arithmetic<T>::value, T>;
485 #endif
486 
487 #ifdef __NVPTX__
488 inline uint32_t membermask() {
489  // use a full mask as sync operations are required to be convergent and exited
490  // threads can safely be in the mask
491  return 0xFFFFFFFF;
492 }
493 #endif
494 
495 template <typename T>
496 EnableIfNativeShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
497 #ifndef __NVPTX__
498  using OCLT = detail::ConvertToOpenCLType_t<T>;
499  return __spirv_SubgroupShuffleINTEL(OCLT(x),
500  static_cast<uint32_t>(local_id.get(0)));
501 #else
502  return __nvvm_shfl_sync_idx_i32(membermask(), x, local_id.get(0), 0x1f);
503 #endif
504 }
505 
506 template <typename T>
507 EnableIfNativeShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
508 #ifndef __NVPTX__
509  using OCLT = detail::ConvertToOpenCLType_t<T>;
510  return __spirv_SubgroupShuffleXorINTEL(
511  OCLT(x), static_cast<uint32_t>(local_id.get(0)));
512 #else
513  return __nvvm_shfl_sync_bfly_i32(membermask(), x, local_id.get(0), 0x1f);
514 #endif
515 }
516 
517 template <typename T>
518 EnableIfNativeShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
519 #ifndef __NVPTX__
520  using OCLT = detail::ConvertToOpenCLType_t<T>;
521  return __spirv_SubgroupShuffleDownINTEL(OCLT(x), OCLT(x), delta);
522 #else
523  return __nvvm_shfl_sync_down_i32(membermask(), x, delta, 0x1f);
524 #endif
525 }
526 
527 template <typename T>
528 EnableIfNativeShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
529 #ifndef __NVPTX__
530  using OCLT = detail::ConvertToOpenCLType_t<T>;
531  return __spirv_SubgroupShuffleUpINTEL(OCLT(x), OCLT(x), delta);
532 #else
533  return __nvvm_shfl_sync_up_i32(membermask(), x, delta, 0);
534 #endif
535 }
536 
537 #ifdef __NVPTX__
538 template <typename T>
539 EnableIfVectorShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
540  T result;
541  for (int s = 0; s < x.get_size(); ++s) {
542  result[s] = SubgroupShuffle(x[s], local_id);
543  }
544  return result;
545 }
546 
547 template <typename T>
548 EnableIfVectorShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
549  T result;
550  for (int s = 0; s < x.get_size(); ++s) {
551  result[s] = SubgroupShuffleXor(x[s], local_id);
552  }
553  return result;
554 }
555 
556 template <typename T>
557 EnableIfVectorShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
558  T result;
559  for (int s = 0; s < x.get_size(); ++s) {
560  result[s] = SubgroupShuffleDown(x[s], delta);
561  }
562  return result;
563 }
564 
565 template <typename T>
566 EnableIfVectorShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
567  T result;
568  for (int s = 0; s < x.get_size(); ++s) {
569  result[s] = SubgroupShuffleUp(x[s], delta);
570  }
571  return result;
572 }
573 #endif
574 
575 // Bitcast shuffles can be implemented using a single SubgroupShuffle
576 // intrinsic, but require type-punning via an appropriate integer type
577 #ifndef __NVPTX__
578 template <typename T>
579 using EnableIfBitcastShuffle =
580  detail::enable_if_t<!detail::is_arithmetic<T>::value &&
581  (std::is_trivially_copyable<T>::value &&
582  (sizeof(T) == 1 || sizeof(T) == 2 ||
583  sizeof(T) == 4 || sizeof(T) == 8)),
584  T>;
585 #else
586 template <typename T>
587 using EnableIfBitcastShuffle = detail::enable_if_t<
588  !(std::is_integral<T>::value && (sizeof(T) <= sizeof(int32_t))) &&
589  !detail::is_vector_arithmetic<T>::value &&
590  (std::is_trivially_copyable<T>::value &&
591  (sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4)),
592  T>;
593 #endif
594 
595 template <typename T>
596 using ConvertToNativeShuffleType_t = select_cl_scalar_integral_unsigned_t<T>;
597 
598 template <typename T>
599 EnableIfBitcastShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
600  using ShuffleT = ConvertToNativeShuffleType_t<T>;
601  auto ShuffleX = bit_cast<ShuffleT>(x);
602 #ifndef __NVPTX__
603  ShuffleT Result = __spirv_SubgroupShuffleINTEL(
604  ShuffleX, static_cast<uint32_t>(local_id.get(0)));
605 #else
606  ShuffleT Result =
607  __nvvm_shfl_sync_idx_i32(membermask(), ShuffleX, local_id.get(0), 0x1f);
608 #endif
609  return bit_cast<T>(Result);
610 }
611 
612 template <typename T>
613 EnableIfBitcastShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
614  using ShuffleT = ConvertToNativeShuffleType_t<T>;
615  auto ShuffleX = bit_cast<ShuffleT>(x);
616 #ifndef __NVPTX__
617  ShuffleT Result = __spirv_SubgroupShuffleXorINTEL(
618  ShuffleX, static_cast<uint32_t>(local_id.get(0)));
619 #else
620  ShuffleT Result =
621  __nvvm_shfl_sync_bfly_i32(membermask(), ShuffleX, local_id.get(0), 0x1f);
622 #endif
623  return bit_cast<T>(Result);
624 }
625 
626 template <typename T>
627 EnableIfBitcastShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
628  using ShuffleT = ConvertToNativeShuffleType_t<T>;
629  auto ShuffleX = bit_cast<ShuffleT>(x);
630 #ifndef __NVPTX__
631  ShuffleT Result = __spirv_SubgroupShuffleDownINTEL(ShuffleX, ShuffleX, delta);
632 #else
633  ShuffleT Result =
634  __nvvm_shfl_sync_down_i32(membermask(), ShuffleX, delta, 0x1f);
635 #endif
636  return bit_cast<T>(Result);
637 }
638 
639 template <typename T>
640 EnableIfBitcastShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
641  using ShuffleT = ConvertToNativeShuffleType_t<T>;
642  auto ShuffleX = bit_cast<ShuffleT>(x);
643 #ifndef __NVPTX__
644  ShuffleT Result = __spirv_SubgroupShuffleUpINTEL(ShuffleX, ShuffleX, delta);
645 #else
646  ShuffleT Result = __nvvm_shfl_sync_up_i32(membermask(), ShuffleX, delta, 0);
647 #endif
648  return bit_cast<T>(Result);
649 }
650 
651 // Generic shuffles may require multiple calls to SubgroupShuffle
652 // intrinsics, and should use the fewest shuffles possible:
653 // - Loop over 64-bit chunks until remaining bytes < 64-bit
654 // - At most one 32-bit, 16-bit and 8-bit chunk left over
655 #ifndef __NVPTX__
656 template <typename T>
657 using EnableIfGenericShuffle =
658  detail::enable_if_t<!detail::is_arithmetic<T>::value &&
659  !(std::is_trivially_copyable<T>::value &&
660  (sizeof(T) == 1 || sizeof(T) == 2 ||
661  sizeof(T) == 4 || sizeof(T) == 8)),
662  T>;
663 #else
664 template <typename T>
665 using EnableIfGenericShuffle = detail::enable_if_t<
666  !(std::is_integral<T>::value && (sizeof(T) <= sizeof(int32_t))) &&
667  !detail::is_vector_arithmetic<T>::value &&
668  !(std::is_trivially_copyable<T>::value &&
669  (sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4)),
670  T>;
671 #endif
672 
673 template <typename T>
674 EnableIfGenericShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
675  T Result;
676  char *XBytes = reinterpret_cast<char *>(&x);
677  char *ResultBytes = reinterpret_cast<char *>(&Result);
678  auto ShuffleBytes = [=](size_t Offset, size_t Size) {
679  ShuffleChunkT ShuffleX, ShuffleResult;
680  std::memcpy(&ShuffleX, XBytes + Offset, Size);
681  ShuffleResult = SubgroupShuffle(ShuffleX, local_id);
682  std::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
683  };
684  GenericCall<T>(ShuffleBytes);
685  return Result;
686 }
687 
688 template <typename T>
689 EnableIfGenericShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
690  T Result;
691  char *XBytes = reinterpret_cast<char *>(&x);
692  char *ResultBytes = reinterpret_cast<char *>(&Result);
693  auto ShuffleBytes = [=](size_t Offset, size_t Size) {
694  ShuffleChunkT ShuffleX, ShuffleResult;
695  std::memcpy(&ShuffleX, XBytes + Offset, Size);
696  ShuffleResult = SubgroupShuffleXor(ShuffleX, local_id);
697  std::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
698  };
699  GenericCall<T>(ShuffleBytes);
700  return Result;
701 }
702 
703 template <typename T>
704 EnableIfGenericShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
705  T Result;
706  char *XBytes = reinterpret_cast<char *>(&x);
707  char *ResultBytes = reinterpret_cast<char *>(&Result);
708  auto ShuffleBytes = [=](size_t Offset, size_t Size) {
709  ShuffleChunkT ShuffleX, ShuffleResult;
710  std::memcpy(&ShuffleX, XBytes + Offset, Size);
711  ShuffleResult = SubgroupShuffleDown(ShuffleX, delta);
712  std::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
713  };
714  GenericCall<T>(ShuffleBytes);
715  return Result;
716 }
717 
718 template <typename T>
719 EnableIfGenericShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
720  T Result;
721  char *XBytes = reinterpret_cast<char *>(&x);
722  char *ResultBytes = reinterpret_cast<char *>(&Result);
723  auto ShuffleBytes = [=](size_t Offset, size_t Size) {
724  ShuffleChunkT ShuffleX, ShuffleResult;
725  std::memcpy(&ShuffleX, XBytes + Offset, Size);
726  ShuffleResult = SubgroupShuffleUp(ShuffleX, delta);
727  std::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
728  };
729  GenericCall<T>(ShuffleBytes);
730  return Result;
731 }
732 
733 } // namespace spirv
734 } // namespace detail
735 } // namespace sycl
736 } // __SYCL_INLINE_NAMESPACE(cl)
737 #endif // __SYCL_DEVICE_ONLY__
spirv_ops.hpp
__spirv_AtomicStore
void __spirv_AtomicStore(std::atomic< T > *Ptr, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag MS, T V)
Definition: atomic.hpp:91
__spv::MemorySemanticsMask::SubgroupMemory
@ SubgroupMemory
Definition: spirv_types.hpp:91
__spv::Scope::Workgroup
@ Workgroup
Definition: spirv_types.hpp:30
cl::sycl::memory_order
memory_order
Definition: memory_enums.hpp:16
T
cl::sycl::info::device
device
Definition: info_desc.hpp:53
type_traits.hpp
cl::sycl
Definition: access.hpp:14
cl::sycl::memory_scope::sub_group
@ sub_group
__spv::MemorySemanticsMask::Acquire
@ Acquire
Definition: spirv_types.hpp:86
helpers.hpp
spirv_vars.hpp
sycl
Definition: invoke_simd.hpp:68
__spv::MemorySemanticsMask::AcquireRelease
@ AcquireRelease
Definition: spirv_types.hpp:88
__spv::MemorySemanticsMask::CrossWorkgroupMemory
@ CrossWorkgroupMemory
Definition: spirv_types.hpp:93
id.hpp
cl::sycl::detail::memcpy
void memcpy(void *Dst, const void *Src, std::size_t Size)
__spv::Scope::Flag
Flag
Definition: spirv_types.hpp:27
cl::sycl::memory_scope
memory_scope
Definition: memory_enums.hpp:26
cl::sycl::detail::conditional_t
typename std::conditional< B, T, F >::type conditional_t
Definition: stl_type_traits.hpp:27
cl::sycl::detail::bool_constant
std::integral_constant< bool, V > bool_constant
Definition: stl_type_traits.hpp:40
__spv::Scope::Subgroup
@ Subgroup
Definition: spirv_types.hpp:31
generic_type_traits.hpp
cl::__ESIMD_ENS::lsc_scope::group
@ group
__spirv_AtomicExchange
T __spirv_AtomicExchange(std::atomic< T > *Ptr, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag MS, T V)
Definition: atomic.hpp:103
cl
We provide new interfaces for matrix muliply in this patch:
Definition: access.hpp:13
__spirv_AtomicISub
T __spirv_AtomicISub(std::atomic< T > *Ptr, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag MS, T V)
Definition: atomic.hpp:115
__spv::Scope::Invocation
@ Invocation
Definition: spirv_types.hpp:32
__spirv_AtomicMax
T __spirv_AtomicMax(std::atomic< T > *Ptr, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag MS, T V)
Definition: atomic.hpp:152
__spirv_AtomicIAdd
T __spirv_AtomicIAdd(std::atomic< T > *Ptr, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag MS, T V)
Definition: atomic.hpp:109
__spv::MemorySemanticsMask::SequentiallyConsistent
@ SequentiallyConsistent
Definition: spirv_types.hpp:89
__spv::MemorySemanticsMask::None
@ None
Definition: spirv_types.hpp:85
cl::sycl::ext::oneapi::sub_group
Definition: sub_group.hpp:108
__spv::MemorySemanticsMask::WorkgroupMemory
@ WorkgroupMemory
Definition: spirv_types.hpp:92
__spirv_AtomicMin
T __spirv_AtomicMin(std::atomic< T > *Ptr, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag MS, T V)
Definition: atomic.hpp:139
cl::sycl::cl_int
std::int32_t cl_int
Definition: aliases.hpp:82
__spv::MemorySemanticsMask::Release
@ Release
Definition: spirv_types.hpp:87
__spv::Scope::CrossDevice
@ CrossDevice
Definition: spirv_types.hpp:28
__spirv_AtomicAnd
T __spirv_AtomicAnd(std::atomic< T > *Ptr, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag MS, T V)
Definition: atomic.hpp:121
__spv::MemorySemanticsMask::Flag
Flag
Definition: spirv_types.hpp:84
__spirv_AtomicLoad
T __spirv_AtomicLoad(const std::atomic< T > *Ptr, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag MS)
Definition: atomic.hpp:97
memory_enums.hpp
cl::sycl::cl_uint
std::uint32_t cl_uint
Definition: aliases.hpp:83
__spirv_AtomicXor
T __spirv_AtomicXor(std::atomic< T > *Ptr, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag MS, T V)
Definition: atomic.hpp:133
__spirv_AtomicOr
T __spirv_AtomicOr(std::atomic< T > *Ptr, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag MS, T V)
Definition: atomic.hpp:127
cl::sycl::detail::enable_if_t
typename std::enable_if< B, T >::type enable_if_t
Definition: stl_type_traits.hpp:24
cl::sycl::Dimensions
Dimensions
Definition: backend.hpp:138
__spv::Scope::Device
@ Device
Definition: spirv_types.hpp:29
spirv_types.hpp
__SYCL_INLINE_NAMESPACE
#define __SYCL_INLINE_NAMESPACE(X)
Definition: defines_elementary.hpp:12