Program Listing for File sync-tracker.h

Return to documentation for file (include\playback\sync-tracker.h)

/******************************************************************************
© Intel Corporation.

This software and the related documents are Intel copyrighted materials,
and your use of them is governed by the express license under which they
were provided to you ("License"). Unless the License provides otherwise,
you may not use, modify, copy, publish, distribute, disclose or transmit
this software or the related documents without Intel's prior written
permission.


 This software and the related documents are provided as is, with no express
or implied warranties, other than those that are expressly stated in the
License.

******************************************************************************/
#pragma once

#include "playback/sync-handler-common.h"

#include <functional>
#include <memory>
#include <vector>
#include <string>
#include <deque>
#include <cassert>
#include <map>

namespace gpa {
namespace playback {
namespace sync_handler {

class FenceTracker
{
public:
    struct FenceState
    {
        uint64_t value = 0;
    };

    FenceState& GetState(uint64_t fence);
    void CreateSharedHandle(uint64_t object, wchar_t const* name, uint64_t handle);
    void CreateSharedHandleByName(wchar_t const* name, uint64_t handle);
    void OpenSharedHandle(uint64_t handle, uint64_t object);

private:
    std::deque<FenceState> mStates;
    std::map<uint64_t, FenceState*> mFenceToState;
    std::map<uint64_t, FenceState*> mHandleToState;
    std::map<std::wstring, FenceState*> mNameToState;
};

template<typename CallableType>
class SyncTracker
{
    struct QueuePendingCall
    {
        CallableType callable;
        SyncCommand syncCommand;
    };

    struct QueueWaitState
    {
        uint64_t fenceUid;
        uint64_t fenceWaitValue;
        std::vector<QueuePendingCall> pendingCalls;
    };

    struct QueueState
    {
        std::deque<QueueWaitState> waits;
    };

    struct SyncStateUpdate
    {
        bool shouldPlay;
        std::vector<QueuePendingCall> unlockedCallables;
    };

public:
    using CallableCallback = std::function<void(CallableType& callable)>;

    SyncTracker(){};

    bool IsQueueLocked(uint64_t queueId) const
    {
        auto it = mQueueStates.find(queueId);
        if (it == mQueueStates.end()) {
            return false;
        }
        return !it->second.waits.empty();
    }

    void EnumerateLockedCalls(CallableCallback callback)
    {
        for (auto& queueState : mQueueStates) {
            for (auto& wait : queueState.second.waits) {
                for (auto& call : wait.pendingCalls) {
                    callback(call.callable);
                }
            }
        }
    }

    void EnumerateUnlockedCalls(CallableType& originalCall, SyncCommand& originalSyncCommand, CallableCallback callback = [](CallableType&) {})
    {
        std::deque<QueuePendingCall> unlockedCalls;
        auto processCall = [&](CallableType& callable, SyncCommand& syncCommand) {
            SyncStateUpdate state = this->UpdateState(callable, syncCommand);
            if (state.shouldPlay) {
                callback(callable);
                unlockedCalls.insert(unlockedCalls.begin(), state.unlockedCallables.begin(), state.unlockedCallables.end());
            } else {
                assert(state.unlockedCallables.empty());
            }
        };
        processCall(originalCall, originalSyncCommand);
        while (!unlockedCalls.empty()) {
            QueuePendingCall unlockedCall = unlockedCalls.front();
            unlockedCalls.pop_front();
            processCall(unlockedCall.callable, unlockedCall.syncCommand);
        }
    }

    SyncStateUpdate UpdateState(CallableType& call, SyncCommand& syncCommand)
    {
        auto updateQueues = [&](std::vector<QueuePendingCall>& output) {
            for (auto& queueState : mQueueStates) {
                auto& waits = queueState.second.waits;
                if (waits.empty()) {
                    continue;
                }
                auto& firstWait = waits.front();
                auto const& fenceState = mFenceTracker.GetState(firstWait.fenceUid);
                if (fenceState.value >= firstWait.fenceWaitValue) {
                    output.insert(output.end(), firstWait.pendingCalls.begin(), firstWait.pendingCalls.end());
                    waits.pop_front();
                }
            }
        };

        SyncStateUpdate result{true, {}};
        switch (syncCommand.type) {
        case SyncCommandType::kImmidiateSignal: {
            SyncImmidiateSignal& data = syncCommand.data.immidiateSignal;
            auto& fenceState = mFenceTracker.GetState(data.fenceUid);
            fenceState.value = data.signalValue;
            updateQueues(result.unlockedCallables);
            break;
        }
        case SyncCommandType::kQueueSignal: {
            SyncQueueSignal signal = syncCommand.data.queueSignal;
            QueueState& queueState = mQueueStates[signal.queueUid];
            if (queueState.waits.empty()) {
                auto& fenceState = mFenceTracker.GetState(signal.fenceUid);
                fenceState.value = signal.signalValue;
                updateQueues(result.unlockedCallables);
            } else {
                queueState.waits.back().pendingCalls.push_back({call, syncCommand});
                result.shouldPlay = false;
            }
            break;
        }
        case SyncCommandType::kQueueWait: {
            SyncQueueWait data = syncCommand.data.queueWait;
            QueueState& queueState = mQueueStates[data.queueUid];
            if (!queueState.waits.empty()) {
                queueState.waits.back().pendingCalls.push_back({call, syncCommand});
                result.shouldPlay = false;
            } else {
                auto const& fenceState = mFenceTracker.GetState(data.fenceUid);
                if (fenceState.value < data.waitValue) {
                    queueState.waits.push_back(QueueWaitState{data.fenceUid, data.waitValue, {{call, syncCommand}}});
                    result.shouldPlay = false;
                }
            }
            break;
        }
        case SyncCommandType::kCreateSharedHandle: {
            CreateSharedHandle const& data = syncCommand.data.createSharedHandle;
            if (data.objectUid == 0) {
                mFenceTracker.CreateSharedHandleByName(data.name, data.outHandle);
            } else {
                mFenceTracker.CreateSharedHandle(data.objectUid, data.name, data.outHandle);
            }
            break;
        }
        case SyncCommandType::kOpenSharedHandle: {
            auto const& data = syncCommand.data.openSharedHandle;
            mFenceTracker.OpenSharedHandle(data.handle, data.outObjectUid);
            break;
        }
        case SyncCommandType::kQueueOther: {
            SyncQueueOther data = syncCommand.data.queueOther;
            QueueState& queueState = mQueueStates[data.queueUid];
            if (!queueState.waits.empty()) {
                queueState.waits.back().pendingCalls.push_back({call, syncCommand});
                result.shouldPlay = false;
            }
            break;
        }
        default:
            break;
        }
        return result;
    }

private:
    FenceTracker mFenceTracker;
    std::map<uint64_t, QueueState> mQueueStates;
};

}  // namespace sync_handler
}  // namespace playback
}  // namespace gpa