DPC++ Runtime
Runtime libraries for oneAPI DPC++
group_helpers_sorters.hpp
Go to the documentation of this file.
1 //==------- group_helpers_sorters.hpp - SYCL sorters and group helpers -----==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
14 
15 namespace sycl {
17 namespace ext::oneapi::experimental {
18 
19 // ---- group helpers
20 template <typename Group, size_t Extent> class group_with_scratchpad {
21  Group g;
22  sycl::span<std::byte, Extent> scratch;
23 
24 public:
25  group_with_scratchpad(Group g_, sycl::span<std::byte, Extent> scratch_)
26  : g(g_), scratch(scratch_) {}
27  Group get_group() const { return g; }
28  sycl::span<std::byte, Extent> get_memory() const { return scratch; }
29 };
30 
31 // ---- sorters
32 template <typename Compare = std::less<>> class default_sorter {
33  Compare comp;
34  std::byte *scratch;
35  size_t scratch_size;
36 
37 public:
38  template <size_t Extent>
39  default_sorter(sycl::span<std::byte, Extent> scratch_,
40  Compare comp_ = Compare())
41  : comp(comp_), scratch(scratch_.data()), scratch_size(scratch_.size()) {}
42 
43  template <typename Group, typename Ptr>
44  void operator()(Group g, Ptr first, Ptr last) {
45 #ifdef __SYCL_DEVICE_ONLY__
46  using T = typename sycl::detail::GetValueType<Ptr>::type;
47  if (scratch_size >= memory_required<T>(Group::fence_scope, last - first))
48  sycl::detail::merge_sort(g, first, last - first, comp, scratch);
49  // TODO: it's better to add else branch
50 #else
51  (void)g;
52  (void)first;
53  (void)last;
54  throw sycl::exception(
55  std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
56  "default_sorter constructor is not supported on host device.");
57 #endif
58  }
59 
60  template <typename Group, typename T> T operator()(Group g, T val) {
61 #ifdef __SYCL_DEVICE_ONLY__
62  auto range_size = g.get_local_range().size();
63  if (scratch_size >= memory_required<T>(Group::fence_scope, range_size)) {
64  size_t local_id = g.get_local_linear_id();
65  T *temp = reinterpret_cast<T *>(scratch);
66  ::new (temp + local_id) T(val);
67  sycl::detail::merge_sort(g, temp, range_size, comp,
68  scratch + range_size * sizeof(T));
69  val = temp[local_id];
70  }
71  // TODO: it's better to add else branch
72 #else
73  (void)g;
74  throw sycl::exception(
75  std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
76  "default_sorter operator() is not supported on host device.");
77 #endif
78  return val;
79  }
80 
81  template <typename T>
82  static constexpr size_t memory_required(sycl::memory_scope,
83  size_t range_size) {
84  return range_size * sizeof(T) + alignof(T);
85  }
86 
87  template <typename T, int dim = 1>
88  static constexpr size_t memory_required(sycl::memory_scope scope,
89  sycl::range<dim> r) {
90  return 2 * memory_required<T>(scope, r.size());
91  }
92 };
93 
95 
96 namespace detail {
97 
98 template <typename T, sorting_order = sorting_order::ascending>
99 struct ConvertToComp {
100  using Type = std::less<T>;
101 };
102 
103 template <typename T> struct ConvertToComp<T, sorting_order::descending> {
104  using Type = std::greater<T>;
105 };
106 } // namespace detail
107 
108 template <typename ValT, sorting_order OrderT = sorting_order::ascending,
109  unsigned int BitsPerPass = 4>
111 
112  std::byte *scratch = nullptr;
113  uint32_t first_bit = 0;
114  uint32_t last_bit = 0;
115  size_t scratch_size = 0;
116 
117  static constexpr uint32_t bits = BitsPerPass;
118 
119 public:
120  template <size_t Extent>
121  radix_sorter(sycl::span<std::byte, Extent> scratch_,
122  const std::bitset<sizeof(ValT) *CHAR_BIT> mask =
123  std::bitset<sizeof(ValT) * CHAR_BIT>(
125  : scratch(scratch_.data()), scratch_size(scratch_.size()) {
126  static_assert((std::is_arithmetic<ValT>::value ||
127  std::is_same<ValT, sycl::half>::value ||
128  std::is_same<ValT, sycl::ext::oneapi::bfloat16>::value),
129  "radix sort is not usable");
130 
131  first_bit = 0;
132  while (first_bit < mask.size() && !mask[first_bit])
133  ++first_bit;
134 
135  last_bit = first_bit;
136  while (last_bit < mask.size() && mask[last_bit])
137  ++last_bit;
138  }
139 
140  template <typename GroupT, typename PtrT>
141  void operator()(GroupT g, PtrT first, PtrT last) {
142  (void)g;
143  (void)first;
144  (void)last;
145 #ifdef __SYCL_DEVICE_ONLY__
146  sycl::detail::privateDynamicSort</*is_key_value=*/false,
147  OrderT == sorting_order::ascending,
148  /*empty*/ 1, BitsPerPass>(
149  g, first, /*empty*/ first, (last - first) > 0 ? (last - first) : 0,
150  scratch, first_bit, last_bit);
151 #else
152  throw sycl::exception(
153  std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
154  "radix_sorter is not supported on host device.");
155 #endif
156  }
157 
158  template <typename GroupT> ValT operator()(GroupT g, ValT val) {
159  (void)g;
160  (void)val;
161 #ifdef __SYCL_DEVICE_ONLY__
162  ValT result[]{val};
163  sycl::detail::privateStaticSort</*is_key_value=*/false,
164  /*is_blocked=*/true,
165  OrderT == sorting_order::ascending,
166  /*items_per_work_item=*/1, bits>(
167  g, result, /*empty*/ result, scratch, first_bit, last_bit);
168  return result[0];
169 #else
170  throw sycl::exception(
171  std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
172  "radix_sorter is not supported on host device.");
173 #endif
174  }
175 
176  static constexpr size_t memory_required(sycl::memory_scope scope,
177  size_t range_size) {
178  // Scope is not important so far
179  (void)scope;
180  return range_size * sizeof(ValT) +
181  (1 << bits) * range_size * sizeof(uint32_t) + alignof(uint32_t);
182  }
183 
184  // memory_helpers
185  template <int dimensions = 1>
186  static constexpr size_t memory_required(sycl::memory_scope scope,
187  sycl::range<dimensions> local_range) {
188  // Scope is not important so far
189  (void)scope;
190  return (std::max)(local_range.size() * sizeof(ValT),
191  local_range.size() * (1 << bits) * sizeof(uint32_t));
192  }
193 };
194 
195 } // namespace ext::oneapi::experimental
196 } // __SYCL_INLINE_VER_NAMESPACE(_V1)
197 } // namespace sycl
198 #endif
sycl::_V1::ext::oneapi::experimental::default_sorter::memory_required
static constexpr size_t memory_required(sycl::memory_scope, size_t range_size)
Definition: group_helpers_sorters.hpp:82
sycl::_V1::sycl_category
const std::error_category & sycl_category() noexcept
Definition: exception.cpp:86
sycl::_V1::ext::oneapi::experimental::group_with_scratchpad::group_with_scratchpad
group_with_scratchpad(Group g_, sycl::span< std::byte, Extent > scratch_)
Definition: group_helpers_sorters.hpp:25
sycl::_V1::ext::oneapi::experimental::sorting_order::ascending
@ ascending
__SYCL_INLINE_VER_NAMESPACE
#define __SYCL_INLINE_VER_NAMESPACE(X)
Definition: defines_elementary.hpp:11
sycl::_V1::ext::oneapi::experimental::radix_sorter::radix_sorter
radix_sorter(sycl::span< std::byte, Extent > scratch_, const std::bitset< sizeof(ValT) *CHAR_BIT > mask=std::bitset< sizeof(ValT) *CHAR_BIT >((std::numeric_limits< unsigned long long >::max)()))
Definition: group_helpers_sorters.hpp:121
sycl
---— Error handling, matching OpenCL plugin semantics.
Definition: access.hpp:14
sycl::_V1::ext::oneapi::experimental::radix_sorter::memory_required
static constexpr size_t memory_required(sycl::memory_scope scope, size_t range_size)
Definition: group_helpers_sorters.hpp:176
max
simd< _Tp, _Abi > max(const simd< _Tp, _Abi > &, const simd< _Tp, _Abi > &) noexcept
sycl::_V1::ext::oneapi::experimental::detail::ConvertToComp::Type
std::less< T > Type
Definition: group_helpers_sorters.hpp:100
sycl::_V1::ext::oneapi::experimental::sorting_order::descending
@ descending
group_sort_impl.hpp
sycl::_V1::ext::oneapi::experimental::default_sorter
Definition: group_helpers_sorters.hpp:32
sycl::_V1::ext::oneapi::experimental::default_sorter::memory_required
static constexpr size_t memory_required(sycl::memory_scope scope, sycl::range< dim > r)
Definition: group_helpers_sorters.hpp:88
sycl::_V1::ext::oneapi::experimental::radix_sorter::operator()
ValT operator()(GroupT g, ValT val)
Definition: group_helpers_sorters.hpp:158
sycl::_V1::ext::oneapi::experimental::sorting_order
sorting_order
Definition: group_helpers_sorters.hpp:94
builtins.hpp
sycl::_V1::ext::oneapi::experimental::radix_sorter::memory_required
static constexpr size_t memory_required(sycl::memory_scope scope, sycl::range< dimensions > local_range)
Definition: group_helpers_sorters.hpp:186
sycl::_V1::ext::oneapi::experimental::radix_sorter::operator()
void operator()(GroupT g, PtrT first, PtrT last)
Definition: group_helpers_sorters.hpp:141
sycl::_V1::ext::oneapi::experimental::group_with_scratchpad::get_memory
sycl::span< std::byte, Extent > get_memory() const
Definition: group_helpers_sorters.hpp:28
sycl::_V1::ext::oneapi::experimental::radix_sorter
Definition: group_helpers_sorters.hpp:110
sycl::_V1::memory_scope
memory_scope
Definition: memory_enums.hpp:26
sycl::_V1::image_channel_order::r
@ r
sycl::_V1::ext::oneapi::experimental::default_sorter::operator()
void operator()(Group g, Ptr first, Ptr last)
Definition: group_helpers_sorters.hpp:44
sycl::_V1::ext::oneapi::experimental::group_with_scratchpad::get_group
Group get_group() const
Definition: group_helpers_sorters.hpp:27
sycl::_V1::ext::oneapi::experimental::detail::ConvertToComp< T, sorting_order::descending >::Type
std::greater< T > Type
Definition: group_helpers_sorters.hpp:104
sycl::_V1::ext::oneapi::experimental::default_sorter::default_sorter
default_sorter(sycl::span< std::byte, Extent > scratch_, Compare comp_=Compare())
Definition: group_helpers_sorters.hpp:39
sycl::_V1::ext::intel::experimental::byte
unsigned char byte
Definition: online_compiler.hpp:22
sycl::_V1::ext::oneapi::experimental::detail::ConvertToComp
Definition: group_helpers_sorters.hpp:99
sycl::_V1::ext::oneapi::experimental::default_sorter::operator()
T operator()(Group g, T val)
Definition: group_helpers_sorters.hpp:60