DPC++ Runtime
Runtime libraries for oneAPI DPC++
graph_impl.hpp
Go to the documentation of this file.
1 //==--------- graph_impl.hpp --- SYCL graph extension ---------------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #include <sycl/detail/cg_types.hpp>
12 #include <sycl/detail/os_util.hpp>
14 #include <sycl/handler.hpp>
15 
16 #include <detail/accessor_impl.hpp>
17 #include <detail/cg.hpp>
18 #include <detail/event_impl.hpp>
19 #include <detail/host_task.hpp>
20 #include <detail/kernel_impl.hpp>
22 
23 #include <cstring>
24 #include <deque>
25 #include <fstream>
26 #include <functional>
27 #include <iomanip>
28 #include <list>
29 #include <set>
30 #include <shared_mutex>
31 
32 namespace sycl {
33 inline namespace _V1 {
34 
35 namespace detail {
36 class SYCLMemObjT;
37 }
38 
39 namespace ext {
40 namespace oneapi {
41 namespace experimental {
42 namespace detail {
43 
45  using sycl::detail::CG;
46 
47  switch (CGType) {
48  case sycl::detail::CGType::None:
49  return node_type::empty;
50  case sycl::detail::CGType::Kernel:
51  return node_type::kernel;
52  case sycl::detail::CGType::CopyAccToPtr:
53  case sycl::detail::CGType::CopyPtrToAcc:
54  case sycl::detail::CGType::CopyAccToAcc:
55  case sycl::detail::CGType::CopyUSM:
56  return node_type::memcpy;
57  case sycl::detail::CGType::Memset2DUSM:
58  return node_type::memset;
59  case sycl::detail::CGType::Fill:
60  case sycl::detail::CGType::FillUSM:
61  return node_type::memfill;
62  case sycl::detail::CGType::PrefetchUSM:
63  return node_type::prefetch;
64  case sycl::detail::CGType::AdviseUSM:
65  return node_type::memadvise;
66  case sycl::detail::CGType::Barrier:
67  case sycl::detail::CGType::BarrierWaitlist:
69  case sycl::detail::CGType::CodeplayHostTask:
70  return node_type::host_task;
71  case sycl::detail::CGType::ExecCommandBuffer:
72  return node_type::subgraph;
73  default:
74  assert(false && "Invalid Graph Node Type");
75  return node_type::empty;
76  }
77 }
78 
80 class node_impl {
81 public:
82  using id_type = uint64_t;
83 
85  id_type MID = getNextNodeID();
87  std::vector<std::weak_ptr<node_impl>> MSuccessors;
91  std::vector<std::weak_ptr<node_impl>> MPredecessors;
93  sycl::detail::CGType MCGType = sycl::detail::CGType::None;
97  std::unique_ptr<sycl::detail::CG> MCommandGroup;
100  std::shared_ptr<exec_graph_impl> MSubGraphImpl;
101 
103  bool MVisited = false;
104 
108  int MPartitionNum = -1;
109 
111  bool MNDRangeUsed = false;
112 
119  void registerSuccessor(const std::shared_ptr<node_impl> &Node,
120  const std::shared_ptr<node_impl> &Prev) {
121  if (std::find_if(MSuccessors.begin(), MSuccessors.end(),
122  [Node](const std::weak_ptr<node_impl> &Ptr) {
123  return Ptr.lock() == Node;
124  }) != MSuccessors.end()) {
125  return;
126  }
127  MSuccessors.push_back(Node);
128  Node->registerPredecessor(Prev);
129  }
130 
133  void registerPredecessor(const std::shared_ptr<node_impl> &Node) {
134  if (std::find_if(MPredecessors.begin(), MPredecessors.end(),
135  [&Node](const std::weak_ptr<node_impl> &Ptr) {
136  return Ptr.lock() == Node;
137  }) != MPredecessors.end()) {
138  return;
139  }
140  MPredecessors.push_back(Node);
141  }
142 
145 
151  std::unique_ptr<sycl::detail::CG> &&CommandGroup)
152  : MCGType(CommandGroup->getType()), MNodeType(NodeType),
153  MCommandGroup(std::move(CommandGroup)) {
154  if (NodeType == node_type::subgraph) {
155  MSubGraphImpl =
156  static_cast<sycl::detail::CGExecCommandBuffer *>(MCommandGroup.get())
157  ->MExecGraph;
158  }
159  }
160 
165  MCGType(Other.MCGType), MNodeType(Other.MNodeType),
167 
171  if (this != &Other) {
172  MSuccessors = Other.MSuccessors;
174  MCGType = Other.MCGType;
175  MNodeType = Other.MNodeType;
176  MCommandGroup = Other.getCGCopy();
178  }
179  return *this;
180  }
186  bool hasRequirementDependency(sycl::detail::AccessorImplHost *IncomingReq) {
187  if (!MCommandGroup)
188  return false;
189 
190  access_mode InMode = IncomingReq->MAccessMode;
191  switch (InMode) {
192  case access_mode::read:
194  case access_mode::atomic:
195  break;
196  // These access modes don't care about existing buffer data, so we don't
197  // need a dependency.
198  case access_mode::write:
199  case access_mode::discard_read_write:
200  case access_mode::discard_write:
201  return false;
202  }
203 
204  for (sycl::detail::AccessorImplHost *CurrentReq :
205  MCommandGroup->getRequirements()) {
206  if (IncomingReq->MSYCLMemObj == CurrentReq->MSYCLMemObj) {
207  access_mode CurrentMode = CurrentReq->MAccessMode;
208  // Since we have an incoming read requirement, we only care
209  // about requirements on this node if they are write
210  if (CurrentMode != access_mode::read) {
211  return true;
212  }
213  }
214  }
215  // No dependency necessary
216  return false;
217  }
218 
223  bool isEmpty() const {
224  return ((MCGType == sycl::detail::CGType::None) ||
225  (MCGType == sycl::detail::CGType::Barrier));
226  }
227 
230  std::unique_ptr<sycl::detail::CG> getCGCopy() const {
231  switch (MCGType) {
232  case sycl::detail::CGType::Kernel: {
233  auto CGCopy = createCGCopy<sycl::detail::CGExecKernel>();
234  rebuildArgStorage(CGCopy->MArgs, MCommandGroup->getArgsStorage(),
235  CGCopy->getArgsStorage());
236  return std::move(CGCopy);
237  }
238  case sycl::detail::CGType::CopyAccToPtr:
239  case sycl::detail::CGType::CopyPtrToAcc:
240  case sycl::detail::CGType::CopyAccToAcc:
241  return createCGCopy<sycl::detail::CGCopy>();
242  case sycl::detail::CGType::Fill:
243  return createCGCopy<sycl::detail::CGFill>();
244  case sycl::detail::CGType::UpdateHost:
245  return createCGCopy<sycl::detail::CGUpdateHost>();
246  case sycl::detail::CGType::CopyUSM:
247  return createCGCopy<sycl::detail::CGCopyUSM>();
248  case sycl::detail::CGType::FillUSM:
249  return createCGCopy<sycl::detail::CGFillUSM>();
250  case sycl::detail::CGType::PrefetchUSM:
251  return createCGCopy<sycl::detail::CGPrefetchUSM>();
252  case sycl::detail::CGType::AdviseUSM:
253  return createCGCopy<sycl::detail::CGAdviseUSM>();
254  case sycl::detail::CGType::Copy2DUSM:
255  return createCGCopy<sycl::detail::CGCopy2DUSM>();
256  case sycl::detail::CGType::Fill2DUSM:
257  return createCGCopy<sycl::detail::CGFill2DUSM>();
258  case sycl::detail::CGType::Memset2DUSM:
259  return createCGCopy<sycl::detail::CGMemset2DUSM>();
260  case sycl::detail::CGType::EnqueueNativeCommand:
261  case sycl::detail::CGType::CodeplayHostTask: {
262  // The unique_ptr to the `sycl::detail::HostTask`, which is also used for
263  // a EnqueueNativeCommand command, in the HostTask CG prevents from
264  // copying the CG. We overcome this restriction by creating a new CG with
265  // the same data.
266  auto CommandGroupPtr =
267  static_cast<sycl::detail::CGHostTask *>(MCommandGroup.get());
268  sycl::detail::HostTask HostTask = *CommandGroupPtr->MHostTask.get();
269  auto HostTaskSPtr = std::make_shared<sycl::detail::HostTask>(HostTask);
270 
272  CommandGroupPtr->getArgsStorage(), CommandGroupPtr->getAccStorage(),
273  CommandGroupPtr->getSharedPtrStorage(),
274  CommandGroupPtr->getRequirements(), CommandGroupPtr->getEvents());
275 
276  std::vector<sycl::detail::ArgDesc> NewArgs = CommandGroupPtr->MArgs;
277 
278  rebuildArgStorage(NewArgs, CommandGroupPtr->getArgsStorage(),
279  Data.MArgsStorage);
280 
281  sycl::detail::code_location Loc(CommandGroupPtr->MFileName.data(),
282  CommandGroupPtr->MFunctionName.data(),
283  CommandGroupPtr->MLine,
284  CommandGroupPtr->MColumn);
285 
286  return std::make_unique<sycl::detail::CGHostTask>(
287  sycl::detail::CGHostTask(
288  std::move(HostTaskSPtr), CommandGroupPtr->MQueue,
289  CommandGroupPtr->MContext, std::move(NewArgs), std::move(Data),
290  CommandGroupPtr->getType(), Loc));
291  }
292  case sycl::detail::CGType::Barrier:
293  case sycl::detail::CGType::BarrierWaitlist:
294  // Barrier nodes are stored in the graph with only the base CG class,
295  // since they are treated internally as empty nodes.
296  return createCGCopy<sycl::detail::CG>();
297  case sycl::detail::CGType::CopyToDeviceGlobal:
298  return createCGCopy<sycl::detail::CGCopyToDeviceGlobal>();
299  case sycl::detail::CGType::CopyFromDeviceGlobal:
300  return createCGCopy<sycl::detail::CGCopyFromDeviceGlobal>();
301  case sycl::detail::CGType::ReadWriteHostPipe:
302  return createCGCopy<sycl::detail::CGReadWriteHostPipe>();
303  case sycl::detail::CGType::CopyImage:
304  return createCGCopy<sycl::detail::CGCopyImage>();
305  case sycl::detail::CGType::SemaphoreSignal:
306  return createCGCopy<sycl::detail::CGSemaphoreSignal>();
307  case sycl::detail::CGType::SemaphoreWait:
308  return createCGCopy<sycl::detail::CGSemaphoreWait>();
309  case sycl::detail::CGType::ProfilingTag:
310  return createCGCopy<sycl::detail::CGProfilingTag>();
311  case sycl::detail::CGType::ExecCommandBuffer:
312  return createCGCopy<sycl::detail::CGExecCommandBuffer>();
313  case sycl::detail::CGType::None:
314  return nullptr;
315  }
316  return nullptr;
317  }
318 
324  bool isSimilar(const std::shared_ptr<node_impl> &Node,
325  bool CompareContentOnly = false) const {
326  if (!CompareContentOnly) {
327  if (MSuccessors.size() != Node->MSuccessors.size())
328  return false;
329 
330  if (MPredecessors.size() != Node->MPredecessors.size())
331  return false;
332  }
333  if (MCGType != Node->MCGType)
334  return false;
335 
336  switch (MCGType) {
337  case sycl::detail::CGType::Kernel: {
338  sycl::detail::CGExecKernel *ExecKernelA =
339  static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get());
340  sycl::detail::CGExecKernel *ExecKernelB =
341  static_cast<sycl::detail::CGExecKernel *>(Node->MCommandGroup.get());
342  return ExecKernelA->MKernelName.compare(ExecKernelB->MKernelName) == 0;
343  }
344  case sycl::detail::CGType::CopyUSM: {
345  sycl::detail::CGCopyUSM *CopyA =
346  static_cast<sycl::detail::CGCopyUSM *>(MCommandGroup.get());
347  sycl::detail::CGCopyUSM *CopyB =
348  static_cast<sycl::detail::CGCopyUSM *>(Node->MCommandGroup.get());
349  return (CopyA->getSrc() == CopyB->getSrc()) &&
350  (CopyA->getDst() == CopyB->getDst()) &&
351  (CopyA->getLength() == CopyB->getLength());
352  }
353  case sycl::detail::CGType::CopyAccToAcc:
354  case sycl::detail::CGType::CopyAccToPtr:
355  case sycl::detail::CGType::CopyPtrToAcc: {
356  sycl::detail::CGCopy *CopyA =
357  static_cast<sycl::detail::CGCopy *>(MCommandGroup.get());
358  sycl::detail::CGCopy *CopyB =
359  static_cast<sycl::detail::CGCopy *>(Node->MCommandGroup.get());
360  return (CopyA->getSrc() == CopyB->getSrc()) &&
361  (CopyA->getDst() == CopyB->getDst());
362  }
363  default:
364  assert(false && "Unexpected command group type!");
365  return false;
366  }
367  }
368 
375  void printDotRecursive(std::fstream &Stream,
376  std::vector<node_impl *> &Visited, bool Verbose) {
377  // if Node has been already visited, we skip it
378  if (std::find(Visited.begin(), Visited.end(), this) != Visited.end())
379  return;
380 
381  Visited.push_back(this);
382 
383  printDotCG(Stream, Verbose);
384  for (const auto &Dep : MPredecessors) {
385  auto NodeDep = Dep.lock();
386  Stream << " \"" << NodeDep.get() << "\" -> \"" << this << "\""
387  << std::endl;
388  }
389 
390  for (std::weak_ptr<node_impl> Succ : MSuccessors) {
391  if (MPartitionNum == Succ.lock()->MPartitionNum)
392  Succ.lock()->printDotRecursive(Stream, Visited, Verbose);
393  }
394  }
395 
398  bool isNDCopyNode() const {
399  if ((MCGType != sycl::detail::CGType::CopyAccToAcc) &&
400  (MCGType != sycl::detail::CGType::CopyAccToPtr) &&
401  (MCGType != sycl::detail::CGType::CopyPtrToAcc)) {
402  return false;
403  }
404 
405  auto Copy = static_cast<sycl::detail::CGCopy *>(MCommandGroup.get());
406  auto ReqSrc = static_cast<sycl::detail::Requirement *>(Copy->getSrc());
407  auto ReqDst = static_cast<sycl::detail::Requirement *>(Copy->getDst());
408  return (ReqSrc->MDims > 1) || (ReqDst->MDims > 1);
409  }
410 
415  void updateAccessor(int ArgIndex, const sycl::detail::AccessorBaseHost *Acc) {
416  auto &Args =
417  static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get())->MArgs;
418  auto NewAccImpl = sycl::detail::getSyclObjImpl(*Acc);
419  for (auto &Arg : Args) {
420  if (Arg.MIndex != ArgIndex) {
421  continue;
422  }
423  assert(Arg.MType == sycl::detail::kernel_param_kind_t::kind_accessor);
424 
425  // Find old accessor in accessor storage and replace with new one
426  if (static_cast<sycl::detail::SYCLMemObjT *>(NewAccImpl->MSYCLMemObj)
427  ->needsWriteBack()) {
428  throw sycl::exception(
430  "Accessors to buffers which have write_back enabled "
431  "are not allowed to be used in command graphs.");
432  }
433 
434  // All accessors passed to this function will be placeholders, so we must
435  // perform steps similar to what happens when handler::require() is
436  // called here.
437  sycl::detail::Requirement *NewReq = NewAccImpl.get();
438  if (NewReq->MAccessMode != sycl::access_mode::read) {
439  auto SYCLMemObj =
440  static_cast<sycl::detail::SYCLMemObjT *>(NewReq->MSYCLMemObj);
441  SYCLMemObj->handleWriteAccessorCreation();
442  }
443 
444  for (auto &Acc : MCommandGroup->getAccStorage()) {
445  if (auto OldAcc =
446  static_cast<sycl::detail::AccessorImplHost *>(Arg.MPtr);
447  Acc.get() == OldAcc) {
448  Acc = NewAccImpl;
449  }
450  }
451 
452  for (auto &Req : MCommandGroup->getRequirements()) {
453  if (auto OldReq =
454  static_cast<sycl::detail::AccessorImplHost *>(Arg.MPtr);
455  Req == OldReq) {
456  Req = NewReq;
457  }
458  }
459  Arg.MPtr = NewAccImpl.get();
460  break;
461  }
462  }
463 
464  void updateArgValue(int ArgIndex, const void *NewValue, size_t Size) {
465 
466  auto &Args =
467  static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get())->MArgs;
468  for (auto &Arg : Args) {
469  if (Arg.MIndex != ArgIndex) {
470  continue;
471  }
472  assert(Arg.MSize == static_cast<int>(Size));
473  // MPtr may be a pointer into arg storage so we memcpy the contents of
474  // NewValue rather than assign it directly
475  std::memcpy(Arg.MPtr, NewValue, Size);
476  break;
477  }
478  }
479 
480  template <int Dimensions>
481  void updateNDRange(nd_range<Dimensions> ExecutionRange) {
482  if (MCGType != sycl::detail::CGType::Kernel) {
483  throw sycl::exception(
484  sycl::errc::invalid,
485  "Cannot update execution range of nodes which are not kernel nodes");
486  }
487  if (!MNDRangeUsed) {
488  throw sycl::exception(sycl::errc::invalid,
489  "Cannot update node which was created with a "
490  "sycl::range with a sycl::nd_range");
491  }
492 
493  auto &NDRDesc =
494  static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get())
495  ->MNDRDesc;
496 
497  if (NDRDesc.Dims != Dimensions) {
498  throw sycl::exception(sycl::errc::invalid,
499  "Cannot update execution range of a node with an "
500  "execution range of different dimensions than what "
501  "the node was originall created with.");
502  }
503 
504  NDRDesc = sycl::detail::NDRDescT{ExecutionRange};
505  }
506 
507  template <int Dimensions> void updateRange(range<Dimensions> ExecutionRange) {
508  if (MCGType != sycl::detail::CGType::Kernel) {
509  throw sycl::exception(
510  sycl::errc::invalid,
511  "Cannot update execution range of nodes which are not kernel nodes");
512  }
513  if (MNDRangeUsed) {
514  throw sycl::exception(sycl::errc::invalid,
515  "Cannot update node which was created with a "
516  "sycl::nd_range with a sycl::range");
517  }
518 
519  auto &NDRDesc =
520  static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get())
521  ->MNDRDesc;
522 
523  if (NDRDesc.Dims != Dimensions) {
524  throw sycl::exception(sycl::errc::invalid,
525  "Cannot update execution range of a node with an "
526  "execution range of different dimensions than what "
527  "the node was originall created with.");
528  }
529 
530  NDRDesc = sycl::detail::NDRDescT{ExecutionRange};
531  }
532 
533  void updateFromOtherNode(const std::shared_ptr<node_impl> &Other) {
534  auto ExecCG =
535  static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get());
536  auto OtherExecCG =
537  static_cast<sycl::detail::CGExecKernel *>(Other->MCommandGroup.get());
538 
539  ExecCG->MArgs = OtherExecCG->MArgs;
540  ExecCG->MNDRDesc = OtherExecCG->MNDRDesc;
541  ExecCG->getAccStorage() = OtherExecCG->getAccStorage();
542  ExecCG->getRequirements() = OtherExecCG->getRequirements();
543 
544  auto &OldArgStorage = OtherExecCG->getArgsStorage();
545  auto &NewArgStorage = ExecCG->getArgsStorage();
546  // Rebuild the arg storage and update the args
547  rebuildArgStorage(ExecCG->MArgs, OldArgStorage, NewArgStorage);
548  }
549 
550  id_type getID() const { return MID; }
551 
552 private:
553  void rebuildArgStorage(std::vector<sycl::detail::ArgDesc> &Args,
554  const std::vector<std::vector<char>> &OldArgStorage,
555  std::vector<std::vector<char>> &NewArgStorage) const {
556  // Clear the arg storage so we can rebuild it
557  NewArgStorage.clear();
558 
559  // Loop over all the args, any std_layout ones need their pointers updated
560  // to point to the new arg storage.
561  for (auto &Arg : Args) {
562  if (Arg.MType != sycl::detail::kernel_param_kind_t::kind_std_layout) {
563  continue;
564  }
565  // Find which ArgStorage Arg.MPtr is pointing to
566  for (auto &ArgStorage : OldArgStorage) {
567  if (ArgStorage.data() != Arg.MPtr) {
568  continue;
569  }
570  NewArgStorage.emplace_back(Arg.MSize);
571  // Memcpy contents from old storage to new storage
572  std::memcpy(NewArgStorage.back().data(), ArgStorage.data(), Arg.MSize);
573  // Update MPtr to point to the new storage instead of the old
574  Arg.MPtr = NewArgStorage.back().data();
575 
576  break;
577  }
578  }
579  }
580  // Gets the next unique identifier for a node, should only be used when
581  // constructing nodes.
582  static id_type getNextNodeID() {
583  static id_type nextID = 0;
584 
585  // Return the value then increment the next ID
586  return nextID++;
587  }
588 
593  void printDotCG(std::ostream &Stream, bool Verbose) {
594  Stream << "\"" << this << "\" [style=bold, label=\"";
595 
596  Stream << "ID = " << this << "\\n";
597  Stream << "TYPE = ";
598 
599  switch (MCGType) {
600  case sycl::detail::CGType::None:
601  Stream << "None \\n";
602  break;
603  case sycl::detail::CGType::Kernel: {
604  Stream << "CGExecKernel \\n";
605  sycl::detail::CGExecKernel *Kernel =
606  static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get());
607  Stream << "NAME = " << Kernel->MKernelName << "\\n";
608  if (Verbose) {
609  Stream << "ARGS = \\n";
610  for (size_t i = 0; i < Kernel->MArgs.size(); i++) {
611  auto Arg = Kernel->MArgs[i];
612  std::string Type = "Undefined";
613  if (Arg.MType == sycl::detail::kernel_param_kind_t::kind_accessor) {
614  Type = "Accessor";
615  } else if (Arg.MType ==
616  sycl::detail::kernel_param_kind_t::kind_std_layout) {
617  Type = "STD_Layout";
618  } else if (Arg.MType ==
619  sycl::detail::kernel_param_kind_t::kind_sampler) {
620  Type = "Sampler";
621  } else if (Arg.MType ==
622  sycl::detail::kernel_param_kind_t::kind_pointer) {
623  Type = "Pointer";
624  auto Fill = Stream.fill();
625  Stream << i << ") Type: " << Type << " Ptr: " << Arg.MPtr << "(0x"
626  << std::hex << std::setfill('0');
627  for (int i = Arg.MSize - 1; i >= 0; --i) {
628  Stream << std::setw(2)
629  << static_cast<int16_t>(
630  (static_cast<unsigned char *>(Arg.MPtr))[i]);
631  }
632  Stream.fill(Fill);
633  Stream << std::dec << ")\\n";
634  continue;
635  } else if (Arg.MType == sycl::detail::kernel_param_kind_t::
636  kind_specialization_constants_buffer) {
637  Type = "Specialization Constants Buffer";
638  } else if (Arg.MType ==
639  sycl::detail::kernel_param_kind_t::kind_stream) {
640  Type = "Stream";
641  } else if (Arg.MType ==
642  sycl::detail::kernel_param_kind_t::kind_invalid) {
643  Type = "Invalid";
644  }
645  Stream << i << ") Type: " << Type << " Ptr: " << Arg.MPtr << "\\n";
646  }
647  }
648  break;
649  }
650  case sycl::detail::CGType::CopyAccToPtr:
651  Stream << "CGCopy Device-to-Host \\n";
652  if (Verbose) {
653  sycl::detail::CGCopy *Copy =
654  static_cast<sycl::detail::CGCopy *>(MCommandGroup.get());
655  Stream << "Src: " << Copy->getSrc() << " Dst: " << Copy->getDst()
656  << "\\n";
657  }
658  break;
659  case sycl::detail::CGType::CopyPtrToAcc:
660  Stream << "CGCopy Host-to-Device \\n";
661  if (Verbose) {
662  sycl::detail::CGCopy *Copy =
663  static_cast<sycl::detail::CGCopy *>(MCommandGroup.get());
664  Stream << "Src: " << Copy->getSrc() << " Dst: " << Copy->getDst()
665  << "\\n";
666  }
667  break;
668  case sycl::detail::CGType::CopyAccToAcc:
669  Stream << "CGCopy Device-to-Device \\n";
670  if (Verbose) {
671  sycl::detail::CGCopy *Copy =
672  static_cast<sycl::detail::CGCopy *>(MCommandGroup.get());
673  Stream << "Src: " << Copy->getSrc() << " Dst: " << Copy->getDst()
674  << "\\n";
675  }
676  break;
677  case sycl::detail::CGType::Fill:
678  Stream << "CGFill \\n";
679  if (Verbose) {
680  sycl::detail::CGFill *Fill =
681  static_cast<sycl::detail::CGFill *>(MCommandGroup.get());
682  Stream << "Ptr: " << Fill->MPtr << "\\n";
683  }
684  break;
685  case sycl::detail::CGType::UpdateHost:
686  Stream << "CGCUpdateHost \\n";
687  if (Verbose) {
688  sycl::detail::CGUpdateHost *Host =
689  static_cast<sycl::detail::CGUpdateHost *>(MCommandGroup.get());
690  Stream << "Ptr: " << Host->getReqToUpdate() << "\\n";
691  }
692  break;
693  case sycl::detail::CGType::CopyUSM:
694  Stream << "CGCopyUSM \\n";
695  if (Verbose) {
696  sycl::detail::CGCopyUSM *CopyUSM =
697  static_cast<sycl::detail::CGCopyUSM *>(MCommandGroup.get());
698  Stream << "Src: " << CopyUSM->getSrc() << " Dst: " << CopyUSM->getDst()
699  << " Length: " << CopyUSM->getLength() << "\\n";
700  }
701  break;
702  case sycl::detail::CGType::FillUSM:
703  Stream << "CGFillUSM \\n";
704  if (Verbose) {
705  sycl::detail::CGFillUSM *FillUSM =
706  static_cast<sycl::detail::CGFillUSM *>(MCommandGroup.get());
707  Stream << "Dst: " << FillUSM->getDst()
708  << " Length: " << FillUSM->getLength() << " Pattern: ";
709  for (auto byte : FillUSM->getPattern())
710  Stream << byte;
711  Stream << "\\n";
712  }
713  break;
714  case sycl::detail::CGType::PrefetchUSM:
715  Stream << "CGPrefetchUSM \\n";
716  if (Verbose) {
717  sycl::detail::CGPrefetchUSM *Prefetch =
718  static_cast<sycl::detail::CGPrefetchUSM *>(MCommandGroup.get());
719  Stream << "Dst: " << Prefetch->getDst()
720  << " Length: " << Prefetch->getLength() << "\\n";
721  }
722  break;
723  case sycl::detail::CGType::AdviseUSM:
724  Stream << "CGAdviseUSM \\n";
725  if (Verbose) {
726  sycl::detail::CGAdviseUSM *AdviseUSM =
727  static_cast<sycl::detail::CGAdviseUSM *>(MCommandGroup.get());
728  Stream << "Dst: " << AdviseUSM->getDst()
729  << " Length: " << AdviseUSM->getLength() << "\\n";
730  }
731  break;
732  case sycl::detail::CGType::CodeplayHostTask:
733  Stream << "CGHostTask \\n";
734  break;
735  case sycl::detail::CGType::Barrier:
736  Stream << "CGBarrier \\n";
737  break;
738  case sycl::detail::CGType::Copy2DUSM:
739  Stream << "CGCopy2DUSM \\n";
740  if (Verbose) {
741  sycl::detail::CGCopy2DUSM *Copy2DUSM =
742  static_cast<sycl::detail::CGCopy2DUSM *>(MCommandGroup.get());
743  Stream << "Src:" << Copy2DUSM->getSrc()
744  << " Dst: " << Copy2DUSM->getDst() << "\\n";
745  }
746  break;
747  case sycl::detail::CGType::Fill2DUSM:
748  Stream << "CGFill2DUSM \\n";
749  if (Verbose) {
750  sycl::detail::CGFill2DUSM *Fill2DUSM =
751  static_cast<sycl::detail::CGFill2DUSM *>(MCommandGroup.get());
752  Stream << "Dst: " << Fill2DUSM->getDst() << "\\n";
753  }
754  break;
755  case sycl::detail::CGType::Memset2DUSM:
756  Stream << "CGMemset2DUSM \\n";
757  if (Verbose) {
758  sycl::detail::CGMemset2DUSM *Memset2DUSM =
759  static_cast<sycl::detail::CGMemset2DUSM *>(MCommandGroup.get());
760  Stream << "Dst: " << Memset2DUSM->getDst() << "\\n";
761  }
762  break;
763  case sycl::detail::CGType::ReadWriteHostPipe:
764  Stream << "CGReadWriteHostPipe \\n";
765  break;
766  case sycl::detail::CGType::CopyToDeviceGlobal:
767  Stream << "CGCopyToDeviceGlobal \\n";
768  if (Verbose) {
769  sycl::detail::CGCopyToDeviceGlobal *CopyToDeviceGlobal =
770  static_cast<sycl::detail::CGCopyToDeviceGlobal *>(
771  MCommandGroup.get());
772  Stream << "Src: " << CopyToDeviceGlobal->getSrc()
773  << " Dst: " << CopyToDeviceGlobal->getDeviceGlobalPtr() << "\\n";
774  }
775  break;
776  case sycl::detail::CGType::CopyFromDeviceGlobal:
777  Stream << "CGCopyFromDeviceGlobal \\n";
778  if (Verbose) {
779  sycl::detail::CGCopyFromDeviceGlobal *CopyFromDeviceGlobal =
780  static_cast<sycl::detail::CGCopyFromDeviceGlobal *>(
781  MCommandGroup.get());
782  Stream << "Src: " << CopyFromDeviceGlobal->getDeviceGlobalPtr()
783  << " Dst: " << CopyFromDeviceGlobal->getDest() << "\\n";
784  }
785  break;
786  case sycl::detail::CGType::ExecCommandBuffer:
787  Stream << "CGExecCommandBuffer \\n";
788  break;
789  default:
790  Stream << "Other \\n";
791  break;
792  }
793  Stream << "\"];" << std::endl;
794  }
795 
800  template <typename CGT> std::unique_ptr<CGT> createCGCopy() const {
801  return std::make_unique<CGT>(*static_cast<CGT *>(MCommandGroup.get()));
802  }
803 };
804 
805 class partition {
806 public:
809 
811  std::set<std::weak_ptr<node_impl>, std::owner_less<std::weak_ptr<node_impl>>>
814  std::list<std::shared_ptr<node_impl>> MSchedule;
816  std::unordered_map<sycl::device, ur_exp_command_buffer_handle_t>
819  std::vector<std::shared_ptr<partition>> MPredecessors;
822  bool MIsInOrderGraph = false;
823 
825  bool isHostTask() const {
826  return (MRoots.size() && ((*MRoots.begin()).lock()->MCGType ==
827  sycl::detail::CGType::CodeplayHostTask));
828  }
829 
833  if (MRoots.size() > 1) {
834  return false;
835  }
836  for (const auto &Node : MSchedule) {
837  // In version 1.3.28454 of the L0 driver, 2D Copy ops cannot not
838  // be enqueued in an in-order cmd-list (causing execution to stall).
839  // The 2D Copy test should be removed from here when the bug is fixed.
840  if ((Node->MSuccessors.size() > 1) || (Node->isNDCopyNode())) {
841  return false;
842  }
843  }
844 
845  return true;
846  }
847 
849  void schedule();
850 };
851 
853 class graph_impl {
854 public:
855  using ReadLock = std::shared_lock<std::shared_mutex>;
856  using WriteLock = std::unique_lock<std::shared_mutex>;
857 
859  mutable std::shared_mutex MMutex;
860 
865  graph_impl(const sycl::context &SyclContext, const sycl::device &SyclDevice,
866  const sycl::property_list &PropList = {})
867  : MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(),
868  MEventsMap(), MInorderQueueMap() {
869  if (PropList.has_property<property::graph::no_cycle_check>()) {
870  MSkipCycleChecks = true;
871  }
872  if (PropList
873  .has_property<property::graph::assume_buffer_outlives_graph>()) {
874  MAllowBuffers = true;
875  }
876 
877  if (!SyclDevice.has(aspect::ext_oneapi_limited_graph) &&
878  !SyclDevice.has(aspect::ext_oneapi_graph)) {
879  std::stringstream Stream;
880  Stream << SyclDevice.get_backend();
881  std::string BackendString = Stream.str();
882  throw sycl::exception(
884  BackendString + " backend is not supported by SYCL Graph extension.");
885  }
886  }
887 
888  ~graph_impl();
889 
892  void removeRoot(const std::shared_ptr<node_impl> &Root);
893 
899  std::shared_ptr<node_impl>
900  add(node_type NodeType, std::unique_ptr<sycl::detail::CG> CommandGroup,
901  const std::vector<std::shared_ptr<node_impl>> &Dep = {});
902 
909  std::shared_ptr<node_impl>
910  add(const std::shared_ptr<graph_impl> &Impl,
911  std::function<void(handler &)> CGF,
912  const std::vector<sycl::detail::ArgDesc> &Args,
913  const std::vector<std::shared_ptr<node_impl>> &Dep = {});
914 
919  std::shared_ptr<node_impl>
920  add(const std::shared_ptr<graph_impl> &Impl,
921  const std::vector<std::shared_ptr<node_impl>> &Dep = {});
922 
927  std::shared_ptr<node_impl>
928  add(const std::shared_ptr<graph_impl> &Impl,
929  const std::vector<sycl::detail::EventImplPtr> Events);
930 
934  void
935  addQueue(const std::shared_ptr<sycl::detail::queue_impl> &RecordingQueue) {
936  MRecordingQueues.insert(RecordingQueue);
937  }
938 
942  void
943  removeQueue(const std::shared_ptr<sycl::detail::queue_impl> &RecordingQueue) {
944  MRecordingQueues.erase(RecordingQueue);
945  }
946 
951  bool clearQueues();
952 
958  void addEventForNode(std::shared_ptr<graph_impl> GraphImpl,
959  std::shared_ptr<sycl::detail::event_impl> EventImpl,
960  std::shared_ptr<node_impl> NodeImpl) {
961  if (!(EventImpl->getCommandGraph()))
962  EventImpl->setCommandGraph(GraphImpl);
963  MEventsMap[EventImpl] = NodeImpl;
964  }
965 
969  std::shared_ptr<sycl::detail::event_impl>
970  getEventForNode(std::shared_ptr<node_impl> NodeImpl) const {
971  ReadLock Lock(MMutex);
972  if (auto EventImpl = std::find_if(
973  MEventsMap.begin(), MEventsMap.end(),
974  [NodeImpl](auto &it) { return it.second == NodeImpl; });
975  EventImpl != MEventsMap.end()) {
976  return EventImpl->first;
977  }
978 
979  throw sycl::exception(
981  "No event has been recorded for the specified graph node");
982  }
983 
988  std::shared_ptr<node_impl>
989  getNodeForEvent(std::shared_ptr<sycl::detail::event_impl> EventImpl) {
990  ReadLock Lock(MMutex);
991 
992  if (auto NodeFound = MEventsMap.find(EventImpl);
993  NodeFound != std::end(MEventsMap)) {
994  return NodeFound->second;
995  }
996 
997  throw sycl::exception(
999  "No node in this graph is associated with this event");
1000  }
1001 
1004  sycl::context getContext() const { return MContext; }
1005 
1008  sycl::device getDevice() const { return MDevice; }
1009 
1011  std::set<std::weak_ptr<node_impl>, std::owner_less<std::weak_ptr<node_impl>>>
1013 
1018  std::vector<std::shared_ptr<node_impl>> MNodeStorage;
1019 
1024  std::shared_ptr<node_impl>
1025  getLastInorderNode(std::shared_ptr<sycl::detail::queue_impl> Queue) {
1026  std::weak_ptr<sycl::detail::queue_impl> QueueWeakPtr(Queue);
1027  if (0 == MInorderQueueMap.count(QueueWeakPtr)) {
1028  return {};
1029  }
1030  return MInorderQueueMap[QueueWeakPtr];
1031  }
1032 
1036  void setLastInorderNode(std::shared_ptr<sycl::detail::queue_impl> Queue,
1037  std::shared_ptr<node_impl> Node) {
1038  std::weak_ptr<sycl::detail::queue_impl> QueueWeakPtr(Queue);
1039  MInorderQueueMap[QueueWeakPtr] = Node;
1040  }
1041 
1046  void printGraphAsDot(const std::string FilePath, bool Verbose) const {
1048  std::vector<node_impl *> VisitedNodes;
1049 
1050  std::fstream Stream(FilePath, std::ios::out);
1051  Stream << "digraph dot {" << std::endl;
1052 
1053  for (std::weak_ptr<node_impl> Node : MRoots)
1054  Node.lock()->printDotRecursive(Stream, VisitedNodes, Verbose);
1055 
1056  Stream << "}" << std::endl;
1057 
1058  Stream.close();
1059  }
1060 
1066  void makeEdge(std::shared_ptr<node_impl> Src,
1067  std::shared_ptr<node_impl> Dest);
1068 
1072  void throwIfGraphRecordingQueue(const std::string ExceptionMsg) const {
1073  if (MRecordingQueues.size()) {
1074  throw sycl::exception(make_error_code(sycl::errc::invalid),
1075  ExceptionMsg +
1076  " cannot be called when a queue "
1077  "is currently recording commands to a graph.");
1078  }
1079  }
1080 
1085  static bool checkNodeRecursive(const std::shared_ptr<node_impl> &NodeA,
1086  const std::shared_ptr<node_impl> &NodeB) {
1087  size_t FoundCnt = 0;
1088  for (std::weak_ptr<node_impl> &SuccA : NodeA->MSuccessors) {
1089  for (std::weak_ptr<node_impl> &SuccB : NodeB->MSuccessors) {
1090  if (NodeA->isSimilar(NodeB) &&
1091  checkNodeRecursive(SuccA.lock(), SuccB.lock())) {
1092  FoundCnt++;
1093  break;
1094  }
1095  }
1096  }
1097  if (FoundCnt != NodeA->MSuccessors.size()) {
1098  return false;
1099  }
1100 
1101  return true;
1102  }
1103 
1115  bool hasSimilarStructure(std::shared_ptr<detail::graph_impl> Graph,
1116  bool DebugPrint = false) const {
1117  if (this == Graph.get())
1118  return true;
1119 
1120  if (MContext != Graph->MContext) {
1121  if (DebugPrint) {
1123  "MContext are not the same.");
1124  }
1125  return false;
1126  }
1127 
1128  if (MDevice != Graph->MDevice) {
1129  if (DebugPrint) {
1131  "MDevice are not the same.");
1132  }
1133  return false;
1134  }
1135 
1136  if (MEventsMap.size() != Graph->MEventsMap.size()) {
1137  if (DebugPrint) {
1139  "MEventsMap sizes are not the same.");
1140  }
1141  return false;
1142  }
1143 
1144  if (MInorderQueueMap.size() != Graph->MInorderQueueMap.size()) {
1145  if (DebugPrint) {
1147  "MInorderQueueMap sizes are not the same.");
1148  }
1149  return false;
1150  }
1151 
1152  if (MRoots.size() != Graph->MRoots.size()) {
1153  if (DebugPrint) {
1155  "MRoots sizes are not the same.");
1156  }
1157  return false;
1158  }
1159 
1160  size_t RootsFound = 0;
1161  for (std::weak_ptr<node_impl> NodeA : MRoots) {
1162  for (std::weak_ptr<node_impl> NodeB : Graph->MRoots) {
1163  auto NodeALocked = NodeA.lock();
1164  auto NodeBLocked = NodeB.lock();
1165 
1166  if (NodeALocked->isSimilar(NodeBLocked)) {
1167  if (checkNodeRecursive(NodeALocked, NodeBLocked)) {
1168  RootsFound++;
1169  break;
1170  }
1171  }
1172  }
1173  }
1174 
1175  if (RootsFound != MRoots.size()) {
1176  if (DebugPrint) {
1178  "Root Nodes do NOT match.");
1179  }
1180  return false;
1181  }
1182 
1183  return true;
1184  }
1185 
1188  size_t getNumberOfNodes() const { return MNodeStorage.size(); }
1189 
1194  std::vector<sycl::detail::EventImplPtr>
1195  getExitNodesEvents(std::weak_ptr<sycl::detail::queue_impl> Queue);
1196 
1200  void setBarrierDep(std::weak_ptr<sycl::detail::queue_impl> Queue,
1201  std::shared_ptr<node_impl> BarrierNodeImpl) {
1202  MBarrierDependencyMap[Queue] = BarrierNodeImpl;
1203  }
1204 
1208  std::shared_ptr<node_impl>
1209  getBarrierDep(std::weak_ptr<sycl::detail::queue_impl> Queue) {
1210  return MBarrierDependencyMap[Queue];
1211  }
1212 
1213 private:
1220  void
1221  searchDepthFirst(std::function<bool(std::shared_ptr<node_impl> &,
1222  std::deque<std::shared_ptr<node_impl>> &)>
1223  NodeFunc);
1224 
1229  bool checkForCycles();
1230 
1233  void addRoot(const std::shared_ptr<node_impl> &Root);
1234 
1239  std::shared_ptr<node_impl>
1240  addNodesToExits(const std::shared_ptr<graph_impl> &Impl,
1241  const std::list<std::shared_ptr<node_impl>> &NodeList);
1242 
1247  void addDepsToNode(std::shared_ptr<node_impl> Node,
1248  const std::vector<std::shared_ptr<node_impl>> &Deps) {
1249  if (!Deps.empty()) {
1250  for (auto &N : Deps) {
1251  N->registerSuccessor(Node, N);
1252  this->removeRoot(Node);
1253  }
1254  } else {
1255  this->addRoot(Node);
1256  }
1257  }
1258 
1260  sycl::context MContext;
1263  sycl::device MDevice;
1265  std::set<std::weak_ptr<sycl::detail::queue_impl>,
1266  std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
1267  MRecordingQueues;
1269  std::unordered_map<std::shared_ptr<sycl::detail::event_impl>,
1270  std::shared_ptr<node_impl>>
1271  MEventsMap;
1275  std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
1276  std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
1277  MInorderQueueMap;
1280  bool MSkipCycleChecks = false;
1282  std::set<sycl::detail::SYCLMemObjT *> MMemObjs;
1283 
1286  bool MAllowBuffers = false;
1287 
1290  std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
1291  std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
1292  MBarrierDependencyMap;
1293 };
1294 
1297 public:
1298  using ReadLock = std::shared_lock<std::shared_mutex>;
1299  using WriteLock = std::unique_lock<std::shared_mutex>;
1300 
1302  mutable std::shared_mutex MMutex;
1303 
1313  const std::shared_ptr<graph_impl> &GraphImpl,
1314  const property_list &PropList);
1315 
1319  ~exec_graph_impl();
1320 
1325  void makePartitions();
1326 
1332  sycl::event enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
1334 
1340  void createCommandBuffers(sycl::device Device,
1341  std::shared_ptr<partition> &Partition);
1342 
1345  sycl::device getDevice() const { return MDevice; }
1346 
1349  sycl::context getContext() const { return MContext; }
1350 
1353  const std::list<std::shared_ptr<node_impl>> &getSchedule() const {
1354  return MSchedule;
1355  }
1356 
1359  const std::shared_ptr<graph_impl> &getGraphImpl() const { return MGraphImpl; }
1360 
1363  const std::vector<std::shared_ptr<partition>> &getPartitions() const {
1364  return MPartitions;
1365  }
1366 
1373  for (auto Event : MExecutionEvents) {
1374  if (!Event->isCompleted()) {
1375  return false;
1376  }
1377  }
1378  return true;
1379  }
1380 
1382  std::vector<sycl::detail::AccessorImplHost *> getRequirements() const {
1383  return MRequirements;
1384  }
1385 
1386  void update(std::shared_ptr<graph_impl> GraphImpl);
1387  void update(std::shared_ptr<node_impl> Node);
1388  void update(const std::vector<std::shared_ptr<node_impl>> Nodes);
1389 
1390  void updateImpl(std::shared_ptr<node_impl> NodeImpl);
1391 
1392 private:
1400  ur_exp_command_buffer_sync_point_t
1401  enqueueNode(sycl::context Ctx, sycl::detail::DeviceImplPtr DeviceImpl,
1402  ur_exp_command_buffer_handle_t CommandBuffer,
1403  std::shared_ptr<node_impl> Node);
1404 
1412  ur_exp_command_buffer_sync_point_t
1413  enqueueNodeDirect(sycl::context Ctx, sycl::detail::DeviceImplPtr DeviceImpl,
1414  ur_exp_command_buffer_handle_t CommandBuffer,
1415  std::shared_ptr<node_impl> Node);
1416 
1423  void findRealDeps(std::vector<ur_exp_command_buffer_sync_point_t> &Deps,
1424  std::shared_ptr<node_impl> CurrentNode,
1425  int ReferencePartitionNum);
1426 
1430  void duplicateNodes();
1431 
1436  void printGraphAsDot(const std::string FilePath, bool Verbose) const {
1438  std::vector<node_impl *> VisitedNodes;
1439 
1440  std::fstream Stream(FilePath, std::ios::out);
1441  Stream << "digraph dot {" << std::endl;
1442 
1443  std::vector<std::shared_ptr<node_impl>> Roots;
1444  for (auto &Node : MNodeStorage) {
1445  if (Node->MPredecessors.size() == 0) {
1446  Roots.push_back(Node);
1447  }
1448  }
1449 
1450  for (std::shared_ptr<node_impl> Node : Roots)
1451  Node->printDotRecursive(Stream, VisitedNodes, Verbose);
1452 
1453  Stream << "}" << std::endl;
1454 
1455  Stream.close();
1456  }
1457 
1459  std::list<std::shared_ptr<node_impl>> MSchedule;
1466  std::shared_ptr<graph_impl> MGraphImpl;
1469  std::unordered_map<std::shared_ptr<node_impl>,
1470  ur_exp_command_buffer_sync_point_t>
1471  MSyncPoints;
1474  std::unordered_map<std::shared_ptr<node_impl>, int> MPartitionNodes;
1476  sycl::device MDevice;
1478  sycl::context MContext;
1481  std::vector<sycl::detail::AccessorImplHost *> MRequirements;
1484  std::vector<sycl::detail::AccessorImplPtr> MAccessors;
1486  std::vector<sycl::detail::EventImplPtr> MExecutionEvents;
1488  std::vector<std::shared_ptr<partition>> MPartitions;
1490  std::vector<std::shared_ptr<node_impl>> MNodeStorage;
1492  std::unordered_map<std::shared_ptr<node_impl>,
1493  ur_exp_command_buffer_command_handle_t>
1494  MCommandMap;
1496  bool MIsUpdatable;
1498  bool MEnableProfiling;
1499 
1500  // Stores a cache of node ids from modifiable graph nodes to the companion
1501  // node(s) in this graph. Used for quick access when updating this graph.
1502  std::multimap<node_impl::id_type, std::shared_ptr<node_impl>> MIDCache;
1503 };
1504 
1506 public:
1507  dynamic_parameter_impl(std::shared_ptr<graph_impl> GraphImpl,
1508  size_t ParamSize, const void *Data)
1509  : MGraph(GraphImpl), MValueStorage(ParamSize) {
1510  std::memcpy(MValueStorage.data(), Data, ParamSize);
1511  }
1512 
1517  void registerNode(std::shared_ptr<node_impl> NodeImpl, int ArgIndex) {
1518  MNodes.emplace_back(NodeImpl, ArgIndex);
1519  }
1520 
1522  void *getValue() { return MValueStorage.data(); }
1523 
1528  void updateValue(const void *NewValue, size_t Size) {
1529  for (auto &[NodeWeak, ArgIndex] : MNodes) {
1530  auto NodeShared = NodeWeak.lock();
1531  if (NodeShared) {
1532  NodeShared->updateArgValue(ArgIndex, NewValue, Size);
1533  }
1534  }
1535  std::memcpy(MValueStorage.data(), NewValue, Size);
1536  }
1537 
1542  void updateAccessor(const sycl::detail::AccessorBaseHost *Acc) {
1543  for (auto &[NodeWeak, ArgIndex] : MNodes) {
1544  auto NodeShared = NodeWeak.lock();
1545  // Should we fail here if the node isn't alive anymore?
1546  if (NodeShared) {
1547  NodeShared->updateAccessor(ArgIndex, Acc);
1548  }
1549  }
1550  std::memcpy(MValueStorage.data(), Acc,
1551  sizeof(sycl::detail::AccessorBaseHost));
1552  }
1553 
1554  // Weak ptrs to node_impls which will be updated
1555  std::vector<std::pair<std::weak_ptr<node_impl>, int>> MNodes;
1556 
1557  std::shared_ptr<graph_impl> MGraph;
1558  std::vector<std::byte> MValueStorage;
1559 };
1560 
1561 } // namespace detail
1562 } // namespace experimental
1563 } // namespace oneapi
1564 } // namespace ext
1565 } // namespace _V1
1566 } // namespace sycl
The context class represents a SYCL context on which kernel functions may be executed.
Definition: context.hpp:50
The SYCL device class encapsulates a single SYCL device on which kernels may be executed.
Definition: device.hpp:64
backend get_backend() const noexcept
Returns the backend associated with this device.
Definition: device.cpp:203
bool has(aspect Aspect) const __SYCL_WARN_IMAGE_ASPECT(Aspect)
Indicates if the SYCL device has the given feature.
Definition: device.cpp:207
An event object can be used to synchronize memory transfers, enqueues of kernels and signaling barrie...
Definition: event.hpp:44
dynamic_parameter_impl(std::shared_ptr< graph_impl > GraphImpl, size_t ParamSize, const void *Data)
std::vector< std::pair< std::weak_ptr< node_impl >, int > > MNodes
void * getValue()
Get a pointer to the internal value of this dynamic parameter.
void registerNode(std::shared_ptr< node_impl > NodeImpl, int ArgIndex)
Register a node with this dynamic parameter.
void updateValue(const void *NewValue, size_t Size)
Update the internal value of this dynamic parameter as well as the value of this parameter in all reg...
void updateAccessor(const sycl::detail::AccessorBaseHost *Acc)
Update the internal value of this dynamic parameter as well as the value of this parameter in all reg...
Class representing the implementation of command_graph<executable>.
exec_graph_impl(sycl::context Context, const std::shared_ptr< graph_impl > &GraphImpl, const property_list &PropList)
Constructor.
Definition: graph_impl.cpp:757
void createCommandBuffers(sycl::device Device, std::shared_ptr< partition > &Partition)
Turns the internal graph representation into UR command-buffers for a device.
Definition: graph_impl.cpp:699
std::vector< sycl::detail::AccessorImplHost * > getRequirements() const
Returns a list of all the accessor requirements for this graph.
const std::vector< std::shared_ptr< partition > > & getPartitions() const
Query the vector of the partitions composing the exec_graph.
void update(std::shared_ptr< graph_impl > GraphImpl)
sycl::context getContext() const
Query for the context tied to this graph.
void makePartitions()
Partition the graph nodes and put the partition in MPartitions.
Definition: graph_impl.cpp:183
void updateImpl(std::shared_ptr< node_impl > NodeImpl)
sycl::device getDevice() const
Query for the device tied to this graph.
const std::list< std::shared_ptr< node_impl > > & getSchedule() const
Query the scheduling of node execution.
const std::shared_ptr< graph_impl > & getGraphImpl() const
Query the graph_impl.
std::shared_mutex MMutex
Protects all the fields that can be changed by class' methods.
bool previousSubmissionCompleted() const
Checks if the previous submissions of this graph have been completed This function checks the status ...
sycl::event enqueue(const std::shared_ptr< sycl::detail::queue_impl > &Queue, sycl::detail::CG::StorageInitHelper CGData)
Called by handler::ext_oneapi_command_graph() to schedule graph for execution.
Definition: graph_impl.cpp:818
Implementation details of command_graph<modifiable>.
Definition: graph_impl.hpp:853
std::vector< std::shared_ptr< node_impl > > MNodeStorage
Storage for all nodes contained within a graph.
void addQueue(const std::shared_ptr< sycl::detail::queue_impl > &RecordingQueue)
Add a queue to the set of queues which are currently recording to this graph.
Definition: graph_impl.hpp:935
void removeRoot(const std::shared_ptr< node_impl > &Root)
Remove node from list of root nodes.
Definition: graph_impl.cpp:348
std::unique_lock< std::shared_mutex > WriteLock
Definition: graph_impl.hpp:856
void makeEdge(std::shared_ptr< node_impl > Src, std::shared_ptr< node_impl > Dest)
Make an edge between two nodes in the graph.
Definition: graph_impl.cpp:556
std::shared_mutex MMutex
Protects all the fields that can be changed by class' methods.
Definition: graph_impl.hpp:859
void printGraphAsDot(const std::string FilePath, bool Verbose) const
Prints the contents of the graph to a text file in DOT format.
void throwIfGraphRecordingQueue(const std::string ExceptionMsg) const
Throws an invalid exception if this function is called while a queue is recording commands to the gra...
std::shared_ptr< sycl::detail::event_impl > getEventForNode(std::shared_ptr< node_impl > NodeImpl) const
Find the sycl event associated with a node.
Definition: graph_impl.hpp:970
void setBarrierDep(std::weak_ptr< sycl::detail::queue_impl > Queue, std::shared_ptr< node_impl > BarrierNodeImpl)
Store the last barrier node that was submitted to the queue.
void setLastInorderNode(std::shared_ptr< sycl::detail::queue_impl > Queue, std::shared_ptr< node_impl > Node)
Track the last node added to this graph from an in-order queue.
std::set< std::weak_ptr< node_impl >, std::owner_less< std::weak_ptr< node_impl > > > MRoots
List of root nodes.
bool hasSimilarStructure(std::shared_ptr< detail::graph_impl > Graph, bool DebugPrint=false) const
Checks if the graph_impl of Graph has a similar structure to the graph_impl of the caller.
std::vector< sycl::detail::EventImplPtr > getExitNodesEvents(std::weak_ptr< sycl::detail::queue_impl > Queue)
Traverse the graph recursively to get the events associated with the output nodes of this graph assoc...
Definition: graph_impl.cpp:607
bool clearQueues()
Remove all queues which are recording to this graph, also sets all queues cleared back to the executi...
Definition: graph_impl.cpp:502
sycl::device getDevice() const
Query for the device tied to this graph.
std::shared_ptr< node_impl > getNodeForEvent(std::shared_ptr< sycl::detail::event_impl > EventImpl)
Find the node associated with a SYCL event.
Definition: graph_impl.hpp:989
static bool checkNodeRecursive(const std::shared_ptr< node_impl > &NodeA, const std::shared_ptr< node_impl > &NodeB)
Recursively check successors of NodeA and NodeB to check they are similar.
graph_impl(const sycl::context &SyclContext, const sycl::device &SyclDevice, const sycl::property_list &PropList={})
Constructor.
Definition: graph_impl.hpp:865
sycl::context getContext() const
Query for the context tied to this graph.
std::shared_ptr< node_impl > add(node_type NodeType, std::unique_ptr< sycl::detail::CG > CommandGroup, const std::vector< std::shared_ptr< node_impl >> &Dep={})
Create a kernel node in the graph.
Definition: graph_impl.cpp:437
void addEventForNode(std::shared_ptr< graph_impl > GraphImpl, std::shared_ptr< sycl::detail::event_impl > EventImpl, std::shared_ptr< node_impl > NodeImpl)
Associate a sycl event with a node in the graph.
Definition: graph_impl.hpp:958
std::shared_lock< std::shared_mutex > ReadLock
Definition: graph_impl.hpp:855
size_t getNumberOfNodes() const
Returns the number of nodes in the Graph.
std::shared_ptr< node_impl > getLastInorderNode(std::shared_ptr< sycl::detail::queue_impl > Queue)
Find the last node added to this graph from an in-order queue.
std::shared_ptr< node_impl > getBarrierDep(std::weak_ptr< sycl::detail::queue_impl > Queue)
Get the last barrier node that was submitted to the queue.
void removeQueue(const std::shared_ptr< sycl::detail::queue_impl > &RecordingQueue)
Remove a queue from the set of queues which are currently recording to this graph.
Definition: graph_impl.hpp:943
Implementation of node class from SYCL_EXT_ONEAPI_GRAPH.
Definition: graph_impl.hpp:80
bool isNDCopyNode() const
Test if the node contains a N-D copy.
Definition: graph_impl.hpp:398
void updateAccessor(int ArgIndex, const sycl::detail::AccessorBaseHost *Acc)
Update the value of an accessor inside this node.
Definition: graph_impl.hpp:415
void registerSuccessor(const std::shared_ptr< node_impl > &Node, const std::shared_ptr< node_impl > &Prev)
Add successor to the node.
Definition: graph_impl.hpp:119
std::vector< std::weak_ptr< node_impl > > MPredecessors
List of predecessors to this node.
Definition: graph_impl.hpp:91
node_impl(node_impl &Other)
Construct a node from another node.
Definition: graph_impl.hpp:163
std::vector< std::weak_ptr< node_impl > > MSuccessors
List of successors to this node.
Definition: graph_impl.hpp:87
void updateFromOtherNode(const std::shared_ptr< node_impl > &Other)
Definition: graph_impl.hpp:533
sycl::detail::CGType MCGType
Type of the command-group for the node.
Definition: graph_impl.hpp:93
std::shared_ptr< exec_graph_impl > MSubGraphImpl
Stores the executable graph impl associated with this node if it is a subgraph node.
Definition: graph_impl.hpp:100
node_impl(node_type NodeType, std::unique_ptr< sycl::detail::CG > &&CommandGroup)
Construct a node representing a command-group.
Definition: graph_impl.hpp:150
bool MNDRangeUsed
Track whether an ND-Range was used for kernel nodes.
Definition: graph_impl.hpp:111
bool hasRequirementDependency(sycl::detail::AccessorImplHost *IncomingReq)
Checks if this node should be a dependency of another node based on accessor requirements.
Definition: graph_impl.hpp:186
int MPartitionNum
Partition number needed to assign a Node to a a partition.
Definition: graph_impl.hpp:108
void updateRange(range< Dimensions > ExecutionRange)
Definition: graph_impl.hpp:507
void updateNDRange(nd_range< Dimensions > ExecutionRange)
Definition: graph_impl.hpp:481
bool isSimilar(const std::shared_ptr< node_impl > &Node, bool CompareContentOnly=false) const
Tests if the caller is similar to Node, this is only used for testing.
Definition: graph_impl.hpp:324
void printDotRecursive(std::fstream &Stream, std::vector< node_impl * > &Visited, bool Verbose)
Recursive Depth first traversal of linked nodes.
Definition: graph_impl.hpp:375
node_type MNodeType
User facing type of the node.
Definition: graph_impl.hpp:95
std::unique_ptr< sycl::detail::CG > getCGCopy() const
Get a deep copy of this node's command group.
Definition: graph_impl.hpp:230
void updateArgValue(int ArgIndex, const void *NewValue, size_t Size)
Definition: graph_impl.hpp:464
bool isEmpty() const
Query if this is an empty node.
Definition: graph_impl.hpp:223
void registerPredecessor(const std::shared_ptr< node_impl > &Node)
Add predecessor to the node.
Definition: graph_impl.hpp:133
bool MVisited
Used for tracking visited status during cycle checks.
Definition: graph_impl.hpp:103
std::unique_ptr< sycl::detail::CG > MCommandGroup
Command group object which stores all args etc needed to enqueue the node.
Definition: graph_impl.hpp:97
id_type MID
Unique identifier for this node.
Definition: graph_impl.hpp:85
node_impl & operator=(node_impl &Other)
Copy-assignment operator.
Definition: graph_impl.hpp:170
bool checkIfGraphIsSinglePath()
Checks if the graph is single path, i.e.
Definition: graph_impl.hpp:832
std::unordered_map< sycl::device, ur_exp_command_buffer_handle_t > MCommandBuffers
Map of devices to command buffers.
Definition: graph_impl.hpp:817
std::set< std::weak_ptr< node_impl >, std::owner_less< std::weak_ptr< node_impl > > > MRoots
List of root nodes.
Definition: graph_impl.hpp:812
std::vector< std::shared_ptr< partition > > MPredecessors
List of predecessors to this partition.
Definition: graph_impl.hpp:819
bool MIsInOrderGraph
True if the graph of this partition is a single path graph and in-order optmization can be applied on...
Definition: graph_impl.hpp:822
std::list< std::shared_ptr< node_impl > > MSchedule
Execution schedule of nodes in the graph.
Definition: graph_impl.hpp:814
Property passed to command_graph constructor to disable checking for cycles.
Definition: graph.hpp:153
Command group handler class.
Definition: handler.hpp:467
Defines the iteration domain of both the work-groups and the overall dispatch.
Definition: nd_range.hpp:22
Objects of the property_list class are containers for the SYCL properties.
decltype(Obj::impl) const & getSyclObjImpl(const Obj &SyclObject)
Definition: impl_utils.hpp:31
AccessorImplHost Requirement
std::shared_ptr< device_impl > DeviceImplPtr
CGType
Type of the command group.
Definition: cg_types.hpp:42
node_type getNodeTypeFromCG(sycl::detail::CGType CGType)
Definition: graph_impl.hpp:44
class __SYCL_EBO __SYCL_SPECIAL_CLASS Dimensions
constexpr mode_tag_t< access_mode::read_write > read_write
Definition: access.hpp:85
__width_manipulator__ setw(int Width)
Definition: stream.hpp:841
std::error_code make_error_code(sycl::errc E) noexcept
Constructs an error code using e and sycl_category()
Definition: exception.cpp:65
Definition: access.hpp:18
std::vector< std::vector< char > > MArgsStorage
Storage for standard layout arguments.
Definition: cg.hpp:178