DPC++ Runtime
Runtime libraries for oneAPI DPC++
launch.hpp
Go to the documentation of this file.
1 /***************************************************************************
2  *
3  * Copyright (C) Codeplay Software Ltd.
4  *
5  * Part of the LLVM Project, under the Apache License v2.0 with LLVM
6  * Exceptions. See https://llvm.org/LICENSE.txt for license information.
7  * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  *
15  * SYCL compatibility extension
16  *
17  * launch.hpp
18  *
19  * Description:
20  * launch functionality for the SYCL compatibility extension
21  **************************************************************************/
22 
23 #pragma once
24 
25 #include <sycl/accessor.hpp>
26 #include <sycl/event.hpp>
27 #include <sycl/nd_range.hpp>
28 #include <sycl/queue.hpp>
29 #include <sycl/range.hpp>
30 #include <sycl/reduction.hpp>
31 
32 #include <syclcompat/device.hpp>
33 #include <syclcompat/dims.hpp>
34 
35 namespace syclcompat {
36 
37 namespace detail {
38 
39 template <typename R, typename... Types>
40 constexpr size_t getArgumentCount(R (*f)(Types...)) {
41  return sizeof...(Types);
42 }
43 
44 template <int Dim>
46  sycl::range<Dim> global_range = range.get_global_range();
47  sycl::range<Dim> local_range = range.get_local_range();
48  if constexpr (Dim == 3) {
49  return range;
50  } else if constexpr (Dim == 2) {
51  return sycl::nd_range<3>{{1, global_range[0], global_range[1]},
52  {1, local_range[0], local_range[1]}};
53  }
54  return sycl::nd_range<3>{{1, 1, global_range[0]}, {1, 1, local_range[0]}};
55 }
56 
57 template <auto F, typename... Args>
58 std::enable_if_t<std::is_invocable_v<decltype(F), Args...>, sycl::event>
59 launch(const sycl::nd_range<3> &range, sycl::queue q, Args... args) {
60  static_assert(detail::getArgumentCount(F) == sizeof...(args),
61  "Wrong number of arguments to SYCL kernel");
62  static_assert(
63  std::is_same<std::invoke_result_t<decltype(F), Args...>, void>::value,
64  "SYCL kernels should return void");
65 
66  return q.parallel_for(
67  range, [=](sycl::nd_item<3>) { [[clang::always_inline]] F(args...); });
68 }
69 
70 template <auto F, typename... Args>
71 sycl::event launch(const sycl::nd_range<3> &range, size_t mem_size,
72  sycl::queue q, Args... args) {
73  static_assert(detail::getArgumentCount(F) == sizeof...(args) + 1,
74  "Wrong number of arguments to SYCL kernel");
75 
76  using F_t = decltype(F);
77  using f_return_t = typename std::invoke_result_t<F_t, Args..., char *>;
78  static_assert(std::is_same<f_return_t, void>::value,
79  "SYCL kernels should return void");
80 
81  return q.submit([&](sycl::handler &cgh) {
82  auto local_acc = sycl::local_accessor<char, 1>(mem_size, cgh);
83  cgh.parallel_for(range, [=](sycl::nd_item<3>) {
84  auto local_mem = local_acc.get_pointer();
85  [[clang::always_inline]] F(args..., local_mem);
86  });
87  });
88 }
89 
90 } // namespace detail
91 
92 template <int Dim>
95 
96  if (global_size_in.size() == 0 || work_group_size.size() == 0) {
97  throw std::invalid_argument("Global or local size is zero!");
98  }
99  for (size_t i = 0; i < Dim; ++i) {
100  if (global_size_in[i] < work_group_size[i])
101  throw std::invalid_argument("Work group size larger than global size");
102  }
103 
104  auto global_size =
105  ((global_size_in + work_group_size - 1) / work_group_size) *
107  return {global_size, work_group_size};
108 }
109 
110 inline sycl::nd_range<1> compute_nd_range(int global_size_in,
111  int work_group_size) {
112  return compute_nd_range<1>(global_size_in, work_group_size);
113 }
114 
115 template <auto F, int Dim, typename... Args>
116 std::enable_if_t<std::is_invocable_v<decltype(F), Args...>, sycl::event>
117 launch(const sycl::nd_range<Dim> &range, sycl::queue q, Args... args) {
118  return detail::launch<F>(detail::transform_nd_range<Dim>(range), q, args...);
119 }
120 
121 template <auto F, int Dim, typename... Args>
122 std::enable_if_t<std::is_invocable_v<decltype(F), Args...>, sycl::event>
123 launch(const sycl::nd_range<Dim> &range, Args... args) {
124  return launch<F>(range, get_default_queue(), args...);
125 }
126 
127 // Alternative launch through dim3 objects
128 template <auto F, typename... Args>
129 std::enable_if_t<std::is_invocable_v<decltype(F), Args...>, sycl::event>
130 launch(const dim3 &grid, const dim3 &threads, sycl::queue q, Args... args) {
131  return launch<F>(sycl::nd_range<3>{grid * threads, threads}, q, args...);
132 }
133 
134 template <auto F, typename... Args>
135 std::enable_if_t<std::is_invocable_v<decltype(F), Args...>, sycl::event>
136 launch(const dim3 &grid, const dim3 &threads, Args... args) {
137  return launch<F>(grid, threads, get_default_queue(), args...);
138 }
139 
154 template <auto F, int Dim, typename... Args>
155 sycl::event launch(const sycl::nd_range<Dim> &range, size_t mem_size,
156  sycl::queue q, Args... args) {
157  return detail::launch<F>(detail::transform_nd_range<Dim>(range), mem_size, q,
158  args...);
159 }
160 
174 template <auto F, int Dim, typename... Args>
175 sycl::event launch(const sycl::nd_range<Dim> &range, size_t mem_size,
176  Args... args) {
177  return launch<F>(range, mem_size, get_default_queue(), args...);
178 }
179 
195 template <auto F, typename... Args>
196 sycl::event launch(const dim3 &grid, const dim3 &threads, size_t mem_size,
197  sycl::queue q, Args... args) {
198  return launch<F>(sycl::nd_range<3>{grid * threads, threads}, mem_size, q,
199  args...);
200 }
201 
217 template <auto F, typename... Args>
218 sycl::event launch(const dim3 &grid, const dim3 &threads, size_t mem_size,
219  Args... args) {
220  return launch<F>(grid, threads, mem_size, get_default_queue(), args...);
221 }
222 
223 } // namespace syclcompat
The file contains implementations of accessor class.
An event object can be used to synchronize memory transfers, enqueues of kernels and signaling barrie...
Definition: event.hpp:44
Command group handler class.
Definition: handler.hpp:458
void parallel_for(range< 1 > NumWorkItems, _KERNELFUNCPARAM(KernelFunc))
Definition: handler.hpp:2014
Identifies an instance of the function object executing at each point in an nd_range.
Definition: nd_item.hpp:48
Defines the iteration domain of both the work-groups and the overall dispatch.
Definition: nd_range.hpp:22
range< Dimensions > get_global_range() const
Definition: nd_range.hpp:43
range< Dimensions > get_local_range() const
Definition: nd_range.hpp:45
Encapsulates a single SYCL queue which schedules kernels on a SYCL device.
Definition: queue.hpp:111
event parallel_for(range< 1 > Range, RestT &&...Rest)
parallel_for version with a kernel represented as a lambda + range that specifies global size only.
Definition: queue.hpp:2271
std::enable_if_t< std::is_invocable_r_v< void, T, handler & >, event > submit(T CGF, const detail::code_location &CodeLoc=detail::code_location::current())
Submits a command group function object to the queue, in order to be scheduled for execution on the d...
Definition: queue.hpp:346
Defines the iteration domain of either a single work-group in a parallel dispatch,...
Definition: range.hpp:26
size_t size() const
Definition: range.hpp:56
constexpr work_group_size_key::value_t< Dim0, Dims... > work_group_size
Definition: properties.hpp:117
sycl::nd_range< 3 > transform_nd_range(const sycl::nd_range< Dim > &range)
Definition: launch.hpp:45
std::enable_if_t< std::is_invocable_v< decltype(F), Args... >, sycl::event > launch(const sycl::nd_range< 3 > &range, sycl::queue q, Args... args)
Definition: launch.hpp:59
constexpr size_t getArgumentCount(R(*f)(Types...))
Definition: launch.hpp:40
static sycl::queue get_default_queue()
Util function to get the default queue of current device in device manager.
Definition: device.hpp:744
auto * local_mem()
Definition: memory.hpp:69
sycl::nd_range< Dim > compute_nd_range(sycl::range< Dim > global_size_in, sycl::range< Dim > work_group_size)
Definition: launch.hpp:93
std::enable_if_t< std::is_invocable_v< decltype(F), Args... >, sycl::event > launch(const sycl::nd_range< Dim > &range, sycl::queue q, Args... args)
Definition: launch.hpp:117