27 #include <shared_mutex>
30 inline namespace _V1 {
38 namespace experimental {
42 using sycl::detail::CG;
49 case CG::CopyAccToPtr:
50 case CG::CopyPtrToAcc:
51 case CG::CopyAccToAcc:
64 case CG::BarrierWaitlist:
66 case CG::CodeplayHostTask:
68 case CG::ExecCommandBuffer:
71 assert(
false &&
"Invalid Graph Node Type");
117 const std::shared_ptr<node_impl> &Prev) {
119 [Node](
const std::weak_ptr<node_impl> &Ptr) {
120 return Ptr.lock() == Node;
125 Node->registerPredecessor(Prev);
132 [&Node](
const std::weak_ptr<node_impl> &Ptr) {
133 return Ptr.lock() == Node;
148 std::unique_ptr<sycl::detail::CG> &&CommandGroup)
153 static_cast<sycl::detail::CGExecCommandBuffer *
>(
MCommandGroup.get())
168 if (
this != &Other) {
193 case access_mode::discard_read_write:
194 case access_mode::discard_write:
198 for (sycl::detail::AccessorImplHost *CurrentReq :
200 if (IncomingReq->MSYCLMemObj == CurrentReq->MSYCLMemObj) {
218 return ((
MCGType == sycl::detail::CG::None) ||
219 (
MCGType == sycl::detail::CG::Barrier));
226 case sycl::detail::CG::Kernel: {
227 auto CGCopy = createCGCopy<sycl::detail::CGExecKernel>();
228 rebuildArgStorage(CGCopy->MArgs,
MCommandGroup->getArgsStorage(),
229 CGCopy->getArgsStorage());
230 return std::move(CGCopy);
232 case sycl::detail::CG::CopyAccToPtr:
233 case sycl::detail::CG::CopyPtrToAcc:
234 case sycl::detail::CG::CopyAccToAcc:
235 return createCGCopy<sycl::detail::CGCopy>();
236 case sycl::detail::CG::Fill:
237 return createCGCopy<sycl::detail::CGFill>();
238 case sycl::detail::CG::UpdateHost:
239 return createCGCopy<sycl::detail::CGUpdateHost>();
240 case sycl::detail::CG::CopyUSM:
241 return createCGCopy<sycl::detail::CGCopyUSM>();
242 case sycl::detail::CG::FillUSM:
243 return createCGCopy<sycl::detail::CGFillUSM>();
244 case sycl::detail::CG::PrefetchUSM:
245 return createCGCopy<sycl::detail::CGPrefetchUSM>();
246 case sycl::detail::CG::AdviseUSM:
247 return createCGCopy<sycl::detail::CGAdviseUSM>();
248 case sycl::detail::CG::Copy2DUSM:
249 return createCGCopy<sycl::detail::CGCopy2DUSM>();
250 case sycl::detail::CG::Fill2DUSM:
251 return createCGCopy<sycl::detail::CGFill2DUSM>();
252 case sycl::detail::CG::Memset2DUSM:
253 return createCGCopy<sycl::detail::CGMemset2DUSM>();
254 case sycl::detail::CG::CodeplayHostTask: {
258 auto CommandGroupPtr =
259 static_cast<sycl::detail::CGHostTask *
>(
MCommandGroup.get());
260 sycl::detail::HostTask HostTask = *CommandGroupPtr->MHostTask.get();
261 auto HostTaskUPtr = std::make_unique<sycl::detail::HostTask>(HostTask);
264 CommandGroupPtr->getArgsStorage(), CommandGroupPtr->getAccStorage(),
265 CommandGroupPtr->getSharedPtrStorage(),
266 CommandGroupPtr->getRequirements(), CommandGroupPtr->getEvents());
268 std::vector<sycl::detail::ArgDesc> NewArgs = CommandGroupPtr->MArgs;
270 rebuildArgStorage(NewArgs, CommandGroupPtr->getArgsStorage(),
273 sycl::detail::code_location Loc(CommandGroupPtr->MFileName.data(),
274 CommandGroupPtr->MFunctionName.data(),
275 CommandGroupPtr->MLine,
276 CommandGroupPtr->MColumn);
278 return std::make_unique<sycl::detail::CGHostTask>(
279 sycl::detail::CGHostTask(
280 std::move(HostTaskUPtr), CommandGroupPtr->MQueue,
281 CommandGroupPtr->MContext, std::move(NewArgs), std::move(Data),
282 CommandGroupPtr->getType(), Loc));
284 case sycl::detail::CG::Barrier:
285 case sycl::detail::CG::BarrierWaitlist:
288 return createCGCopy<sycl::detail::CG>();
289 case sycl::detail::CG::CopyToDeviceGlobal:
290 return createCGCopy<sycl::detail::CGCopyToDeviceGlobal>();
291 case sycl::detail::CG::CopyFromDeviceGlobal:
292 return createCGCopy<sycl::detail::CGCopyFromDeviceGlobal>();
293 case sycl::detail::CG::ReadWriteHostPipe:
294 return createCGCopy<sycl::detail::CGReadWriteHostPipe>();
295 case sycl::detail::CG::CopyImage:
296 return createCGCopy<sycl::detail::CGCopyImage>();
297 case sycl::detail::CG::SemaphoreSignal:
298 return createCGCopy<sycl::detail::CGSemaphoreSignal>();
299 case sycl::detail::CG::SemaphoreWait:
300 return createCGCopy<sycl::detail::CGSemaphoreWait>();
301 case sycl::detail::CG::ExecCommandBuffer:
302 return createCGCopy<sycl::detail::CGExecCommandBuffer>();
303 case sycl::detail::CG::None:
315 bool CompareContentOnly =
false)
const {
316 if (!CompareContentOnly) {
317 if (
MSuccessors.size() != Node->MSuccessors.size())
327 case sycl::detail::CG::CGTYPE::Kernel: {
328 sycl::detail::CGExecKernel *ExecKernelA =
329 static_cast<sycl::detail::CGExecKernel *
>(
MCommandGroup.get());
330 sycl::detail::CGExecKernel *ExecKernelB =
331 static_cast<sycl::detail::CGExecKernel *
>(Node->MCommandGroup.get());
332 return ExecKernelA->MKernelName.compare(ExecKernelB->MKernelName) == 0;
334 case sycl::detail::CG::CGTYPE::CopyUSM: {
335 sycl::detail::CGCopyUSM *CopyA =
337 sycl::detail::CGCopyUSM *CopyB =
338 static_cast<sycl::detail::CGCopyUSM *
>(Node->MCommandGroup.get());
339 return (CopyA->getSrc() == CopyB->getSrc()) &&
340 (CopyA->getDst() == CopyB->getDst()) &&
341 (CopyA->getLength() == CopyB->getLength());
343 case sycl::detail::CG::CGTYPE::CopyAccToAcc:
344 case sycl::detail::CG::CGTYPE::CopyAccToPtr:
345 case sycl::detail::CG::CGTYPE::CopyPtrToAcc: {
346 sycl::detail::CGCopy *CopyA =
348 sycl::detail::CGCopy *CopyB =
349 static_cast<sycl::detail::CGCopy *
>(Node->MCommandGroup.get());
350 return (CopyA->getSrc() == CopyB->getSrc()) &&
351 (CopyA->getDst() == CopyB->getDst());
354 assert(
false &&
"Unexpected command group type!");
366 std::vector<node_impl *> &Visited,
bool Verbose) {
368 if (std::find(Visited.begin(), Visited.end(),
this) != Visited.end())
371 Visited.push_back(
this);
373 printDotCG(Stream, Verbose);
375 auto NodeDep = Dep.lock();
376 Stream <<
" \"" << NodeDep.get() <<
"\" -> \"" <<
this <<
"\""
380 for (std::weak_ptr<node_impl> Succ :
MSuccessors) {
382 Succ.lock()->printDotRecursive(Stream, Visited, Verbose);
392 static_cast<sycl::detail::CGExecKernel *
>(
MCommandGroup.get())->MArgs;
394 for (
auto &Arg : Args) {
395 if (Arg.MIndex != ArgIndex) {
398 assert(Arg.MType == sycl::detail::kernel_param_kind_t::kind_accessor);
401 if (
static_cast<sycl::detail::SYCLMemObjT *
>(NewAccImpl->MSYCLMemObj)
402 ->needsWriteBack()) {
405 "Accessors to buffers which have write_back enabled "
406 "are not allowed to be used in command graphs.");
413 if (NewReq->MAccessMode != sycl::access_mode::read) {
415 static_cast<sycl::detail::SYCLMemObjT *
>(NewReq->MSYCLMemObj);
416 SYCLMemObj->handleWriteAccessorCreation();
421 static_cast<sycl::detail::AccessorImplHost *
>(Arg.MPtr);
422 Acc.get() == OldAcc) {
429 static_cast<sycl::detail::AccessorImplHost *
>(Arg.MPtr);
434 Arg.MPtr = NewAccImpl.get();
442 static_cast<sycl::detail::CGExecKernel *
>(
MCommandGroup.get())->MArgs;
443 for (
auto &Arg : Args) {
444 if (Arg.MIndex != ArgIndex) {
447 assert(Arg.MSize ==
static_cast<int>(Size));
450 std::memcpy(Arg.MPtr, NewValue, Size);
455 template <
int Dimensions>
457 if (
MCGType != sycl::detail::CG::Kernel) {
460 "Cannot update execution range of nodes which are not kernel nodes");
464 "Cannot update node which was created with a "
465 "sycl::range with a sycl::nd_range");
469 static_cast<sycl::detail::CGExecKernel *
>(
MCommandGroup.get())
474 "Cannot update execution range of a node with an "
475 "execution range of different dimensions than what "
476 "the node was originall created with.");
479 NDRDesc.set(ExecutionRange);
483 if (
MCGType != sycl::detail::CG::Kernel) {
486 "Cannot update execution range of nodes which are not kernel nodes");
490 "Cannot update node which was created with a "
491 "sycl::nd_range with a sycl::range");
495 static_cast<sycl::detail::CGExecKernel *
>(
MCommandGroup.get())
500 "Cannot update execution range of a node with an "
501 "execution range of different dimensions than what "
502 "the node was originall created with.");
505 NDRDesc.set(ExecutionRange);
510 static_cast<sycl::detail::CGExecKernel *
>(
MCommandGroup.get());
512 static_cast<sycl::detail::CGExecKernel *
>(Other->MCommandGroup.get());
514 ExecCG->MArgs = OtherExecCG->MArgs;
515 ExecCG->MNDRDesc = OtherExecCG->MNDRDesc;
516 ExecCG->getAccStorage() = OtherExecCG->getAccStorage();
517 ExecCG->getRequirements() = OtherExecCG->getRequirements();
519 auto &OldArgStorage = OtherExecCG->getArgsStorage();
520 auto &NewArgStorage = ExecCG->getArgsStorage();
522 rebuildArgStorage(ExecCG->MArgs, OldArgStorage, NewArgStorage);
528 void rebuildArgStorage(std::vector<sycl::detail::ArgDesc> &Args,
529 const std::vector<std::vector<char>> &OldArgStorage,
530 std::vector<std::vector<char>> &NewArgStorage)
const {
532 NewArgStorage.clear();
536 for (
auto &Arg : Args) {
537 if (Arg.MType != sycl::detail::kernel_param_kind_t::kind_std_layout) {
541 for (
auto &ArgStorage : OldArgStorage) {
542 if (ArgStorage.data() != Arg.MPtr) {
545 NewArgStorage.emplace_back(Arg.MSize);
547 std::memcpy(NewArgStorage.back().data(), ArgStorage.data(), Arg.MSize);
549 Arg.MPtr = NewArgStorage.back().data();
557 static id_type getNextNodeID() {
568 void printDotCG(std::ostream &Stream,
bool Verbose) {
569 Stream <<
"\"" <<
this <<
"\" [style=bold, label=\"";
571 Stream <<
"ID = " <<
this <<
"\\n";
575 case sycl::detail::CG::CGTYPE::None:
576 Stream <<
"None \\n";
578 case sycl::detail::CG::CGTYPE::Kernel: {
579 Stream <<
"CGExecKernel \\n";
580 sycl::detail::CGExecKernel *Kernel =
581 static_cast<sycl::detail::CGExecKernel *
>(
MCommandGroup.get());
582 Stream <<
"NAME = " << Kernel->MKernelName <<
"\\n";
584 Stream <<
"ARGS = \\n";
585 for (
size_t i = 0; i < Kernel->MArgs.size(); i++) {
586 auto Arg = Kernel->MArgs[i];
587 std::string Type =
"Undefined";
588 if (Arg.MType == sycl::detail::kernel_param_kind_t::kind_accessor) {
590 }
else if (Arg.MType ==
591 sycl::detail::kernel_param_kind_t::kind_std_layout) {
593 }
else if (Arg.MType ==
594 sycl::detail::kernel_param_kind_t::kind_sampler) {
596 }
else if (Arg.MType ==
597 sycl::detail::kernel_param_kind_t::kind_pointer) {
600 kind_specialization_constants_buffer) {
601 Type =
"Specialization Constants Buffer";
602 }
else if (Arg.MType ==
603 sycl::detail::kernel_param_kind_t::kind_stream) {
605 }
else if (Arg.MType ==
606 sycl::detail::kernel_param_kind_t::kind_invalid) {
609 Stream << i <<
") Type: " << Type <<
" Ptr: " << Arg.MPtr <<
"\\n";
614 case sycl::detail::CG::CGTYPE::CopyAccToPtr:
615 Stream <<
"CGCopy Device-to-Host \\n";
617 sycl::detail::CGCopy *Copy =
619 Stream <<
"Src: " << Copy->getSrc() <<
" Dst: " << Copy->getDst()
623 case sycl::detail::CG::CGTYPE::CopyPtrToAcc:
624 Stream <<
"CGCopy Host-to-Device \\n";
626 sycl::detail::CGCopy *Copy =
628 Stream <<
"Src: " << Copy->getSrc() <<
" Dst: " << Copy->getDst()
632 case sycl::detail::CG::CGTYPE::CopyAccToAcc:
633 Stream <<
"CGCopy Device-to-Device \\n";
635 sycl::detail::CGCopy *Copy =
637 Stream <<
"Src: " << Copy->getSrc() <<
" Dst: " << Copy->getDst()
641 case sycl::detail::CG::CGTYPE::Fill:
642 Stream <<
"CGFill \\n";
644 sycl::detail::CGFill *Fill =
646 Stream <<
"Ptr: " << Fill->MPtr <<
"\\n";
649 case sycl::detail::CG::CGTYPE::UpdateHost:
650 Stream <<
"CGCUpdateHost \\n";
652 sycl::detail::CGUpdateHost *Host =
653 static_cast<sycl::detail::CGUpdateHost *
>(
MCommandGroup.get());
654 Stream <<
"Ptr: " << Host->getReqToUpdate() <<
"\\n";
657 case sycl::detail::CG::CGTYPE::CopyUSM:
658 Stream <<
"CGCopyUSM \\n";
660 sycl::detail::CGCopyUSM *CopyUSM =
662 Stream <<
"Src: " << CopyUSM->getSrc() <<
" Dst: " << CopyUSM->getDst()
663 <<
" Length: " << CopyUSM->getLength() <<
"\\n";
666 case sycl::detail::CG::CGTYPE::FillUSM:
667 Stream <<
"CGFillUSM \\n";
669 sycl::detail::CGFillUSM *FillUSM =
671 Stream <<
"Dst: " << FillUSM->getDst()
672 <<
" Length: " << FillUSM->getLength()
673 <<
" Pattern: " << FillUSM->getFill() <<
"\\n";
676 case sycl::detail::CG::CGTYPE::PrefetchUSM:
677 Stream <<
"CGPrefetchUSM \\n";
679 sycl::detail::CGPrefetchUSM *Prefetch =
680 static_cast<sycl::detail::CGPrefetchUSM *
>(
MCommandGroup.get());
681 Stream <<
"Dst: " << Prefetch->getDst()
682 <<
" Length: " << Prefetch->getLength() <<
"\\n";
685 case sycl::detail::CG::CGTYPE::AdviseUSM:
686 Stream <<
"CGAdviseUSM \\n";
688 sycl::detail::CGAdviseUSM *AdviseUSM =
689 static_cast<sycl::detail::CGAdviseUSM *
>(
MCommandGroup.get());
690 Stream <<
"Dst: " << AdviseUSM->getDst()
691 <<
" Length: " << AdviseUSM->getLength() <<
"\\n";
694 case sycl::detail::CG::CGTYPE::CodeplayHostTask:
695 Stream <<
"CGHostTask \\n";
697 case sycl::detail::CG::CGTYPE::Barrier:
698 Stream <<
"CGBarrier \\n";
700 case sycl::detail::CG::CGTYPE::Copy2DUSM:
701 Stream <<
"CGCopy2DUSM \\n";
703 sycl::detail::CGCopy2DUSM *Copy2DUSM =
704 static_cast<sycl::detail::CGCopy2DUSM *
>(
MCommandGroup.get());
705 Stream <<
"Src:" << Copy2DUSM->getSrc()
706 <<
" Dst: " << Copy2DUSM->getDst() <<
"\\n";
709 case sycl::detail::CG::CGTYPE::Fill2DUSM:
710 Stream <<
"CGFill2DUSM \\n";
712 sycl::detail::CGFill2DUSM *Fill2DUSM =
713 static_cast<sycl::detail::CGFill2DUSM *
>(
MCommandGroup.get());
714 Stream <<
"Dst: " << Fill2DUSM->getDst() <<
"\\n";
717 case sycl::detail::CG::CGTYPE::Memset2DUSM:
718 Stream <<
"CGMemset2DUSM \\n";
720 sycl::detail::CGMemset2DUSM *Memset2DUSM =
721 static_cast<sycl::detail::CGMemset2DUSM *
>(
MCommandGroup.get());
722 Stream <<
"Dst: " << Memset2DUSM->getDst() <<
"\\n";
725 case sycl::detail::CG::CGTYPE::ReadWriteHostPipe:
726 Stream <<
"CGReadWriteHostPipe \\n";
728 case sycl::detail::CG::CGTYPE::CopyToDeviceGlobal:
729 Stream <<
"CGCopyToDeviceGlobal \\n";
731 sycl::detail::CGCopyToDeviceGlobal *CopyToDeviceGlobal =
732 static_cast<sycl::detail::CGCopyToDeviceGlobal *
>(
734 Stream <<
"Src: " << CopyToDeviceGlobal->getSrc()
735 <<
" Dst: " << CopyToDeviceGlobal->getDeviceGlobalPtr() <<
"\\n";
738 case sycl::detail::CG::CGTYPE::CopyFromDeviceGlobal:
739 Stream <<
"CGCopyFromDeviceGlobal \\n";
741 sycl::detail::CGCopyFromDeviceGlobal *CopyFromDeviceGlobal =
742 static_cast<sycl::detail::CGCopyFromDeviceGlobal *
>(
744 Stream <<
"Src: " << CopyFromDeviceGlobal->getDeviceGlobalPtr()
745 <<
" Dst: " << CopyFromDeviceGlobal->getDest() <<
"\\n";
748 case sycl::detail::CG::CGTYPE::ExecCommandBuffer:
749 Stream <<
"CGExecCommandBuffer \\n";
752 Stream <<
"Other \\n";
755 Stream <<
"\"];" << std::endl;
762 template <
typename CGT> std::unique_ptr<CGT> createCGCopy()
const {
763 return std::make_unique<CGT>(*
static_cast<CGT *
>(
MCommandGroup.get()));
773 std::set<std::weak_ptr<node_impl>, std::owner_less<std::weak_ptr<node_impl>>>
778 std::unordered_map<sycl::device, sycl::detail::pi::PiExtCommandBuffer>
785 return (
MRoots.size() && ((*
MRoots.begin()).lock()->MCGType ==
786 sycl::detail::CG::CGTYPE::CodeplayHostTask));
796 using ReadLock = std::shared_lock<std::shared_mutex>;
808 : MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(),
809 MEventsMap(), MInorderQueueMap() {
811 MSkipCycleChecks =
true;
814 .has_property<property::graph::assume_buffer_outlives_graph>()) {
815 MAllowBuffers =
true;
818 if (!SyclDevice.
has(aspect::ext_oneapi_limited_graph) &&
819 !SyclDevice.
has(aspect::ext_oneapi_graph)) {
820 std::stringstream Stream;
822 std::string BackendString = Stream.str();
825 BackendString +
" backend is not supported by SYCL Graph extension.");
833 void removeRoot(
const std::shared_ptr<node_impl> &Root);
840 std::shared_ptr<node_impl>
841 add(
node_type NodeType, std::unique_ptr<sycl::detail::CG> CommandGroup,
842 const std::vector<std::shared_ptr<node_impl>> &Dep = {});
850 std::shared_ptr<node_impl>
851 add(
const std::shared_ptr<graph_impl> &Impl,
852 std::function<
void(
handler &)> CGF,
853 const std::vector<sycl::detail::ArgDesc> &Args,
854 const std::vector<std::shared_ptr<node_impl>> &Dep = {});
860 std::shared_ptr<node_impl>
861 add(
const std::shared_ptr<graph_impl> &Impl,
862 const std::vector<std::shared_ptr<node_impl>> &Dep = {});
868 std::shared_ptr<node_impl>
869 add(
const std::shared_ptr<graph_impl> &Impl,
870 const std::vector<sycl::detail::EventImplPtr> Events);
876 addQueue(
const std::shared_ptr<sycl::detail::queue_impl> &RecordingQueue) {
877 MRecordingQueues.insert(RecordingQueue);
884 removeQueue(
const std::shared_ptr<sycl::detail::queue_impl> &RecordingQueue) {
885 MRecordingQueues.erase(RecordingQueue);
900 std::shared_ptr<sycl::detail::event_impl> EventImpl,
901 std::shared_ptr<node_impl> NodeImpl) {
902 if (EventImpl && !(EventImpl->getCommandGraph()))
903 EventImpl->setCommandGraph(GraphImpl);
904 MEventsMap[EventImpl] = NodeImpl;
910 std::shared_ptr<sycl::detail::event_impl>
913 if (
auto EventImpl = std::find_if(
914 MEventsMap.begin(), MEventsMap.end(),
915 [NodeImpl](
auto &it) { return it.second == NodeImpl; });
916 EventImpl != MEventsMap.end()) {
917 return EventImpl->first;
922 "No event has been recorded for the specified graph node");
929 std::shared_ptr<node_impl>
933 if (
auto NodeFound = MEventsMap.find(EventImpl);
934 NodeFound != std::end(MEventsMap)) {
935 return NodeFound->second;
940 "No node in this graph is associated with this event");
952 std::set<std::weak_ptr<node_impl>, std::owner_less<std::weak_ptr<node_impl>>>
965 std::shared_ptr<node_impl>
967 std::weak_ptr<sycl::detail::queue_impl> QueueWeakPtr(Queue);
968 if (0 == MInorderQueueMap.count(QueueWeakPtr)) {
971 return MInorderQueueMap[QueueWeakPtr];
978 std::shared_ptr<node_impl> Node) {
979 std::weak_ptr<sycl::detail::queue_impl> QueueWeakPtr(Queue);
980 MInorderQueueMap[QueueWeakPtr] = Node;
989 std::vector<node_impl *> VisitedNodes;
991 std::fstream Stream(FilePath, std::ios::out);
992 Stream <<
"digraph dot {" << std::endl;
994 for (std::weak_ptr<node_impl> Node :
MRoots)
995 Node.lock()->printDotRecursive(Stream, VisitedNodes, Verbose);
997 Stream <<
"}" << std::endl;
1007 void makeEdge(std::shared_ptr<node_impl> Src,
1008 std::shared_ptr<node_impl> Dest);
1014 if (MRecordingQueues.size()) {
1017 " cannot be called when a queue "
1018 "is currently recording commands to a graph.");
1027 const std::shared_ptr<node_impl> &NodeB) {
1028 size_t FoundCnt = 0;
1029 for (std::weak_ptr<node_impl> &SuccA : NodeA->MSuccessors) {
1030 for (std::weak_ptr<node_impl> &SuccB : NodeB->MSuccessors) {
1031 if (NodeA->isSimilar(NodeB) &&
1038 if (FoundCnt != NodeA->MSuccessors.size()) {
1057 bool DebugPrint =
false)
const {
1058 if (
this == Graph.get())
1061 if (MContext != Graph->MContext) {
1064 "MContext are not the same.");
1069 if (MDevice != Graph->MDevice) {
1072 "MDevice are not the same.");
1077 if (MEventsMap.size() != Graph->MEventsMap.size()) {
1080 "MEventsMap sizes are not the same.");
1085 if (MInorderQueueMap.size() != Graph->MInorderQueueMap.size()) {
1088 "MInorderQueueMap sizes are not the same.");
1093 if (
MRoots.size() != Graph->MRoots.size()) {
1096 "MRoots sizes are not the same.");
1101 size_t RootsFound = 0;
1102 for (std::weak_ptr<node_impl> NodeA :
MRoots) {
1103 for (std::weak_ptr<node_impl> NodeB : Graph->MRoots) {
1104 auto NodeALocked = NodeA.lock();
1105 auto NodeBLocked = NodeB.lock();
1107 if (NodeALocked->isSimilar(NodeBLocked)) {
1116 if (RootsFound !=
MRoots.size()) {
1119 "Root Nodes do NOT match.");
1139 std::vector<sycl::detail::EventImplPtr>
1141 std::vector<sycl::detail::EventImplPtr> Events;
1142 for (
auto It = MExtraDependencies.begin();
1143 It != MExtraDependencies.end();) {
1144 if ((*It)->MCGType == sycl::detail::CG::Barrier) {
1146 It = MExtraDependencies.erase(It);
1162 searchDepthFirst(std::function<
bool(std::shared_ptr<node_impl> &,
1163 std::deque<std::shared_ptr<node_impl>> &)>
1170 bool checkForCycles();
1174 void addRoot(
const std::shared_ptr<node_impl> &Root);
1180 std::shared_ptr<node_impl>
1181 addNodesToExits(
const std::shared_ptr<graph_impl> &Impl,
1182 const std::list<std::shared_ptr<node_impl>> &NodeList);
1188 void addDepsToNode(std::shared_ptr<node_impl> Node,
1189 const std::vector<std::shared_ptr<node_impl>> &Deps) {
1190 if (!Deps.empty()) {
1191 for (
auto &N : Deps) {
1192 N->registerSuccessor(Node, N);
1196 this->addRoot(Node);
1206 std::set<std::weak_ptr<sycl::detail::queue_impl>,
1207 std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
1210 std::unordered_map<std::shared_ptr<sycl::detail::event_impl>,
1211 std::shared_ptr<node_impl>>
1216 std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
1217 std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
1221 bool MSkipCycleChecks =
false;
1223 std::set<sycl::detail::SYCLMemObjT *> MMemObjs;
1227 bool MAllowBuffers =
false;
1233 std::list<std::shared_ptr<node_impl>> MExtraDependencies;
1254 const std::shared_ptr<graph_impl> &GraphImpl,
1282 std::shared_ptr<partition> &Partition);
1300 const std::shared_ptr<graph_impl> &
getGraphImpl()
const {
return MGraphImpl; }
1314 for (
auto Event : MExecutionEvents) {
1315 if (!Event->isCompleted()) {
1324 return MRequirements;
1327 void update(std::shared_ptr<graph_impl> GraphImpl);
1328 void update(std::shared_ptr<node_impl> Node);
1329 void update(
const std::vector<std::shared_ptr<node_impl>> Nodes);
1331 void updateImpl(std::shared_ptr<node_impl> NodeImpl);
1344 std::shared_ptr<node_impl> Node);
1356 std::shared_ptr<node_impl> Node);
1364 void findRealDeps(std::vector<sycl::detail::pi::PiExtSyncPoint> &Deps,
1365 std::shared_ptr<node_impl> CurrentNode,
1366 int ReferencePartitionNum);
1371 void duplicateNodes();
1377 void printGraphAsDot(
const std::string FilePath,
bool Verbose)
const {
1379 std::vector<node_impl *> VisitedNodes;
1381 std::fstream Stream(FilePath, std::ios::out);
1382 Stream <<
"digraph dot {" << std::endl;
1384 std::vector<std::shared_ptr<node_impl>> Roots;
1385 for (
auto &Node : MNodeStorage) {
1386 if (Node->MPredecessors.size() == 0) {
1387 Roots.push_back(Node);
1391 for (std::shared_ptr<node_impl> Node : Roots)
1392 Node->printDotRecursive(Stream, VisitedNodes, Verbose);
1394 Stream <<
"}" << std::endl;
1400 std::list<std::shared_ptr<node_impl>> MSchedule;
1407 std::shared_ptr<graph_impl> MGraphImpl;
1410 std::unordered_map<std::shared_ptr<node_impl>,
1415 std::unordered_map<std::shared_ptr<node_impl>,
int> MPartitionNodes;
1422 std::vector<sycl::detail::AccessorImplHost *> MRequirements;
1425 std::vector<sycl::detail::AccessorImplPtr> MAccessors;
1427 std::vector<sycl::detail::EventImplPtr> MExecutionEvents;
1429 std::vector<std::shared_ptr<partition>> MPartitions;
1431 std::vector<std::shared_ptr<node_impl>> MNodeStorage;
1433 std::unordered_map<std::shared_ptr<node_impl>,
1441 std::multimap<node_impl::id_type, std::shared_ptr<node_impl>> MIDCache;
1447 size_t ParamSize,
const void *Data)
1457 MNodes.emplace_back(NodeImpl, ArgIndex);
1468 for (
auto &[NodeWeak, ArgIndex] :
MNodes) {
1469 auto NodeShared = NodeWeak.lock();
1471 NodeShared->updateArgValue(ArgIndex, NewValue, Size);
1482 for (
auto &[NodeWeak, ArgIndex] :
MNodes) {
1483 auto NodeShared = NodeWeak.lock();
1486 NodeShared->updateAccessor(ArgIndex, Acc);
1490 sizeof(sycl::detail::AccessorBaseHost));
1494 std::vector<std::pair<std::weak_ptr<node_impl>,
int>>
MNodes;
The context class represents a SYCL context on which kernel functions may be executed.
CGTYPE
Type of the command group.
The SYCL device class encapsulates a single SYCL device on which kernels may be executed.
backend get_backend() const noexcept
Returns the backend associated with this device.
bool has(aspect Aspect) const __SYCL_WARN_IMAGE_ASPECT(Aspect)
Indicates if the SYCL device has the given feature.
An event object can be used to synchronize memory transfers, enqueues of kernels and signaling barrie...
std::vector< std::byte > MValueStorage
dynamic_parameter_impl(std::shared_ptr< graph_impl > GraphImpl, size_t ParamSize, const void *Data)
std::vector< std::pair< std::weak_ptr< node_impl >, int > > MNodes
void * getValue()
Get a pointer to the internal value of this dynamic parameter.
std::shared_ptr< graph_impl > MGraph
void registerNode(std::shared_ptr< node_impl > NodeImpl, int ArgIndex)
Register a node with this dynamic parameter.
void updateValue(const void *NewValue, size_t Size)
Update the internal value of this dynamic parameter as well as the value of this parameter in all reg...
void updateAccessor(const sycl::detail::AccessorBaseHost *Acc)
Update the internal value of this dynamic parameter as well as the value of this parameter in all reg...
Class representing the implementation of command_graph<executable>.
exec_graph_impl(sycl::context Context, const std::shared_ptr< graph_impl > &GraphImpl, const property_list &PropList)
Constructor.
void createCommandBuffers(sycl::device Device, std::shared_ptr< partition > &Partition)
Turns the internal graph representation into UR command-buffers for a device.
std::unique_lock< std::shared_mutex > WriteLock
std::vector< sycl::detail::AccessorImplHost * > getRequirements() const
Returns a list of all the accessor requirements for this graph.
const std::vector< std::shared_ptr< partition > > & getPartitions() const
Query the vector of the partitions composing the exec_graph.
std::shared_lock< std::shared_mutex > ReadLock
void update(std::shared_ptr< graph_impl > GraphImpl)
sycl::context getContext() const
Query for the context tied to this graph.
~exec_graph_impl()
Destructor.
void makePartitions()
Partition the graph nodes and put the partition in MPartitions.
void updateImpl(std::shared_ptr< node_impl > NodeImpl)
sycl::device getDevice() const
Query for the device tied to this graph.
const std::list< std::shared_ptr< node_impl > > & getSchedule() const
Query the scheduling of node execution.
const std::shared_ptr< graph_impl > & getGraphImpl() const
Query the graph_impl.
std::shared_mutex MMutex
Protects all the fields that can be changed by class' methods.
bool previousSubmissionCompleted() const
Checks if the previous submissions of this graph have been completed This function checks the status ...
sycl::event enqueue(const std::shared_ptr< sycl::detail::queue_impl > &Queue, sycl::detail::CG::StorageInitHelper CGData)
Called by handler::ext_oneapi_command_graph() to schedule graph for execution.
Implementation details of command_graph<modifiable>.
std::vector< std::shared_ptr< node_impl > > MNodeStorage
Storage for all nodes contained within a graph.
void addQueue(const std::shared_ptr< sycl::detail::queue_impl > &RecordingQueue)
Add a queue to the set of queues which are currently recording to this graph.
void removeRoot(const std::shared_ptr< node_impl > &Root)
Remove node from list of root nodes.
std::unique_lock< std::shared_mutex > WriteLock
void makeEdge(std::shared_ptr< node_impl > Src, std::shared_ptr< node_impl > Dest)
Make an edge between two nodes in the graph.
std::shared_mutex MMutex
Protects all the fields that can be changed by class' methods.
void printGraphAsDot(const std::string FilePath, bool Verbose) const
Prints the contents of the graph to a text file in DOT format.
std::vector< sycl::detail::EventImplPtr > getExitNodesEvents()
Traverse the graph recursively to get the events associated with the output nodes of this graph.
void throwIfGraphRecordingQueue(const std::string ExceptionMsg) const
Throws an invalid exception if this function is called while a queue is recording commands to the gra...
std::shared_ptr< sycl::detail::event_impl > getEventForNode(std::shared_ptr< node_impl > NodeImpl) const
Find the sycl event associated with a node.
void setLastInorderNode(std::shared_ptr< sycl::detail::queue_impl > Queue, std::shared_ptr< node_impl > Node)
Track the last node added to this graph from an in-order queue.
std::set< std::weak_ptr< node_impl >, std::owner_less< std::weak_ptr< node_impl > > > MRoots
List of root nodes.
bool hasSimilarStructure(std::shared_ptr< detail::graph_impl > Graph, bool DebugPrint=false) const
Checks if the graph_impl of Graph has a similar structure to the graph_impl of the caller.
std::vector< sycl::detail::EventImplPtr > removeBarriersFromExtraDependencies()
Removes all Barrier nodes from the list of extra dependencies MExtraDependencies.
bool clearQueues()
Remove all queues which are recording to this graph, also sets all queues cleared back to the executi...
sycl::device getDevice() const
Query for the device tied to this graph.
std::shared_ptr< node_impl > getNodeForEvent(std::shared_ptr< sycl::detail::event_impl > EventImpl)
Find the node associated with a SYCL event.
static bool checkNodeRecursive(const std::shared_ptr< node_impl > &NodeA, const std::shared_ptr< node_impl > &NodeB)
Recursively check successors of NodeA and NodeB to check they are similar.
graph_impl(const sycl::context &SyclContext, const sycl::device &SyclDevice, const sycl::property_list &PropList={})
Constructor.
sycl::context getContext() const
Query for the context tied to this graph.
std::shared_ptr< node_impl > add(node_type NodeType, std::unique_ptr< sycl::detail::CG > CommandGroup, const std::vector< std::shared_ptr< node_impl >> &Dep={})
Create a kernel node in the graph.
void addEventForNode(std::shared_ptr< graph_impl > GraphImpl, std::shared_ptr< sycl::detail::event_impl > EventImpl, std::shared_ptr< node_impl > NodeImpl)
Associate a sycl event with a node in the graph.
std::shared_lock< std::shared_mutex > ReadLock
size_t getNumberOfNodes() const
Returns the number of nodes in the Graph.
std::shared_ptr< node_impl > getLastInorderNode(std::shared_ptr< sycl::detail::queue_impl > Queue)
Find the last node added to this graph from an in-order queue.
void removeQueue(const std::shared_ptr< sycl::detail::queue_impl > &RecordingQueue)
Remove a queue from the set of queues which are currently recording to this graph.
Implementation of node class from SYCL_EXT_ONEAPI_GRAPH.
void updateAccessor(int ArgIndex, const sycl::detail::AccessorBaseHost *Acc)
Update the value of an accessor inside this node.
void registerSuccessor(const std::shared_ptr< node_impl > &Node, const std::shared_ptr< node_impl > &Prev)
Add successor to the node.
std::vector< std::weak_ptr< node_impl > > MPredecessors
List of predecessors to this node.
node_impl(node_impl &Other)
Construct a node from another node.
std::vector< std::weak_ptr< node_impl > > MSuccessors
List of successors to this node.
void updateFromOtherNode(const std::shared_ptr< node_impl > &Other)
sycl::detail::CG::CGTYPE MCGType
Type of the command-group for the node.
std::shared_ptr< exec_graph_impl > MSubGraphImpl
Stores the executable graph impl associated with this node if it is a subgraph node.
node_impl(node_type NodeType, std::unique_ptr< sycl::detail::CG > &&CommandGroup)
Construct a node representing a command-group.
bool MNDRangeUsed
Track whether an ND-Range was used for kernel nodes.
bool hasRequirementDependency(sycl::detail::AccessorImplHost *IncomingReq)
Checks if this node should be a dependency of another node based on accessor requirements.
int MPartitionNum
Partition number needed to assign a Node to a a partition.
void updateRange(range< Dimensions > ExecutionRange)
void updateNDRange(nd_range< Dimensions > ExecutionRange)
bool isSimilar(const std::shared_ptr< node_impl > &Node, bool CompareContentOnly=false) const
Tests if the caller is similar to Node, this is only used for testing.
void printDotRecursive(std::fstream &Stream, std::vector< node_impl * > &Visited, bool Verbose)
Recursive Depth first traversal of linked nodes.
node_impl()
Construct an empty node.
node_type MNodeType
User facing type of the node.
std::unique_ptr< sycl::detail::CG > getCGCopy() const
Get a deep copy of this node's command group.
void updateArgValue(int ArgIndex, const void *NewValue, size_t Size)
bool isEmpty() const
Query if this is an empty node.
void registerPredecessor(const std::shared_ptr< node_impl > &Node)
Add predecessor to the node.
bool MVisited
Used for tracking visited status during cycle checks.
std::unique_ptr< sycl::detail::CG > MCommandGroup
Command group object which stores all args etc needed to enqueue the node.
id_type MID
Unique identifier for this node.
node_impl & operator=(node_impl &Other)
Copy-assignment operator.
void schedule()
Add nodes to MSchedule.
std::unordered_map< sycl::device, sycl::detail::pi::PiExtCommandBuffer > MPiCommandBuffers
Map of devices to command buffers.
std::set< std::weak_ptr< node_impl >, std::owner_less< std::weak_ptr< node_impl > > > MRoots
List of root nodes.
std::vector< std::shared_ptr< partition > > MPredecessors
List of predecessors to this partition.
std::list< std::shared_ptr< node_impl > > MSchedule
Execution schedule of nodes in the graph.
Property passed to command_graph constructor to disable checking for cycles.
Command group handler class.
Defines the iteration domain of both the work-groups and the overall dispatch.
Objects of the property_list class are containers for the SYCL properties.
::pi_ext_sync_point PiExtSyncPoint
::pi_ext_command_buffer_command PiExtCommandBufferCommand
decltype(Obj::impl) getSyclObjImpl(const Obj &SyclObject)
AccessorImplHost Requirement
std::shared_ptr< device_impl > DeviceImplPtr
node_type getNodeTypeFromCG(sycl::detail::CG::CGTYPE CGType)
class __SYCL_EBO __SYCL_SPECIAL_CLASS Dimensions
constexpr mode_tag_t< access_mode::read_write > read_write
std::error_code make_error_code(sycl::errc E) noexcept
Constructs an error code using e and sycl_category()
std::vector< std::vector< char > > MArgsStorage
Storage for standard layout arguments.