DPC++ Runtime
Runtime libraries for oneAPI DPC++
non_uniform_groups.hpp
Go to the documentation of this file.
1 //==--- non_uniform_groups.hpp --- SYCL extension for non-uniform groups ---==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #include <sycl/ext/oneapi/sub_group_mask.hpp> // for sub_group_mask
12 #include <sycl/marray.hpp> // for marray
13 #include <sycl/types.hpp> // for vec
14 
15 #include <stddef.h> // for size_t
16 #include <stdint.h> // for uint32_t
17 
18 namespace sycl {
19 inline namespace _V1 {
20 
21 namespace detail {
22 
24  sycl::marray<unsigned, 4> TmpMArray;
25  Mask.extract_bits(TmpMArray);
26  sycl::vec<unsigned, 4> MemberMask;
27  for (int i = 0; i < 4; ++i) {
28  MemberMask[i] = TmpMArray[i];
29  }
30  return MemberMask;
31 }
32 
33 #ifdef __SYCL_DEVICE_ONLY__
34 // TODO: This may need to be generalized beyond uint32_t for big masks
35 inline uint32_t CallerPositionInMask(ext::oneapi::sub_group_mask Mask) {
36  sycl::vec<unsigned, 4> MemberMask = ExtractMask(Mask);
37  return __spirv_GroupNonUniformBallotBitCount(
40 }
41 #endif
42 
43 template <typename NonUniformGroup>
44 inline ext::oneapi::sub_group_mask GetMask(NonUniformGroup Group) {
45  return Group.Mask;
46 }
47 
48 template <typename NonUniformGroup>
49 inline uint32_t IdToMaskPosition(NonUniformGroup Group, uint32_t Id) {
50  sycl::vec<unsigned, 4> MemberMask = ExtractMask(GetMask(Group));
51 #if defined(__NVPTX__)
52  return __nvvm_fns(MemberMask[0], 0, Id + 1);
53 #else
54  // TODO: This will need to be optimized
55  uint32_t Count = 0;
56  for (int i = 0; i < 4; ++i) {
57  for (int b = 0; b < 32; ++b) {
58  if (MemberMask[i] & (1 << b)) {
59  if (Count == Id) {
60  return i * 32 + b;
61  }
62  Count++;
63  }
64  }
65  }
66  return Count;
67 #endif
68 }
69 
70 } // namespace detail
71 
72 namespace ext::oneapi::experimental {
73 
74 // Forward declarations of non-uniform group types for algorithm definitions
75 template <typename ParentGroup> class ballot_group;
76 template <size_t PartitionSize, typename ParentGroup> class fixed_size_group;
77 template <typename ParentGroup> class tangle_group;
78 class opportunistic_group;
79 
80 } // namespace ext::oneapi::experimental
81 
82 } // namespace _V1
83 } // namespace sycl
Provides a cross-platform math array class template that works on SYCL devices as well as in host C++...
Definition: marray.hpp:49
class sycl::vec ///////////////////////// Provides a cross-patform vector class template that works e...
Definition: vector.hpp:361
uint32_t IdToMaskPosition(NonUniformGroup Group, uint32_t Id)
ext::oneapi::sub_group_mask GetMask(NonUniformGroup Group)
sycl::vec< unsigned, 4 > ExtractMask(ext::oneapi::sub_group_mask Mask)
auto autodecltype(a) b
Definition: access.hpp:18
void extract_bits(Type &bits, id< 1 > pos=0) const