16 #define SYCL_EXT_ONEAPI_INVOKE_SIMD 1
22 #include <sycl/detail/boost/mp11.hpp>
40 template <
bool IsFunc,
class SpmdRet,
class SimdCallee,
class... SpmdArgs,
41 class = std::enable_if_t<!IsFunc>>
44 #ifdef __SYCL_DEVICE_ONLY__
49 throw sycl::exception(sycl::errc::feature_not_supported,
50 "__builtin_invoke_simd is not supported on host");
52 #endif // __SYCL_DEVICE_ONLY__
54 template <
bool IsFunc,
class SpmdRet,
class SimdCallee,
class... SpmdArgs,
55 class = std::enable_if_t<IsFunc>>
58 #ifdef __SYCL_DEVICE_ONLY__
63 throw sycl::exception(sycl::errc::feature_not_supported,
64 "__builtin_invoke_simd is not supported on host");
66 #endif // __SYCL_DEVICE_ONLY__
71 namespace experimental {
77 template <
class T,
int N>
79 std::experimental::_StorageKind::_VecExt, N>;
84 template <
class T,
int N>
85 using simd = std::experimental::simd<T, simd_abi::native_fixed_size<T, N>>;
88 template <
class T,
int N>
90 std::experimental::simd_mask<T, simd_abi::native_fixed_size<T, N>>;
95 namespace __MP11_NS = sycl::detail::boost::mp11;
99 template <
class T,
int N,
class =
void>
struct spmd2simd;
106 using type = std::tuple<typename spmd2simd<T, N>::type...>;
111 template <
class T,
int N>
129 using type = std::tuple<typename simd2spmd<T>::type...>;
145 template <
class T,
int N>
147 template <
class T,
int N>
153 constexpr
operator bool() {
154 using ArgTypeList = __MP11_NS::mp_list<SpmdArgs...>;
156 if constexpr (__MP11_NS::mp_all_of<ArgTypeList, is_uniform_type>::value) {
157 using SimdRet = std::invoke_result_t<SimdCallable, SpmdArgs...>;
170 static auto impl(T val) {
return val; }
184 template <
class SimdCallable,
class... SpmdArgs>
struct sg_size {
189 constexpr
operator int() {
190 using SupportedSgSizes = __MP11_NS::mp_list_c<int, 1, 2, 4, 8, 16, 32>;
191 using InvocableSgSizes =
192 __MP11_NS::mp_copy_if<SupportedSgSizes, IsInvocableSgSize>;
193 static_assert((__MP11_NS::mp_size<InvocableSgSizes>::value == 1) &&
194 "no or multiple invoke_simd targets found");
195 return __MP11_NS::mp_front<InvocableSgSizes>::value;
200 template <
int N,
class SimdCallable,
class... SpmdArgs>
202 std::invoke_result_t<SimdCallable,
206 template <
int N,
class SimdCallable,
class... SpmdArgs>
210 template <
class SimdCallable,
class... SpmdArgs>
215 return sg_size<SimdCallable, SpmdArgs...>();
222 template <
int N,
class Callable,
class... T>
227 *
reinterpret_cast<const std::remove_reference_t<Callable> *
>(obj_ptr);
228 return f(simd_args...);
231 #ifdef _GLIBCXX_RELEASE
232 #if _GLIBCXX_RELEASE < 10
233 #define __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
234 #endif // _GLIBCXX_RELEASE < 10
235 #endif // _GLIBCXX_RELEASE
237 #ifdef __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
245 template <
class F>
struct is_regcall_function_ptr_or_ref : std::false_type {};
247 template <
class Ret,
class... Args>
248 struct is_regcall_function_ptr_or_ref<Ret(__regcall &)(Args...)>
251 template <
class Ret,
class... Args>
252 struct is_regcall_function_ptr_or_ref<Ret(__regcall *)(Args...)>
255 template <
class Ret,
class... Args>
256 struct is_regcall_function_ptr_or_ref<Ret(__regcall *&)(Args...)>
260 static constexpr
bool is_regcall_function_ptr_or_ref_v =
261 is_regcall_function_ptr_or_ref<F>::value;
262 #endif // __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
264 template <
class Callable>
266 std::is_function_v<std::remove_pointer_t<std::remove_reference_t<Callable>>>
267 #ifdef __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
268 || is_regcall_function_ptr_or_ref_v<Callable>
269 #endif // __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
292 template <
class Callable,
class... T>
294 Callable &&f, T... args) {
301 constexpr
bool is_function = detail::is_function_ptr_or_ref_v<Callable>;
303 if constexpr (is_function) {
304 return __builtin_invoke_simd<is_function, RetSpmd>(
311 return __builtin_invoke_simd<is_function, RetSpmd>(
312 detail::simd_call_helper<N, Callable, T...>, &f,
317 #ifndef __INVOKE_SIMD_ENABLE_ALL_CALLABLES
318 static_assert(is_function &&
319 "invoke_simd does not support functors or lambdas yet");
320 #endif // __INVOKE_SIMD_ENABLE_ALL_CALLABLES