DPC++ Runtime
Runtime libraries for oneAPI DPC++
plugin.hpp
Go to the documentation of this file.
1 //==------------------------- plugin.hpp - SYCL platform -------------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 #include <detail/config.hpp>
11 #include <memory>
12 #include <mutex>
13 #include <sycl/backend_types.hpp>
14 #include <sycl/detail/common.hpp>
16 #include <sycl/detail/ur.hpp>
17 
18 #include <ur_api.h>
19 
20 #ifdef XPTI_ENABLE_INSTRUMENTATION
21 // Include the headers necessary for emitting traces using the trace framework
22 #include "xpti/xpti_trace_framework.h"
23 #endif
24 
26 
27 #define __SYCL_REPORT_UR_ERR_TO_STREAM(expr) \
28  { \
29  auto code = expr; \
30  if (code != UR_RESULT_SUCCESS) { \
31  std::cerr << __SYCL_UR_ERROR_REPORT << sycl::detail::codeToString(code) \
32  << std::endl; \
33  } \
34  }
35 
36 #define __SYCL_CHECK_OCL_CODE_NO_EXC(X) __SYCL_REPORT_UR_ERR_TO_STREAM(X)
37 
38 namespace sycl {
39 inline namespace _V1 {
40 namespace detail {
41 
46 class plugin {
47 public:
48  plugin() = delete;
49 
50  plugin(ur_adapter_handle_t adapter, backend UseBackend)
51  : MAdapter(adapter), MBackend(UseBackend),
52  TracingMutex(std::make_shared<std::mutex>()),
53  MPluginMutex(std::make_shared<std::mutex>()) {}
54 
55  // Disallow accidental copies of plugins
56  plugin &operator=(const plugin &) = delete;
57  plugin(const plugin &) = delete;
58  plugin &operator=(plugin &&other) noexcept = delete;
59  plugin(plugin &&other) noexcept = delete;
60 
61  ~plugin() = default;
62 
64  template <sycl::errc errc = sycl::errc::runtime>
65  void checkUrResult(ur_result_t ur_result) const {
66  const char *message = nullptr;
67  if (ur_result == UR_RESULT_ERROR_ADAPTER_SPECIFIC) {
68  int32_t adapter_error = 0;
69  ur_result = call_nocheck(urAdapterGetLastError, MAdapter, &message, &adapter_error);
70 
71  // If the warning level is greater then 2 emit the message
73  std::clog << message << std::endl;
74  }
75 
76  // If it is a warning do not throw code
77  if (ur_result == UR_RESULT_SUCCESS) {
78  return;
79  }
80  }
81  if (ur_result != UR_RESULT_SUCCESS) {
85  sycl::detail::codeToString(ur_result) +
86  (message ? "\n" + std::string(message) + "\n"
87  : std::string{})),
88  ur_result);
89  }
90  }
91 
92  std::vector<ur_platform_handle_t> &getUrPlatforms() {
93  std::call_once(PlatformsPopulated, [&]() {
94  uint32_t platformCount = 0;
95  call(urPlatformGet, &MAdapter, 1, 0, nullptr, &platformCount);
96  UrPlatforms.resize(platformCount);
97  call(urPlatformGet, &MAdapter, 1, platformCount, UrPlatforms.data(),
98  nullptr);
99  // We need one entry in this per platform
100  LastDeviceIds.resize(platformCount);
101  });
102  return UrPlatforms;
103  }
104 
105  ur_adapter_handle_t getUrAdapter() { return MAdapter; }
106 
117  template <class UrFunc, typename... ArgsT>
118  ur_result_t call_nocheck(UrFunc F, ArgsT... Args) const {
119  ur_result_t R = UR_RESULT_SUCCESS;
120  if (!adapterReleased) {
121  R = F(Args...);
122  }
123  return R;
124  }
125 
129  template <class UrFunc, typename... ArgsT>
130  void call(UrFunc F, ArgsT... Args) const {
131  auto Err = call_nocheck(F, Args...);
132  checkUrResult(Err);
133  }
134 
136  template <sycl::errc errc, class UrFunc, typename... ArgsT>
137  void call(UrFunc F, ArgsT... Args) const {
138  auto Err = call_nocheck(F, Args...);
139  checkUrResult<errc>(Err);
140  }
141 
145  bool hasBackend(backend Backend) const { return Backend == MBackend; }
146 
147  void release() {
148  call(urAdapterRelease, MAdapter);
149  this->adapterReleased = true;
150  }
151 
152  // Return the index of a UR platform.
153  // If not found, add it and return its index.
154  // The function is expected to be called in a thread safe manner.
155  int getPlatformId(ur_platform_handle_t Platform) {
156  auto It = std::find(UrPlatforms.begin(), UrPlatforms.end(), Platform);
157  if (It != UrPlatforms.end())
158  return It - UrPlatforms.begin();
159 
160  UrPlatforms.push_back(Platform);
161  LastDeviceIds.push_back(0);
162  return UrPlatforms.size() - 1;
163  }
164 
165  // Device ids are consecutive across platforms within a plugin.
166  // We need to return the same starting index for the given platform.
167  // So, instead of returing the last device id of the given platform,
168  // return the last device id of the predecessor platform.
169  // The function is expected to be called in a thread safe manner.
170  int getStartingDeviceId(ur_platform_handle_t Platform) {
171  int PlatformId = getPlatformId(Platform);
172  if (PlatformId == 0)
173  return 0;
174  return LastDeviceIds[PlatformId - 1];
175  }
176 
177  // set the id of the last device for the given platform
178  // The function is expected to be called in a thread safe manner.
179  void setLastDeviceId(ur_platform_handle_t Platform, int Id) {
180  int PlatformId = getPlatformId(Platform);
181  LastDeviceIds[PlatformId] = Id;
182  }
183 
184  // Adjust the id of the last device for the given platform.
185  // Involved when there is no device on that platform at all.
186  // The function is expected to be called in a thread safe manner.
187  void adjustLastDeviceId(ur_platform_handle_t Platform) {
188  int PlatformId = getPlatformId(Platform);
189  if (PlatformId > 0 &&
190  LastDeviceIds[PlatformId] < LastDeviceIds[PlatformId - 1])
191  LastDeviceIds[PlatformId] = LastDeviceIds[PlatformId - 1];
192  }
193 
194  bool containsUrPlatform(ur_platform_handle_t Platform) {
195  auto It = std::find(UrPlatforms.begin(), UrPlatforms.end(), Platform);
196  return It != UrPlatforms.end();
197  }
198 
199  std::shared_ptr<std::mutex> getPluginMutex() { return MPluginMutex; }
200  bool adapterReleased = false;
201 
202 private:
203  ur_adapter_handle_t MAdapter;
204  backend MBackend;
205  std::shared_ptr<std::mutex> TracingMutex;
206  // Mutex to guard UrPlatforms and LastDeviceIds.
207  // Note that this is a temporary solution until we implement the global
208  // Device/Platform cache later.
209  std::shared_ptr<std::mutex> MPluginMutex;
210  // vector of UrPlatforms that belong to this plugin
211  std::once_flag PlatformsPopulated;
212  std::vector<ur_platform_handle_t> UrPlatforms;
213  // represents the unique ids of the last device of each platform
214  // index of this vector corresponds to the index in UrPlatforms vector.
215  std::vector<int> LastDeviceIds;
216 }; // class plugin
217 
218 using PluginPtr = std::shared_ptr<plugin>;
219 
220 } // namespace detail
221 } // namespace _V1
222 } // namespace sycl
The plugin class provides a unified interface to the underlying low-level runtimes for the device-agn...
Definition: plugin.hpp:46
bool hasBackend(backend Backend) const
Tells if this plugin can serve specified backend.
Definition: plugin.hpp:145
plugin & operator=(const plugin &)=delete
ur_result_t call_nocheck(UrFunc F, ArgsT... Args) const
Calls the UR Api, traces the call, and returns the result.
Definition: plugin.hpp:118
int getStartingDeviceId(ur_platform_handle_t Platform)
Definition: plugin.hpp:170
plugin(ur_adapter_handle_t adapter, backend UseBackend)
Definition: plugin.hpp:50
std::vector< ur_platform_handle_t > & getUrPlatforms()
Definition: plugin.hpp:92
plugin(plugin &&other) noexcept=delete
int getPlatformId(ur_platform_handle_t Platform)
Definition: plugin.hpp:155
void setLastDeviceId(ur_platform_handle_t Platform, int Id)
Definition: plugin.hpp:179
bool containsUrPlatform(ur_platform_handle_t Platform)
Definition: plugin.hpp:194
plugin(const plugin &)=delete
plugin & operator=(plugin &&other) noexcept=delete
void checkUrResult(ur_result_t ur_result) const
Definition: plugin.hpp:65
std::shared_ptr< std::mutex > getPluginMutex()
Definition: plugin.hpp:199
void call(UrFunc F, ArgsT... Args) const
Definition: plugin.hpp:137
void adjustLastDeviceId(ur_platform_handle_t Platform)
Definition: plugin.hpp:187
void call(UrFunc F, ArgsT... Args) const
Calls the API, traces the call, checks the result.
Definition: plugin.hpp:130
ur_adapter_handle_t getUrAdapter()
Definition: plugin.hpp:105
#define __SYCL_UR_ERROR_REPORT
Definition: common.hpp:164
__SYCL_EXTERN_STREAM_ATTRS ostream clog
Linked to standard error (buffered)
std::string codeToString(int32_t code)
Definition: exception.hpp:57
std::shared_ptr< plugin > PluginPtr
Definition: ur.hpp:60
exception set_ur_error(exception &&e, int32_t ur_err)
Definition: exception.hpp:157
std::error_code make_error_code(sycl::errc E) noexcept
Constructs an error code using e and sycl_category()
Definition: exception.cpp:65
Definition: access.hpp:18
_Abi const simd< _Tp, _Abi > & noexcept
Definition: simd.hpp:1324
C++ utilities for Unified Runtime integration.