27 #include <unordered_set>
31 inline namespace _V1 {
41 const std::lock_guard<std::mutex> Guard(
44 std::vector<PlatformImplPtr> &PlatformCache =
48 for (
const auto &PlatImpl : PlatformCache) {
49 if (PlatImpl->getHandleRef() == UrPlatform)
54 Result = std::make_shared<platform_impl>(UrPlatform, Plugin);
55 PlatformCache.emplace_back(Result);
64 ur_platform_handle_t Plt =
67 Plugin->call<UrApiKind::urDeviceGetInfo>(UrDevice, UR_DEVICE_INFO_PLATFORM,
68 sizeof(Plt), &Plt,
nullptr);
84 auto IsMatchingOpenCL = [](
platform Platform,
const std::string_view name) {
85 const bool HasNameMatch = Platform.get_info<info::platform::name>().find(
86 name) != std::string::npos;
88 const bool IsMatchingOCL = (HasNameMatch && Backend ==
backend::opencl);
91 <<
" OpenCL platform found but is not compatible." << std::endl;
95 return IsMatchingOpenCL(Platform,
"NVIDIA CUDA") ||
96 IsMatchingOpenCL(Platform,
"AMD Accelerated Parallel Processing");
103 std::vector<platform> platform_impl::getPluginPlatforms(
PluginPtr &Plugin,
105 std::vector<platform> Platforms;
107 auto UrPlatforms = Plugin->getUrPlatforms();
109 if (UrPlatforms.empty()) {
113 for (
const auto &UrPlatform : UrPlatforms) {
114 platform Platform = detail::createSyclObjFromImpl<platform>(
117 const bool HasAnyDevices =
121 if (IsBanned || !HasAnyDevices) {
122 Platforms.push_back(Platform);
134 Platforms.push_back(Platform);
142 std::vector<platform> UnsupportedPlatforms;
146 for (
auto &Plugin : Plugins) {
150 std::vector<platform> PluginPlatforms =
151 getPluginPlatforms(Plugin,
false);
152 std::copy(PluginPlatforms.begin(), PluginPlatforms.end(),
153 std::back_inserter(UnsupportedPlatforms));
156 return UnsupportedPlatforms;
166 std::vector<std::pair<platform, PluginPtr>> PlatformsWithPlugin;
169 for (
auto &Plugin : Plugins) {
170 const auto &PluginPlatforms = getPluginPlatforms(Plugin);
171 for (
const auto &P : PluginPlatforms) {
172 PlatformsWithPlugin.push_back({P, Plugin});
177 std::vector<platform> Platforms;
178 for (
auto &Platform : PlatformsWithPlugin) {
179 auto &Plugin = Platform.second;
180 std::lock_guard<std::mutex> Guard(*Plugin->getPluginMutex());
181 Plugin->getPlatformId(
getSyclObjImpl(Platform.first)->getHandleRef());
182 Platforms.push_back(Platform.first);
198 template <
typename ListT,
typename FilterT>
200 platform_impl::filterDeviceFilter(std::vector<ur_device_handle_t> &UrDevices,
201 ListT *FilterList)
const {
203 constexpr
bool is_ods_target = std::is_same_v<FilterT, ods_target>;
205 if constexpr (is_ods_target) {
213 std::sort(FilterList->get().begin(), FilterList->get().end(),
215 return filter1.IsNegativeTarget && !filter2.IsNegativeTarget;
223 std::map<int, bool> Blacklist;
226 std::vector<int> original_indices;
229 ur_platform_backend_t UrBackend = UR_PLATFORM_BACKEND_UNKNOWN;
230 MPlugin->call<UrApiKind::urPlatformGetInfo>(
231 MPlatform, UR_PLATFORM_INFO_BACKEND,
sizeof(ur_platform_backend_t),
232 &UrBackend,
nullptr);
238 std::lock_guard<std::mutex> Guard(*MPlugin->getPluginMutex());
239 int DeviceNum = MPlugin->getStartingDeviceId(MPlatform);
240 for (ur_device_handle_t Device : UrDevices) {
241 ur_device_type_t UrDevType = UR_DEVICE_TYPE_ALL;
242 MPlugin->call<UrApiKind::urDeviceGetInfo>(Device, UR_DEVICE_INFO_TYPE,
243 sizeof(ur_device_type_t),
244 &UrDevType,
nullptr);
250 case UR_DEVICE_TYPE_ALL:
253 case UR_DEVICE_TYPE_GPU:
256 case UR_DEVICE_TYPE_CPU:
259 case UR_DEVICE_TYPE_FPGA:
264 for (
const FilterT &Filter : FilterList->get()) {
267 if (FilterBackend != Backend && FilterBackend !=
backend::all)
273 if (Filter.DeviceNum && DeviceNum != Filter.DeviceNum.value())
277 FilterDevType != DeviceType)
280 if constexpr (is_ods_target) {
282 if (Blacklist[DeviceNum])
285 if (Filter.IsNegativeTarget) {
288 Blacklist[DeviceNum] =
true;
293 UrDevices[InsertIDx++] = Device;
294 original_indices.push_back(DeviceNum);
299 UrDevices.resize(InsertIDx);
303 MPlugin->setLastDeviceId(MPlatform, DeviceNum);
304 return original_indices;
307 std::shared_ptr<device_impl>
309 const std::lock_guard<std::mutex> Guard(MDeviceMapMutex);
310 return getDeviceImplHelper(UrDevice);
314 ur_device_handle_t UrDevice,
315 const std::shared_ptr<platform_impl> &PlatformImpl) {
316 const std::lock_guard<std::mutex> Guard(MDeviceMapMutex);
318 std::shared_ptr<device_impl> Result = getDeviceImplHelper(UrDevice);
323 Result = std::make_shared<device_impl>(UrDevice, PlatformImpl);
324 MDeviceCache.emplace_back(Result);
335 auto supported = dev.
get_info<info::device::partition_affinity_domains>();
336 auto It = std::find(std::begin(supported), std::end(supported), domain);
337 return It != std::end(supported);
342 auto supported = dev.
get_info<info::device::partition_properties>();
344 std::find(std::begin(supported), std::end(supported), partitionProp);
345 return It != std::end(supported);
349 backend PlatformBackend, std::vector<device> &DeviceList,
350 ods_target_list *OdsTargetList,
const std::vector<int> &original_indices,
357 std::vector<device> FinalResult;
362 for (
unsigned i = 0; i < DeviceList.size(); i++) {
365 device &dev = DeviceList[i];
366 bool deviceAdded =
false;
369 if (PlatformBackend != TargetBackend && TargetBackend !=
backend::all)
372 bool deviceMatch =
target.HasDeviceWildCard;
378 }
else if (
target.DeviceNum) {
379 deviceMatch = (
target.DeviceNum.value() == original_indices[i]);
386 bool wantSubDevice =
target.SubDeviceNum ||
target.HasSubDeviceWildCard;
387 bool supportsSubPartitioning =
390 bool wantSubSubDevice =
391 target.SubSubDeviceNum ||
target.HasSubSubDeviceWildCard;
393 if (!wantSubDevice) {
396 FinalResult.push_back(dev);
402 if (!supportsSubPartitioning) {
416 if (
target.SubDeviceNum) {
417 if (subDevices.size() <=
target.SubDeviceNum.value()) {
421 subDevices[0] = subDevices[
target.SubDeviceNum.value()];
422 subDevices.resize(1);
425 if (!wantSubSubDevice) {
427 FinalResult.insert(FinalResult.end(), subDevices.begin(),
433 for (
device subDev : subDevices) {
434 bool supportsSubSubPartitioning =
437 if (!supportsSubSubPartitioning) {
438 if (
target.SubDeviceNum) {
449 subDev.create_sub_devices<partitionProperty>(affinityDomain);
450 if (
target.SubSubDeviceNum) {
451 if (subSubDevices.size() <=
target.SubSubDeviceNum.value()) {
456 subSubDevices[0] = subSubDevices[
target.SubSubDeviceNum.value()];
457 subSubDevices.resize(1);
459 FinalResult.insert(FinalResult.end(), subSubDevices.begin(),
460 subSubDevices.end());
469 std::vector<device> Res;
475 ur_device_type_t UrDeviceType = UR_DEVICE_TYPE_ALL;
477 switch (DeviceType) {
480 UrDeviceType = UR_DEVICE_TYPE_ALL;
483 UrDeviceType = UR_DEVICE_TYPE_GPU;
486 UrDeviceType = UR_DEVICE_TYPE_CPU;
489 UrDeviceType = UR_DEVICE_TYPE_FPGA;
493 uint32_t NumDevices = 0;
494 MPlugin->call<UrApiKind::urDeviceGet>(MPlatform, UrDeviceType,
496 nullptr, &NumDevices);
499 if (NumDevices == 0) {
506 auto It = std::find_if(Plugins.begin(), Plugins.end(),
507 [&Platform = MPlatform](
PluginPtr &Plugin) {
508 return Plugin->containsUrPlatform(Platform);
510 if (It != Plugins.end()) {
512 std::lock_guard<std::mutex> Guard(*Plugin->getPluginMutex());
513 Plugin->adjustLastDeviceId(MPlatform);
518 std::vector<ur_device_handle_t> UrDevices(NumDevices);
520 MPlugin->call<UrApiKind::urDeviceGet>(
523 NumDevices, UrDevices.data(),
nullptr);
527 std::vector<ur_device_handle_t> UrDevicesToCleanUp = UrDevices;
536 std::vector<int> PlatformDeviceIndices;
538 PlatformDeviceIndices = filterDeviceFilter<ods_target_list, ods_target>(
539 UrDevices, OdsTargetList);
546 UrDevices.begin(), UrDevices.end(), std::back_inserter(Res),
547 [PlatformImpl](
const ur_device_handle_t UrDevice) ->
device {
548 return detail::createSyclObjFromImpl<device>(
549 PlatformImpl->getOrMakeDeviceImpl(UrDevice, PlatformImpl));
554 for (ur_device_handle_t &UrDev : UrDevicesToCleanUp)
555 MPlugin->call<UrApiKind::urDeviceRelease>(UrDev);
560 if (!OdsTargetList || Res.size() == 0)
566 PlatformDeviceIndices, PlatformImpl);
573 return (AllExtensionNames.find(ExtensionName) != std::string::npos);
583 ur_native_handle_t Handle = 0;
584 Plugin->call<UrApiKind::urPlatformGetNativeHandle>(
getHandleRef(), &Handle);
588 template <
typename Param>
594 typename info::platform::version::return_type
595 platform_impl::get_backend_info<info::platform::version>()
const {
598 "the info::platform::version info descriptor can "
599 "only be queried with an OpenCL backend");
601 return get_info<info::platform::version>();
605 std::vector<device> &Devices);
608 typename info::device::version::return_type
609 platform_impl::get_backend_info<info::device::version>()
const {
612 "the info::device::version info descriptor can only "
613 "be queried with an OpenCL backend");
615 auto Devices = get_devices();
616 if (Devices.empty()) {
617 return "No available device";
625 typename info::device::backend_version::return_type
626 platform_impl::get_backend_info<info::device::backend_version>()
const {
629 "the info::device::backend_version info descriptor "
630 "can only be queried with a Level Zero backend");
641 if (dev.has(Aspect) ==
false) {
648 std::shared_ptr<device_impl>
649 platform_impl::getDeviceImplHelper(ur_device_handle_t UrDevice) {
650 for (
const std::weak_ptr<device_impl> &DeviceWP : MDeviceCache) {
651 if (std::shared_ptr<device_impl> Device = DeviceWP.lock()) {
652 if (Device->getHandleRef() == UrDevice)
659 #define __SYCL_PARAM_TRAITS_SPEC(DescType, Desc, ReturnT, PiCode) \
660 template ReturnT platform_impl::get_info<info::platform::Desc>() const;
662 #include <sycl/info/platform_traits.def>
663 #undef __SYCL_PARAM_TRAITS_SPEC
static void registerEarlyShutdownHandler()
std::vector< PlatformImplPtr > & getPlatformCache()
static GlobalHandler & instance()
static const char * get()
std::vector< ods_target > & get()
The SYCL device class encapsulates a single SYCL device on which kernels may be executed.
std::vector< device > create_sub_devices(size_t ComputeUnits) const
Partition device into sub devices.
detail::is_device_info_desc< Param >::return_type get_info() const
Queries this SYCL device for information requested by the template parameter param.
__SYCL_EXTERN_STREAM_ATTRS ostream cout
Linked to standard output.
std::vector< PluginPtr > & initializeUr(ur_loader_config_handle_t LoaderConfig=nullptr)
bool trace(TraceLevel level)
void applyAllowList(std::vector< ur_device_handle_t > &UrDevices, ur_platform_handle_t UrPlatform, const PluginPtr &Plugin)
decltype(Obj::impl) const & getSyclObjImpl(const Obj &SyclObject)
static std::vector< device > amendDeviceAndSubDevices(backend PlatformBackend, std::vector< device > &DeviceList, ods_target_list *OdsTargetList, const std::vector< int > &original_indices, PlatformImplPtr PlatformImpl)
static bool supportsPartitionProperty(const device &dev, info::partition_property partitionProp)
std::function< int(const sycl::device &)> DSelectorInvocableType
device select_device(const DSelectorInvocableType &DeviceSelectorInvocable)
backend convertUrBackend(ur_platform_backend_t UrBackend)
std::string get_platform_info_string_impl(ur_platform_handle_t Plt, const PluginPtr &Plugin, ur_platform_info_t UrCode)
std::shared_ptr< plugin > PluginPtr
std::shared_ptr< detail::platform_impl > PlatformImplPtr
static bool supportsAffinityDomain(const device &dev, info::partition_property partitionProp, info::partition_affinity_domain domain)
static bool IsBannedPlatform(platform Platform)
void copy(handler &CGH, const T *Src, T *Dest, size_t Count)
@ partition_by_affinity_domain
partition_affinity_domain
int default_selector_v(const device &dev)
C++ utilities for Unified Runtime integration.