17 #include <sycl/detail/boost/mp11.hpp>
37 template <
bool IsFunc,
class SpmdRet,
class HelperFunc,
38 class... UserSimdFuncAndSpmdArgs,
class = std::enable_if_t<!IsFunc>>
41 UserSimdFuncAndSpmdArgs... args)
42 #ifdef __SYCL_DEVICE_ONLY__
48 "__builtin_invoke_simd is not supported on host");
52 template <
bool IsFunc,
class SpmdRet,
class HelperFunc,
53 class... UserSimdFuncAndSpmdArgs,
class = std::enable_if_t<IsFunc>>
56 #ifdef __SYCL_DEVICE_ONLY__
62 "__builtin_invoke_simd is not supported on host");
67 inline namespace _V1 {
69 namespace ext::oneapi::experimental {
74 namespace __MP11_NS = sycl::detail::boost::mp11;
78 template <
class T,
int N,
class =
void>
struct spmd2simd;
84 template <
class... T,
int N>
struct spmd2simd<
std::tuple<T...>, N> {
85 using type = std::tuple<typename spmd2simd<T, N>::type...>;
90 template <
class T,
int N>
91 struct spmd2simd<T, N,
std::enable_if_t<std::is_arithmetic_v<T>>> {
116 using type = std::tuple<typename simd2spmd<T>::type...>;
149 template <
class T,
int N>
151 template <
class T,
int N>
156 constexpr
operator bool() {
157 using TypeList = __MP11_NS::mp_list<SpmdArgs...>;
158 return __MP11_NS::mp_all_of<TypeList, is_uniform_type>::value;
166 static auto impl(T val) {
return val; }
179 template <auto SgSize,
typename SimdRet>
182 constexpr
auto RetVecLength = SimdRet::size();
183 static_assert(RetVecLength == SgSize,
184 "invoke_simd callee return type vector length must match "
185 "kernel subgroup size");
195 template <
class SimdCallable,
class... SpmdArgs>
struct sg_size {
201 using SupportedSgSizes = __MP11_NS::mp_list_c<int, 1, 2, 4, 8, 16, 32>;
202 using InvocableSgSizes =
203 __MP11_NS::mp_copy_if<SupportedSgSizes, IsInvocableSgSize>;
204 constexpr
auto found_invoke_simd_target =
205 __MP11_NS::mp_empty<InvocableSgSizes>::value != 1;
206 if constexpr (found_invoke_simd_target) {
207 static_assert((__MP11_NS::mp_size<InvocableSgSizes>::value == 1) &&
208 "multiple invoke_simd targets found");
209 return __MP11_NS::mp_front<InvocableSgSizes>::value;
212 found_invoke_simd_target,
213 "No callable invoke_simd target found. Confirm the "
214 "invoke_simd invocation argument types are convertible to the "
215 "invoke_simd target argument types");
220 template <
int N,
class SimdCallable,
class... SpmdArgs>
222 std::invoke_result_t<SimdCallable,
226 template <
int N,
class SimdCallable,
class... SpmdArgs>
230 template <
class SimdCallable,
class... SpmdArgs>
233 using SimdRet = std::invoke_result_t<SimdCallable, SpmdArgs...>;
242 return sg_size<SimdCallable, SpmdArgs...>();
249 template <
int N,
class Callable,
class... T>
251 SimdRetType<N, Callable, T...>
255 *
reinterpret_cast<const std::remove_reference_t<Callable> *
>(obj_ptr);
256 return f(simd_args...);
260 template <
int N,
class Callable,
class... T>
262 SimdRetType<N, Callable, T...>
265 return f(simd_args...);
268 #ifdef _GLIBCXX_RELEASE
269 #if _GLIBCXX_RELEASE < 10
270 #define __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
274 #ifdef __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
282 template <
class F>
struct is_regcall_function_ptr_or_ref : std::false_type {};
284 template <
class Ret,
class... Args>
285 struct is_regcall_function_ptr_or_ref<Ret(__regcall &)(Args...)>
288 template <
class Ret,
class... Args>
289 struct is_regcall_function_ptr_or_ref<Ret(__regcall *)(Args...)>
292 template <
class Ret,
class... Args>
293 struct is_regcall_function_ptr_or_ref<Ret(__regcall *&)(Args...)>
297 static constexpr
bool is_regcall_function_ptr_or_ref_v =
298 is_regcall_function_ptr_or_ref<F>::value;
301 template <
class Callable>
303 std::is_function_v<std::remove_pointer_t<std::remove_reference_t<Callable>>>
304 #ifdef __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
305 || is_regcall_function_ptr_or_ref_v<Callable>
313 template <
typename Ret,
typename... Args>
318 template <
typename Ret,
typename... Args>
320 using type = Ret(__regcall *)(Args...);
323 template <
typename T>
329 template <
typename Ret,
typename... Args>
334 template <
typename Ret,
typename... Args>
339 template <
typename T>
346 !std::is_trivially_copyable_v<typename unwrap_uniform<T>::type>;
350 static constexpr
bool value =
false;
353 template <
typename T>
357 template <
typename Ret,
typename... Args>
359 return (... || std::is_reference_v<Args>);
362 template <
typename Ret,
typename... Args>
364 return std::is_reference_v<Ret>;
367 template <
typename Ret,
typename... Args>
373 template <
typename Ret,
typename... Args>
375 return is_non_trivially_copyable_uniform_v<Ret>;
379 if constexpr (is_function_ptr_or_ref_v<Callable>) {
383 std::conditional_t<std::is_pointer_v<RemoveRef>, RemoveRef,
384 std::add_pointer_t<RemoveRef>>;
386 constexpr FuncPtrNoCC
obj = {};
389 !callable_has_ref_ret,
390 "invoke_simd does not support callables returning references");
393 !callable_has_ref_arg,
394 "invoke_simd does not support callables with reference arguments");
395 #ifndef __INVOKE_SIMD_ENABLE_STRUCTS
396 constexpr
bool callable_has_non_uniform_struct_ret =
398 static_assert(!callable_has_non_uniform_struct_ret,
399 "invoke_simd does not support callables returning "
400 "non-uniform structures");
402 #ifdef __SYCL_DEVICE_ONLY__
403 constexpr
bool callable_has_uniform_non_trivially_copyable_ret =
405 static_assert(!callable_has_uniform_non_trivially_copyable_ret,
406 "invoke_simd does not support callables returning uniforms "
407 "that are not trivially copyable");
412 template <
class... Ts>
414 #ifdef __SYCL_DEVICE_ONLY__
415 constexpr
bool has_non_trivially_copyable_uniform_arg =
416 (... || is_non_trivially_copyable_uniform_v<Ts>);
417 static_assert(!has_non_trivially_copyable_uniform_arg,
418 "Uniform arguments must be trivially copyable");
423 #if defined(__SYCL_DEVICE_ONLY__) && !defined(__INVOKE_SIMD_ENABLE_STRUCTS)
424 constexpr
bool has_non_uniform_struct_arg =
427 static_assert(!has_non_uniform_struct_arg,
428 "Structure arguments must be uniform");
432 template <
class Callable,
class... Ts>
438 verify_callable<Callable>();
463 template <
class Callable,
class... T>
465 Callable &&f, T... args) {
475 constexpr
bool is_function = detail::is_function_ptr_or_ref_v<Callable>;
477 if constexpr (is_function) {
500 #ifndef __INVOKE_SIMD_ENABLE_ALL_CALLABLES
501 static_assert(is_function &&
502 "invoke_simd does not support functors or lambdas yet");
#define __DPCPP_SYCL_EXTERNAL
__DPCPP_SYCL_EXTERNAL __regcall SpmdRet __builtin_invoke_simd(HelperFunc helper, const void *obj, UserSimdFuncAndSpmdArgs... args)
Middle End - to - Back End interface to invoke explicit SIMD functions from SPMD SYCL context.
typename strip_regcall_from_function_ptr< T >::type strip_regcall_from_function_ptr_t
constexpr bool has_non_uniform_struct_ret(Ret(*)(Args...))
constexpr void verify_valid_args_and_ret()
constexpr void verify_return_type_matches_sg_size()
std::invoke_result_t< SimdCallable, typename spmd2simd< SpmdArgs, N >::type... > SimdRetType
constexpr void verify_no_uniform_non_trivially_copyable_args()
constexpr void verify_no_non_uniform_struct_args()
constexpr bool is_non_trivially_copyable_uniform_v
__DPCPP_SYCL_EXTERNAL __regcall detail::SimdRetType< N, Callable, T... > simd_obj_call_helper(const void *obj_ptr, typename detail::spmd2simd< T, N >::type... simd_args)
constexpr bool has_ref_arg(Ret(*)(Args...))
__DPCPP_SYCL_EXTERNAL __regcall detail::SimdRetType< N, Callable, T... > simd_func_call_helper(Callable f, typename detail::spmd2simd< T, N >::type... simd_args)
constexpr bool has_ref_ret(Ret(*)(Args...))
constexpr void verify_callable()
constexpr bool has_non_trivially_copyable_uniform_ret(Ret(*)(Args...))
typename simd2spmd< SimdRetType< N, SimdCallable, SpmdArgs... > >::type SpmdRetType
static constexpr bool is_function_ptr_or_ref_v
typename remove_ref_from_func_ptr_ref_type< T >::type remove_ref_from_func_ptr_ref_type_t
static constexpr int get_sg_size()
__attribute__((always_inline)) auto invoke_simd(sycl
The invoke_simd free function invokes a SIMD function using all work-items in a sub_group.
Ret(__regcall *)(Args...) type
__MP11_NS::mp_bool< std::is_invocable_v< SimdCallable, typename spmd2simd< SpmdArgs, N::value >::type... > > IsInvocableSgSize
std::tuple< typename simd2spmd< T >::type... > type
static constexpr int value
simd_mask< unsigned char, N > type
std::tuple< typename spmd2simd< T, N >::type... > type