DPC++ Runtime
Runtime libraries for oneAPI DPC++
invoke_simd.hpp
Go to the documentation of this file.
1 //==------ invoke_simd.hpp - SYCL invoke_simd extension --*- C++ -*---------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // ===--------------------------------------------------------------------=== //
8 // Implemenation of the sycl_ext_oneapi_invoke_simd extension.
9 // https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/proposed/sycl_ext_oneapi_invoke_simd.asciidoc
10 // ===--------------------------------------------------------------------=== //
11 
12 #pragma once
13 
16 
17 #include <sycl/detail/boost/mp11.hpp>
18 #include <sycl/sub_group.hpp>
19 
20 #include <functional>
21 
22 // TODOs:
23 // * (a) TODO enforce constness of a functor/lambda's () operator
24 // * (b) TODO support lambdas and functors in BE
25 
37 template <bool IsFunc, class SpmdRet, class HelperFunc,
38  class... UserSimdFuncAndSpmdArgs, class = std::enable_if_t<!IsFunc>>
39 __DPCPP_SYCL_EXTERNAL __regcall SpmdRet
40 __builtin_invoke_simd(HelperFunc helper, const void *obj,
41  UserSimdFuncAndSpmdArgs... args)
42 #ifdef __SYCL_DEVICE_ONLY__
43  ;
44 #else
45 {
46  // __builtin_invoke_simd is not supported on the host device yet
47  throw sycl::exception(sycl::errc::feature_not_supported,
48  "__builtin_invoke_simd is not supported on host");
49 }
50 #endif // __SYCL_DEVICE_ONLY__
51 
52 template <bool IsFunc, class SpmdRet, class HelperFunc,
53  class... UserSimdFuncAndSpmdArgs, class = std::enable_if_t<IsFunc>>
54 __DPCPP_SYCL_EXTERNAL __regcall SpmdRet
55 __builtin_invoke_simd(HelperFunc helper, UserSimdFuncAndSpmdArgs... args)
56 #ifdef __SYCL_DEVICE_ONLY__
57  ;
58 #else
59 {
60  // __builtin_invoke_simd is not supported on the host device yet
61  throw sycl::exception(sycl::errc::feature_not_supported,
62  "__builtin_invoke_simd is not supported on host");
63 }
64 #endif // __SYCL_DEVICE_ONLY__
65 
66 namespace sycl {
67 inline namespace _V1 {
68 
69 namespace ext::oneapi::experimental {
70 
71 // --- Helpers
72 namespace detail {
73 
74 namespace __MP11_NS = sycl::detail::boost::mp11;
75 
76 // This structure performs the SPMD-to-SIMD parameter type conversion as defined
77 // by the spec.
78 template <class T, int N, class = void> struct spmd2simd;
79 // * `uniform<T>` converts to `T`
80 template <class T, int N> struct spmd2simd<uniform<T>, N> {
81  using type = T;
82 };
83 // * tuple of types converts to tuple of converted tuple element types.
84 template <class... T, int N> struct spmd2simd<std::tuple<T...>, N> {
85  using type = std::tuple<typename spmd2simd<T, N>::type...>;
86 };
87 // * arithmetic type converts to a simd vector with this element type and the
88 // width equal to caller's subgroup size and passed as the `N` template
89 // argument.
90 template <class T, int N>
91 struct spmd2simd<T, N, std::enable_if_t<std::is_arithmetic_v<T>>> {
92  using type = simd<T, N>;
93 };
94 
95 // * bool converts to `simd_mask` with a user specified element type.
96 // Arbitrarily use unsigned char for the element type for subgroup size
97 // deduction and rely on the implicit conversion operator for the the actual
98 // user type.
99 template <int N> struct spmd2simd<bool, N> {
101 };
102 
103 // This structure performs the SIMD-to-SPMD return type conversion as defined
104 // by the spec.
105 template <class, class = void> struct simd2spmd;
106 // * `uniform<T>` stays the same
107 template <class T> struct simd2spmd<uniform<T>> {
108  using type = uniform<T>;
109 };
110 // * `simd<T, N>` converts to T
111 template <class T, int N> struct simd2spmd<simd<T, N>> {
112  using type = T;
113 };
114 // * tuple of types converts to tuple of converted tuple element types.
115 template <class... T> struct simd2spmd<std::tuple<T...>> {
116  using type = std::tuple<typename simd2spmd<T>::type...>;
117 };
118 // * arithmetic type T converts to `uniform<T>`
119 template <class T>
120 struct simd2spmd<T, std::enable_if_t<std::is_arithmetic_v<T>>> {
121  using type = uniform<T>;
122 };
123 
124 // * `simd_mask` converts to bool
125 template <class T, int N> struct simd2spmd<simd_mask<T, N>> {
126  using type = bool;
127 };
128 
129 template <> struct simd2spmd<void> { using type = void; };
130 
131 // Determine number of elements in a simd type.
132 template <class T> struct simd_size {
133  static constexpr int value = 1; // 1 element in any type by default
134 };
135 
136 // * Specialization for the simd type.
137 template <class T, int N> struct simd_size<simd<T, N>> {
138  static constexpr int value = N;
139 };
140 
141 // Check if given type is uniform.
142 template <class T> struct is_uniform_type : std::false_type {};
143 template <class T> struct is_uniform_type<uniform<T>> : std::true_type {
144  using type = T;
145 };
146 
147 // Check if given type is simd or simd_mask.
148 template <class T> struct is_simd_or_mask_type : std::false_type {};
149 template <class T, int N>
150 struct is_simd_or_mask_type<simd<T, N>> : std::true_type {};
151 template <class T, int N>
152 struct is_simd_or_mask_type<simd_mask<T, N>> : std::true_type {};
153 
154 // Checks if all the types in the parameter pack are uniform<T>.
155 template <class... SpmdArgs> struct all_uniform_types {
156  constexpr operator bool() {
157  using TypeList = __MP11_NS::mp_list<SpmdArgs...>;
158  return __MP11_NS::mp_all_of<TypeList, is_uniform_type>::value;
159  }
160 };
161 
162 // "Unwraps" a value of the `uniform` type (used before passing to SPMD
163 // arguments to the __builtin_invoke_simd):
164 // - the case when there is nothing to unwrap
165 template <typename T> struct unwrap_uniform {
166  static auto impl(T val) { return val; }
167  using type = T;
168 };
169 
170 // - the real unwrapping case
171 template <typename T> struct unwrap_uniform<uniform<T>> {
172  static T impl(uniform<T> val) { return val; }
173  using type = T;
174 };
175 
176 // Verify the callee return type matches the subgroup size as is required by the
177 // spec. For example: simd<int, 8> foo(simd<int,16>); The return type vector
178 // length (8) does not match the subgroup size (16).
179 template <auto SgSize, typename SimdRet>
181  if constexpr (is_simd_or_mask_type<SimdRet>::value) {
182  constexpr auto RetVecLength = SimdRet::size();
183  static_assert(RetVecLength == SgSize,
184  "invoke_simd callee return type vector length must match "
185  "kernel subgroup size");
186  }
187 }
188 
189 // Deduces subgroup size of the caller based on given SIMD callable and
190 // corresponding SPMD arguments it is being invoke with via invoke_simd.
191 // Basically, for each supported subgroup size, this meta-function finds out if
192 // the callable can be invoked by C++ rules given the SPMD arguments transformed
193 // as prescribed by the spec assuming this subgroup size. One and only one
194 // subgroup size should conform.
195 template <class SimdCallable, class... SpmdArgs> struct sg_size {
196  template <class N>
197  using IsInvocableSgSize = __MP11_NS::mp_bool<std::is_invocable_v<
198  SimdCallable, typename spmd2simd<SpmdArgs, N::value>::type...>>;
199 
200  __DPCPP_SYCL_EXTERNAL constexpr operator int() {
201  using SupportedSgSizes = __MP11_NS::mp_list_c<int, 1, 2, 4, 8, 16, 32>;
202  using InvocableSgSizes =
203  __MP11_NS::mp_copy_if<SupportedSgSizes, IsInvocableSgSize>;
204  constexpr auto found_invoke_simd_target =
205  __MP11_NS::mp_empty<InvocableSgSizes>::value != 1;
206  if constexpr (found_invoke_simd_target) {
207  static_assert((__MP11_NS::mp_size<InvocableSgSizes>::value == 1) &&
208  "multiple invoke_simd targets found");
209  return __MP11_NS::mp_front<InvocableSgSizes>::value;
210  }
211  static_assert(
212  found_invoke_simd_target,
213  "No callable invoke_simd target found. Confirm the "
214  "invoke_simd invocation argument types are convertible to the "
215  "invoke_simd target argument types");
216  }
217 };
218 
219 // Determine the return type of a SIMD callable.
220 template <int N, class SimdCallable, class... SpmdArgs>
221 using SimdRetType =
222  std::invoke_result_t<SimdCallable,
223  typename spmd2simd<SpmdArgs, N>::type...>;
224 // Determine the return type of an invoke_simd based on the return type of a
225 // SIMD callable.
226 template <int N, class SimdCallable, class... SpmdArgs>
227 using SpmdRetType =
228  typename simd2spmd<SimdRetType<N, SimdCallable, SpmdArgs...>>::type;
229 
230 template <class SimdCallable, class... SpmdArgs>
231 static constexpr int get_sg_size() {
232  if constexpr (all_uniform_types<SpmdArgs...>()) {
233  using SimdRet = std::invoke_result_t<SimdCallable, SpmdArgs...>;
234 
235  if constexpr (is_simd_or_mask_type<SimdRet>::value) {
237  } else {
238  // fully uniform function - subgroup size does not matter
239  return 0;
240  }
241  } else {
242  return sg_size<SimdCallable, SpmdArgs...>();
243  }
244 }
245 
246 // This function is a wrapper around a call to a functor with field or a lambda
247 // with captures. Note __regcall - this is needed for efficient argument
248 // forwarding.
249 template <int N, class Callable, class... T>
250 [[intel::device_indirectly_callable]] __DPCPP_SYCL_EXTERNAL __regcall detail::
251  SimdRetType<N, Callable, T...>
252  simd_obj_call_helper(const void *obj_ptr,
253  typename detail::spmd2simd<T, N>::type... simd_args) {
254  auto f =
255  *reinterpret_cast<const std::remove_reference_t<Callable> *>(obj_ptr);
256  return f(simd_args...);
257 }
258 
259 // This function is a wrapper around a call to a function.
260 template <int N, class Callable, class... T>
261 [[intel::device_indirectly_callable]] __DPCPP_SYCL_EXTERNAL __regcall detail::
262  SimdRetType<N, Callable, T...>
264  typename detail::spmd2simd<T, N>::type... simd_args) {
265  return f(simd_args...);
266 }
267 
268 #ifdef _GLIBCXX_RELEASE
269 #if _GLIBCXX_RELEASE < 10
270 #define __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
271 #endif // _GLIBCXX_RELEASE < 10
272 #endif // _GLIBCXX_RELEASE
273 
274 #ifdef __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
275 // TODO This is a workaround for libstdc++ version 9 buggy behavior which
276 // returns false in the code below. Version 10 works fine. Once required
277 // minimum libstdc++ version is bumped to 10, this w/a should be removed.
278 // template <class F> bool foo(F &&f) {
279 // return std::is_function_v<std::remove_reference_t<F>>;
280 // }
281 // where F is a function type with __regcall.
282 template <class F> struct is_regcall_function_ptr_or_ref : std::false_type {};
283 
284 template <class Ret, class... Args>
285 struct is_regcall_function_ptr_or_ref<Ret(__regcall &)(Args...)>
286  : std::true_type {};
287 
288 template <class Ret, class... Args>
289 struct is_regcall_function_ptr_or_ref<Ret(__regcall *)(Args...)>
290  : std::true_type {};
291 
292 template <class Ret, class... Args>
293 struct is_regcall_function_ptr_or_ref<Ret(__regcall *&)(Args...)>
294  : std::true_type {};
295 
296 template <class F>
297 static constexpr bool is_regcall_function_ptr_or_ref_v =
298  is_regcall_function_ptr_or_ref<F>::value;
299 #endif // __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
300 
301 template <class Callable>
302 static constexpr bool is_function_ptr_or_ref_v =
303  std::is_function_v<std::remove_pointer_t<std::remove_reference_t<Callable>>>
304 #ifdef __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
305  || is_regcall_function_ptr_or_ref_v<Callable>
306 #endif // __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
307  ;
308 
309 template <typename Callable> struct remove_ref_from_func_ptr_ref_type {
310  using type = Callable;
311 };
312 
313 template <typename Ret, typename... Args>
314 struct remove_ref_from_func_ptr_ref_type<Ret (*&)(Args...)> {
315  using type = Ret (*)(Args...);
316 };
317 
318 template <typename Ret, typename... Args>
319 struct remove_ref_from_func_ptr_ref_type<Ret(__regcall *&)(Args...)> {
320  using type = Ret(__regcall *)(Args...);
321 };
322 
323 template <typename T>
326 
327 template <typename T> struct strip_regcall_from_function_ptr;
328 
329 template <typename Ret, typename... Args>
330 struct strip_regcall_from_function_ptr<Ret (*)(Args...)> {
331  using type = Ret (*)(Args...);
332 };
333 
334 template <typename Ret, typename... Args>
335 struct strip_regcall_from_function_ptr<Ret(__regcall *)(Args...)> {
336  using type = Ret (*)(Args...);
337 };
338 
339 template <typename T>
342 
343 template <typename T> struct is_non_trivially_copyable_uniform {
344  static constexpr bool value =
346  !std::is_trivially_copyable_v<typename unwrap_uniform<T>::type>;
347 };
348 
349 template <> struct is_non_trivially_copyable_uniform<void> {
350  static constexpr bool value = false;
351 };
352 
353 template <typename T>
354 inline constexpr bool is_non_trivially_copyable_uniform_v =
356 
357 template <typename Ret, typename... Args>
358 constexpr bool has_ref_arg(Ret (*)(Args...)) {
359  return (... || std::is_reference_v<Args>);
360 }
361 
362 template <typename Ret, typename... Args>
363 constexpr bool has_ref_ret(Ret (*)(Args...)) {
364  return std::is_reference_v<Ret>;
365 }
366 
367 template <typename Ret, typename... Args>
368 constexpr bool has_non_uniform_struct_ret(Ret (*)(Args...)) {
369  return std::is_class_v<Ret> && !is_simd_or_mask_type<Ret>::value &&
371 }
372 
373 template <typename Ret, typename... Args>
374 constexpr bool has_non_trivially_copyable_uniform_ret(Ret (*)(Args...)) {
375  return is_non_trivially_copyable_uniform_v<Ret>;
376 }
377 
378 template <class Callable> constexpr void verify_callable() {
379  if constexpr (is_function_ptr_or_ref_v<Callable>) {
380  using RemoveRef =
382  using FuncPtrType =
383  std::conditional_t<std::is_pointer_v<RemoveRef>, RemoveRef,
384  std::add_pointer_t<RemoveRef>>;
386  constexpr FuncPtrNoCC obj = {};
387  constexpr bool callable_has_ref_ret = has_ref_ret(obj);
388  static_assert(
389  !callable_has_ref_ret,
390  "invoke_simd does not support callables returning references");
391  constexpr bool callable_has_ref_arg = has_ref_arg(obj);
392  static_assert(
393  !callable_has_ref_arg,
394  "invoke_simd does not support callables with reference arguments");
395 #ifndef __INVOKE_SIMD_ENABLE_STRUCTS
396  constexpr bool callable_has_non_uniform_struct_ret =
398  static_assert(!callable_has_non_uniform_struct_ret,
399  "invoke_simd does not support callables returning "
400  "non-uniform structures");
401 #endif
402 #ifdef __SYCL_DEVICE_ONLY__
403  constexpr bool callable_has_uniform_non_trivially_copyable_ret =
405  static_assert(!callable_has_uniform_non_trivially_copyable_ret,
406  "invoke_simd does not support callables returning uniforms "
407  "that are not trivially copyable");
408 #endif
409  }
410 }
411 
412 template <class... Ts>
414 #ifdef __SYCL_DEVICE_ONLY__
415  constexpr bool has_non_trivially_copyable_uniform_arg =
416  (... || is_non_trivially_copyable_uniform_v<Ts>);
417  static_assert(!has_non_trivially_copyable_uniform_arg,
418  "Uniform arguments must be trivially copyable");
419 #endif
420 }
421 
422 template <class... Ts> constexpr void verify_no_non_uniform_struct_args() {
423 #if defined(__SYCL_DEVICE_ONLY__) && !defined(__INVOKE_SIMD_ENABLE_STRUCTS)
424  constexpr bool has_non_uniform_struct_arg =
425  (... || (std::is_class_v<Ts> && !is_simd_or_mask_type<Ts>::value &&
427  static_assert(!has_non_uniform_struct_arg,
428  "Structure arguments must be uniform");
429 #endif
430 }
431 
432 template <class Callable, class... Ts>
433 constexpr void verify_valid_args_and_ret() {
435 
437 
438  verify_callable<Callable>();
439 }
440 
441 } // namespace detail
442 
443 // --- The main API
444 
461 // TODO works only for functions and pointers to functions now,
462 // enable for lambda functions and functors.
463 template <class Callable, class... T>
464 __attribute__((always_inline)) auto invoke_simd(sycl::sub_group sg,
465  Callable &&f, T... args) {
466  // If the invoke_simd call site is fully uniform, then it does not matter
467  // what the subgroup size is and arguments don't need widening and return
468  // value does not need shrinking by this library or SPMD compiler, so 0
469  // is fine in this case.
470  detail::verify_valid_args_and_ret<Callable, T...>();
471  constexpr int N = detail::get_sg_size<Callable, T...>();
472  using RetSpmd = detail::SpmdRetType<N, Callable, T...>;
474  N, detail::SimdRetType<N, Callable, T...>>();
475  constexpr bool is_function = detail::is_function_ptr_or_ref_v<Callable>;
476 
477  if constexpr (is_function) {
478  // The variables typed as pointer to a function become lvalue-reference
479  // when passed to invoke_simd() as universal pointers. That creates an
480  // additional indirection, which is resolved automatically by the compiler
481  // for the caller side of __builtin_invoke_simd, but which must be resolved
482  // manually during the creation of simd_func_call_helper.
483  // The class remove_ref_from_func_ptr_ref_type is used removes that
484  // unwanted indirection.
485  return __builtin_invoke_simd<true /*function*/, RetSpmd>(
488  f, detail::unwrap_uniform<T>::impl(args)...);
489  } else {
490  // TODO support functors and lambdas which are handled in this branch.
491  // The limiting factor for now is that the LLVMIR data flow analysis
492  // implemented in LowerInvokeSimd.cpp which, finds actual invoke_simd
493  // target function, can't handle this case yet.
494  return __builtin_invoke_simd<false /*functor/lambda*/, RetSpmd>(
495  detail::simd_obj_call_helper<N, Callable, T...>, &f,
497  }
498 // TODO Temporary macro and assert to enable API compilation testing.
499 // LowerInvokeSimd.cpp does not support this case yet.
500 #ifndef __INVOKE_SIMD_ENABLE_ALL_CALLABLES
501  static_assert(is_function &&
502  "invoke_simd does not support functors or lambdas yet");
503 #endif // __INVOKE_SIMD_ENABLE_ALL_CALLABLES
504 }
505 
506 } // namespace ext::oneapi::experimental
507 } // namespace _V1
508 } // namespace sycl
Definition: simd.hpp:1387
#define __DPCPP_SYCL_EXTERNAL
__DPCPP_SYCL_EXTERNAL __regcall SpmdRet __builtin_invoke_simd(HelperFunc helper, const void *obj, UserSimdFuncAndSpmdArgs... args)
Middle End - to - Back End interface to invoke explicit SIMD functions from SPMD SYCL context.
Definition: invoke_simd.hpp:40
typename strip_regcall_from_function_ptr< T >::type strip_regcall_from_function_ptr_t
constexpr bool has_non_uniform_struct_ret(Ret(*)(Args...))
std::invoke_result_t< SimdCallable, typename spmd2simd< SpmdArgs, N >::type... > SimdRetType
constexpr void verify_no_uniform_non_trivially_copyable_args()
__DPCPP_SYCL_EXTERNAL __regcall detail::SimdRetType< N, Callable, T... > simd_obj_call_helper(const void *obj_ptr, typename detail::spmd2simd< T, N >::type... simd_args)
constexpr bool has_ref_arg(Ret(*)(Args...))
__DPCPP_SYCL_EXTERNAL __regcall detail::SimdRetType< N, Callable, T... > simd_func_call_helper(Callable f, typename detail::spmd2simd< T, N >::type... simd_args)
constexpr bool has_ref_ret(Ret(*)(Args...))
constexpr bool has_non_trivially_copyable_uniform_ret(Ret(*)(Args...))
typename simd2spmd< SimdRetType< N, SimdCallable, SpmdArgs... > >::type SpmdRetType
typename remove_ref_from_func_ptr_ref_type< T >::type remove_ref_from_func_ptr_ref_type_t
__attribute__((always_inline)) auto invoke_simd(sycl
The invoke_simd free function invokes a SIMD function using all work-items in a sub_group.
Definition: access.hpp:18
__MP11_NS::mp_bool< std::is_invocable_v< SimdCallable, typename spmd2simd< SpmdArgs, N::value >::type... > > IsInvocableSgSize