Program Listing for File thread-pool.h

Return to documentation for file (include\concurrent\thread-pool.h)

/******************************************************************************
© Intel Corporation.

This software and the related documents are Intel copyrighted materials,
and your use of them is governed by the express license under which they
were provided to you ("License"). Unless the License provides otherwise,
you may not use, modify, copy, publish, distribute, disclose or transmit
this software or the related documents without Intel's prior written
permission.


 This software and the related documents are provided as is, with no express
or implied warranties, other than those that are expressly stated in the
License.

******************************************************************************/
#pragma once

#ifdef _WIN32
#include <Windows.h>
#endif

#include <atomic>
#include <future>
#include <functional>
#include <mutex>
#include <thread>
#include <vector>

#include "concurrent/queue.h"

namespace gpa {
namespace concurrent {

#ifdef _WIN32

//fastest possible async Task using Windows provided Thread Pool
class WinNativeThreadPool
{
    static DWORD WINAPI WorkItem(void* ctx);
    static size_t& GetCounter();

public:
    struct Task  //inherit your tasks from this
    {
        virtual void Run() = 0;
    };

    //Fastest. This code doesn't handle the lifetime of Task. You can "new" it and delete in the end of Run, see below
    static HRESULT UnsafeQueue(Task& task, DWORD dwFlags = 0);  // flags come from WT_EXECUTEDEFAULT family

    template<class F>
    class Wrapper : public Task  //keeps copy of functor and deletes it after call
    {
    protected:
        F mCallable;

        Wrapper(F clb)  // protected to allow changing of "new/delete this" only by heirs: https://isocpp.org/wiki/faq/freestore-mgmt#delete-this
            : mCallable(clb)
        {
        }

    public:
        static Wrapper<F>* Create(F clb)
        {
            return new Wrapper<F>(clb);
        }
        virtual void Run()
        {
            mCallable();
            delete this;
        }
    };

    template<class T>                                       //Fastest: takes only basic function that take no arguments to call
    static HRESULT Queue(T(callable)(), DWORD dwFlags = 0)  //costs sync on malloc / free
    {
        return UnsafeQueue(*Wrapper<decltype(callable)>::Create(callable), dwFlags);
    }

    template<class T>                                           //Takes any callable like lambda, std::function, std::bind, or whatever has operator ()
    static HRESULT Queue(const T& callable, DWORD dwFlags = 0)  //AND keeps the copy thrughout execution
    {
        return UnsafeQueue(*Wrapper<T>::Create(callable), dwFlags);
    }

    //Wait on this counter to make sure all tasks are finished
    static size_t Count()
    {
        return GetCounter();
    }
};
#endif

using VoidFunction = std::function<void()>;

class ThreadPool;

class Thread
{
public:
    enum ThreadState {
        kUninitialized,
        kAliveIdle,
        kAliveBusy,
        kDead
    };

    Thread(ThreadPool* parentPool, VoidFunction initializer = nullptr);

    ~Thread();

    void RequestShutdown();
    void WaitForShutdown();
    ThreadState State();

private:
    void Execute();

    ThreadPool* mParentPool;

    std::atomic<bool> mStayAlive;
    std::atomic<ThreadState> mState;
    VoidFunction mInitializer;
    std::thread mThread;
};

struct ThreadPoolTask
{
    VoidFunction job;
    bool optional;
};

class ThreadPool
{
private:
    template<typename Fun, typename... Args>
    auto SubmitJobImpl(bool optional, Fun&& function, Args&&... args) -> std::future<decltype(function(args...))>
    {
        if (!mAcceptNewJobs) {
            // return invalid future object
            return std::future<decltype(function(args...))>();
        }

        std::function<decltype(function(args...))()> functionWithParams = std::bind(std::forward<Fun>(function), std::forward<Args>(args)...);
        auto taskPtr = std::make_shared<std::packaged_task<decltype(function(args...))()>>(functionWithParams);
        VoidFunction wrappedFunction = [taskPtr, this]() {
            (*taskPtr)();
            this->IncrementFinishedJobsCount();
        };
        mTaskQueue.Push({wrappedFunction, optional});
        IncrementSubmittedJobsCount();
        mJobAvailable.notify_one();
        return taskPtr->get_future();
    }

public:
    ThreadPool(uint32_t threadCount = uint32_t(~0x0), VoidFunction initializer = nullptr);
    ~ThreadPool();

    uint32_t ThreadCount();
    uint32_t ThreadCountWithState(Thread::ThreadState state);
    void AddThreads(uint32_t threadCount, VoidFunction initializer = nullptr);

    template<typename Fun, typename... Args>
    auto SubmitJob(Fun&& function, Args&&... args) -> std::future<decltype(function(args...))>
    {
        return SubmitJobImpl(false, function, args...);
    }

    template<typename Fun, typename... Args>
    auto SubmitOptionalJob(Fun&& function, Args&&... args) -> std::future<decltype(function(args...))>
    {
        return SubmitJobImpl(true, function, args...);
    }

    void WaitTillFinished();

    template<typename Fun, typename... Args>
    auto RunWhenFinished(Fun&& function, Args&&... args) -> std::future<decltype(function(args...))>
    {
        if (!mAcceptNewJobs) {
            // return invalid future object
            return std::future<decltype(function(args...))>();
        }

        std::function<decltype(function(args...))()> functionWithParams = std::bind(std::forward<Fun>(function), std::forward<Args>(args)...);
        auto taskPtr = std::make_shared<std::packaged_task<decltype(function(args...))()>>(functionWithParams);
        VoidFunction wrappedFunction = [taskPtr, this]() {
            uint64_t submittedTillNow = this->mSubmittedJobsCount;
            while (mFinishedJobsCount < submittedTillNow) {
                std::this_thread::yield();
            }
            (*taskPtr)();
        };
        mTaskQueue.Push({wrappedFunction, false});
        mJobAvailable.notify_one();
        return taskPtr->get_future();
    }

    void Shutdown();

    void CancelOptionalJobs();

private:
    void KillAllThreads();

    SimpleConcurrentQueue<ThreadPoolTask> mTaskQueue;
    std::vector<Thread*> mThreads;
    std::atomic<uint64_t> mSubmittedJobsCount = {0};
    std::atomic<uint64_t> mFinishedJobsCount = {0};
    std::condition_variable mJobAvailable;
    std::mutex mJobMutex;
    std::atomic_bool mAcceptNewJobs = {true};

    friend class Thread;

    void IncrementFinishedJobsCount();
    void IncrementSubmittedJobsCount();
};

}  // namespace concurrent
}  // namespace gpa