DPC++ Runtime
Runtime libraries for oneAPI DPC++
user_defined_reductions.hpp
Go to the documentation of this file.
1 //==--- user_defined_reductions.hpp -- SYCL ext header file -=--*- C++ -*---==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #include <sycl/detail/defines.hpp>
13 #include <sycl/group_algorithm.hpp>
14 
15 namespace sycl {
16 inline namespace _V1 {
17 namespace ext::oneapi::experimental {
18 namespace detail {
19 template <typename GroupHelper, typename T, typename BinaryOperation>
20 T reduce_over_group_impl(GroupHelper group_helper, T x, size_t num_elements,
21  BinaryOperation binary_op) {
22 #ifdef __SYCL_DEVICE_ONLY__
23  T *Memory = reinterpret_cast<T *>(group_helper.get_memory().data());
24  auto g = group_helper.get_group();
25  Memory[g.get_local_linear_id()] = x;
26  group_barrier(g);
27  T result = Memory[0];
28  if (g.leader()) {
29  for (int i = 1; i < num_elements; i++) {
30  result = binary_op(result, Memory[i]);
31  }
32  }
33  group_barrier(g);
34  return group_broadcast(g, result);
35 #else
36  std::ignore = group_helper;
37  std::ignore = x;
38  std::ignore = num_elements;
39  std::ignore = binary_op;
41  "Group algorithms are not supported on host.");
42 #endif
43 }
44 } // namespace detail
45 
46 // ---- reduce_over_group
47 template <typename GroupHelper, typename T, typename BinaryOperation>
48 std::enable_if_t<(is_group_helper_v<GroupHelper>), T>
49 reduce_over_group(GroupHelper group_helper, T x, BinaryOperation binary_op) {
50  if constexpr (sycl::detail::is_native_op<T, BinaryOperation>::value) {
51  return sycl::reduce_over_group(group_helper.get_group(), x, binary_op);
52  }
53 #ifdef __SYCL_DEVICE_ONLY__
55  group_helper, x, group_helper.get_group().get_local_linear_range(),
56  binary_op);
57 #else
59  "Group algorithms are not supported on host.");
60 #endif
61 }
62 
63 template <typename GroupHelper, typename V, typename T,
64  typename BinaryOperation>
65 std::enable_if_t<(is_group_helper_v<GroupHelper>), T>
66 reduce_over_group(GroupHelper group_helper, V x, T init,
67  BinaryOperation binary_op) {
68  if constexpr (sycl::detail::is_native_op<V, BinaryOperation>::value &&
69  sycl::detail::is_native_op<T, BinaryOperation>::value) {
70  return sycl::reduce_over_group(group_helper.get_group(), x, init,
71  binary_op);
72  }
73 #ifdef __SYCL_DEVICE_ONLY__
74  return binary_op(init, reduce_over_group(group_helper, x, binary_op));
75 #else
76  std::ignore = group_helper;
78  "Group algorithms are not supported on host.");
79 #endif
80 }
81 
82 // ---- joint_reduce
83 template <typename GroupHelper, typename Ptr, typename BinaryOperation>
84 std::enable_if_t<(is_group_helper_v<GroupHelper> &&
85  sycl::detail::is_pointer_v<Ptr>),
87 joint_reduce(GroupHelper group_helper, Ptr first, Ptr last,
88  BinaryOperation binary_op) {
89  if constexpr (sycl::detail::is_native_op<
91  BinaryOperation>::value) {
92  return sycl::joint_reduce(group_helper.get_group(), first, last, binary_op);
93  }
94 #ifdef __SYCL_DEVICE_ONLY__
95  // TODO: the complexity is linear and not logarithmic. Something like
96  // https://github.com/intel/llvm/blob/8ebd912679f27943d8ef6c33a9775347dce6b80d/sycl/include/sycl/reduction.hpp#L1810-L1818
97  // might be applicable here.
98  using T = typename std::iterator_traits<Ptr>::value_type;
99  auto g = group_helper.get_group();
100  T partial = *(first + g.get_local_linear_id());
101  Ptr second = first + g.get_local_linear_range();
102  sycl::detail::for_each(g, second, last,
103  [&](const T &x) { partial = binary_op(partial, x); });
104  group_barrier(g);
105  size_t num_elements = last - first;
106  num_elements = std::min(num_elements, g.get_local_linear_range());
107  return detail::reduce_over_group_impl(group_helper, partial, num_elements,
108  binary_op);
109 #else
110  std::ignore = group_helper;
111  std::ignore = first;
112  std::ignore = last;
113  std::ignore = binary_op;
115  "Group algorithms are not supported on host.");
116 #endif
117 }
118 
119 template <typename GroupHelper, typename Ptr, typename T,
120  typename BinaryOperation>
121 std::enable_if_t<
122  (is_group_helper_v<GroupHelper> && sycl::detail::is_pointer_v<Ptr>), T>
123 joint_reduce(GroupHelper group_helper, Ptr first, Ptr last, T init,
124  BinaryOperation binary_op) {
125  if constexpr (sycl::detail::is_native_op<T, BinaryOperation>::value) {
126  return sycl::joint_reduce(group_helper.get_group(), first, last, init,
127  binary_op);
128  }
129 #ifdef __SYCL_DEVICE_ONLY__
130  return binary_op(init, joint_reduce(group_helper, first, last, binary_op));
131 #else
132  std::ignore = group_helper;
133  std::ignore = last;
135  "Group algorithms are not supported on host.");
136 #endif
137 }
138 } // namespace ext::oneapi::experimental
139 } // namespace _V1
140 } // namespace sycl
Function for_each(Group g, Ptr first, Ptr last, Function f)
T reduce_over_group_impl(GroupHelper group_helper, T x, size_t num_elements, BinaryOperation binary_op)
std::enable_if_t<(is_group_helper_v< GroupHelper >), T > reduce_over_group(GroupHelper group_helper, T x, BinaryOperation binary_op)
std::enable_if_t<(is_group_helper_v< GroupHelper > &&sycl::detail::is_pointer_v< Ptr >), typename std::iterator_traits< Ptr >::value_type > joint_reduce(GroupHelper group_helper, Ptr first, Ptr last, BinaryOperation binary_op)
std::enable_if_t<(is_group_v< std::decay_t< Group >> &&(std::is_trivially_copyable_v< T >||detail::is_vec< T >::value)), T > group_broadcast(Group g, T x, typename Group::id_type local_id)
std::enable_if_t<(is_group_v< std::decay_t< Group >> &&detail::is_pointer_v< Ptr > &&detail::is_arithmetic_or_complex< typename detail::remove_pointer< Ptr >::type >::value &&detail::is_arithmetic_or_complex< T >::value &&detail::is_plus_or_multiplies_if_complex< T, BinaryOperation >::value &&detail::is_native_op< T, BinaryOperation >::value), T > joint_reduce(Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op)
void group_barrier(ext::oneapi::experimental::root_group< dimensions > G, memory_scope FenceScope=decltype(G)::fence_scope)
Definition: root_group.hpp:100
std::enable_if_t<(is_group_v< std::decay_t< Group >> &&(detail::is_scalar_arithmetic< T >::value||(detail::is_complex< T >::value &&detail::is_multiplies< T, BinaryOperation >::value)) &&detail::is_native_op< T, BinaryOperation >::value), T > reduce_over_group(Group g, T x, BinaryOperation binary_op)
std::error_code make_error_code(sycl::errc E) noexcept
Constructs an error code using e and sycl_category()
Definition: exception.cpp:65
autodecltype(x) x
const void value_type
Definition: multi_ptr.hpp:457
Definition: access.hpp:18