DPC++ Runtime
Runtime libraries for oneAPI DPC++
group_sort_impl.hpp
Go to the documentation of this file.
1 //==------------ group_sort_impl.hpp ---------------------------------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This file includes some functions for group sorting algorithm implementations
9 //
10 
11 #pragma once
12 
13 #include <sycl/detail/helpers.hpp>
14 #include <sycl/group_barrier.hpp>
15 #include <sycl/multi_ptr.hpp>
16 
17 #ifdef __SYCL_DEVICE_ONLY__
18 
19 namespace sycl {
21 namespace detail {
22 
23 // ---- merge sort implementation
24 
25 // following two functions could be useless if std::[lower|upper]_bound worked
26 // well
27 template <typename Acc, typename Value, typename Compare>
28 std::size_t lower_bound(Acc acc, std::size_t first, std::size_t last,
29  const Value &value, Compare comp) {
30  std::size_t n = last - first;
31  std::size_t cur = n;
32  std::size_t it;
33  while (n > 0) {
34  it = first;
35  cur = n / 2;
36  it += cur;
37  if (comp(acc[it], value)) {
38  n -= cur + 1, first = ++it;
39  } else
40  n = cur;
41  }
42  return first;
43 }
44 
45 template <typename Acc, typename Value, typename Compare>
46 std::size_t upper_bound(Acc acc, const std::size_t first,
47  const std::size_t last, const Value &value,
48  Compare comp) {
49  return detail::lower_bound(acc, first, last, value,
50  [comp](auto x, auto y) { return !comp(y, x); });
51 }
52 
53 // swap for all data types including tuple-like types
54 template <typename T> void swap_tuples(T &a, T &b) { std::swap(a, b); }
55 
56 template <template <typename...> class TupleLike, typename T1, typename T2>
57 void swap_tuples(TupleLike<T1, T2> &&a, TupleLike<T1, T2> &&b) {
58  std::swap(std::get<0>(a), std::get<0>(b));
59  std::swap(std::get<1>(a), std::get<1>(b));
60 }
61 
62 template <typename Iter> struct GetValueType {
63  using type = typename std::iterator_traits<Iter>::value_type;
64 };
65 
66 template <typename ElementType, access::address_space Space,
67  access::decorated IsDecorated>
68 struct GetValueType<sycl::multi_ptr<ElementType, Space, IsDecorated>> {
69  using type = ElementType;
70 };
71 
72 // since we couldn't assign data to raw memory, it's better to use placement
73 // for first assignment
74 template <typename Acc, typename T>
75 void set_value(Acc ptr, const std::size_t idx, const T &val, bool is_first) {
76  if (is_first) {
77  ::new (ptr + idx) T(val);
78  } else {
79  ptr[idx] = val;
80  }
81 }
82 
83 template <typename InAcc, typename OutAcc, typename Compare>
84 void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
85  const std::size_t start_1, const std::size_t end_1,
86  const std::size_t end_2, const std::size_t start_out, Compare comp,
87  const std::size_t chunk, bool is_first) {
88  const std::size_t start_2 = end_1;
89  // Borders of the sequences to merge within this call
90  const std::size_t local_start_1 =
91  sycl::min(static_cast<std::size_t>(offset + start_1), end_1);
92  const std::size_t local_end_1 =
93  sycl::min(static_cast<std::size_t>(local_start_1 + chunk), end_1);
94  const std::size_t local_start_2 =
95  sycl::min(static_cast<std::size_t>(offset + start_2), end_2);
96  const std::size_t local_end_2 =
97  sycl::min(static_cast<std::size_t>(local_start_2 + chunk), end_2);
98 
99  const std::size_t local_size_1 = local_end_1 - local_start_1;
100  const std::size_t local_size_2 = local_end_2 - local_start_2;
101 
102  // TODO: process cases where all elements of 1st sequence > 2nd, 2nd > 1st
103  // to improve performance
104 
105  // Process 1st sequence
106  if (local_start_1 < local_end_1) {
107  // Reduce the range for searching within the 2nd sequence and handle bound
108  // items find left border in 2nd sequence
109  const auto local_l_item_1 = in_acc1[local_start_1];
110  std::size_t l_search_bound_2 =
111  detail::lower_bound(in_acc1, start_2, end_2, local_l_item_1, comp);
112  const std::size_t l_shift_1 = local_start_1 - start_1;
113  const std::size_t l_shift_2 = l_search_bound_2 - start_2;
114 
115  set_value(out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_1,
116  is_first);
117 
118  std::size_t r_search_bound_2{};
119  // find right border in 2nd sequence
120  if (local_size_1 > 1) {
121  const auto local_r_item_1 = in_acc1[local_end_1 - 1];
122  r_search_bound_2 = detail::lower_bound(in_acc1, l_search_bound_2, end_2,
123  local_r_item_1, comp);
124  const auto r_shift_1 = local_end_1 - 1 - start_1;
125  const auto r_shift_2 = r_search_bound_2 - start_2;
126 
127  set_value(out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_1,
128  is_first);
129  }
130 
131  // Handle intermediate items
132  for (std::size_t idx = local_start_1 + 1; idx < local_end_1 - 1; ++idx) {
133  const auto intermediate_item_1 = in_acc1[idx];
134  // we shouldn't seek in whole 2nd sequence. Just for the part where the
135  // 1st sequence should be
136  l_search_bound_2 =
137  detail::lower_bound(in_acc1, l_search_bound_2, r_search_bound_2,
138  intermediate_item_1, comp);
139  const std::size_t shift_1 = idx - start_1;
140  const std::size_t shift_2 = l_search_bound_2 - start_2;
141 
142  set_value(out_acc1, start_out + shift_1 + shift_2, intermediate_item_1,
143  is_first);
144  }
145  }
146  // Process 2nd sequence
147  if (local_start_2 < local_end_2) {
148  // Reduce the range for searching within the 1st sequence and handle bound
149  // items find left border in 1st sequence
150  const auto local_l_item_2 = in_acc1[local_start_2];
151  std::size_t l_search_bound_1 =
152  detail::upper_bound(in_acc1, start_1, end_1, local_l_item_2, comp);
153  const std::size_t l_shift_1 = l_search_bound_1 - start_1;
154  const std::size_t l_shift_2 = local_start_2 - start_2;
155 
156  set_value(out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_2,
157  is_first);
158 
159  std::size_t r_search_bound_1{};
160  // find right border in 1st sequence
161  if (local_size_2 > 1) {
162  const auto local_r_item_2 = in_acc1[local_end_2 - 1];
163  r_search_bound_1 = detail::upper_bound(in_acc1, l_search_bound_1, end_1,
164  local_r_item_2, comp);
165  const std::size_t r_shift_1 = r_search_bound_1 - start_1;
166  const std::size_t r_shift_2 = local_end_2 - 1 - start_2;
167 
168  set_value(out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_2,
169  is_first);
170  }
171 
172  // Handle intermediate items
173  for (auto idx = local_start_2 + 1; idx < local_end_2 - 1; ++idx) {
174  const auto intermediate_item_2 = in_acc1[idx];
175  // we shouldn't seek in whole 1st sequence. Just for the part where the
176  // 2nd sequence should be
177  l_search_bound_1 =
178  detail::upper_bound(in_acc1, l_search_bound_1, r_search_bound_1,
179  intermediate_item_2, comp);
180  const std::size_t shift_1 = l_search_bound_1 - start_1;
181  const std::size_t shift_2 = idx - start_2;
182 
183  set_value(out_acc1, start_out + shift_1 + shift_2, intermediate_item_2,
184  is_first);
185  }
186  }
187 }
188 
189 template <typename Iter, typename Compare>
190 void bubble_sort(Iter first, const std::size_t begin, const std::size_t end,
191  Compare comp) {
192  if (begin < end) {
193  for (std::size_t i = begin; i < end; ++i) {
194  // Handle intermediate items
195  for (std::size_t idx = i + 1; idx < end; ++idx) {
196  if (comp(first[idx], first[i])) {
197  detail::swap_tuples(first[i], first[idx]);
198  }
199  }
200  }
201  }
202 }
203 
204 template <typename Group, typename Iter, typename Compare>
205 void merge_sort(Group group, Iter first, const std::size_t n, Compare comp,
206  std::byte *scratch) {
207  using T = typename GetValueType<Iter>::type;
208  const std::size_t idx = group.get_local_linear_id();
209  const std::size_t local = group.get_local_range().size();
210  const std::size_t chunk = (n - 1) / local + 1;
211 
212  // we need to sort within work item first
213  bubble_sort(first, idx * chunk, sycl::min((idx + 1) * chunk, n), comp);
214  sycl::group_barrier(group);
215 
216  T *temp = reinterpret_cast<T *>(scratch);
217  bool data_in_temp = false;
218  bool is_first = true;
219  std::size_t sorted_size = 1;
220  while (sorted_size * chunk < n) {
221  const std::size_t start_1 =
222  sycl::min(2 * sorted_size * chunk * (idx / sorted_size), n);
223  const std::size_t end_1 = sycl::min(start_1 + sorted_size * chunk, n);
224  const std::size_t end_2 = sycl::min(end_1 + sorted_size * chunk, n);
225  const std::size_t offset = chunk * (idx % sorted_size);
226 
227  if (!data_in_temp) {
228  merge(offset, first, temp, start_1, end_1, end_2, start_1, comp, chunk,
229  is_first);
230  } else {
231  merge(offset, temp, first, start_1, end_1, end_2, start_1, comp, chunk,
232  /*is_first*/ false);
233  }
234  sycl::group_barrier(group);
235 
236  data_in_temp = !data_in_temp;
237  sorted_size *= 2;
238  if (is_first)
239  is_first = false;
240  }
241 
242  // copy back if data is in a temporary storage
243  if (data_in_temp) {
244  for (std::size_t i = 0; i < chunk; ++i) {
245  if (idx * chunk + i < n) {
246  first[idx * chunk + i] = temp[idx * chunk + i];
247  }
248  }
249  sycl::group_barrier(group);
250  }
251 }
252 
253 } // namespace detail
254 } // __SYCL_INLINE_VER_NAMESPACE(_V1)
255 } // namespace sycl
256 #endif
#define __SYCL_INLINE_VER_NAMESPACE(X)
@ local
flush out to the threadgroup's scope
__ESIMD_API simd< T, N > merge(simd< T, N > a, simd< T, N > b, simd_mask< N > m)
"Merges" elements of the input simd object according to the merge mask.
Definition: alt_ui.hpp:28
std::enable_if< is_group_v< Group > >::type group_barrier(Group, memory_scope FenceScope=Group::fence_scope)
---— Error handling, matching OpenCL plugin semantics.
Definition: access.hpp:14