13 #ifdef __SYCL_DEVICE_ONLY__
19 inline namespace _V1 {
26 template <
typename Acc,
typename Value,
typename Compare>
27 size_t lower_bound(Acc acc,
size_t first,
size_t last,
const Value &value,
29 size_t n = last - first;
36 if (comp(acc[it], value)) {
37 n -= cur + 1, first = ++it;
44 template <
typename Acc,
typename Value,
typename Compare>
45 size_t upper_bound(Acc acc,
const size_t first,
const size_t last,
46 const Value &value, Compare comp) {
47 return detail::lower_bound(acc, first, last, value,
48 [comp](
auto x,
auto y) {
return !comp(y, x); });
52 template <
typename T>
void swap_tuples(T &a, T &b) { std::swap(a, b); }
54 template <
template <
typename...>
class TupleLike,
typename T1,
typename T2>
55 void swap_tuples(TupleLike<T1, T2> &&a, TupleLike<T1, T2> &&b) {
56 std::swap(std::get<0>(a), std::get<0>(b));
57 std::swap(std::get<1>(a), std::get<1>(b));
60 template <
typename Iter>
struct GetValueType {
66 struct GetValueType<
sycl::
multi_ptr<ElementType, Space, IsDecorated>> {
67 using type = ElementType;
72 template <
typename Acc,
typename T>
73 void set_value(Acc ptr,
const size_t idx,
const T &val,
bool is_first) {
75 ::new (ptr + idx) T(val);
81 template <
typename InAcc,
typename OutAcc,
typename Compare>
82 void merge(
const size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
83 const size_t start_1,
const size_t end_1,
const size_t end_2,
84 const size_t start_out, Compare comp,
const size_t chunk,
86 const size_t start_2 = end_1;
88 const size_t local_start_1 =
89 sycl::min(
static_cast<size_t>(offset + start_1), end_1);
90 const size_t local_end_1 =
91 sycl::min(
static_cast<size_t>(local_start_1 + chunk), end_1);
92 const size_t local_start_2 =
93 sycl::min(
static_cast<size_t>(offset + start_2), end_2);
94 const size_t local_end_2 =
95 sycl::min(
static_cast<size_t>(local_start_2 + chunk), end_2);
97 const size_t local_size_1 = local_end_1 - local_start_1;
98 const size_t local_size_2 = local_end_2 - local_start_2;
104 if (local_start_1 < local_end_1) {
107 const auto local_l_item_1 = in_acc1[local_start_1];
108 size_t l_search_bound_2 =
109 detail::lower_bound(in_acc1, start_2, end_2, local_l_item_1, comp);
110 const size_t l_shift_1 = local_start_1 - start_1;
111 const size_t l_shift_2 = l_search_bound_2 - start_2;
113 set_value(out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_1,
116 size_t r_search_bound_2{};
118 if (local_size_1 > 1) {
119 const auto local_r_item_1 = in_acc1[local_end_1 - 1];
120 r_search_bound_2 = detail::lower_bound(in_acc1, l_search_bound_2, end_2,
121 local_r_item_1, comp);
122 const auto r_shift_1 = local_end_1 - 1 - start_1;
123 const auto r_shift_2 = r_search_bound_2 - start_2;
125 set_value(out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_1,
130 for (
size_t idx = local_start_1 + 1; idx < local_end_1 - 1; ++idx) {
131 const auto intermediate_item_1 = in_acc1[idx];
135 detail::lower_bound(in_acc1, l_search_bound_2, r_search_bound_2,
136 intermediate_item_1, comp);
137 const size_t shift_1 = idx - start_1;
138 const size_t shift_2 = l_search_bound_2 - start_2;
140 set_value(out_acc1, start_out + shift_1 + shift_2, intermediate_item_1,
145 if (local_start_2 < local_end_2) {
148 const auto local_l_item_2 = in_acc1[local_start_2];
149 size_t l_search_bound_1 =
150 detail::upper_bound(in_acc1, start_1, end_1, local_l_item_2, comp);
151 const size_t l_shift_1 = l_search_bound_1 - start_1;
152 const size_t l_shift_2 = local_start_2 - start_2;
154 set_value(out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_2,
157 size_t r_search_bound_1{};
159 if (local_size_2 > 1) {
160 const auto local_r_item_2 = in_acc1[local_end_2 - 1];
161 r_search_bound_1 = detail::upper_bound(in_acc1, l_search_bound_1, end_1,
162 local_r_item_2, comp);
163 const size_t r_shift_1 = r_search_bound_1 - start_1;
164 const size_t r_shift_2 = local_end_2 - 1 - start_2;
166 set_value(out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_2,
171 for (
auto idx = local_start_2 + 1; idx < local_end_2 - 1; ++idx) {
172 const auto intermediate_item_2 = in_acc1[idx];
176 detail::upper_bound(in_acc1, l_search_bound_1, r_search_bound_1,
177 intermediate_item_2, comp);
178 const size_t shift_1 = l_search_bound_1 - start_1;
179 const size_t shift_2 = idx - start_2;
181 set_value(out_acc1, start_out + shift_1 + shift_2, intermediate_item_2,
187 template <
typename Iter,
typename Compare>
188 void bubble_sort(Iter first,
const size_t begin,
const size_t end,
191 for (
size_t i = begin; i <
end; ++i) {
193 for (
size_t idx = i + 1; idx <
end; ++idx) {
194 if (comp(first[idx], first[i])) {
195 detail::swap_tuples(first[i], first[idx]);
202 template <
typename Group,
typename Iter,
typename Compare>
203 void merge_sort(Group group, Iter first,
const size_t n, Compare comp,
205 using T =
typename GetValueType<Iter>::type;
206 const size_t idx =
group.get_local_linear_id();
207 const size_t local =
group.get_local_range().size();
208 const size_t chunk = (n - 1) / local + 1;
211 bubble_sort(first, idx * chunk,
sycl::min((idx + 1) * chunk, n), comp);
214 T *temp =
reinterpret_cast<T *
>(scratch);
215 bool data_in_temp =
false;
216 bool is_first =
true;
217 size_t sorted_size = 1;
218 while (sorted_size * chunk < n) {
219 const size_t start_1 =
220 sycl::min(2 * sorted_size * chunk * (idx / sorted_size), n);
221 const size_t end_1 =
sycl::min(start_1 + sorted_size * chunk, n);
222 const size_t end_2 =
sycl::min(end_1 + sorted_size * chunk, n);
223 const size_t offset = chunk * (idx % sorted_size);
226 merge(offset, first, temp, start_1, end_1, end_2, start_1, comp, chunk,
229 merge(offset, temp, first, start_1, end_1, end_2, start_1, comp, chunk,
234 data_in_temp = !data_in_temp;
242 for (
size_t i = 0; i < chunk; ++i) {
243 if (idx * chunk + i < n) {
244 first[idx * chunk + i] = temp[idx * chunk + i];
252 template <
typename CompT>
struct IsCompAscending {
253 static constexpr
bool value =
false;
255 template <
typename Type>
struct IsCompAscending<
std::less<Type>> {
256 static constexpr
bool value =
true;
260 constexpr uint32_t getStatesInBits(uint32_t radix_bits) {
261 return (1 << radix_bits);
268 template <
size_t type_size,
bool is_
integral_type>
struct GetOrdered {};
270 template <>
struct GetOrdered<1, true> {
271 using Type = uint8_t;
272 constexpr
static int8_t mask = 0x80;
275 template <>
struct GetOrdered<2, true> {
276 using Type = uint16_t;
277 constexpr
static int16_t mask = 0x8000;
280 template <>
struct GetOrdered<4, true> {
281 using Type = uint32_t;
282 constexpr
static int32_t mask = 0x80000000;
285 template <>
struct GetOrdered<8, true> {
286 using Type = uint64_t;
287 constexpr
static int64_t mask = 0x8000000000000000;
290 template <>
struct GetOrdered<2, false> {
291 using Type = uint16_t;
292 constexpr
static uint32_t nmask = 0xFFFF;
293 constexpr
static uint32_t pmask = 0x8000;
296 template <>
struct GetOrdered<4, false> {
297 using Type = uint32_t;
298 constexpr
static uint32_t nmask = 0xFFFFFFFF;
299 constexpr
static uint32_t pmask = 0x80000000;
302 template <>
struct GetOrdered<8, false> {
303 using Type = uint64_t;
304 constexpr
static uint64_t nmask = 0xFFFFFFFFFFFFFFFF;
305 constexpr
static uint64_t pmask = 0x8000000000000000;
313 template <
typename ValT,
typename Enabled =
void>
struct Ordered {};
316 template <
typename ValT>
317 struct Ordered<ValT,
std::enable_if_t<std::is_integral<ValT>::value &&
318 std::is_unsigned<ValT>::value>> {
324 template <
typename ValT>
326 ValT,
std::enable_if_t<
327 (std::is_integral<ValT>::value && std::is_signed<ValT>::value) ||
328 std::is_floating_point<ValT>::value ||
329 std::is_same<ValT, sycl::half>::value ||
330 std::is_same<ValT, sycl::ext::oneapi::bfloat16>::value>> {
332 typename GetOrdered<
sizeof(ValT), std::is_integral<ValT>::value>::Type;
336 template <
typename ValT>
using OrderedT =
typename Ordered<ValT>::Type;
343 template <
typename ValT>
344 std::enable_if_t<std::is_same_v<ValT, OrderedT<ValT>>, OrderedT<ValT>>
345 convertToOrdered(ValT value) {
350 template <
typename ValT>
351 std::enable_if_t<!std::is_same<ValT, OrderedT<ValT>>::value &&
352 std::is_integral<ValT>::value,
354 convertToOrdered(ValT value) {
355 ValT result = value ^ GetOrdered<
sizeof(ValT),
true>::mask;
356 return *
reinterpret_cast<OrderedT<ValT> *
>(&result);
360 template <
typename ValT>
361 std::enable_if_t<!std::is_same<ValT, OrderedT<ValT>>::value &&
362 (std::is_floating_point<ValT>::value ||
363 std::is_same<ValT, sycl::half>::value ||
364 std::is_same<ValT, sycl::ext::oneapi::bfloat16>::value),
366 convertToOrdered(ValT value) {
367 OrderedT<ValT> uvalue = *
reinterpret_cast<OrderedT<ValT> *
>(&value);
369 OrderedT<ValT> is_negative = uvalue >> (
sizeof(ValT) * CHAR_BIT - 1);
372 OrderedT<ValT> ordered_mask =
373 (is_negative * GetOrdered<
sizeof(ValT),
false>::nmask) |
374 GetOrdered<
sizeof(ValT),
false>::pmask;
375 return uvalue ^ ordered_mask;
383 template <
bool flag>
struct InvertIf {
384 template <
typename ValT> ValT operator()(ValT value) {
return value; }
388 template <>
struct InvertIf<true> {
389 template <
typename ValT> ValT operator()(ValT value) {
return ~value; }
392 bool operator()(
bool value) {
return !value; }
396 template <u
int32_t radix_bits,
bool is_comp_asc,
typename ValT>
397 uint32_t getBucketValue(ValT value, uint32_t radix_iter) {
399 value = InvertIf<!is_comp_asc>{}(value);
402 uint32_t bucket_offset = radix_iter * radix_bits;
406 OrderedT<ValT> bucket_mask = (1u << radix_bits) - 1u;
409 return (value >> bucket_offset) & bucket_mask;
411 template <
typename ValT> ValT getDefaultValue(
bool is_comp_asc) {
415 return std::numeric_limits<ValT>::lowest();
418 template <
bool is_key_value_sort>
struct ValuesAssigner {
419 template <
typename IterInT,
typename IterOutT>
420 void operator()(IterOutT output,
size_t idx_out, IterInT input,
422 output[idx_out] =
input[idx_in];
425 template <
typename IterOutT,
typename ValT>
426 void operator()(IterOutT output,
size_t idx_out, ValT value) {
427 output[idx_out] = value;
431 template <>
struct ValuesAssigner<false> {
432 template <
typename IterInT,
typename IterOutT>
433 void operator()(IterOutT,
size_t, IterInT,
size_t) {}
435 template <
typename IterOutT,
typename ValT>
436 void operator()(IterOutT,
size_t, ValT) {}
442 struct ScratchMemory {
445 template <
typename T>
struct ReferenceObj {
447 ReferenceObj() : MPtr{nullptr} {};
458 T value_before = value++;
482 MPtr = std::move(value.MPtr);
486 void copy(
const ReferenceObj &value)
noexcept {
491 ReferenceObj(
std::byte *ptr) : MPtr{ptr} {}
493 friend struct ScratchMemory;
499 return {MMemory + byte_offset};
502 ScratchMemory(
std::byte *memory) : MMemory{memory} {}
504 ScratchMemory(
const ScratchMemory &) =
default;
505 ScratchMemory(ScratchMemory &&) =
default;
506 ScratchMemory &
operator=(
const ScratchMemory &) =
default;
507 ScratchMemory &
operator=(ScratchMemory &&) =
default;
509 template <
typename ValueT>
510 ReferenceObj<ValueT>
get(
size_t index)
const noexcept {
511 return {MMemory + index *
sizeof(ValueT)};
518 template <uint32_t radix_bits,
bool is_key_value_sort,
bool is_comp_asc,
519 typename KeysT,
typename ValueT,
typename GroupT>
520 void performRadixIterDynamicSize(
521 GroupT group,
const uint32_t items_per_work_item,
const uint32_t radix_iter,
522 const size_t n,
const ScratchMemory &keys_input,
523 const ScratchMemory &vals_input,
const ScratchMemory &keys_output,
524 const ScratchMemory &vals_output,
const ScratchMemory &memory) {
525 const uint32_t radix_states = getStatesInBits(radix_bits);
526 const size_t wgsize =
group.get_local_linear_range();
527 const size_t idx =
group.get_local_linear_id();
530 for (uint32_t state = 0; state < radix_states; ++state)
531 memory.get<uint32_t>(state * wgsize + idx) = uint32_t{0};
536 for (uint32_t i = 0; i < items_per_work_item; ++i) {
537 const uint32_t val_idx = items_per_work_item * idx + i;
540 convertToOrdered((val_idx < n) ? keys_input.get<KeysT>(val_idx)
541 : getDefaultValue<ValueT>(is_comp_asc));
543 const uint32_t bucket_val =
544 getBucketValue<radix_bits, is_comp_asc>(val, radix_iter);
548 ++memory.get<uint32_t>(bucket_val * wgsize + idx);
554 uint32_t reduced = 0;
555 for (uint32_t i = 0; i < radix_states; ++i)
556 reduced += memory.get<uint32_t>(idx * radix_states + i);
563 for (uint32_t i = 0; i < radix_states; ++i) {
564 auto value_ref = memory.get<uint32_t>(idx * radix_states + i);
565 uint32_t value_before = value_ref;
567 scanned += value_before;
572 uint32_t private_scan_memory[radix_states] = {0};
575 for (uint32_t i = 0; i < items_per_work_item; ++i) {
576 const uint32_t val_idx = items_per_work_item * idx + i;
579 convertToOrdered((val_idx < n) ? keys_input.get<KeysT>(val_idx)
580 : getDefaultValue<ValueT>(is_comp_asc));
582 uint32_t bucket_val =
583 getBucketValue<radix_bits, is_comp_asc>(val, radix_iter);
585 uint32_t new_offset_idx = private_scan_memory[bucket_val]++ +
586 memory.get<uint32_t>(bucket_val * wgsize + idx);
588 keys_output.get<KeysT>(new_offset_idx)
589 .copy(keys_input.get<KeysT>(val_idx));
590 if constexpr (is_key_value_sort)
591 vals_output.get<ValueT>(new_offset_idx)
592 .copy(vals_input.get<ValueT>(val_idx));
598 template <
size_t items_per_work_item, uint32_t radix_bits,
bool is_comp_asc,
599 bool is_key_value_sort,
bool is_blocked,
typename KeysT,
600 typename ValsT,
typename GroupT>
601 void performRadixIterStaticSize(GroupT group,
const uint32_t radix_iter,
602 const uint32_t last_iter, KeysT *keys,
603 ValsT vals,
const ScratchMemory &memory) {
604 const uint32_t radix_states = getStatesInBits(radix_bits);
605 const size_t wgsize =
group.get_local_linear_range();
606 const size_t idx =
group.get_local_linear_id();
609 uint32_t count_arr[items_per_work_item] = {0};
610 uint32_t ranks[items_per_work_item] = {0};
613 for (uint32_t state = 0; state < radix_states; ++state)
614 memory.get<uint32_t>(state * wgsize + idx) = uint32_t{0};
618 ScratchMemory::ReferenceObj<uint32_t> value_refs[items_per_work_item];
620 for (uint32_t i = 0; i < items_per_work_item; ++i) {
622 OrderedT<KeysT> val = convertToOrdered(keys[i]);
624 uint32_t bucket_val =
625 getBucketValue<radix_bits, is_comp_asc>(val, radix_iter);
626 value_refs[i] = memory.get<uint32_t>(bucket_val * wgsize + idx);
627 count_arr[i] = value_refs[i]++;
632 uint32_t reduced = 0;
633 for (uint32_t i = 0; i < radix_states; ++i)
634 reduced += memory.get<uint32_t>(idx * radix_states + i);
641 for (uint32_t i = 0; i < radix_states; ++i) {
642 auto value_ref = memory.get<uint32_t>(idx * radix_states + i);
643 uint32_t value_before = value_ref;
645 scanned += value_before;
651 for (uint32_t i = 0; i < items_per_work_item; ++i)
652 ranks[i] = count_arr[i] + value_refs[i];
657 const ScratchMemory &keys_temp = memory;
658 const ScratchMemory vals_temp =
659 memory + wgsize * items_per_work_item *
sizeof(KeysT);
660 for (uint32_t i = 0; i < items_per_work_item; ++i) {
661 keys_temp.get<KeysT>(ranks[i]) = keys[i];
662 if constexpr (is_key_value_sort)
663 vals_temp.get<ValsT>(ranks[i]) = vals[i];
669 for (uint32_t i = 0; i < items_per_work_item; ++i) {
670 size_t shift = idx * items_per_work_item + i;
671 if constexpr (!is_blocked) {
672 if (radix_iter == last_iter - 1)
673 shift = i * wgsize + idx;
675 keys[i] = keys_temp.get<KeysT>(shift);
676 if constexpr (is_key_value_sort)
677 vals[i] = vals_temp.get<ValsT>(shift);
681 template <
bool is_key_value_sort,
bool is_comp_asc,
682 uint32_t items_per_work_item = 1, uint32_t radix_bits = 4,
683 typename GroupT,
typename KeysT,
typename ValsT>
684 void privateDynamicSort(GroupT group, KeysT *keys, ValsT *values,
686 const uint32_t first_bit,
const uint32_t last_bit) {
687 const size_t wgsize =
group.get_local_linear_range();
688 constexpr uint32_t radix_states = getStatesInBits(radix_bits);
689 const uint32_t first_iter = first_bit / radix_bits;
690 const uint32_t last_iter = last_bit / radix_bits;
692 ScratchMemory keys_input{
reinterpret_cast<std::byte *
>(keys)};
693 ScratchMemory vals_input{
reinterpret_cast<std::byte *
>(values)};
694 const uint32_t runtime_items_per_work_item = (n - 1) / wgsize + 1;
697 ScratchMemory wrapped_scratch{scratch};
699 ScratchMemory keys_output =
700 wrapped_scratch + radix_states * wgsize *
sizeof(uint32_t);
703 ScratchMemory vals_output =
704 keys_output + is_key_value_sort * n *
sizeof(KeysT) +
alignof(uint32_t);
706 for (uint32_t radix_iter = first_iter; radix_iter < last_iter; ++radix_iter) {
707 performRadixIterDynamicSize<radix_bits, is_key_value_sort, is_comp_asc,
709 group, runtime_items_per_work_item, radix_iter, n, keys_input,
710 vals_input, keys_output, vals_output, wrapped_scratch);
714 std::swap(keys_input, keys_output);
715 std::swap(vals_input, vals_output);
719 template <
bool is_key_value_sort,
bool is_blocked,
bool is_comp_asc,
720 size_t items_per_work_item = 1, uint32_t radix_bits = 4,
721 typename GroupT,
typename T,
typename U>
722 void privateStaticSort(GroupT group, T *keys, U *values,
std::byte *scratch,
723 const uint32_t first_bit,
const uint32_t last_bit) {
725 const uint32_t first_iter = first_bit / radix_bits;
726 const uint32_t last_iter = last_bit / radix_bits;
728 for (uint32_t radix_iter = first_iter; radix_iter < last_iter; ++radix_iter) {
729 performRadixIterStaticSize<items_per_work_item, radix_bits, is_comp_asc,
730 is_key_value_sort, is_blocked>(
731 group, radix_iter, last_iter, keys, values, scratch);
__ESIMD_API simd< T, N > merge(simd< T, N > a, simd< T, N > b, simd_mask< N > m)
"Merges" elements of the input simd object according to the merge mask.
conditional< sizeof(long)==8, long, long long >::type int64_t
void memcpy(void *Dst, const void *Src, size_t Size)
auto operator+(const __ESIMD_DNS::simd_obj_impl< __raw_t< T1 >, N, SimdT< T1, N >> &LHS, const __ESIMD_DNS::simd_obj_impl< __raw_t< T2 >, N, SimdT< T2, N >> &RHS)
@ group
Wait until all previous memory transactions from this thread are observed within the local thread-gro...
@ local
Wait until all previous memory transactions from this thread are observed within the local sub-slice.
annotated_ptr & operator++() noexcept
std::enable_if_t<(is_group_v< std::decay_t< Group >> &&(detail::is_scalar_arithmetic< T >::value||(detail::is_complex< T >::value &&detail::is_multiplies< T, BinaryOperation >::value)) &&detail::is_native_op< T, BinaryOperation >::value), T > exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op)
void group_barrier(ext::oneapi::experimental::root_group< dimensions > G, memory_scope FenceScope=decltype(G)::fence_scope)
PropertyListT int access::address_space multi_ptr & operator=(multi_ptr &&)=default
_Abi const simd< _Tp, _Abi > & noexcept