DPC++ Runtime
Runtime libraries for oneAPI DPC++
fixed_size_group.hpp
Go to the documentation of this file.
1 //==--- fixed_size_group.hpp --- SYCL extension for non-uniform groups -----==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #include <sycl/aspects.hpp>
12 #include <sycl/detail/pi.h> // for PI_ERROR_INVALID_DEVICE
13 #include <sycl/detail/type_traits.hpp> // for is_fixed_size_group, is_group
14 #include <sycl/exception.hpp> // for runtime_error
15 #include <sycl/ext/oneapi/sub_group_mask.hpp> // for sub_group_mask
16 #include <sycl/id.hpp> // for id
17 #include <sycl/memory_enums.hpp> // for memory_scope
18 #include <sycl/range.hpp> // for range
19 #include <sycl/sub_group.hpp> // for sub_group
20 
21 #include <stddef.h> // for size_t
22 #include <type_traits> // for enable_if_t, true_type, dec...
23 
24 namespace sycl {
25 inline namespace _V1 {
26 namespace ext::oneapi::experimental {
27 
28 template <size_t PartitionSize, typename ParentGroup> class fixed_size_group;
29 
30 template <size_t PartitionSize, typename Group>
31 #ifdef __SYCL_DEVICE_ONLY__
32 [[__sycl_detail__::__uses_aspects__(sycl::aspect::ext_oneapi_fixed_size_group)]]
33 #endif
34 inline std::enable_if_t<sycl::is_group_v<std::decay_t<Group>> &&
35  std::is_same_v<Group, sycl::sub_group>,
36  fixed_size_group<PartitionSize, Group>>
37 get_fixed_size_group(Group group);
38 
39 template <size_t PartitionSize, typename ParentGroup> class fixed_size_group {
40 public:
41  using id_type = id<1>;
43  using linear_id_type = typename ParentGroup::linear_id_type;
44  static constexpr int dimensions = 1;
46 
48 #ifdef __SYCL_DEVICE_ONLY__
49  return __spirv_SubgroupLocalInvocationId() / PartitionSize;
50 #else
51  throw runtime_error("Non-uniform groups are not supported on host device.",
52  PI_ERROR_INVALID_DEVICE);
53 #endif
54  }
55 
57 #ifdef __SYCL_DEVICE_ONLY__
58  return __spirv_SubgroupLocalInvocationId() % PartitionSize;
59 #else
60  throw runtime_error("Non-uniform groups are not supported on host device.",
61  PI_ERROR_INVALID_DEVICE);
62 #endif
63  }
64 
66 #ifdef __SYCL_DEVICE_ONLY__
67  return __spirv_SubgroupSize() / PartitionSize;
68 #else
69  throw runtime_error("Non-uniform groups are not supported on host device.",
70  PI_ERROR_INVALID_DEVICE);
71 #endif
72  }
73 
75 #ifdef __SYCL_DEVICE_ONLY__
76  return PartitionSize;
77 #else
78  throw runtime_error("Non-uniform groups are not supported on host device.",
79  PI_ERROR_INVALID_DEVICE);
80 #endif
81  }
82 
84 #ifdef __SYCL_DEVICE_ONLY__
85  return static_cast<linear_id_type>(get_group_id()[0]);
86 #else
87  throw runtime_error("Non-uniform groups are not supported on host device.",
88  PI_ERROR_INVALID_DEVICE);
89 #endif
90  }
91 
93 #ifdef __SYCL_DEVICE_ONLY__
94  return static_cast<linear_id_type>(get_local_id()[0]);
95 #else
96  throw runtime_error("Non-uniform groups are not supported on host device.",
97  PI_ERROR_INVALID_DEVICE);
98 #endif
99  }
100 
102 #ifdef __SYCL_DEVICE_ONLY__
103  return static_cast<linear_id_type>(get_group_range()[0]);
104 #else
105  throw runtime_error("Non-uniform groups are not supported on host device.",
106  PI_ERROR_INVALID_DEVICE);
107 #endif
108  }
109 
111 #ifdef __SYCL_DEVICE_ONLY__
112  return static_cast<linear_id_type>(get_local_range()[0]);
113 #else
114  throw runtime_error("Non-uniform groups are not supported on host device.",
115  PI_ERROR_INVALID_DEVICE);
116 #endif
117  }
118 
119  bool leader() const {
120 #ifdef __SYCL_DEVICE_ONLY__
121  return get_local_linear_id() == 0;
122 #else
123  throw runtime_error("Non-uniform groups are not supported on host device.",
124  PI_ERROR_INVALID_DEVICE);
125 #endif
126  }
127 
128 protected:
129 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
130  sub_group_mask Mask;
131 #endif
132 
133 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
134  fixed_size_group(ext::oneapi::sub_group_mask mask) : Mask(mask) {}
135 #else
137 #endif
138 
140  get_fixed_size_group<PartitionSize, ParentGroup>(ParentGroup g);
141 
143  sycl::detail::GetMask<fixed_size_group<PartitionSize, ParentGroup>>(
145 };
146 
147 template <size_t PartitionSize, typename Group>
148 inline std::enable_if_t<sycl::is_group_v<std::decay_t<Group>> &&
149  std::is_same_v<Group, sycl::sub_group>,
152  (void)group;
153 #ifdef __SYCL_DEVICE_ONLY__
154 #if defined(__NVPTX__)
155  uint32_t loc_id = group.get_local_linear_id();
156  uint32_t loc_size = group.get_local_linear_range();
157  uint32_t bits = PartitionSize == 32
158  ? 0xffffffff
159  : ((1 << PartitionSize) - 1)
160  << ((loc_id / PartitionSize) * PartitionSize);
161 
163  sycl::detail::Builder::createSubGroupMask<ext::oneapi::sub_group_mask>(
164  bits, loc_size));
165 #else
167 #endif
168 #else
169  throw runtime_error("Non-uniform groups are not supported on host device.",
170  PI_ERROR_INVALID_DEVICE);
171 #endif
172 }
173 
174 template <size_t PartitionSize, typename ParentGroup>
175 struct is_user_constructed_group<fixed_size_group<PartitionSize, ParentGroup>>
176  : std::true_type {};
177 
178 } // namespace ext::oneapi::experimental
179 
180 namespace detail {
181 template <size_t PartitionSize, typename ParentGroup>
183  ext::oneapi::experimental::fixed_size_group<PartitionSize, ParentGroup>>
184  : std::true_type {};
185 } // namespace detail
186 
187 template <size_t PartitionSize, typename ParentGroup>
188 struct is_group<
189  ext::oneapi::experimental::fixed_size_group<PartitionSize, ParentGroup>>
190  : std::true_type {};
191 
192 } // namespace _V1
193 } // namespace sycl
typename ParentGroup::linear_id_type linear_id_type
fence_scope
The scope that fence() operation should apply to.
Definition: common.hpp:350
std::enable_if_t< sycl::is_group_v< std::decay_t< Group > > &&std::is_same_v< Group, sycl::sub_group >, fixed_size_group< PartitionSize, Group > > get_fixed_size_group(Group group)
Definition: access.hpp:18