From 67e40a5d4d3b4e7ed6823e9a768cea48cfd6cfe6 Mon Sep 17 00:00:00 2001 From: Jonathan Bradley Date: Wed, 29 Nov 2023 20:58:54 -0500 Subject: add thread_pool --- src/thread_pool.cpp | 202 ++++++++++++++++++++++++++++++++++++++++++++++++++++ src/thread_pool.hpp | 22 ++++++ 2 files changed, 224 insertions(+) create mode 100644 src/thread_pool.cpp create mode 100644 src/thread_pool.hpp (limited to 'src') diff --git a/src/thread_pool.cpp b/src/thread_pool.cpp new file mode 100644 index 0000000..d3073dc --- /dev/null +++ b/src/thread_pool.cpp @@ -0,0 +1,202 @@ + +#include "thread_pool.hpp" + +#include +#include + +TypeSafeInt_B(ThreadPoolHandle); + +struct ThreadPool { + bool isRunning; + bool isPaused; + uint8_t maxJobQueueCount; + std::mutex mutex; + std::atomic completedCount; + std::condition_variable condition; + DynArray *> *jobQueue; + DynArray *threads; + MemBucket *bkt = nullptr; +}; + +struct ThreadBucket { + ThreadPool threadPools[8]; +}; + +BucketContainer ThreadPool_BucketContainer{}; + +void ThreadRun(ThreadPool *tp) { + std::packaged_task *j = nullptr; + while (tp->isRunning && !tp->isPaused) { + { + std::unique_lock lck(tp->mutex); + tp->condition.wait(lck, [tp] { + return tp->jobQueue->Count() != 0 || !tp->isRunning || tp->isPaused; + }); + if (!tp->isRunning || tp->isPaused) { + return; + } + if (tp->jobQueue->Count() == 0) { + continue; + } + j = (*tp->jobQueue)[0]; + tp->jobQueue->Remove(0, 1); + } + assert(j != nullptr); + (*j)(); + Pke_Delete>(j, tp->bkt); + tp->completedCount = tp->completedCount + 1; + } +} + +void inline PkeThreads_JoinAll_Inner(ThreadPool &tp) { + tp.condition.notify_all(); + long count = tp.threads->Count(); + for (long l = 0; l < count; ++l) { + auto &t = (*tp.threads)[l]; + if (t.joinable()) { + t.join(); + } + } +} + +void inline PkeThreads_Reset_Inner(ThreadPool &tp) { + tp.maxJobQueueCount = 0; + tp.completedCount = 0; + tp.jobQueue->Resize(0); + tp.threads->Resize(0); +} + +bool inline PkeThreads_Enqueue_Inner(ThreadPool &tp, std::packaged_task *job) { + tp.mutex.lock(); + if (tp.isRunning == true) { + if (tp.jobQueue->Count() < tp.maxJobQueueCount) { + tp.jobQueue->Push(job); + tp.condition.notify_one(); + tp.mutex.unlock(); + return true; + } + } + tp.mutex.unlock(); + return false; +} + +void inline PkeThreads_Pause_Inner(ThreadPool &tp) { + tp.mutex.lock(); + if (tp.isPaused == true) { + return; // called more than once + } + tp.isPaused = true; + tp.mutex.unlock(); + PkeThreads_JoinAll_Inner(tp); +} + +void inline PkeThreads_Resume_Inner(ThreadPool &tp) { + tp.mutex.lock(); + tp.isPaused = false; + long count = tp.threads->Count(); + for (size_t i = 0; i < count; i++) { + (*tp.threads)[i] = std::thread(std::bind(ThreadRun, &tp)); + } + tp.mutex.unlock(); +} + +void inline PkeThreads_Shutdown_Inner(ThreadPool &tp) { + tp.mutex.lock(); + if (tp.isRunning == false) { + return; + } + tp.isRunning = false; + tp.isPaused = false; + tp.jobQueue->Resize(0); + tp.mutex.unlock(); + PkeThreads_JoinAll_Inner(tp); +} + +ThreadPoolHandle PkeThreads_Init(uint8_t threadCount, uint8_t maxQueueCount, MemBucket *bkt) { + if (!ThreadPool_BucketContainer.buckets) { + Buckets_Init(ThreadPool_BucketContainer); + } + assert(threadCount > 0); + ThreadPoolHandle_T newHandle{Buckets_NewHandle(255, ThreadPool_BucketContainer)}; + + auto b = Buckets_GetBucketIndex(newHandle); + auto i = Buckets_GetItemIndex(newHandle); + auto *tp = &ThreadPool_BucketContainer.buckets[b].threadPools[i]; + + tp->bkt = bkt; + tp->isRunning = true; + tp->isPaused = false; + tp->maxJobQueueCount = maxQueueCount; + tp->completedCount = 0; + tp->jobQueue = Pke_New *>>(bkt); + tp->threads = Pke_New>(bkt); + + tp->threads->Resize(threadCount); + for (long l = 0; l < threadCount; ++l) { + (*tp->threads)[l] = std::thread(std::bind(ThreadRun, tp)); + } + + return ThreadPoolHandle{newHandle}; +} + +void PkeThreads_Reset(ThreadPoolHandle handle) { + ThreadPoolHandle_T handle_T{static_cast(handle)}; + auto b = Buckets_GetBucketIndex(handle_T); + auto i = Buckets_GetItemIndex(handle_T); + auto *tp = &ThreadPool_BucketContainer.buckets[b].threadPools[i]; + PkeThreads_Reset_Inner(*tp); +} + +bool PkeThreads_Enqueue(ThreadPoolHandle handle, std::packaged_task job) { + ThreadPoolHandle_T handle_T{static_cast(handle)}; + auto b = Buckets_GetBucketIndex(handle_T); + auto i = Buckets_GetItemIndex(handle_T); + auto *tp = &ThreadPool_BucketContainer.buckets[b].threadPools[i]; + + auto jobPtr = Pke_New>(tp->bkt); + *jobPtr = std::move(job); + + return PkeThreads_Enqueue_Inner(*tp, jobPtr); +} + +void PkeThreads_Pause(ThreadPoolHandle handle) { + ThreadPoolHandle_T handle_T{static_cast(handle)}; + auto b = Buckets_GetBucketIndex(handle_T); + auto i = Buckets_GetItemIndex(handle_T); + auto *tp = &ThreadPool_BucketContainer.buckets[b].threadPools[i]; + + PkeThreads_Pause_Inner(*tp); +} + +void PkeThreads_Resume(ThreadPoolHandle handle) { + ThreadPoolHandle_T handle_T{static_cast(handle)}; + auto b = Buckets_GetBucketIndex(handle_T); + auto i = Buckets_GetItemIndex(handle_T); + auto *tp = &ThreadPool_BucketContainer.buckets[b].threadPools[i]; + + PkeThreads_Resume_Inner(*tp); +} + +void PkeThreads_Shutdown(ThreadPoolHandle handle) { + ThreadPoolHandle_T handle_T{static_cast(handle)}; + auto b = Buckets_GetBucketIndex(handle_T); + auto i = Buckets_GetItemIndex(handle_T); + auto *tp = &ThreadPool_BucketContainer.buckets[b].threadPools[i]; + + PkeThreads_Shutdown_Inner(*tp); +} + +void PkeThreads_Teardown(ThreadPoolHandle handle) { + ThreadPoolHandle_T handle_T{static_cast(handle)}; + auto b = Buckets_GetBucketIndex(handle_T); + auto i = Buckets_GetItemIndex(handle_T); + auto *tp = &ThreadPool_BucketContainer.buckets[b].threadPools[i]; + + PkeThreads_Shutdown_Inner(*tp); + PkeThreads_Reset_Inner(*tp); + Pke_Delete *>>(tp->jobQueue, tp->bkt); + Pke_Delete>(tp->threads, tp->bkt); + tp->jobQueue = CAFE_BABE(DynArray *>); + tp->threads = CAFE_BABE(DynArray); + tp->bkt = CAFE_BABE(MemBucket); +} diff --git a/src/thread_pool.hpp b/src/thread_pool.hpp new file mode 100644 index 0000000..12962dd --- /dev/null +++ b/src/thread_pool.hpp @@ -0,0 +1,22 @@ +#ifndef PKE_THREADING_HPP +#define PKE_THREADING_HPP + +#include "dynamic-array.hpp" +#include "macros.hpp" + +#include +#include +#include +#include + +TypeSafeInt_H(ThreadPoolHandle, uint64_t, 0xFFFFFFFFFFFFFFFF); + +ThreadPoolHandle PkeThreads_Init (uint8_t threadCount, uint8_t maxQueueCount, MemBucket *bkt = nullptr); +void PkeThreads_Reset (ThreadPoolHandle handle); +bool PkeThreads_Enqueue (ThreadPoolHandle handle, std::packaged_task job); +void PkeThreads_Pause (ThreadPoolHandle handle); +void PkeThreads_Resume (ThreadPoolHandle handle); +void PkeThreads_Shutdown (ThreadPoolHandle handle); +void PkeThreads_Teardown (ThreadPoolHandle handle); + +#endif /* PKE_THREADING_HPP */ -- cgit v1.2.3