Program Listing for File d3d12-object.h

Return to documentation for file (include\d3d12-state-tracking\d3d12-object.h)

/******************************************************************************
© Intel Corporation.

This software and the related documents are Intel copyrighted materials,
and your use of them is governed by the express license under which they
were provided to you ("License"). Unless the License provides otherwise,
you may not use, modify, copy, publish, distribute, disclose or transmit
this software or the related documents without Intel's prior written
permission.


 This software and the related documents are provided as is, with no express
or implied warranties, other than those that are expressly stated in the
License.

******************************************************************************/

#pragma once

#include "utility/rdtsc.h"
#include <concurrent/critical-section.h>

#include <functional>
#include <list>
#include <d3d12.h>
#include <string>
#include <memory>

namespace gpa {
namespace d3d12_state_tracker {

using DeathNotificationCallback = std::function<void(ID3D12Object* deadObjectPointer)>;

class D3D12ObjectState : public IUnknown
{
public:
    static constexpr GUID sGUID = {0x69a84fd0, 0xb106, 0x4496, {0xbe, 0x30, 0xd6, 0xac, 0x1e, 0x30, 0x55, 0x2b}};

private:
    std::list<DeathNotificationCallback> mDeathCallbacks;
    std::wstring mName;

protected:
    ID3D12Object* mRuntimeObject = nullptr;
    uint64_t mCreationTimestamp = RDTSC();

public:
    D3D12ObjectState(ID3D12Object* runtimeObj);
    virtual ~D3D12ObjectState();

    HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, void** ppvObject) override;
    ULONG STDMETHODCALLTYPE AddRef() override;
    ULONG STDMETHODCALLTYPE Release() override;

    virtual GUID GetGUID();
    void SetName(LPCWSTR name);
    std::wstring GetName() const;
    ID3D12Object* GetRuntimeObject() const;
    uint64_t GetCreationTimestamp() const;

    void RegisterDeathNotificationCallback(DeathNotificationCallback cb);
};

namespace detail {

// This class is designed to control the runtimeObject's lifetime.
// An instance of this class is placed to the runtimeObject's private data,
// but it must nevertheless live longer the runtimeObject
// in order to tell everyone interested that the runtimeObject no longer exists
//
// Therefore, to control the lifetime of the LivetimeTracker itself, shared_ptr was chosed
template<typename RuntimeObjectType>
class D3D12LivetimeTracker : public IUnknown
{
    // {531A73DE-57D3-4782-AC4A-680997FAA7FD}
    static constexpr GUID sLivetimeTrackerGuid = {0x531a73de, 0x57d3, 0x4782, {0xac, 0x4a, 0x68, 0x9, 0x97, 0xfa, 0xa7, 0xfd}};

    RuntimeObjectType* mRuntimeObject = nullptr;
    std::weak_ptr<D3D12LivetimeTracker> mThisPtr;  //< a weak pointer to yourself

    HRESULT STDMETHODCALLTYPE QueryInterface(REFIID, void**) override
    {
        return E_UNEXPECTED;
    }
    ULONG STDMETHODCALLTYPE AddRef() override
    {
        return 0;
    }
    ULONG STDMETHODCALLTYPE Release() override
    {
        mRuntimeObject = nullptr;
        return 0;
    }

    D3D12LivetimeTracker(RuntimeObjectType* runtimeObject)
        : mRuntimeObject(runtimeObject)
    {
    }

    D3D12LivetimeTracker(D3D12LivetimeTracker const&) = delete;
    D3D12LivetimeTracker& operator=(D3D12LivetimeTracker const&) = delete;

public:
    ~D3D12LivetimeTracker()
    {
        if (mRuntimeObject && mThisPtr.expired()) {
            mRuntimeObject->SetPrivateDataInterface(sLivetimeTrackerGuid, nullptr);
        }
    }

    // It should only be called if the runtimeObject exists
    static std::shared_ptr<D3D12LivetimeTracker> GetLivetimeTracker(RuntimeObjectType* runtimeObject)
    {
        static concurrent::CriticalSection sCriticalSection;

        if (!runtimeObject) {
            return nullptr;
        }
        D3D12LivetimeTracker* livetimeTracker = nullptr;
        UINT dataSize = sizeof(livetimeTracker);

        concurrent::CriticalSectionScopedLock guard(sCriticalSection);
        HRESULT hr = runtimeObject->GetPrivateData(sLivetimeTrackerGuid, &dataSize, &livetimeTracker);

        if (FAILED(hr) || !livetimeTracker || livetimeTracker->mThisPtr.expired()) {
            std::shared_ptr<D3D12LivetimeTracker> livetimeTrackerPtr(new D3D12LivetimeTracker(runtimeObject));
            runtimeObject->SetPrivateDataInterface(sLivetimeTrackerGuid, livetimeTrackerPtr.get());
            livetimeTrackerPtr->mThisPtr = livetimeTrackerPtr;
            return livetimeTrackerPtr;
        }
        return livetimeTracker->mThisPtr.lock();
    }

    RuntimeObjectType* GetRuntimeObject()
    {
        return mRuntimeObject;
    }
};
}  // namespace detail

// This class behaves like a RuntimeObjectType, but it always knows if the runtimeObject still exists
template<typename RuntimeObjectType>
class WeakPtr
{
    // Since WeakPtr instances can be copied, so we have to share the LivetimeTracker between them
    std::shared_ptr<detail::D3D12LivetimeTracker<RuntimeObjectType>> mLivetimeTracker;

public:
    WeakPtr() = default;
    WeakPtr(WeakPtr const& other) = default;
    WeakPtr(WeakPtr&& other) = default;
    WeakPtr& operator=(WeakPtr const& other) = default;
    WeakPtr& operator=(WeakPtr&& other) = default;

    WeakPtr(RuntimeObjectType* runtimeObject)
    {
        *this = runtimeObject;
    }

    WeakPtr& operator=(RuntimeObjectType* runtimeObject)
    {
        mLivetimeTracker = detail::D3D12LivetimeTracker<RuntimeObjectType>::GetLivetimeTracker(runtimeObject);
        return *this;
    }

    explicit operator uint64_t() const
    {
        return mLivetimeTracker ? (uint64_t)mLivetimeTracker->GetRuntimeObject() : 0;
    }

    operator RuntimeObjectType*() const
    {
        return mLivetimeTracker ? mLivetimeTracker->GetRuntimeObject() : nullptr;
    }

    RuntimeObjectType* operator->() const
    {
        return mLivetimeTracker->GetRuntimeObject();
    }
};

}  // namespace d3d12_state_tracker
}  // namespace gpa