DPC++ Runtime
Runtime libraries for oneAPI DPC++
kernel_bundle_impl.hpp
Go to the documentation of this file.
1 //==------- kernel_bundle_impl.hpp - SYCL kernel_bundle_impl ---------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
14 #include <detail/kernel_impl.hpp>
16 #include <sycl/backend_types.hpp>
17 #include <sycl/context.hpp>
18 #include <sycl/detail/common.hpp>
19 #include <sycl/detail/pi.h>
20 #include <sycl/device.hpp>
21 #include <sycl/kernel_bundle.hpp>
22 
23 #include <algorithm>
24 #include <cassert>
25 #include <cstdint>
26 #include <cstring>
27 #include <memory>
28 #include <vector>
29 
30 #include "split_string.hpp"
31 
32 namespace sycl {
33 inline namespace _V1 {
34 namespace detail {
35 
36 static bool checkAllDevicesAreInContext(const std::vector<device> &Devices,
37  const context &Context) {
38  return std::all_of(
39  Devices.begin(), Devices.end(), [&Context](const device &Dev) {
40  return getSyclObjImpl(Context)->isDeviceValid(getSyclObjImpl(Dev));
41  });
42 }
43 
44 static bool checkAllDevicesHaveAspect(const std::vector<device> &Devices,
45  aspect Aspect) {
46  return std::all_of(Devices.begin(), Devices.end(),
47  [&Aspect](const device &Dev) { return Dev.has(Aspect); });
48 }
49 
50 namespace syclex = sycl::ext::oneapi::experimental;
51 
52 class kernel_impl;
53 
55 // It provides an access and utilities to manage set of sycl::device_images
56 // objects.
58 
59  using SpecConstMapT = std::map<std::string, std::vector<unsigned char>>;
60 
61  void common_ctor_checks(bundle_state State) {
62  const bool AllDevicesInTheContext =
63  checkAllDevicesAreInContext(MDevices, MContext);
64  if (MDevices.empty() || !AllDevicesInTheContext)
65  throw sycl::exception(
67  "Not all devices are associated with the context or "
68  "vector of devices is empty");
69 
70  if (bundle_state::input == State &&
71  !checkAllDevicesHaveAspect(MDevices, aspect::online_compiler))
73  "Not all devices have aspect::online_compiler");
74 
75  if (bundle_state::object == State &&
76  !checkAllDevicesHaveAspect(MDevices, aspect::online_linker))
78  "Not all devices have aspect::online_linker");
79  }
80 
81 public:
82  kernel_bundle_impl(context Ctx, std::vector<device> Devs, bundle_state State)
83  : MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) {
84 
85  common_ctor_checks(State);
86 
88  MContext, MDevices, State);
89  }
90 
91  // Interop constructor used by make_kernel
92  kernel_bundle_impl(context Ctx, std::vector<device> Devs)
93  : MContext(Ctx), MDevices(Devs), MState(bundle_state::executable) {
94  if (!checkAllDevicesAreInContext(Devs, Ctx))
95  throw sycl::exception(
97  "Not all devices are associated with the context or "
98  "vector of devices is empty");
99  MIsInterop = true;
100  }
101 
102  // Interop constructor
103  kernel_bundle_impl(context Ctx, std::vector<device> Devs,
104  device_image_plain &DevImage)
105  : kernel_bundle_impl(Ctx, Devs) {
106  MDeviceImages.push_back(DevImage);
107  }
108 
109  // Matches sycl::build and sycl::compile
110  // Have one constructor because sycl::build and sycl::compile have the same
111  // signature
113  std::vector<device> Devs, const property_list &PropList,
114  bundle_state TargetState)
115  : MContext(InputBundle.get_context()), MDevices(std::move(Devs)),
116  MState(TargetState) {
117 
118  MSpecConstValues = getSyclObjImpl(InputBundle)->get_spec_const_map_ref();
119 
120  const std::vector<device> &InputBundleDevices =
121  getSyclObjImpl(InputBundle)->get_devices();
122  const bool AllDevsAssociatedWithInputBundle =
123  std::all_of(MDevices.begin(), MDevices.end(),
124  [&InputBundleDevices](const device &Dev) {
125  return InputBundleDevices.end() !=
126  std::find(InputBundleDevices.begin(),
127  InputBundleDevices.end(), Dev);
128  });
129  if (MDevices.empty() || !AllDevsAssociatedWithInputBundle)
130  throw sycl::exception(
132  "Not all devices are in the set of associated "
133  "devices for input bundle or vector of devices is empty");
134 
135  for (const device_image_plain &DeviceImage : InputBundle) {
136  // Skip images which are not compatible with devices provided
137  if (std::none_of(
138  MDevices.begin(), MDevices.end(),
139  [&DeviceImage](const device &Dev) {
140  return getSyclObjImpl(DeviceImage)->compatible_with_device(Dev);
141  }))
142  continue;
143 
144  switch (TargetState) {
146  MDeviceImages.push_back(detail::ProgramManager::getInstance().compile(
147  DeviceImage, MDevices, PropList));
148  break;
150  MDeviceImages.push_back(detail::ProgramManager::getInstance().build(
151  DeviceImage, MDevices, PropList));
152  break;
153  case bundle_state::input:
156  "Internal error. The target state should not be input "
157  "or ext_oneapi_source");
158  break;
159  }
160  }
161  }
162 
163  // Matches sycl::link
165  const std::vector<kernel_bundle<bundle_state::object>> &ObjectBundles,
166  std::vector<device> Devs, const property_list &PropList)
167  : MDevices(std::move(Devs)), MState(bundle_state::executable) {
168 
169  if (MDevices.empty())
171  "Vector of devices is empty");
172 
173  if (ObjectBundles.empty())
174  return;
175 
176  MContext = ObjectBundles[0].get_context();
177  for (size_t I = 1; I < ObjectBundles.size(); ++I) {
178  if (ObjectBundles[I].get_context() != MContext)
179  throw sycl::exception(
181  "Not all input bundles have the same associated context");
182  }
183 
184  // Check if any of the devices in devs are not in the set of associated
185  // devices for any of the bundles in ObjectBundles
186  const bool AllDevsAssociatedWithInputBundles = std::all_of(
187  MDevices.begin(), MDevices.end(), [&ObjectBundles](const device &Dev) {
188  // Number of devices is expected to be small
189  return std::all_of(
190  ObjectBundles.begin(), ObjectBundles.end(),
191  [&Dev](const kernel_bundle<bundle_state::object> &KernelBundle) {
192  const std::vector<device> &BundleDevices =
193  getSyclObjImpl(KernelBundle)->get_devices();
194  return BundleDevices.end() != std::find(BundleDevices.begin(),
195  BundleDevices.end(),
196  Dev);
197  });
198  });
199  if (!AllDevsAssociatedWithInputBundles)
201  "Not all devices are in the set of associated "
202  "devices for input bundles");
203 
204  // TODO: Unify with c'tor for sycl::comile and sycl::build by calling
205  // sycl::join on vector of kernel_bundles
206 
207  // The loop below just links each device image separately, not linking any
208  // two device images together. This is correct so long as each device image
209  // has no unresolved symbols. That's the case when device images are created
210  // from generic SYCL APIs. There's no way in generic SYCL to create a kernel
211  // which references an undefined symbol. If we decide in the future to allow
212  // a backend interop API to create a "sycl::kernel_bundle" that references
213  // undefined symbols, then the logic in this loop will need to be changed.
214  for (const kernel_bundle<bundle_state::object> &ObjectBundle :
215  ObjectBundles) {
216  for (const device_image_plain &DeviceImage : ObjectBundle) {
217 
218  // Skip images which are not compatible with devices provided
219  if (std::none_of(MDevices.begin(), MDevices.end(),
220  [&DeviceImage](const device &Dev) {
221  return getSyclObjImpl(DeviceImage)
222  ->compatible_with_device(Dev);
223  }))
224  continue;
225 
226  std::vector<device_image_plain> LinkedResults =
227  detail::ProgramManager::getInstance().link(DeviceImage, MDevices,
228  PropList);
229  MDeviceImages.insert(MDeviceImages.end(), LinkedResults.begin(),
230  LinkedResults.end());
231  }
232  }
233 
234  for (const kernel_bundle<bundle_state::object> &Bundle : ObjectBundles) {
235  const KernelBundleImplPtr BundlePtr = getSyclObjImpl(Bundle);
236  for (const std::pair<const std::string, std::vector<unsigned char>>
237  &SpecConst : BundlePtr->MSpecConstValues) {
238  MSpecConstValues[SpecConst.first] = SpecConst.second;
239  }
240  }
241  }
242 
243  kernel_bundle_impl(context Ctx, std::vector<device> Devs,
244  const std::vector<kernel_id> &KernelIDs,
245  bundle_state State)
246  : MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) {
247 
248  common_ctor_checks(State);
249 
251  MContext, MDevices, KernelIDs, State);
252  }
253 
254  kernel_bundle_impl(context Ctx, std::vector<device> Devs,
255  const DevImgSelectorImpl &Selector, bundle_state State)
256  : MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) {
257 
258  common_ctor_checks(State);
259 
261  MContext, MDevices, Selector, State);
262  }
263 
264  // C'tor matches sycl::join API
265  kernel_bundle_impl(const std::vector<detail::KernelBundleImplPtr> &Bundles,
266  bundle_state State)
267  : MState(State) {
268  if (Bundles.empty())
269  return;
270 
271  MContext = Bundles[0]->MContext;
272  MDevices = Bundles[0]->MDevices;
273  for (size_t I = 1; I < Bundles.size(); ++I) {
274  if (Bundles[I]->MContext != MContext)
275  throw sycl::exception(
277  "Not all input bundles have the same associated context.");
278  if (Bundles[I]->MDevices != MDevices)
279  throw sycl::exception(
281  "Not all input bundles have the same set of associated devices.");
282  }
283 
284  for (const detail::KernelBundleImplPtr &Bundle : Bundles) {
285 
286  MDeviceImages.insert(MDeviceImages.end(), Bundle->MDeviceImages.begin(),
287  Bundle->MDeviceImages.end());
288  }
289 
290  std::sort(MDeviceImages.begin(), MDeviceImages.end(),
292 
293  if (get_bundle_state() == bundle_state::input) {
294  // Copy spec constants values from the device images to be removed.
295  auto MergeSpecConstants = [this](const device_image_plain &Img) {
296  const detail::DeviceImageImplPtr &ImgImpl = getSyclObjImpl(Img);
297  const std::map<std::string,
298  std::vector<device_image_impl::SpecConstDescT>>
299  &SpecConsts = ImgImpl->get_spec_const_data_ref();
300  const std::vector<unsigned char> &Blob =
301  ImgImpl->get_spec_const_blob_ref();
302  for (const std::pair<const std::string,
303  std::vector<device_image_impl::SpecConstDescT>>
304  &SpecConst : SpecConsts) {
305  if (SpecConst.second.front().IsSet)
306  set_specialization_constant_raw_value(
307  SpecConst.first.c_str(),
308  Blob.data() + SpecConst.second.front().BlobOffset,
309  SpecConst.second.back().CompositeOffset +
310  SpecConst.second.back().Size);
311  }
312  };
313  std::for_each(MDeviceImages.begin(), MDeviceImages.end(),
314  MergeSpecConstants);
315  }
316 
317  const auto DevImgIt =
318  std::unique(MDeviceImages.begin(), MDeviceImages.end());
319 
320  // Remove duplicate device images.
321  MDeviceImages.erase(DevImgIt, MDeviceImages.end());
322 
323  for (const detail::KernelBundleImplPtr &Bundle : Bundles) {
324  for (const std::pair<const std::string, std::vector<unsigned char>>
325  &SpecConst : Bundle->MSpecConstValues) {
326  set_specialization_constant_raw_value(SpecConst.first.c_str(),
327  SpecConst.second.data(),
328  SpecConst.second.size());
329  }
330  }
331  }
332 
334  std::vector<std::pair<std::string /* name */, std::string /* content */>>;
335  // oneapi_ext_kernel_compiler
336  // construct from source string
338  const std::string &Src, include_pairs_t IncludePairsVec)
339  : MContext(Context), MDevices(Context.get_devices()),
340  MState(bundle_state::ext_oneapi_source), Language(Lang), Source(Src),
341  IncludePairs(IncludePairsVec) {}
342 
343  // oneapi_ext_kernel_compiler
344  // construct from source bytes
346  const std::vector<std::byte> &Bytes)
347  : MContext(Context), MDevices(Context.get_devices()),
348  MState(bundle_state::ext_oneapi_source), Language(Lang), Source(Bytes) {
349  }
350 
351  // oneapi_ext_kernel_compiler
352  // interop constructor
353  kernel_bundle_impl(context Ctx, std::vector<device> Devs,
354  device_image_plain &DevImage,
355  std::vector<std::string> KNames,
357  : kernel_bundle_impl(Ctx, Devs, DevImage) {
358  MState = bundle_state::executable;
359  KernelNames = KNames;
360  Language = Lang;
361  }
362 
363  std::string trimXsFlags(std::string &str) {
364  // Trim first and last quote if they exist, but no others.
365  char EncounteredQuote = '\0';
366  auto Start = std::find_if(str.begin(), str.end(), [&](char c) {
367  if (!EncounteredQuote && (c == '\'' || c == '"')) {
368  EncounteredQuote = c;
369  return false;
370  }
371  return !std::isspace(c);
372  });
373  auto End = std::find_if(str.rbegin(), str.rend(), [&](char c) {
374  if (c == EncounteredQuote) {
375  EncounteredQuote = '\0';
376  return false;
377  }
378  return !std::isspace(c);
379  }).base();
380  if (Start != std::end(str) && End != std::begin(str) && Start < End) {
381  return std::string(Start, End);
382  }
383 
384  return "";
385  }
386 
387  std::string extractXsFlags(const std::vector<std::string> &BuildOptions) {
388  std::stringstream SS;
389  for (std::string Option : BuildOptions) {
390  auto Where = Option.find("-Xs");
391  if (Where != std::string::npos) {
392  Where += 3;
393  std::string Flags = Option.substr(Where);
394  SS << trimXsFlags(Flags) << " ";
395  }
396  }
397  return SS.str();
398  }
399 
400  std::shared_ptr<kernel_bundle_impl>
401  build_from_source(const std::vector<device> Devices,
402  const std::vector<std::string> &BuildOptions,
403  std::string *LogPtr,
404  const std::vector<std::string> &RegisteredKernelNames) {
405  assert(MState == bundle_state::ext_oneapi_source &&
406  "bundle_state::ext_oneapi_source required");
407 
408  using ContextImplPtr = std::shared_ptr<sycl::detail::context_impl>;
409  ContextImplPtr ContextImpl = getSyclObjImpl(MContext);
410  const PluginPtr &Plugin = ContextImpl->getPlugin();
411 
412  std::vector<pi::PiDevice> DeviceVec;
413  DeviceVec.reserve(Devices.size());
414  for (const auto &SyclDev : Devices) {
415  pi::PiDevice Dev = getSyclObjImpl(SyclDev)->getHandleRef();
416  DeviceVec.push_back(Dev);
417  }
418 
419  const auto spirv = [&]() -> std::vector<uint8_t> {
420  if (Language == syclex::source_language::opencl) {
421  // if successful, the log is empty. if failed, throws an error with the
422  // compilation log.
423  const auto &SourceStr = std::get<std::string>(this->Source);
424  std::vector<uint32_t> IPVersionVec(Devices.size());
425  std::transform(DeviceVec.begin(), DeviceVec.end(), IPVersionVec.begin(),
426  [&](pi::PiDevice d) {
427  uint32_t ipVersion = 0;
428  Plugin->call<PiApiKind::piDeviceGetInfo>(
429  d, PI_EXT_ONEAPI_DEVICE_INFO_IP_VERSION,
430  sizeof(uint32_t), &ipVersion, nullptr);
431  return ipVersion;
432  });
433  return syclex::detail::OpenCLC_to_SPIRV(SourceStr, IPVersionVec,
434  BuildOptions, LogPtr);
435  }
436  if (Language == syclex::source_language::spirv) {
437  const auto &SourceBytes =
438  std::get<std::vector<std::byte>>(this->Source);
439  std::vector<uint8_t> Result(SourceBytes.size());
440  std::transform(SourceBytes.cbegin(), SourceBytes.cend(), Result.begin(),
441  [](std::byte B) { return static_cast<uint8_t>(B); });
442  return Result;
443  }
444  if (Language == syclex::source_language::sycl) {
445  const auto &SourceStr = std::get<std::string>(this->Source);
446  return syclex::detail::SYCL_to_SPIRV(SourceStr, IncludePairs,
447  BuildOptions, LogPtr,
449  }
450  throw sycl::exception(
451  make_error_code(errc::invalid),
452  "OpenCL C and SPIR-V are the only supported languages at this time");
453  }();
454 
456  Plugin->call<PiApiKind::piProgramCreate>(
457  ContextImpl->getHandleRef(), spirv.data(), spirv.size(), &PiProgram);
458  // program created by piProgramCreate is implicitly retained.
459 
460  std::string XsFlags = extractXsFlags(BuildOptions);
461  Plugin->call<errc::build, PiApiKind::piProgramBuild>(
462  PiProgram, DeviceVec.size(), DeviceVec.data(), XsFlags.c_str(), nullptr,
463  nullptr);
464 
465  // Get the number of kernels in the program.
466  size_t NumKernels;
467  Plugin->call<PiApiKind::piProgramGetInfo>(
468  PiProgram, PI_PROGRAM_INFO_NUM_KERNELS, sizeof(size_t), &NumKernels,
469  nullptr);
470 
471  // Get the kernel names.
472  size_t KernelNamesSize;
473  Plugin->call<PiApiKind::piProgramGetInfo>(
474  PiProgram, PI_PROGRAM_INFO_KERNEL_NAMES, 0, nullptr, &KernelNamesSize);
475 
476  // semi-colon delimited list of kernel names.
477  std::string KernelNamesStr(KernelNamesSize, ' ');
478  Plugin->call<PiApiKind::piProgramGetInfo>(
479  PiProgram, PI_PROGRAM_INFO_KERNEL_NAMES, KernelNamesStr.size(),
480  &KernelNamesStr[0], nullptr);
481  std::vector<std::string> KernelNames =
482  detail::split_string(KernelNamesStr, ';');
483 
484  // make the device image and the kernel_bundle_impl
485  auto KernelIDs = std::make_shared<std::vector<kernel_id>>();
486  auto DevImgImpl = std::make_shared<device_image_impl>(
487  nullptr, MContext, MDevices, bundle_state::executable, KernelIDs,
488  PiProgram);
489  device_image_plain DevImg{DevImgImpl};
490  return std::make_shared<kernel_bundle_impl>(MContext, MDevices, DevImg,
491  KernelNames, Language);
492  }
493 
494  std::string adjust_kernel_name(const std::string &Name,
496  // Once name demangling support is in, we won't need this.
497  if (Lang != syclex::source_language::sycl)
498  return Name;
499 
500  bool isMangled = Name.find("__sycl_kernel_") != std::string::npos;
501  return isMangled ? Name : "__sycl_kernel_" + Name;
502  }
503 
504  bool ext_oneapi_has_kernel(const std::string &Name) {
505  auto it = std::find(KernelNames.begin(), KernelNames.end(),
506  adjust_kernel_name(Name, Language));
507  return it != KernelNames.end();
508  }
509 
510  kernel
511  ext_oneapi_get_kernel(const std::string &Name,
512  const std::shared_ptr<kernel_bundle_impl> &Self) {
513  if (KernelNames.empty())
514  throw sycl::exception(make_error_code(errc::invalid),
515  "'ext_oneapi_get_kernel' is only available in "
516  "kernel_bundles successfully built from "
517  "kernel_bundle<bundle_state:ext_oneapi_source>.");
518 
519  std::string AdjustedName = adjust_kernel_name(Name, Language);
520  if (!ext_oneapi_has_kernel(Name))
521  throw sycl::exception(make_error_code(errc::invalid),
522  "kernel '" + AdjustedName +
523  "' not found in kernel_bundle");
524 
525  assert(MDeviceImages.size() > 0);
526  const std::shared_ptr<detail::device_image_impl> &DeviceImageImpl =
527  detail::getSyclObjImpl(MDeviceImages[0]);
528  sycl::detail::pi::PiProgram PiProgram = DeviceImageImpl->get_program_ref();
529  ContextImplPtr ContextImpl = getSyclObjImpl(MContext);
530  const PluginPtr &Plugin = ContextImpl->getPlugin();
532  Plugin->call<PiApiKind::piKernelCreate>(PiProgram, AdjustedName.c_str(),
533  &PiKernel);
534  // Kernel created by piKernelCreate is implicitly retained.
535 
536  std::shared_ptr<kernel_impl> KernelImpl = std::make_shared<kernel_impl>(
537  PiKernel, detail::getSyclObjImpl(MContext), Self);
538 
539  return detail::createSyclObjFromImpl<kernel>(KernelImpl);
540  }
541 
542  bool empty() const noexcept { return MDeviceImages.empty(); }
543 
545  return MContext.get_platform().get_backend();
546  }
547 
548  context get_context() const noexcept { return MContext; }
549 
550  const std::vector<device> &get_devices() const noexcept { return MDevices; }
551 
552  std::vector<kernel_id> get_kernel_ids() const {
553  // Collect kernel ids from all device images, then remove duplicates
554 
555  std::vector<kernel_id> Result;
556  for (const device_image_plain &DeviceImage : MDeviceImages) {
557  const std::vector<kernel_id> &KernelIDs =
558  getSyclObjImpl(DeviceImage)->get_kernel_ids();
559 
560  Result.insert(Result.end(), KernelIDs.begin(), KernelIDs.end());
561  }
562  std::sort(Result.begin(), Result.end(), LessByNameComp{});
563 
564  auto NewIt = std::unique(Result.begin(), Result.end(), EqualByNameComp{});
565  Result.erase(NewIt, Result.end());
566 
567  return Result;
568  }
569 
570  kernel
571  get_kernel(const kernel_id &KernelID,
572  const std::shared_ptr<detail::kernel_bundle_impl> &Self) const {
573  using ImageImpl = std::shared_ptr<detail::device_image_impl>;
574  // Selected image.
575  ImageImpl SelectedImage = nullptr;
576  // Image where specialization constants are replaced with default values.
577  ImageImpl ImageWithReplacedSpecConsts = nullptr;
578  // Original image where specialization constants are not replaced with
579  // default values.
580  ImageImpl OriginalImage = nullptr;
581  // Used to track if any of the candidate images has specialization values
582  // set.
583  bool SpecConstsSet = false;
584  for (auto &DeviceImage : MDeviceImages) {
585  if (!DeviceImage.has_kernel(KernelID))
586  continue;
587 
588  const auto DeviceImageImpl = detail::getSyclObjImpl(DeviceImage);
589  SpecConstsSet |= DeviceImageImpl->is_any_specialization_constant_set();
590 
591  // Remember current image in corresponding variable depending on whether
592  // specialization constants are replaced with default value or not.
593  (DeviceImageImpl->specialization_constants_replaced_with_default()
594  ? ImageWithReplacedSpecConsts
595  : OriginalImage) = DeviceImageImpl;
596 
597  if (SpecConstsSet) {
598  // If specialization constant is set in any of the candidate images
599  // then we can't use ReplacedImage, so we select NativeImage if any or
600  // we select OriginalImage and keep iterating in case there is an image
601  // with native support.
602  SelectedImage = OriginalImage;
603  if (SelectedImage &&
604  SelectedImage->all_specialization_constant_native())
605  break;
606  } else {
607  // For now select ReplacedImage but it may be reset if any of the
608  // further device images has specialization constant value set. If after
609  // all iterations specialization constant values are not set in any of
610  // the candidate images then that will be the selected image.
611  // Also we don't want to use ReplacedImage if device image has native
612  // support.
613  if (ImageWithReplacedSpecConsts &&
614  !ImageWithReplacedSpecConsts->all_specialization_constant_native())
615  SelectedImage = ImageWithReplacedSpecConsts;
616  else
617  // In case if we don't have or don't use ReplacedImage.
618  SelectedImage = OriginalImage;
619  }
620  }
621 
622  if (!SelectedImage)
623  throw sycl::exception(make_error_code(errc::invalid),
624  "The kernel bundle does not contain the kernel "
625  "identified by kernelId.");
626 
627  auto [Kernel, CacheMutex, ArgMask] =
628  detail::ProgramManager::getInstance().getOrCreateKernel(
629  MContext, KernelID.get_name(), /*PropList=*/{},
630  SelectedImage->get_program_ref());
631 
632  std::shared_ptr<kernel_impl> KernelImpl = std::make_shared<kernel_impl>(
633  Kernel, detail::getSyclObjImpl(MContext), SelectedImage, Self, ArgMask,
634  SelectedImage->get_program_ref(), CacheMutex);
635 
636  return detail::createSyclObjFromImpl<kernel>(KernelImpl);
637  }
638 
639  bool has_kernel(const kernel_id &KernelID) const noexcept {
640  return std::any_of(MDeviceImages.begin(), MDeviceImages.end(),
641  [&KernelID](const device_image_plain &DeviceImage) {
642  return DeviceImage.has_kernel(KernelID);
643  });
644  }
645 
646  bool has_kernel(const kernel_id &KernelID, const device &Dev) const noexcept {
647  return std::any_of(
648  MDeviceImages.begin(), MDeviceImages.end(),
649  [&KernelID, &Dev](const device_image_plain &DeviceImage) {
650  return DeviceImage.has_kernel(KernelID, Dev);
651  });
652  }
653 
655  return std::any_of(
656  MDeviceImages.begin(), MDeviceImages.end(),
657  [](const device_image_plain &DeviceImage) {
658  return getSyclObjImpl(DeviceImage)->has_specialization_constants();
659  });
660  }
661 
663  return contains_specialization_constants() &&
664  std::all_of(MDeviceImages.begin(), MDeviceImages.end(),
665  [](const device_image_plain &DeviceImage) {
666  return getSyclObjImpl(DeviceImage)
667  ->all_specialization_constant_native();
668  });
669  }
670 
671  bool has_specialization_constant(const char *SpecName) const noexcept {
672  return std::any_of(MDeviceImages.begin(), MDeviceImages.end(),
673  [SpecName](const device_image_plain &DeviceImage) {
674  return getSyclObjImpl(DeviceImage)
675  ->has_specialization_constant(SpecName);
676  });
677  }
678 
679  void set_specialization_constant_raw_value(const char *SpecName,
680  const void *Value,
681  size_t Size) noexcept {
682  if (has_specialization_constant(SpecName))
683  for (const device_image_plain &DeviceImage : MDeviceImages)
684  getSyclObjImpl(DeviceImage)
685  ->set_specialization_constant_raw_value(SpecName, Value);
686  else {
687  std::vector<unsigned char> &Val = MSpecConstValues[std::string{SpecName}];
688  Val.resize(Size);
689  std::memcpy(Val.data(), Value, Size);
690  }
691  }
692 
693  void get_specialization_constant_raw_value(const char *SpecName,
694  void *ValueRet) const noexcept {
695  for (const device_image_plain &DeviceImage : MDeviceImages)
696  if (getSyclObjImpl(DeviceImage)->has_specialization_constant(SpecName)) {
697  getSyclObjImpl(DeviceImage)
698  ->get_specialization_constant_raw_value(SpecName, ValueRet);
699  return;
700  }
701 
702  // Specialization constant wasn't found in any of the device images,
703  // try to fetch value from kernel_bundle.
704  if (MSpecConstValues.count(std::string{SpecName}) != 0) {
705  const std::vector<unsigned char> &Val =
706  MSpecConstValues.at(std::string{SpecName});
707  auto *Dest = static_cast<unsigned char *>(ValueRet);
708  std::uninitialized_copy(Val.begin(), Val.end(), Dest);
709  return;
710  }
711 
712  assert(false &&
713  "get_specialization_constant_raw_value called for missing constant");
714  }
715 
716  bool is_specialization_constant_set(const char *SpecName) const noexcept {
717  bool SetInDevImg =
718  std::any_of(MDeviceImages.begin(), MDeviceImages.end(),
719  [SpecName](const device_image_plain &DeviceImage) {
720  return getSyclObjImpl(DeviceImage)
721  ->is_specialization_constant_set(SpecName);
722  });
723  return SetInDevImg || MSpecConstValues.count(std::string{SpecName}) != 0;
724  }
725 
726  const device_image_plain *begin() const { return MDeviceImages.data(); }
727 
728  const device_image_plain *end() const {
729  return MDeviceImages.data() + MDeviceImages.size();
730  }
731 
732  size_t size() const noexcept { return MDeviceImages.size(); }
733 
734  bundle_state get_bundle_state() const { return MState; }
735 
736  const SpecConstMapT &get_spec_const_map_ref() const noexcept {
737  return MSpecConstValues;
738  }
739 
740  bool isInterop() const { return MIsInterop; }
741 
742  bool add_kernel(const kernel_id &KernelID, const device &Dev) {
743  // Skip if kernel is already there
744  if (has_kernel(KernelID, Dev))
745  return true;
746 
747  // First try and get images in current bundle state
748  const bundle_state BundleState = get_bundle_state();
749  std::vector<device_image_plain> NewDevImgs =
750  detail::ProgramManager::getInstance().getSYCLDeviceImages(
751  MContext, {Dev}, {KernelID}, BundleState);
752 
753  // No images found so we report as not inserted
754  if (NewDevImgs.empty())
755  return false;
756 
757  // Propagate already set specialization constants to the new images
758  for (device_image_plain &DevImg : NewDevImgs)
759  for (auto SpecConst : MSpecConstValues)
760  getSyclObjImpl(DevImg)->set_specialization_constant_raw_value(
761  SpecConst.first.c_str(), SpecConst.second.data());
762 
763  // Add the images to the collection
764  MDeviceImages.insert(MDeviceImages.end(), NewDevImgs.begin(),
765  NewDevImgs.end());
766  return true;
767  }
768 
769 private:
770  context MContext;
771  std::vector<device> MDevices;
772  std::vector<device_image_plain> MDeviceImages;
773  // This map stores values for specialization constants, that are missing
774  // from any device image.
775  SpecConstMapT MSpecConstValues;
776  bool MIsInterop = false;
777  bundle_state MState;
778 
779  // ext_oneapi_kernel_compiler : Source, Languauge, KernelNames, IncludePairs
780  // Language is for both state::source and state::executable.
781  syclex::source_language Language = syclex::source_language::opencl;
782  const std::variant<std::string, std::vector<std::byte>> Source;
783  // only kernel_bundles created from source have KernelNames member.
784  std::vector<std::string> KernelNames;
785  include_pairs_t IncludePairs;
786 };
787 
788 } // namespace detail
789 } // namespace _V1
790 } // namespace sycl
The context class represents a SYCL context on which kernel functions may be executed.
Definition: context.hpp:50
std::vector< device_image_plain > getSYCLDeviceImages(const context &Ctx, const std::vector< device > &Devs, bundle_state State)
static ProgramManager & getInstance()
std::vector< device_image_plain > link(const device_image_plain &DeviceImages, const std::vector< device > &Devs, const property_list &PropList)
The class is an impl counterpart of the sycl::kernel_bundle.
bool has_kernel(const kernel_id &KernelID) const noexcept
const SpecConstMapT & get_spec_const_map_ref() const noexcept
std::string adjust_kernel_name(const std::string &Name, syclex::source_language Lang)
kernel_bundle_impl(const std::vector< kernel_bundle< bundle_state::object >> &ObjectBundles, std::vector< device > Devs, const property_list &PropList)
kernel_bundle_impl(context Ctx, std::vector< device > Devs)
kernel_bundle_impl(const context &Context, syclex::source_language Lang, const std::string &Src, include_pairs_t IncludePairsVec)
bool add_kernel(const kernel_id &KernelID, const device &Dev)
kernel ext_oneapi_get_kernel(const std::string &Name, const std::shared_ptr< kernel_bundle_impl > &Self)
kernel_bundle_impl(context Ctx, std::vector< device > Devs, device_image_plain &DevImage)
const device_image_plain * end() const
kernel_bundle_impl(const context &Context, syclex::source_language Lang, const std::vector< std::byte > &Bytes)
const std::vector< device > & get_devices() const noexcept
void set_specialization_constant_raw_value(const char *SpecName, const void *Value, size_t Size) noexcept
bool contains_specialization_constants() const noexcept
std::vector< kernel_id > get_kernel_ids() const
std::vector< std::pair< std::string, std::string > > include_pairs_t
bool is_specialization_constant_set(const char *SpecName) const noexcept
bool native_specialization_constant() const noexcept
std::string extractXsFlags(const std::vector< std::string > &BuildOptions)
kernel_bundle_impl(context Ctx, std::vector< device > Devs, const std::vector< kernel_id > &KernelIDs, bundle_state State)
kernel_bundle_impl(const kernel_bundle< bundle_state::input > &InputBundle, std::vector< device > Devs, const property_list &PropList, bundle_state TargetState)
const device_image_plain * begin() const
kernel_bundle_impl(context Ctx, std::vector< device > Devs, bundle_state State)
kernel_bundle_impl(context Ctx, std::vector< device > Devs, device_image_plain &DevImage, std::vector< std::string > KNames, syclex::source_language Lang)
void get_specialization_constant_raw_value(const char *SpecName, void *ValueRet) const noexcept
bool has_specialization_constant(const char *SpecName) const noexcept
std::shared_ptr< kernel_bundle_impl > build_from_source(const std::vector< device > Devices, const std::vector< std::string > &BuildOptions, std::string *LogPtr, const std::vector< std::string > &RegisteredKernelNames)
bool ext_oneapi_has_kernel(const std::string &Name)
bool has_kernel(const kernel_id &KernelID, const device &Dev) const noexcept
std::string trimXsFlags(std::string &str)
kernel get_kernel(const kernel_id &KernelID, const std::shared_ptr< detail::kernel_bundle_impl > &Self) const
kernel_bundle_impl(context Ctx, std::vector< device > Devs, const DevImgSelectorImpl &Selector, bundle_state State)
kernel_bundle_impl(const std::vector< detail::KernelBundleImplPtr > &Bundles, bundle_state State)
The SYCL device class encapsulates a single SYCL device on which kernels may be executed.
Definition: device.hpp:64
The kernel_bundle class represents collection of device images in a particular state.
Objects of the class identify kernel is some kernel_bundle related APIs.
const char * get_name() const noexcept
Provides an abstraction of a SYCL kernel.
Definition: kernel.hpp:71
Objects of the property_list class are containers for the SYCL properties.
::pi_kernel PiKernel
Definition: pi.hpp:112
::pi_program PiProgram
Definition: pi.hpp:111
std::shared_ptr< device_image_impl > DeviceImageImplPtr
decltype(Obj::impl) const & getSyclObjImpl(const Obj &SyclObject)
Definition: impl_utils.hpp:31
static bool checkAllDevicesAreInContext(const std::vector< device > &Devices, const context &Context)
std::function< bool(const detail::DeviceImageImplPtr &DevImgImpl)> DevImgSelectorImpl
static bool checkAllDevicesHaveAspect(const std::vector< device > &Devices, aspect Aspect)
std::shared_ptr< sycl::detail::context_impl > ContextImplPtr
Definition: event_impl.hpp:32
std::vector< std::string > split_string(std::string_view str, char delimeter)
std::shared_ptr< plugin > PluginPtr
Definition: pi.hpp:47
std::shared_ptr< detail::kernel_bundle_impl > KernelBundleImplPtr
Function for_each(Group g, Ptr first, Ptr last, Function f)
spirv_vec_t OpenCLC_to_SPIRV(const std::string &Source, const std::vector< uint32_t > &IPVersionVec, const std::vector< std::string > &UserArgs, std::string *LogPtr)
std::vector< std::pair< std::string, std::string > > include_pairs_t
spirv_vec_t SYCL_to_SPIRV(const std::string &SYCLSource, include_pairs_t IncludePairs, const std::vector< std::string > &UserArgs, std::string *LogPtr, const std::vector< std::string > &RegisteredKernelNames)
kernel_bundle< bundle_state::executable > build(const kernel_bundle< bundle_state::input > &InputBundle, const std::vector< device > &Devs, const property_list &PropList={})
kernel_bundle< bundle_state::object > compile(const kernel_bundle< bundle_state::input > &InputBundle, const std::vector< device > &Devs, const property_list &PropList={})
std::error_code make_error_code(sycl::errc E) noexcept
Constructs an error code using e and sycl_category()
Definition: exception.cpp:64
Definition: access.hpp:18
pi_result piKernelCreate(pi_program program, const char *kernel_name, pi_kernel *ret_kernel)
Definition: pi_cuda.cpp:341
pi_result piProgramGetInfo(pi_program program, pi_program_info param_name, size_t param_value_size, void *param_value, size_t *param_value_size_ret)
Definition: pi_cuda.cpp:272
pi_result piProgramBuild(pi_program program, pi_uint32 num_devices, const pi_device *device_list, const char *options, void(*pfn_notify)(pi_program program, void *user_data), void *user_data)
pi_result piProgramCreate(pi_context context, const void *il, size_t length, pi_program *res_program)
Definition: pi_cuda.cpp:248
@ PI_PROGRAM_INFO_KERNEL_NAMES
Definition: pi.h:555
@ PI_PROGRAM_INFO_NUM_KERNELS
Definition: pi.h:554
bool any_of(const simd_mask< _Tp, _Abi > &) noexcept
bool all_of(const simd_mask< _Tp, _Abi > &) noexcept
_Abi const simd< _Tp, _Abi > & noexcept
Definition: simd.hpp:1324
bool none_of(const simd_mask< _Tp, _Abi > &) noexcept