Program Listing for File interval-tree.h

Return to documentation for file (include\playback\resource-info\detail\interval-tree.h)

/******************************************************************************

© Intel Corporation.

This software and the related documents are Intel copyrighted materials,
and your use of them is governed by the express license under which they
were provided to you ("License"). Unless the License provides otherwise,
you may not use, modify, copy, publish, distribute, disclose or transmit
this software or the related documents without Intel's prior written
permission.

 This software and the related documents are provided as is, with no express
or implied warranties, other than those that are expressly stated in the
License.

******************************************************************************/

#pragma once

#include <algorithm>
#include <cassert>
#include <cstdint>
#include <memory>
#include <tuple>
#include <utility>
#include <vector>

namespace gpa {
namespace playback {
namespace detail {

template<typename EndPointType>
using Interval = std::pair<EndPointType, EndPointType>;

template<typename EndPointType>
bool Overlap(Interval<EndPointType> const& lhs, Interval<EndPointType> const& rhs)
{
    return lhs.first <= rhs.second && lhs.second >= rhs.first;
}

template<typename EndPointType, typename ValueType>
class IntervalTree final
{
public:
    inline IntervalTree() = default;

    inline IntervalTree(IntervalTree<EndPointType, ValueType>&& other)
    {
        *this = std::move(other);
    }

    inline IntervalTree<EndPointType, ValueType>& operator=(IntervalTree<EndPointType, ValueType>&& other)
    {
        mRoot = std::move(other.mRoot);
        other.mRoot = nullptr;
        return *this;
    }

    inline ~IntervalTree()
    {
        delete mRoot;
    }

    inline ValueType& operator[](EndPointType const& point)
    {
        return operator[]({point, point});
    }

    inline ValueType& operator[](Interval<EndPointType> const& interval)
    {
        bool added = false;
        Node* itr = nullptr;
        mRoot = Node::Add(mRoot, interval, added, &itr);
        Node::ComputeMax(mRoot);
        assert(itr && "itr must not be null after call to Add()");
        return itr->mValue;
    }

    template<typename EnumerationFunctionType>
    inline void EnumerateOverlappingIntervals(
        EndPointType point,
        EnumerationFunctionType enumerationFunction) const
    {
        EnumerateOverlappingIntervals({point, point}, enumerationFunction);
    }

    template<typename EnumerationFunctionType>
    inline void EnumerateOverlappingIntervals(
        Interval<EndPointType> const& interval,
        EnumerationFunctionType enumerationFunction) const
    {
        if (mRoot) {
            Node::EnumerateOverlappingIntervals(mRoot, interval, enumerationFunction);
        }
    }

private:
    class Node final
    {
    public:
        inline Node() = default;

        inline ~Node()
        {
            delete mLeft;
            delete mRight;
        }

        template<typename EnumerationFunctionType>
        static inline void EnumerateOverlappingIntervals(
            Node const* node,
            Interval<EndPointType> const& interval,
            EnumerationFunctionType enumerationFunction)
        {
            assert(node && "EnumerateOverlappingIntervals() must not be called with nullptr for node");
            if (interval.second < node->mInterval.first) {
                if (node->mLeft) {
                    EnumerateOverlappingIntervals(node->mLeft, interval, enumerationFunction);
                }
                if (Overlap(node->mInterval, interval)) {
                    enumerationFunction(node->mInterval, node->mValue);
                }
            } else if (interval.first > node->mMax) {
                if (Overlap(node->mInterval, interval)) {
                    enumerationFunction(node->mInterval, node->mValue);
                }
                if (node->mRight) {
                    EnumerateOverlappingIntervals(node->mRight, interval, enumerationFunction);
                }
            } else {
                if (node->mLeft) {
                    EnumerateOverlappingIntervals(node->mLeft, interval, enumerationFunction);
                }
                if (Overlap(node->mInterval, interval)) {
                    enumerationFunction(node->mInterval, node->mValue);
                }
                if (node->mRight) {
                    EnumerateOverlappingIntervals(node->mRight, interval, enumerationFunction);
                }
            }
        }

        static inline Node* Add(
            Node* node,
            Interval<EndPointType> const& interval,
            bool& added,
            Node** itr)
        {
            assert(itr && "Add() must not be called with nullptr for itr");
            if (!node) {
                node = new Node;
                node->mInterval = interval;
                node->mMax = interval.second;
                added = true;
                *itr = node;
            } else {
                Node* newNode = nullptr;
                if (interval < node->mInterval) {
                    newNode = Add(node->mLeft, interval, added, itr);
                    if (newNode != node->mLeft) {
                        node->mLeft = newNode;
                    }
                    if (added) {
                        --node->mBalance;
                        if (node->mBalance == 0) {
                            added = false;
                        } else if (node->mBalance == -2) {
                            if (node->mLeft->mBalance == 1) {
                                int leftRightBalance = node->mLeft->mRight->mBalance;
                                node->mLeft = RotateLeft(node->mLeft);
                                node = RotateRight(node);
                                node->mBalance = 0;
                                node->mLeft->mBalance = leftRightBalance == 1 ? -1 : 0;
                                node->mRight->mBalance = leftRightBalance == -1 ? 1 : 0;
                            } else if (node->mLeft->mBalance == -1) {
                                node = RotateRight(node);
                                node->mBalance = 0;
                                node->mRight->mBalance = 0;
                            }
                            added = false;
                        }
                    }
                } else if (interval == node->mInterval) {
                    *itr = node;
                } else if (interval > node->mInterval) {
                    newNode = Add(node->mRight, interval, added, itr);
                    if (newNode != node->mRight) {
                        node->mRight = newNode;
                    }
                    if (added) {
                        ++node->mBalance;
                        if (node->mBalance == 0) {
                            added = false;
                        } else if (node->mBalance == 2) {
                            if (node->mRight->mBalance == -1) {
                                auto rightLeftBalance = node->mRight->mLeft->mBalance;
                                node->mRight = RotateRight(node->mRight);
                                node = RotateLeft(node);
                                node->mBalance = 0;
                                node->mLeft->mBalance = rightLeftBalance == 1 ? -1 : 0;
                                node->mRight->mBalance = rightLeftBalance == -1 ? 1 : 0;
                            } else if (node->mRight->mBalance == 1) {
                                node = RotateLeft(node);
                                node->mBalance = 0;
                                node->mLeft->mBalance = 0;
                            }
                            added = false;
                        }
                    }
                }
                ComputeMax(node);
            }
            assert(node && "Add() must not return nullptr");
            return node;
        }

        static inline void ComputeMax(Node* node)
        {
            assert(node && "ComputeMax() must not be called with nullptr for node");
            if (!node->mLeft && !node->mRight) {
                node->mMax = node->mInterval.second;
            } else if (!node->mLeft) {
                node->mMax = std::max(node->mMax, node->mRight->mMax);
            } else if (!node->mRight) {
                node->mMax = std::max(node->mMax, node->mLeft->mMax);
            } else {
                node->mMax = std::max(std::max(node->mLeft->mMax, node->mMax), node->mRight->mMax);
            }
        }

        static inline Node* RotateLeft(Node* node)
        {
            /*
                  A (node)       B
                   \            /
                    B      ->  A
                   /            \
                  C              C
            */
            assert(node && "RotateLeft() must not be called with nullptr for node");
            auto right = node->mRight;
            node->mRight = right->mLeft;
            ComputeMax(node);
            right->mLeft = node;
            ComputeMax(right);
            return right;
        }

        static inline Node* RotateRight(Node* node)
        {
            /*      A (node)         B
                   /                  \
                  B          ->        A
                   \                  /
                    C                C
            */
            assert(node && "RotateRight() must not be called with nullptr for node");
            auto left = node->mLeft;
            node->mLeft = left->mRight;
            ComputeMax(node);
            left->mRight = node;
            ComputeMax(left);
            return left;
        }

        Interval<EndPointType> mInterval{0, 0};
        ValueType mValue{};
        EndPointType mMax{0};
        Node* mLeft{nullptr};
        Node* mRight{nullptr};
        int mBalance{0};

    private:
        Node(Node const&) = delete;
        Node& operator=(Node const&) = delete;
    };

    Node* mRoot{nullptr};

    IntervalTree(IntervalTree<EndPointType, ValueType> const&) = delete;
    IntervalTree& operator=(IntervalTree<EndPointType, ValueType> const&) = delete;
};

}  // namespace detail
}  // namespace playback
}  // namespace gpa