summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJonathan Bradley <jcb@pikum.xyz>2023-11-29 20:58:54 -0500
committerJonathan Bradley <jcb@pikum.xyz>2023-11-29 20:58:54 -0500
commit67e40a5d4d3b4e7ed6823e9a768cea48cfd6cfe6 (patch)
tree366aafc44b5c79ef7f8f0898e50f539293c9ffba /src
parentca78d2bdc7ac42577d14cc199ca082dbe7cd1388 (diff)
add thread_pool
Diffstat (limited to 'src')
-rw-r--r--src/thread_pool.cpp202
-rw-r--r--src/thread_pool.hpp22
2 files changed, 224 insertions, 0 deletions
diff --git a/src/thread_pool.cpp b/src/thread_pool.cpp
new file mode 100644
index 0000000..d3073dc
--- /dev/null
+++ b/src/thread_pool.cpp
@@ -0,0 +1,202 @@
+
+#include "thread_pool.hpp"
+
+#include <functional>
+#include <future>
+
+TypeSafeInt_B(ThreadPoolHandle);
+
+struct ThreadPool {
+ bool isRunning;
+ bool isPaused;
+ uint8_t maxJobQueueCount;
+ std::mutex mutex;
+ std::atomic<uint64_t> completedCount;
+ std::condition_variable condition;
+ DynArray<std::packaged_task<void()> *> *jobQueue;
+ DynArray<std::thread> *threads;
+ MemBucket *bkt = nullptr;
+};
+
+struct ThreadBucket {
+ ThreadPool threadPools[8];
+};
+
+BucketContainer<ThreadBucket, ThreadPoolHandle_T> ThreadPool_BucketContainer{};
+
+void ThreadRun(ThreadPool *tp) {
+ std::packaged_task<void()> *j = nullptr;
+ while (tp->isRunning && !tp->isPaused) {
+ {
+ std::unique_lock<std::mutex> lck(tp->mutex);
+ tp->condition.wait(lck, [tp] {
+ return tp->jobQueue->Count() != 0 || !tp->isRunning || tp->isPaused;
+ });
+ if (!tp->isRunning || tp->isPaused) {
+ return;
+ }
+ if (tp->jobQueue->Count() == 0) {
+ continue;
+ }
+ j = (*tp->jobQueue)[0];
+ tp->jobQueue->Remove(0, 1);
+ }
+ assert(j != nullptr);
+ (*j)();
+ Pke_Delete<std::packaged_task<void()>>(j, tp->bkt);
+ tp->completedCount = tp->completedCount + 1;
+ }
+}
+
+void inline PkeThreads_JoinAll_Inner(ThreadPool &tp) {
+ tp.condition.notify_all();
+ long count = tp.threads->Count();
+ for (long l = 0; l < count; ++l) {
+ auto &t = (*tp.threads)[l];
+ if (t.joinable()) {
+ t.join();
+ }
+ }
+}
+
+void inline PkeThreads_Reset_Inner(ThreadPool &tp) {
+ tp.maxJobQueueCount = 0;
+ tp.completedCount = 0;
+ tp.jobQueue->Resize(0);
+ tp.threads->Resize(0);
+}
+
+bool inline PkeThreads_Enqueue_Inner(ThreadPool &tp, std::packaged_task<void()> *job) {
+ tp.mutex.lock();
+ if (tp.isRunning == true) {
+ if (tp.jobQueue->Count() < tp.maxJobQueueCount) {
+ tp.jobQueue->Push(job);
+ tp.condition.notify_one();
+ tp.mutex.unlock();
+ return true;
+ }
+ }
+ tp.mutex.unlock();
+ return false;
+}
+
+void inline PkeThreads_Pause_Inner(ThreadPool &tp) {
+ tp.mutex.lock();
+ if (tp.isPaused == true) {
+ return; // called more than once
+ }
+ tp.isPaused = true;
+ tp.mutex.unlock();
+ PkeThreads_JoinAll_Inner(tp);
+}
+
+void inline PkeThreads_Resume_Inner(ThreadPool &tp) {
+ tp.mutex.lock();
+ tp.isPaused = false;
+ long count = tp.threads->Count();
+ for (size_t i = 0; i < count; i++) {
+ (*tp.threads)[i] = std::thread(std::bind(ThreadRun, &tp));
+ }
+ tp.mutex.unlock();
+}
+
+void inline PkeThreads_Shutdown_Inner(ThreadPool &tp) {
+ tp.mutex.lock();
+ if (tp.isRunning == false) {
+ return;
+ }
+ tp.isRunning = false;
+ tp.isPaused = false;
+ tp.jobQueue->Resize(0);
+ tp.mutex.unlock();
+ PkeThreads_JoinAll_Inner(tp);
+}
+
+ThreadPoolHandle PkeThreads_Init(uint8_t threadCount, uint8_t maxQueueCount, MemBucket *bkt) {
+ if (!ThreadPool_BucketContainer.buckets) {
+ Buckets_Init(ThreadPool_BucketContainer);
+ }
+ assert(threadCount > 0);
+ ThreadPoolHandle_T newHandle{Buckets_NewHandle(255, ThreadPool_BucketContainer)};
+
+ auto b = Buckets_GetBucketIndex(newHandle);
+ auto i = Buckets_GetItemIndex(newHandle);
+ auto *tp = &ThreadPool_BucketContainer.buckets[b].threadPools[i];
+
+ tp->bkt = bkt;
+ tp->isRunning = true;
+ tp->isPaused = false;
+ tp->maxJobQueueCount = maxQueueCount;
+ tp->completedCount = 0;
+ tp->jobQueue = Pke_New<DynArray<std::packaged_task<void()> *>>(bkt);
+ tp->threads = Pke_New<DynArray<std::thread>>(bkt);
+
+ tp->threads->Resize(threadCount);
+ for (long l = 0; l < threadCount; ++l) {
+ (*tp->threads)[l] = std::thread(std::bind(ThreadRun, tp));
+ }
+
+ return ThreadPoolHandle{newHandle};
+}
+
+void PkeThreads_Reset(ThreadPoolHandle handle) {
+ ThreadPoolHandle_T handle_T{static_cast<ThreadPoolHandle_T>(handle)};
+ auto b = Buckets_GetBucketIndex(handle_T);
+ auto i = Buckets_GetItemIndex(handle_T);
+ auto *tp = &ThreadPool_BucketContainer.buckets[b].threadPools[i];
+ PkeThreads_Reset_Inner(*tp);
+}
+
+bool PkeThreads_Enqueue(ThreadPoolHandle handle, std::packaged_task<void()> job) {
+ ThreadPoolHandle_T handle_T{static_cast<ThreadPoolHandle_T>(handle)};
+ auto b = Buckets_GetBucketIndex(handle_T);
+ auto i = Buckets_GetItemIndex(handle_T);
+ auto *tp = &ThreadPool_BucketContainer.buckets[b].threadPools[i];
+
+ auto jobPtr = Pke_New<std::packaged_task<void()>>(tp->bkt);
+ *jobPtr = std::move(job);
+
+ return PkeThreads_Enqueue_Inner(*tp, jobPtr);
+}
+
+void PkeThreads_Pause(ThreadPoolHandle handle) {
+ ThreadPoolHandle_T handle_T{static_cast<ThreadPoolHandle_T>(handle)};
+ auto b = Buckets_GetBucketIndex(handle_T);
+ auto i = Buckets_GetItemIndex(handle_T);
+ auto *tp = &ThreadPool_BucketContainer.buckets[b].threadPools[i];
+
+ PkeThreads_Pause_Inner(*tp);
+}
+
+void PkeThreads_Resume(ThreadPoolHandle handle) {
+ ThreadPoolHandle_T handle_T{static_cast<ThreadPoolHandle_T>(handle)};
+ auto b = Buckets_GetBucketIndex(handle_T);
+ auto i = Buckets_GetItemIndex(handle_T);
+ auto *tp = &ThreadPool_BucketContainer.buckets[b].threadPools[i];
+
+ PkeThreads_Resume_Inner(*tp);
+}
+
+void PkeThreads_Shutdown(ThreadPoolHandle handle) {
+ ThreadPoolHandle_T handle_T{static_cast<ThreadPoolHandle_T>(handle)};
+ auto b = Buckets_GetBucketIndex(handle_T);
+ auto i = Buckets_GetItemIndex(handle_T);
+ auto *tp = &ThreadPool_BucketContainer.buckets[b].threadPools[i];
+
+ PkeThreads_Shutdown_Inner(*tp);
+}
+
+void PkeThreads_Teardown(ThreadPoolHandle handle) {
+ ThreadPoolHandle_T handle_T{static_cast<ThreadPoolHandle_T>(handle)};
+ auto b = Buckets_GetBucketIndex(handle_T);
+ auto i = Buckets_GetItemIndex(handle_T);
+ auto *tp = &ThreadPool_BucketContainer.buckets[b].threadPools[i];
+
+ PkeThreads_Shutdown_Inner(*tp);
+ PkeThreads_Reset_Inner(*tp);
+ Pke_Delete<DynArray<std::packaged_task<void()> *>>(tp->jobQueue, tp->bkt);
+ Pke_Delete<DynArray<std::thread>>(tp->threads, tp->bkt);
+ tp->jobQueue = CAFE_BABE(DynArray<std::packaged_task<void()> *>);
+ tp->threads = CAFE_BABE(DynArray<std::thread>);
+ tp->bkt = CAFE_BABE(MemBucket);
+}
diff --git a/src/thread_pool.hpp b/src/thread_pool.hpp
new file mode 100644
index 0000000..12962dd
--- /dev/null
+++ b/src/thread_pool.hpp
@@ -0,0 +1,22 @@
+#ifndef PKE_THREADING_HPP
+#define PKE_THREADING_HPP
+
+#include "dynamic-array.hpp"
+#include "macros.hpp"
+
+#include <atomic>
+#include <condition_variable>
+#include <cstdint>
+#include <future>
+
+TypeSafeInt_H(ThreadPoolHandle, uint64_t, 0xFFFFFFFFFFFFFFFF);
+
+ThreadPoolHandle PkeThreads_Init (uint8_t threadCount, uint8_t maxQueueCount, MemBucket *bkt = nullptr);
+void PkeThreads_Reset (ThreadPoolHandle handle);
+bool PkeThreads_Enqueue (ThreadPoolHandle handle, std::packaged_task<void()> job);
+void PkeThreads_Pause (ThreadPoolHandle handle);
+void PkeThreads_Resume (ThreadPoolHandle handle);
+void PkeThreads_Shutdown (ThreadPoolHandle handle);
+void PkeThreads_Teardown (ThreadPoolHandle handle);
+
+#endif /* PKE_THREADING_HPP */