DPC++ Runtime
Runtime libraries for oneAPI DPC++
group_algorithm.hpp
Go to the documentation of this file.
1 //==------------------------ group_algorithm.hpp ---------------------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 #include <complex>
11 
12 #include <CL/__spirv/spirv_ops.hpp>
15 #include <CL/sycl/detail/spirv.hpp>
17 #include <CL/sycl/functional.hpp>
18 #include <CL/sycl/group.hpp>
20 #include <CL/sycl/nd_item.hpp>
21 #include <CL/sycl/sub_group.hpp>
24 
26 namespace sycl {
27 namespace detail {
28 
29 // ---- linear_id_to_id
30 template <int Dimensions>
31 id<Dimensions> linear_id_to_id(range<Dimensions>, size_t linear_id);
32 template <> inline id<1> linear_id_to_id(range<1>, size_t linear_id) {
33  return id<1>(linear_id);
34 }
35 template <> inline id<2> linear_id_to_id(range<2> r, size_t linear_id) {
36  id<2> result;
37  result[0] = linear_id / r[1];
38  result[1] = linear_id % r[1];
39  return result;
40 }
41 template <> inline id<3> linear_id_to_id(range<3> r, size_t linear_id) {
42  id<3> result;
43  result[0] = linear_id / (r[1] * r[2]);
44  result[1] = (linear_id % (r[1] * r[2])) / r[2];
45  result[2] = linear_id % r[2];
46  return result;
47 }
48 
49 // ---- get_local_linear_range
50 template <typename Group> size_t get_local_linear_range(Group g);
51 template <> inline size_t get_local_linear_range<group<1>>(group<1> g) {
52  return g.get_local_range(0);
53 }
54 template <> inline size_t get_local_linear_range<group<2>>(group<2> g) {
55  return g.get_local_range(0) * g.get_local_range(1);
56 }
57 template <> inline size_t get_local_linear_range<group<3>>(group<3> g) {
58  return g.get_local_range(0) * g.get_local_range(1) * g.get_local_range(2);
59 }
60 template <>
61 inline size_t
62 get_local_linear_range<ext::oneapi::sub_group>(ext::oneapi::sub_group g) {
63  return g.get_local_range()[0];
64 }
65 
66 // ---- get_local_linear_id
67 template <typename Group>
68 inline typename Group::linear_id_type get_local_linear_id(Group g);
69 
70 #ifdef __SYCL_DEVICE_ONLY__
71 #define __SYCL_GROUP_GET_LOCAL_LINEAR_ID(D) \
72  template <> \
73  inline group<D>::linear_id_type get_local_linear_id<group<D>>(group<D>) { \
74  nd_item<D> it = cl::sycl::detail::Builder::getNDItem<D>(); \
75  return it.get_local_linear_id(); \
76  }
77 __SYCL_GROUP_GET_LOCAL_LINEAR_ID(1);
78 __SYCL_GROUP_GET_LOCAL_LINEAR_ID(2);
79 __SYCL_GROUP_GET_LOCAL_LINEAR_ID(3);
80 #undef __SYCL_GROUP_GET_LOCAL_LINEAR_ID
81 #endif // __SYCL_DEVICE_ONLY__
82 
83 template <>
85 get_local_linear_id<ext::oneapi::sub_group>(ext::oneapi::sub_group g) {
86  return g.get_local_id()[0];
87 }
88 
89 // ---- is_native_op
90 template <typename T>
91 using native_op_list =
95 
96 template <typename T, typename BinaryOperation> struct is_native_op {
97  static constexpr bool value =
100 };
101 
102 // ---- is_plus
103 template <typename T, typename BinaryOperation>
104 using is_plus = std::integral_constant<
105  bool, std::is_same<BinaryOperation, sycl::plus<T>>::value ||
106  std::is_same<BinaryOperation, sycl::plus<void>>::value>;
107 
108 // ---- is_complex
109 // NOTE: std::complex<long double> not yet supported by group algorithms.
110 template <typename T>
112  : std::integral_constant<bool,
113  std::is_same<T, std::complex<float>>::value ||
114  std::is_same<T, std::complex<double>>::value> {
115 };
116 
117 // ---- is_arithmetic_or_complex
118 template <typename T>
120  std::integral_constant<bool, sycl::detail::is_complex<T>::value ||
122 // ---- is_plus_if_complex
123 template <typename T, typename BinaryOperation>
124 using is_plus_if_complex =
125  std::integral_constant<bool, (is_complex<T>::value
127  : std::true_type::value)>;
128 
129 // ---- identity_for_ga_op
130 // the group algorithms support std::complex, limited to sycl::plus operation
131 // get the correct identity for group algorithm operation.
132 // TODO: identiy_for_ga_op should be replaced with known_identity once the other
133 // callers of known_identity support complex numbers.
134 template <typename T, class BinaryOperation>
135 constexpr detail::enable_if_t<
138  return {0, 0};
139 }
140 
141 template <typename T, class BinaryOperation>
143  return sycl::known_identity_v<BinaryOperation, T>;
144 }
145 
146 // ---- for_each
147 template <typename Group, typename Ptr, class Function>
148 Function for_each(Group g, Ptr first, Ptr last, Function f) {
149 #ifdef __SYCL_DEVICE_ONLY__
150  ptrdiff_t offset = sycl::detail::get_local_linear_id(g);
151  ptrdiff_t stride = sycl::detail::get_local_linear_range(g);
152  for (Ptr p = first + offset; p < last; p += stride) {
153  f(*p);
154  }
155  return f;
156 #else
157  (void)g;
158  (void)first;
159  (void)last;
160  (void)f;
161  throw runtime_error("Group algorithms are not supported on host device.",
162  PI_ERROR_INVALID_DEVICE);
163 #endif
164 }
165 } // namespace detail
166 
167 // ---- reduce_over_group
168 // three argument variant is specialized thrice:
169 // scalar arithmetic, complex (plus only), and vector arithmetic
170 
171 template <typename Group, typename T, class BinaryOperation>
172 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
173  detail::is_scalar_arithmetic<T>::value &&
174  detail::is_native_op<T, BinaryOperation>::value),
175  T>
176 reduce_over_group(Group, T x, BinaryOperation binary_op) {
177  // FIXME: Do not special-case for half precision
178  static_assert(
179  std::is_same<decltype(binary_op(x, x)), T>::value ||
180  (std::is_same<T, half>::value &&
181  std::is_same<decltype(binary_op(x, x)), float>::value),
182  "Result type of binary_op must match reduction accumulation type.");
183 #ifdef __SYCL_DEVICE_ONLY__
184  return sycl::detail::calc<T, __spv::GroupOperation::Reduce,
185  sycl::detail::spirv::group_scope<Group>::value>(
186  typename sycl::detail::GroupOpTag<T>::type(), x, binary_op);
187 #else
188  throw runtime_error("Group algorithms are not supported on host device.",
189  PI_ERROR_INVALID_DEVICE);
190 #endif
191 }
192 
193 // complex specialization. T is std::complex<float> or similar.
194 // binary op is sycl::plus<std::complex<float>>
195 template <typename Group, typename T, class BinaryOperation>
196 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
197  detail::is_complex<T>::value &&
198  detail::is_native_op<T, sycl::plus<T>>::value &&
199  detail::is_plus<T, BinaryOperation>::value),
200  T>
201 reduce_over_group(Group g, T x, BinaryOperation binary_op) {
202 #ifdef __SYCL_DEVICE_ONLY__
203  T result;
204  result.real(reduce_over_group(g, x.real(), sycl::plus<>()));
205  result.imag(reduce_over_group(g, x.imag(), sycl::plus<>()));
206  return result;
207 #else
208  (void)g;
209  (void)x;
210  (void)binary_op;
211  throw runtime_error("Group algorithms are not supported on host device.",
212  PI_ERROR_INVALID_DEVICE);
213 #endif
214 }
215 
216 template <typename Group, typename T, class BinaryOperation>
217 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
218  detail::is_vector_arithmetic<T>::value &&
219  detail::is_native_op<T, BinaryOperation>::value),
220  T>
221 reduce_over_group(Group g, T x, BinaryOperation binary_op) {
222  // FIXME: Do not special-case for half precision
223  static_assert(
224  std::is_same<decltype(binary_op(x[0], x[0])),
225  typename T::element_type>::value ||
226  (std::is_same<T, half>::value &&
227  std::is_same<decltype(binary_op(x[0], x[0])), float>::value),
228  "Result type of binary_op must match reduction accumulation type.");
229  T result;
230  for (int s = 0; s < x.get_size(); ++s) {
231  result[s] = reduce_over_group(g, x[s], binary_op);
232  }
233  return result;
234 }
235 
236 // four argument variant of reduce_over_group specialized twice
237 // (scalar arithmetic || complex), and vector_arithmetic
238 template <typename Group, typename V, typename T, class BinaryOperation>
240  (is_group_v<std::decay_t<Group>> &&
241  (detail::is_scalar_arithmetic<V>::value || detail::is_complex<V>::value) &&
242  (detail::is_scalar_arithmetic<T>::value || detail::is_complex<T>::value) &&
243  detail::is_native_op<V, BinaryOperation>::value &&
244  detail::is_native_op<T, BinaryOperation>::value &&
245  detail::is_plus_if_complex<T, BinaryOperation>::value &&
246  detail::is_plus_if_complex<V, BinaryOperation>::value),
247  T>
248 reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) {
249  // FIXME: Do not special-case for half precision
250  static_assert(
251  std::is_same<decltype(binary_op(init, x)), T>::value ||
252  (std::is_same<T, half>::value &&
253  std::is_same<decltype(binary_op(init, x)), float>::value),
254  "Result type of binary_op must match reduction accumulation type.");
255 #ifdef __SYCL_DEVICE_ONLY__
256  return binary_op(init, reduce_over_group(g, x, binary_op));
257 #else
258  (void)g;
259  throw runtime_error("Group algorithms are not supported on host device.",
260  PI_ERROR_INVALID_DEVICE);
261 #endif
262 }
263 
264 template <typename Group, typename V, typename T, class BinaryOperation>
265 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
266  detail::is_vector_arithmetic<V>::value &&
267  detail::is_vector_arithmetic<T>::value &&
268  detail::is_native_op<V, BinaryOperation>::value &&
269  detail::is_native_op<T, BinaryOperation>::value),
270  T>
271 reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) {
272  // FIXME: Do not special-case for half precision
273  static_assert(
274  std::is_same<decltype(binary_op(init[0], x[0])),
275  typename T::element_type>::value ||
276  (std::is_same<T, half>::value &&
277  std::is_same<decltype(binary_op(init[0], x[0])), float>::value),
278  "Result type of binary_op must match reduction accumulation type.");
279 #ifdef __SYCL_DEVICE_ONLY__
280  T result = init;
281  for (int s = 0; s < x.get_size(); ++s) {
282  result[s] = binary_op(init[s], reduce_over_group(g, x[s], binary_op));
283  }
284  return result;
285 #else
286  (void)g;
287  throw runtime_error("Group algorithms are not supported on host device.",
288  PI_ERROR_INVALID_DEVICE);
289 #endif
290 }
291 
292 // ---- joint_reduce
293 template <typename Group, typename Ptr, class BinaryOperation>
295  (is_group_v<std::decay_t<Group>> && detail::is_pointer<Ptr>::value &&
297  typename detail::remove_pointer<Ptr>::type>::value &&
298  detail::is_plus_if_complex<typename detail::remove_pointer<Ptr>::type,
299  BinaryOperation>::value),
300  typename detail::remove_pointer<Ptr>::type>
301 joint_reduce(Group g, Ptr first, Ptr last, BinaryOperation binary_op) {
302 #ifdef __SYCL_DEVICE_ONLY__
303  using T = typename detail::remove_pointer<Ptr>::type;
304  T init = detail::identity_for_ga_op<T, BinaryOperation>();
305  return joint_reduce(g, first, last, init, binary_op);
306 #else
307  (void)g;
308  (void)first;
309  (void)last;
310  (void)binary_op;
311  throw runtime_error("Group algorithms are not supported on host device.",
312  PI_ERROR_INVALID_DEVICE);
313 #endif
314 }
315 
316 template <typename Group, typename Ptr, typename T, class BinaryOperation>
318  (is_group_v<std::decay_t<Group>> && detail::is_pointer<Ptr>::value &&
320  typename detail::remove_pointer<Ptr>::type>::value &&
321  detail::is_arithmetic_or_complex<T>::value &&
322  detail::is_native_op<typename detail::remove_pointer<Ptr>::type,
323  BinaryOperation>::value &&
324  detail::is_plus_if_complex<typename detail::remove_pointer<Ptr>::type,
325  BinaryOperation>::value &&
326  detail::is_plus_if_complex<T, BinaryOperation>::value &&
327  detail::is_native_op<T, BinaryOperation>::value),
328  T>
329 joint_reduce(Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op) {
330  // FIXME: Do not special-case for half precision
331  static_assert(
332  std::is_same<decltype(binary_op(init, *first)), T>::value ||
333  (std::is_same<T, half>::value &&
334  std::is_same<decltype(binary_op(init, *first)), float>::value),
335  "Result type of binary_op must match reduction accumulation type.");
336 #ifdef __SYCL_DEVICE_ONLY__
337  T partial = detail::identity_for_ga_op<T, BinaryOperation>();
339  g, first, last, [&](const typename detail::remove_pointer<Ptr>::type &x) {
340  partial = binary_op(partial, x);
341  });
342  return reduce_over_group(g, partial, init, binary_op);
343 #else
344  (void)g;
345  (void)last;
346  throw runtime_error("Group algorithms are not supported on host device.",
347  PI_ERROR_INVALID_DEVICE);
348 #endif
349 }
350 
351 // ---- any_of_group
352 template <typename Group>
353 detail::enable_if_t<is_group_v<std::decay_t<Group>>, bool>
354 any_of_group(Group, bool pred) {
355 #ifdef __SYCL_DEVICE_ONLY__
356  return sycl::detail::spirv::GroupAny<Group>(pred);
357 #else
358  (void)pred;
359  throw runtime_error("Group algorithms are not supported on host device.",
360  PI_ERROR_INVALID_DEVICE);
361 #endif
362 }
363 
364 template <typename Group, typename T, class Predicate>
366  Predicate pred) {
367  return any_of_group(g, pred(x));
368 }
369 
370 // ---- joint_any_of
371 template <typename Group, typename Ptr, class Predicate>
373  (is_group_v<std::decay_t<Group>> && detail::is_pointer<Ptr>::value), bool>
374 joint_any_of(Group g, Ptr first, Ptr last, Predicate pred) {
375 #ifdef __SYCL_DEVICE_ONLY__
376  using T = typename detail::remove_pointer<Ptr>::type;
377  bool partial = false;
378  sycl::detail::for_each(g, first, last, [&](T &x) { partial |= pred(x); });
379  return any_of_group(g, partial);
380 #else
381  (void)g;
382  (void)first;
383  (void)last;
384  (void)pred;
385  throw runtime_error("Group algorithms are not supported on host device.",
386  PI_ERROR_INVALID_DEVICE);
387 #endif
388 }
389 
390 // ---- all_of_group
391 template <typename Group>
392 detail::enable_if_t<is_group_v<std::decay_t<Group>>, bool>
393 all_of_group(Group, bool pred) {
394 #ifdef __SYCL_DEVICE_ONLY__
395  return sycl::detail::spirv::GroupAll<Group>(pred);
396 #else
397  (void)pred;
398  throw runtime_error("Group algorithms are not supported on host device.",
399  PI_ERROR_INVALID_DEVICE);
400 #endif
401 }
402 
403 template <typename Group, typename T, class Predicate>
404 detail::enable_if_t<is_group_v<std::decay_t<Group>>, bool>
405 all_of_group(Group g, T x, Predicate pred) {
406  return all_of_group(g, pred(x));
407 }
408 
409 // ---- joint_all_of
410 template <typename Group, typename Ptr, class Predicate>
412  (is_group_v<std::decay_t<Group>> && detail::is_pointer<Ptr>::value), bool>
413 joint_all_of(Group g, Ptr first, Ptr last, Predicate pred) {
414 #ifdef __SYCL_DEVICE_ONLY__
415  using T = typename detail::remove_pointer<Ptr>::type;
416  bool partial = true;
417  sycl::detail::for_each(g, first, last, [&](T &x) { partial &= pred(x); });
418  return all_of_group(g, partial);
419 #else
420  (void)g;
421  (void)first;
422  (void)last;
423  (void)pred;
424  throw runtime_error("Group algorithms are not supported on host device.",
425  PI_ERROR_INVALID_DEVICE);
426 #endif
427 }
428 
429 // ---- none_of_group
430 template <typename Group>
431 detail::enable_if_t<is_group_v<std::decay_t<Group>>, bool>
432 none_of_group(Group, bool pred) {
433 #ifdef __SYCL_DEVICE_ONLY__
434  return sycl::detail::spirv::GroupAll<Group>(!pred);
435 #else
436  (void)pred;
437  throw runtime_error("Group algorithms are not supported on host device.",
438  PI_ERROR_INVALID_DEVICE);
439 #endif
440 }
441 
442 template <typename Group, typename T, class Predicate>
443 detail::enable_if_t<is_group_v<std::decay_t<Group>>, bool>
444 none_of_group(Group g, T x, Predicate pred) {
445  return none_of_group(g, pred(x));
446 }
447 
448 // ---- joint_none_of
449 template <typename Group, typename Ptr, class Predicate>
451  (is_group_v<std::decay_t<Group>> && detail::is_pointer<Ptr>::value), bool>
452 joint_none_of(Group g, Ptr first, Ptr last, Predicate pred) {
453 #ifdef __SYCL_DEVICE_ONLY__
454  return !joint_any_of(g, first, last, pred);
455 #else
456  (void)g;
457  (void)first;
458  (void)last;
459  (void)pred;
460  throw runtime_error("Group algorithms are not supported on host device.",
461  PI_ERROR_INVALID_DEVICE);
462 #endif
463 }
464 
465 // ---- shift_group_left
466 // TODO: remove check for detail::is_vec<T> once sycl::vec is trivially
467 // copyable.
468 template <typename Group, typename T>
469 detail::enable_if_t<(std::is_same<std::decay_t<Group>, sub_group>::value &&
470  (std::is_trivially_copyable<T>::value ||
471  detail::is_vec<T>::value)),
472  T>
473 shift_group_left(Group, T x, typename Group::linear_id_type delta = 1) {
474 #ifdef __SYCL_DEVICE_ONLY__
475  return sycl::detail::spirv::SubgroupShuffleDown(x, delta);
476 #else
477  (void)x;
478  (void)delta;
479  throw runtime_error("Sub-groups are not supported on host device.",
480  PI_ERROR_INVALID_DEVICE);
481 #endif
482 }
483 
484 // ---- shift_group_right
485 // TODO: remove check for detail::is_vec<T> once sycl::vec is trivially
486 // copyable.
487 template <typename Group, typename T>
488 detail::enable_if_t<(std::is_same<std::decay_t<Group>, sub_group>::value &&
489  (std::is_trivially_copyable<T>::value ||
490  detail::is_vec<T>::value)),
491  T>
492 shift_group_right(Group, T x, typename Group::linear_id_type delta = 1) {
493 #ifdef __SYCL_DEVICE_ONLY__
494  return sycl::detail::spirv::SubgroupShuffleUp(x, delta);
495 #else
496  (void)x;
497  (void)delta;
498  throw runtime_error("Sub-groups are not supported on host device.",
499  PI_ERROR_INVALID_DEVICE);
500 #endif
501 }
502 
503 // ---- permute_group_by_xor
504 // TODO: remove check for detail::is_vec<T> once sycl::vec is trivially
505 // copyable.
506 template <typename Group, typename T>
507 detail::enable_if_t<(std::is_same<std::decay_t<Group>, sub_group>::value &&
508  (std::is_trivially_copyable<T>::value ||
509  detail::is_vec<T>::value)),
510  T>
511 permute_group_by_xor(Group, T x, typename Group::linear_id_type mask) {
512 #ifdef __SYCL_DEVICE_ONLY__
513  return sycl::detail::spirv::SubgroupShuffleXor(x, mask);
514 #else
515  (void)x;
516  (void)mask;
517  throw runtime_error("Sub-groups are not supported on host device.",
518  PI_ERROR_INVALID_DEVICE);
519 #endif
520 }
521 
522 // ---- select_from_group
523 // TODO: remove check for detail::is_vec<T> once sycl::vec is trivially
524 // copyable.
525 template <typename Group, typename T>
526 detail::enable_if_t<(std::is_same<std::decay_t<Group>, sub_group>::value &&
527  (std::is_trivially_copyable<T>::value ||
528  detail::is_vec<T>::value)),
529  T>
530 select_from_group(Group, T x, typename Group::id_type local_id) {
531 #ifdef __SYCL_DEVICE_ONLY__
532  return sycl::detail::spirv::SubgroupShuffle(x, local_id);
533 #else
534  (void)x;
535  (void)local_id;
536  throw runtime_error("Sub-groups are not supported on host device.",
537  PI_ERROR_INVALID_DEVICE);
538 #endif
539 }
540 
541 // ---- group_broadcast
542 // TODO: remove check for detail::is_vec<T> once sycl::vec is trivially
543 // copyable.
544 template <typename Group, typename T>
545 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
546  (std::is_trivially_copyable<T>::value ||
547  detail::is_vec<T>::value)),
548  T>
549 group_broadcast(Group, T x, typename Group::id_type local_id) {
550 #ifdef __SYCL_DEVICE_ONLY__
551  return sycl::detail::spirv::GroupBroadcast<Group>(x, local_id);
552 #else
553  (void)x;
554  (void)local_id;
555  throw runtime_error("Group algorithms are not supported on host device.",
556  PI_ERROR_INVALID_DEVICE);
557 #endif
558 }
559 
560 template <typename Group, typename T>
561 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
562  (std::is_trivially_copyable<T>::value ||
563  detail::is_vec<T>::value)),
564  T>
565 group_broadcast(Group g, T x, typename Group::linear_id_type linear_local_id) {
566 #ifdef __SYCL_DEVICE_ONLY__
567  return group_broadcast(
568  g, x,
569  sycl::detail::linear_id_to_id(g.get_local_range(), linear_local_id));
570 #else
571  (void)g;
572  (void)x;
573  (void)linear_local_id;
574  throw runtime_error("Group algorithms are not supported on host device.",
575  PI_ERROR_INVALID_DEVICE);
576 #endif
577 }
578 
579 template <typename Group, typename T>
580 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
581  (std::is_trivially_copyable<T>::value ||
582  detail::is_vec<T>::value)),
583  T>
584 group_broadcast(Group g, T x) {
585 #ifdef __SYCL_DEVICE_ONLY__
586  return group_broadcast(g, x, 0);
587 #else
588  (void)g;
589  (void)x;
590  throw runtime_error("Group algorithms are not supported on host device.",
591  PI_ERROR_INVALID_DEVICE);
592 #endif
593 }
594 
595 // ---- exclusive_scan_over_group
596 // this function has two overloads, one with three arguments and one with four
597 // the three argument version is specialized thrice: scalar, complex, and
598 // vector
599 template <typename Group, typename T, class BinaryOperation>
600 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
601  detail::is_scalar_arithmetic<T>::value &&
602  detail::is_native_op<T, BinaryOperation>::value),
603  T>
604 exclusive_scan_over_group(Group, T x, BinaryOperation binary_op) {
605  // FIXME: Do not special-case for half precision
606  static_assert(std::is_same<decltype(binary_op(x, x)), T>::value ||
607  (std::is_same<T, half>::value &&
608  std::is_same<decltype(binary_op(x, x)), float>::value),
609  "Result type of binary_op must match scan accumulation type.");
610 #ifdef __SYCL_DEVICE_ONLY__
611  return sycl::detail::calc<T, __spv::GroupOperation::ExclusiveScan,
612  sycl::detail::spirv::group_scope<Group>::value>(
613  typename sycl::detail::GroupOpTag<T>::type(), x, binary_op);
614 #else
615  throw runtime_error("Group algorithms are not supported on host device.",
616  PI_ERROR_INVALID_DEVICE);
617 #endif
618 }
619 
620 // complex specialization. T is std::complex<float> or similar.
621 // binary op is sycl::plus<std::complex<float>>
622 template <typename Group, typename T, class BinaryOperation>
623 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
624  detail::is_complex<T>::value &&
625  detail::is_native_op<T, sycl::plus<T>>::value &&
626  detail::is_plus<T, BinaryOperation>::value),
627  T>
628 exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
629 #ifdef __SYCL_DEVICE_ONLY__
630  T result;
631  result.real(exclusive_scan_over_group(g, x.real(), sycl::plus<>()));
632  result.imag(exclusive_scan_over_group(g, x.imag(), sycl::plus<>()));
633  return result;
634 #else
635  (void)g;
636  (void)x;
637  (void)binary_op;
638  throw runtime_error("Group algorithms are not supported on host device.",
639  PI_ERROR_INVALID_DEVICE);
640 #endif
641 }
642 
643 template <typename Group, typename T, class BinaryOperation>
644 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
645  detail::is_vector_arithmetic<T>::value &&
646  detail::is_native_op<T, BinaryOperation>::value),
647  T>
648 exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
649  // FIXME: Do not special-case for half precision
650  static_assert(
651  std::is_same<decltype(binary_op(x[0], x[0])),
652  typename T::element_type>::value ||
653  (std::is_same<T, half>::value &&
654  std::is_same<decltype(binary_op(x[0], x[0])), float>::value),
655  "Result type of binary_op must match scan accumulation type.");
656  T result;
657  for (int s = 0; s < x.get_size(); ++s) {
658  result[s] = exclusive_scan_over_group(g, x[s], binary_op);
659  }
660  return result;
661 }
662 
663 // four argument version of exclusive_scan_over_group is specialized twice
664 // once for vector_arithmetic, once for (scalar_arithmetic || complex)
665 template <typename Group, typename V, typename T, class BinaryOperation>
666 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
667  detail::is_vector_arithmetic<V>::value &&
668  detail::is_vector_arithmetic<T>::value &&
669  detail::is_native_op<V, BinaryOperation>::value &&
670  detail::is_native_op<T, BinaryOperation>::value),
671  T>
672 exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op) {
673  // FIXME: Do not special-case for half precision
674  static_assert(
675  std::is_same<decltype(binary_op(init[0], x[0])),
676  typename T::element_type>::value ||
677  (std::is_same<T, half>::value &&
678  std::is_same<decltype(binary_op(init[0], x[0])), float>::value),
679  "Result type of binary_op must match scan accumulation type.");
680  T result;
681  for (int s = 0; s < x.get_size(); ++s) {
682  result[s] = exclusive_scan_over_group(g, x[s], init[s], binary_op);
683  }
684  return result;
685 }
686 
687 template <typename Group, typename V, typename T, class BinaryOperation>
689  (is_group_v<std::decay_t<Group>> &&
690  (detail::is_scalar_arithmetic<V>::value || detail::is_complex<V>::value) &&
691  (detail::is_scalar_arithmetic<T>::value || detail::is_complex<T>::value) &&
692  detail::is_native_op<V, BinaryOperation>::value &&
693  detail::is_native_op<T, BinaryOperation>::value &&
694  detail::is_plus_if_complex<V, BinaryOperation>::value &&
695  detail::is_plus_if_complex<T, BinaryOperation>::value),
696  T>
697 exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op) {
698  // FIXME: Do not special-case for half precision
699  static_assert(std::is_same<decltype(binary_op(init, x)), T>::value ||
700  (std::is_same<T, half>::value &&
701  std::is_same<decltype(binary_op(init, x)), float>::value),
702  "Result type of binary_op must match scan accumulation type.");
703 #ifdef __SYCL_DEVICE_ONLY__
704  typename Group::linear_id_type local_linear_id =
706  if (local_linear_id == 0) {
707  x = binary_op(init, x);
708  }
709  T scan = exclusive_scan_over_group(g, x, binary_op);
710  if (local_linear_id == 0) {
711  scan = init;
712  }
713  return scan;
714 #else
715  (void)g;
716  throw runtime_error("Group algorithms are not supported on host device.",
717  PI_ERROR_INVALID_DEVICE);
718 #endif
719 }
720 
721 // ---- joint_exclusive_scan
722 template <typename Group, typename InPtr, typename OutPtr, typename T,
723  class BinaryOperation>
725  (is_group_v<std::decay_t<Group>> && detail::is_pointer<InPtr>::value &&
726  detail::is_pointer<OutPtr>::value &&
728  typename detail::remove_pointer<InPtr>::type>::value &&
729  detail::is_arithmetic_or_complex<T>::value &&
730  detail::is_native_op<typename detail::remove_pointer<InPtr>::type,
731  BinaryOperation>::value &&
732  detail::is_native_op<T, BinaryOperation>::value &&
733  detail::is_plus_if_complex<typename detail::remove_pointer<InPtr>::type,
734  BinaryOperation>::value &&
735  detail::is_plus_if_complex<T, BinaryOperation>::value),
736  OutPtr>
737 joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, T init,
738  BinaryOperation binary_op) {
739  // FIXME: Do not special-case for half precision
740  static_assert(
741  std::is_same<decltype(binary_op(*first, *first)), T>::value ||
742  (std::is_same<T, half>::value &&
743  std::is_same<decltype(binary_op(*first, *first)), float>::value),
744  "Result type of binary_op must match scan accumulation type.");
745 #ifdef __SYCL_DEVICE_ONLY__
746  ptrdiff_t offset = sycl::detail::get_local_linear_id(g);
747  ptrdiff_t stride = sycl::detail::get_local_linear_range(g);
748  ptrdiff_t N = last - first;
749  auto roundup = [=](const ptrdiff_t &v,
750  const ptrdiff_t &divisor) -> ptrdiff_t {
751  return ((v + divisor - 1) / divisor) * divisor;
752  };
753  typename std::remove_const<typename detail::remove_pointer<InPtr>::type>::type
754  x;
755  typename detail::remove_pointer<OutPtr>::type carry = init;
756  for (ptrdiff_t chunk = 0; chunk < roundup(N, stride); chunk += stride) {
757  ptrdiff_t i = chunk + offset;
758  if (i < N) {
759  x = first[i];
760  }
762  exclusive_scan_over_group(g, x, carry, binary_op);
763  if (i < N) {
764  result[i] = out;
765  }
766  carry = group_broadcast(g, binary_op(out, x), stride - 1);
767  }
768  return result + N;
769 #else
770  (void)g;
771  (void)last;
772  (void)result;
773  (void)init;
774  throw runtime_error("Group algorithms are not supported on host device.",
775  PI_ERROR_INVALID_DEVICE);
776 #endif
777 }
778 
779 template <typename Group, typename InPtr, typename OutPtr,
780  class BinaryOperation>
782  (is_group_v<std::decay_t<Group>> && detail::is_pointer<InPtr>::value &&
783  detail::is_pointer<OutPtr>::value &&
785  typename detail::remove_pointer<InPtr>::type>::value &&
786  detail::is_native_op<typename detail::remove_pointer<InPtr>::type,
787  BinaryOperation>::value &&
788  detail::is_plus_if_complex<typename detail::remove_pointer<InPtr>::type,
789  BinaryOperation>::value),
790  OutPtr>
791 joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
792  BinaryOperation binary_op) {
793  // FIXME: Do not special-case for half precision
794  static_assert(
795  std::is_same<decltype(binary_op(*first, *first)),
796  typename detail::remove_pointer<OutPtr>::type>::value ||
797  (std::is_same<typename detail::remove_pointer<OutPtr>::type,
798  half>::value &&
799  std::is_same<decltype(binary_op(*first, *first)), float>::value),
800  "Result type of binary_op must match scan accumulation type.");
801  using T = typename detail::remove_pointer<InPtr>::type;
802  T init = detail::identity_for_ga_op<T, BinaryOperation>();
803  return joint_exclusive_scan(g, first, last, result, init, binary_op);
804 }
805 
806 // ---- inclusive_scan_over_group
807 // this function has two overloads, one with three arguments and one with four
808 // the three argument version is specialized thrice: vector, scalar, and
809 // complex
810 template <typename Group, typename T, class BinaryOperation>
811 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
812  detail::is_vector_arithmetic<T>::value &&
813  detail::is_native_op<T, BinaryOperation>::value),
814  T>
815 inclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
816  // FIXME: Do not special-case for half precision
817  static_assert(
818  std::is_same<decltype(binary_op(x[0], x[0])),
819  typename T::element_type>::value ||
820  (std::is_same<T, half>::value &&
821  std::is_same<decltype(binary_op(x[0], x[0])), float>::value),
822  "Result type of binary_op must match scan accumulation type.");
823  T result;
824  for (int s = 0; s < x.get_size(); ++s) {
825  result[s] = inclusive_scan_over_group(g, x[s], binary_op);
826  }
827  return result;
828 }
829 
830 template <typename Group, typename T, class BinaryOperation>
831 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
832  detail::is_scalar_arithmetic<T>::value &&
833  detail::is_native_op<T, BinaryOperation>::value),
834  T>
835 inclusive_scan_over_group(Group, T x, BinaryOperation binary_op) {
836  // FIXME: Do not special-case for half precision
837  static_assert(std::is_same<decltype(binary_op(x, x)), T>::value ||
838  (std::is_same<T, half>::value &&
839  std::is_same<decltype(binary_op(x, x)), float>::value),
840  "Result type of binary_op must match scan accumulation type.");
841 #ifdef __SYCL_DEVICE_ONLY__
842  return sycl::detail::calc<T, __spv::GroupOperation::InclusiveScan,
843  sycl::detail::spirv::group_scope<Group>::value>(
844  typename sycl::detail::GroupOpTag<T>::type(), x, binary_op);
845 #else
846  throw runtime_error("Group algorithms are not supported on host device.",
847  PI_ERROR_INVALID_DEVICE);
848 #endif
849 }
850 
851 // complex specializaiton
852 template <typename Group, typename T, class BinaryOperation>
853 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
854  detail::is_complex<T>::value &&
855  detail::is_native_op<T, sycl::plus<T>>::value &&
856  detail::is_plus<T, BinaryOperation>::value),
857  T>
858 inclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
859 #ifdef __SYCL_DEVICE_ONLY__
860  T result;
861  result.real(inclusive_scan_over_group(g, x.real(), sycl::plus<>()));
862  result.imag(inclusive_scan_over_group(g, x.imag(), sycl::plus<>()));
863  return result;
864 #else
865  (void)g;
866  (void)x;
867  (void)binary_op;
868  throw runtime_error("Group algorithms are not supported on host device.",
869  PI_ERROR_INVALID_DEVICE);
870 #endif
871 }
872 
873 // four argument version of inclusive_scan_over_group is specialized twice
874 // once for (scalar_arithmetic || complex) and once for vector_arithmetic
875 template <typename Group, typename V, class BinaryOperation, typename T>
877  (is_group_v<std::decay_t<Group>> &&
878  (detail::is_scalar_arithmetic<V>::value || detail::is_complex<V>::value) &&
879  (detail::is_scalar_arithmetic<T>::value || detail::is_complex<T>::value) &&
880  detail::is_native_op<V, BinaryOperation>::value &&
881  detail::is_native_op<T, BinaryOperation>::value &&
882  detail::is_plus_if_complex<T, BinaryOperation>::value &&
883  detail::is_plus_if_complex<V, BinaryOperation>::value),
884  T>
885 inclusive_scan_over_group(Group g, V x, BinaryOperation binary_op, T init) {
886  // FIXME: Do not special-case for half precision
887  static_assert(std::is_same<decltype(binary_op(init, x)), T>::value ||
888  (std::is_same<T, half>::value &&
889  std::is_same<decltype(binary_op(init, x)), float>::value),
890  "Result type of binary_op must match scan accumulation type.");
891 #ifdef __SYCL_DEVICE_ONLY__
892  if (sycl::detail::get_local_linear_id(g) == 0) {
893  x = binary_op(init, x);
894  }
895  return inclusive_scan_over_group(g, x, binary_op);
896 #else
897  (void)g;
898  throw runtime_error("Group algorithms are not supported on host device.",
899  PI_ERROR_INVALID_DEVICE);
900 #endif
901 }
902 
903 template <typename Group, typename V, class BinaryOperation, typename T>
904 detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
905  detail::is_vector_arithmetic<V>::value &&
906  detail::is_vector_arithmetic<T>::value &&
907  detail::is_native_op<V, BinaryOperation>::value &&
908  detail::is_native_op<T, BinaryOperation>::value),
909  T>
910 inclusive_scan_over_group(Group g, V x, BinaryOperation binary_op, T init) {
911  // FIXME: Do not special-case for half precision
912  static_assert(
913  std::is_same<decltype(binary_op(init[0], x[0])), T>::value ||
914  (std::is_same<T, half>::value &&
915  std::is_same<decltype(binary_op(init[0], x[0])), float>::value),
916  "Result type of binary_op must match scan accumulation type.");
917  T result;
918  for (int s = 0; s < x.get_size(); ++s) {
919  result[s] = inclusive_scan_over_group(g, x[s], binary_op, init[s]);
920  }
921  return result;
922 }
923 
924 // ---- joint_inclusive_scan
925 template <typename Group, typename InPtr, typename OutPtr,
926  class BinaryOperation, typename T>
928  (is_group_v<std::decay_t<Group>> && detail::is_pointer<InPtr>::value &&
929  detail::is_pointer<OutPtr>::value &&
931  typename detail::remove_pointer<InPtr>::type>::value &&
932  detail::is_arithmetic_or_complex<T>::value &&
933  detail::is_native_op<typename detail::remove_pointer<InPtr>::type,
934  BinaryOperation>::value &&
935  detail::is_native_op<T, BinaryOperation>::value &&
936  detail::is_plus_if_complex<typename detail::remove_pointer<InPtr>::type,
937  BinaryOperation>::value &&
938  detail::is_plus_if_complex<T, BinaryOperation>::value),
939  OutPtr>
940 joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
941  BinaryOperation binary_op, T init) {
942  // FIXME: Do not special-case for half precision
943  static_assert(
944  std::is_same<decltype(binary_op(init, *first)), T>::value ||
945  (std::is_same<T, half>::value &&
946  std::is_same<decltype(binary_op(init, *first)), float>::value),
947  "Result type of binary_op must match scan accumulation type.");
948 #ifdef __SYCL_DEVICE_ONLY__
949  ptrdiff_t offset = sycl::detail::get_local_linear_id(g);
950  ptrdiff_t stride = sycl::detail::get_local_linear_range(g);
951  ptrdiff_t N = last - first;
952  auto roundup = [=](const ptrdiff_t &v,
953  const ptrdiff_t &divisor) -> ptrdiff_t {
954  return ((v + divisor - 1) / divisor) * divisor;
955  };
956  typename std::remove_const<typename detail::remove_pointer<InPtr>::type>::type
957  x;
958  typename detail::remove_pointer<OutPtr>::type carry = init;
959  for (ptrdiff_t chunk = 0; chunk < roundup(N, stride); chunk += stride) {
960  ptrdiff_t i = chunk + offset;
961  if (i < N) {
962  x = first[i];
963  }
965  inclusive_scan_over_group(g, x, binary_op, carry);
966  if (i < N) {
967  result[i] = out;
968  }
969  carry = group_broadcast(g, out, stride - 1);
970  }
971  return result + N;
972 #else
973  (void)g;
974  (void)last;
975  (void)result;
976  throw runtime_error("Group algorithms are not supported on host device.",
977  PI_ERROR_INVALID_DEVICE);
978 #endif
979 }
980 
981 template <typename Group, typename InPtr, typename OutPtr,
982  class BinaryOperation>
984  (is_group_v<std::decay_t<Group>> && detail::is_pointer<InPtr>::value &&
985  detail::is_pointer<OutPtr>::value &&
987  typename detail::remove_pointer<InPtr>::type>::value &&
988  detail::is_native_op<typename detail::remove_pointer<InPtr>::type,
989  BinaryOperation>::value &&
990  detail::is_plus_if_complex<typename detail::remove_pointer<InPtr>::type,
991  BinaryOperation>::value),
992  OutPtr>
993 joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
994  BinaryOperation binary_op) {
995  // FIXME: Do not special-case for half precision
996  static_assert(
997  std::is_same<decltype(binary_op(*first, *first)),
998  typename detail::remove_pointer<OutPtr>::type>::value ||
999  (std::is_same<typename detail::remove_pointer<OutPtr>::type,
1000  half>::value &&
1001  std::is_same<decltype(binary_op(*first, *first)), float>::value),
1002  "Result type of binary_op must match scan accumulation type.");
1003 
1004  using T = typename detail::remove_pointer<InPtr>::type;
1005  T init = detail::identity_for_ga_op<T, BinaryOperation>();
1006  return joint_inclusive_scan(g, first, last, result, binary_op, init);
1007 }
1008 
1009 namespace detail {
1010 template <typename G> struct group_barrier_scope {};
1011 template <> struct group_barrier_scope<sycl::sub_group> {
1012  constexpr static auto Scope = __spv::Scope::Subgroup;
1013 };
1014 template <int D> struct group_barrier_scope<sycl::group<D>> {
1015  constexpr static auto Scope = __spv::Scope::Workgroup;
1016 };
1017 } // namespace detail
1018 
1019 template <typename Group>
1020 typename std::enable_if<is_group_v<Group>>::type
1021 group_barrier(Group, memory_scope FenceScope = Group::fence_scope) {
1022  (void)FenceScope;
1023 #ifdef __SYCL_DEVICE_ONLY__
1024  // Per SYCL spec, group_barrier must perform both control barrier and memory
1025  // fence operations. All work-items execute a release fence prior to
1026  // barrier and acquire fence afterwards. The rest of semantics flags specify
1027  // which type of memory this behavior is applied to.
1029  sycl::detail::spirv::getScope(FenceScope),
1034 #else
1035  throw sycl::runtime_error("Barriers are not supported on host device",
1036  PI_ERROR_INVALID_DEVICE);
1037 #endif
1038 }
1039 
1040 } // namespace sycl
1041 } // __SYCL_INLINE_NAMESPACE(cl)
spirv_ops.hpp
__spirv_ControlBarrier
__SYCL_CONVERGENT__ SYCL_EXTERNAL void __spirv_ControlBarrier(__spv::Scope Execution, __spv::Scope Memory, uint32_t Semantics) noexcept
Definition: spirv_ops.cpp:26
__spv::MemorySemanticsMask::SubgroupMemory
@ SubgroupMemory
Definition: spirv_types.hpp:91
cl::sycl::detail::is_arithmetic
Definition: type_traits.hpp:228
__spv::Scope::Workgroup
@ Workgroup
Definition: spirv_types.hpp:30
cl::sycl::bit_xor
std::bit_xor< T > bit_xor
Definition: functional.hpp:22
type_traits.hpp
sub_group.hpp
cl::sycl::multiplies
std::multiplies< T > multiplies
Definition: functional.hpp:19
cl::sycl::detail::type_list
Definition: type_list.hpp:23
cl::sycl::detail::for_each
Function for_each(Group g, Ptr first, Ptr last, Function f)
Definition: group_algorithm.hpp:148
cl::sycl::joint_all_of
detail::enable_if_t<(is_group_v< std::decay_t< Group >> &&detail::is_pointer< Ptr >::value), bool > joint_all_of(Group g, Ptr first, Ptr last, Predicate pred)
Definition: group_algorithm.hpp:413
cl::sycl::id< 1 >
cl::sycl::group< 1 >
cl::sycl
Definition: access.hpp:14
cl::sycl::detail::is_contained
Definition: type_list.hpp:54
cl::sycl::detail::linear_id_to_id
id< 3 > linear_id_to_id(range< 3 > r, size_t linear_id)
Definition: group_algorithm.hpp:41
cl::sycl::detail::get_local_linear_range
size_t get_local_linear_range(Group g)
cl::sycl::exclusive_scan_over_group
detail::enable_if_t<(is_group_v< std::decay_t< Group >> &&(detail::is_scalar_arithmetic< V >::value||detail::is_complex< V >::value) &&(detail::is_scalar_arithmetic< T >::value||detail::is_complex< T >::value) &&detail::is_native_op< V, BinaryOperation >::value &&detail::is_native_op< T, BinaryOperation >::value &&detail::is_plus_if_complex< V, BinaryOperation >::value &&detail::is_plus_if_complex< T, BinaryOperation >::value), T > exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op)
Definition: group_algorithm.hpp:697
cl::sycl::shift_group_left
detail::enable_if_t<(std::is_same< std::decay_t< Group >, sub_group >::value &&(std::is_trivially_copyable< T >::value||detail::is_vec< T >::value)), T > shift_group_left(Group, T x, typename Group::linear_id_type delta=1)
Definition: group_algorithm.hpp:473
cl::sycl::detail::get_local_linear_id
Group::linear_id_type get_local_linear_id(Group g)
spirv_vars.hpp
sycl
Definition: invoke_simd.hpp:68
cl::sycl::shift_group_right
detail::enable_if_t<(std::is_same< std::decay_t< Group >, sub_group >::value &&(std::is_trivially_copyable< T >::value||detail::is_vec< T >::value)), T > shift_group_right(Group, T x, typename Group::linear_id_type delta=1)
Definition: group_algorithm.hpp:492
__spv::MemorySemanticsMask::CrossWorkgroupMemory
@ CrossWorkgroupMemory
Definition: spirv_types.hpp:93
cl::sycl::reduce_over_group
detail::enable_if_t<(is_group_v< std::decay_t< Group >> &&detail::is_vector_arithmetic< V >::value &&detail::is_vector_arithmetic< T >::value &&detail::is_native_op< V, BinaryOperation >::value &&detail::is_native_op< T, BinaryOperation >::value), T > reduce_over_group(Group g, V x, T init, BinaryOperation binary_op)
Definition: group_algorithm.hpp:271
group.hpp
cl::sycl::detail::is_arithmetic_or_complex
std::integral_constant< bool, sycl::detail::is_complex< T >::value||sycl::detail::is_arithmetic< T >::value > is_arithmetic_or_complex
Definition: group_algorithm.hpp:121
cl::sycl::range< 1 >
cl::sycl::detail::identity_for_ga_op
constexpr detail::enable_if_t<!is_complex< T >::value, T > identity_for_ga_op()
Definition: group_algorithm.hpp:142
cl::sycl::detail::group_barrier_scope
Definition: group_algorithm.hpp:1010
cl::sycl::bit_or
std::bit_or< T > bit_or
Definition: functional.hpp:21
cl::sycl::detail::is_complex
Definition: group_algorithm.hpp:111
cl::sycl::memory_scope
memory_scope
Definition: memory_enums.hpp:26
__spv::GroupOperation::ExclusiveScan
@ ExclusiveScan
cl::sycl::none_of_group
detail::enable_if_t< is_group_v< std::decay_t< Group > >, bool > none_of_group(Group g, T x, Predicate pred)
Definition: group_algorithm.hpp:444
cl::sycl::maximum
Definition: functional.hpp:43
functional.hpp
cl::sycl::detail::half_impl::half
Definition: half_type.hpp:329
__spv::Scope::Subgroup
@ Subgroup
Definition: spirv_types.hpp:31
__spv::GroupOperation::InclusiveScan
@ InclusiveScan
cl::sycl::minimum
Definition: functional.hpp:26
spirv.hpp
cl::sycl::inclusive_scan_over_group
detail::enable_if_t<(is_group_v< std::decay_t< Group >> &&detail::is_vector_arithmetic< V >::value &&detail::is_vector_arithmetic< T >::value &&detail::is_native_op< V, BinaryOperation >::value &&detail::is_native_op< T, BinaryOperation >::value), T > inclusive_scan_over_group(Group g, V x, BinaryOperation binary_op, T init)
Definition: group_algorithm.hpp:910
cl::sycl::group_broadcast
detail::enable_if_t<(is_group_v< std::decay_t< Group >> &&(std::is_trivially_copyable< T >::value||detail::is_vec< T >::value)), T > group_broadcast(Group g, T x)
Definition: group_algorithm.hpp:584
cl
We provide new interfaces for matrix muliply in this patch:
Definition: access.hpp:13
cl::sycl::detail::is_plus
std::integral_constant< bool, std::is_same< BinaryOperation, sycl::plus< T > >::value||std::is_same< BinaryOperation, sycl::plus< void > >::value > is_plus
Definition: group_algorithm.hpp:106
cl::sycl::bit_and
std::bit_and< T > bit_and
Definition: functional.hpp:20
cl::sycl::image_channel_order::r
@ r
__spv::MemorySemanticsMask::SequentiallyConsistent
@ SequentiallyConsistent
Definition: spirv_types.hpp:89
cl::sycl::joint_none_of
detail::enable_if_t<(is_group_v< std::decay_t< Group >> &&detail::is_pointer< Ptr >::value), bool > joint_none_of(Group g, Ptr first, Ptr last, Predicate pred)
Definition: group_algorithm.hpp:452
cl::sycl::ext::oneapi::sub_group
Definition: sub_group.hpp:108
__spv::MemorySemanticsMask::WorkgroupMemory
@ WorkgroupMemory
Definition: spirv_types.hpp:92
cl::sycl::any_of_group
detail::enable_if_t< is_group_v< Group >, bool > any_of_group(Group g, T x, Predicate pred)
Definition: group_algorithm.hpp:365
cl::sycl::joint_reduce
detail::enable_if_t<(is_group_v< std::decay_t< Group >> &&detail::is_pointer< Ptr >::value &&detail::is_arithmetic_or_complex< typename detail::remove_pointer< Ptr >::type >::value &&detail::is_arithmetic_or_complex< T >::value &&detail::is_native_op< typename detail::remove_pointer< Ptr >::type, BinaryOperation >::value &&detail::is_plus_if_complex< typename detail::remove_pointer< Ptr >::type, BinaryOperation >::value &&detail::is_plus_if_complex< T, BinaryOperation >::value &&detail::is_native_op< T, BinaryOperation >::value), T > joint_reduce(Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op)
Definition: group_algorithm.hpp:329
group_sort.hpp
cl::sycl::group_barrier
std::enable_if< is_group_v< Group > >::type group_barrier(Group, memory_scope FenceScope=Group::fence_scope)
Definition: group_algorithm.hpp:1021
cl::sycl::detail::remove_pointer_impl< remove_cv_t< T > >::type
remove_cv_t< T > type
Definition: type_traits.hpp:263
cl::sycl::ext::oneapi::sub_group::linear_id_type
uint32_t linear_id_type
Definition: sub_group.hpp:112
cl::sycl::select_from_group
detail::enable_if_t<(std::is_same< std::decay_t< Group >, sub_group >::value &&(std::is_trivially_copyable< T >::value||detail::is_vec< T >::value)), T > select_from_group(Group, T x, typename Group::id_type local_id)
Definition: group_algorithm.hpp:530
cl::sycl::joint_exclusive_scan
detail::enable_if_t<(is_group_v< std::decay_t< Group >> &&detail::is_pointer< InPtr >::value &&detail::is_pointer< OutPtr >::value &&detail::is_arithmetic_or_complex< typename detail::remove_pointer< InPtr >::type >::value &&detail::is_native_op< typename detail::remove_pointer< InPtr >::type, BinaryOperation >::value &&detail::is_plus_if_complex< typename detail::remove_pointer< InPtr >::type, BinaryOperation >::value), OutPtr > joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, BinaryOperation binary_op)
Definition: group_algorithm.hpp:791
cl::sycl::ext::oneapi::sub_group::get_local_range
range_type get_local_range() const
Definition: sub_group.hpp:137
cl::sycl::detail::is_plus_if_complex
std::integral_constant< bool,(is_complex< T >::value ? is_plus< T, BinaryOperation >::value :std::true_type::value)> is_plus_if_complex
Definition: group_algorithm.hpp:127
cl::sycl::joint_any_of
detail::enable_if_t<(is_group_v< std::decay_t< Group >> &&detail::is_pointer< Ptr >::value), bool > joint_any_of(Group g, Ptr first, Ptr last, Predicate pred)
Definition: group_algorithm.hpp:374
cl::sycl::ext::oneapi::sub_group::get_local_id
id_type get_local_id() const
Definition: sub_group.hpp:119
nd_item.hpp
cl::sycl::joint_inclusive_scan
detail::enable_if_t<(is_group_v< std::decay_t< Group >> &&detail::is_pointer< InPtr >::value &&detail::is_pointer< OutPtr >::value &&detail::is_arithmetic_or_complex< typename detail::remove_pointer< InPtr >::type >::value &&detail::is_native_op< typename detail::remove_pointer< InPtr >::type, BinaryOperation >::value &&detail::is_plus_if_complex< typename detail::remove_pointer< InPtr >::type, BinaryOperation >::value), OutPtr > joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, BinaryOperation binary_op)
Definition: group_algorithm.hpp:993
cl::sycl::plus
std::plus< T > plus
Definition: functional.hpp:18
cl::sycl::permute_group_by_xor
detail::enable_if_t<(std::is_same< std::decay_t< Group >, sub_group >::value &&(std::is_trivially_copyable< T >::value||detail::is_vec< T >::value)), T > permute_group_by_xor(Group, T x, typename Group::linear_id_type mask)
Definition: group_algorithm.hpp:511
cl::sycl::detail::is_native_op
Definition: group_algorithm.hpp:96
cl::sycl::all_of_group
detail::enable_if_t< is_group_v< std::decay_t< Group > >, bool > all_of_group(Group g, T x, Predicate pred)
Definition: group_algorithm.hpp:405
cl::sycl::detail::enable_if_t
typename std::enable_if< B, T >::type enable_if_t
Definition: stl_type_traits.hpp:24
known_identity.hpp
__spv::GroupOperation::Reduce
@ Reduce
functional.hpp
spirv_types.hpp
__SYCL_INLINE_NAMESPACE
#define __SYCL_INLINE_NAMESPACE(X)
Definition: defines_elementary.hpp:12