DPC++ Runtime
Runtime libraries for oneAPI DPC++
masked_shuffles.hpp
Go to the documentation of this file.
1 //==--------- masked_shuffles.hpp - cuda masked shuffle algorithms ---------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
12 
13 namespace sycl {
14 inline namespace _V1 {
15 namespace detail {
16 
17 #define CUDA_SHFL_SYNC(SHUFFLE_INSTR) \
18  template <typename T> \
19  inline __SYCL_ALWAYS_INLINE T cuda_shfl_sync_##SHUFFLE_INSTR( \
20  unsigned int mask, T val, unsigned int shfl_param, int c) { \
21  T res; \
22  if constexpr (std::is_same_v<T, double>) { \
23  int x_a, x_b; \
24  asm("mov.b64 {%0,%1},%2;" : "=r"(x_a), "=r"(x_b) : "d"(val)); \
25  auto tmp_a = __nvvm_shfl_sync_##SHUFFLE_INSTR(mask, x_a, shfl_param, c); \
26  auto tmp_b = __nvvm_shfl_sync_##SHUFFLE_INSTR(mask, x_b, shfl_param, c); \
27  asm("mov.b64 %0,{%1,%2};" : "=d"(res) : "r"(tmp_a), "r"(tmp_b)); \
28  } else if constexpr (std::is_same_v<T, long> || \
29  std::is_same_v<T, unsigned long>) { \
30  int x_a, x_b; \
31  asm("mov.b64 {%0,%1},%2;" : "=r"(x_a), "=r"(x_b) : "l"(val)); \
32  auto tmp_a = __nvvm_shfl_sync_##SHUFFLE_INSTR(mask, x_a, shfl_param, c); \
33  auto tmp_b = __nvvm_shfl_sync_##SHUFFLE_INSTR(mask, x_b, shfl_param, c); \
34  asm("mov.b64 %0,{%1,%2};" : "=l"(res) : "r"(tmp_a), "r"(tmp_b)); \
35  } else if constexpr (std::is_same_v<T, half>) { \
36  short tmp_b16; \
37  asm("mov.b16 %0,%1;" : "=h"(tmp_b16) : "h"(val)); \
38  auto tmp_b32 = __nvvm_shfl_sync_##SHUFFLE_INSTR( \
39  mask, static_cast<int>(tmp_b16), shfl_param, c); \
40  asm("mov.b16 %0,%1;" : "=h"(res) : "h"(static_cast<short>(tmp_b32))); \
41  } else if constexpr (std::is_same_v<T, float>) { \
42  auto tmp_b32 = __nvvm_shfl_sync_##SHUFFLE_INSTR( \
43  mask, __nvvm_bitcast_f2i(val), shfl_param, c); \
44  res = __nvvm_bitcast_i2f(tmp_b32); \
45  } else { \
46  res = __nvvm_shfl_sync_##SHUFFLE_INSTR(mask, val, shfl_param, c); \
47  } \
48  return res; \
49  }
50 
51 CUDA_SHFL_SYNC(bfly_i32)
52 CUDA_SHFL_SYNC(up_i32)
53 CUDA_SHFL_SYNC(down_i32)
54 CUDA_SHFL_SYNC(idx_i32)
55 
56 #undef CUDA_SHFL_SYNC
57 
58 } // namespace detail
59 } // namespace _V1
60 } // namespace sycl
61 
62 #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
Definition: access.hpp:18