24 #include <unordered_set>
28 inline namespace _V1 {
44 const std::lock_guard<std::mutex> Guard(
47 std::vector<PlatformImplPtr> &PlatformCache =
51 for (
const auto &PlatImpl : PlatformCache) {
57 Result = std::make_shared<platform_impl>(
PiPlatform, Plugin);
58 PlatformCache.emplace_back(Result);
71 sizeof(Plt), &Plt,
nullptr);
87 auto IsMatchingOpenCL = [](
platform Platform,
const std::string_view name) {
91 const bool HasNameMatch = Platform.
get_info<info::platform::name>().find(
92 name) != std::string::npos;
94 const bool IsMatchingOCL = (HasNameMatch && Backend ==
backend::opencl);
97 std::cout <<
"SYCL_PI_TRACE[all]: " << name
98 <<
" OpenCL platform found but is not compatible." << std::endl;
100 return IsMatchingOCL;
102 return IsMatchingOpenCL(Platform,
"NVIDIA CUDA") ||
103 IsMatchingOpenCL(Platform,
"AMD Accelerated Parallel Processing");
111 auto getPluginPlatforms = [](
PluginPtr &Plugin) {
112 std::vector<platform> Platforms;
115 0,
nullptr, &NumPlatforms) != PI_SUCCESS)
119 std::vector<sycl::detail::pi::PiPlatform> PiPlatforms(NumPlatforms);
121 NumPlatforms, PiPlatforms.data(),
nullptr) != PI_SUCCESS)
125 platform Platform = detail::createSyclObjFromImpl<platform>(
136 Platforms.push_back(Platform);
143 static const bool PreferUR = [] {
144 const char *PreferURStr = std::getenv(
"SYCL_PREFER_UR");
145 return (PreferURStr && (std::stoi(PreferURStr) != 0));
151 std::vector<std::pair<platform, PluginPtr>> PlatformsWithPlugin;
155 std::unordered_set<backend> BackendsUR;
165 for (
const auto &P : getPluginPlatforms(*PluginUR)) {
166 PlatformsWithPlugin.push_back({P, *PluginUR});
173 for (
auto &Plugin : Plugins) {
177 const auto &PluginPlatforms = getPluginPlatforms(Plugin);
178 for (
const auto &P : PluginPlatforms) {
182 PlatformsWithPlugin.push_back({P, Plugin});
188 std::vector<platform> Platforms;
189 for (
auto &Platform : PlatformsWithPlugin) {
190 auto &Plugin = Platform.second;
191 std::lock_guard<std::mutex> Guard(*Plugin->getPluginMutex());
192 Plugin->getPlatformId(
getSyclObjImpl(Platform.first)->getHandleRef());
193 Platforms.push_back(Platform.first);
219 template <
typename ListT,
typename FilterT>
220 std::vector<int> platform_impl::filterDeviceFilter(
221 std::vector<sycl::detail::pi::PiDevice> &PiDevices,
222 ListT *FilterList)
const {
224 constexpr
bool is_ods_target = std::is_same_v<FilterT, ods_target>;
229 if constexpr (is_ods_target) {
237 std::sort(FilterList->get().begin(), FilterList->get().end(),
239 return filter1.IsNegativeTarget && !filter2.IsNegativeTarget;
247 std::map<int, bool> Blacklist;
250 std::vector<int> original_indices;
262 std::lock_guard<std::mutex> Guard(*MPlugin->getPluginMutex());
263 int DeviceNum = MPlugin->getStartingDeviceId(MPlatform);
268 &PiDevType,
nullptr);
273 for (
const FilterT &Filter : FilterList->get()) {
276 if (FilterBackend == Backend || FilterBackend ==
backend::all) {
282 if (!Filter.DeviceNum || DeviceNum == Filter.DeviceNum.value()) {
283 if constexpr (is_ods_target) {
284 if (!Blacklist[DeviceNum]) {
285 if (!Filter.IsNegativeTarget) {
286 PiDevices[InsertIDx++] = Device;
287 original_indices.push_back(DeviceNum);
291 Blacklist[DeviceNum] =
true;
295 PiDevices[InsertIDx++] = Device;
296 original_indices.push_back(DeviceNum);
301 }
else if (FilterDevType == DeviceType) {
302 if (!Filter.DeviceNum || DeviceNum == Filter.DeviceNum.value()) {
303 if constexpr (is_ods_target) {
304 if (!Blacklist[DeviceNum]) {
305 if (!Filter.IsNegativeTarget) {
306 PiDevices[InsertIDx++] = Device;
307 original_indices.push_back(DeviceNum);
311 Blacklist[DeviceNum] =
true;
315 PiDevices[InsertIDx++] = Device;
316 original_indices.push_back(DeviceNum);
325 PiDevices.resize(InsertIDx);
329 MPlugin->setLastDeviceId(MPlatform, DeviceNum);
330 return original_indices;
333 std::shared_ptr<device_impl>
335 const std::lock_guard<std::mutex> Guard(MDeviceMapMutex);
336 return getDeviceImplHelper(
PiDevice);
341 const std::shared_ptr<platform_impl> &PlatformImpl) {
342 const std::lock_guard<std::mutex> Guard(MDeviceMapMutex);
344 std::shared_ptr<device_impl> Result = getDeviceImplHelper(
PiDevice);
349 Result = std::make_shared<device_impl>(
PiDevice, PlatformImpl);
350 MDeviceCache.emplace_back(Result);
361 auto supported = dev.
get_info<info::device::partition_affinity_domains>();
362 auto It = std::find(std::begin(supported), std::end(supported), domain);
363 return It != std::end(supported);
368 auto supported = dev.
get_info<info::device::partition_properties>();
370 std::find(std::begin(supported), std::end(supported), partitionProp);
371 return It != std::end(supported);
375 backend PlatformBackend, std::vector<device> &DeviceList,
376 ods_target_list *OdsTargetList,
const std::vector<int> &original_indices,
383 std::vector<device> FinalResult;
388 for (
unsigned i = 0; i < DeviceList.size(); i++) {
391 device &dev = DeviceList[i];
392 bool deviceAdded =
false;
395 if (PlatformBackend == TargetBackend || TargetBackend ==
backend::all) {
396 bool deviceMatch =
target.HasDeviceWildCard;
402 }
else if (
target.DeviceNum) {
403 deviceMatch = (
target.DeviceNum.value() == original_indices[i]);
410 bool supportsSubPartitioning =
413 bool wantSubSubDevice =
414 target.SubSubDeviceNum ||
target.HasSubSubDeviceWildCard;
417 if (!wantSubDevice) {
419 FinalResult.push_back(dev);
423 if (!supportsSubPartitioning) {
435 if (wantSubSubDevice) {
437 auto subDevicesToPartition =
439 if (
target.SubDeviceNum) {
440 if (subDevicesToPartition.size() >
441 target.SubDeviceNum.value()) {
442 subDevicesToPartition[0] =
443 subDevicesToPartition[
target.SubDeviceNum.value()];
444 subDevicesToPartition.resize(1);
451 for (
device subDev : subDevicesToPartition) {
452 bool supportsSubSubPartitioning =
456 if (!supportsSubSubPartitioning) {
457 if (
target.SubDeviceNum) {
467 subDev.create_sub_devices<partitionProperty>(
469 if (
target.HasSubSubDeviceWildCard) {
470 FinalResult.insert(FinalResult.end(), subSubDevices.begin(),
471 subSubDevices.end());
473 if (subSubDevices.size() >
target.SubSubDeviceNum.value()) {
474 FinalResult.push_back(
475 subSubDevices[
target.SubSubDeviceNum.value()]);
478 <<
"sub-sub-device index out of bounds: " <<
target
483 }
else if (wantSubDevice) {
487 if (
target.HasSubDeviceWildCard) {
488 FinalResult.insert(FinalResult.end(), subDevices.begin(),
491 if (subDevices.size() >
target.SubDeviceNum.value()) {
492 FinalResult.push_back(
493 subDevices[
target.SubDeviceNum.value()]);
510 std::vector<device> Res;
529 MPlatform, pi::cast<sycl::detail::pi::PiDeviceType>(DeviceType),
531 pi::cast<sycl::detail::pi::PiDevice *>(
nullptr), &NumDevices);
534 if (NumDevices == 0) {
541 auto It = std::find_if(Plugins.begin(), Plugins.end(),
542 [&Platform = MPlatform](
PluginPtr &Plugin) {
543 return Plugin->containsPiPlatform(Platform);
545 if (It != Plugins.end()) {
547 std::lock_guard<std::mutex> Guard(*Plugin->getPluginMutex());
548 Plugin->adjustLastDeviceId(MPlatform);
553 std::vector<sycl::detail::pi::PiDevice> PiDevices(NumDevices);
557 pi::cast<sycl::detail::pi::PiDeviceType>(
559 NumDevices, PiDevices.data(),
nullptr);
563 std::vector<sycl::detail::pi::PiDevice> PiDevicesToCleanUp = PiDevices;
572 std::vector<int> PlatformDeviceIndices;
576 "ONEAPI_DEVICE_SELECTOR cannot be used in "
577 "conjunction with SYCL_DEVICE_FILTER");
579 PlatformDeviceIndices = filterDeviceFilter<ods_target_list, ods_target>(
580 PiDevices, OdsTargetList);
581 }
else if (FilterList) {
582 PlatformDeviceIndices =
583 filterDeviceFilter<device_filter_list, device_filter>(PiDevices,
591 PiDevices.begin(), PiDevices.end(), std::back_inserter(Res),
593 return detail::createSyclObjFromImpl<device>(
594 PlatformImpl->getOrMakeDeviceImpl(PiDevice, PlatformImpl));
605 if (!OdsTargetList || Res.size() == 0)
611 PlatformDeviceIndices, PlatformImpl);
621 return (AllExtensionNames.find(ExtensionName) != std::string::npos);
632 template <
typename Param>
635 return get_platform_info_host<Param>();
643 if (dev.has(Aspect) ==
false) {
650 std::shared_ptr<device_impl>
652 for (
const std::weak_ptr<device_impl> &DeviceWP : MDeviceCache) {
653 if (std::shared_ptr<device_impl> Device = DeviceWP.lock()) {
654 if (Device->getHandleRef() ==
PiDevice)
661 #define __SYCL_PARAM_TRAITS_SPEC(DescType, Desc, ReturnT, PiCode) \
662 template ReturnT platform_impl::get_info<info::platform::Desc>() const;
664 #include <sycl/info/platform_traits.def>
665 #undef __SYCL_PARAM_TRAITS_SPEC