DPC++ Runtime
Runtime libraries for oneAPI DPC++
graph_impl.hpp
Go to the documentation of this file.
1 //==--------- graph_impl.hpp --- SYCL graph extension ---------------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #include <sycl/detail/cg_types.hpp>
12 #include <sycl/detail/os_util.hpp>
14 #include <sycl/handler.hpp>
15 
16 #include <detail/accessor_impl.hpp>
17 #include <detail/event_impl.hpp>
18 #include <detail/kernel_impl.hpp>
20 
21 #include <cstring>
22 #include <deque>
23 #include <fstream>
24 #include <functional>
25 #include <list>
26 #include <set>
27 #include <shared_mutex>
28 
29 namespace sycl {
30 inline namespace _V1 {
31 
32 namespace detail {
33 class SYCLMemObjT;
34 }
35 
36 namespace ext {
37 namespace oneapi {
38 namespace experimental {
39 namespace detail {
40 
42  using sycl::detail::CG;
43 
44  switch (CGType) {
45  case CG::None:
46  return node_type::empty;
47  case CG::Kernel:
48  return node_type::kernel;
49  case CG::CopyAccToPtr:
50  case CG::CopyPtrToAcc:
51  case CG::CopyAccToAcc:
52  case CG::CopyUSM:
53  return node_type::memcpy;
54  case CG::Memset2DUSM:
55  return node_type::memset;
56  case CG::Fill:
57  case CG::FillUSM:
58  return node_type::memfill;
59  case CG::PrefetchUSM:
60  return node_type::prefetch;
61  case CG::AdviseUSM:
62  return node_type::memadvise;
63  case CG::Barrier:
64  case CG::BarrierWaitlist:
66  case CG::CodeplayHostTask:
67  return node_type::host_task;
68  case CG::ExecCommandBuffer:
69  return node_type::subgraph;
70  default:
71  assert(false && "Invalid Graph Node Type");
72  return node_type::empty;
73  }
74 }
75 
77 class node_impl {
78 public:
79  using id_type = uint64_t;
80 
82  id_type MID = getNextNodeID();
84  std::vector<std::weak_ptr<node_impl>> MSuccessors;
88  std::vector<std::weak_ptr<node_impl>> MPredecessors;
90  sycl::detail::CG::CGTYPE MCGType = sycl::detail::CG::None;
94  std::unique_ptr<sycl::detail::CG> MCommandGroup;
97  std::shared_ptr<exec_graph_impl> MSubGraphImpl;
98 
100  bool MVisited = false;
101 
105  int MPartitionNum = -1;
106 
108  bool MNDRangeUsed = false;
109 
116  void registerSuccessor(const std::shared_ptr<node_impl> &Node,
117  const std::shared_ptr<node_impl> &Prev) {
118  if (std::find_if(MSuccessors.begin(), MSuccessors.end(),
119  [Node](const std::weak_ptr<node_impl> &Ptr) {
120  return Ptr.lock() == Node;
121  }) != MSuccessors.end()) {
122  return;
123  }
124  MSuccessors.push_back(Node);
125  Node->registerPredecessor(Prev);
126  }
127 
130  void registerPredecessor(const std::shared_ptr<node_impl> &Node) {
131  if (std::find_if(MPredecessors.begin(), MPredecessors.end(),
132  [&Node](const std::weak_ptr<node_impl> &Ptr) {
133  return Ptr.lock() == Node;
134  }) != MPredecessors.end()) {
135  return;
136  }
137  MPredecessors.push_back(Node);
138  }
139 
142 
148  std::unique_ptr<sycl::detail::CG> &&CommandGroup)
149  : MCGType(CommandGroup->getType()), MNodeType(NodeType),
150  MCommandGroup(std::move(CommandGroup)) {
151  if (NodeType == node_type::subgraph) {
152  MSubGraphImpl =
153  static_cast<sycl::detail::CGExecCommandBuffer *>(MCommandGroup.get())
154  ->MExecGraph;
155  }
156  }
157 
162  MCGType(Other.MCGType), MNodeType(Other.MNodeType),
164 
168  if (this != &Other) {
169  MSuccessors = Other.MSuccessors;
171  MCGType = Other.MCGType;
172  MNodeType = Other.MNodeType;
173  MCommandGroup = Other.getCGCopy();
175  }
176  return *this;
177  }
183  bool hasRequirementDependency(sycl::detail::AccessorImplHost *IncomingReq) {
184  access_mode InMode = IncomingReq->MAccessMode;
185  switch (InMode) {
186  case access_mode::read:
188  case access_mode::atomic:
189  break;
190  // These access modes don't care about existing buffer data, so we don't
191  // need a dependency.
192  case access_mode::write:
193  case access_mode::discard_read_write:
194  case access_mode::discard_write:
195  return false;
196  }
197 
198  for (sycl::detail::AccessorImplHost *CurrentReq :
199  MCommandGroup->getRequirements()) {
200  if (IncomingReq->MSYCLMemObj == CurrentReq->MSYCLMemObj) {
201  access_mode CurrentMode = CurrentReq->MAccessMode;
202  // Since we have an incoming read requirement, we only care
203  // about requirements on this node if they are write
204  if (CurrentMode != access_mode::read) {
205  return true;
206  }
207  }
208  }
209  // No dependency necessary
210  return false;
211  }
212 
217  bool isEmpty() const {
218  return ((MCGType == sycl::detail::CG::None) ||
219  (MCGType == sycl::detail::CG::Barrier));
220  }
221 
224  std::unique_ptr<sycl::detail::CG> getCGCopy() const {
225  switch (MCGType) {
226  case sycl::detail::CG::Kernel: {
227  auto CGCopy = createCGCopy<sycl::detail::CGExecKernel>();
228  rebuildArgStorage(CGCopy->MArgs, MCommandGroup->getArgsStorage(),
229  CGCopy->getArgsStorage());
230  return std::move(CGCopy);
231  }
232  case sycl::detail::CG::CopyAccToPtr:
233  case sycl::detail::CG::CopyPtrToAcc:
234  case sycl::detail::CG::CopyAccToAcc:
235  return createCGCopy<sycl::detail::CGCopy>();
236  case sycl::detail::CG::Fill:
237  return createCGCopy<sycl::detail::CGFill>();
238  case sycl::detail::CG::UpdateHost:
239  return createCGCopy<sycl::detail::CGUpdateHost>();
240  case sycl::detail::CG::CopyUSM:
241  return createCGCopy<sycl::detail::CGCopyUSM>();
242  case sycl::detail::CG::FillUSM:
243  return createCGCopy<sycl::detail::CGFillUSM>();
244  case sycl::detail::CG::PrefetchUSM:
245  return createCGCopy<sycl::detail::CGPrefetchUSM>();
246  case sycl::detail::CG::AdviseUSM:
247  return createCGCopy<sycl::detail::CGAdviseUSM>();
248  case sycl::detail::CG::Copy2DUSM:
249  return createCGCopy<sycl::detail::CGCopy2DUSM>();
250  case sycl::detail::CG::Fill2DUSM:
251  return createCGCopy<sycl::detail::CGFill2DUSM>();
252  case sycl::detail::CG::Memset2DUSM:
253  return createCGCopy<sycl::detail::CGMemset2DUSM>();
254  case sycl::detail::CG::CodeplayHostTask: {
255  // The unique_ptr to the `sycl::detail::HostTask` in the HostTask CG
256  // prevents from copying the CG.
257  // We overcome this restriction by creating a new CG with the same data.
258  auto CommandGroupPtr =
259  static_cast<sycl::detail::CGHostTask *>(MCommandGroup.get());
260  sycl::detail::HostTask HostTask = *CommandGroupPtr->MHostTask.get();
261  auto HostTaskUPtr = std::make_unique<sycl::detail::HostTask>(HostTask);
262 
264  CommandGroupPtr->getArgsStorage(), CommandGroupPtr->getAccStorage(),
265  CommandGroupPtr->getSharedPtrStorage(),
266  CommandGroupPtr->getRequirements(), CommandGroupPtr->getEvents());
267 
268  std::vector<sycl::detail::ArgDesc> NewArgs = CommandGroupPtr->MArgs;
269 
270  rebuildArgStorage(NewArgs, CommandGroupPtr->getArgsStorage(),
271  Data.MArgsStorage);
272 
273  sycl::detail::code_location Loc(CommandGroupPtr->MFileName.data(),
274  CommandGroupPtr->MFunctionName.data(),
275  CommandGroupPtr->MLine,
276  CommandGroupPtr->MColumn);
277 
278  return std::make_unique<sycl::detail::CGHostTask>(
279  sycl::detail::CGHostTask(
280  std::move(HostTaskUPtr), CommandGroupPtr->MQueue,
281  CommandGroupPtr->MContext, std::move(NewArgs), std::move(Data),
282  CommandGroupPtr->getType(), Loc));
283  }
284  case sycl::detail::CG::Barrier:
285  case sycl::detail::CG::BarrierWaitlist:
286  // Barrier nodes are stored in the graph with only the base CG class,
287  // since they are treated internally as empty nodes.
288  return createCGCopy<sycl::detail::CG>();
289  case sycl::detail::CG::CopyToDeviceGlobal:
290  return createCGCopy<sycl::detail::CGCopyToDeviceGlobal>();
291  case sycl::detail::CG::CopyFromDeviceGlobal:
292  return createCGCopy<sycl::detail::CGCopyFromDeviceGlobal>();
293  case sycl::detail::CG::ReadWriteHostPipe:
294  return createCGCopy<sycl::detail::CGReadWriteHostPipe>();
295  case sycl::detail::CG::CopyImage:
296  return createCGCopy<sycl::detail::CGCopyImage>();
297  case sycl::detail::CG::SemaphoreSignal:
298  return createCGCopy<sycl::detail::CGSemaphoreSignal>();
299  case sycl::detail::CG::SemaphoreWait:
300  return createCGCopy<sycl::detail::CGSemaphoreWait>();
301  case sycl::detail::CG::ExecCommandBuffer:
302  return createCGCopy<sycl::detail::CGExecCommandBuffer>();
303  case sycl::detail::CG::None:
304  return nullptr;
305  }
306  return nullptr;
307  }
308 
314  bool isSimilar(const std::shared_ptr<node_impl> &Node,
315  bool CompareContentOnly = false) const {
316  if (!CompareContentOnly) {
317  if (MSuccessors.size() != Node->MSuccessors.size())
318  return false;
319 
320  if (MPredecessors.size() != Node->MPredecessors.size())
321  return false;
322  }
323  if (MCGType != Node->MCGType)
324  return false;
325 
326  switch (MCGType) {
327  case sycl::detail::CG::CGTYPE::Kernel: {
328  sycl::detail::CGExecKernel *ExecKernelA =
329  static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get());
330  sycl::detail::CGExecKernel *ExecKernelB =
331  static_cast<sycl::detail::CGExecKernel *>(Node->MCommandGroup.get());
332  return ExecKernelA->MKernelName.compare(ExecKernelB->MKernelName) == 0;
333  }
334  case sycl::detail::CG::CGTYPE::CopyUSM: {
335  sycl::detail::CGCopyUSM *CopyA =
336  static_cast<sycl::detail::CGCopyUSM *>(MCommandGroup.get());
337  sycl::detail::CGCopyUSM *CopyB =
338  static_cast<sycl::detail::CGCopyUSM *>(Node->MCommandGroup.get());
339  return (CopyA->getSrc() == CopyB->getSrc()) &&
340  (CopyA->getDst() == CopyB->getDst()) &&
341  (CopyA->getLength() == CopyB->getLength());
342  }
343  case sycl::detail::CG::CGTYPE::CopyAccToAcc:
344  case sycl::detail::CG::CGTYPE::CopyAccToPtr:
345  case sycl::detail::CG::CGTYPE::CopyPtrToAcc: {
346  sycl::detail::CGCopy *CopyA =
347  static_cast<sycl::detail::CGCopy *>(MCommandGroup.get());
348  sycl::detail::CGCopy *CopyB =
349  static_cast<sycl::detail::CGCopy *>(Node->MCommandGroup.get());
350  return (CopyA->getSrc() == CopyB->getSrc()) &&
351  (CopyA->getDst() == CopyB->getDst());
352  }
353  default:
354  assert(false && "Unexpected command group type!");
355  return false;
356  }
357  }
358 
365  void printDotRecursive(std::fstream &Stream,
366  std::vector<node_impl *> &Visited, bool Verbose) {
367  // if Node has been already visited, we skip it
368  if (std::find(Visited.begin(), Visited.end(), this) != Visited.end())
369  return;
370 
371  Visited.push_back(this);
372 
373  printDotCG(Stream, Verbose);
374  for (const auto &Dep : MPredecessors) {
375  auto NodeDep = Dep.lock();
376  Stream << " \"" << NodeDep.get() << "\" -> \"" << this << "\""
377  << std::endl;
378  }
379 
380  for (std::weak_ptr<node_impl> Succ : MSuccessors) {
381  if (MPartitionNum == Succ.lock()->MPartitionNum)
382  Succ.lock()->printDotRecursive(Stream, Visited, Verbose);
383  }
384  }
385 
390  void updateAccessor(int ArgIndex, const sycl::detail::AccessorBaseHost *Acc) {
391  auto &Args =
392  static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get())->MArgs;
393  auto NewAccImpl = sycl::detail::getSyclObjImpl(*Acc);
394  for (auto &Arg : Args) {
395  if (Arg.MIndex != ArgIndex) {
396  continue;
397  }
398  assert(Arg.MType == sycl::detail::kernel_param_kind_t::kind_accessor);
399 
400  // Find old accessor in accessor storage and replace with new one
401  if (static_cast<sycl::detail::SYCLMemObjT *>(NewAccImpl->MSYCLMemObj)
402  ->needsWriteBack()) {
403  throw sycl::exception(
405  "Accessors to buffers which have write_back enabled "
406  "are not allowed to be used in command graphs.");
407  }
408 
409  // All accessors passed to this function will be placeholders, so we must
410  // perform steps similar to what happens when handler::require() is
411  // called here.
412  sycl::detail::Requirement *NewReq = NewAccImpl.get();
413  if (NewReq->MAccessMode != sycl::access_mode::read) {
414  auto SYCLMemObj =
415  static_cast<sycl::detail::SYCLMemObjT *>(NewReq->MSYCLMemObj);
416  SYCLMemObj->handleWriteAccessorCreation();
417  }
418 
419  for (auto &Acc : MCommandGroup->getAccStorage()) {
420  if (auto OldAcc =
421  static_cast<sycl::detail::AccessorImplHost *>(Arg.MPtr);
422  Acc.get() == OldAcc) {
423  Acc = NewAccImpl;
424  }
425  }
426 
427  for (auto &Req : MCommandGroup->getRequirements()) {
428  if (auto OldReq =
429  static_cast<sycl::detail::AccessorImplHost *>(Arg.MPtr);
430  Req == OldReq) {
431  Req = NewReq;
432  }
433  }
434  Arg.MPtr = NewAccImpl.get();
435  break;
436  }
437  }
438 
439  void updateArgValue(int ArgIndex, const void *NewValue, size_t Size) {
440 
441  auto &Args =
442  static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get())->MArgs;
443  for (auto &Arg : Args) {
444  if (Arg.MIndex != ArgIndex) {
445  continue;
446  }
447  assert(Arg.MSize == static_cast<int>(Size));
448  // MPtr may be a pointer into arg storage so we memcpy the contents of
449  // NewValue rather than assign it directly
450  std::memcpy(Arg.MPtr, NewValue, Size);
451  break;
452  }
453  }
454 
455  template <int Dimensions>
456  void updateNDRange(nd_range<Dimensions> ExecutionRange) {
457  if (MCGType != sycl::detail::CG::Kernel) {
458  throw sycl::exception(
459  sycl::errc::invalid,
460  "Cannot update execution range of nodes which are not kernel nodes");
461  }
462  if (!MNDRangeUsed) {
463  throw sycl::exception(sycl::errc::invalid,
464  "Cannot update node which was created with a "
465  "sycl::range with a sycl::nd_range");
466  }
467 
468  auto &NDRDesc =
469  static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get())
470  ->MNDRDesc;
471 
472  if (NDRDesc.Dims != Dimensions) {
473  throw sycl::exception(sycl::errc::invalid,
474  "Cannot update execution range of a node with an "
475  "execution range of different dimensions than what "
476  "the node was originall created with.");
477  }
478 
479  NDRDesc.set(ExecutionRange);
480  }
481 
482  template <int Dimensions> void updateRange(range<Dimensions> ExecutionRange) {
483  if (MCGType != sycl::detail::CG::Kernel) {
484  throw sycl::exception(
485  sycl::errc::invalid,
486  "Cannot update execution range of nodes which are not kernel nodes");
487  }
488  if (MNDRangeUsed) {
489  throw sycl::exception(sycl::errc::invalid,
490  "Cannot update node which was created with a "
491  "sycl::nd_range with a sycl::range");
492  }
493 
494  auto &NDRDesc =
495  static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get())
496  ->MNDRDesc;
497 
498  if (NDRDesc.Dims != Dimensions) {
499  throw sycl::exception(sycl::errc::invalid,
500  "Cannot update execution range of a node with an "
501  "execution range of different dimensions than what "
502  "the node was originall created with.");
503  }
504 
505  NDRDesc.set(ExecutionRange);
506  }
507 
508  void updateFromOtherNode(const std::shared_ptr<node_impl> &Other) {
509  auto ExecCG =
510  static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get());
511  auto OtherExecCG =
512  static_cast<sycl::detail::CGExecKernel *>(Other->MCommandGroup.get());
513 
514  ExecCG->MArgs = OtherExecCG->MArgs;
515  ExecCG->MNDRDesc = OtherExecCG->MNDRDesc;
516  ExecCG->getAccStorage() = OtherExecCG->getAccStorage();
517  ExecCG->getRequirements() = OtherExecCG->getRequirements();
518 
519  auto &OldArgStorage = OtherExecCG->getArgsStorage();
520  auto &NewArgStorage = ExecCG->getArgsStorage();
521  // Rebuild the arg storage and update the args
522  rebuildArgStorage(ExecCG->MArgs, OldArgStorage, NewArgStorage);
523  }
524 
525  id_type getID() const { return MID; }
526 
527 private:
528  void rebuildArgStorage(std::vector<sycl::detail::ArgDesc> &Args,
529  const std::vector<std::vector<char>> &OldArgStorage,
530  std::vector<std::vector<char>> &NewArgStorage) const {
531  // Clear the arg storage so we can rebuild it
532  NewArgStorage.clear();
533 
534  // Loop over all the args, any std_layout ones need their pointers updated
535  // to point to the new arg storage.
536  for (auto &Arg : Args) {
537  if (Arg.MType != sycl::detail::kernel_param_kind_t::kind_std_layout) {
538  continue;
539  }
540  // Find which ArgStorage Arg.MPtr is pointing to
541  for (auto &ArgStorage : OldArgStorage) {
542  if (ArgStorage.data() != Arg.MPtr) {
543  continue;
544  }
545  NewArgStorage.emplace_back(Arg.MSize);
546  // Memcpy contents from old storage to new storage
547  std::memcpy(NewArgStorage.back().data(), ArgStorage.data(), Arg.MSize);
548  // Update MPtr to point to the new storage instead of the old
549  Arg.MPtr = NewArgStorage.back().data();
550 
551  break;
552  }
553  }
554  }
555  // Gets the next unique identifier for a node, should only be used when
556  // constructing nodes.
557  static id_type getNextNodeID() {
558  static id_type nextID = 0;
559 
560  // Return the value then increment the next ID
561  return nextID++;
562  }
563 
568  void printDotCG(std::ostream &Stream, bool Verbose) {
569  Stream << "\"" << this << "\" [style=bold, label=\"";
570 
571  Stream << "ID = " << this << "\\n";
572  Stream << "TYPE = ";
573 
574  switch (MCGType) {
575  case sycl::detail::CG::CGTYPE::None:
576  Stream << "None \\n";
577  break;
578  case sycl::detail::CG::CGTYPE::Kernel: {
579  Stream << "CGExecKernel \\n";
580  sycl::detail::CGExecKernel *Kernel =
581  static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get());
582  Stream << "NAME = " << Kernel->MKernelName << "\\n";
583  if (Verbose) {
584  Stream << "ARGS = \\n";
585  for (size_t i = 0; i < Kernel->MArgs.size(); i++) {
586  auto Arg = Kernel->MArgs[i];
587  std::string Type = "Undefined";
588  if (Arg.MType == sycl::detail::kernel_param_kind_t::kind_accessor) {
589  Type = "Accessor";
590  } else if (Arg.MType ==
591  sycl::detail::kernel_param_kind_t::kind_std_layout) {
592  Type = "STD_Layout";
593  } else if (Arg.MType ==
594  sycl::detail::kernel_param_kind_t::kind_sampler) {
595  Type = "Sampler";
596  } else if (Arg.MType ==
597  sycl::detail::kernel_param_kind_t::kind_pointer) {
598  Type = "Pointer";
599  } else if (Arg.MType == sycl::detail::kernel_param_kind_t::
600  kind_specialization_constants_buffer) {
601  Type = "Specialization Constants Buffer";
602  } else if (Arg.MType ==
603  sycl::detail::kernel_param_kind_t::kind_stream) {
604  Type = "Stream";
605  } else if (Arg.MType ==
606  sycl::detail::kernel_param_kind_t::kind_invalid) {
607  Type = "Invalid";
608  }
609  Stream << i << ") Type: " << Type << " Ptr: " << Arg.MPtr << "\\n";
610  }
611  }
612  break;
613  }
614  case sycl::detail::CG::CGTYPE::CopyAccToPtr:
615  Stream << "CGCopy Device-to-Host \\n";
616  if (Verbose) {
617  sycl::detail::CGCopy *Copy =
618  static_cast<sycl::detail::CGCopy *>(MCommandGroup.get());
619  Stream << "Src: " << Copy->getSrc() << " Dst: " << Copy->getDst()
620  << "\\n";
621  }
622  break;
623  case sycl::detail::CG::CGTYPE::CopyPtrToAcc:
624  Stream << "CGCopy Host-to-Device \\n";
625  if (Verbose) {
626  sycl::detail::CGCopy *Copy =
627  static_cast<sycl::detail::CGCopy *>(MCommandGroup.get());
628  Stream << "Src: " << Copy->getSrc() << " Dst: " << Copy->getDst()
629  << "\\n";
630  }
631  break;
632  case sycl::detail::CG::CGTYPE::CopyAccToAcc:
633  Stream << "CGCopy Device-to-Device \\n";
634  if (Verbose) {
635  sycl::detail::CGCopy *Copy =
636  static_cast<sycl::detail::CGCopy *>(MCommandGroup.get());
637  Stream << "Src: " << Copy->getSrc() << " Dst: " << Copy->getDst()
638  << "\\n";
639  }
640  break;
641  case sycl::detail::CG::CGTYPE::Fill:
642  Stream << "CGFill \\n";
643  if (Verbose) {
644  sycl::detail::CGFill *Fill =
645  static_cast<sycl::detail::CGFill *>(MCommandGroup.get());
646  Stream << "Ptr: " << Fill->MPtr << "\\n";
647  }
648  break;
649  case sycl::detail::CG::CGTYPE::UpdateHost:
650  Stream << "CGCUpdateHost \\n";
651  if (Verbose) {
652  sycl::detail::CGUpdateHost *Host =
653  static_cast<sycl::detail::CGUpdateHost *>(MCommandGroup.get());
654  Stream << "Ptr: " << Host->getReqToUpdate() << "\\n";
655  }
656  break;
657  case sycl::detail::CG::CGTYPE::CopyUSM:
658  Stream << "CGCopyUSM \\n";
659  if (Verbose) {
660  sycl::detail::CGCopyUSM *CopyUSM =
661  static_cast<sycl::detail::CGCopyUSM *>(MCommandGroup.get());
662  Stream << "Src: " << CopyUSM->getSrc() << " Dst: " << CopyUSM->getDst()
663  << " Length: " << CopyUSM->getLength() << "\\n";
664  }
665  break;
666  case sycl::detail::CG::CGTYPE::FillUSM:
667  Stream << "CGFillUSM \\n";
668  if (Verbose) {
669  sycl::detail::CGFillUSM *FillUSM =
670  static_cast<sycl::detail::CGFillUSM *>(MCommandGroup.get());
671  Stream << "Dst: " << FillUSM->getDst()
672  << " Length: " << FillUSM->getLength()
673  << " Pattern: " << FillUSM->getFill() << "\\n";
674  }
675  break;
676  case sycl::detail::CG::CGTYPE::PrefetchUSM:
677  Stream << "CGPrefetchUSM \\n";
678  if (Verbose) {
679  sycl::detail::CGPrefetchUSM *Prefetch =
680  static_cast<sycl::detail::CGPrefetchUSM *>(MCommandGroup.get());
681  Stream << "Dst: " << Prefetch->getDst()
682  << " Length: " << Prefetch->getLength() << "\\n";
683  }
684  break;
685  case sycl::detail::CG::CGTYPE::AdviseUSM:
686  Stream << "CGAdviseUSM \\n";
687  if (Verbose) {
688  sycl::detail::CGAdviseUSM *AdviseUSM =
689  static_cast<sycl::detail::CGAdviseUSM *>(MCommandGroup.get());
690  Stream << "Dst: " << AdviseUSM->getDst()
691  << " Length: " << AdviseUSM->getLength() << "\\n";
692  }
693  break;
694  case sycl::detail::CG::CGTYPE::CodeplayHostTask:
695  Stream << "CGHostTask \\n";
696  break;
697  case sycl::detail::CG::CGTYPE::Barrier:
698  Stream << "CGBarrier \\n";
699  break;
700  case sycl::detail::CG::CGTYPE::Copy2DUSM:
701  Stream << "CGCopy2DUSM \\n";
702  if (Verbose) {
703  sycl::detail::CGCopy2DUSM *Copy2DUSM =
704  static_cast<sycl::detail::CGCopy2DUSM *>(MCommandGroup.get());
705  Stream << "Src:" << Copy2DUSM->getSrc()
706  << " Dst: " << Copy2DUSM->getDst() << "\\n";
707  }
708  break;
709  case sycl::detail::CG::CGTYPE::Fill2DUSM:
710  Stream << "CGFill2DUSM \\n";
711  if (Verbose) {
712  sycl::detail::CGFill2DUSM *Fill2DUSM =
713  static_cast<sycl::detail::CGFill2DUSM *>(MCommandGroup.get());
714  Stream << "Dst: " << Fill2DUSM->getDst() << "\\n";
715  }
716  break;
717  case sycl::detail::CG::CGTYPE::Memset2DUSM:
718  Stream << "CGMemset2DUSM \\n";
719  if (Verbose) {
720  sycl::detail::CGMemset2DUSM *Memset2DUSM =
721  static_cast<sycl::detail::CGMemset2DUSM *>(MCommandGroup.get());
722  Stream << "Dst: " << Memset2DUSM->getDst() << "\\n";
723  }
724  break;
725  case sycl::detail::CG::CGTYPE::ReadWriteHostPipe:
726  Stream << "CGReadWriteHostPipe \\n";
727  break;
728  case sycl::detail::CG::CGTYPE::CopyToDeviceGlobal:
729  Stream << "CGCopyToDeviceGlobal \\n";
730  if (Verbose) {
731  sycl::detail::CGCopyToDeviceGlobal *CopyToDeviceGlobal =
732  static_cast<sycl::detail::CGCopyToDeviceGlobal *>(
733  MCommandGroup.get());
734  Stream << "Src: " << CopyToDeviceGlobal->getSrc()
735  << " Dst: " << CopyToDeviceGlobal->getDeviceGlobalPtr() << "\\n";
736  }
737  break;
738  case sycl::detail::CG::CGTYPE::CopyFromDeviceGlobal:
739  Stream << "CGCopyFromDeviceGlobal \\n";
740  if (Verbose) {
741  sycl::detail::CGCopyFromDeviceGlobal *CopyFromDeviceGlobal =
742  static_cast<sycl::detail::CGCopyFromDeviceGlobal *>(
743  MCommandGroup.get());
744  Stream << "Src: " << CopyFromDeviceGlobal->getDeviceGlobalPtr()
745  << " Dst: " << CopyFromDeviceGlobal->getDest() << "\\n";
746  }
747  break;
748  case sycl::detail::CG::CGTYPE::ExecCommandBuffer:
749  Stream << "CGExecCommandBuffer \\n";
750  break;
751  default:
752  Stream << "Other \\n";
753  break;
754  }
755  Stream << "\"];" << std::endl;
756  }
757 
762  template <typename CGT> std::unique_ptr<CGT> createCGCopy() const {
763  return std::make_unique<CGT>(*static_cast<CGT *>(MCommandGroup.get()));
764  }
765 };
766 
767 class partition {
768 public:
771 
773  std::set<std::weak_ptr<node_impl>, std::owner_less<std::weak_ptr<node_impl>>>
776  std::list<std::shared_ptr<node_impl>> MSchedule;
778  std::unordered_map<sycl::device, sycl::detail::pi::PiExtCommandBuffer>
781  std::vector<std::shared_ptr<partition>> MPredecessors;
782 
784  bool isHostTask() const {
785  return (MRoots.size() && ((*MRoots.begin()).lock()->MCGType ==
786  sycl::detail::CG::CGTYPE::CodeplayHostTask));
787  }
788 
790  void schedule();
791 };
792 
794 class graph_impl {
795 public:
796  using ReadLock = std::shared_lock<std::shared_mutex>;
797  using WriteLock = std::unique_lock<std::shared_mutex>;
798 
800  mutable std::shared_mutex MMutex;
801 
806  graph_impl(const sycl::context &SyclContext, const sycl::device &SyclDevice,
807  const sycl::property_list &PropList = {})
808  : MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(),
809  MEventsMap(), MInorderQueueMap() {
810  if (PropList.has_property<property::graph::no_cycle_check>()) {
811  MSkipCycleChecks = true;
812  }
813  if (PropList
814  .has_property<property::graph::assume_buffer_outlives_graph>()) {
815  MAllowBuffers = true;
816  }
817 
818  if (!SyclDevice.has(aspect::ext_oneapi_limited_graph) &&
819  !SyclDevice.has(aspect::ext_oneapi_graph)) {
820  std::stringstream Stream;
821  Stream << SyclDevice.get_backend();
822  std::string BackendString = Stream.str();
823  throw sycl::exception(
825  BackendString + " backend is not supported by SYCL Graph extension.");
826  }
827  }
828 
829  ~graph_impl();
830 
833  void removeRoot(const std::shared_ptr<node_impl> &Root);
834 
840  std::shared_ptr<node_impl>
841  add(node_type NodeType, std::unique_ptr<sycl::detail::CG> CommandGroup,
842  const std::vector<std::shared_ptr<node_impl>> &Dep = {});
843 
850  std::shared_ptr<node_impl>
851  add(const std::shared_ptr<graph_impl> &Impl,
852  std::function<void(handler &)> CGF,
853  const std::vector<sycl::detail::ArgDesc> &Args,
854  const std::vector<std::shared_ptr<node_impl>> &Dep = {});
855 
860  std::shared_ptr<node_impl>
861  add(const std::shared_ptr<graph_impl> &Impl,
862  const std::vector<std::shared_ptr<node_impl>> &Dep = {});
863 
868  std::shared_ptr<node_impl>
869  add(const std::shared_ptr<graph_impl> &Impl,
870  const std::vector<sycl::detail::EventImplPtr> Events);
871 
875  void
876  addQueue(const std::shared_ptr<sycl::detail::queue_impl> &RecordingQueue) {
877  MRecordingQueues.insert(RecordingQueue);
878  }
879 
883  void
884  removeQueue(const std::shared_ptr<sycl::detail::queue_impl> &RecordingQueue) {
885  MRecordingQueues.erase(RecordingQueue);
886  }
887 
892  bool clearQueues();
893 
899  void addEventForNode(std::shared_ptr<graph_impl> GraphImpl,
900  std::shared_ptr<sycl::detail::event_impl> EventImpl,
901  std::shared_ptr<node_impl> NodeImpl) {
902  if (EventImpl && !(EventImpl->getCommandGraph()))
903  EventImpl->setCommandGraph(GraphImpl);
904  MEventsMap[EventImpl] = NodeImpl;
905  }
906 
910  std::shared_ptr<sycl::detail::event_impl>
911  getEventForNode(std::shared_ptr<node_impl> NodeImpl) const {
912  ReadLock Lock(MMutex);
913  if (auto EventImpl = std::find_if(
914  MEventsMap.begin(), MEventsMap.end(),
915  [NodeImpl](auto &it) { return it.second == NodeImpl; });
916  EventImpl != MEventsMap.end()) {
917  return EventImpl->first;
918  }
919 
920  throw sycl::exception(
922  "No event has been recorded for the specified graph node");
923  }
924 
929  std::shared_ptr<node_impl>
930  getNodeForEvent(std::shared_ptr<sycl::detail::event_impl> EventImpl) {
931  ReadLock Lock(MMutex);
932 
933  if (auto NodeFound = MEventsMap.find(EventImpl);
934  NodeFound != std::end(MEventsMap)) {
935  return NodeFound->second;
936  }
937 
938  throw sycl::exception(
940  "No node in this graph is associated with this event");
941  }
942 
945  sycl::context getContext() const { return MContext; }
946 
949  sycl::device getDevice() const { return MDevice; }
950 
952  std::set<std::weak_ptr<node_impl>, std::owner_less<std::weak_ptr<node_impl>>>
954 
959  std::vector<std::shared_ptr<node_impl>> MNodeStorage;
960 
965  std::shared_ptr<node_impl>
966  getLastInorderNode(std::shared_ptr<sycl::detail::queue_impl> Queue) {
967  std::weak_ptr<sycl::detail::queue_impl> QueueWeakPtr(Queue);
968  if (0 == MInorderQueueMap.count(QueueWeakPtr)) {
969  return {};
970  }
971  return MInorderQueueMap[QueueWeakPtr];
972  }
973 
977  void setLastInorderNode(std::shared_ptr<sycl::detail::queue_impl> Queue,
978  std::shared_ptr<node_impl> Node) {
979  std::weak_ptr<sycl::detail::queue_impl> QueueWeakPtr(Queue);
980  MInorderQueueMap[QueueWeakPtr] = Node;
981  }
982 
987  void printGraphAsDot(const std::string FilePath, bool Verbose) const {
989  std::vector<node_impl *> VisitedNodes;
990 
991  std::fstream Stream(FilePath, std::ios::out);
992  Stream << "digraph dot {" << std::endl;
993 
994  for (std::weak_ptr<node_impl> Node : MRoots)
995  Node.lock()->printDotRecursive(Stream, VisitedNodes, Verbose);
996 
997  Stream << "}" << std::endl;
998 
999  Stream.close();
1000  }
1001 
1007  void makeEdge(std::shared_ptr<node_impl> Src,
1008  std::shared_ptr<node_impl> Dest);
1009 
1013  void throwIfGraphRecordingQueue(const std::string ExceptionMsg) const {
1014  if (MRecordingQueues.size()) {
1015  throw sycl::exception(make_error_code(sycl::errc::invalid),
1016  ExceptionMsg +
1017  " cannot be called when a queue "
1018  "is currently recording commands to a graph.");
1019  }
1020  }
1021 
1026  static bool checkNodeRecursive(const std::shared_ptr<node_impl> &NodeA,
1027  const std::shared_ptr<node_impl> &NodeB) {
1028  size_t FoundCnt = 0;
1029  for (std::weak_ptr<node_impl> &SuccA : NodeA->MSuccessors) {
1030  for (std::weak_ptr<node_impl> &SuccB : NodeB->MSuccessors) {
1031  if (NodeA->isSimilar(NodeB) &&
1032  checkNodeRecursive(SuccA.lock(), SuccB.lock())) {
1033  FoundCnt++;
1034  break;
1035  }
1036  }
1037  }
1038  if (FoundCnt != NodeA->MSuccessors.size()) {
1039  return false;
1040  }
1041 
1042  return true;
1043  }
1044 
1056  bool hasSimilarStructure(std::shared_ptr<detail::graph_impl> Graph,
1057  bool DebugPrint = false) const {
1058  if (this == Graph.get())
1059  return true;
1060 
1061  if (MContext != Graph->MContext) {
1062  if (DebugPrint) {
1064  "MContext are not the same.");
1065  }
1066  return false;
1067  }
1068 
1069  if (MDevice != Graph->MDevice) {
1070  if (DebugPrint) {
1072  "MDevice are not the same.");
1073  }
1074  return false;
1075  }
1076 
1077  if (MEventsMap.size() != Graph->MEventsMap.size()) {
1078  if (DebugPrint) {
1080  "MEventsMap sizes are not the same.");
1081  }
1082  return false;
1083  }
1084 
1085  if (MInorderQueueMap.size() != Graph->MInorderQueueMap.size()) {
1086  if (DebugPrint) {
1088  "MInorderQueueMap sizes are not the same.");
1089  }
1090  return false;
1091  }
1092 
1093  if (MRoots.size() != Graph->MRoots.size()) {
1094  if (DebugPrint) {
1096  "MRoots sizes are not the same.");
1097  }
1098  return false;
1099  }
1100 
1101  size_t RootsFound = 0;
1102  for (std::weak_ptr<node_impl> NodeA : MRoots) {
1103  for (std::weak_ptr<node_impl> NodeB : Graph->MRoots) {
1104  auto NodeALocked = NodeA.lock();
1105  auto NodeBLocked = NodeB.lock();
1106 
1107  if (NodeALocked->isSimilar(NodeBLocked)) {
1108  if (checkNodeRecursive(NodeALocked, NodeBLocked)) {
1109  RootsFound++;
1110  break;
1111  }
1112  }
1113  }
1114  }
1115 
1116  if (RootsFound != MRoots.size()) {
1117  if (DebugPrint) {
1119  "Root Nodes do NOT match.");
1120  }
1121  return false;
1122  }
1123 
1124  return true;
1125  }
1126 
1129  size_t getNumberOfNodes() const { return MNodeStorage.size(); }
1130 
1134  std::vector<sycl::detail::EventImplPtr> getExitNodesEvents();
1135 
1139  std::vector<sycl::detail::EventImplPtr>
1141  std::vector<sycl::detail::EventImplPtr> Events;
1142  for (auto It = MExtraDependencies.begin();
1143  It != MExtraDependencies.end();) {
1144  if ((*It)->MCGType == sycl::detail::CG::Barrier) {
1145  Events.push_back(getEventForNode(*It));
1146  It = MExtraDependencies.erase(It);
1147  } else {
1148  ++It;
1149  }
1150  }
1151  return Events;
1152  }
1153 
1154 private:
1161  void
1162  searchDepthFirst(std::function<bool(std::shared_ptr<node_impl> &,
1163  std::deque<std::shared_ptr<node_impl>> &)>
1164  NodeFunc);
1165 
1170  bool checkForCycles();
1171 
1174  void addRoot(const std::shared_ptr<node_impl> &Root);
1175 
1180  std::shared_ptr<node_impl>
1181  addNodesToExits(const std::shared_ptr<graph_impl> &Impl,
1182  const std::list<std::shared_ptr<node_impl>> &NodeList);
1183 
1188  void addDepsToNode(std::shared_ptr<node_impl> Node,
1189  const std::vector<std::shared_ptr<node_impl>> &Deps) {
1190  if (!Deps.empty()) {
1191  for (auto &N : Deps) {
1192  N->registerSuccessor(Node, N);
1193  this->removeRoot(Node);
1194  }
1195  } else {
1196  this->addRoot(Node);
1197  }
1198  }
1199 
1201  sycl::context MContext;
1204  sycl::device MDevice;
1206  std::set<std::weak_ptr<sycl::detail::queue_impl>,
1207  std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
1208  MRecordingQueues;
1210  std::unordered_map<std::shared_ptr<sycl::detail::event_impl>,
1211  std::shared_ptr<node_impl>>
1212  MEventsMap;
1216  std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
1217  std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
1218  MInorderQueueMap;
1221  bool MSkipCycleChecks = false;
1223  std::set<sycl::detail::SYCLMemObjT *> MMemObjs;
1224 
1227  bool MAllowBuffers = false;
1228 
1233  std::list<std::shared_ptr<node_impl>> MExtraDependencies;
1234 };
1235 
1238 public:
1239  using ReadLock = std::shared_lock<std::shared_mutex>;
1240  using WriteLock = std::unique_lock<std::shared_mutex>;
1241 
1243  mutable std::shared_mutex MMutex;
1244 
1254  const std::shared_ptr<graph_impl> &GraphImpl,
1255  const property_list &PropList);
1256 
1260  ~exec_graph_impl();
1261 
1266  void makePartitions();
1267 
1273  sycl::event enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
1275 
1281  void createCommandBuffers(sycl::device Device,
1282  std::shared_ptr<partition> &Partition);
1283 
1286  sycl::device getDevice() const { return MDevice; }
1287 
1290  sycl::context getContext() const { return MContext; }
1291 
1294  const std::list<std::shared_ptr<node_impl>> &getSchedule() const {
1295  return MSchedule;
1296  }
1297 
1300  const std::shared_ptr<graph_impl> &getGraphImpl() const { return MGraphImpl; }
1301 
1304  const std::vector<std::shared_ptr<partition>> &getPartitions() const {
1305  return MPartitions;
1306  }
1307 
1314  for (auto Event : MExecutionEvents) {
1315  if (!Event->isCompleted()) {
1316  return false;
1317  }
1318  }
1319  return true;
1320  }
1321 
1323  std::vector<sycl::detail::AccessorImplHost *> getRequirements() const {
1324  return MRequirements;
1325  }
1326 
1327  void update(std::shared_ptr<graph_impl> GraphImpl);
1328  void update(std::shared_ptr<node_impl> Node);
1329  void update(const std::vector<std::shared_ptr<node_impl>> Nodes);
1330 
1331  void updateImpl(std::shared_ptr<node_impl> NodeImpl);
1332 
1333 private:
1342  enqueueNode(sycl::context Ctx, sycl::detail::DeviceImplPtr DeviceImpl,
1344  std::shared_ptr<node_impl> Node);
1345 
1354  enqueueNodeDirect(sycl::context Ctx, sycl::detail::DeviceImplPtr DeviceImpl,
1356  std::shared_ptr<node_impl> Node);
1357 
1364  void findRealDeps(std::vector<sycl::detail::pi::PiExtSyncPoint> &Deps,
1365  std::shared_ptr<node_impl> CurrentNode,
1366  int ReferencePartitionNum);
1367 
1371  void duplicateNodes();
1372 
1377  void printGraphAsDot(const std::string FilePath, bool Verbose) const {
1379  std::vector<node_impl *> VisitedNodes;
1380 
1381  std::fstream Stream(FilePath, std::ios::out);
1382  Stream << "digraph dot {" << std::endl;
1383 
1384  std::vector<std::shared_ptr<node_impl>> Roots;
1385  for (auto &Node : MNodeStorage) {
1386  if (Node->MPredecessors.size() == 0) {
1387  Roots.push_back(Node);
1388  }
1389  }
1390 
1391  for (std::shared_ptr<node_impl> Node : Roots)
1392  Node->printDotRecursive(Stream, VisitedNodes, Verbose);
1393 
1394  Stream << "}" << std::endl;
1395 
1396  Stream.close();
1397  }
1398 
1400  std::list<std::shared_ptr<node_impl>> MSchedule;
1407  std::shared_ptr<graph_impl> MGraphImpl;
1410  std::unordered_map<std::shared_ptr<node_impl>,
1412  MPiSyncPoints;
1415  std::unordered_map<std::shared_ptr<node_impl>, int> MPartitionNodes;
1417  sycl::device MDevice;
1419  sycl::context MContext;
1422  std::vector<sycl::detail::AccessorImplHost *> MRequirements;
1425  std::vector<sycl::detail::AccessorImplPtr> MAccessors;
1427  std::vector<sycl::detail::EventImplPtr> MExecutionEvents;
1429  std::vector<std::shared_ptr<partition>> MPartitions;
1431  std::vector<std::shared_ptr<node_impl>> MNodeStorage;
1433  std::unordered_map<std::shared_ptr<node_impl>,
1435  MCommandMap;
1437  bool MIsUpdatable;
1438 
1439  // Stores a cache of node ids from modifiable graph nodes to the companion
1440  // node(s) in this graph. Used for quick access when updating this graph.
1441  std::multimap<node_impl::id_type, std::shared_ptr<node_impl>> MIDCache;
1442 };
1443 
1445 public:
1446  dynamic_parameter_impl(std::shared_ptr<graph_impl> GraphImpl,
1447  size_t ParamSize, const void *Data)
1448  : MGraph(GraphImpl), MValueStorage(ParamSize) {
1449  std::memcpy(MValueStorage.data(), Data, ParamSize);
1450  }
1451 
1456  void registerNode(std::shared_ptr<node_impl> NodeImpl, int ArgIndex) {
1457  MNodes.emplace_back(NodeImpl, ArgIndex);
1458  }
1459 
1461  void *getValue() { return MValueStorage.data(); }
1462 
1467  void updateValue(const void *NewValue, size_t Size) {
1468  for (auto &[NodeWeak, ArgIndex] : MNodes) {
1469  auto NodeShared = NodeWeak.lock();
1470  if (NodeShared) {
1471  NodeShared->updateArgValue(ArgIndex, NewValue, Size);
1472  }
1473  }
1474  std::memcpy(MValueStorage.data(), NewValue, Size);
1475  }
1476 
1481  void updateAccessor(const sycl::detail::AccessorBaseHost *Acc) {
1482  for (auto &[NodeWeak, ArgIndex] : MNodes) {
1483  auto NodeShared = NodeWeak.lock();
1484  // Should we fail here if the node isn't alive anymore?
1485  if (NodeShared) {
1486  NodeShared->updateAccessor(ArgIndex, Acc);
1487  }
1488  }
1489  std::memcpy(MValueStorage.data(), Acc,
1490  sizeof(sycl::detail::AccessorBaseHost));
1491  }
1492 
1493  // Weak ptrs to node_impls which will be updated
1494  std::vector<std::pair<std::weak_ptr<node_impl>, int>> MNodes;
1495 
1496  std::shared_ptr<graph_impl> MGraph;
1497  std::vector<std::byte> MValueStorage;
1498 };
1499 
1500 } // namespace detail
1501 } // namespace experimental
1502 } // namespace oneapi
1503 } // namespace ext
1504 } // namespace _V1
1505 } // namespace sycl
The context class represents a SYCL context on which kernel functions may be executed.
Definition: context.hpp:50
CGTYPE
Type of the command group.
Definition: cg.hpp:56
The SYCL device class encapsulates a single SYCL device on which kernels may be executed.
Definition: device.hpp:64
backend get_backend() const noexcept
Returns the backend associated with this device.
Definition: device.cpp:215
bool has(aspect Aspect) const __SYCL_WARN_IMAGE_ASPECT(Aspect)
Indicates if the SYCL device has the given feature.
Definition: device.cpp:219
An event object can be used to synchronize memory transfers, enqueues of kernels and signaling barrie...
Definition: event.hpp:44
dynamic_parameter_impl(std::shared_ptr< graph_impl > GraphImpl, size_t ParamSize, const void *Data)
std::vector< std::pair< std::weak_ptr< node_impl >, int > > MNodes
void * getValue()
Get a pointer to the internal value of this dynamic parameter.
void registerNode(std::shared_ptr< node_impl > NodeImpl, int ArgIndex)
Register a node with this dynamic parameter.
void updateValue(const void *NewValue, size_t Size)
Update the internal value of this dynamic parameter as well as the value of this parameter in all reg...
void updateAccessor(const sycl::detail::AccessorBaseHost *Acc)
Update the internal value of this dynamic parameter as well as the value of this parameter in all reg...
Class representing the implementation of command_graph<executable>.
exec_graph_impl(sycl::context Context, const std::shared_ptr< graph_impl > &GraphImpl, const property_list &PropList)
Constructor.
Definition: graph_impl.cpp:757
void createCommandBuffers(sycl::device Device, std::shared_ptr< partition > &Partition)
Turns the internal graph representation into UR command-buffers for a device.
Definition: graph_impl.cpp:696
std::vector< sycl::detail::AccessorImplHost * > getRequirements() const
Returns a list of all the accessor requirements for this graph.
const std::vector< std::shared_ptr< partition > > & getPartitions() const
Query the vector of the partitions composing the exec_graph.
void update(std::shared_ptr< graph_impl > GraphImpl)
sycl::context getContext() const
Query for the context tied to this graph.
void makePartitions()
Partition the graph nodes and put the partition in MPartitions.
Definition: graph_impl.cpp:183
void updateImpl(std::shared_ptr< node_impl > NodeImpl)
sycl::device getDevice() const
Query for the device tied to this graph.
const std::list< std::shared_ptr< node_impl > > & getSchedule() const
Query the scheduling of node execution.
const std::shared_ptr< graph_impl > & getGraphImpl() const
Query the graph_impl.
std::shared_mutex MMutex
Protects all the fields that can be changed by class' methods.
bool previousSubmissionCompleted() const
Checks if the previous submissions of this graph have been completed This function checks the status ...
sycl::event enqueue(const std::shared_ptr< sycl::detail::queue_impl > &Queue, sycl::detail::CG::StorageInitHelper CGData)
Called by handler::ext_oneapi_command_graph() to schedule graph for execution.
Definition: graph_impl.cpp:812
Implementation details of command_graph<modifiable>.
Definition: graph_impl.hpp:794
std::vector< std::shared_ptr< node_impl > > MNodeStorage
Storage for all nodes contained within a graph.
Definition: graph_impl.hpp:959
void addQueue(const std::shared_ptr< sycl::detail::queue_impl > &RecordingQueue)
Add a queue to the set of queues which are currently recording to this graph.
Definition: graph_impl.hpp:876
void removeRoot(const std::shared_ptr< node_impl > &Root)
Remove node from list of root nodes.
Definition: graph_impl.cpp:343
std::unique_lock< std::shared_mutex > WriteLock
Definition: graph_impl.hpp:797
void makeEdge(std::shared_ptr< node_impl > Src, std::shared_ptr< node_impl > Dest)
Make an edge between two nodes in the graph.
Definition: graph_impl.cpp:560
std::shared_mutex MMutex
Protects all the fields that can be changed by class' methods.
Definition: graph_impl.hpp:800
void printGraphAsDot(const std::string FilePath, bool Verbose) const
Prints the contents of the graph to a text file in DOT format.
Definition: graph_impl.hpp:987
std::vector< sycl::detail::EventImplPtr > getExitNodesEvents()
Traverse the graph recursively to get the events associated with the output nodes of this graph.
Definition: graph_impl.cpp:611
void throwIfGraphRecordingQueue(const std::string ExceptionMsg) const
Throws an invalid exception if this function is called while a queue is recording commands to the gra...
std::shared_ptr< sycl::detail::event_impl > getEventForNode(std::shared_ptr< node_impl > NodeImpl) const
Find the sycl event associated with a node.
Definition: graph_impl.hpp:911
void setLastInorderNode(std::shared_ptr< sycl::detail::queue_impl > Queue, std::shared_ptr< node_impl > Node)
Track the last node added to this graph from an in-order queue.
Definition: graph_impl.hpp:977
std::set< std::weak_ptr< node_impl >, std::owner_less< std::weak_ptr< node_impl > > > MRoots
List of root nodes.
Definition: graph_impl.hpp:953
bool hasSimilarStructure(std::shared_ptr< detail::graph_impl > Graph, bool DebugPrint=false) const
Checks if the graph_impl of Graph has a similar structure to the graph_impl of the caller.
std::vector< sycl::detail::EventImplPtr > removeBarriersFromExtraDependencies()
Removes all Barrier nodes from the list of extra dependencies MExtraDependencies.
bool clearQueues()
Remove all queues which are recording to this graph, also sets all queues cleared back to the executi...
Definition: graph_impl.cpp:506
sycl::device getDevice() const
Query for the device tied to this graph.
Definition: graph_impl.hpp:949
std::shared_ptr< node_impl > getNodeForEvent(std::shared_ptr< sycl::detail::event_impl > EventImpl)
Find the node associated with a SYCL event.
Definition: graph_impl.hpp:930
static bool checkNodeRecursive(const std::shared_ptr< node_impl > &NodeA, const std::shared_ptr< node_impl > &NodeB)
Recursively check successors of NodeA and NodeB to check they are similar.
graph_impl(const sycl::context &SyclContext, const sycl::device &SyclDevice, const sycl::property_list &PropList={})
Constructor.
Definition: graph_impl.hpp:806
sycl::context getContext() const
Query for the context tied to this graph.
Definition: graph_impl.hpp:945
std::shared_ptr< node_impl > add(node_type NodeType, std::unique_ptr< sycl::detail::CG > CommandGroup, const std::vector< std::shared_ptr< node_impl >> &Dep={})
Create a kernel node in the graph.
Definition: graph_impl.cpp:433
void addEventForNode(std::shared_ptr< graph_impl > GraphImpl, std::shared_ptr< sycl::detail::event_impl > EventImpl, std::shared_ptr< node_impl > NodeImpl)
Associate a sycl event with a node in the graph.
Definition: graph_impl.hpp:899
std::shared_lock< std::shared_mutex > ReadLock
Definition: graph_impl.hpp:796
size_t getNumberOfNodes() const
Returns the number of nodes in the Graph.
std::shared_ptr< node_impl > getLastInorderNode(std::shared_ptr< sycl::detail::queue_impl > Queue)
Find the last node added to this graph from an in-order queue.
Definition: graph_impl.hpp:966
void removeQueue(const std::shared_ptr< sycl::detail::queue_impl > &RecordingQueue)
Remove a queue from the set of queues which are currently recording to this graph.
Definition: graph_impl.hpp:884
Implementation of node class from SYCL_EXT_ONEAPI_GRAPH.
Definition: graph_impl.hpp:77
void updateAccessor(int ArgIndex, const sycl::detail::AccessorBaseHost *Acc)
Update the value of an accessor inside this node.
Definition: graph_impl.hpp:390
void registerSuccessor(const std::shared_ptr< node_impl > &Node, const std::shared_ptr< node_impl > &Prev)
Add successor to the node.
Definition: graph_impl.hpp:116
std::vector< std::weak_ptr< node_impl > > MPredecessors
List of predecessors to this node.
Definition: graph_impl.hpp:88
node_impl(node_impl &Other)
Construct a node from another node.
Definition: graph_impl.hpp:160
std::vector< std::weak_ptr< node_impl > > MSuccessors
List of successors to this node.
Definition: graph_impl.hpp:84
void updateFromOtherNode(const std::shared_ptr< node_impl > &Other)
Definition: graph_impl.hpp:508
sycl::detail::CG::CGTYPE MCGType
Type of the command-group for the node.
Definition: graph_impl.hpp:90
std::shared_ptr< exec_graph_impl > MSubGraphImpl
Stores the executable graph impl associated with this node if it is a subgraph node.
Definition: graph_impl.hpp:97
node_impl(node_type NodeType, std::unique_ptr< sycl::detail::CG > &&CommandGroup)
Construct a node representing a command-group.
Definition: graph_impl.hpp:147
bool MNDRangeUsed
Track whether an ND-Range was used for kernel nodes.
Definition: graph_impl.hpp:108
bool hasRequirementDependency(sycl::detail::AccessorImplHost *IncomingReq)
Checks if this node should be a dependency of another node based on accessor requirements.
Definition: graph_impl.hpp:183
int MPartitionNum
Partition number needed to assign a Node to a a partition.
Definition: graph_impl.hpp:105
void updateRange(range< Dimensions > ExecutionRange)
Definition: graph_impl.hpp:482
void updateNDRange(nd_range< Dimensions > ExecutionRange)
Definition: graph_impl.hpp:456
bool isSimilar(const std::shared_ptr< node_impl > &Node, bool CompareContentOnly=false) const
Tests if the caller is similar to Node, this is only used for testing.
Definition: graph_impl.hpp:314
void printDotRecursive(std::fstream &Stream, std::vector< node_impl * > &Visited, bool Verbose)
Recursive Depth first traversal of linked nodes.
Definition: graph_impl.hpp:365
node_type MNodeType
User facing type of the node.
Definition: graph_impl.hpp:92
std::unique_ptr< sycl::detail::CG > getCGCopy() const
Get a deep copy of this node's command group.
Definition: graph_impl.hpp:224
void updateArgValue(int ArgIndex, const void *NewValue, size_t Size)
Definition: graph_impl.hpp:439
bool isEmpty() const
Query if this is an empty node.
Definition: graph_impl.hpp:217
void registerPredecessor(const std::shared_ptr< node_impl > &Node)
Add predecessor to the node.
Definition: graph_impl.hpp:130
bool MVisited
Used for tracking visited status during cycle checks.
Definition: graph_impl.hpp:100
std::unique_ptr< sycl::detail::CG > MCommandGroup
Command group object which stores all args etc needed to enqueue the node.
Definition: graph_impl.hpp:94
id_type MID
Unique identifier for this node.
Definition: graph_impl.hpp:82
node_impl & operator=(node_impl &Other)
Copy-assignment operator.
Definition: graph_impl.hpp:167
std::unordered_map< sycl::device, sycl::detail::pi::PiExtCommandBuffer > MPiCommandBuffers
Map of devices to command buffers.
Definition: graph_impl.hpp:779
std::set< std::weak_ptr< node_impl >, std::owner_less< std::weak_ptr< node_impl > > > MRoots
List of root nodes.
Definition: graph_impl.hpp:774
std::vector< std::shared_ptr< partition > > MPredecessors
List of predecessors to this partition.
Definition: graph_impl.hpp:781
std::list< std::shared_ptr< node_impl > > MSchedule
Execution schedule of nodes in the graph.
Definition: graph_impl.hpp:776
Property passed to command_graph constructor to disable checking for cycles.
Definition: graph.hpp:147
Command group handler class.
Definition: handler.hpp:458
Defines the iteration domain of both the work-groups and the overall dispatch.
Definition: nd_range.hpp:22
Objects of the property_list class are containers for the SYCL properties.
::pi_ext_sync_point PiExtSyncPoint
Definition: pi.hpp:156
::pi_ext_command_buffer_command PiExtCommandBufferCommand
Definition: pi.hpp:159
decltype(Obj::impl) getSyclObjImpl(const Obj &SyclObject)
Definition: impl_utils.hpp:30
AccessorImplHost Requirement
std::shared_ptr< device_impl > DeviceImplPtr
node_type getNodeTypeFromCG(sycl::detail::CG::CGTYPE CGType)
Definition: graph_impl.hpp:41
class __SYCL_EBO __SYCL_SPECIAL_CLASS Dimensions
constexpr mode_tag_t< access_mode::read_write > read_write
Definition: access.hpp:85
std::error_code make_error_code(sycl::errc E) noexcept
Constructs an error code using e and sycl_category()
Definition: exception.cpp:91
Definition: access.hpp:18
std::vector< std::vector< char > > MArgsStorage
Storage for standard layout arguments.
Definition: cg.hpp:99