30 #include <shared_mutex>
33 inline namespace _V1 {
41 namespace experimental {
45 using sycl::detail::CG;
48 case sycl::detail::CGType::None:
50 case sycl::detail::CGType::Kernel:
52 case sycl::detail::CGType::CopyAccToPtr:
53 case sycl::detail::CGType::CopyPtrToAcc:
54 case sycl::detail::CGType::CopyAccToAcc:
55 case sycl::detail::CGType::CopyUSM:
57 case sycl::detail::CGType::Memset2DUSM:
59 case sycl::detail::CGType::Fill:
60 case sycl::detail::CGType::FillUSM:
62 case sycl::detail::CGType::PrefetchUSM:
64 case sycl::detail::CGType::AdviseUSM:
66 case sycl::detail::CGType::Barrier:
67 case sycl::detail::CGType::BarrierWaitlist:
69 case sycl::detail::CGType::CodeplayHostTask:
71 case sycl::detail::CGType::ExecCommandBuffer:
74 assert(
false &&
"Invalid Graph Node Type");
120 const std::shared_ptr<node_impl> &Prev) {
122 [Node](
const std::weak_ptr<node_impl> &Ptr) {
123 return Ptr.lock() == Node;
128 Node->registerPredecessor(Prev);
135 [&Node](
const std::weak_ptr<node_impl> &Ptr) {
136 return Ptr.lock() == Node;
151 std::unique_ptr<sycl::detail::CG> &&CommandGroup)
156 static_cast<sycl::detail::CGExecCommandBuffer *
>(
MCommandGroup.get())
171 if (
this != &Other) {
199 case access_mode::discard_read_write:
200 case access_mode::discard_write:
204 for (sycl::detail::AccessorImplHost *CurrentReq :
206 if (IncomingReq->MSYCLMemObj == CurrentReq->MSYCLMemObj) {
224 return ((
MCGType == sycl::detail::CGType::None) ||
225 (
MCGType == sycl::detail::CGType::Barrier));
232 case sycl::detail::CGType::Kernel: {
233 auto CGCopy = createCGCopy<sycl::detail::CGExecKernel>();
234 rebuildArgStorage(CGCopy->MArgs,
MCommandGroup->getArgsStorage(),
235 CGCopy->getArgsStorage());
236 return std::move(CGCopy);
238 case sycl::detail::CGType::CopyAccToPtr:
239 case sycl::detail::CGType::CopyPtrToAcc:
240 case sycl::detail::CGType::CopyAccToAcc:
241 return createCGCopy<sycl::detail::CGCopy>();
242 case sycl::detail::CGType::Fill:
243 return createCGCopy<sycl::detail::CGFill>();
244 case sycl::detail::CGType::UpdateHost:
245 return createCGCopy<sycl::detail::CGUpdateHost>();
246 case sycl::detail::CGType::CopyUSM:
247 return createCGCopy<sycl::detail::CGCopyUSM>();
248 case sycl::detail::CGType::FillUSM:
249 return createCGCopy<sycl::detail::CGFillUSM>();
250 case sycl::detail::CGType::PrefetchUSM:
251 return createCGCopy<sycl::detail::CGPrefetchUSM>();
252 case sycl::detail::CGType::AdviseUSM:
253 return createCGCopy<sycl::detail::CGAdviseUSM>();
254 case sycl::detail::CGType::Copy2DUSM:
255 return createCGCopy<sycl::detail::CGCopy2DUSM>();
256 case sycl::detail::CGType::Fill2DUSM:
257 return createCGCopy<sycl::detail::CGFill2DUSM>();
258 case sycl::detail::CGType::Memset2DUSM:
259 return createCGCopy<sycl::detail::CGMemset2DUSM>();
260 case sycl::detail::CGType::EnqueueNativeCommand:
261 case sycl::detail::CGType::CodeplayHostTask: {
266 auto CommandGroupPtr =
267 static_cast<sycl::detail::CGHostTask *
>(
MCommandGroup.get());
268 sycl::detail::HostTask HostTask = *CommandGroupPtr->MHostTask.get();
269 auto HostTaskSPtr = std::make_shared<sycl::detail::HostTask>(HostTask);
272 CommandGroupPtr->getArgsStorage(), CommandGroupPtr->getAccStorage(),
273 CommandGroupPtr->getSharedPtrStorage(),
274 CommandGroupPtr->getRequirements(), CommandGroupPtr->getEvents());
276 std::vector<sycl::detail::ArgDesc> NewArgs = CommandGroupPtr->MArgs;
278 rebuildArgStorage(NewArgs, CommandGroupPtr->getArgsStorage(),
281 sycl::detail::code_location Loc(CommandGroupPtr->MFileName.data(),
282 CommandGroupPtr->MFunctionName.data(),
283 CommandGroupPtr->MLine,
284 CommandGroupPtr->MColumn);
286 return std::make_unique<sycl::detail::CGHostTask>(
287 sycl::detail::CGHostTask(
288 std::move(HostTaskSPtr), CommandGroupPtr->MQueue,
289 CommandGroupPtr->MContext, std::move(NewArgs), std::move(Data),
290 CommandGroupPtr->getType(), Loc));
292 case sycl::detail::CGType::Barrier:
293 case sycl::detail::CGType::BarrierWaitlist:
296 return createCGCopy<sycl::detail::CG>();
297 case sycl::detail::CGType::CopyToDeviceGlobal:
298 return createCGCopy<sycl::detail::CGCopyToDeviceGlobal>();
299 case sycl::detail::CGType::CopyFromDeviceGlobal:
300 return createCGCopy<sycl::detail::CGCopyFromDeviceGlobal>();
301 case sycl::detail::CGType::ReadWriteHostPipe:
302 return createCGCopy<sycl::detail::CGReadWriteHostPipe>();
303 case sycl::detail::CGType::CopyImage:
304 return createCGCopy<sycl::detail::CGCopyImage>();
305 case sycl::detail::CGType::SemaphoreSignal:
306 return createCGCopy<sycl::detail::CGSemaphoreSignal>();
307 case sycl::detail::CGType::SemaphoreWait:
308 return createCGCopy<sycl::detail::CGSemaphoreWait>();
309 case sycl::detail::CGType::ProfilingTag:
310 return createCGCopy<sycl::detail::CGProfilingTag>();
311 case sycl::detail::CGType::ExecCommandBuffer:
312 return createCGCopy<sycl::detail::CGExecCommandBuffer>();
313 case sycl::detail::CGType::None:
325 bool CompareContentOnly =
false)
const {
326 if (!CompareContentOnly) {
327 if (
MSuccessors.size() != Node->MSuccessors.size())
337 case sycl::detail::CGType::Kernel: {
338 sycl::detail::CGExecKernel *ExecKernelA =
339 static_cast<sycl::detail::CGExecKernel *
>(
MCommandGroup.get());
340 sycl::detail::CGExecKernel *ExecKernelB =
341 static_cast<sycl::detail::CGExecKernel *
>(Node->MCommandGroup.get());
342 return ExecKernelA->MKernelName.compare(ExecKernelB->MKernelName) == 0;
344 case sycl::detail::CGType::CopyUSM: {
345 sycl::detail::CGCopyUSM *CopyA =
347 sycl::detail::CGCopyUSM *CopyB =
348 static_cast<sycl::detail::CGCopyUSM *
>(Node->MCommandGroup.get());
349 return (CopyA->getSrc() == CopyB->getSrc()) &&
350 (CopyA->getDst() == CopyB->getDst()) &&
351 (CopyA->getLength() == CopyB->getLength());
353 case sycl::detail::CGType::CopyAccToAcc:
354 case sycl::detail::CGType::CopyAccToPtr:
355 case sycl::detail::CGType::CopyPtrToAcc: {
356 sycl::detail::CGCopy *CopyA =
358 sycl::detail::CGCopy *CopyB =
359 static_cast<sycl::detail::CGCopy *
>(Node->MCommandGroup.get());
360 return (CopyA->getSrc() == CopyB->getSrc()) &&
361 (CopyA->getDst() == CopyB->getDst());
364 assert(
false &&
"Unexpected command group type!");
376 std::vector<node_impl *> &Visited,
bool Verbose) {
378 if (std::find(Visited.begin(), Visited.end(),
this) != Visited.end())
381 Visited.push_back(
this);
383 printDotCG(Stream, Verbose);
385 auto NodeDep = Dep.lock();
386 Stream <<
" \"" << NodeDep.get() <<
"\" -> \"" <<
this <<
"\""
390 for (std::weak_ptr<node_impl> Succ :
MSuccessors) {
392 Succ.lock()->printDotRecursive(Stream, Visited, Verbose);
399 if ((
MCGType != sycl::detail::CGType::CopyAccToAcc) &&
400 (
MCGType != sycl::detail::CGType::CopyAccToPtr) &&
401 (
MCGType != sycl::detail::CGType::CopyPtrToAcc)) {
405 auto Copy =
static_cast<sycl::detail::CGCopy *
>(
MCommandGroup.get());
408 return (ReqSrc->MDims > 1) || (ReqDst->MDims > 1);
417 static_cast<sycl::detail::CGExecKernel *
>(
MCommandGroup.get())->MArgs;
419 for (
auto &Arg : Args) {
420 if (Arg.MIndex != ArgIndex) {
423 assert(Arg.MType == sycl::detail::kernel_param_kind_t::kind_accessor);
426 if (
static_cast<sycl::detail::SYCLMemObjT *
>(NewAccImpl->MSYCLMemObj)
427 ->needsWriteBack()) {
430 "Accessors to buffers which have write_back enabled "
431 "are not allowed to be used in command graphs.");
438 if (NewReq->MAccessMode != sycl::access_mode::read) {
440 static_cast<sycl::detail::SYCLMemObjT *
>(NewReq->MSYCLMemObj);
441 SYCLMemObj->handleWriteAccessorCreation();
446 static_cast<sycl::detail::AccessorImplHost *
>(Arg.MPtr);
447 Acc.get() == OldAcc) {
454 static_cast<sycl::detail::AccessorImplHost *
>(Arg.MPtr);
459 Arg.MPtr = NewAccImpl.get();
467 static_cast<sycl::detail::CGExecKernel *
>(
MCommandGroup.get())->MArgs;
468 for (
auto &Arg : Args) {
469 if (Arg.MIndex != ArgIndex) {
472 assert(Arg.MSize ==
static_cast<int>(Size));
475 std::memcpy(Arg.MPtr, NewValue, Size);
480 template <
int Dimensions>
482 if (
MCGType != sycl::detail::CGType::Kernel) {
485 "Cannot update execution range of nodes which are not kernel nodes");
489 "Cannot update node which was created with a "
490 "sycl::range with a sycl::nd_range");
494 static_cast<sycl::detail::CGExecKernel *
>(
MCommandGroup.get())
499 "Cannot update execution range of a node with an "
500 "execution range of different dimensions than what "
501 "the node was originall created with.");
504 NDRDesc = sycl::detail::NDRDescT{ExecutionRange};
508 if (
MCGType != sycl::detail::CGType::Kernel) {
511 "Cannot update execution range of nodes which are not kernel nodes");
515 "Cannot update node which was created with a "
516 "sycl::nd_range with a sycl::range");
520 static_cast<sycl::detail::CGExecKernel *
>(
MCommandGroup.get())
525 "Cannot update execution range of a node with an "
526 "execution range of different dimensions than what "
527 "the node was originall created with.");
530 NDRDesc = sycl::detail::NDRDescT{ExecutionRange};
535 static_cast<sycl::detail::CGExecKernel *
>(
MCommandGroup.get());
537 static_cast<sycl::detail::CGExecKernel *
>(Other->MCommandGroup.get());
539 ExecCG->MArgs = OtherExecCG->MArgs;
540 ExecCG->MNDRDesc = OtherExecCG->MNDRDesc;
541 ExecCG->getAccStorage() = OtherExecCG->getAccStorage();
542 ExecCG->getRequirements() = OtherExecCG->getRequirements();
544 auto &OldArgStorage = OtherExecCG->getArgsStorage();
545 auto &NewArgStorage = ExecCG->getArgsStorage();
547 rebuildArgStorage(ExecCG->MArgs, OldArgStorage, NewArgStorage);
553 void rebuildArgStorage(std::vector<sycl::detail::ArgDesc> &Args,
554 const std::vector<std::vector<char>> &OldArgStorage,
555 std::vector<std::vector<char>> &NewArgStorage)
const {
557 NewArgStorage.clear();
561 for (
auto &Arg : Args) {
562 if (Arg.MType != sycl::detail::kernel_param_kind_t::kind_std_layout) {
566 for (
auto &ArgStorage : OldArgStorage) {
567 if (ArgStorage.data() != Arg.MPtr) {
570 NewArgStorage.emplace_back(Arg.MSize);
572 std::memcpy(NewArgStorage.back().data(), ArgStorage.data(), Arg.MSize);
574 Arg.MPtr = NewArgStorage.back().data();
582 static id_type getNextNodeID() {
593 void printDotCG(std::ostream &Stream,
bool Verbose) {
594 Stream <<
"\"" <<
this <<
"\" [style=bold, label=\"";
596 Stream <<
"ID = " <<
this <<
"\\n";
600 case sycl::detail::CGType::None:
601 Stream <<
"None \\n";
603 case sycl::detail::CGType::Kernel: {
604 Stream <<
"CGExecKernel \\n";
605 sycl::detail::CGExecKernel *Kernel =
606 static_cast<sycl::detail::CGExecKernel *
>(
MCommandGroup.get());
607 Stream <<
"NAME = " << Kernel->MKernelName <<
"\\n";
609 Stream <<
"ARGS = \\n";
610 for (
size_t i = 0; i < Kernel->MArgs.size(); i++) {
611 auto Arg = Kernel->MArgs[i];
612 std::string Type =
"Undefined";
613 if (Arg.MType == sycl::detail::kernel_param_kind_t::kind_accessor) {
615 }
else if (Arg.MType ==
616 sycl::detail::kernel_param_kind_t::kind_std_layout) {
618 }
else if (Arg.MType ==
619 sycl::detail::kernel_param_kind_t::kind_sampler) {
621 }
else if (Arg.MType ==
622 sycl::detail::kernel_param_kind_t::kind_pointer) {
624 auto Fill = Stream.fill();
625 Stream << i <<
") Type: " << Type <<
" Ptr: " << Arg.MPtr <<
"(0x"
626 << std::hex << std::setfill(
'0');
627 for (
int i = Arg.MSize - 1; i >= 0; --i) {
629 <<
static_cast<int16_t
>(
630 (
static_cast<unsigned char *
>(Arg.MPtr))[i]);
633 Stream << std::dec <<
")\\n";
636 kind_specialization_constants_buffer) {
637 Type =
"Specialization Constants Buffer";
638 }
else if (Arg.MType ==
639 sycl::detail::kernel_param_kind_t::kind_stream) {
641 }
else if (Arg.MType ==
642 sycl::detail::kernel_param_kind_t::kind_invalid) {
645 Stream << i <<
") Type: " << Type <<
" Ptr: " << Arg.MPtr <<
"\\n";
650 case sycl::detail::CGType::CopyAccToPtr:
651 Stream <<
"CGCopy Device-to-Host \\n";
653 sycl::detail::CGCopy *Copy =
655 Stream <<
"Src: " << Copy->getSrc() <<
" Dst: " << Copy->getDst()
659 case sycl::detail::CGType::CopyPtrToAcc:
660 Stream <<
"CGCopy Host-to-Device \\n";
662 sycl::detail::CGCopy *Copy =
664 Stream <<
"Src: " << Copy->getSrc() <<
" Dst: " << Copy->getDst()
668 case sycl::detail::CGType::CopyAccToAcc:
669 Stream <<
"CGCopy Device-to-Device \\n";
671 sycl::detail::CGCopy *Copy =
673 Stream <<
"Src: " << Copy->getSrc() <<
" Dst: " << Copy->getDst()
677 case sycl::detail::CGType::Fill:
678 Stream <<
"CGFill \\n";
680 sycl::detail::CGFill *
Fill =
682 Stream <<
"Ptr: " <<
Fill->MPtr <<
"\\n";
685 case sycl::detail::CGType::UpdateHost:
686 Stream <<
"CGCUpdateHost \\n";
688 sycl::detail::CGUpdateHost *Host =
689 static_cast<sycl::detail::CGUpdateHost *
>(
MCommandGroup.get());
690 Stream <<
"Ptr: " << Host->getReqToUpdate() <<
"\\n";
693 case sycl::detail::CGType::CopyUSM:
694 Stream <<
"CGCopyUSM \\n";
696 sycl::detail::CGCopyUSM *
CopyUSM =
698 Stream <<
"Src: " <<
CopyUSM->getSrc() <<
" Dst: " <<
CopyUSM->getDst()
699 <<
" Length: " <<
CopyUSM->getLength() <<
"\\n";
702 case sycl::detail::CGType::FillUSM:
703 Stream <<
"CGFillUSM \\n";
705 sycl::detail::CGFillUSM *
FillUSM =
707 Stream <<
"Dst: " <<
FillUSM->getDst()
708 <<
" Length: " <<
FillUSM->getLength() <<
" Pattern: ";
709 for (
auto byte :
FillUSM->getPattern())
714 case sycl::detail::CGType::PrefetchUSM:
715 Stream <<
"CGPrefetchUSM \\n";
717 sycl::detail::CGPrefetchUSM *Prefetch =
718 static_cast<sycl::detail::CGPrefetchUSM *
>(
MCommandGroup.get());
719 Stream <<
"Dst: " << Prefetch->getDst()
720 <<
" Length: " << Prefetch->getLength() <<
"\\n";
723 case sycl::detail::CGType::AdviseUSM:
724 Stream <<
"CGAdviseUSM \\n";
727 static_cast<sycl::detail::CGAdviseUSM *
>(
MCommandGroup.get());
729 <<
" Length: " <<
AdviseUSM->getLength() <<
"\\n";
732 case sycl::detail::CGType::CodeplayHostTask:
733 Stream <<
"CGHostTask \\n";
735 case sycl::detail::CGType::Barrier:
736 Stream <<
"CGBarrier \\n";
738 case sycl::detail::CGType::Copy2DUSM:
739 Stream <<
"CGCopy2DUSM \\n";
742 static_cast<sycl::detail::CGCopy2DUSM *
>(
MCommandGroup.get());
744 <<
" Dst: " <<
Copy2DUSM->getDst() <<
"\\n";
747 case sycl::detail::CGType::Fill2DUSM:
748 Stream <<
"CGFill2DUSM \\n";
751 static_cast<sycl::detail::CGFill2DUSM *
>(
MCommandGroup.get());
752 Stream <<
"Dst: " <<
Fill2DUSM->getDst() <<
"\\n";
755 case sycl::detail::CGType::Memset2DUSM:
756 Stream <<
"CGMemset2DUSM \\n";
759 static_cast<sycl::detail::CGMemset2DUSM *
>(
MCommandGroup.get());
760 Stream <<
"Dst: " <<
Memset2DUSM->getDst() <<
"\\n";
763 case sycl::detail::CGType::ReadWriteHostPipe:
764 Stream <<
"CGReadWriteHostPipe \\n";
766 case sycl::detail::CGType::CopyToDeviceGlobal:
767 Stream <<
"CGCopyToDeviceGlobal \\n";
770 static_cast<sycl::detail::CGCopyToDeviceGlobal *
>(
776 case sycl::detail::CGType::CopyFromDeviceGlobal:
777 Stream <<
"CGCopyFromDeviceGlobal \\n";
780 static_cast<sycl::detail::CGCopyFromDeviceGlobal *
>(
786 case sycl::detail::CGType::ExecCommandBuffer:
787 Stream <<
"CGExecCommandBuffer \\n";
790 Stream <<
"Other \\n";
793 Stream <<
"\"];" << std::endl;
800 template <
typename CGT> std::unique_ptr<CGT> createCGCopy()
const {
801 return std::make_unique<CGT>(*
static_cast<CGT *
>(
MCommandGroup.get()));
811 std::set<std::weak_ptr<node_impl>, std::owner_less<std::weak_ptr<node_impl>>>
816 std::unordered_map<sycl::device, ur_exp_command_buffer_handle_t>
826 return (
MRoots.size() && ((*
MRoots.begin()).lock()->MCGType ==
827 sycl::detail::CGType::CodeplayHostTask));
840 if ((Node->MSuccessors.size() > 1) || (Node->isNDCopyNode())) {
855 using ReadLock = std::shared_lock<std::shared_mutex>;
867 : MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(),
868 MEventsMap(), MInorderQueueMap() {
870 MSkipCycleChecks =
true;
873 .has_property<property::graph::assume_buffer_outlives_graph>()) {
874 MAllowBuffers =
true;
877 if (!SyclDevice.
has(aspect::ext_oneapi_limited_graph) &&
878 !SyclDevice.
has(aspect::ext_oneapi_graph)) {
879 std::stringstream Stream;
881 std::string BackendString = Stream.str();
884 BackendString +
" backend is not supported by SYCL Graph extension.");
892 void removeRoot(
const std::shared_ptr<node_impl> &Root);
899 std::shared_ptr<node_impl>
900 add(
node_type NodeType, std::unique_ptr<sycl::detail::CG> CommandGroup,
901 const std::vector<std::shared_ptr<node_impl>> &Dep = {});
909 std::shared_ptr<node_impl>
910 add(
const std::shared_ptr<graph_impl> &Impl,
911 std::function<
void(
handler &)> CGF,
912 const std::vector<sycl::detail::ArgDesc> &Args,
913 const std::vector<std::shared_ptr<node_impl>> &Dep = {});
919 std::shared_ptr<node_impl>
920 add(
const std::shared_ptr<graph_impl> &Impl,
921 const std::vector<std::shared_ptr<node_impl>> &Dep = {});
927 std::shared_ptr<node_impl>
928 add(
const std::shared_ptr<graph_impl> &Impl,
929 const std::vector<sycl::detail::EventImplPtr> Events);
935 addQueue(
const std::shared_ptr<sycl::detail::queue_impl> &RecordingQueue) {
936 MRecordingQueues.insert(RecordingQueue);
943 removeQueue(
const std::shared_ptr<sycl::detail::queue_impl> &RecordingQueue) {
944 MRecordingQueues.erase(RecordingQueue);
959 std::shared_ptr<sycl::detail::event_impl> EventImpl,
960 std::shared_ptr<node_impl> NodeImpl) {
961 if (!(EventImpl->getCommandGraph()))
962 EventImpl->setCommandGraph(GraphImpl);
963 MEventsMap[EventImpl] = NodeImpl;
969 std::shared_ptr<sycl::detail::event_impl>
972 if (
auto EventImpl = std::find_if(
973 MEventsMap.begin(), MEventsMap.end(),
974 [NodeImpl](
auto &it) { return it.second == NodeImpl; });
975 EventImpl != MEventsMap.end()) {
976 return EventImpl->first;
981 "No event has been recorded for the specified graph node");
988 std::shared_ptr<node_impl>
992 if (
auto NodeFound = MEventsMap.find(EventImpl);
993 NodeFound != std::end(MEventsMap)) {
994 return NodeFound->second;
999 "No node in this graph is associated with this event");
1011 std::set<std::weak_ptr<node_impl>, std::owner_less<std::weak_ptr<node_impl>>>
1024 std::shared_ptr<node_impl>
1026 std::weak_ptr<sycl::detail::queue_impl> QueueWeakPtr(Queue);
1027 if (0 == MInorderQueueMap.count(QueueWeakPtr)) {
1030 return MInorderQueueMap[QueueWeakPtr];
1037 std::shared_ptr<node_impl> Node) {
1038 std::weak_ptr<sycl::detail::queue_impl> QueueWeakPtr(Queue);
1039 MInorderQueueMap[QueueWeakPtr] = Node;
1048 std::vector<node_impl *> VisitedNodes;
1050 std::fstream Stream(FilePath, std::ios::out);
1051 Stream <<
"digraph dot {" << std::endl;
1053 for (std::weak_ptr<node_impl> Node :
MRoots)
1054 Node.lock()->printDotRecursive(Stream, VisitedNodes, Verbose);
1056 Stream <<
"}" << std::endl;
1066 void makeEdge(std::shared_ptr<node_impl> Src,
1067 std::shared_ptr<node_impl> Dest);
1073 if (MRecordingQueues.size()) {
1076 " cannot be called when a queue "
1077 "is currently recording commands to a graph.");
1086 const std::shared_ptr<node_impl> &NodeB) {
1087 size_t FoundCnt = 0;
1088 for (std::weak_ptr<node_impl> &SuccA : NodeA->MSuccessors) {
1089 for (std::weak_ptr<node_impl> &SuccB : NodeB->MSuccessors) {
1090 if (NodeA->isSimilar(NodeB) &&
1097 if (FoundCnt != NodeA->MSuccessors.size()) {
1116 bool DebugPrint =
false)
const {
1117 if (
this == Graph.get())
1120 if (MContext != Graph->MContext) {
1123 "MContext are not the same.");
1128 if (MDevice != Graph->MDevice) {
1131 "MDevice are not the same.");
1136 if (MEventsMap.size() != Graph->MEventsMap.size()) {
1139 "MEventsMap sizes are not the same.");
1144 if (MInorderQueueMap.size() != Graph->MInorderQueueMap.size()) {
1147 "MInorderQueueMap sizes are not the same.");
1152 if (
MRoots.size() != Graph->MRoots.size()) {
1155 "MRoots sizes are not the same.");
1160 size_t RootsFound = 0;
1161 for (std::weak_ptr<node_impl> NodeA :
MRoots) {
1162 for (std::weak_ptr<node_impl> NodeB : Graph->MRoots) {
1163 auto NodeALocked = NodeA.lock();
1164 auto NodeBLocked = NodeB.lock();
1166 if (NodeALocked->isSimilar(NodeBLocked)) {
1175 if (RootsFound !=
MRoots.size()) {
1178 "Root Nodes do NOT match.");
1194 std::vector<sycl::detail::EventImplPtr>
1201 std::shared_ptr<node_impl> BarrierNodeImpl) {
1202 MBarrierDependencyMap[Queue] = BarrierNodeImpl;
1208 std::shared_ptr<node_impl>
1210 return MBarrierDependencyMap[Queue];
1221 searchDepthFirst(std::function<
bool(std::shared_ptr<node_impl> &,
1222 std::deque<std::shared_ptr<node_impl>> &)>
1229 bool checkForCycles();
1233 void addRoot(
const std::shared_ptr<node_impl> &Root);
1239 std::shared_ptr<node_impl>
1240 addNodesToExits(
const std::shared_ptr<graph_impl> &Impl,
1241 const std::list<std::shared_ptr<node_impl>> &NodeList);
1247 void addDepsToNode(std::shared_ptr<node_impl> Node,
1248 const std::vector<std::shared_ptr<node_impl>> &Deps) {
1249 if (!Deps.empty()) {
1250 for (
auto &N : Deps) {
1251 N->registerSuccessor(Node, N);
1255 this->addRoot(Node);
1265 std::set<std::weak_ptr<sycl::detail::queue_impl>,
1266 std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
1269 std::unordered_map<std::shared_ptr<sycl::detail::event_impl>,
1270 std::shared_ptr<node_impl>>
1275 std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
1276 std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
1280 bool MSkipCycleChecks =
false;
1282 std::set<sycl::detail::SYCLMemObjT *> MMemObjs;
1286 bool MAllowBuffers =
false;
1290 std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
1291 std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
1292 MBarrierDependencyMap;
1313 const std::shared_ptr<graph_impl> &GraphImpl,
1341 std::shared_ptr<partition> &Partition);
1359 const std::shared_ptr<graph_impl> &
getGraphImpl()
const {
return MGraphImpl; }
1373 for (
auto Event : MExecutionEvents) {
1374 if (!Event->isCompleted()) {
1383 return MRequirements;
1386 void update(std::shared_ptr<graph_impl> GraphImpl);
1387 void update(std::shared_ptr<node_impl> Node);
1388 void update(
const std::vector<std::shared_ptr<node_impl>> Nodes);
1390 void updateImpl(std::shared_ptr<node_impl> NodeImpl);
1400 ur_exp_command_buffer_sync_point_t
1402 ur_exp_command_buffer_handle_t CommandBuffer,
1403 std::shared_ptr<node_impl> Node);
1412 ur_exp_command_buffer_sync_point_t
1414 ur_exp_command_buffer_handle_t CommandBuffer,
1415 std::shared_ptr<node_impl> Node);
1423 void findRealDeps(std::vector<ur_exp_command_buffer_sync_point_t> &Deps,
1424 std::shared_ptr<node_impl> CurrentNode,
1425 int ReferencePartitionNum);
1430 void duplicateNodes();
1436 void printGraphAsDot(
const std::string FilePath,
bool Verbose)
const {
1438 std::vector<node_impl *> VisitedNodes;
1440 std::fstream Stream(FilePath, std::ios::out);
1441 Stream <<
"digraph dot {" << std::endl;
1443 std::vector<std::shared_ptr<node_impl>> Roots;
1444 for (
auto &Node : MNodeStorage) {
1445 if (Node->MPredecessors.size() == 0) {
1446 Roots.push_back(Node);
1450 for (std::shared_ptr<node_impl> Node : Roots)
1451 Node->printDotRecursive(Stream, VisitedNodes, Verbose);
1453 Stream <<
"}" << std::endl;
1459 std::list<std::shared_ptr<node_impl>> MSchedule;
1466 std::shared_ptr<graph_impl> MGraphImpl;
1469 std::unordered_map<std::shared_ptr<node_impl>,
1470 ur_exp_command_buffer_sync_point_t>
1474 std::unordered_map<std::shared_ptr<node_impl>,
int> MPartitionNodes;
1481 std::vector<sycl::detail::AccessorImplHost *> MRequirements;
1484 std::vector<sycl::detail::AccessorImplPtr> MAccessors;
1486 std::vector<sycl::detail::EventImplPtr> MExecutionEvents;
1488 std::vector<std::shared_ptr<partition>> MPartitions;
1490 std::vector<std::shared_ptr<node_impl>> MNodeStorage;
1492 std::unordered_map<std::shared_ptr<node_impl>,
1493 ur_exp_command_buffer_command_handle_t>
1498 bool MEnableProfiling;
1502 std::multimap<node_impl::id_type, std::shared_ptr<node_impl>> MIDCache;
1508 size_t ParamSize,
const void *Data)
1518 MNodes.emplace_back(NodeImpl, ArgIndex);
1529 for (
auto &[NodeWeak, ArgIndex] :
MNodes) {
1530 auto NodeShared = NodeWeak.lock();
1532 NodeShared->updateArgValue(ArgIndex, NewValue, Size);
1543 for (
auto &[NodeWeak, ArgIndex] :
MNodes) {
1544 auto NodeShared = NodeWeak.lock();
1547 NodeShared->updateAccessor(ArgIndex, Acc);
1551 sizeof(sycl::detail::AccessorBaseHost));
1555 std::vector<std::pair<std::weak_ptr<node_impl>,
int>>
MNodes;
The context class represents a SYCL context on which kernel functions may be executed.
The SYCL device class encapsulates a single SYCL device on which kernels may be executed.
backend get_backend() const noexcept
Returns the backend associated with this device.
bool has(aspect Aspect) const __SYCL_WARN_IMAGE_ASPECT(Aspect)
Indicates if the SYCL device has the given feature.
An event object can be used to synchronize memory transfers, enqueues of kernels and signaling barrie...
std::vector< std::byte > MValueStorage
dynamic_parameter_impl(std::shared_ptr< graph_impl > GraphImpl, size_t ParamSize, const void *Data)
std::vector< std::pair< std::weak_ptr< node_impl >, int > > MNodes
void * getValue()
Get a pointer to the internal value of this dynamic parameter.
std::shared_ptr< graph_impl > MGraph
void registerNode(std::shared_ptr< node_impl > NodeImpl, int ArgIndex)
Register a node with this dynamic parameter.
void updateValue(const void *NewValue, size_t Size)
Update the internal value of this dynamic parameter as well as the value of this parameter in all reg...
void updateAccessor(const sycl::detail::AccessorBaseHost *Acc)
Update the internal value of this dynamic parameter as well as the value of this parameter in all reg...
Class representing the implementation of command_graph<executable>.
exec_graph_impl(sycl::context Context, const std::shared_ptr< graph_impl > &GraphImpl, const property_list &PropList)
Constructor.
void createCommandBuffers(sycl::device Device, std::shared_ptr< partition > &Partition)
Turns the internal graph representation into UR command-buffers for a device.
std::unique_lock< std::shared_mutex > WriteLock
std::vector< sycl::detail::AccessorImplHost * > getRequirements() const
Returns a list of all the accessor requirements for this graph.
const std::vector< std::shared_ptr< partition > > & getPartitions() const
Query the vector of the partitions composing the exec_graph.
std::shared_lock< std::shared_mutex > ReadLock
void update(std::shared_ptr< graph_impl > GraphImpl)
sycl::context getContext() const
Query for the context tied to this graph.
~exec_graph_impl()
Destructor.
void makePartitions()
Partition the graph nodes and put the partition in MPartitions.
void updateImpl(std::shared_ptr< node_impl > NodeImpl)
sycl::device getDevice() const
Query for the device tied to this graph.
const std::list< std::shared_ptr< node_impl > > & getSchedule() const
Query the scheduling of node execution.
const std::shared_ptr< graph_impl > & getGraphImpl() const
Query the graph_impl.
std::shared_mutex MMutex
Protects all the fields that can be changed by class' methods.
bool previousSubmissionCompleted() const
Checks if the previous submissions of this graph have been completed This function checks the status ...
sycl::event enqueue(const std::shared_ptr< sycl::detail::queue_impl > &Queue, sycl::detail::CG::StorageInitHelper CGData)
Called by handler::ext_oneapi_command_graph() to schedule graph for execution.
Implementation details of command_graph<modifiable>.
std::vector< std::shared_ptr< node_impl > > MNodeStorage
Storage for all nodes contained within a graph.
void addQueue(const std::shared_ptr< sycl::detail::queue_impl > &RecordingQueue)
Add a queue to the set of queues which are currently recording to this graph.
void removeRoot(const std::shared_ptr< node_impl > &Root)
Remove node from list of root nodes.
std::unique_lock< std::shared_mutex > WriteLock
void makeEdge(std::shared_ptr< node_impl > Src, std::shared_ptr< node_impl > Dest)
Make an edge between two nodes in the graph.
std::shared_mutex MMutex
Protects all the fields that can be changed by class' methods.
void printGraphAsDot(const std::string FilePath, bool Verbose) const
Prints the contents of the graph to a text file in DOT format.
void throwIfGraphRecordingQueue(const std::string ExceptionMsg) const
Throws an invalid exception if this function is called while a queue is recording commands to the gra...
std::shared_ptr< sycl::detail::event_impl > getEventForNode(std::shared_ptr< node_impl > NodeImpl) const
Find the sycl event associated with a node.
void setBarrierDep(std::weak_ptr< sycl::detail::queue_impl > Queue, std::shared_ptr< node_impl > BarrierNodeImpl)
Store the last barrier node that was submitted to the queue.
void setLastInorderNode(std::shared_ptr< sycl::detail::queue_impl > Queue, std::shared_ptr< node_impl > Node)
Track the last node added to this graph from an in-order queue.
std::set< std::weak_ptr< node_impl >, std::owner_less< std::weak_ptr< node_impl > > > MRoots
List of root nodes.
bool hasSimilarStructure(std::shared_ptr< detail::graph_impl > Graph, bool DebugPrint=false) const
Checks if the graph_impl of Graph has a similar structure to the graph_impl of the caller.
std::vector< sycl::detail::EventImplPtr > getExitNodesEvents(std::weak_ptr< sycl::detail::queue_impl > Queue)
Traverse the graph recursively to get the events associated with the output nodes of this graph assoc...
bool clearQueues()
Remove all queues which are recording to this graph, also sets all queues cleared back to the executi...
sycl::device getDevice() const
Query for the device tied to this graph.
std::shared_ptr< node_impl > getNodeForEvent(std::shared_ptr< sycl::detail::event_impl > EventImpl)
Find the node associated with a SYCL event.
static bool checkNodeRecursive(const std::shared_ptr< node_impl > &NodeA, const std::shared_ptr< node_impl > &NodeB)
Recursively check successors of NodeA and NodeB to check they are similar.
graph_impl(const sycl::context &SyclContext, const sycl::device &SyclDevice, const sycl::property_list &PropList={})
Constructor.
sycl::context getContext() const
Query for the context tied to this graph.
std::shared_ptr< node_impl > add(node_type NodeType, std::unique_ptr< sycl::detail::CG > CommandGroup, const std::vector< std::shared_ptr< node_impl >> &Dep={})
Create a kernel node in the graph.
void addEventForNode(std::shared_ptr< graph_impl > GraphImpl, std::shared_ptr< sycl::detail::event_impl > EventImpl, std::shared_ptr< node_impl > NodeImpl)
Associate a sycl event with a node in the graph.
std::shared_lock< std::shared_mutex > ReadLock
size_t getNumberOfNodes() const
Returns the number of nodes in the Graph.
std::shared_ptr< node_impl > getLastInorderNode(std::shared_ptr< sycl::detail::queue_impl > Queue)
Find the last node added to this graph from an in-order queue.
std::shared_ptr< node_impl > getBarrierDep(std::weak_ptr< sycl::detail::queue_impl > Queue)
Get the last barrier node that was submitted to the queue.
void removeQueue(const std::shared_ptr< sycl::detail::queue_impl > &RecordingQueue)
Remove a queue from the set of queues which are currently recording to this graph.
Implementation of node class from SYCL_EXT_ONEAPI_GRAPH.
bool isNDCopyNode() const
Test if the node contains a N-D copy.
void updateAccessor(int ArgIndex, const sycl::detail::AccessorBaseHost *Acc)
Update the value of an accessor inside this node.
void registerSuccessor(const std::shared_ptr< node_impl > &Node, const std::shared_ptr< node_impl > &Prev)
Add successor to the node.
std::vector< std::weak_ptr< node_impl > > MPredecessors
List of predecessors to this node.
node_impl(node_impl &Other)
Construct a node from another node.
std::vector< std::weak_ptr< node_impl > > MSuccessors
List of successors to this node.
void updateFromOtherNode(const std::shared_ptr< node_impl > &Other)
sycl::detail::CGType MCGType
Type of the command-group for the node.
std::shared_ptr< exec_graph_impl > MSubGraphImpl
Stores the executable graph impl associated with this node if it is a subgraph node.
node_impl(node_type NodeType, std::unique_ptr< sycl::detail::CG > &&CommandGroup)
Construct a node representing a command-group.
bool MNDRangeUsed
Track whether an ND-Range was used for kernel nodes.
bool hasRequirementDependency(sycl::detail::AccessorImplHost *IncomingReq)
Checks if this node should be a dependency of another node based on accessor requirements.
int MPartitionNum
Partition number needed to assign a Node to a a partition.
void updateRange(range< Dimensions > ExecutionRange)
void updateNDRange(nd_range< Dimensions > ExecutionRange)
bool isSimilar(const std::shared_ptr< node_impl > &Node, bool CompareContentOnly=false) const
Tests if the caller is similar to Node, this is only used for testing.
void printDotRecursive(std::fstream &Stream, std::vector< node_impl * > &Visited, bool Verbose)
Recursive Depth first traversal of linked nodes.
node_impl()
Construct an empty node.
node_type MNodeType
User facing type of the node.
std::unique_ptr< sycl::detail::CG > getCGCopy() const
Get a deep copy of this node's command group.
void updateArgValue(int ArgIndex, const void *NewValue, size_t Size)
bool isEmpty() const
Query if this is an empty node.
void registerPredecessor(const std::shared_ptr< node_impl > &Node)
Add predecessor to the node.
bool MVisited
Used for tracking visited status during cycle checks.
std::unique_ptr< sycl::detail::CG > MCommandGroup
Command group object which stores all args etc needed to enqueue the node.
id_type MID
Unique identifier for this node.
node_impl & operator=(node_impl &Other)
Copy-assignment operator.
void schedule()
Add nodes to MSchedule.
bool checkIfGraphIsSinglePath()
Checks if the graph is single path, i.e.
std::unordered_map< sycl::device, ur_exp_command_buffer_handle_t > MCommandBuffers
Map of devices to command buffers.
std::set< std::weak_ptr< node_impl >, std::owner_less< std::weak_ptr< node_impl > > > MRoots
List of root nodes.
std::vector< std::shared_ptr< partition > > MPredecessors
List of predecessors to this partition.
bool MIsInOrderGraph
True if the graph of this partition is a single path graph and in-order optmization can be applied on...
std::list< std::shared_ptr< node_impl > > MSchedule
Execution schedule of nodes in the graph.
Property passed to command_graph constructor to disable checking for cycles.
Command group handler class.
Defines the iteration domain of both the work-groups and the overall dispatch.
Objects of the property_list class are containers for the SYCL properties.
decltype(Obj::impl) const & getSyclObjImpl(const Obj &SyclObject)
AccessorImplHost Requirement
std::shared_ptr< device_impl > DeviceImplPtr
CGType
Type of the command group.
node_type getNodeTypeFromCG(sycl::detail::CGType CGType)
class __SYCL_EBO __SYCL_SPECIAL_CLASS Dimensions
constexpr mode_tag_t< access_mode::read_write > read_write
__width_manipulator__ setw(int Width)
std::error_code make_error_code(sycl::errc E) noexcept
Constructs an error code using e and sycl_category()
std::vector< std::vector< char > > MArgsStorage
Storage for standard layout arguments.