DPC++ Runtime
Runtime libraries for oneAPI DPC++
group_sort_impl.hpp
Go to the documentation of this file.
1 //==------------ group_sort_impl.hpp ---------------------------------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This file includes some functions for group sorting algorithm implementations
9 //
10 
11 #pragma once
12 
13 #ifdef __SYCL_DEVICE_ONLY__
14 
15 #include <sycl/group_algorithm.hpp>
16 #include <sycl/group_barrier.hpp>
17 
18 namespace sycl {
19 inline namespace _V1 {
20 namespace detail {
21 
22 // ---- merge sort implementation
23 
24 // following two functions could be useless if std::[lower|upper]_bound worked
25 // well
26 template <typename Acc, typename Value, typename Compare>
27 size_t lower_bound(Acc acc, size_t first, size_t last, const Value &value,
28  Compare comp) {
29  size_t n = last - first;
30  size_t cur = n;
31  size_t it;
32  while (n > 0) {
33  it = first;
34  cur = n / 2;
35  it += cur;
36  if (comp(acc[it], value)) {
37  n -= cur + 1, first = ++it;
38  } else
39  n = cur;
40  }
41  return first;
42 }
43 
44 template <typename Acc, typename Value, typename Compare>
45 size_t upper_bound(Acc acc, const size_t first, const size_t last,
46  const Value &value, Compare comp) {
47  return detail::lower_bound(acc, first, last, value,
48  [comp](auto x, auto y) { return !comp(y, x); });
49 }
50 
51 // swap for all data types including tuple-like types
52 template <typename T> void swap_tuples(T &a, T &b) { std::swap(a, b); }
53 
54 template <template <typename...> class TupleLike, typename T1, typename T2>
55 void swap_tuples(TupleLike<T1, T2> &&a, TupleLike<T1, T2> &&b) {
56  std::swap(std::get<0>(a), std::get<0>(b));
57  std::swap(std::get<1>(a), std::get<1>(b));
58 }
59 
60 template <typename Iter> struct GetValueType {
61  using type = typename std::iterator_traits<Iter>::value_type;
62 };
63 
64 template <typename ElementType, access::address_space Space,
65  access::decorated IsDecorated>
66 struct GetValueType<sycl::multi_ptr<ElementType, Space, IsDecorated>> {
67  using type = ElementType;
68 };
69 
70 // since we couldn't assign data to raw memory, it's better to use placement
71 // for first assignment
72 template <typename Acc, typename T>
73 void set_value(Acc ptr, const size_t idx, const T &val, bool is_first) {
74  if (is_first) {
75  ::new (ptr + idx) T(val);
76  } else {
77  ptr[idx] = val;
78  }
79 }
80 
81 template <typename InAcc, typename OutAcc, typename Compare>
82 void merge(const size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
83  const size_t start_1, const size_t end_1, const size_t end_2,
84  const size_t start_out, Compare comp, const size_t chunk,
85  bool is_first) {
86  const size_t start_2 = end_1;
87  // Borders of the sequences to merge within this call
88  const size_t local_start_1 =
89  sycl::min(static_cast<size_t>(offset + start_1), end_1);
90  const size_t local_end_1 =
91  sycl::min(static_cast<size_t>(local_start_1 + chunk), end_1);
92  const size_t local_start_2 =
93  sycl::min(static_cast<size_t>(offset + start_2), end_2);
94  const size_t local_end_2 =
95  sycl::min(static_cast<size_t>(local_start_2 + chunk), end_2);
96 
97  const size_t local_size_1 = local_end_1 - local_start_1;
98  const size_t local_size_2 = local_end_2 - local_start_2;
99 
100  // TODO: process cases where all elements of 1st sequence > 2nd, 2nd > 1st
101  // to improve performance
102 
103  // Process 1st sequence
104  if (local_start_1 < local_end_1) {
105  // Reduce the range for searching within the 2nd sequence and handle bound
106  // items find left border in 2nd sequence
107  const auto local_l_item_1 = in_acc1[local_start_1];
108  size_t l_search_bound_2 =
109  detail::lower_bound(in_acc1, start_2, end_2, local_l_item_1, comp);
110  const size_t l_shift_1 = local_start_1 - start_1;
111  const size_t l_shift_2 = l_search_bound_2 - start_2;
112 
113  set_value(out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_1,
114  is_first);
115 
116  size_t r_search_bound_2{};
117  // find right border in 2nd sequence
118  if (local_size_1 > 1) {
119  const auto local_r_item_1 = in_acc1[local_end_1 - 1];
120  r_search_bound_2 = detail::lower_bound(in_acc1, l_search_bound_2, end_2,
121  local_r_item_1, comp);
122  const auto r_shift_1 = local_end_1 - 1 - start_1;
123  const auto r_shift_2 = r_search_bound_2 - start_2;
124 
125  set_value(out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_1,
126  is_first);
127  }
128 
129  // Handle intermediate items
130  for (size_t idx = local_start_1 + 1; idx < local_end_1 - 1; ++idx) {
131  const auto intermediate_item_1 = in_acc1[idx];
132  // we shouldn't seek in whole 2nd sequence. Just for the part where the
133  // 1st sequence should be
134  l_search_bound_2 =
135  detail::lower_bound(in_acc1, l_search_bound_2, r_search_bound_2,
136  intermediate_item_1, comp);
137  const size_t shift_1 = idx - start_1;
138  const size_t shift_2 = l_search_bound_2 - start_2;
139 
140  set_value(out_acc1, start_out + shift_1 + shift_2, intermediate_item_1,
141  is_first);
142  }
143  }
144  // Process 2nd sequence
145  if (local_start_2 < local_end_2) {
146  // Reduce the range for searching within the 1st sequence and handle bound
147  // items find left border in 1st sequence
148  const auto local_l_item_2 = in_acc1[local_start_2];
149  size_t l_search_bound_1 =
150  detail::upper_bound(in_acc1, start_1, end_1, local_l_item_2, comp);
151  const size_t l_shift_1 = l_search_bound_1 - start_1;
152  const size_t l_shift_2 = local_start_2 - start_2;
153 
154  set_value(out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_2,
155  is_first);
156 
157  size_t r_search_bound_1{};
158  // find right border in 1st sequence
159  if (local_size_2 > 1) {
160  const auto local_r_item_2 = in_acc1[local_end_2 - 1];
161  r_search_bound_1 = detail::upper_bound(in_acc1, l_search_bound_1, end_1,
162  local_r_item_2, comp);
163  const size_t r_shift_1 = r_search_bound_1 - start_1;
164  const size_t r_shift_2 = local_end_2 - 1 - start_2;
165 
166  set_value(out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_2,
167  is_first);
168  }
169 
170  // Handle intermediate items
171  for (auto idx = local_start_2 + 1; idx < local_end_2 - 1; ++idx) {
172  const auto intermediate_item_2 = in_acc1[idx];
173  // we shouldn't seek in whole 1st sequence. Just for the part where the
174  // 2nd sequence should be
175  l_search_bound_1 =
176  detail::upper_bound(in_acc1, l_search_bound_1, r_search_bound_1,
177  intermediate_item_2, comp);
178  const size_t shift_1 = l_search_bound_1 - start_1;
179  const size_t shift_2 = idx - start_2;
180 
181  set_value(out_acc1, start_out + shift_1 + shift_2, intermediate_item_2,
182  is_first);
183  }
184  }
185 }
186 
187 template <typename Iter, typename Compare>
188 void bubble_sort(Iter first, const size_t begin, const size_t end,
189  Compare comp) {
190  if (begin < end) {
191  for (size_t i = begin; i < end; ++i) {
192  // Handle intermediate items
193  for (size_t idx = i + 1; idx < end; ++idx) {
194  if (comp(first[idx], first[i])) {
195  detail::swap_tuples(first[i], first[idx]);
196  }
197  }
198  }
199  }
200 }
201 
202 template <typename Group, typename Iter, typename Compare>
203 void merge_sort(Group group, Iter first, const size_t n, Compare comp,
204  std::byte *scratch) {
205  using T = typename GetValueType<Iter>::type;
206  const size_t idx = group.get_local_linear_id();
207  const size_t local = group.get_local_range().size();
208  const size_t chunk = (n - 1) / local + 1;
209 
210  // we need to sort within work item first
211  bubble_sort(first, idx * chunk, sycl::min((idx + 1) * chunk, n), comp);
212  sycl::group_barrier(group);
213 
214  T *temp = reinterpret_cast<T *>(scratch);
215  bool data_in_temp = false;
216  bool is_first = true;
217  size_t sorted_size = 1;
218  while (sorted_size * chunk < n) {
219  const size_t start_1 =
220  sycl::min(2 * sorted_size * chunk * (idx / sorted_size), n);
221  const size_t end_1 = sycl::min(start_1 + sorted_size * chunk, n);
222  const size_t end_2 = sycl::min(end_1 + sorted_size * chunk, n);
223  const size_t offset = chunk * (idx % sorted_size);
224 
225  if (!data_in_temp) {
226  merge(offset, first, temp, start_1, end_1, end_2, start_1, comp, chunk,
227  is_first);
228  } else {
229  merge(offset, temp, first, start_1, end_1, end_2, start_1, comp, chunk,
230  /*is_first*/ false);
231  }
232  sycl::group_barrier(group);
233 
234  data_in_temp = !data_in_temp;
235  sorted_size *= 2;
236  if (is_first)
237  is_first = false;
238  }
239 
240  // copy back if data is in a temporary storage
241  if (data_in_temp) {
242  for (size_t i = 0; i < chunk; ++i) {
243  if (idx * chunk + i < n) {
244  first[idx * chunk + i] = temp[idx * chunk + i];
245  }
246  }
247  sycl::group_barrier(group);
248  }
249 }
250 
251 // traits for ascending functors
252 template <typename CompT> struct IsCompAscending {
253  static constexpr bool value = false;
254 };
255 template <typename Type> struct IsCompAscending<std::less<Type>> {
256  static constexpr bool value = true;
257 };
258 
259 // get number of states radix bits can represent
260 constexpr uint32_t getStatesInBits(uint32_t radix_bits) {
261  return (1 << radix_bits);
262 }
263 
264 //------------------------------------------------------------------------
265 // Ordered traits for a given size and integral/float flag
266 //------------------------------------------------------------------------
267 
268 template <size_t type_size, bool is_integral_type> struct GetOrdered {};
269 
270 template <> struct GetOrdered<1, true> {
271  using Type = uint8_t;
272  constexpr static int8_t mask = 0x80;
273 };
274 
275 template <> struct GetOrdered<2, true> {
276  using Type = uint16_t;
277  constexpr static int16_t mask = 0x8000;
278 };
279 
280 template <> struct GetOrdered<4, true> {
281  using Type = uint32_t;
282  constexpr static int32_t mask = 0x80000000;
283 };
284 
285 template <> struct GetOrdered<8, true> {
286  using Type = uint64_t;
287  constexpr static int64_t mask = 0x8000000000000000;
288 };
289 
290 template <> struct GetOrdered<2, false> {
291  using Type = uint16_t;
292  constexpr static uint32_t nmask = 0xFFFF; // for negative numbers
293  constexpr static uint32_t pmask = 0x8000; // for positive numbers
294 };
295 
296 template <> struct GetOrdered<4, false> {
297  using Type = uint32_t;
298  constexpr static uint32_t nmask = 0xFFFFFFFF; // for negative numbers
299  constexpr static uint32_t pmask = 0x80000000; // for positive numbers
300 };
301 
302 template <> struct GetOrdered<8, false> {
303  using Type = uint64_t;
304  constexpr static uint64_t nmask = 0xFFFFFFFFFFFFFFFF; // for negative numbers
305  constexpr static uint64_t pmask = 0x8000000000000000; // for positive numbers
306 };
307 
308 //------------------------------------------------------------------------
309 // Ordered type for a given type
310 //------------------------------------------------------------------------
311 
312 // for unknown/unsupported type we do not have any trait
313 template <typename ValT, typename Enabled = void> struct Ordered {};
314 
315 // for unsigned integrals we use the same type
316 template <typename ValT>
317 struct Ordered<ValT, std::enable_if_t<std::is_integral<ValT>::value &&
318  std::is_unsigned<ValT>::value>> {
319  using Type = ValT;
320 };
321 
322 // for signed integrals or floatings we map: size -> corresponding unsigned
323 // integral
324 template <typename ValT>
325 struct Ordered<
326  ValT, std::enable_if_t<
327  (std::is_integral<ValT>::value && std::is_signed<ValT>::value) ||
328  std::is_floating_point<ValT>::value ||
329  std::is_same<ValT, sycl::half>::value ||
330  std::is_same<ValT, sycl::ext::oneapi::bfloat16>::value>> {
331  using Type =
332  typename GetOrdered<sizeof(ValT), std::is_integral<ValT>::value>::Type;
333 };
334 
335 // shorthand
336 template <typename ValT> using OrderedT = typename Ordered<ValT>::Type;
337 
338 //------------------------------------------------------------------------
339 // functions for conversion to Ordered type
340 //------------------------------------------------------------------------
341 
342 // for already Ordered types (any uints) we use the same type
343 template <typename ValT>
344 std::enable_if_t<std::is_same_v<ValT, OrderedT<ValT>>, OrderedT<ValT>>
345 convertToOrdered(ValT value) {
346  return value;
347 }
348 
349 // converts integral type to Ordered (in terms of bitness) type
350 template <typename ValT>
351 std::enable_if_t<!std::is_same<ValT, OrderedT<ValT>>::value &&
352  std::is_integral<ValT>::value,
353  OrderedT<ValT>>
354 convertToOrdered(ValT value) {
355  ValT result = value ^ GetOrdered<sizeof(ValT), true>::mask;
356  return *reinterpret_cast<OrderedT<ValT> *>(&result);
357 }
358 
359 // converts floating type to Ordered (in terms of bitness) type
360 template <typename ValT>
361 std::enable_if_t<!std::is_same<ValT, OrderedT<ValT>>::value &&
362  (std::is_floating_point<ValT>::value ||
363  std::is_same<ValT, sycl::half>::value ||
364  std::is_same<ValT, sycl::ext::oneapi::bfloat16>::value),
365  OrderedT<ValT>>
366 convertToOrdered(ValT value) {
367  OrderedT<ValT> uvalue = *reinterpret_cast<OrderedT<ValT> *>(&value);
368  // check if value negative
369  OrderedT<ValT> is_negative = uvalue >> (sizeof(ValT) * CHAR_BIT - 1);
370  // for positive: 00..00 -> 00..00 -> 10..00
371  // for negative: 00..01 -> 11..11 -> 11..11
372  OrderedT<ValT> ordered_mask =
373  (is_negative * GetOrdered<sizeof(ValT), false>::nmask) |
374  GetOrdered<sizeof(ValT), false>::pmask;
375  return uvalue ^ ordered_mask;
376 }
377 
378 //------------------------------------------------------------------------
379 // bit pattern functions
380 //------------------------------------------------------------------------
381 
382 // required for descending comparator support
383 template <bool flag> struct InvertIf {
384  template <typename ValT> ValT operator()(ValT value) { return value; }
385 };
386 
387 // invert value if descending comparator is passed
388 template <> struct InvertIf<true> {
389  template <typename ValT> ValT operator()(ValT value) { return ~value; }
390 
391  // invertation for bool type have to be logical, rather than bit
392  bool operator()(bool value) { return !value; }
393 };
394 
395 // get bit values in a certain bucket of a value
396 template <uint32_t radix_bits, bool is_comp_asc, typename ValT>
397 uint32_t getBucketValue(ValT value, uint32_t radix_iter) {
398  // invert value if we need to sort in descending order
399  value = InvertIf<!is_comp_asc>{}(value);
400 
401  // get bucket offset idx from the end of bit type (least significant bits)
402  uint32_t bucket_offset = radix_iter * radix_bits;
403 
404  // get offset mask for one bucket, e.g.
405  // radix_bits=2: 0000 0001 -> 0000 0100 -> 0000 0011
406  OrderedT<ValT> bucket_mask = (1u << radix_bits) - 1u;
407 
408  // get bits under bucket mask
409  return (value >> bucket_offset) & bucket_mask;
410 }
411 template <typename ValT> ValT getDefaultValue(bool is_comp_asc) {
412  if (is_comp_asc)
414  else
415  return std::numeric_limits<ValT>::lowest();
416 }
417 
418 template <bool is_key_value_sort> struct ValuesAssigner {
419  template <typename IterInT, typename IterOutT>
420  void operator()(IterOutT output, size_t idx_out, IterInT input,
421  size_t idx_in) {
422  output[idx_out] = input[idx_in];
423  }
424 
425  template <typename IterOutT, typename ValT>
426  void operator()(IterOutT output, size_t idx_out, ValT value) {
427  output[idx_out] = value;
428  }
429 };
430 
431 template <> struct ValuesAssigner<false> {
432  template <typename IterInT, typename IterOutT>
433  void operator()(IterOutT, size_t, IterInT, size_t) {}
434 
435  template <typename IterOutT, typename ValT>
436  void operator()(IterOutT, size_t, ValT) {}
437 };
438 
439 // Wrapper class for scratchpad memory used by the group-sorting
440 // implementations. It simplifies accessing the supplied memory as arbitrary
441 // types without breaking strict aliasing and avoiding alignment issues.
442 struct ScratchMemory {
443 public:
444  // "Reference" object for accessing part of the scratch memory as a type T.
445  template <typename T> struct ReferenceObj {
446  public:
447  ReferenceObj() : MPtr{nullptr} {};
448 
449  operator T() const {
450  T value{0};
451  detail::memcpy(&value, MPtr, sizeof(T));
452  return value;
453  }
454 
455  T operator++(int) noexcept {
456  T value{0};
457  detail::memcpy(&value, MPtr, sizeof(T));
458  T value_before = value++;
459  detail::memcpy(MPtr, &value, sizeof(T));
460  return value_before;
461  }
462 
463  T operator++() noexcept {
464  T value{0};
465  detail::memcpy(&value, MPtr, sizeof(T));
466  ++value;
467  detail::memcpy(MPtr, &value, sizeof(T));
468  return value;
469  }
470 
471  ReferenceObj &operator=(const T &value) noexcept {
472  detail::memcpy(MPtr, &value, sizeof(T));
473  return *this;
474  }
475 
476  ReferenceObj &operator=(const ReferenceObj &value) noexcept {
477  MPtr = value.MPtr;
478  return *this;
479  }
480 
481  ReferenceObj &operator=(ReferenceObj &&value) noexcept {
482  MPtr = std::move(value.MPtr);
483  return *this;
484  }
485 
486  void copy(const ReferenceObj &value) noexcept {
487  detail::memcpy(MPtr, value.MPtr, sizeof(T));
488  }
489 
490  private:
491  ReferenceObj(std::byte *ptr) : MPtr{ptr} {}
492 
493  friend struct ScratchMemory;
494 
495  std::byte *MPtr;
496  };
497 
498  ScratchMemory operator+(size_t byte_offset) const noexcept {
499  return {MMemory + byte_offset};
500  }
501 
502  ScratchMemory(std::byte *memory) : MMemory{memory} {}
503 
504  ScratchMemory(const ScratchMemory &) = default;
505  ScratchMemory(ScratchMemory &&) = default;
506  ScratchMemory &operator=(const ScratchMemory &) = default;
507  ScratchMemory &operator=(ScratchMemory &&) = default;
508 
509  template <typename ValueT>
510  ReferenceObj<ValueT> get(size_t index) const noexcept {
511  return {MMemory + index * sizeof(ValueT)};
512  }
513 
514  std::byte *MMemory;
515 };
516 
517 // The iteration of radix sort for unknown number of elements per work item
518 template <uint32_t radix_bits, bool is_key_value_sort, bool is_comp_asc,
519  typename KeysT, typename ValueT, typename GroupT>
520 void performRadixIterDynamicSize(
521  GroupT group, const uint32_t items_per_work_item, const uint32_t radix_iter,
522  const size_t n, const ScratchMemory &keys_input,
523  const ScratchMemory &vals_input, const ScratchMemory &keys_output,
524  const ScratchMemory &vals_output, const ScratchMemory &memory) {
525  const uint32_t radix_states = getStatesInBits(radix_bits);
526  const size_t wgsize = group.get_local_linear_range();
527  const size_t idx = group.get_local_linear_id();
528 
529  // 1.1. Zeroinitialize local memory
530  for (uint32_t state = 0; state < radix_states; ++state)
531  memory.get<uint32_t>(state * wgsize + idx) = uint32_t{0};
532 
533  sycl::group_barrier(group);
534 
535  // 1.2. count values and write result to private count array and count memory
536  for (uint32_t i = 0; i < items_per_work_item; ++i) {
537  const uint32_t val_idx = items_per_work_item * idx + i;
538  // get value, convert it to Ordered (in terms of bitness)
539  const auto val =
540  convertToOrdered((val_idx < n) ? keys_input.get<KeysT>(val_idx)
541  : getDefaultValue<ValueT>(is_comp_asc));
542  // get bit values in a certain bucket of a value
543  const uint32_t bucket_val =
544  getBucketValue<radix_bits, is_comp_asc>(val, radix_iter);
545 
546  // increment counter for this bit bucket
547  if (val_idx < n)
548  ++memory.get<uint32_t>(bucket_val * wgsize + idx);
549  }
550 
551  sycl::group_barrier(group);
552 
553  // 2.1 Scan. Upsweep: reduce over radix states
554  uint32_t reduced = 0;
555  for (uint32_t i = 0; i < radix_states; ++i)
556  reduced += memory.get<uint32_t>(idx * radix_states + i);
557 
558  // 2.2. Exclusive scan: over work items
559  uint32_t scanned =
560  sycl::exclusive_scan_over_group(group, reduced, std::plus<uint32_t>());
561 
562  // 2.3. Exclusive downsweep: exclusive scan over radix states
563  for (uint32_t i = 0; i < radix_states; ++i) {
564  auto value_ref = memory.get<uint32_t>(idx * radix_states + i);
565  uint32_t value_before = value_ref;
566  value_ref = scanned;
567  scanned += value_before;
568  }
569 
570  sycl::group_barrier(group);
571 
572  uint32_t private_scan_memory[radix_states] = {0};
573 
574  // 3. Reorder
575  for (uint32_t i = 0; i < items_per_work_item; ++i) {
576  const uint32_t val_idx = items_per_work_item * idx + i;
577  // get value, convert it to Ordered (in terms of bitness)
578  auto val =
579  convertToOrdered((val_idx < n) ? keys_input.get<KeysT>(val_idx)
580  : getDefaultValue<ValueT>(is_comp_asc));
581  // get bit values in a certain bucket of a value
582  uint32_t bucket_val =
583  getBucketValue<radix_bits, is_comp_asc>(val, radix_iter);
584 
585  uint32_t new_offset_idx = private_scan_memory[bucket_val]++ +
586  memory.get<uint32_t>(bucket_val * wgsize + idx);
587  if (val_idx < n) {
588  keys_output.get<KeysT>(new_offset_idx)
589  .copy(keys_input.get<KeysT>(val_idx));
590  if constexpr (is_key_value_sort)
591  vals_output.get<ValueT>(new_offset_idx)
592  .copy(vals_input.get<ValueT>(val_idx));
593  }
594  }
595 }
596 
597 // The iteration of radix sort for known number of elements per work item
598 template <size_t items_per_work_item, uint32_t radix_bits, bool is_comp_asc,
599  bool is_key_value_sort, bool is_blocked, typename KeysT,
600  typename ValsT, typename GroupT>
601 void performRadixIterStaticSize(GroupT group, const uint32_t radix_iter,
602  const uint32_t last_iter, KeysT *keys,
603  ValsT vals, const ScratchMemory &memory) {
604  const uint32_t radix_states = getStatesInBits(radix_bits);
605  const size_t wgsize = group.get_local_linear_range();
606  const size_t idx = group.get_local_linear_id();
607 
608  // 1.1. count per witem: create a private array for storing count values
609  uint32_t count_arr[items_per_work_item] = {0};
610  uint32_t ranks[items_per_work_item] = {0};
611 
612  // 1.1. Zeroinitialize local memory
613  for (uint32_t state = 0; state < radix_states; ++state)
614  memory.get<uint32_t>(state * wgsize + idx) = uint32_t{0};
615 
616  sycl::group_barrier(group);
617 
618  ScratchMemory::ReferenceObj<uint32_t> value_refs[items_per_work_item];
619  // 1.2. count values and write result to private count array
620  for (uint32_t i = 0; i < items_per_work_item; ++i) {
621  // get value, convert it to Ordered (in terms of bitness)
622  OrderedT<KeysT> val = convertToOrdered(keys[i]);
623  // get bit values in a certain bucket of a value
624  uint32_t bucket_val =
625  getBucketValue<radix_bits, is_comp_asc>(val, radix_iter);
626  value_refs[i] = memory.get<uint32_t>(bucket_val * wgsize + idx);
627  count_arr[i] = value_refs[i]++;
628  }
629  sycl::group_barrier(group);
630 
631  // 2.1 Scan. Upsweep: reduce over radix states
632  uint32_t reduced = 0;
633  for (uint32_t i = 0; i < radix_states; ++i)
634  reduced += memory.get<uint32_t>(idx * radix_states + i);
635 
636  // 2.2. Exclusive scan: over work items
637  uint32_t scanned =
638  sycl::exclusive_scan_over_group(group, reduced, std::plus<uint32_t>());
639 
640  // 2.3. Exclusive downsweep: exclusive scan over radix states
641  for (uint32_t i = 0; i < radix_states; ++i) {
642  auto value_ref = memory.get<uint32_t>(idx * radix_states + i);
643  uint32_t value_before = value_ref;
644  value_ref = scanned;
645  scanned += value_before;
646  }
647 
648  sycl::group_barrier(group);
649 
650  // 2.4. Fill ranks with offsets
651  for (uint32_t i = 0; i < items_per_work_item; ++i)
652  ranks[i] = count_arr[i] + value_refs[i];
653 
654  sycl::group_barrier(group);
655 
656  // 3. Reorder
657  const ScratchMemory &keys_temp = memory;
658  const ScratchMemory vals_temp =
659  memory + wgsize * items_per_work_item * sizeof(KeysT);
660  for (uint32_t i = 0; i < items_per_work_item; ++i) {
661  keys_temp.get<KeysT>(ranks[i]) = keys[i];
662  if constexpr (is_key_value_sort)
663  vals_temp.get<ValsT>(ranks[i]) = vals[i];
664  }
665 
666  sycl::group_barrier(group);
667 
668  // 4. Copy back to input
669  for (uint32_t i = 0; i < items_per_work_item; ++i) {
670  size_t shift = idx * items_per_work_item + i;
671  if constexpr (!is_blocked) {
672  if (radix_iter == last_iter - 1)
673  shift = i * wgsize + idx;
674  }
675  keys[i] = keys_temp.get<KeysT>(shift);
676  if constexpr (is_key_value_sort)
677  vals[i] = vals_temp.get<ValsT>(shift);
678  }
679 }
680 
681 template <bool is_key_value_sort, bool is_comp_asc,
682  uint32_t items_per_work_item = 1, uint32_t radix_bits = 4,
683  typename GroupT, typename KeysT, typename ValsT>
684 void privateDynamicSort(GroupT group, KeysT *keys, ValsT *values,
685  const size_t n, std::byte *scratch,
686  const uint32_t first_bit, const uint32_t last_bit) {
687  const size_t wgsize = group.get_local_linear_range();
688  constexpr uint32_t radix_states = getStatesInBits(radix_bits);
689  const uint32_t first_iter = first_bit / radix_bits;
690  const uint32_t last_iter = last_bit / radix_bits;
691 
692  ScratchMemory keys_input{reinterpret_cast<std::byte *>(keys)};
693  ScratchMemory vals_input{reinterpret_cast<std::byte *>(values)};
694  const uint32_t runtime_items_per_work_item = (n - 1) / wgsize + 1;
695 
696  // Create scratch wrapper.
697  ScratchMemory wrapped_scratch{scratch};
698  // set pointers to unaligned memory
699  ScratchMemory keys_output =
700  wrapped_scratch + radix_states * wgsize * sizeof(uint32_t);
701  // Adding 4 bytes extra space for keys due to specifics of some hardware
702  // architectures.
703  ScratchMemory vals_output =
704  keys_output + is_key_value_sort * n * sizeof(KeysT) + alignof(uint32_t);
705 
706  for (uint32_t radix_iter = first_iter; radix_iter < last_iter; ++radix_iter) {
707  performRadixIterDynamicSize<radix_bits, is_key_value_sort, is_comp_asc,
708  KeysT, ValsT>(
709  group, runtime_items_per_work_item, radix_iter, n, keys_input,
710  vals_input, keys_output, vals_output, wrapped_scratch);
711 
712  sycl::group_barrier(group);
713 
714  std::swap(keys_input, keys_output);
715  std::swap(vals_input, vals_output);
716  }
717 }
718 
719 template <bool is_key_value_sort, bool is_blocked, bool is_comp_asc,
720  size_t items_per_work_item = 1, uint32_t radix_bits = 4,
721  typename GroupT, typename T, typename U>
722 void privateStaticSort(GroupT group, T *keys, U *values, std::byte *scratch,
723  const uint32_t first_bit, const uint32_t last_bit) {
724 
725  const uint32_t first_iter = first_bit / radix_bits;
726  const uint32_t last_iter = last_bit / radix_bits;
727 
728  for (uint32_t radix_iter = first_iter; radix_iter < last_iter; ++radix_iter) {
729  performRadixIterStaticSize<items_per_work_item, radix_bits, is_comp_asc,
730  is_key_value_sort, is_blocked>(
731  group, radix_iter, last_iter, keys, values, scratch);
732  sycl::group_barrier(group);
733  }
734 }
735 
736 } // namespace detail
737 } // namespace _V1
738 } // namespace sycl
739 #endif
__ESIMD_API simd< T, N > merge(simd< T, N > a, simd< T, N > b, simd_mask< N > m)
"Merges" elements of the input simd object according to the merge mask.
Definition: alt_ui.hpp:30
conditional< sizeof(long)==8, long, long long >::type int64_t
Definition: kernel_desc.hpp:35
void memcpy(void *Dst, const void *Src, size_t Size)
Definition: memcpy.hpp:16
auto operator+(const __ESIMD_DNS::simd_obj_impl< __raw_t< T1 >, N, SimdT< T1, N >> &LHS, const __ESIMD_DNS::simd_obj_impl< __raw_t< T2 >, N, SimdT< T2, N >> &RHS)
Definition: operators.hpp:187
@ group
Wait until all previous memory transactions from this thread are observed within the local thread-gro...
@ local
Wait until all previous memory transactions from this thread are observed within the local sub-slice.
annotated_ptr & operator++() noexcept
std::enable_if_t<(is_group_v< std::decay_t< Group >> &&(detail::is_scalar_arithmetic< T >::value||(detail::is_complex< T >::value &&detail::is_multiplies< T, BinaryOperation >::value)) &&detail::is_native_op< T, BinaryOperation >::value), T > exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op)
void group_barrier(ext::oneapi::experimental::root_group< dimensions > G, memory_scope FenceScope=decltype(G)::fence_scope)
Definition: root_group.hpp:112
pointer get() const
Definition: multi_ptr.hpp:544
PropertyListT int access::address_space multi_ptr & operator=(multi_ptr &&)=default
const void value_type
Definition: multi_ptr.hpp:457
Definition: access.hpp:18
_Abi const simd< _Tp, _Abi > & noexcept
Definition: simd.hpp:1324