13 #if __cplusplus >= 201703L
16 #ifdef __SYCL_DEVICE_ONLY__
26 template <
typename Acc,
typename Value,
typename Compare>
27 std::size_t lower_bound(Acc acc, std::size_t first, std::size_t last,
28 const Value &value, Compare comp) {
29 std::size_t n = last - first;
36 if (comp(acc[it], value)) {
37 n -= cur + 1, first = ++it;
44 template <
typename Acc,
typename Value,
typename Compare>
45 std::size_t upper_bound(Acc acc,
const std::size_t first,
46 const std::size_t last,
const Value &value,
48 return detail::lower_bound(acc, first, last, value,
49 [comp](
auto x,
auto y) {
return !comp(y, x); });
53 template <
typename T>
void swap_tuples(
T &a,
T &b) { std::swap(a, b); }
55 template <
template <
typename...>
class TupleLike,
typename T1,
typename T2>
56 void swap_tuples(TupleLike<T1, T2> &&a, TupleLike<T1, T2> &&b) {
57 std::swap(std::get<0>(a), std::get<0>(b));
58 std::swap(std::get<1>(a), std::get<1>(b));
61 template <
typename Iter>
struct GetValueType {
62 using type =
typename std::iterator_traits<Iter>::value_type;
65 template <
typename ElementType, access::address_space Space>
66 struct GetValueType<
sycl::multi_ptr<ElementType, Space>> {
67 using type = ElementType;
72 template <
typename Acc,
typename T>
73 void set_value(Acc ptr,
const std::size_t idx,
const T &val,
bool is_first) {
75 ::new (ptr + idx)
T(val);
81 template <
typename InAcc,
typename OutAcc,
typename Compare>
82 void merge(
const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
83 const std::size_t start_1,
const std::size_t end_1,
84 const std::size_t end_2,
const std::size_t start_out, Compare comp,
85 const std::size_t chunk,
bool is_first) {
86 const std::size_t start_2 = end_1;
88 const std::size_t local_start_1 =
89 sycl::min(
static_cast<std::size_t
>(offset + start_1), end_1);
90 const std::size_t local_end_1 =
91 sycl::min(
static_cast<std::size_t
>(local_start_1 + chunk), end_1);
92 const std::size_t local_start_2 =
93 sycl::min(
static_cast<std::size_t
>(offset + start_2), end_2);
94 const std::size_t local_end_2 =
95 sycl::min(
static_cast<std::size_t
>(local_start_2 + chunk), end_2);
97 const std::size_t local_size_1 = local_end_1 - local_start_1;
98 const std::size_t local_size_2 = local_end_2 - local_start_2;
104 if (local_start_1 < local_end_1) {
107 const auto local_l_item_1 = in_acc1[local_start_1];
108 std::size_t l_search_bound_2 =
109 detail::lower_bound(in_acc1, start_2, end_2, local_l_item_1, comp);
110 const std::size_t l_shift_1 = local_start_1 - start_1;
111 const std::size_t l_shift_2 = l_search_bound_2 - start_2;
113 set_value(out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_1,
116 std::size_t r_search_bound_2{};
118 if (local_size_1 > 1) {
119 const auto local_r_item_1 = in_acc1[local_end_1 - 1];
120 r_search_bound_2 = detail::lower_bound(in_acc1, l_search_bound_2, end_2,
121 local_r_item_1, comp);
122 const auto r_shift_1 = local_end_1 - 1 - start_1;
123 const auto r_shift_2 = r_search_bound_2 - start_2;
125 set_value(out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_1,
130 for (std::size_t idx = local_start_1 + 1; idx < local_end_1 - 1; ++idx) {
131 const auto intermediate_item_1 = in_acc1[idx];
135 detail::lower_bound(in_acc1, l_search_bound_2, r_search_bound_2,
136 intermediate_item_1, comp);
137 const std::size_t shift_1 = idx - start_1;
138 const std::size_t shift_2 = l_search_bound_2 - start_2;
140 set_value(out_acc1, start_out + shift_1 + shift_2, intermediate_item_1,
145 if (local_start_2 < local_end_2) {
148 const auto local_l_item_2 = in_acc1[local_start_2];
149 std::size_t l_search_bound_1 =
150 detail::upper_bound(in_acc1, start_1, end_1, local_l_item_2, comp);
151 const std::size_t l_shift_1 = l_search_bound_1 - start_1;
152 const std::size_t l_shift_2 = local_start_2 - start_2;
154 set_value(out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_2,
157 std::size_t r_search_bound_1{};
159 if (local_size_2 > 1) {
160 const auto local_r_item_2 = in_acc1[local_end_2 - 1];
161 r_search_bound_1 = detail::upper_bound(in_acc1, l_search_bound_1, end_1,
162 local_r_item_2, comp);
163 const std::size_t r_shift_1 = r_search_bound_1 - start_1;
164 const std::size_t r_shift_2 = local_end_2 - 1 - start_2;
166 set_value(out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_2,
171 for (
auto idx = local_start_2 + 1; idx < local_end_2 - 1; ++idx) {
172 const auto intermediate_item_2 = in_acc1[idx];
176 detail::upper_bound(in_acc1, l_search_bound_1, r_search_bound_1,
177 intermediate_item_2, comp);
178 const std::size_t shift_1 = l_search_bound_1 - start_1;
179 const std::size_t shift_2 = idx - start_2;
181 set_value(out_acc1, start_out + shift_1 + shift_2, intermediate_item_2,
187 template <
typename Iter,
typename Compare>
188 void bubble_sort(Iter first,
const std::size_t begin,
const std::size_t end,
191 for (std::size_t i = begin; i < end; ++i) {
193 for (std::size_t idx = i + 1; idx < end; ++idx) {
194 if (comp(first[idx], first[i])) {
195 detail::swap_tuples(first[i], first[idx]);
202 template <
typename Group,
typename Iter,
typename Compare>
203 void merge_sort(Group group, Iter first,
const std::size_t n, Compare comp,
205 using T =
typename GetValueType<Iter>::type;
206 auto id = sycl::detail::Builder::getNDItem<Group::dimensions>();
207 const std::size_t idx =
id.get_local_linear_id();
208 const std::size_t
local =
group.get_local_range().size();
209 const std::size_t chunk = (n - 1) / local + 1;
212 bubble_sort(first, idx * chunk, sycl::min((idx + 1) * chunk, n), comp);
215 T *temp =
reinterpret_cast<T *
>(scratch);
216 bool data_in_temp =
false;
217 bool is_first =
true;
218 std::size_t sorted_size = 1;
219 while (sorted_size * chunk < n) {
220 const std::size_t start_1 =
221 sycl::min(2 * sorted_size * chunk * (idx / sorted_size), n);
222 const std::size_t end_1 = sycl::min(start_1 + sorted_size * chunk, n);
223 const std::size_t end_2 = sycl::min(end_1 + sorted_size * chunk, n);
224 const std::size_t offset = chunk * (idx % sorted_size);
227 merge(offset, first, temp, start_1, end_1, end_2, start_1, comp, chunk,
230 merge(offset, temp, first, start_1, end_1, end_2, start_1, comp, chunk,
235 data_in_temp = !data_in_temp;
243 for (std::size_t i = 0; i < chunk; ++i) {
244 if (idx * chunk + i < n) {
245 first[idx * chunk + i] = temp[idx * chunk + i];
256 #endif // __cplusplus >=201703L