DPC++ Runtime
Runtime libraries for oneAPI DPC++
user_defined_reductions.hpp
Go to the documentation of this file.
1 //==--- user_defined_reductions.hpp -- SYCL ext header file -=--*- C++ -*---==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #include <sycl/detail/defines.hpp>
12 #include <sycl/group_algorithm.hpp>
13 
14 namespace sycl {
15 inline namespace _V1 {
16 namespace ext::oneapi::experimental {
17 namespace detail {
18 template <typename GroupHelper, typename T, typename BinaryOperation>
19 T reduce_over_group_impl(GroupHelper group_helper, T x, size_t num_elements,
20  BinaryOperation binary_op) {
21 #ifdef __SYCL_DEVICE_ONLY__
22  T *Memory = reinterpret_cast<T *>(group_helper.get_memory().data());
23  auto g = group_helper.get_group();
24  Memory[g.get_local_linear_id()] = x;
25  group_barrier(g);
26  T result = Memory[0];
27  if (g.leader()) {
28  for (int i = 1; i < num_elements; i++) {
29  result = binary_op(result, Memory[i]);
30  }
31  }
32  group_barrier(g);
33  return group_broadcast(g, result);
34 #else
35  std::ignore = group_helper;
36  std::ignore = x;
37  std::ignore = num_elements;
38  std::ignore = binary_op;
39  throw runtime_error("Group algorithms are not supported on host.",
40  PI_ERROR_INVALID_DEVICE);
41 #endif
42 }
43 } // namespace detail
44 
45 // ---- reduce_over_group
46 template <typename GroupHelper, typename T, typename BinaryOperation>
47 std::enable_if_t<(is_group_helper_v<GroupHelper>), T>
48 reduce_over_group(GroupHelper group_helper, T x, BinaryOperation binary_op) {
49  if constexpr (sycl::detail::is_native_op<T, BinaryOperation>::value) {
50  return sycl::reduce_over_group(group_helper.get_group(), x, binary_op);
51  }
52 #ifdef __SYCL_DEVICE_ONLY__
54  group_helper, x, group_helper.get_group().get_local_linear_range(),
55  binary_op);
56 #else
57  throw runtime_error("Group algorithms are not supported on host.",
58  PI_ERROR_INVALID_DEVICE);
59 #endif
60 }
61 
62 template <typename GroupHelper, typename V, typename T,
63  typename BinaryOperation>
64 std::enable_if_t<(is_group_helper_v<GroupHelper>), T>
65 reduce_over_group(GroupHelper group_helper, V x, T init,
66  BinaryOperation binary_op) {
67  if constexpr (sycl::detail::is_native_op<V, BinaryOperation>::value &&
68  sycl::detail::is_native_op<T, BinaryOperation>::value) {
69  return sycl::reduce_over_group(group_helper.get_group(), x, init,
70  binary_op);
71  }
72 #ifdef __SYCL_DEVICE_ONLY__
73  return binary_op(init, reduce_over_group(group_helper, x, binary_op));
74 #else
75  std::ignore = group_helper;
76  throw runtime_error("Group algorithms are not supported on host.",
77  PI_ERROR_INVALID_DEVICE);
78 #endif
79 }
80 
81 // ---- joint_reduce
82 template <typename GroupHelper, typename Ptr, typename BinaryOperation>
83 std::enable_if_t<(is_group_helper_v<GroupHelper> &&
84  sycl::detail::is_pointer_v<Ptr>),
86 joint_reduce(GroupHelper group_helper, Ptr first, Ptr last,
87  BinaryOperation binary_op) {
88  if constexpr (sycl::detail::is_native_op<
90  BinaryOperation>::value) {
91  return sycl::joint_reduce(group_helper.get_group(), first, last, binary_op);
92  }
93 #ifdef __SYCL_DEVICE_ONLY__
94  // TODO: the complexity is linear and not logarithmic. Something like
95  // https://github.com/intel/llvm/blob/8ebd912679f27943d8ef6c33a9775347dce6b80d/sycl/include/sycl/reduction.hpp#L1810-L1818
96  // might be applicable here.
97  using T = typename std::iterator_traits<Ptr>::value_type;
98  auto g = group_helper.get_group();
99  T partial = *(first + g.get_local_linear_id());
100  Ptr second = first + g.get_local_linear_range();
101  sycl::detail::for_each(g, second, last,
102  [&](const T &x) { partial = binary_op(partial, x); });
103  group_barrier(g);
104  size_t num_elements = last - first;
105  num_elements = std::min(num_elements, g.get_local_linear_range());
106  return detail::reduce_over_group_impl(group_helper, partial, num_elements,
107  binary_op);
108 #else
109  std::ignore = group_helper;
110  std::ignore = first;
111  std::ignore = last;
112  std::ignore = binary_op;
113  throw runtime_error("Group algorithms are not supported on host.",
114  PI_ERROR_INVALID_DEVICE);
115 #endif
116 }
117 
118 template <typename GroupHelper, typename Ptr, typename T,
119  typename BinaryOperation>
120 std::enable_if_t<
121  (is_group_helper_v<GroupHelper> && sycl::detail::is_pointer_v<Ptr>), T>
122 joint_reduce(GroupHelper group_helper, Ptr first, Ptr last, T init,
123  BinaryOperation binary_op) {
124  if constexpr (sycl::detail::is_native_op<T, BinaryOperation>::value) {
125  return sycl::joint_reduce(group_helper.get_group(), first, last, init,
126  binary_op);
127  }
128 #ifdef __SYCL_DEVICE_ONLY__
129  return binary_op(init, joint_reduce(group_helper, first, last, binary_op));
130 #else
131  std::ignore = group_helper;
132  std::ignore = last;
133  throw runtime_error("Group algorithms are not supported on host.",
134  PI_ERROR_INVALID_DEVICE);
135 #endif
136 }
137 } // namespace ext::oneapi::experimental
138 } // namespace _V1
139 } // namespace sycl
Function for_each(Group g, Ptr first, Ptr last, Function f)
T reduce_over_group_impl(GroupHelper group_helper, T x, size_t num_elements, BinaryOperation binary_op)
std::enable_if_t<(is_group_helper_v< GroupHelper >), T > reduce_over_group(GroupHelper group_helper, T x, BinaryOperation binary_op)
std::enable_if_t<(is_group_helper_v< GroupHelper > &&sycl::detail::is_pointer_v< Ptr >), typename std::iterator_traits< Ptr >::value_type > joint_reduce(GroupHelper group_helper, Ptr first, Ptr last, BinaryOperation binary_op)
std::enable_if_t<(is_group_v< std::decay_t< Group >> &&(std::is_trivially_copyable_v< T >||detail::is_vec< T >::value)), T > group_broadcast(Group g, T x, typename Group::id_type local_id)
std::enable_if_t<(is_group_v< std::decay_t< Group >> &&detail::is_pointer_v< Ptr > &&detail::is_arithmetic_or_complex< typename detail::remove_pointer< Ptr >::type >::value &&detail::is_arithmetic_or_complex< T >::value &&detail::is_plus_or_multiplies_if_complex< T, BinaryOperation >::value &&detail::is_native_op< T, BinaryOperation >::value), T > joint_reduce(Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op)
void group_barrier(ext::oneapi::experimental::root_group< dimensions > G, memory_scope FenceScope=decltype(G)::fence_scope)
Definition: root_group.hpp:112
std::enable_if_t<(is_group_v< std::decay_t< Group >> &&(detail::is_scalar_arithmetic< T >::value||(detail::is_complex< T >::value &&detail::is_multiplies< T, BinaryOperation >::value)) &&detail::is_native_op< T, BinaryOperation >::value), T > reduce_over_group(Group g, T x, BinaryOperation binary_op)
autodecltype(x) x
const void value_type
Definition: multi_ptr.hpp:457
Definition: access.hpp:18