DPC++ Runtime
Runtime libraries for oneAPI DPC++
group_helpers_sorters.hpp
Go to the documentation of this file.
1 //==------- group_helpers_sorters.hpp - SYCL sorters and group helpers -----==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
12 
13 #include <sycl/aliases.hpp> // for half
14 #include <sycl/builtins.hpp> // for min
15 #include <sycl/detail/pi.h> // for PI_ERROR_INVALID_DEVICE
16 #include <sycl/exception.hpp> // for sycl_category, exception
17 #include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16
19 #include <sycl/memory_enums.hpp> // for memory_scope
20 #include <sycl/range.hpp> // for range
21 #include <sycl/sycl_span.hpp> // for span
22 
23 #ifdef __SYCL_DEVICE_ONLY__
25 #endif
26 
27 #include <bitset> // for bitset
28 #include <cstddef> // for size_t, byte
29 #include <functional> // for less, greater
30 #include <limits.h> // for CHAR_BIT
31 #include <limits> // for numeric_limits
32 #include <stdint.h> // for uint32_t
33 #include <system_error> // for error_code
34 #include <type_traits> // for is_same, is_arithmetic
35 
36 namespace sycl {
37 inline namespace _V1 {
38 namespace ext::oneapi::experimental {
39 
40 enum class group_algorithm_data_placement : std::uint8_t { blocked, striped };
41 
43  : detail::compile_time_property_key<detail::PropKind::InputDataPlacement> {
44  template <group_algorithm_data_placement Placement>
47  std::integral_constant<group_algorithm_data_placement, Placement>>;
48 };
49 
51  : detail::compile_time_property_key<detail::PropKind::OutputDataPlacement> {
52  template <group_algorithm_data_placement Placement>
55  std::integral_constant<group_algorithm_data_placement, Placement>>;
56 };
57 
58 template <group_algorithm_data_placement Placement>
61 
62 template <group_algorithm_data_placement Placement>
65 
66 namespace detail {
67 
68 template <typename Properties>
69 constexpr bool isInputBlocked(Properties properties) {
70  if constexpr (properties.template has_property<input_data_placement_key>())
71  return properties.template get_property<input_data_placement_key>() ==
72  input_data_placement<group_algorithm_data_placement::blocked>;
73  else
74  return true;
75 }
76 
77 template <typename Properties>
78 constexpr bool isOutputBlocked(Properties properties) {
79  if constexpr (properties.template has_property<output_data_placement_key>())
80  return properties.template get_property<output_data_placement_key>() ==
81  output_data_placement<group_algorithm_data_placement::blocked>;
82  else
83  return true;
84 }
85 
86 } // namespace detail
87 
88 // ---- group helpers
89 template <typename Group, size_t Extent> class group_with_scratchpad {
90  Group g;
92 
93 public:
95  : g(g_), scratch(scratch_) {}
96  Group get_group() const { return g; }
97  sycl::span<std::byte, Extent> get_memory() const { return scratch; }
98 };
99 
101 
102 namespace default_sorters {
103 
104 template <typename CompareT = std::less<>> class joint_sorter {
105  CompareT comp;
106  sycl::span<std::byte> scratch;
107 
108 public:
109  template <size_t Extent>
111  CompareT comp_ = CompareT())
112  : comp(comp_), scratch(scratch_) {}
113 
114  template <typename Group, typename Ptr>
115  void operator()([[maybe_unused]] Group g, [[maybe_unused]] Ptr first,
116  [[maybe_unused]] Ptr last) {
117 #ifdef __SYCL_DEVICE_ONLY__
118  using T = typename sycl::detail::GetValueType<Ptr>::type;
119  size_t n = std::distance(first, last);
120  T *scratch_begin = sycl::detail::align_scratch<T>(scratch, g, n);
121  sycl::detail::merge_sort(g, first, n, comp, scratch_begin);
122 #else
123  throw sycl::exception(
124  std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
125  "default_sorter constructor is not supported on host device.");
126 #endif
127  }
128 
129  template <typename T>
130  static size_t memory_required(sycl::memory_scope, size_t range_size) {
131  return range_size * sizeof(T) + alignof(T);
132  }
133 };
134 
135 template <typename T, typename CompareT = std::less<>,
136  std::size_t ElementsPerWorkItem = 1>
138  CompareT comp;
139  sycl::span<std::byte> scratch;
140 
141 public:
142  template <std::size_t Extent>
144  CompareT comp_ = CompareT{})
145  : comp(comp_), scratch(scratch_) {}
146 
147  template <typename Group> T operator()([[maybe_unused]] Group g, T val) {
148 #ifdef __SYCL_DEVICE_ONLY__
149  std::size_t local_id = g.get_local_linear_id();
150  auto range_size = g.get_local_range().size();
151  T *scratch_begin = sycl::detail::align_scratch<T>(
152  scratch, g, /* output storage and temporary storage */ 2 * range_size);
153  scratch_begin[local_id] = val;
154  sycl::detail::merge_sort(g, scratch_begin, range_size, comp,
155  scratch_begin + range_size);
156  val = scratch_begin[local_id];
157 #else
158  throw sycl::exception(
159  std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
160  "default_sorter operator() is not supported on host device.");
161 #endif
162  return val;
163  }
164 
165  template <typename Group, typename Properties>
166  void operator()([[maybe_unused]] Group g,
167  [[maybe_unused]] sycl::span<T, ElementsPerWorkItem> values,
168  [[maybe_unused]] Properties properties) {
169 #ifdef __SYCL_DEVICE_ONLY__
170  std::size_t local_id = g.get_local_linear_id();
171  auto wg_size = g.get_local_range().size();
172  auto number_of_elements = wg_size * ElementsPerWorkItem;
173  T *scratch_begin = sycl::detail::align_scratch<T>(
174  scratch, g,
175  /* output storage and temporary storage */ 2 * number_of_elements);
176  for (std::uint32_t i = 0; i < ElementsPerWorkItem; ++i)
177  scratch_begin[local_id * ElementsPerWorkItem + i] = values[i];
178  sycl::detail::merge_sort(g, scratch_begin, number_of_elements, comp,
179  scratch_begin + number_of_elements);
180 
181  std::size_t shift{};
182  for (std::uint32_t i = 0; i < ElementsPerWorkItem; ++i) {
183  if constexpr (detail::isOutputBlocked(properties)) {
184  shift = local_id * ElementsPerWorkItem + i;
185  } else {
186  shift = i * wg_size + local_id;
187  }
188  values[i] = scratch_begin[shift];
189  }
190 #endif
191  }
192 
193  static std::size_t memory_required([[maybe_unused]] sycl::memory_scope scope,
194  size_t range_size) {
195  // We need a space (in bytes) for the buffer of output values and the
196  // temporary buffer. Where number of elements in each buffer is range_size
197  // (group size) multiplied by elements per work item. Also we have to align
198  // these two buffers, so need an additional space of size alignof(T).
199  return 2 * range_size * ElementsPerWorkItem * sizeof(T) + alignof(T);
200  }
201 };
202 
203 template <typename KeyTy, typename ValueTy, typename CompareT = std::less<>,
204  std::size_t ElementsPerWorkItem = 1>
206  CompareT comp;
207  sycl::span<std::byte> scratch;
208 
209 public:
210  template <std::size_t Extent>
212  CompareT comp_ = {})
213  : comp(comp_), scratch(scratch_) {}
214 
215  template <typename Group>
216  std::tuple<KeyTy, ValueTy> operator()([[maybe_unused]] Group g, KeyTy key,
217  ValueTy value) {
218  static_assert(ElementsPerWorkItem == 1,
219  "ElementsPerWorkItem must be equal 1");
220 #ifdef __SYCL_DEVICE_ONLY__
221  auto range_size = g.get_local_linear_range();
222  std::size_t local_id = g.get_local_linear_id();
223  auto [keys_scratch_begin, values_scratch_begin] =
224  sycl::detail::align_key_value_scratch<KeyTy, ValueTy>(scratch, g,
225  2 * range_size);
226 
227  keys_scratch_begin[local_id] = key;
228  values_scratch_begin[local_id] = value;
229 
230  auto scratch_begin = sycl::detail::key_value_iterator(keys_scratch_begin,
231  values_scratch_begin);
232  auto scratch_temp_begin = sycl::detail::key_value_iterator(
233  keys_scratch_begin + range_size, values_scratch_begin + range_size);
234  sycl::detail::merge_sort(
235  g, scratch_begin, range_size,
236  [this](auto x, auto y) { return comp(std::get<0>(x), std::get<0>(y)); },
237  scratch_temp_begin);
238 
239  key = keys_scratch_begin[local_id];
240  value = values_scratch_begin[local_id];
241 #endif
242  return std::make_tuple(key, value);
243  }
244 
245  template <typename Group, typename Properties>
246  void
247  operator()([[maybe_unused]] Group g,
248  [[maybe_unused]] sycl::span<KeyTy, ElementsPerWorkItem> keys,
249  [[maybe_unused]] sycl::span<ValueTy, ElementsPerWorkItem> values,
250  [[maybe_unused]] Properties property) {
251 #ifdef __SYCL_DEVICE_ONLY__
252  auto range_size = g.get_local_linear_range();
253  std::size_t local_id = g.get_local_linear_id();
254  auto number_of_elements = range_size * ElementsPerWorkItem;
255  auto [keys_scratch_begin, values_scratch_begin] =
256  sycl::detail::align_key_value_scratch<KeyTy, ValueTy>(
257  scratch, g, 2 * number_of_elements);
258 
259  std::size_t shift{};
260  for (std::uint32_t i = 0; i < ElementsPerWorkItem; ++i) {
261  if constexpr (detail::isInputBlocked(property)) {
262  shift = local_id * ElementsPerWorkItem + i;
263  } else {
264  shift = i * range_size + local_id;
265  }
266  keys_scratch_begin[shift] = keys[i];
267  values_scratch_begin[shift] = values[i];
268  }
269 
270  // We need a barrier here if input is striped.
271  // When input is blocked, each work item initializes elements which it is
272  // going to process, so we can start bubble sort (the first step in
273  // merge_sort function) without any barrier because it sorts elements within
274  // work item. When input is striped, work item initializes elements which
275  // can possibly be processed by another work items, so we have to put a
276  // barrier.
277  if constexpr (!detail::isInputBlocked(property))
279 
280  auto scratch_begin = sycl::detail::key_value_iterator(keys_scratch_begin,
281  values_scratch_begin);
282  auto scratch_temp_begin = sycl::detail::key_value_iterator(
283  keys_scratch_begin + number_of_elements,
284  values_scratch_begin + number_of_elements);
285  sycl::detail::merge_sort(
286  g, scratch_begin, number_of_elements,
287  [this](auto x, auto y) { return comp(std::get<0>(x), std::get<0>(y)); },
288  scratch_temp_begin);
289 
290  // from temp
291  for (std::uint32_t i = 0; i < ElementsPerWorkItem; ++i) {
292  if constexpr (detail::isOutputBlocked(property)) {
293  shift = local_id * ElementsPerWorkItem + i;
294  } else {
295  shift = i * range_size + local_id;
296  }
297 
298  keys[i] = std::get<0>(scratch_begin[shift]);
299  values[i] = std::get<1>(scratch_begin[shift]);
300  }
301 #endif
302  }
303 
304  static std::size_t memory_required([[maybe_unused]] sycl::memory_scope scope,
305  std::size_t range_size) {
306  // We need a space (in bytes) for the following buffers:
307  // 1. Output buffer for keys and temporary buffer for keys.
308  // 2. Output buffer for values and temporary buffer for values.
309  // Where number of elements in each buffer is range_size (group size)
310  // multiplied by elements per work item. We have to align buffers of keys
311  // and buffers of values, so need an additional space equal to maximum
312  // between alignment requirements of types KeyTy and ValueTy.
313  return 2 * range_size * ElementsPerWorkItem *
314  (sizeof(KeyTy) + sizeof(ValueTy)) +
315  (std::max)(alignof(KeyTy), alignof(ValueTy));
316  }
317 };
318 } // namespace default_sorters
319 
320 namespace radix_sorters {
321 
322 template <typename ValT, sorting_order OrderT = sorting_order::ascending,
323  unsigned int BitsPerPass = 4>
325 
326  sycl::span<std::byte> scratch;
327  uint32_t first_bit = 0;
328  uint32_t last_bit = 0;
329 
330  static constexpr uint32_t bits = BitsPerPass;
331  using bitset_t = std::bitset<sizeof(ValT) * CHAR_BIT>;
332 
333 public:
334  template <std::size_t Extent>
336  const bitset_t mask = bitset_t{}.set())
337  : scratch(scratch_) {
338  static_assert((std::is_arithmetic<ValT>::value ||
339  std::is_same<ValT, sycl::half>::value ||
340  std::is_same<ValT, sycl::ext::oneapi::bfloat16>::value),
341  "radix sort is not supported for the given type");
342 
343  for (first_bit = 0; first_bit < mask.size() && !mask[first_bit];
344  ++first_bit)
345  ;
346  for (last_bit = first_bit; last_bit < mask.size() && mask[last_bit];
347  ++last_bit)
348  ;
349  }
350 
351  template <typename GroupT, typename PtrT>
352  void operator()([[maybe_unused]] GroupT g, [[maybe_unused]] PtrT first,
353  [[maybe_unused]] PtrT last) {
354 #ifdef __SYCL_DEVICE_ONLY__
355  sycl::detail::privateDynamicSort</*is_key_value=*/false,
356  OrderT == sorting_order::ascending,
357  /*empty*/ 1, BitsPerPass>(
358  g, first, /*empty*/ first, std::distance(first, last), scratch.data(),
359  first_bit, last_bit);
360 #else
361  throw sycl::exception(
362  std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
363  "radix_sorter is not supported on host device.");
364 #endif
365  }
366 
367  static constexpr std::size_t
368  memory_required([[maybe_unused]] sycl::memory_scope scope,
369  std::size_t range_size) {
370  return range_size * sizeof(ValT) +
371  (1 << bits) * range_size * sizeof(uint32_t) + alignof(uint32_t);
372  }
373 };
374 
375 template <typename ValT, sorting_order OrderT = sorting_order::ascending,
376  size_t ElementsPerWorkItem = 1, unsigned int BitsPerPass = 4>
378 
379  sycl::span<std::byte> scratch;
380  uint32_t first_bit = 0;
381  uint32_t last_bit = 0;
382 
383  static constexpr uint32_t bits = BitsPerPass;
384  using bitset_t = std::bitset<sizeof(ValT) * CHAR_BIT>;
385 
386 public:
387  template <std::size_t Extent>
389  const bitset_t mask = bitset_t{}.set())
390  : scratch(scratch_) {
391  static_assert((std::is_arithmetic<ValT>::value ||
392  std::is_same<ValT, sycl::half>::value ||
393  std::is_same<ValT, sycl::ext::oneapi::bfloat16>::value),
394  "radix sort is not usable");
395 
396  for (first_bit = 0; first_bit < mask.size() && !mask[first_bit];
397  ++first_bit)
398  ;
399  for (last_bit = first_bit; last_bit < mask.size() && mask[last_bit];
400  ++last_bit)
401  ;
402  }
403 
404  template <typename GroupT>
405  ValT operator()([[maybe_unused]] GroupT g, [[maybe_unused]] ValT val) {
406 #ifdef __SYCL_DEVICE_ONLY__
407  ValT result[]{val};
408  sycl::detail::privateStaticSort</*is_key_value=*/false,
409  /*is_input_blocked=*/true,
410  /*is_output_blocked=*/true,
411  OrderT == sorting_order::ascending,
412  /*items_per_work_item=*/1, bits>(
413  g, result, /*empty*/ result, scratch.data(), first_bit, last_bit);
414  return result[0];
415 #else
416  throw sycl::exception(
417  std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
418  "radix_sorter is not supported on host device.");
419 #endif
420  }
421 
422  template <typename Group, typename Properties>
423  void operator()([[maybe_unused]] Group g,
424  [[maybe_unused]] sycl::span<ValT, ElementsPerWorkItem> values,
425  [[maybe_unused]] Properties properties) {
426 #ifdef __SYCL_DEVICE_ONLY__
427  sycl::detail::privateStaticSort<
428  /*is_key_value=*/false, /*is_input_blocked=*/true,
430  ElementsPerWorkItem, bits>(g, values.data(), /*empty*/ values.data(),
431  scratch.data(), first_bit, last_bit);
432 #endif
433  }
434 
435  static constexpr size_t
436  memory_required([[maybe_unused]] sycl::memory_scope scope,
437  size_t range_size) {
438  return (std::max)(range_size * sizeof(ValT),
439  range_size * (1 << bits) * sizeof(uint32_t));
440  }
441 };
442 
443 template <typename KeyTy, typename ValueTy,
445  size_t ElementsPerWorkItem = 1, unsigned int BitsPerPass = 4>
447  sycl::span<std::byte> scratch;
448  uint32_t first_bit;
449  uint32_t last_bit;
450 
451  static constexpr uint32_t bits = BitsPerPass;
452  using bitset_t = std::bitset<sizeof(KeyTy) * CHAR_BIT>;
453 
454 public:
455  template <std::size_t Extent>
457  const bitset_t mask = bitset_t{}.set())
458  : scratch(scratch_) {
459  static_assert((std::is_arithmetic<KeyTy>::value ||
460  std::is_same<KeyTy, sycl::half>::value),
461  "radix sort is not usable");
462  for (first_bit = 0; first_bit < mask.size() && !mask[first_bit];
463  ++first_bit)
464  ;
465  for (last_bit = first_bit; last_bit < mask.size() && mask[last_bit];
466  ++last_bit)
467  ;
468  }
469 
470  template <typename Group>
471  std::tuple<KeyTy, ValueTy> operator()([[maybe_unused]] Group g, KeyTy key,
472  ValueTy val) {
473  static_assert(ElementsPerWorkItem == 1, "ElementsPerWorkItem must be 1");
474  KeyTy key_result[]{key};
475  ValueTy val_result[]{val};
476 #ifdef __SYCL_DEVICE_ONLY__
477  sycl::detail::privateStaticSort<
478  /*is_key_value=*/true,
479  /*is_input_blocked=*/true,
480  /*is_output_blocked=*/true, Order == sorting_order::ascending, 1, bits>(
481  g, key_result, val_result, scratch.data(), first_bit, last_bit);
482 #endif
483  key = key_result[0];
484  val = val_result[0];
485  return {key, val};
486  }
487 
488  template <typename Group, typename Properties>
489  void
490  operator()([[maybe_unused]] Group g,
491  [[maybe_unused]] sycl::span<KeyTy, ElementsPerWorkItem> keys,
492  [[maybe_unused]] sycl::span<ValueTy, ElementsPerWorkItem> vals,
493  [[maybe_unused]] Properties properties) {
494 #ifdef __SYCL_DEVICE_ONLY__
495  sycl::detail::privateStaticSort<
496  /*is_key_value=*/true, detail::isInputBlocked(properties),
498  ElementsPerWorkItem, bits>(g, keys.data(), vals.data(), scratch.data(),
499  first_bit, last_bit);
500 #endif
501  }
502 
503  static constexpr std::size_t memory_required(sycl::memory_scope,
504  std::size_t range_size) {
505  return (std::max)(range_size * ElementsPerWorkItem *
506  (sizeof(KeyTy) + sizeof(ValueTy)),
507  range_size * (1 << bits) * sizeof(uint32_t));
508  }
509 };
510 } // namespace radix_sorters
511 
512 } // namespace ext::oneapi::experimental
513 } // namespace _V1
514 } // namespace sycl
515 #endif
static std::size_t memory_required([[maybe_unused]] sycl::memory_scope scope, std::size_t range_size)
std::tuple< KeyTy, ValueTy > operator()([[maybe_unused]] Group g, KeyTy key, ValueTy value)
void operator()([[maybe_unused]] Group g, [[maybe_unused]] sycl::span< KeyTy, ElementsPerWorkItem > keys, [[maybe_unused]] sycl::span< ValueTy, ElementsPerWorkItem > values, [[maybe_unused]] Properties property)
group_key_value_sorter(sycl::span< std::byte, Extent > scratch_, CompareT comp_={})
static std::size_t memory_required([[maybe_unused]] sycl::memory_scope scope, size_t range_size)
group_sorter(sycl::span< std::byte, Extent > scratch_, CompareT comp_=CompareT{})
void operator()([[maybe_unused]] Group g, [[maybe_unused]] sycl::span< T, ElementsPerWorkItem > values, [[maybe_unused]] Properties properties)
void operator()([[maybe_unused]] Group g, [[maybe_unused]] Ptr first, [[maybe_unused]] Ptr last)
joint_sorter(sycl::span< std::byte, Extent > scratch_, CompareT comp_=CompareT())
static size_t memory_required(sycl::memory_scope, size_t range_size)
group_with_scratchpad(Group g_, sycl::span< std::byte, Extent > scratch_)
std::tuple< KeyTy, ValueTy > operator()([[maybe_unused]] Group g, KeyTy key, ValueTy val)
void operator()([[maybe_unused]] Group g, [[maybe_unused]] sycl::span< KeyTy, ElementsPerWorkItem > keys, [[maybe_unused]] sycl::span< ValueTy, ElementsPerWorkItem > vals, [[maybe_unused]] Properties properties)
group_key_value_sorter(sycl::span< std::byte, Extent > scratch_, const bitset_t mask=bitset_t{}.set())
static constexpr std::size_t memory_required(sycl::memory_scope, std::size_t range_size)
void operator()([[maybe_unused]] Group g, [[maybe_unused]] sycl::span< ValT, ElementsPerWorkItem > values, [[maybe_unused]] Properties properties)
static constexpr size_t memory_required([[maybe_unused]] sycl::memory_scope scope, size_t range_size)
ValT operator()([[maybe_unused]] GroupT g, [[maybe_unused]] ValT val)
group_sorter(sycl::span< std::byte, Extent > scratch_, const bitset_t mask=bitset_t{}.set())
joint_sorter(sycl::span< std::byte, Extent > scratch_, const bitset_t mask=bitset_t{}.set())
static constexpr std::size_t memory_required([[maybe_unused]] sycl::memory_scope scope, std::size_t range_size)
void operator()([[maybe_unused]] GroupT g, [[maybe_unused]] PtrT first, [[maybe_unused]] PtrT last)
constexpr _SYCL_SPAN_INLINE_VISIBILITY pointer data() const noexcept
Definition: sycl_span.hpp:378
constexpr tuple< Ts... > make_tuple(Ts... Args)
Definition: tuple.hpp:35
constexpr bool isInputBlocked(Properties properties)
constexpr bool isOutputBlocked(Properties properties)
constexpr output_data_placement_key::value_t< Placement > output_data_placement
constexpr input_data_placement_key::value_t< Placement > input_data_placement
void group_barrier(ext::oneapi::experimental::root_group< dimensions > G, memory_scope FenceScope=decltype(G)::fence_scope)
Definition: root_group.hpp:100
const std::error_category & sycl_category() noexcept
Definition: exception.cpp:59
autodecltype(x) x
Definition: access.hpp:18
error_code
Definition: defs.hpp:70