Program Listing for File stateobject-parser.h

Return to documentation for file (include\d3d12-state-tracking\stateobject-parser.h)

/******************************************************************************
(C) Intel Corporation.

This software and the related documents are Intel copyrighted materials,
and your use of them is governed by the express license under which they
were provided to you ("License"). Unless the License provides otherwise,
you may not use, modify, copy, publish, distribute, disclose or transmit
this software or the related documents without Intel's prior written
permission.


 This software and the related documents are provided as is, with no express
or implied warranties, other than those that are expressly stated in the
License.

******************************************************************************/

#pragma once

#include <d3d12.h>

#include <atlbase.h>
#include <set>
#include <map>
#include <string>
#include <vector>
#include <limits>
#include <functional>

namespace gpa {
namespace d3d12_state_tracker {

class StateObjectParser
{
public:
    struct ShaderRef
    {
        size_t libraryIndex = std::numeric_limits<size_t>::max();
        size_t entryPointIndex = std::numeric_limits<size_t>::max();

        operator bool() const;
        void Merge(ShaderRef const& other);
    };

    struct HitGroupDesc
    {
        std::wstring hitGroupExport;
        D3D12_HIT_GROUP_TYPE type{D3D12_HIT_GROUP_TYPE_TRIANGLES};
        std::wstring anyHitShaderImport;
        std::wstring closestHitShaderImport;
        std::wstring intersectionShaderImport;

        operator bool() const
        {
            return !hitGroupExport.empty() || !anyHitShaderImport.empty() || !closestHitShaderImport.empty() || !intersectionShaderImport.empty();
        }

        bool operator==(const HitGroupDesc& rhs) const noexcept
        {
            return std::tie(hitGroupExport, type, anyHitShaderImport, closestHitShaderImport, intersectionShaderImport) ==
                   std::tie(rhs.hitGroupExport, rhs.type, rhs.anyHitShaderImport, rhs.closestHitShaderImport, rhs.intersectionShaderImport);
        }
    };

    struct Export
    {
        ShaderRef shader = {};
        HitGroupDesc hitGroup;
        ID3D12RootSignature* localRootSignature = nullptr;

        operator bool() const;
        void Merge(Export const& other);
    };

    struct ParsedData
    {
        std::vector<D3D12_SHADER_BYTECODE> libraries;
        std::vector<CComPtr<ID3D12RootSignature>> shaderExtractedRootSignatures;
        std::map<std::wstring, Export> exports;
        ID3D12RootSignature* globalRootSignature = nullptr;

        void Add(ParsedData const& source, UINT numExports = 0, D3D12_EXPORT_DESC const* exportDescs = nullptr);
    };

    using ExistingCollectionCallbackFunc = std::function<ParsedData const*(ID3D12StateObject*)>;

public:
    StateObjectParser() = default;
    StateObjectParser(ID3D12Device* device, ExistingCollectionCallbackFunc existingCollectionCallback);
    ParsedData Parse(D3D12_STATE_OBJECT_DESC const& desc);

private:
    void OnExistingCollection(D3D12_EXISTING_COLLECTION_DESC const& desc);
    void OnDxilLibrary(D3D12_DXIL_LIBRARY_DESC const& desc);
    void OnHitGroup(D3D12_HIT_GROUP_DESC const& desc);
    void OnGlobalRootSignature(D3D12_GLOBAL_ROOT_SIGNATURE const& desc);
    void OnLocalRootSignature(D3D12_LOCAL_ROOT_SIGNATURE const& desc);
    void OnSubobjectToExportsAssociation(D3D12_SUBOBJECT_TO_EXPORTS_ASSOCIATION const& desc);
    void OnDxilSubobjectToExportsAssociation(D3D12_DXIL_SUBOBJECT_TO_EXPORTS_ASSOCIATION const& desc);

    void Reset();
    void ResolveAssociations();
    void ResolveHitGroup(Export& hitGroupExport, LPCWSTR importName) const;

private:
    ParsedData mParsedData;
    ExistingCollectionCallbackFunc mExistingCollectionCallback;
    std::set<ID3D12RootSignature*> mUnassociatedLocalRootSignatures;
    ID3D12RootSignature* mDefaultLocalRootSignature = nullptr;
    ID3D12Device* mDevice = nullptr;
};

}  // namespace d3d12_state_tracker
}  // namespace gpa