DPC++ Runtime
Runtime libraries for oneAPI DPC++
graph_impl.hpp
Go to the documentation of this file.
1 //==--------- graph_impl.hpp --- SYCL graph extension ---------------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #include <sycl/detail/cg_types.hpp>
12 #include <sycl/detail/os_util.hpp>
14 #include <sycl/handler.hpp>
15 
16 #include <detail/accessor_impl.hpp>
17 #include <detail/event_impl.hpp>
18 #include <detail/kernel_impl.hpp>
19 
20 #include <cstring>
21 #include <deque>
22 #include <fstream>
23 #include <functional>
24 #include <list>
25 #include <set>
26 #include <shared_mutex>
27 
28 namespace sycl {
29 inline namespace _V1 {
30 
31 namespace detail {
32 class SYCLMemObjT;
33 }
34 
35 namespace ext {
36 namespace oneapi {
37 namespace experimental {
38 namespace detail {
39 
41  using sycl::detail::CG;
42 
43  switch (CGType) {
44  case CG::None:
45  return node_type::empty;
46  case CG::Kernel:
47  return node_type::kernel;
48  case CG::CopyAccToPtr:
49  case CG::CopyPtrToAcc:
50  case CG::CopyAccToAcc:
51  case CG::CopyUSM:
52  return node_type::memcpy;
53  case CG::Memset2DUSM:
54  return node_type::memset;
55  case CG::Fill:
56  case CG::FillUSM:
57  return node_type::memfill;
58  case CG::PrefetchUSM:
59  return node_type::prefetch;
60  case CG::AdviseUSM:
61  return node_type::memadvise;
62  case CG::Barrier:
63  case CG::BarrierWaitlist:
65  case CG::CodeplayHostTask:
66  return node_type::host_task;
67  case CG::ExecCommandBuffer:
68  return node_type::subgraph;
69  default:
70  assert(false && "Invalid Graph Node Type");
71  return node_type::empty;
72  }
73 }
74 
76 class node_impl {
77 public:
79  std::vector<std::weak_ptr<node_impl>> MSuccessors;
83  std::vector<std::weak_ptr<node_impl>> MPredecessors;
85  sycl::detail::CG::CGTYPE MCGType = sycl::detail::CG::None;
89  std::unique_ptr<sycl::detail::CG> MCommandGroup;
92  std::shared_ptr<exec_graph_impl> MSubGraphImpl;
93 
95  bool MVisited = false;
96 
100  int MPartitionNum = -1;
101 
108  void registerSuccessor(const std::shared_ptr<node_impl> &Node,
109  const std::shared_ptr<node_impl> &Prev) {
110  if (std::find_if(MSuccessors.begin(), MSuccessors.end(),
111  [Node](const std::weak_ptr<node_impl> &Ptr) {
112  return Ptr.lock() == Node;
113  }) != MSuccessors.end()) {
114  return;
115  }
116  MSuccessors.push_back(Node);
117  Node->registerPredecessor(Prev);
118  }
119 
122  void registerPredecessor(const std::shared_ptr<node_impl> &Node) {
123  if (std::find_if(MPredecessors.begin(), MPredecessors.end(),
124  [&Node](const std::weak_ptr<node_impl> &Ptr) {
125  return Ptr.lock() == Node;
126  }) != MPredecessors.end()) {
127  return;
128  }
129  MPredecessors.push_back(Node);
130  }
131 
134 
140  std::unique_ptr<sycl::detail::CG> &&CommandGroup)
141  : MCGType(CommandGroup->getType()), MNodeType(NodeType),
142  MCommandGroup(std::move(CommandGroup)) {
143  if (NodeType == node_type::subgraph) {
144  MSubGraphImpl =
145  static_cast<sycl::detail::CGExecCommandBuffer *>(MCommandGroup.get())
146  ->MExecGraph;
147  }
148  }
149 
154  MCGType(Other.MCGType), MNodeType(Other.MNodeType),
156 
160  bool hasRequirement(sycl::detail::AccessorImplHost *IncomingReq) {
161  for (sycl::detail::AccessorImplHost *CurrentReq :
162  MCommandGroup->getRequirements()) {
163  if (IncomingReq->MSYCLMemObj == CurrentReq->MSYCLMemObj) {
164  return true;
165  }
166  }
167  return false;
168  }
169 
174  bool isEmpty() const {
175  return ((MCGType == sycl::detail::CG::None) ||
176  (MCGType == sycl::detail::CG::Barrier));
177  }
178 
181  std::unique_ptr<sycl::detail::CG> getCGCopy() const {
182  switch (MCGType) {
183  case sycl::detail::CG::Kernel:
184  return createCGCopy<sycl::detail::CGExecKernel>();
185  case sycl::detail::CG::CopyAccToPtr:
186  case sycl::detail::CG::CopyPtrToAcc:
187  case sycl::detail::CG::CopyAccToAcc:
188  return createCGCopy<sycl::detail::CGCopy>();
189  case sycl::detail::CG::Fill:
190  return createCGCopy<sycl::detail::CGFill>();
191  case sycl::detail::CG::UpdateHost:
192  return createCGCopy<sycl::detail::CGUpdateHost>();
193  case sycl::detail::CG::CopyUSM:
194  return createCGCopy<sycl::detail::CGCopyUSM>();
195  case sycl::detail::CG::FillUSM:
196  return createCGCopy<sycl::detail::CGFillUSM>();
197  case sycl::detail::CG::PrefetchUSM:
198  return createCGCopy<sycl::detail::CGPrefetchUSM>();
199  case sycl::detail::CG::AdviseUSM:
200  return createCGCopy<sycl::detail::CGAdviseUSM>();
201  case sycl::detail::CG::Copy2DUSM:
202  return createCGCopy<sycl::detail::CGCopy2DUSM>();
203  case sycl::detail::CG::Fill2DUSM:
204  return createCGCopy<sycl::detail::CGFill2DUSM>();
205  case sycl::detail::CG::Memset2DUSM:
206  return createCGCopy<sycl::detail::CGMemset2DUSM>();
207  case sycl::detail::CG::CodeplayHostTask: {
208  // The unique_ptr to the `sycl::detail::HostTask` in the HostTask CG
209  // prevents from copying the CG.
210  // We overcome this restriction by creating a new CG with the same data.
211  auto CommandGroupPtr =
212  static_cast<sycl::detail::CGHostTask *>(MCommandGroup.get());
213  sycl::detail::HostTask HostTask = *CommandGroupPtr->MHostTask.get();
214  auto HostTaskUPtr = std::make_unique<sycl::detail::HostTask>(HostTask);
215 
217  CommandGroupPtr->getArgsStorage(), CommandGroupPtr->getAccStorage(),
218  CommandGroupPtr->getSharedPtrStorage(),
219  CommandGroupPtr->getRequirements(), CommandGroupPtr->getEvents());
220 
221  sycl::detail::code_location Loc(CommandGroupPtr->MFileName.data(),
222  CommandGroupPtr->MFunctionName.data(),
223  CommandGroupPtr->MLine,
224  CommandGroupPtr->MColumn);
225 
226  return std::make_unique<sycl::detail::CGHostTask>(
227  sycl::detail::CGHostTask(
228  std::move(HostTaskUPtr), CommandGroupPtr->MQueue,
229  CommandGroupPtr->MContext, CommandGroupPtr->MArgs, Data,
230  CommandGroupPtr->getType(), Loc));
231  }
232  case sycl::detail::CG::Barrier:
233  case sycl::detail::CG::BarrierWaitlist:
234  // Barrier nodes are stored in the graph with only the base CG class,
235  // since they are treated internally as empty nodes.
236  return createCGCopy<sycl::detail::CG>();
237  case sycl::detail::CG::CopyToDeviceGlobal:
238  return createCGCopy<sycl::detail::CGCopyToDeviceGlobal>();
239  case sycl::detail::CG::CopyFromDeviceGlobal:
240  return createCGCopy<sycl::detail::CGCopyFromDeviceGlobal>();
241  case sycl::detail::CG::ReadWriteHostPipe:
242  return createCGCopy<sycl::detail::CGReadWriteHostPipe>();
243  case sycl::detail::CG::CopyImage:
244  return createCGCopy<sycl::detail::CGCopyImage>();
245  case sycl::detail::CG::SemaphoreSignal:
246  return createCGCopy<sycl::detail::CGSemaphoreSignal>();
247  case sycl::detail::CG::SemaphoreWait:
248  return createCGCopy<sycl::detail::CGSemaphoreWait>();
249  case sycl::detail::CG::ExecCommandBuffer:
250  return createCGCopy<sycl::detail::CGExecCommandBuffer>();
251  case sycl::detail::CG::None:
252  return nullptr;
253  }
254  return nullptr;
255  }
256 
262  bool isSimilar(const std::shared_ptr<node_impl> &Node,
263  bool CompareContentOnly = false) const {
264  if (!CompareContentOnly) {
265  if (MSuccessors.size() != Node->MSuccessors.size())
266  return false;
267 
268  if (MPredecessors.size() != Node->MPredecessors.size())
269  return false;
270  }
271  if (MCGType != Node->MCGType)
272  return false;
273 
274  switch (MCGType) {
275  case sycl::detail::CG::CGTYPE::Kernel: {
276  sycl::detail::CGExecKernel *ExecKernelA =
277  static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get());
278  sycl::detail::CGExecKernel *ExecKernelB =
279  static_cast<sycl::detail::CGExecKernel *>(Node->MCommandGroup.get());
280  return ExecKernelA->MKernelName.compare(ExecKernelB->MKernelName) == 0;
281  }
282  case sycl::detail::CG::CGTYPE::CopyUSM: {
283  sycl::detail::CGCopyUSM *CopyA =
284  static_cast<sycl::detail::CGCopyUSM *>(MCommandGroup.get());
285  sycl::detail::CGCopyUSM *CopyB =
286  static_cast<sycl::detail::CGCopyUSM *>(Node->MCommandGroup.get());
287  return (CopyA->getSrc() == CopyB->getSrc()) &&
288  (CopyA->getDst() == CopyB->getDst()) &&
289  (CopyA->getLength() == CopyB->getLength());
290  }
291  case sycl::detail::CG::CGTYPE::CopyAccToAcc:
292  case sycl::detail::CG::CGTYPE::CopyAccToPtr:
293  case sycl::detail::CG::CGTYPE::CopyPtrToAcc: {
294  sycl::detail::CGCopy *CopyA =
295  static_cast<sycl::detail::CGCopy *>(MCommandGroup.get());
296  sycl::detail::CGCopy *CopyB =
297  static_cast<sycl::detail::CGCopy *>(Node->MCommandGroup.get());
298  return (CopyA->getSrc() == CopyB->getSrc()) &&
299  (CopyA->getDst() == CopyB->getDst());
300  }
301  default:
302  assert(false && "Unexpected command group type!");
303  return false;
304  }
305  }
306 
313  void printDotRecursive(std::fstream &Stream,
314  std::vector<node_impl *> &Visited, bool Verbose) {
315  // if Node has been already visited, we skip it
316  if (std::find(Visited.begin(), Visited.end(), this) != Visited.end())
317  return;
318 
319  Visited.push_back(this);
320 
321  printDotCG(Stream, Verbose);
322  for (const auto &Dep : MPredecessors) {
323  auto NodeDep = Dep.lock();
324  Stream << " \"" << NodeDep.get() << "\" -> \"" << this << "\""
325  << std::endl;
326  }
327 
328  for (std::weak_ptr<node_impl> Succ : MSuccessors) {
329  if (MPartitionNum == Succ.lock()->MPartitionNum)
330  Succ.lock()->printDotRecursive(Stream, Visited, Verbose);
331  }
332  }
333 
334 private:
339  void printDotCG(std::ostream &Stream, bool Verbose) {
340  Stream << "\"" << this << "\" [style=bold, label=\"";
341 
342  Stream << "ID = " << this << "\\n";
343  Stream << "TYPE = ";
344 
345  switch (MCGType) {
346  case sycl::detail::CG::CGTYPE::None:
347  Stream << "None \\n";
348  break;
349  case sycl::detail::CG::CGTYPE::Kernel: {
350  Stream << "CGExecKernel \\n";
351  sycl::detail::CGExecKernel *Kernel =
352  static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get());
353  Stream << "NAME = " << Kernel->MKernelName << "\\n";
354  if (Verbose) {
355  Stream << "ARGS = \\n";
356  for (size_t i = 0; i < Kernel->MArgs.size(); i++) {
357  auto Arg = Kernel->MArgs[i];
358  std::string Type = "Undefined";
359  if (Arg.MType == sycl::detail::kernel_param_kind_t::kind_accessor) {
360  Type = "Accessor";
361  } else if (Arg.MType ==
362  sycl::detail::kernel_param_kind_t::kind_std_layout) {
363  Type = "STD_Layout";
364  } else if (Arg.MType ==
365  sycl::detail::kernel_param_kind_t::kind_sampler) {
366  Type = "Sampler";
367  } else if (Arg.MType ==
368  sycl::detail::kernel_param_kind_t::kind_pointer) {
369  Type = "Pointer";
370  } else if (Arg.MType == sycl::detail::kernel_param_kind_t::
371  kind_specialization_constants_buffer) {
372  Type = "Specialization Constants Buffer";
373  } else if (Arg.MType ==
374  sycl::detail::kernel_param_kind_t::kind_stream) {
375  Type = "Stream";
376  } else if (Arg.MType ==
377  sycl::detail::kernel_param_kind_t::kind_invalid) {
378  Type = "Invalid";
379  }
380  Stream << i << ") Type: " << Type << " Ptr: " << Arg.MPtr << "\\n";
381  }
382  }
383  break;
384  }
385  case sycl::detail::CG::CGTYPE::CopyAccToPtr:
386  Stream << "CGCopy Device-to-Host \\n";
387  if (Verbose) {
388  sycl::detail::CGCopy *Copy =
389  static_cast<sycl::detail::CGCopy *>(MCommandGroup.get());
390  Stream << "Src: " << Copy->getSrc() << " Dst: " << Copy->getDst()
391  << "\\n";
392  }
393  break;
394  case sycl::detail::CG::CGTYPE::CopyPtrToAcc:
395  Stream << "CGCopy Host-to-Device \\n";
396  if (Verbose) {
397  sycl::detail::CGCopy *Copy =
398  static_cast<sycl::detail::CGCopy *>(MCommandGroup.get());
399  Stream << "Src: " << Copy->getSrc() << " Dst: " << Copy->getDst()
400  << "\\n";
401  }
402  break;
403  case sycl::detail::CG::CGTYPE::CopyAccToAcc:
404  Stream << "CGCopy Device-to-Device \\n";
405  if (Verbose) {
406  sycl::detail::CGCopy *Copy =
407  static_cast<sycl::detail::CGCopy *>(MCommandGroup.get());
408  Stream << "Src: " << Copy->getSrc() << " Dst: " << Copy->getDst()
409  << "\\n";
410  }
411  break;
412  case sycl::detail::CG::CGTYPE::Fill:
413  Stream << "CGFill \\n";
414  if (Verbose) {
415  sycl::detail::CGFill *Fill =
416  static_cast<sycl::detail::CGFill *>(MCommandGroup.get());
417  Stream << "Ptr: " << Fill->MPtr << "\\n";
418  }
419  break;
420  case sycl::detail::CG::CGTYPE::UpdateHost:
421  Stream << "CGCUpdateHost \\n";
422  if (Verbose) {
423  sycl::detail::CGUpdateHost *Host =
424  static_cast<sycl::detail::CGUpdateHost *>(MCommandGroup.get());
425  Stream << "Ptr: " << Host->getReqToUpdate() << "\\n";
426  }
427  break;
428  case sycl::detail::CG::CGTYPE::CopyUSM:
429  Stream << "CGCopyUSM \\n";
430  if (Verbose) {
431  sycl::detail::CGCopyUSM *CopyUSM =
432  static_cast<sycl::detail::CGCopyUSM *>(MCommandGroup.get());
433  Stream << "Src: " << CopyUSM->getSrc() << " Dst: " << CopyUSM->getDst()
434  << " Length: " << CopyUSM->getLength() << "\\n";
435  }
436  break;
437  case sycl::detail::CG::CGTYPE::FillUSM:
438  Stream << "CGFillUSM \\n";
439  if (Verbose) {
440  sycl::detail::CGFillUSM *FillUSM =
441  static_cast<sycl::detail::CGFillUSM *>(MCommandGroup.get());
442  Stream << "Dst: " << FillUSM->getDst()
443  << " Length: " << FillUSM->getLength()
444  << " Pattern: " << FillUSM->getFill() << "\\n";
445  }
446  break;
447  case sycl::detail::CG::CGTYPE::PrefetchUSM:
448  Stream << "CGPrefetchUSM \\n";
449  if (Verbose) {
450  sycl::detail::CGPrefetchUSM *Prefetch =
451  static_cast<sycl::detail::CGPrefetchUSM *>(MCommandGroup.get());
452  Stream << "Dst: " << Prefetch->getDst()
453  << " Length: " << Prefetch->getLength() << "\\n";
454  }
455  break;
456  case sycl::detail::CG::CGTYPE::AdviseUSM:
457  Stream << "CGAdviseUSM \\n";
458  if (Verbose) {
459  sycl::detail::CGAdviseUSM *AdviseUSM =
460  static_cast<sycl::detail::CGAdviseUSM *>(MCommandGroup.get());
461  Stream << "Dst: " << AdviseUSM->getDst()
462  << " Length: " << AdviseUSM->getLength() << "\\n";
463  }
464  break;
465  case sycl::detail::CG::CGTYPE::CodeplayHostTask:
466  Stream << "CGHostTask \\n";
467  break;
468  case sycl::detail::CG::CGTYPE::Barrier:
469  Stream << "CGBarrier \\n";
470  break;
471  case sycl::detail::CG::CGTYPE::Copy2DUSM:
472  Stream << "CGCopy2DUSM \\n";
473  if (Verbose) {
474  sycl::detail::CGCopy2DUSM *Copy2DUSM =
475  static_cast<sycl::detail::CGCopy2DUSM *>(MCommandGroup.get());
476  Stream << "Src:" << Copy2DUSM->getSrc()
477  << " Dst: " << Copy2DUSM->getDst() << "\\n";
478  }
479  break;
480  case sycl::detail::CG::CGTYPE::Fill2DUSM:
481  Stream << "CGFill2DUSM \\n";
482  if (Verbose) {
483  sycl::detail::CGFill2DUSM *Fill2DUSM =
484  static_cast<sycl::detail::CGFill2DUSM *>(MCommandGroup.get());
485  Stream << "Dst: " << Fill2DUSM->getDst() << "\\n";
486  }
487  break;
488  case sycl::detail::CG::CGTYPE::Memset2DUSM:
489  Stream << "CGMemset2DUSM \\n";
490  if (Verbose) {
491  sycl::detail::CGMemset2DUSM *Memset2DUSM =
492  static_cast<sycl::detail::CGMemset2DUSM *>(MCommandGroup.get());
493  Stream << "Dst: " << Memset2DUSM->getDst() << "\\n";
494  }
495  break;
496  case sycl::detail::CG::CGTYPE::ReadWriteHostPipe:
497  Stream << "CGReadWriteHostPipe \\n";
498  break;
499  case sycl::detail::CG::CGTYPE::CopyToDeviceGlobal:
500  Stream << "CGCopyToDeviceGlobal \\n";
501  if (Verbose) {
502  sycl::detail::CGCopyToDeviceGlobal *CopyToDeviceGlobal =
503  static_cast<sycl::detail::CGCopyToDeviceGlobal *>(
504  MCommandGroup.get());
505  Stream << "Src: " << CopyToDeviceGlobal->getSrc()
506  << " Dst: " << CopyToDeviceGlobal->getDeviceGlobalPtr() << "\\n";
507  }
508  break;
509  case sycl::detail::CG::CGTYPE::CopyFromDeviceGlobal:
510  Stream << "CGCopyFromDeviceGlobal \\n";
511  if (Verbose) {
512  sycl::detail::CGCopyFromDeviceGlobal *CopyFromDeviceGlobal =
513  static_cast<sycl::detail::CGCopyFromDeviceGlobal *>(
514  MCommandGroup.get());
515  Stream << "Src: " << CopyFromDeviceGlobal->getDeviceGlobalPtr()
516  << " Dst: " << CopyFromDeviceGlobal->getDest() << "\\n";
517  }
518  break;
519  case sycl::detail::CG::CGTYPE::ExecCommandBuffer:
520  Stream << "CGExecCommandBuffer \\n";
521  break;
522  default:
523  Stream << "Other \\n";
524  break;
525  }
526  Stream << "\"];" << std::endl;
527  }
528 
533  template <typename CGT> std::unique_ptr<CGT> createCGCopy() const {
534  return std::make_unique<CGT>(*static_cast<CGT *>(MCommandGroup.get()));
535  }
536 };
537 
538 class partition {
539 public:
542 
544  std::set<std::weak_ptr<node_impl>, std::owner_less<std::weak_ptr<node_impl>>>
547  std::list<std::shared_ptr<node_impl>> MSchedule;
549  std::unordered_map<sycl::device, sycl::detail::pi::PiExtCommandBuffer>
552  std::vector<std::shared_ptr<partition>> MPredecessors;
553 
555  bool isHostTask() const {
556  return (MRoots.size() && ((*MRoots.begin()).lock()->MCGType ==
557  sycl::detail::CG::CGTYPE::CodeplayHostTask));
558  }
559 
561  void schedule();
562 };
563 
565 class graph_impl {
566 public:
567  using ReadLock = std::shared_lock<std::shared_mutex>;
568  using WriteLock = std::unique_lock<std::shared_mutex>;
569 
571  mutable std::shared_mutex MMutex;
572 
577  graph_impl(const sycl::context &SyclContext, const sycl::device &SyclDevice,
578  const sycl::property_list &PropList = {})
579  : MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(),
580  MEventsMap(), MInorderQueueMap() {
581  if (PropList.has_property<property::graph::no_cycle_check>()) {
582  MSkipCycleChecks = true;
583  }
584  if (PropList
585  .has_property<property::graph::assume_buffer_outlives_graph>()) {
586  MAllowBuffers = true;
587  }
588 
589  if (SyclDevice.get_info<
590  ext::oneapi::experimental::info::device::graph_support>() ==
592  std::stringstream Stream;
593  Stream << SyclDevice.get_backend();
594  std::string BackendString = Stream.str();
595  throw sycl::exception(
597  BackendString + " backend is not supported by SYCL Graph extension.");
598  }
599  }
600 
601  ~graph_impl();
602 
605  void removeRoot(const std::shared_ptr<node_impl> &Root);
606 
612  std::shared_ptr<node_impl>
613  add(node_type NodeType, std::unique_ptr<sycl::detail::CG> CommandGroup,
614  const std::vector<std::shared_ptr<node_impl>> &Dep = {});
615 
622  std::shared_ptr<node_impl>
623  add(const std::shared_ptr<graph_impl> &Impl,
624  std::function<void(handler &)> CGF,
625  const std::vector<sycl::detail::ArgDesc> &Args,
626  const std::vector<std::shared_ptr<node_impl>> &Dep = {});
627 
632  std::shared_ptr<node_impl>
633  add(const std::shared_ptr<graph_impl> &Impl,
634  const std::vector<std::shared_ptr<node_impl>> &Dep = {});
635 
640  std::shared_ptr<node_impl>
641  add(const std::shared_ptr<graph_impl> &Impl,
642  const std::vector<sycl::detail::EventImplPtr> Events);
643 
647  void
648  addQueue(const std::shared_ptr<sycl::detail::queue_impl> &RecordingQueue) {
649  MRecordingQueues.insert(RecordingQueue);
650  }
651 
655  void
656  removeQueue(const std::shared_ptr<sycl::detail::queue_impl> &RecordingQueue) {
657  MRecordingQueues.erase(RecordingQueue);
658  }
659 
664  bool clearQueues();
665 
671  void addEventForNode(std::shared_ptr<graph_impl> GraphImpl,
672  std::shared_ptr<sycl::detail::event_impl> EventImpl,
673  std::shared_ptr<node_impl> NodeImpl) {
674  if (!EventImpl->getCommandGraph())
675  EventImpl->setCommandGraph(GraphImpl);
676  MEventsMap[EventImpl] = NodeImpl;
677  }
678 
682  std::shared_ptr<sycl::detail::event_impl>
683  getEventForNode(std::shared_ptr<node_impl> NodeImpl) const {
684  ReadLock Lock(MMutex);
685  if (auto EventImpl = std::find_if(
686  MEventsMap.begin(), MEventsMap.end(),
687  [NodeImpl](auto &it) { return it.second == NodeImpl; });
688  EventImpl != MEventsMap.end()) {
689  return EventImpl->first;
690  }
691 
692  throw sycl::exception(
694  "No event has been recorded for the specified graph node");
695  }
696 
701  std::shared_ptr<node_impl>
702  getNodeForEvent(std::shared_ptr<sycl::detail::event_impl> EventImpl) {
703  ReadLock Lock(MMutex);
704 
705  if (auto NodeFound = MEventsMap.find(EventImpl);
706  NodeFound != std::end(MEventsMap)) {
707  return NodeFound->second;
708  }
709 
710  throw sycl::exception(
712  "No node in this graph is associated with this event");
713  }
714 
717  sycl::context getContext() const { return MContext; }
718 
721  sycl::device getDevice() const { return MDevice; }
722 
724  std::set<std::weak_ptr<node_impl>, std::owner_less<std::weak_ptr<node_impl>>>
726 
731  std::vector<std::shared_ptr<node_impl>> MNodeStorage;
732 
737  std::shared_ptr<node_impl>
738  getLastInorderNode(std::shared_ptr<sycl::detail::queue_impl> Queue) {
739  std::weak_ptr<sycl::detail::queue_impl> QueueWeakPtr(Queue);
740  if (0 == MInorderQueueMap.count(QueueWeakPtr)) {
741  return {};
742  }
743  return MInorderQueueMap[QueueWeakPtr];
744  }
745 
749  void setLastInorderNode(std::shared_ptr<sycl::detail::queue_impl> Queue,
750  std::shared_ptr<node_impl> Node) {
751  std::weak_ptr<sycl::detail::queue_impl> QueueWeakPtr(Queue);
752  MInorderQueueMap[QueueWeakPtr] = Node;
753  }
754 
759  void printGraphAsDot(const std::string FilePath, bool Verbose) const {
761  std::vector<node_impl *> VisitedNodes;
762 
763  std::fstream Stream(FilePath, std::ios::out);
764  Stream << "digraph dot {" << std::endl;
765 
766  for (std::weak_ptr<node_impl> Node : MRoots)
767  Node.lock()->printDotRecursive(Stream, VisitedNodes, Verbose);
768 
769  Stream << "}" << std::endl;
770 
771  Stream.close();
772  }
773 
779  void makeEdge(std::shared_ptr<node_impl> Src,
780  std::shared_ptr<node_impl> Dest);
781 
785  void throwIfGraphRecordingQueue(const std::string ExceptionMsg) const {
786  if (MRecordingQueues.size()) {
787  throw sycl::exception(make_error_code(sycl::errc::invalid),
788  ExceptionMsg +
789  " cannot be called when a queue "
790  "is currently recording commands to a graph.");
791  }
792  }
793 
798  static bool checkNodeRecursive(const std::shared_ptr<node_impl> &NodeA,
799  const std::shared_ptr<node_impl> &NodeB) {
800  size_t FoundCnt = 0;
801  for (std::weak_ptr<node_impl> &SuccA : NodeA->MSuccessors) {
802  for (std::weak_ptr<node_impl> &SuccB : NodeB->MSuccessors) {
803  if (NodeA->isSimilar(NodeB) &&
804  checkNodeRecursive(SuccA.lock(), SuccB.lock())) {
805  FoundCnt++;
806  break;
807  }
808  }
809  }
810  if (FoundCnt != NodeA->MSuccessors.size()) {
811  return false;
812  }
813 
814  return true;
815  }
816 
828  bool hasSimilarStructure(std::shared_ptr<detail::graph_impl> Graph,
829  bool DebugPrint = false) const {
830  if (this == Graph.get())
831  return true;
832 
833  if (MContext != Graph->MContext) {
834  if (DebugPrint) {
836  "MContext are not the same.");
837  }
838  return false;
839  }
840 
841  if (MDevice != Graph->MDevice) {
842  if (DebugPrint) {
844  "MDevice are not the same.");
845  }
846  return false;
847  }
848 
849  if (MEventsMap.size() != Graph->MEventsMap.size()) {
850  if (DebugPrint) {
852  "MEventsMap sizes are not the same.");
853  }
854  return false;
855  }
856 
857  if (MInorderQueueMap.size() != Graph->MInorderQueueMap.size()) {
858  if (DebugPrint) {
860  "MInorderQueueMap sizes are not the same.");
861  }
862  return false;
863  }
864 
865  if (MRoots.size() != Graph->MRoots.size()) {
866  if (DebugPrint) {
868  "MRoots sizes are not the same.");
869  }
870  return false;
871  }
872 
873  size_t RootsFound = 0;
874  for (std::weak_ptr<node_impl> NodeA : MRoots) {
875  for (std::weak_ptr<node_impl> NodeB : Graph->MRoots) {
876  auto NodeALocked = NodeA.lock();
877  auto NodeBLocked = NodeB.lock();
878 
879  if (NodeALocked->isSimilar(NodeBLocked)) {
880  if (checkNodeRecursive(NodeALocked, NodeBLocked)) {
881  RootsFound++;
882  break;
883  }
884  }
885  }
886  }
887 
888  if (RootsFound != MRoots.size()) {
889  if (DebugPrint) {
891  "Root Nodes do NOT match.");
892  }
893  return false;
894  }
895 
896  return true;
897  }
898 
901  size_t getNumberOfNodes() const { return MNodeStorage.size(); }
902 
906  std::vector<sycl::detail::EventImplPtr> getExitNodesEvents();
907 
911  std::vector<sycl::detail::EventImplPtr>
913  std::vector<sycl::detail::EventImplPtr> Events;
914  for (auto It = MExtraDependencies.begin();
915  It != MExtraDependencies.end();) {
916  if ((*It)->MCGType == sycl::detail::CG::Barrier) {
917  Events.push_back(getEventForNode(*It));
918  It = MExtraDependencies.erase(It);
919  } else {
920  ++It;
921  }
922  }
923  return Events;
924  }
925 
926 private:
933  void
934  searchDepthFirst(std::function<bool(std::shared_ptr<node_impl> &,
935  std::deque<std::shared_ptr<node_impl>> &)>
936  NodeFunc);
937 
942  bool checkForCycles();
943 
946  void addRoot(const std::shared_ptr<node_impl> &Root);
947 
952  std::shared_ptr<node_impl>
953  addNodesToExits(const std::shared_ptr<graph_impl> &Impl,
954  const std::list<std::shared_ptr<node_impl>> &NodeList);
955 
960  void addDepsToNode(std::shared_ptr<node_impl> Node,
961  const std::vector<std::shared_ptr<node_impl>> &Deps) {
962  if (!Deps.empty()) {
963  for (auto &N : Deps) {
964  N->registerSuccessor(Node, N);
965  this->removeRoot(Node);
966  }
967  } else {
968  this->addRoot(Node);
969  }
970  }
971 
973  sycl::context MContext;
976  sycl::device MDevice;
978  std::set<std::weak_ptr<sycl::detail::queue_impl>,
979  std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
980  MRecordingQueues;
982  std::unordered_map<std::shared_ptr<sycl::detail::event_impl>,
983  std::shared_ptr<node_impl>>
984  MEventsMap;
988  std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
989  std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
990  MInorderQueueMap;
993  bool MSkipCycleChecks = false;
995  std::set<sycl::detail::SYCLMemObjT *> MMemObjs;
996 
999  bool MAllowBuffers = false;
1000 
1005  std::list<std::shared_ptr<node_impl>> MExtraDependencies;
1006 };
1007 
1010 public:
1011  using ReadLock = std::shared_lock<std::shared_mutex>;
1012  using WriteLock = std::unique_lock<std::shared_mutex>;
1013 
1015  mutable std::shared_mutex MMutex;
1016 
1025  const std::shared_ptr<graph_impl> &GraphImpl)
1026  : MSchedule(), MGraphImpl(GraphImpl), MPiSyncPoints(), MContext(Context),
1027  MRequirements(), MExecutionEvents() {
1028  // Copy nodes from GraphImpl and merge any subgraph nodes into this graph.
1029  duplicateNodes();
1030  }
1031 
1035  ~exec_graph_impl();
1036 
1041  void makePartitions();
1042 
1048  sycl::event enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
1050 
1056  void createCommandBuffers(sycl::device Device,
1057  std::shared_ptr<partition> &Partition);
1058 
1061  sycl::context getContext() const { return MContext; }
1062 
1065  const std::list<std::shared_ptr<node_impl>> &getSchedule() const {
1066  return MSchedule;
1067  }
1068 
1071  const std::shared_ptr<graph_impl> &getGraphImpl() const { return MGraphImpl; }
1072 
1075  const std::vector<std::shared_ptr<partition>> &getPartitions() const {
1076  return MPartitions;
1077  }
1078 
1085  for (auto Event : MExecutionEvents) {
1086  if (!Event->isCompleted()) {
1087  return false;
1088  }
1089  }
1090  return true;
1091  }
1092 
1094  std::vector<sycl::detail::AccessorImplHost *> getRequirements() const {
1095  return MRequirements;
1096  }
1097 
1098 private:
1107  enqueueNode(sycl::context Ctx, sycl::detail::DeviceImplPtr DeviceImpl,
1109  std::shared_ptr<node_impl> Node);
1110 
1119  enqueueNodeDirect(sycl::context Ctx, sycl::detail::DeviceImplPtr DeviceImpl,
1121  std::shared_ptr<node_impl> Node);
1122 
1129  void findRealDeps(std::vector<sycl::detail::pi::PiExtSyncPoint> &Deps,
1130  std::shared_ptr<node_impl> CurrentNode,
1131  int ReferencePartitionNum);
1132 
1136  void duplicateNodes();
1137 
1142  void printGraphAsDot(const std::string FilePath, bool Verbose) const {
1144  std::vector<node_impl *> VisitedNodes;
1145 
1146  std::fstream Stream(FilePath, std::ios::out);
1147  Stream << "digraph dot {" << std::endl;
1148 
1149  std::vector<std::shared_ptr<node_impl>> Roots;
1150  for (auto &Node : MNodeStorage) {
1151  if (Node->MPredecessors.size() == 0) {
1152  Roots.push_back(Node);
1153  }
1154  }
1155 
1156  for (std::shared_ptr<node_impl> Node : Roots)
1157  Node->printDotRecursive(Stream, VisitedNodes, Verbose);
1158 
1159  Stream << "}" << std::endl;
1160 
1161  Stream.close();
1162  }
1163 
1165  std::list<std::shared_ptr<node_impl>> MSchedule;
1172  std::shared_ptr<graph_impl> MGraphImpl;
1175  std::unordered_map<std::shared_ptr<node_impl>,
1177  MPiSyncPoints;
1180  std::unordered_map<std::shared_ptr<node_impl>, int> MPartitionNodes;
1182  sycl::context MContext;
1185  std::vector<sycl::detail::AccessorImplHost *> MRequirements;
1188  std::vector<sycl::detail::AccessorImplPtr> MAccessors;
1190  std::vector<sycl::detail::EventImplPtr> MExecutionEvents;
1192  std::vector<std::shared_ptr<partition>> MPartitions;
1194  std::unordered_map<std::shared_ptr<partition>, sycl::detail::EventImplPtr>
1195  MPartitionsExecutionEvents;
1197  std::vector<std::shared_ptr<node_impl>> MNodeStorage;
1198 };
1199 
1200 } // namespace detail
1201 } // namespace experimental
1202 } // namespace oneapi
1203 } // namespace ext
1204 } // namespace _V1
1205 } // namespace sycl
The context class represents a SYCL context on which kernel functions may be executed.
Definition: context.hpp:51
CGTYPE
Type of the command group.
Definition: cg.hpp:56
The SYCL device class encapsulates a single SYCL device on which kernels may be executed.
Definition: device.hpp:59
detail::is_device_info_desc< Param >::return_type get_info() const
Queries this SYCL device for information requested by the template parameter param.
Definition: device.cpp:139
backend get_backend() const noexcept
Returns the backend associated with this device.
Definition: device.cpp:208
An event object can be used to synchronize memory transfers, enqueues of kernels and signaling barrie...
Definition: event.hpp:44
Class representing the implementation of command_graph<executable>.
void createCommandBuffers(sycl::device Device, std::shared_ptr< partition > &Partition)
Turns the internal graph representation into UR command-buffers for a device.
Definition: graph_impl.cpp:680
std::vector< sycl::detail::AccessorImplHost * > getRequirements() const
Returns a list of all the accessor requirements for this graph.
const std::vector< std::shared_ptr< partition > > & getPartitions() const
Query the vector of the partitions composing the exec_graph.
sycl::context getContext() const
Query for the context tied to this graph.
void makePartitions()
Partition the graph nodes and put the partition in MPartitions.
Definition: graph_impl.cpp:187
const std::list< std::shared_ptr< node_impl > > & getSchedule() const
Query the scheduling of node execution.
exec_graph_impl(sycl::context Context, const std::shared_ptr< graph_impl > &GraphImpl)
Constructor.
const std::shared_ptr< graph_impl > & getGraphImpl() const
Query the graph_impl.
std::shared_mutex MMutex
Protects all the fields that can be changed by class' methods.
bool previousSubmissionCompleted() const
Checks if the previous submissions of this graph have been completed This function checks the status ...
sycl::event enqueue(const std::shared_ptr< sycl::detail::queue_impl > &Queue, sycl::detail::CG::StorageInitHelper CGData)
Called by handler::ext_oneapi_command_graph() to schedule graph for execution.
Definition: graph_impl.cpp:763
Implementation details of command_graph<modifiable>.
Definition: graph_impl.hpp:565
std::vector< std::shared_ptr< node_impl > > MNodeStorage
Storage for all nodes contained within a graph.
Definition: graph_impl.hpp:731
void addQueue(const std::shared_ptr< sycl::detail::queue_impl > &RecordingQueue)
Add a queue to the set of queues which are currently recording to this graph.
Definition: graph_impl.hpp:648
void removeRoot(const std::shared_ptr< node_impl > &Root)
Remove node from list of root nodes.
Definition: graph_impl.cpp:347
std::unique_lock< std::shared_mutex > WriteLock
Definition: graph_impl.hpp:568
void makeEdge(std::shared_ptr< node_impl > Src, std::shared_ptr< node_impl > Dest)
Make an edge between two nodes in the graph.
Definition: graph_impl.cpp:548
std::shared_mutex MMutex
Protects all the fields that can be changed by class' methods.
Definition: graph_impl.hpp:571
void printGraphAsDot(const std::string FilePath, bool Verbose) const
Prints the contents of the graph to a text file in DOT format.
Definition: graph_impl.hpp:759
std::vector< sycl::detail::EventImplPtr > getExitNodesEvents()
Traverse the graph recursively to get the events associated with the output nodes of this graph.
Definition: graph_impl.cpp:599
void throwIfGraphRecordingQueue(const std::string ExceptionMsg) const
Throws an invalid exception if this function is called while a queue is recording commands to the gra...
Definition: graph_impl.hpp:785
std::shared_ptr< sycl::detail::event_impl > getEventForNode(std::shared_ptr< node_impl > NodeImpl) const
Find the sycl event associated with a node.
Definition: graph_impl.hpp:683
void setLastInorderNode(std::shared_ptr< sycl::detail::queue_impl > Queue, std::shared_ptr< node_impl > Node)
Track the last node added to this graph from an in-order queue.
Definition: graph_impl.hpp:749
std::set< std::weak_ptr< node_impl >, std::owner_less< std::weak_ptr< node_impl > > > MRoots
List of root nodes.
Definition: graph_impl.hpp:725
bool hasSimilarStructure(std::shared_ptr< detail::graph_impl > Graph, bool DebugPrint=false) const
Checks if the graph_impl of Graph has a similar structure to the graph_impl of the caller.
Definition: graph_impl.hpp:828
std::vector< sycl::detail::EventImplPtr > removeBarriersFromExtraDependencies()
Removes all Barrier nodes from the list of extra dependencies MExtraDependencies.
Definition: graph_impl.hpp:912
bool clearQueues()
Remove all queues which are recording to this graph, also sets all queues cleared back to the executi...
Definition: graph_impl.cpp:494
sycl::device getDevice() const
Query for the device tied to this graph.
Definition: graph_impl.hpp:721
std::shared_ptr< node_impl > getNodeForEvent(std::shared_ptr< sycl::detail::event_impl > EventImpl)
Find the node associated with a SYCL event.
Definition: graph_impl.hpp:702
static bool checkNodeRecursive(const std::shared_ptr< node_impl > &NodeA, const std::shared_ptr< node_impl > &NodeB)
Recursively check successors of NodeA and NodeB to check they are similar.
Definition: graph_impl.hpp:798
graph_impl(const sycl::context &SyclContext, const sycl::device &SyclDevice, const sycl::property_list &PropList={})
Constructor.
Definition: graph_impl.hpp:577
sycl::context getContext() const
Query for the context tied to this graph.
Definition: graph_impl.hpp:717
std::shared_ptr< node_impl > add(node_type NodeType, std::unique_ptr< sycl::detail::CG > CommandGroup, const std::vector< std::shared_ptr< node_impl >> &Dep={})
Create a kernel node in the graph.
Definition: graph_impl.cpp:421
void addEventForNode(std::shared_ptr< graph_impl > GraphImpl, std::shared_ptr< sycl::detail::event_impl > EventImpl, std::shared_ptr< node_impl > NodeImpl)
Associate a sycl event with a node in the graph.
Definition: graph_impl.hpp:671
std::shared_lock< std::shared_mutex > ReadLock
Definition: graph_impl.hpp:567
size_t getNumberOfNodes() const
Returns the number of nodes in the Graph.
Definition: graph_impl.hpp:901
std::shared_ptr< node_impl > getLastInorderNode(std::shared_ptr< sycl::detail::queue_impl > Queue)
Find the last node added to this graph from an in-order queue.
Definition: graph_impl.hpp:738
void removeQueue(const std::shared_ptr< sycl::detail::queue_impl > &RecordingQueue)
Remove a queue from the set of queues which are currently recording to this graph.
Definition: graph_impl.hpp:656
Implementation of node class from SYCL_EXT_ONEAPI_GRAPH.
Definition: graph_impl.hpp:76
void registerSuccessor(const std::shared_ptr< node_impl > &Node, const std::shared_ptr< node_impl > &Prev)
Add successor to the node.
Definition: graph_impl.hpp:108
std::vector< std::weak_ptr< node_impl > > MPredecessors
List of predecessors to this node.
Definition: graph_impl.hpp:83
node_impl(node_impl &Other)
Construct a node from another node.
Definition: graph_impl.hpp:152
std::vector< std::weak_ptr< node_impl > > MSuccessors
List of successors to this node.
Definition: graph_impl.hpp:79
sycl::detail::CG::CGTYPE MCGType
Type of the command-group for the node.
Definition: graph_impl.hpp:85
std::shared_ptr< exec_graph_impl > MSubGraphImpl
Stores the executable graph impl associated with this node if it is a subgraph node.
Definition: graph_impl.hpp:92
node_impl(node_type NodeType, std::unique_ptr< sycl::detail::CG > &&CommandGroup)
Construct a node representing a command-group.
Definition: graph_impl.hpp:139
bool hasRequirement(sycl::detail::AccessorImplHost *IncomingReq)
Checks if this node has a given requirement.
Definition: graph_impl.hpp:160
int MPartitionNum
Partition number needed to assign a Node to a a partition.
Definition: graph_impl.hpp:100
bool isSimilar(const std::shared_ptr< node_impl > &Node, bool CompareContentOnly=false) const
Tests if the caller is similar to Node, this is only used for testing.
Definition: graph_impl.hpp:262
void printDotRecursive(std::fstream &Stream, std::vector< node_impl * > &Visited, bool Verbose)
Recursive Depth first traversal of linked nodes.
Definition: graph_impl.hpp:313
node_type MNodeType
User facing type of the node.
Definition: graph_impl.hpp:87
std::unique_ptr< sycl::detail::CG > getCGCopy() const
Get a deep copy of this node's command group.
Definition: graph_impl.hpp:181
bool isEmpty() const
Query if this is an empty node.
Definition: graph_impl.hpp:174
void registerPredecessor(const std::shared_ptr< node_impl > &Node)
Add predecessor to the node.
Definition: graph_impl.hpp:122
bool MVisited
Used for tracking visited status during cycle checks.
Definition: graph_impl.hpp:95
std::unique_ptr< sycl::detail::CG > MCommandGroup
Command group object which stores all args etc needed to enqueue the node.
Definition: graph_impl.hpp:89
std::unordered_map< sycl::device, sycl::detail::pi::PiExtCommandBuffer > MPiCommandBuffers
Map of devices to command buffers.
Definition: graph_impl.hpp:550
std::set< std::weak_ptr< node_impl >, std::owner_less< std::weak_ptr< node_impl > > > MRoots
List of root nodes.
Definition: graph_impl.hpp:545
std::vector< std::shared_ptr< partition > > MPredecessors
List of predecessors to this partition.
Definition: graph_impl.hpp:552
std::list< std::shared_ptr< node_impl > > MSchedule
Execution schedule of nodes in the graph.
Definition: graph_impl.hpp:547
Property passed to command_graph constructor to disable checking for cycles.
Definition: graph.hpp:134
Command group handler class.
Definition: handler.hpp:453
Objects of the property_list class are containers for the SYCL properties.
::pi_ext_sync_point PiExtSyncPoint
Definition: pi.hpp:156
std::shared_ptr< event_impl > EventImplPtr
Definition: cg.hpp:43
std::shared_ptr< device_impl > DeviceImplPtr
node_type getNodeTypeFromCG(sycl::detail::CG::CGTYPE CGType)
Definition: graph_impl.hpp:40
std::error_code make_error_code(sycl::errc E) noexcept
Constructs an error code using e and sycl_category()
Definition: exception.cpp:94
Definition: access.hpp:18