DPC++ Runtime
Runtime libraries for oneAPI DPC++
group_helpers_sorters.hpp
Go to the documentation of this file.
1 //==------- group_helpers_sorters.hpp - SYCL sorters and group helpers -----==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
12 
13 #include <sycl/aliases.hpp> // for half
14 #include <sycl/detail/pi.h> // for PI_ERROR_INVALID_DEVICE
15 #include <sycl/exception.hpp> // for sycl_category, exception
16 #include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16
17 #include <sycl/memory_enums.hpp> // for memory_scope
18 #include <sycl/range.hpp> // for range
19 #include <sycl/sycl_span.hpp> // for span
20 
21 #ifdef __SYCL_DEVICE_ONLY__
23 #endif
24 
25 #include <bitset> // for bitset
26 #include <cstddef> // for size_t, byte
27 #include <functional> // for less, greater
28 #include <limits.h> // for CHAR_BIT
29 #include <limits> // for numeric_limits
30 #include <stdint.h> // for uint32_t
31 #include <system_error> // for error_code
32 #include <type_traits> // for is_same, is_arithmetic
33 
34 namespace sycl {
35 inline namespace _V1 {
36 namespace ext::oneapi::experimental {
37 
38 // ---- group helpers
39 template <typename Group, size_t Extent> class group_with_scratchpad {
40  Group g;
42 
43 public:
45  : g(g_), scratch(scratch_) {}
46  Group get_group() const { return g; }
47  sycl::span<std::byte, Extent> get_memory() const { return scratch; }
48 };
49 
50 // ---- sorters
51 template <typename Compare = std::less<>> class default_sorter {
52  Compare comp;
53  std::byte *scratch;
54  size_t scratch_size;
55 
56 public:
57  template <size_t Extent>
59  Compare comp_ = Compare())
60  : comp(comp_), scratch(scratch_.data()), scratch_size(scratch_.size()) {}
61 
62  template <typename Group, typename Ptr>
63  void operator()(Group g, Ptr first, Ptr last) {
64 #ifdef __SYCL_DEVICE_ONLY__
65  using T = typename sycl::detail::GetValueType<Ptr>::type;
66  if (scratch_size >= memory_required<T>(Group::fence_scope, last - first))
67  sycl::detail::merge_sort(g, first, last - first, comp, scratch);
68  // TODO: it's better to add else branch
69 #else
70  (void)g;
71  (void)first;
72  (void)last;
73  throw sycl::exception(
74  std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
75  "default_sorter constructor is not supported on host device.");
76 #endif
77  }
78 
79  template <typename Group, typename T> T operator()(Group g, T val) {
80 #ifdef __SYCL_DEVICE_ONLY__
81  auto range_size = g.get_local_range().size();
82  if (scratch_size >= memory_required<T>(Group::fence_scope, range_size)) {
83  size_t local_id = g.get_local_linear_id();
84  T *temp = reinterpret_cast<T *>(scratch);
85  ::new (temp + local_id) T(val);
86  sycl::detail::merge_sort(g, temp, range_size, comp,
87  scratch + range_size * sizeof(T));
88  val = temp[local_id];
89  }
90  // TODO: it's better to add else branch
91 #else
92  (void)g;
93  throw sycl::exception(
94  std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
95  "default_sorter operator() is not supported on host device.");
96 #endif
97  return val;
98  }
99 
100  template <typename T>
101  static constexpr size_t memory_required(sycl::memory_scope,
102  size_t range_size) {
103  return range_size * sizeof(T) + alignof(T);
104  }
105 
106  template <typename T, int dim = 1>
107  static constexpr size_t memory_required(sycl::memory_scope scope,
108  sycl::range<dim> r) {
109  return 2 * memory_required<T>(scope, r.size());
110  }
111 };
112 
114 
115 namespace detail {
116 
117 template <typename T, sorting_order = sorting_order::ascending>
119  using Type = std::less<T>;
120 };
121 
122 template <typename T> struct ConvertToComp<T, sorting_order::descending> {
123  using Type = std::greater<T>;
124 };
125 } // namespace detail
126 
127 template <typename ValT, sorting_order OrderT = sorting_order::ascending,
128  unsigned int BitsPerPass = 4>
130 
131  std::byte *scratch = nullptr;
132  uint32_t first_bit = 0;
133  uint32_t last_bit = 0;
134  size_t scratch_size = 0;
135 
136  static constexpr uint32_t bits = BitsPerPass;
137 
138 public:
139  template <size_t Extent>
141  const std::bitset<sizeof(ValT) *CHAR_BIT> mask =
142  std::bitset<sizeof(ValT) * CHAR_BIT>(
144  : scratch(scratch_.data()), scratch_size(scratch_.size()) {
145  static_assert((std::is_arithmetic<ValT>::value ||
146  std::is_same<ValT, sycl::half>::value ||
147  std::is_same<ValT, sycl::ext::oneapi::bfloat16>::value),
148  "radix sort is not usable");
149 
150  first_bit = 0;
151  while (first_bit < mask.size() && !mask[first_bit])
152  ++first_bit;
153 
154  last_bit = first_bit;
155  while (last_bit < mask.size() && mask[last_bit])
156  ++last_bit;
157  }
158 
159  template <typename GroupT, typename PtrT>
160  void operator()(GroupT g, PtrT first, PtrT last) {
161  (void)g;
162  (void)first;
163  (void)last;
164 #ifdef __SYCL_DEVICE_ONLY__
165  sycl::detail::privateDynamicSort</*is_key_value=*/false,
166  OrderT == sorting_order::ascending,
167  /*empty*/ 1, BitsPerPass>(
168  g, first, /*empty*/ first, (last - first) > 0 ? (last - first) : 0,
169  scratch, first_bit, last_bit);
170 #else
171  throw sycl::exception(
172  std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
173  "radix_sorter is not supported on host device.");
174 #endif
175  }
176 
177  template <typename GroupT> ValT operator()(GroupT g, ValT val) {
178  (void)g;
179  (void)val;
180 #ifdef __SYCL_DEVICE_ONLY__
181  ValT result[]{val};
182  sycl::detail::privateStaticSort</*is_key_value=*/false,
183  /*is_blocked=*/true,
184  OrderT == sorting_order::ascending,
185  /*items_per_work_item=*/1, bits>(
186  g, result, /*empty*/ result, scratch, first_bit, last_bit);
187  return result[0];
188 #else
189  throw sycl::exception(
190  std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
191  "radix_sorter is not supported on host device.");
192 #endif
193  }
194 
195  static constexpr size_t memory_required(sycl::memory_scope scope,
196  size_t range_size) {
197  // Scope is not important so far
198  (void)scope;
199  return range_size * sizeof(ValT) +
200  (1 << bits) * range_size * sizeof(uint32_t) + alignof(uint32_t);
201  }
202 
203  // memory_helpers
204  template <int dimensions = 1>
205  static constexpr size_t memory_required(sycl::memory_scope scope,
206  sycl::range<dimensions> local_range) {
207  // Scope is not important so far
208  (void)scope;
209  return (std::max)(local_range.size() * sizeof(ValT),
210  local_range.size() * (1 << bits) * sizeof(uint32_t));
211  }
212 };
213 
214 } // namespace ext::oneapi::experimental
215 } // namespace _V1
216 } // namespace sycl
217 #endif
default_sorter(sycl::span< std::byte, Extent > scratch_, Compare comp_=Compare())
static constexpr size_t memory_required(sycl::memory_scope scope, sycl::range< dim > r)
static constexpr size_t memory_required(sycl::memory_scope, size_t range_size)
group_with_scratchpad(Group g_, sycl::span< std::byte, Extent > scratch_)
radix_sorter(sycl::span< std::byte, Extent > scratch_, const std::bitset< sizeof(ValT) *CHAR_BIT > mask=std::bitset< sizeof(ValT) *CHAR_BIT >((std::numeric_limits< unsigned long long >::max)()))
static constexpr size_t memory_required(sycl::memory_scope scope, sycl::range< dimensions > local_range)
static constexpr size_t memory_required(sycl::memory_scope scope, size_t range_size)
Defines the iteration domain of either a single work-group in a parallel dispatch,...
Definition: range.hpp:26
size_t size() const
Definition: range.hpp:56
fence_scope
The scope that fence() operation should apply to.
Definition: common.hpp:350
const std::error_category & sycl_category() noexcept
Definition: exception.cpp:82
Definition: access.hpp:18
error_code
Definition: defs.hpp:59