Program Listing for File d3d12-object.h

Return to documentation for file (include\d3d12-state-tracking\d3d12-object.h)

/******************************************************************************
© Intel Corporation.

This software and the related documents are Intel copyrighted materials,
and your use of them is governed by the express license under which they
were provided to you ("License"). Unless the License provides otherwise,
you may not use, modify, copy, publish, distribute, disclose or transmit
this software or the related documents without Intel's prior written
permission.


 This software and the related documents are provided as is, with no express
or implied warranties, other than those that are expressly stated in the
License.

******************************************************************************/

#pragma once

#include "utility/rdtsc.h"
#include <concurrent/critical-section.h>

#include "utility/assert.h"
#include <functional>
#include <list>
#include <d3d12.h>
#include <dxgi1_6.h>
#include <string>
#include <memory>

namespace gpa {
namespace d3d12_state_tracker {

using DeathNotificationCallback = std::function<void(ID3D12Object* deadObjectPointer)>;

template<typename T>
class ObjectState : public IUnknown
{
public:
    static constexpr GUID sGUID = {0x69a84fd0, 0xb106, 0x4496, {0xbe, 0x30, 0xd6, 0xac, 0x1e, 0x30, 0x55, 0x2b}};
    using OnDeathCallback = std::function<void(T* deadObjectPointer)>;

private:
    std::list<OnDeathCallback> mDeathCallbacks;
    std::wstring mName;

protected:
    T* mRuntimeObject = nullptr;
    uint64_t mCreationTimestamp = RDTSC();

public:
    ObjectState(T* runtimeObj)
        : mRuntimeObject(runtimeObj)
    {
        GPA_ASSERT(runtimeObj && "runtimeObj must not be null");
    }

    virtual ~ObjectState()
    {
    }

    HRESULT STDMETHODCALLTYPE QueryInterface(REFIID /*riid*/, void** ppvObject) override
    {
        if (ppvObject == nullptr) {
            return E_INVALIDARG;
        }
        return E_NOINTERFACE;
    }

    ULONG STDMETHODCALLTYPE AddRef() override
    {
        // Reference counting is not required
        return 1;
    }

    ULONG STDMETHODCALLTYPE Release() override
    {
        // Just first Release call means that the object is being deleted
        // The AddRef is called multiple times for whatever reason, but there
        // is only one release

        for (auto deathCallback : mDeathCallbacks) {
            deathCallback(mRuntimeObject);
        }

        delete this;
        return 0;
    }

    virtual GUID GetGUID()
    {
        return sGUID;
    }

    void SetName(LPCWSTR name)
    {
        mName = name;
    }

    std::wstring GetName() const
    {
        return mName;
    }

    T* GetRuntimeObject() const
    {
        return mRuntimeObject;
    }

    uint64_t GetCreationTimestamp() const
    {
        return mCreationTimestamp;
    }

    void RegisterDeathNotificationCallback(OnDeathCallback cb)
    {
        mDeathCallbacks.push_back(cb);
    }
};

using D3D12ObjectState = ObjectState<ID3D12Object>;
using DXGIObjectState = ObjectState<IDXGIObject>;

namespace detail {

// This class is designed to control the runtimeObject's lifetime.
// An instance of this class is placed to the runtimeObject's private data,
// but it must nevertheless live longer the runtimeObject
// in order to tell everyone interested that the runtimeObject no longer exists
//
// Therefore, to control the lifetime of the LivetimeTracker itself, shared_ptr was chosed
template<typename RuntimeObjectType>
class D3D12LivetimeTracker : public IUnknown
{
    // {531A73DE-57D3-4782-AC4A-680997FAA7FD}
    static constexpr GUID sLivetimeTrackerGuid = {0x531a73de, 0x57d3, 0x4782, {0xac, 0x4a, 0x68, 0x9, 0x97, 0xfa, 0xa7, 0xfd}};

    RuntimeObjectType* mRuntimeObject = nullptr;
    std::weak_ptr<D3D12LivetimeTracker> mThisPtr;  //< a weak pointer to yourself

    HRESULT STDMETHODCALLTYPE QueryInterface(REFIID, void**) override
    {
        return E_UNEXPECTED;
    }
    ULONG STDMETHODCALLTYPE AddRef() override
    {
        return 0;
    }
    ULONG STDMETHODCALLTYPE Release() override
    {
        mRuntimeObject = nullptr;
        return 0;
    }

    D3D12LivetimeTracker(RuntimeObjectType* runtimeObject)
        : mRuntimeObject(runtimeObject)
    {
    }

    D3D12LivetimeTracker(D3D12LivetimeTracker const&) = delete;
    D3D12LivetimeTracker& operator=(D3D12LivetimeTracker const&) = delete;

public:
    ~D3D12LivetimeTracker()
    {
        if (mRuntimeObject && mThisPtr.expired()) {
            mRuntimeObject->SetPrivateDataInterface(sLivetimeTrackerGuid, nullptr);
        }
    }

    // It should only be called if the runtimeObject exists
    static std::shared_ptr<D3D12LivetimeTracker> GetLivetimeTracker(RuntimeObjectType* runtimeObject)
    {
        static concurrent::CriticalSection sCriticalSection;

        if (!runtimeObject) {
            return nullptr;
        }
        D3D12LivetimeTracker* livetimeTracker = nullptr;
        UINT dataSize = sizeof(livetimeTracker);

        concurrent::CriticalSectionScopedLock guard(sCriticalSection);
        HRESULT hr = runtimeObject->GetPrivateData(sLivetimeTrackerGuid, &dataSize, &livetimeTracker);

        if (FAILED(hr) || !livetimeTracker || livetimeTracker->mThisPtr.expired()) {
            std::shared_ptr<D3D12LivetimeTracker> livetimeTrackerPtr(new D3D12LivetimeTracker(runtimeObject));
            runtimeObject->SetPrivateDataInterface(sLivetimeTrackerGuid, livetimeTrackerPtr.get());
            livetimeTrackerPtr->mThisPtr = livetimeTrackerPtr;
            return livetimeTrackerPtr;
        }
        return livetimeTracker->mThisPtr.lock();
    }

    RuntimeObjectType* GetRuntimeObject()
    {
        return mRuntimeObject;
    }
};
}  // namespace detail

// This class behaves like a RuntimeObjectType, but it always knows if the runtimeObject still exists
template<typename RuntimeObjectType>
class WeakPtr
{
    // Since WeakPtr instances can be copied, so we have to share the LivetimeTracker between them
    std::shared_ptr<detail::D3D12LivetimeTracker<RuntimeObjectType>> mLivetimeTracker;

public:
    WeakPtr() = default;
    WeakPtr(WeakPtr const& other) = default;
    WeakPtr(WeakPtr&& other) = default;
    WeakPtr& operator=(WeakPtr const& other) = default;
    WeakPtr& operator=(WeakPtr&& other) = default;

    WeakPtr(RuntimeObjectType* runtimeObject)
    {
        *this = runtimeObject;
    }

    WeakPtr& operator=(RuntimeObjectType* runtimeObject)
    {
        mLivetimeTracker = detail::D3D12LivetimeTracker<RuntimeObjectType>::GetLivetimeTracker(runtimeObject);
        return *this;
    }

    explicit operator uint64_t() const
    {
        return mLivetimeTracker ? (uint64_t)mLivetimeTracker->GetRuntimeObject() : 0;
    }

    operator RuntimeObjectType*() const
    {
        return mLivetimeTracker ? mLivetimeTracker->GetRuntimeObject() : nullptr;
    }

    RuntimeObjectType* operator->() const
    {
        return mLivetimeTracker->GetRuntimeObject();
    }
};

}  // namespace d3d12_state_tracker
}  // namespace gpa