#include "thread-pool.hpp" #include "pk.h" #include #include struct ThreadPool { bool isRunning; bool isPaused; uint8_t maxJobQueueCount; std::mutex mutex; std::atomic completedCount; std::condition_variable condition; pk_arr_t *> jobQueue; pk_arr_t threads; struct pk_membucket *bkt = nullptr; }; struct ThreadPoolMaster { pk_membucket *bkt; pk_bkt_arr_t bc{}; }thrdpl_mstr; void ThreadRun(ThreadPool *tp) { std::packaged_task *j = nullptr; while (tp->isRunning && !tp->isPaused) { { std::unique_lock lck(tp->mutex); tp->condition.wait(lck, [tp] { if (!tp->isRunning) return true; if (tp->isPaused) return true; return tp->jobQueue.next != 0; }); if (!tp->isRunning || tp->isPaused) { return; } if (tp->jobQueue.next == 0) { continue; } j = tp->jobQueue[0]; pk_arr_remove_at(&tp->jobQueue, 0); } assert(j != nullptr); (*j)(); j->~packaged_task(); pk_delete>(j, tp->bkt); tp->completedCount = tp->completedCount + 1; } } void inline PkeThreads_JoinAll_Inner(ThreadPool &tp) { tp.condition.notify_all(); uint32_t count = tp.threads.next; for (uint32_t l = 0; l < count; ++l) { auto &t = *tp.threads[l]; if (t.joinable()) { t.join(); } } } void inline PkeThreads_DetatchAll_Inner(ThreadPool &tp) { uint32_t count = tp.threads.next; for (uint32_t i = 0; i < count; ++i) { auto &t = *tp.threads[i]; t.detach(); } tp.condition.notify_all(); } void inline PkeThreads_Reset_Inner(ThreadPool &tp) { tp.mutex.lock(); tp.maxJobQueueCount = 0; tp.completedCount = 0; for (uint32_t i = 0; i < tp.threads.next; ++i) { pk_delete(tp.threads[i]); } pk_arr_clear(&tp.jobQueue); pk_arr_clear(&tp.threads); tp.mutex.unlock(); } bool inline PkeThreads_Enqueue_Inner(ThreadPool &tp, std::packaged_task *job) { tp.mutex.lock(); if (tp.isRunning == true) { if (tp.jobQueue.next < tp.maxJobQueueCount) { pk_arr_append_t(&tp.jobQueue, job); tp.condition.notify_one(); tp.mutex.unlock(); return true; } } tp.mutex.unlock(); return false; } void inline PkeThreads_Pause_Inner(ThreadPool &tp) { tp.mutex.lock(); if (tp.isPaused == true) { return; // called more than once } tp.isPaused = true; tp.mutex.unlock(); PkeThreads_JoinAll_Inner(tp); } void inline PkeThreads_Resume_Inner(ThreadPool &tp) { tp.mutex.lock(); tp.isPaused = false; uint32_t count = tp.threads.next; for (uint32_t i = 0; i < count; i++) { new (tp.threads[i]) std::thread{std::bind(ThreadRun, &tp)}; } tp.mutex.unlock(); } void inline PkeThreads_Shutdown_Inner(ThreadPool &tp) { tp.mutex.lock(); if (tp.isRunning == false) { return; } tp.isRunning = false; tp.isPaused = false; pk_arr_clear(&tp.jobQueue); tp.mutex.unlock(); } void PkeThreads_Init() { thrdpl_mstr.bkt = pk_mem_bucket_create("pk_bkt_arr threads", 1024 * 1024, PK_MEMBUCKET_FLAG_NONE); new (&thrdpl_mstr.bc) pk_bkt_arr_t{ pk_bkt_arr_handle_MAX_constexpr, thrdpl_mstr.bkt, thrdpl_mstr.bkt }; } ThreadPoolHandle PkeThreads_Init(uint8_t threadCount, uint8_t maxQueueCount, struct pk_membucket *bkt) { assert(threadCount > 0); ThreadPoolHandle newHandle{pk_bkt_arr_new_handle(&thrdpl_mstr.bc)}; auto *tp = &thrdpl_mstr.bc[newHandle]; new (tp) ThreadPool{}; tp->bkt = bkt; tp->isRunning = true; tp->isPaused = false; tp->maxJobQueueCount = maxQueueCount; tp->completedCount = 0; tp->jobQueue.bkt = bkt; tp->threads.bkt = bkt; pk_arr_resize(&tp->threads, threadCount); for (uint8_t l = 0; l < threadCount; ++l) { tp->threads[l] = pk_new(bkt); new (tp->threads[l]) std::thread(std::bind(ThreadRun, tp)); } return newHandle; } void PkeThreads_Reset(ThreadPoolHandle handle) { assert(handle != ThreadPoolHandle_MAX); auto *tp = &thrdpl_mstr.bc[handle]; PkeThreads_Reset_Inner(*tp); } bool PkeThreads_Enqueue(ThreadPoolHandle handle, std::packaged_task *job) { assert(handle != ThreadPoolHandle_MAX); auto *tp = &thrdpl_mstr.bc[handle]; if (tp->bkt != nullptr) { /* 2023-12-22 JCB * Note that if this becomes an issue we can change it. * Technically speaking, if we call the right pk_delete * we don't even need to worry about passing the struct pk_membucket */ assert(pk_mem_bucket_ptr_is_in_mem_bucket(job, tp->bkt) == true && "cannot enqueue packaged task from a non-matching struct pk_membucket"); } return PkeThreads_Enqueue_Inner(*tp, job); } uint32_t PkeThreads_GetQueueCount (ThreadPoolHandle handle) { auto &threadPool = thrdpl_mstr.bc[handle]; return threadPool.jobQueue.next; } void PkeThreads_Pause(ThreadPoolHandle handle) { assert(handle != ThreadPoolHandle_MAX); auto *tp = &thrdpl_mstr.bc[handle]; PkeThreads_Pause_Inner(*tp); } void PkeThreads_Resume(ThreadPoolHandle handle) { assert(handle != ThreadPoolHandle_MAX); auto *tp = &thrdpl_mstr.bc[handle]; PkeThreads_Resume_Inner(*tp); } void PkeThreads_Shutdown(ThreadPoolHandle handle) { assert(handle != ThreadPoolHandle_MAX); auto *tp = &thrdpl_mstr.bc[handle]; PkeThreads_Shutdown_Inner(*tp); PkeThreads_JoinAll_Inner(*tp); } void PkeThreads_Teardown(ThreadPoolHandle handle) { assert(handle != ThreadPoolHandle_MAX); auto *tp = &thrdpl_mstr.bc[handle]; PkeThreads_Shutdown_Inner(*tp); PkeThreads_JoinAll_Inner(*tp); PkeThreads_Reset_Inner(*tp); for (uint32_t i = 0; i < tp->threads.next; ++i) { if (tp->threads[i] != nullptr) { pk_delete(tp->threads[i]); } } pk_arr_reset(&tp->jobQueue); pk_arr_reset(&tp->threads); tp->jobQueue.bkt = CAFE_BABE(struct pk_membucket); tp->threads.bkt = CAFE_BABE(struct pk_membucket); tp->bkt = CAFE_BABE(struct pk_membucket); } void PkeThreads_Teardown() { pk_bkt_arr_teardown(&thrdpl_mstr.bc); pk_mem_bucket_destroy(thrdpl_mstr.bkt); }