#ifndef _THREAD_POOL_H_ #define _THREAD_POOL_H_ #include #include #include #include #include #include #include #include #include #include #include template class threadsafe_queue { private: mutable std::mutex mut; std::queue data_queue; public: threadsafe_queue() {} void push(T &&new_value) { std::lock_guard lk(mut); data_queue.emplace(new_value); } bool try_pop(T& value) { std::lock_guard lk(mut); if (data_queue.empty()) return false; value = std::move(data_queue.front()); data_queue.pop(); return true; } size_t size() { std::lock_guard lk(mut); return data_queue.size(); } }; class thread_pool { std::atomic_bool done; int nthreads; std::map thread_id; threadsafe_queue> work_queue; std::vector threads; std::condition_variable cv; std::mutex m; void worker_thread() { while (!done) { std::function task; if (work_queue.try_pop(task)) { task(); } else { std::unique_lock lock(m); cv.wait(lock, [&] {return work_queue.size() != 0 || done; }); } } } public: thread_pool(int nthr): done(false), nthreads(nthr) { try { for (int i = 0; i < nthreads; ++i) { threads.emplace_back( std::thread(&thread_pool::worker_thread, this)); thread_id[threads[i].get_id()] = i; } } catch (...) { done = true; throw; } } ~thread_pool() { if (nthreads > 0) { done = true; std::unique_lock lock(m); cv.notify_all(); lock.unlock(); for (std::thread &thread : threads) { thread.join(); } } } template std::future::type> submit(Function &&f, Args&&... args) { using result_type = typename std::result_of::type; auto task = std::make_shared>(std::bind(std::forward(f), std::forward(args)...)); std::future res(task->get_future()); work_queue.push([task] { (*task)(); }); std::unique_lock lock(m); cv.notify_one(); return res; } int size() { return nthreads; } int thread_id_to_int(std::thread::id id) { return thread_id[id]; } template void parallel_for(T start, T end, T stride, Function &&f) { T range = end - start; T block_size = range / (nthreads); T block_start = start; T block_end = block_start + block_size; if (block_size == 0) block_end = end; std::vector> res; while (block_start < end) { res.emplace_back(submit(f, block_start, block_end, stride)); block_start = block_end; block_end = block_end + block_size; if (block_end >= end) block_end = end; } for (size_t i = 0; i < res.size(); i++) res[i].get(); } }; #endif