11 #if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
17 namespace ext::oneapi::experimental {
20 template <
typename Group,
size_t Extent>
class group_with_scratchpad {
22 sycl::span<std::byte, Extent> scratch;
26 : g(g_), scratch(scratch_) {}
28 sycl::span<std::byte, Extent>
get_memory()
const {
return scratch; }
38 template <
size_t Extent>
40 Compare comp_ = Compare())
41 : comp(comp_), scratch(scratch_.data()), scratch_size(scratch_.size()) {}
43 template <
typename Group,
typename Ptr>
45 #ifdef __SYCL_DEVICE_ONLY__
46 using T =
typename sycl::detail::GetValueType<Ptr>::type;
47 if (scratch_size >= memory_required<T>(Group::fence_scope, last - first))
48 sycl::detail::merge_sort(g, first, last - first, comp, scratch);
54 throw sycl::exception(
56 "default_sorter constructor is not supported on host device.");
60 template <
typename Group,
typename T> T
operator()(Group g, T val) {
61 #ifdef __SYCL_DEVICE_ONLY__
62 auto range_size = g.get_local_range().size();
63 if (scratch_size >= memory_required<T>(Group::fence_scope, range_size)) {
64 size_t local_id = g.get_local_linear_id();
65 T *temp =
reinterpret_cast<T *
>(scratch);
66 ::new (temp + local_id) T(val);
67 sycl::detail::merge_sort(g, temp, range_size, comp,
68 scratch + range_size *
sizeof(T));
74 throw sycl::exception(
76 "default_sorter operator() is not supported on host device.");
84 return range_size *
sizeof(T) +
alignof(T);
87 template <
typename T,
int dim = 1>
90 return 2 * memory_required<T>(scope,
r.size());
98 template <
typename T, sorting_order = sorting_order::ascending>
108 template <
typename ValT,
sorting_order OrderT = sorting_order::ascending,
109 unsigned int BitsPerPass = 4>
113 uint32_t first_bit = 0;
114 uint32_t last_bit = 0;
115 size_t scratch_size = 0;
117 static constexpr uint32_t bits = BitsPerPass;
120 template <
size_t Extent>
122 const std::bitset<
sizeof(ValT) *CHAR_BIT> mask =
123 std::bitset<
sizeof(ValT) * CHAR_BIT>(
125 : scratch(scratch_.data()), scratch_size(scratch_.size()) {
126 static_assert((std::is_arithmetic<ValT>::value ||
127 std::is_same<ValT, sycl::half>::value ||
128 std::is_same<ValT, sycl::ext::oneapi::bfloat16>::value),
129 "radix sort is not usable");
132 while (first_bit < mask.size() && !mask[first_bit])
135 last_bit = first_bit;
136 while (last_bit < mask.size() && mask[last_bit])
140 template <
typename GroupT,
typename PtrT>
145 #ifdef __SYCL_DEVICE_ONLY__
146 sycl::detail::privateDynamicSort<
false,
147 OrderT == sorting_order::ascending,
149 g, first, first, (last - first) > 0 ? (last - first) : 0,
150 scratch, first_bit, last_bit);
152 throw sycl::exception(
154 "radix_sorter is not supported on host device.");
158 template <
typename GroupT> ValT
operator()(GroupT g, ValT val) {
161 #ifdef __SYCL_DEVICE_ONLY__
163 sycl::detail::privateStaticSort<
false,
165 OrderT == sorting_order::ascending,
167 g, result, result, scratch, first_bit, last_bit);
170 throw sycl::exception(
172 "radix_sorter is not supported on host device.");
180 return range_size *
sizeof(ValT) +
181 (1 << bits) * range_size *
sizeof(uint32_t) +
alignof(uint32_t);
185 template <
int dimensions = 1>
187 sycl::range<dimensions> local_range) {
190 return (std::max)(local_range.size() *
sizeof(ValT),
191 local_range.size() * (1 << bits) *
sizeof(uint32_t));