DPC++ Runtime
Runtime libraries for oneAPI DPC++
fixed_size_group.hpp
Go to the documentation of this file.
1 //==--- fixed_size_group.hpp --- SYCL extension for non-uniform groups -----==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #include <sycl/aspects.hpp>
12 #include <sycl/detail/spirv.hpp>
13 #include <sycl/detail/type_traits.hpp> // for is_fixed_size_group, is_group
14 #include <sycl/exception.hpp>
16 #include <sycl/ext/oneapi/sub_group_mask.hpp> // for sub_group_mask
17 #include <sycl/id.hpp> // for id
18 #include <sycl/memory_enums.hpp> // for memory_scope
19 #include <sycl/range.hpp> // for range
20 #include <sycl/sub_group.hpp> // for sub_group
21 
22 #include <stddef.h> // for size_t
23 #include <type_traits> // for enable_if_t, true_type, dec...
24 
25 namespace sycl {
26 inline namespace _V1 {
27 namespace ext::oneapi::experimental {
28 
29 template <size_t PartitionSize, typename ParentGroup> class fixed_size_group;
30 
31 template <size_t PartitionSize, typename Group>
32 #ifdef __SYCL_DEVICE_ONLY__
33 [[__sycl_detail__::__uses_aspects__(sycl::aspect::ext_oneapi_fixed_size_group)]]
34 #endif
35 inline std::enable_if_t<sycl::is_group_v<std::decay_t<Group>> &&
36  std::is_same_v<Group, sycl::sub_group>,
37  fixed_size_group<PartitionSize, Group>>
38 get_fixed_size_group(Group group);
39 
40 template <size_t PartitionSize, typename ParentGroup> class fixed_size_group {
41 public:
42  using id_type = id<1>;
44  using linear_id_type = typename ParentGroup::linear_id_type;
45  static constexpr int dimensions = 1;
47 
49 #ifdef __SYCL_DEVICE_ONLY__
50  return __spirv_SubgroupLocalInvocationId() / PartitionSize;
51 #else
53  "Non-uniform groups are not supported on host.");
54 #endif
55  }
56 
58 #ifdef __SYCL_DEVICE_ONLY__
59  return __spirv_SubgroupLocalInvocationId() % PartitionSize;
60 #else
62  "Non-uniform groups are not supported on host.");
63 #endif
64  }
65 
67 #ifdef __SYCL_DEVICE_ONLY__
68  return __spirv_SubgroupSize() / PartitionSize;
69 #else
71  "Non-uniform groups are not supported on host.");
72 #endif
73  }
74 
76 #ifdef __SYCL_DEVICE_ONLY__
77  return PartitionSize;
78 #else
80  "Non-uniform groups are not supported on host.");
81 #endif
82  }
83 
85 #ifdef __SYCL_DEVICE_ONLY__
86  return static_cast<linear_id_type>(get_group_id()[0]);
87 #else
89  "Non-uniform groups are not supported on host.");
90 #endif
91  }
92 
94 #ifdef __SYCL_DEVICE_ONLY__
95  return static_cast<linear_id_type>(get_local_id()[0]);
96 #else
98  "Non-uniform groups are not supported on host.");
99 #endif
100  }
101 
103 #ifdef __SYCL_DEVICE_ONLY__
104  return static_cast<linear_id_type>(get_group_range()[0]);
105 #else
107  "Non-uniform groups are not supported on host.");
108 #endif
109  }
110 
112 #ifdef __SYCL_DEVICE_ONLY__
113  return static_cast<linear_id_type>(get_local_range()[0]);
114 #else
116  "Non-uniform groups are not supported on host.");
117 #endif
118  }
119 
120  bool leader() const {
121 #ifdef __SYCL_DEVICE_ONLY__
122  return get_local_linear_id() == 0;
123 #else
125  "Non-uniform groups are not supported on host.");
126 #endif
127  }
128 
129 protected:
130 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
131  sub_group_mask Mask;
132 #endif
133 
134 #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
135  fixed_size_group(ext::oneapi::sub_group_mask mask) : Mask(mask) {}
136 #else
138 #endif
139 
141  get_fixed_size_group<PartitionSize, ParentGroup>(ParentGroup g);
142 
144  sycl::detail::GetMask<fixed_size_group<PartitionSize, ParentGroup>>(
146 };
147 
148 template <size_t PartitionSize, typename Group>
149 inline std::enable_if_t<sycl::is_group_v<std::decay_t<Group>> &&
150  std::is_same_v<Group, sycl::sub_group>,
153  (void)group;
154 #ifdef __SYCL_DEVICE_ONLY__
155 #if defined(__NVPTX__)
156  uint32_t loc_id = group.get_local_linear_id();
157  uint32_t loc_size = group.get_local_linear_range();
158  uint32_t bits = PartitionSize == 32
159  ? 0xffffffff
160  : ((1 << PartitionSize) - 1)
161  << ((loc_id / PartitionSize) * PartitionSize);
162 
164  sycl::detail::Builder::createSubGroupMask<ext::oneapi::sub_group_mask>(
165  bits, loc_size));
166 #else
168 #endif
169 #else
171  "Non-uniform groups are not supported on host.");
172 #endif
173 }
174 
175 template <size_t PartitionSize, typename ParentGroup>
176 struct is_user_constructed_group<fixed_size_group<PartitionSize, ParentGroup>>
177  : std::true_type {};
178 
179 } // namespace ext::oneapi::experimental
180 
181 namespace detail {
182 template <size_t PartitionSize, typename ParentGroup>
184  ext::oneapi::experimental::fixed_size_group<PartitionSize, ParentGroup>>
185  : std::true_type {};
186 } // namespace detail
187 
188 template <size_t PartitionSize, typename ParentGroup>
189 struct is_group<
190  ext::oneapi::experimental::fixed_size_group<PartitionSize, ParentGroup>>
191  : std::true_type {};
192 
193 } // namespace _V1
194 } // namespace sycl
typename ParentGroup::linear_id_type linear_id_type
fence_scope
The scope that fence() operation should apply to.
Definition: common.hpp:345
std::enable_if_t< sycl::is_group_v< std::decay_t< Group > > &&std::is_same_v< Group, sycl::sub_group >, fixed_size_group< PartitionSize, Group > > get_fixed_size_group(Group group)
std::error_code make_error_code(sycl::errc E) noexcept
Constructs an error code using e and sycl_category()
Definition: exception.cpp:65
Definition: access.hpp:18