summaryrefslogtreecommitdiff
path: root/src/thread-pool.cpp
diff options
context:
space:
mode:
authorJonathan Bradley <jcb@pikum.xyz>2025-01-09 14:44:31 -0500
committerJonathan Bradley <jcb@pikum.xyz>2025-01-09 14:44:31 -0500
commitb76e309166f714b0a66fb4802f02e92a82d09082 (patch)
tree44244fc0e6f873e1ccf6e1e95e2fec62fcec394a /src/thread-pool.cpp
parentb04fefe8ee0086bc1404c06b8351ecb4e942f151 (diff)
flatten file structure + rename
Diffstat (limited to 'src/thread-pool.cpp')
-rw-r--r--src/thread-pool.cpp217
1 files changed, 217 insertions, 0 deletions
diff --git a/src/thread-pool.cpp b/src/thread-pool.cpp
new file mode 100644
index 0000000..061ae68
--- /dev/null
+++ b/src/thread-pool.cpp
@@ -0,0 +1,217 @@
+
+#include "thread-pool.hpp"
+
+#include "bucketed-array.hpp"
+
+#include <functional>
+#include <future>
+
+struct ThreadPool {
+ bool isRunning;
+ bool isPaused;
+ uint8_t maxJobQueueCount;
+ std::mutex mutex;
+ std::atomic<uint64_t> completedCount;
+ std::condition_variable condition;
+ DynArray<std::packaged_task<void()> *> *jobQueue;
+ DynArray<std::thread> *threads;
+ struct pk_membucket *bkt = nullptr;
+};
+
+const pk_handle_item_index_T MAX_THREADS_PER_BUCKET = 8;
+
+BucketContainer<ThreadPool, ThreadPoolHandle> ThreadPool_BucketContainer{};
+
+void ThreadRun(ThreadPool *tp) {
+ std::packaged_task<void()> *j = nullptr;
+ while (tp->isRunning && !tp->isPaused) {
+ {
+ std::unique_lock<std::mutex> lck(tp->mutex);
+ tp->condition.wait(lck, [tp] {
+ if (!tp->isRunning) return true;
+ if (tp->isPaused) return true;
+ if (tp->jobQueue == nullptr) return true;
+ if (tp->jobQueue == CAFE_BABE(DynArray<std::packaged_task<void()> *>)) return true;
+ return tp->jobQueue->Count() != 0;
+ });
+ if (!tp->isRunning || tp->isPaused || tp->jobQueue == nullptr || tp->jobQueue == CAFE_BABE(DynArray<std::packaged_task<void()> *>)) {
+ return;
+ }
+ if (tp->jobQueue->Count() == 0) {
+ continue;
+ }
+ j = (*tp->jobQueue)[0];
+ tp->jobQueue->Remove(0, 1);
+ }
+ assert(j != nullptr);
+ (*j)();
+ pk_delete<std::packaged_task<void()>>(j, tp->bkt);
+ tp->completedCount = tp->completedCount + 1;
+ }
+}
+
+void inline PkeThreads_JoinAll_Inner(ThreadPool &tp) {
+ tp.condition.notify_all();
+ long count = tp.threads->Count();
+ for (long l = 0; l < count; ++l) {
+ auto &t = (*tp.threads)[l];
+ if (t.joinable()) {
+ t.join();
+ }
+ }
+}
+
+void inline PkeThreads_DetatchAll_Inner(ThreadPool &tp) {
+ long count = tp.threads->Count();
+ for (long i = 0; i < count; ++i) {
+ auto &t = (*tp.threads)[i];
+ t.detach();
+ }
+ tp.condition.notify_all();
+}
+
+void inline PkeThreads_Reset_Inner(ThreadPool &tp) {
+ tp.mutex.lock();
+ tp.maxJobQueueCount = 0;
+ tp.completedCount = 0;
+ tp.jobQueue->Resize(0);
+ tp.threads->Resize(0);
+ tp.mutex.unlock();
+}
+
+bool inline PkeThreads_Enqueue_Inner(ThreadPool &tp, std::packaged_task<void()> *job) {
+ tp.mutex.lock();
+ if (tp.isRunning == true) {
+ if (tp.jobQueue->Count() < tp.maxJobQueueCount) {
+ tp.jobQueue->Push(job);
+ tp.condition.notify_one();
+ tp.mutex.unlock();
+ return true;
+ }
+ }
+ tp.mutex.unlock();
+ return false;
+}
+
+void inline PkeThreads_Pause_Inner(ThreadPool &tp) {
+ tp.mutex.lock();
+ if (tp.isPaused == true) {
+ return; // called more than once
+ }
+ tp.isPaused = true;
+ tp.mutex.unlock();
+ PkeThreads_JoinAll_Inner(tp);
+}
+
+void inline PkeThreads_Resume_Inner(ThreadPool &tp) {
+ tp.mutex.lock();
+ tp.isPaused = false;
+ long count = tp.threads->Count();
+ for (size_t i = 0; i < count; i++) {
+ (*tp.threads)[i] = std::thread(std::bind(ThreadRun, &tp));
+ }
+ tp.mutex.unlock();
+}
+
+void inline PkeThreads_Shutdown_Inner(ThreadPool &tp) {
+ tp.mutex.lock();
+ if (tp.isRunning == false) {
+ return;
+ }
+ tp.isRunning = false;
+ tp.isPaused = false;
+ tp.jobQueue->Resize(0);
+ tp.mutex.unlock();
+}
+
+void PkeThreads_Init() {
+ Buckets_Init(ThreadPool_BucketContainer, MAX_THREADS_PER_BUCKET);
+}
+
+ThreadPoolHandle PkeThreads_Init(uint8_t threadCount, uint8_t maxQueueCount, struct pk_membucket *bkt) {
+ assert(threadCount > 0);
+ ThreadPoolHandle newHandle{Buckets_NewHandle(ThreadPool_BucketContainer)};
+
+ auto *tp = &ThreadPool_BucketContainer.buckets[newHandle.bucketIndex][newHandle.itemIndex];
+
+ tp->bkt = bkt;
+ tp->isRunning = true;
+ tp->isPaused = false;
+ tp->maxJobQueueCount = maxQueueCount;
+ tp->completedCount = 0;
+ tp->jobQueue = pk_new<DynArray<std::packaged_task<void()> *>>(bkt);
+ tp->threads = pk_new<DynArray<std::thread>>(bkt);
+
+ tp->threads->Resize(threadCount);
+ for (long l = 0; l < threadCount; ++l) {
+ (*tp->threads)[l] = std::thread(std::bind(ThreadRun, tp));
+ }
+
+ return newHandle;
+}
+
+void PkeThreads_Reset(ThreadPoolHandle handle) {
+ assert(handle != ThreadPoolHandle_MAX);
+ auto *tp = &ThreadPool_BucketContainer.buckets[handle.bucketIndex][handle.itemIndex];
+ PkeThreads_Reset_Inner(*tp);
+}
+
+bool PkeThreads_Enqueue(ThreadPoolHandle handle, std::packaged_task<void()> *job) {
+ assert(handle != ThreadPoolHandle_MAX);
+ auto *tp = &ThreadPool_BucketContainer.buckets[handle.bucketIndex][handle.itemIndex];
+ if (tp->bkt != nullptr) {
+ /* 2023-12-22 JCB
+ * Note that if this becomes an issue we can change it.
+ * Technically speaking, if we call the right pk_delete
+ * we don't even need to worry about passing the struct pk_membucket
+ */
+ assert(pk_memory_is_in_bucket(job, tp->bkt) == true && "cannot enqueue packaged task from a non-matching struct pk_membucket");
+ }
+
+ return PkeThreads_Enqueue_Inner(*tp, job);
+}
+
+int64_t PkeThreads_GetQueueCount (ThreadPoolHandle handle) {
+ auto &threadPool = ThreadPool_BucketContainer.buckets[handle.bucketIndex][handle.itemIndex];
+ return threadPool.jobQueue->Count();
+}
+
+void PkeThreads_Pause(ThreadPoolHandle handle) {
+ assert(handle != ThreadPoolHandle_MAX);
+ auto *tp = &ThreadPool_BucketContainer.buckets[handle.bucketIndex][handle.itemIndex];
+
+ PkeThreads_Pause_Inner(*tp);
+}
+
+void PkeThreads_Resume(ThreadPoolHandle handle) {
+ assert(handle != ThreadPoolHandle_MAX);
+ auto *tp = &ThreadPool_BucketContainer.buckets[handle.bucketIndex][handle.itemIndex];
+
+ PkeThreads_Resume_Inner(*tp);
+}
+
+void PkeThreads_Shutdown(ThreadPoolHandle handle) {
+ assert(handle != ThreadPoolHandle_MAX);
+ auto *tp = &ThreadPool_BucketContainer.buckets[handle.bucketIndex][handle.itemIndex];
+
+ PkeThreads_Shutdown_Inner(*tp);
+ PkeThreads_JoinAll_Inner(*tp);
+}
+
+void PkeThreads_Teardown(ThreadPoolHandle handle) {
+ assert(handle != ThreadPoolHandle_MAX);
+ auto *tp = &ThreadPool_BucketContainer.buckets[handle.bucketIndex][handle.itemIndex];
+
+ PkeThreads_Shutdown_Inner(*tp);
+ PkeThreads_JoinAll_Inner(*tp);
+ PkeThreads_Reset_Inner(*tp);
+ pk_delete<DynArray<std::packaged_task<void()> *>>(tp->jobQueue, tp->bkt);
+ pk_delete<DynArray<std::thread>>(tp->threads, tp->bkt);
+ tp->jobQueue = CAFE_BABE(DynArray<std::packaged_task<void()> *>);
+ tp->threads = CAFE_BABE(DynArray<std::thread>);
+ tp->bkt = CAFE_BABE(struct pk_membucket);
+}
+
+void PkeThreads_Teardown() {
+ Buckets_Destroy(ThreadPool_BucketContainer);
+}