DPC++ Runtime
Runtime libraries for oneAPI Data Parallel C++
group_algorithm.hpp
Go to the documentation of this file.
1 //==------------------------ group_algorithm.hpp ---------------------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 #include <CL/__spirv/spirv_ops.hpp>
13 #include <CL/sycl/detail/spirv.hpp>
15 #include <CL/sycl/functional.hpp>
16 #include <CL/sycl/group.hpp>
18 #include <CL/sycl/nd_item.hpp>
19 #include <CL/sycl/sub_group.hpp>
22 
24 namespace sycl {
25 namespace detail {
26 
27 // ---- linear_id_to_id
28 template <int Dimensions>
29 id<Dimensions> linear_id_to_id(range<Dimensions>, size_t linear_id);
30 template <> inline id<1> linear_id_to_id(range<1>, size_t linear_id) {
31  return id<1>(linear_id);
32 }
33 template <> inline id<2> linear_id_to_id(range<2> r, size_t linear_id) {
34  id<2> result;
35  result[0] = linear_id / r[1];
36  result[1] = linear_id % r[1];
37  return result;
38 }
39 template <> inline id<3> linear_id_to_id(range<3> r, size_t linear_id) {
40  id<3> result;
41  result[0] = linear_id / (r[1] * r[2]);
42  result[1] = (linear_id % (r[1] * r[2])) / r[2];
43  result[2] = linear_id % r[2];
44  return result;
45 }
46 
47 // ---- get_local_linear_range
48 template <typename Group> size_t get_local_linear_range(Group g);
49 template <> inline size_t get_local_linear_range<group<1>>(group<1> g) {
50  return g.get_local_range(0);
51 }
52 template <> inline size_t get_local_linear_range<group<2>>(group<2> g) {
53  return g.get_local_range(0) * g.get_local_range(1);
54 }
55 template <> inline size_t get_local_linear_range<group<3>>(group<3> g) {
56  return g.get_local_range(0) * g.get_local_range(1) * g.get_local_range(2);
57 }
58 template <>
59 inline size_t
60 get_local_linear_range<ext::oneapi::sub_group>(ext::oneapi::sub_group g) {
61  return g.get_local_range()[0];
62 }
63 
64 // ---- get_local_linear_id
65 template <typename Group>
66 typename Group::linear_id_type get_local_linear_id(Group g);
67 
68 #ifdef __SYCL_DEVICE_ONLY__
69 #define __SYCL_GROUP_GET_LOCAL_LINEAR_ID(D) \
70  template <> \
71  group<D>::linear_id_type get_local_linear_id<group<D>>(group<D>) { \
72  nd_item<D> it = cl::sycl::detail::Builder::getNDItem<D>(); \
73  return it.get_local_linear_id(); \
74  }
75 __SYCL_GROUP_GET_LOCAL_LINEAR_ID(1);
76 __SYCL_GROUP_GET_LOCAL_LINEAR_ID(2);
77 __SYCL_GROUP_GET_LOCAL_LINEAR_ID(3);
78 #undef __SYCL_GROUP_GET_LOCAL_LINEAR_ID
79 #endif // __SYCL_DEVICE_ONLY__
80 
81 template <>
83 get_local_linear_id<ext::oneapi::sub_group>(ext::oneapi::sub_group g) {
84  return g.get_local_id()[0];
85 }
86 
87 // ---- is_native_op
88 template <typename T>
89 using native_op_list =
93 
94 template <typename T, typename BinaryOperation> struct is_native_op {
95  static constexpr bool value =
98 };
99 
100 // ---- for_each
101 template <typename Group, typename Ptr, class Function>
102 Function for_each(Group g, Ptr first, Ptr last, Function f) {
103 #ifdef __SYCL_DEVICE_ONLY__
104  ptrdiff_t offset = sycl::detail::get_local_linear_id(g);
105  ptrdiff_t stride = sycl::detail::get_local_linear_range(g);
106  for (Ptr p = first + offset; p < last; p += stride) {
107  f(*p);
108  }
109  return f;
110 #else
111  (void)g;
112  (void)first;
113  (void)last;
114  (void)f;
115  throw runtime_error("Group algorithms are not supported on host device.",
117 #endif
118 }
119 } // namespace detail
120 
121 // ---- reduce_over_group
122 template <typename Group, typename T, class BinaryOperation>
123 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
124  detail::is_scalar_arithmetic<T>::value &&
125  detail::is_native_op<T, BinaryOperation>::value),
126  T>
127 reduce_over_group(Group, T x, BinaryOperation binary_op) {
128  // FIXME: Do not special-case for half precision
129  static_assert(
130  std::is_same<decltype(binary_op(x, x)), T>::value ||
131  (std::is_same<T, half>::value &&
132  std::is_same<decltype(binary_op(x, x)), float>::value),
133  "Result type of binary_op must match reduction accumulation type.");
134 #ifdef __SYCL_DEVICE_ONLY__
135  return sycl::detail::calc<T, __spv::GroupOperation::Reduce,
136  sycl::detail::spirv::group_scope<Group>::value>(
138 #else
139  throw runtime_error("Group algorithms are not supported on host device.",
141 #endif
142 }
143 
144 template <typename Group, typename T, class BinaryOperation>
145 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
146  detail::is_vector_arithmetic<T>::value &&
147  detail::is_native_op<T, BinaryOperation>::value),
148  T>
149 reduce_over_group(Group g, T x, BinaryOperation binary_op) {
150  // FIXME: Do not special-case for half precision
151  static_assert(
152  std::is_same<decltype(binary_op(x[0], x[0])),
153  typename T::element_type>::value ||
154  (std::is_same<T, half>::value &&
155  std::is_same<decltype(binary_op(x[0], x[0])), float>::value),
156  "Result type of binary_op must match reduction accumulation type.");
157  T result;
158  for (int s = 0; s < x.get_size(); ++s) {
159  result[s] = reduce_over_group(g, x[s], binary_op);
160  }
161  return result;
162 }
163 
164 template <typename Group, typename V, typename T, class BinaryOperation>
165 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
166  detail::is_scalar_arithmetic<V>::value &&
167  detail::is_scalar_arithmetic<T>::value &&
168  detail::is_native_op<V, BinaryOperation>::value &&
169  detail::is_native_op<T, BinaryOperation>::value),
170  T>
171 reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) {
172  // FIXME: Do not special-case for half precision
173  static_assert(
174  std::is_same<decltype(binary_op(init, x)), T>::value ||
175  (std::is_same<T, half>::value &&
176  std::is_same<decltype(binary_op(init, x)), float>::value),
177  "Result type of binary_op must match reduction accumulation type.");
178 #ifdef __SYCL_DEVICE_ONLY__
179  return binary_op(init, reduce_over_group(g, x, binary_op));
180 #else
181  (void)g;
182  throw runtime_error("Group algorithms are not supported on host device.",
184 #endif
185 }
186 
187 template <typename Group, typename V, typename T, class BinaryOperation>
188 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
189  detail::is_vector_arithmetic<V>::value &&
190  detail::is_vector_arithmetic<T>::value &&
191  detail::is_native_op<V, BinaryOperation>::value &&
192  detail::is_native_op<T, BinaryOperation>::value),
193  T>
194 reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) {
195  // FIXME: Do not special-case for half precision
196  static_assert(
197  std::is_same<decltype(binary_op(init[0], x[0])),
198  typename T::element_type>::value ||
199  (std::is_same<T, half>::value &&
200  std::is_same<decltype(binary_op(init[0], x[0])), float>::value),
201  "Result type of binary_op must match reduction accumulation type.");
202 #ifdef __SYCL_DEVICE_ONLY__
203  T result = init;
204  for (int s = 0; s < x.get_size(); ++s) {
205  result[s] = binary_op(init[s], reduce_over_group(g, x[s], binary_op));
206  }
207  return result;
208 #else
209  (void)g;
210  throw runtime_error("Group algorithms are not supported on host device.",
212 #endif
213 }
214 
215 // ---- joint_reduce
216 template <typename Group, typename Ptr, class BinaryOperation>
218  (is_group_v<std::decay_t<Group>> && detail::is_pointer<Ptr>::value &&
221 joint_reduce(Group g, Ptr first, Ptr last, BinaryOperation binary_op) {
222  using T = typename detail::remove_pointer<Ptr>::type;
223  // FIXME: Do not special-case for half precision
224  static_assert(
225  std::is_same<decltype(binary_op(*first, *first)), T>::value ||
226  (std::is_same<T, half>::value &&
227  std::is_same<decltype(binary_op(*first, *first)), float>::value),
228  "Result type of binary_op must match reduction accumulation type.");
229 #ifdef __SYCL_DEVICE_ONLY__
230  T partial = sycl::known_identity_v<BinaryOperation, T>;
231  sycl::detail::for_each(g, first, last,
232  [&](const T &x) { partial = binary_op(partial, x); });
233  return reduce_over_group(g, partial, binary_op);
234 #else
235  (void)g;
236  (void)last;
237  (void)binary_op;
238  throw runtime_error("Group algorithms are not supported on host device.",
240 #endif
241 }
242 
243 template <typename Group, typename Ptr, typename T, class BinaryOperation>
245  (is_group_v<std::decay_t<Group>> && detail::is_pointer<Ptr>::value &&
247  detail::is_arithmetic<T>::value &&
249  BinaryOperation>::value &&
250  detail::is_native_op<T, BinaryOperation>::value),
251  T>
252 joint_reduce(Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op) {
253  // FIXME: Do not special-case for half precision
254  static_assert(
255  std::is_same<decltype(binary_op(init, *first)), T>::value ||
256  (std::is_same<T, half>::value &&
257  std::is_same<decltype(binary_op(init, *first)), float>::value),
258  "Result type of binary_op must match reduction accumulation type.");
259 #ifdef __SYCL_DEVICE_ONLY__
260  T partial = sycl::known_identity_v<BinaryOperation, T>;
262  g, first, last, [&](const typename detail::remove_pointer<Ptr>::type &x) {
263  partial = binary_op(partial, x);
264  });
265  return reduce_over_group(g, partial, init, binary_op);
266 #else
267  (void)g;
268  (void)last;
269  throw runtime_error("Group algorithms are not supported on host device.",
271 #endif
272 }
273 
274 // ---- any_of_group
275 template <typename Group>
276 detail::enable_if_t<is_group_v<std::decay_t<Group>>, bool>
277 any_of_group(Group, bool pred) {
278 #ifdef __SYCL_DEVICE_ONLY__
279  return sycl::detail::spirv::GroupAny<Group>(pred);
280 #else
281  (void)pred;
282  throw runtime_error("Group algorithms are not supported on host device.",
284 #endif
285 }
286 
287 template <typename Group, typename T, class Predicate>
289  Predicate pred) {
290  return any_of_group(g, pred(x));
291 }
292 
293 // ---- joint_any_of
294 template <typename Group, typename Ptr, class Predicate>
296  (is_group_v<std::decay_t<Group>> && detail::is_pointer<Ptr>::value), bool>
297 joint_any_of(Group g, Ptr first, Ptr last, Predicate pred) {
298 #ifdef __SYCL_DEVICE_ONLY__
299  using T = typename detail::remove_pointer<Ptr>::type;
300  bool partial = false;
301  sycl::detail::for_each(g, first, last, [&](T &x) { partial |= pred(x); });
302  return any_of_group(g, partial);
303 #else
304  (void)g;
305  (void)first;
306  (void)last;
307  (void)pred;
308  throw runtime_error("Group algorithms are not supported on host device.",
310 #endif
311 }
312 
313 // ---- all_of_group
314 template <typename Group>
315 detail::enable_if_t<is_group_v<std::decay_t<Group>>, bool>
316 all_of_group(Group, bool pred) {
317 #ifdef __SYCL_DEVICE_ONLY__
318  return sycl::detail::spirv::GroupAll<Group>(pred);
319 #else
320  (void)pred;
321  throw runtime_error("Group algorithms are not supported on host device.",
323 #endif
324 }
325 
326 template <typename Group, typename T, class Predicate>
327 detail::enable_if_t<is_group_v<std::decay_t<Group>>, bool>
328 all_of_group(Group g, T x, Predicate pred) {
329  return all_of_group(g, pred(x));
330 }
331 
332 // ---- joint_all_of
333 template <typename Group, typename Ptr, class Predicate>
335  (is_group_v<std::decay_t<Group>> && detail::is_pointer<Ptr>::value), bool>
336 joint_all_of(Group g, Ptr first, Ptr last, Predicate pred) {
337 #ifdef __SYCL_DEVICE_ONLY__
338  using T = typename detail::remove_pointer<Ptr>::type;
339  bool partial = true;
340  sycl::detail::for_each(g, first, last, [&](T &x) { partial &= pred(x); });
341  return all_of_group(g, partial);
342 #else
343  (void)g;
344  (void)first;
345  (void)last;
346  (void)pred;
347  throw runtime_error("Group algorithms are not supported on host device.",
349 #endif
350 }
351 
352 // ---- none_of_group
353 template <typename Group>
354 detail::enable_if_t<is_group_v<std::decay_t<Group>>, bool>
355 none_of_group(Group, bool pred) {
356 #ifdef __SYCL_DEVICE_ONLY__
357  return sycl::detail::spirv::GroupAll<Group>(!pred);
358 #else
359  (void)pred;
360  throw runtime_error("Group algorithms are not supported on host device.",
362 #endif
363 }
364 
365 template <typename Group, typename T, class Predicate>
366 detail::enable_if_t<is_group_v<std::decay_t<Group>>, bool>
367 none_of_group(Group g, T x, Predicate pred) {
368  return none_of_group(g, pred(x));
369 }
370 
371 // ---- joint_none_of
372 template <typename Group, typename Ptr, class Predicate>
374  (is_group_v<std::decay_t<Group>> && detail::is_pointer<Ptr>::value), bool>
375 joint_none_of(Group g, Ptr first, Ptr last, Predicate pred) {
376 #ifdef __SYCL_DEVICE_ONLY__
377  return !joint_any_of(g, first, last, pred);
378 #else
379  (void)g;
380  (void)first;
381  (void)last;
382  (void)pred;
383  throw runtime_error("Group algorithms are not supported on host device.",
385 #endif
386 }
387 
388 // ---- shift_group_left
389 // TODO: remove check for detail::is_vec<T> once sycl::vec is trivially
390 // copyable.
391 template <typename Group, typename T>
392 detail::enable_if_t<(std::is_same<std::decay_t<Group>, sub_group>::value &&
393  (std::is_trivially_copyable<T>::value ||
394  detail::is_vec<T>::value)),
395  T>
396 shift_group_left(Group, T x, typename Group::linear_id_type delta = 1) {
397 #ifdef __SYCL_DEVICE_ONLY__
398  return sycl::detail::spirv::SubgroupShuffleDown(x, delta);
399 #else
400  (void)x;
401  (void)delta;
402  throw runtime_error("Sub-groups are not supported on host device.",
404 #endif
405 }
406 
407 // ---- shift_group_right
408 // TODO: remove check for detail::is_vec<T> once sycl::vec is trivially
409 // copyable.
410 template <typename Group, typename T>
411 detail::enable_if_t<(std::is_same<std::decay_t<Group>, sub_group>::value &&
412  (std::is_trivially_copyable<T>::value ||
413  detail::is_vec<T>::value)),
414  T>
415 shift_group_right(Group, T x, typename Group::linear_id_type delta = 1) {
416 #ifdef __SYCL_DEVICE_ONLY__
417  return sycl::detail::spirv::SubgroupShuffleUp(x, delta);
418 #else
419  (void)x;
420  (void)delta;
421  throw runtime_error("Sub-groups are not supported on host device.",
423 #endif
424 }
425 
426 // ---- permute_group_by_xor
427 // TODO: remove check for detail::is_vec<T> once sycl::vec is trivially
428 // copyable.
429 template <typename Group, typename T>
430 detail::enable_if_t<(std::is_same<std::decay_t<Group>, sub_group>::value &&
431  (std::is_trivially_copyable<T>::value ||
432  detail::is_vec<T>::value)),
433  T>
434 permute_group_by_xor(Group, T x, typename Group::linear_id_type mask) {
435 #ifdef __SYCL_DEVICE_ONLY__
436  return sycl::detail::spirv::SubgroupShuffleXor(x, mask);
437 #else
438  (void)x;
439  (void)mask;
440  throw runtime_error("Sub-groups are not supported on host device.",
442 #endif
443 }
444 
445 // ---- select_from_group
446 // TODO: remove check for detail::is_vec<T> once sycl::vec is trivially
447 // copyable.
448 template <typename Group, typename T>
449 detail::enable_if_t<(std::is_same<std::decay_t<Group>, sub_group>::value &&
450  (std::is_trivially_copyable<T>::value ||
451  detail::is_vec<T>::value)),
452  T>
453 select_from_group(Group, T x, typename Group::id_type local_id) {
454 #ifdef __SYCL_DEVICE_ONLY__
455  return sycl::detail::spirv::SubgroupShuffle(x, local_id);
456 #else
457  (void)x;
458  (void)local_id;
459  throw runtime_error("Sub-groups are not supported on host device.",
461 #endif
462 }
463 
464 // ---- group_broadcast
465 // TODO: remove check for detail::is_vec<T> once sycl::vec is trivially
466 // copyable.
467 template <typename Group, typename T>
468 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
469  (std::is_trivially_copyable<T>::value ||
470  detail::is_vec<T>::value)),
471  T>
472 group_broadcast(Group, T x, typename Group::id_type local_id) {
473 #ifdef __SYCL_DEVICE_ONLY__
474  return sycl::detail::spirv::GroupBroadcast<Group>(x, local_id);
475 #else
476  (void)x;
477  (void)local_id;
478  throw runtime_error("Group algorithms are not supported on host device.",
480 #endif
481 }
482 
483 template <typename Group, typename T>
484 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
485  (std::is_trivially_copyable<T>::value ||
486  detail::is_vec<T>::value)),
487  T>
488 group_broadcast(Group g, T x, typename Group::linear_id_type linear_local_id) {
489 #ifdef __SYCL_DEVICE_ONLY__
490  return group_broadcast(
491  g, x,
492  sycl::detail::linear_id_to_id(g.get_local_range(), linear_local_id));
493 #else
494  (void)g;
495  (void)x;
496  (void)linear_local_id;
497  throw runtime_error("Group algorithms are not supported on host device.",
499 #endif
500 }
501 
502 template <typename Group, typename T>
503 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
504  (std::is_trivially_copyable<T>::value ||
505  detail::is_vec<T>::value)),
506  T>
507 group_broadcast(Group g, T x) {
508 #ifdef __SYCL_DEVICE_ONLY__
509  return group_broadcast(g, x, 0);
510 #else
511  (void)g;
512  (void)x;
513  throw runtime_error("Group algorithms are not supported on host device.",
515 #endif
516 }
517 
518 // ---- exclusive_scan_over_group
519 template <typename Group, typename T, class BinaryOperation>
520 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
521  detail::is_scalar_arithmetic<T>::value &&
522  detail::is_native_op<T, BinaryOperation>::value),
523  T>
524 exclusive_scan_over_group(Group, T x, BinaryOperation binary_op) {
525  // FIXME: Do not special-case for half precision
526  static_assert(std::is_same<decltype(binary_op(x, x)), T>::value ||
527  (std::is_same<T, half>::value &&
528  std::is_same<decltype(binary_op(x, x)), float>::value),
529  "Result type of binary_op must match scan accumulation type.");
530 #ifdef __SYCL_DEVICE_ONLY__
531  return sycl::detail::calc<T, __spv::GroupOperation::ExclusiveScan,
532  sycl::detail::spirv::group_scope<Group>::value>(
534 #else
535  throw runtime_error("Group algorithms are not supported on host device.",
537 #endif
538 }
539 
540 template <typename Group, typename T, class BinaryOperation>
541 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
542  detail::is_vector_arithmetic<T>::value &&
543  detail::is_native_op<T, BinaryOperation>::value),
544  T>
545 exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
546  // FIXME: Do not special-case for half precision
547  static_assert(
548  std::is_same<decltype(binary_op(x[0], x[0])),
549  typename T::element_type>::value ||
550  (std::is_same<T, half>::value &&
551  std::is_same<decltype(binary_op(x[0], x[0])), float>::value),
552  "Result type of binary_op must match scan accumulation type.");
553  T result;
554  for (int s = 0; s < x.get_size(); ++s) {
555  result[s] = exclusive_scan_over_group(g, x[s], binary_op);
556  }
557  return result;
558 }
559 
560 template <typename Group, typename V, typename T, class BinaryOperation>
561 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
562  detail::is_vector_arithmetic<V>::value &&
563  detail::is_vector_arithmetic<T>::value &&
564  detail::is_native_op<V, BinaryOperation>::value &&
565  detail::is_native_op<T, BinaryOperation>::value),
566  T>
567 exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op) {
568  // FIXME: Do not special-case for half precision
569  static_assert(
570  std::is_same<decltype(binary_op(init[0], x[0])),
571  typename T::element_type>::value ||
572  (std::is_same<T, half>::value &&
573  std::is_same<decltype(binary_op(init[0], x[0])), float>::value),
574  "Result type of binary_op must match scan accumulation type.");
575  T result;
576  for (int s = 0; s < x.get_size(); ++s) {
577  result[s] = exclusive_scan_over_group(g, x[s], init[s], binary_op);
578  }
579  return result;
580 }
581 
582 template <typename Group, typename V, typename T, class BinaryOperation>
583 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
584  detail::is_scalar_arithmetic<V>::value &&
585  detail::is_scalar_arithmetic<T>::value &&
586  detail::is_native_op<V, BinaryOperation>::value &&
587  detail::is_native_op<T, BinaryOperation>::value),
588  T>
589 exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op) {
590  // FIXME: Do not special-case for half precision
591  static_assert(std::is_same<decltype(binary_op(init, x)), T>::value ||
592  (std::is_same<T, half>::value &&
593  std::is_same<decltype(binary_op(init, x)), float>::value),
594  "Result type of binary_op must match scan accumulation type.");
595 #ifdef __SYCL_DEVICE_ONLY__
596  typename Group::linear_id_type local_linear_id =
598  if (local_linear_id == 0) {
599  x = binary_op(init, x);
600  }
601  T scan = exclusive_scan_over_group(g, x, binary_op);
602  if (local_linear_id == 0) {
603  scan = init;
604  }
605  return scan;
606 #else
607  (void)g;
608  throw runtime_error("Group algorithms are not supported on host device.",
610 #endif
611 }
612 
613 // ---- joint_exclusive_scan
614 template <typename Group, typename InPtr, typename OutPtr, typename T,
615  class BinaryOperation>
617  (is_group_v<std::decay_t<Group>> && detail::is_pointer<InPtr>::value &&
618  detail::is_pointer<OutPtr>::value &&
619  detail::is_arithmetic<
620  typename detail::remove_pointer<InPtr>::type>::value &&
621  detail::is_arithmetic<T>::value &&
623  BinaryOperation>::value &&
624  detail::is_native_op<T, BinaryOperation>::value),
625  OutPtr>
626 joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, T init,
627  BinaryOperation binary_op) {
628  // FIXME: Do not special-case for half precision
629  static_assert(
630  std::is_same<decltype(binary_op(*first, *first)), T>::value ||
631  (std::is_same<T, half>::value &&
632  std::is_same<decltype(binary_op(*first, *first)), float>::value),
633  "Result type of binary_op must match scan accumulation type.");
634 #ifdef __SYCL_DEVICE_ONLY__
635  ptrdiff_t offset = sycl::detail::get_local_linear_id(g);
636  ptrdiff_t stride = sycl::detail::get_local_linear_range(g);
637  ptrdiff_t N = last - first;
638  auto roundup = [=](const ptrdiff_t &v,
639  const ptrdiff_t &divisor) -> ptrdiff_t {
640  return ((v + divisor - 1) / divisor) * divisor;
641  };
643  x;
644  typename detail::remove_pointer<OutPtr>::type carry = init;
645  for (ptrdiff_t chunk = 0; chunk < roundup(N, stride); chunk += stride) {
646  ptrdiff_t i = chunk + offset;
647  if (i < N) {
648  x = first[i];
649  }
651  exclusive_scan_over_group(g, x, carry, binary_op);
652  if (i < N) {
653  result[i] = out;
654  }
655  carry = group_broadcast(g, binary_op(out, x), stride - 1);
656  }
657  return result + N;
658 #else
659  (void)g;
660  (void)last;
661  (void)result;
662  (void)init;
663  throw runtime_error("Group algorithms are not supported on host device.",
665 #endif
666 }
667 
668 template <typename Group, typename InPtr, typename OutPtr,
669  class BinaryOperation>
671  (is_group_v<std::decay_t<Group>> && detail::is_pointer<InPtr>::value &&
672  detail::is_pointer<OutPtr>::value &&
673  detail::is_arithmetic<
674  typename detail::remove_pointer<InPtr>::type>::value &&
676  BinaryOperation>::value),
677  OutPtr>
678 joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
679  BinaryOperation binary_op) {
680  // FIXME: Do not special-case for half precision
681  static_assert(
682  std::is_same<decltype(binary_op(*first, *first)),
683  typename detail::remove_pointer<OutPtr>::type>::value ||
684  (std::is_same<typename detail::remove_pointer<OutPtr>::type,
685  half>::value &&
686  std::is_same<decltype(binary_op(*first, *first)), float>::value),
687  "Result type of binary_op must match scan accumulation type.");
688  return joint_exclusive_scan(
689  g, first, last, result,
690  sycl::known_identity_v<BinaryOperation,
692  binary_op);
693 }
694 
695 // ---- inclusive_scan_over_group
696 template <typename Group, typename T, class BinaryOperation>
697 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
698  detail::is_vector_arithmetic<T>::value &&
699  detail::is_native_op<T, BinaryOperation>::value),
700  T>
701 inclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
702  // FIXME: Do not special-case for half precision
703  static_assert(
704  std::is_same<decltype(binary_op(x[0], x[0])),
705  typename T::element_type>::value ||
706  (std::is_same<T, half>::value &&
707  std::is_same<decltype(binary_op(x[0], x[0])), float>::value),
708  "Result type of binary_op must match scan accumulation type.");
709  T result;
710  for (int s = 0; s < x.get_size(); ++s) {
711  result[s] = inclusive_scan_over_group(g, x[s], binary_op);
712  }
713  return result;
714 }
715 
716 template <typename Group, typename T, class BinaryOperation>
717 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
718  detail::is_scalar_arithmetic<T>::value &&
719  detail::is_native_op<T, BinaryOperation>::value),
720  T>
721 inclusive_scan_over_group(Group, T x, BinaryOperation binary_op) {
722  // FIXME: Do not special-case for half precision
723  static_assert(std::is_same<decltype(binary_op(x, x)), T>::value ||
724  (std::is_same<T, half>::value &&
725  std::is_same<decltype(binary_op(x, x)), float>::value),
726  "Result type of binary_op must match scan accumulation type.");
727 #ifdef __SYCL_DEVICE_ONLY__
728  return sycl::detail::calc<T, __spv::GroupOperation::InclusiveScan,
729  sycl::detail::spirv::group_scope<Group>::value>(
731 #else
732  throw runtime_error("Group algorithms are not supported on host device.",
734 #endif
735 }
736 
737 template <typename Group, typename V, class BinaryOperation, typename T>
738 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
739  detail::is_scalar_arithmetic<V>::value &&
740  detail::is_scalar_arithmetic<T>::value &&
741  detail::is_native_op<V, BinaryOperation>::value &&
742  detail::is_native_op<T, BinaryOperation>::value),
743  T>
744 inclusive_scan_over_group(Group g, V x, BinaryOperation binary_op, T init) {
745  // FIXME: Do not special-case for half precision
746  static_assert(std::is_same<decltype(binary_op(init, x)), T>::value ||
747  (std::is_same<T, half>::value &&
748  std::is_same<decltype(binary_op(init, x)), float>::value),
749  "Result type of binary_op must match scan accumulation type.");
750 #ifdef __SYCL_DEVICE_ONLY__
751  if (sycl::detail::get_local_linear_id(g) == 0) {
752  x = binary_op(init, x);
753  }
754  return inclusive_scan_over_group(g, x, binary_op);
755 #else
756  (void)g;
757  throw runtime_error("Group algorithms are not supported on host device.",
759 #endif
760 }
761 
762 template <typename Group, typename V, class BinaryOperation, typename T>
763 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
764  detail::is_vector_arithmetic<V>::value &&
765  detail::is_vector_arithmetic<T>::value &&
766  detail::is_native_op<V, BinaryOperation>::value &&
767  detail::is_native_op<T, BinaryOperation>::value),
768  T>
769 inclusive_scan_over_group(Group g, V x, BinaryOperation binary_op, T init) {
770  // FIXME: Do not special-case for half precision
771  static_assert(
772  std::is_same<decltype(binary_op(init[0], x[0])), T>::value ||
773  (std::is_same<T, half>::value &&
774  std::is_same<decltype(binary_op(init[0], x[0])), float>::value),
775  "Result type of binary_op must match scan accumulation type.");
776  T result;
777  for (int s = 0; s < x.get_size(); ++s) {
778  result[s] = inclusive_scan_over_group(g, x[s], binary_op, init[s]);
779  }
780  return result;
781 }
782 
783 // ---- joint_inclusive_scan
784 template <typename Group, typename InPtr, typename OutPtr,
785  class BinaryOperation, typename T>
787  (is_group_v<std::decay_t<Group>> && detail::is_pointer<InPtr>::value &&
788  detail::is_pointer<OutPtr>::value &&
789  detail::is_arithmetic<
790  typename detail::remove_pointer<InPtr>::type>::value &&
791  detail::is_arithmetic<T>::value &&
793  BinaryOperation>::value &&
794  detail::is_native_op<T, BinaryOperation>::value),
795  OutPtr>
796 joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
797  BinaryOperation binary_op, T init) {
798  // FIXME: Do not special-case for half precision
799  static_assert(
800  std::is_same<decltype(binary_op(init, *first)), T>::value ||
801  (std::is_same<T, half>::value &&
802  std::is_same<decltype(binary_op(init, *first)), float>::value),
803  "Result type of binary_op must match scan accumulation type.");
804 #ifdef __SYCL_DEVICE_ONLY__
805  ptrdiff_t offset = sycl::detail::get_local_linear_id(g);
806  ptrdiff_t stride = sycl::detail::get_local_linear_range(g);
807  ptrdiff_t N = last - first;
808  auto roundup = [=](const ptrdiff_t &v,
809  const ptrdiff_t &divisor) -> ptrdiff_t {
810  return ((v + divisor - 1) / divisor) * divisor;
811  };
813  x;
814  typename detail::remove_pointer<OutPtr>::type carry = init;
815  for (ptrdiff_t chunk = 0; chunk < roundup(N, stride); chunk += stride) {
816  ptrdiff_t i = chunk + offset;
817  if (i < N) {
818  x = first[i];
819  }
821  inclusive_scan_over_group(g, x, binary_op, carry);
822  if (i < N) {
823  result[i] = out;
824  }
825  carry = group_broadcast(g, out, stride - 1);
826  }
827  return result + N;
828 #else
829  (void)g;
830  (void)last;
831  (void)result;
832  throw runtime_error("Group algorithms are not supported on host device.",
834 #endif
835 }
836 
837 template <typename Group, typename InPtr, typename OutPtr,
838  class BinaryOperation>
840  (is_group_v<std::decay_t<Group>> && detail::is_pointer<InPtr>::value &&
841  detail::is_pointer<OutPtr>::value &&
842  detail::is_arithmetic<
843  typename detail::remove_pointer<InPtr>::type>::value &&
845  BinaryOperation>::value),
846  OutPtr>
847 joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
848  BinaryOperation binary_op) {
849  // FIXME: Do not special-case for half precision
850  static_assert(
851  std::is_same<decltype(binary_op(*first, *first)),
852  typename detail::remove_pointer<OutPtr>::type>::value ||
853  (std::is_same<typename detail::remove_pointer<OutPtr>::type,
854  half>::value &&
855  std::is_same<decltype(binary_op(*first, *first)), float>::value),
856  "Result type of binary_op must match scan accumulation type.");
857  return joint_inclusive_scan(
858  g, first, last, result, binary_op,
859  sycl::known_identity_v<BinaryOperation,
861 }
862 
863 namespace detail {
864 template <typename G> struct group_barrier_scope {};
865 template <> struct group_barrier_scope<sycl::sub_group> {
866  constexpr static auto Scope = __spv::Scope::Subgroup;
867 };
868 template <int D> struct group_barrier_scope<sycl::group<D>> {
869  constexpr static auto Scope = __spv::Scope::Workgroup;
870 };
871 } // namespace detail
872 
873 template <typename Group>
874 typename std::enable_if<is_group_v<Group>>::type
875 group_barrier(Group, memory_scope FenceScope = Group::fence_scope) {
876  (void)FenceScope;
877 #ifdef __SYCL_DEVICE_ONLY__
878  // Per SYCL spec, group_barrier must perform both control barrier and memory
879  // fence operations. All work-items execute a release fence prior to
880  // barrier and acquire fence afterwards. The rest of semantics flags specify
881  // which type of memory this behavior is applied to.
883  sycl::detail::spirv::getScope(FenceScope),
888 #else
889  throw sycl::runtime_error("Barriers are not supported on host device",
891 #endif
892 }
893 
894 } // namespace sycl
895 } // __SYCL_INLINE_NAMESPACE(cl)
spirv_ops.hpp
__spirv_ControlBarrier
__SYCL_CONVERGENT__ SYCL_EXTERNAL void __spirv_ControlBarrier(__spv::Scope Execution, __spv::Scope Memory, uint32_t Semantics) noexcept
Definition: spirv_ops.cpp:26
__spv::MemorySemanticsMask::SubgroupMemory
@ SubgroupMemory
Definition: spirv_types.hpp:91
cl::sycl::joint_reduce
detail::enable_if_t<(is_group_v< std::decay_t< Group >> &&detail::is_pointer< Ptr >::value &&detail::is_arithmetic< typename detail::remove_pointer< Ptr >::type >::value &&detail::is_arithmetic< T >::value &&detail::is_native_op< typename detail::remove_pointer< Ptr >::type, BinaryOperation >::value &&detail::is_native_op< T, BinaryOperation >::value), T > joint_reduce(Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op)
Definition: group_algorithm.hpp:252
cl::sycl::joint_exclusive_scan
detail::enable_if_t<(is_group_v< std::decay_t< Group >> &&detail::is_pointer< InPtr >::value &&detail::is_pointer< OutPtr >::value &&detail::is_arithmetic< typename detail::remove_pointer< InPtr >::type >::value &&detail::is_native_op< typename detail::remove_pointer< InPtr >::type, BinaryOperation >::value), OutPtr > joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, BinaryOperation binary_op)
Definition: group_algorithm.hpp:678
type
__spv::Scope::Workgroup
@ Workgroup
Definition: spirv_types.hpp:30
cl::sycl::bit_xor
std::bit_xor< T > bit_xor
Definition: functional.hpp:22
T
type_traits.hpp
sub_group.hpp
cl::sycl::multiplies
std::multiplies< T > multiplies
Definition: functional.hpp:19
cl::sycl::detail::type_list
Definition: type_list.hpp:23
cl::sycl::detail::for_each
Function for_each(Group g, Ptr first, Ptr last, Function f)
Definition: group_algorithm.hpp:102
cl::sycl::ext::intel::experimental::type
type
Definition: fpga_utils.hpp:22
cl::sycl::joint_all_of
detail::enable_if_t<(is_group_v< std::decay_t< Group >> &&detail::is_pointer< Ptr >::value), bool > joint_all_of(Group g, Ptr first, Ptr last, Predicate pred)
Definition: group_algorithm.hpp:336
cl::sycl::id< 1 >
cl::sycl::group< 1 >
cl::sycl
Definition: access.hpp:14
cl::sycl::detail::is_contained
Definition: type_list.hpp:54
cl::sycl::detail::linear_id_to_id
id< 3 > linear_id_to_id(range< 3 > r, size_t linear_id)
Definition: group_algorithm.hpp:39
cl::sycl::detail::get_local_linear_range
size_t get_local_linear_range(Group g)
cl::sycl::shift_group_left
detail::enable_if_t<(std::is_same< std::decay_t< Group >, sub_group >::value &&(std::is_trivially_copyable< T >::value||detail::is_vec< T >::value)), T > shift_group_left(Group, T x, typename Group::linear_id_type delta=1)
Definition: group_algorithm.hpp:396
detail
Definition: pi_opencl.cpp:86
cl::sycl::detail::get_local_linear_id
Group::linear_id_type get_local_linear_id(Group g)
cl::sycl::exclusive_scan_over_group
detail::enable_if_t<(is_group_v< std::decay_t< Group >> &&detail::is_scalar_arithmetic< V >::value &&detail::is_scalar_arithmetic< T >::value &&detail::is_native_op< V, BinaryOperation >::value &&detail::is_native_op< T, BinaryOperation >::value), T > exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op)
Definition: group_algorithm.hpp:589
spirv_vars.hpp
cl::sycl::shift_group_right
detail::enable_if_t<(std::is_same< std::decay_t< Group >, sub_group >::value &&(std::is_trivially_copyable< T >::value||detail::is_vec< T >::value)), T > shift_group_right(Group, T x, typename Group::linear_id_type delta=1)
Definition: group_algorithm.hpp:415
__spv::MemorySemanticsMask::CrossWorkgroupMemory
@ CrossWorkgroupMemory
Definition: spirv_types.hpp:93
cl::sycl::reduce_over_group
detail::enable_if_t<(is_group_v< std::decay_t< Group >> &&detail::is_vector_arithmetic< V >::value &&detail::is_vector_arithmetic< T >::value &&detail::is_native_op< V, BinaryOperation >::value &&detail::is_native_op< T, BinaryOperation >::value), T > reduce_over_group(Group g, V x, T init, BinaryOperation binary_op)
Definition: group_algorithm.hpp:194
group.hpp
cl::sycl::range< 1 >
cl::sycl::detail::group_barrier_scope
Definition: group_algorithm.hpp:864
cl::sycl::bit_or
std::bit_or< T > bit_or
Definition: functional.hpp:21
cl::__SEIEED::binary_op
ESIMD_INLINE T binary_op(T X, T Y)
Definition: elem_type_traits.hpp:415
cl::sycl::memory_scope
memory_scope
Definition: memory_enums.hpp:26
__spv::GroupOperation::ExclusiveScan
@ ExclusiveScan
cl::sycl::none_of_group
detail::enable_if_t< is_group_v< std::decay_t< Group > >, bool > none_of_group(Group g, T x, Predicate pred)
Definition: group_algorithm.hpp:367
cl::sycl::maximum
Definition: functional.hpp:43
functional.hpp
cl::sycl::detail::half_impl::half
Definition: half_type.hpp:335
__spv::Scope::Subgroup
@ Subgroup
Definition: spirv_types.hpp:31
cl::sycl::known_identity_v
__SYCL_INLINE_CONSTEXPR AccumulatorT known_identity_v
Definition: known_identity.hpp:394
__spv::GroupOperation::InclusiveScan
@ InclusiveScan
cl::sycl::joint_inclusive_scan
detail::enable_if_t<(is_group_v< std::decay_t< Group >> &&detail::is_pointer< InPtr >::value &&detail::is_pointer< OutPtr >::value &&detail::is_arithmetic< typename detail::remove_pointer< InPtr >::type >::value &&detail::is_native_op< typename detail::remove_pointer< InPtr >::type, BinaryOperation >::value), OutPtr > joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, BinaryOperation binary_op)
Definition: group_algorithm.hpp:847
cl::sycl::minimum
Definition: functional.hpp:26
spirv.hpp
cl::sycl::inclusive_scan_over_group
detail::enable_if_t<(is_group_v< std::decay_t< Group >> &&detail::is_vector_arithmetic< V >::value &&detail::is_vector_arithmetic< T >::value &&detail::is_native_op< V, BinaryOperation >::value &&detail::is_native_op< T, BinaryOperation >::value), T > inclusive_scan_over_group(Group g, V x, BinaryOperation binary_op, T init)
Definition: group_algorithm.hpp:769
cl::sycl::group_broadcast
detail::enable_if_t<(is_group_v< std::decay_t< Group >> &&(std::is_trivially_copyable< T >::value||detail::is_vec< T >::value)), T > group_broadcast(Group g, T x)
Definition: group_algorithm.hpp:507
cl
We provide new interfaces for matrix muliply in this patch:
Definition: access.hpp:13
cl::sycl::bit_and
std::bit_and< T > bit_and
Definition: functional.hpp:20
cl::sycl::image_channel_order::r
@ r
i
int i
Definition: math_intrin.hpp:399
__spv::MemorySemanticsMask::SequentiallyConsistent
@ SequentiallyConsistent
Definition: spirv_types.hpp:89
cl::sycl::joint_none_of
detail::enable_if_t<(is_group_v< std::decay_t< Group >> &&detail::is_pointer< Ptr >::value), bool > joint_none_of(Group g, Ptr first, Ptr last, Predicate pred)
Definition: group_algorithm.hpp:375
cl::sycl::ext::oneapi::sub_group
Definition: sub_group.hpp:108
__spv::MemorySemanticsMask::WorkgroupMemory
@ WorkgroupMemory
Definition: spirv_types.hpp:92
cl::sycl::any_of_group
detail::enable_if_t< is_group_v< Group >, bool > any_of_group(Group g, T x, Predicate pred)
Definition: group_algorithm.hpp:288
group_sort.hpp
cl::sycl::group_barrier
std::enable_if< is_group_v< Group > >::type group_barrier(Group, memory_scope FenceScope=Group::fence_scope)
Definition: group_algorithm.hpp:875
cl::sycl::detail::remove_pointer_impl< remove_cv_t< T > >::type
remove_cv_t< T > type
Definition: type_traits.hpp:263
cl::sycl::ext::oneapi::sub_group::linear_id_type
uint32_t linear_id_type
Definition: sub_group.hpp:112
cl::sycl::select_from_group
detail::enable_if_t<(std::is_same< std::decay_t< Group >, sub_group >::value &&(std::is_trivially_copyable< T >::value||detail::is_vec< T >::value)), T > select_from_group(Group, T x, typename Group::id_type local_id)
Definition: group_algorithm.hpp:453
cl::sycl::ext::oneapi::sub_group::get_local_range
range_type get_local_range() const
Definition: sub_group.hpp:137
cl::sycl::joint_any_of
detail::enable_if_t<(is_group_v< std::decay_t< Group >> &&detail::is_pointer< Ptr >::value), bool > joint_any_of(Group g, Ptr first, Ptr last, Predicate pred)
Definition: group_algorithm.hpp:297
cl::sycl::ext::oneapi::sub_group::get_local_id
id_type get_local_id() const
Definition: sub_group.hpp:119
nd_item.hpp
cl::sycl::permute_group_by_xor
detail::enable_if_t<(std::is_same< std::decay_t< Group >, sub_group >::value &&(std::is_trivially_copyable< T >::value||detail::is_vec< T >::value)), T > permute_group_by_xor(Group, T x, typename Group::linear_id_type mask)
Definition: group_algorithm.hpp:434
cl::sycl::detail::is_native_op
Definition: group_algorithm.hpp:94
cl::sycl::all_of_group
detail::enable_if_t< is_group_v< std::decay_t< Group > >, bool > all_of_group(Group g, T x, Predicate pred)
Definition: group_algorithm.hpp:328
cl::sycl::detail::enable_if_t
typename std::enable_if< B, T >::type enable_if_t
Definition: stl_type_traits.hpp:24
known_identity.hpp
__spv::GroupOperation::Reduce
@ Reduce
PI_INVALID_DEVICE
@ PI_INVALID_DEVICE
Definition: pi.h:91
functional.hpp
spirv_types.hpp
__SYCL_INLINE_NAMESPACE
#define __SYCL_INLINE_NAMESPACE(X)
Definition: defines_elementary.hpp:12