From 044a3f3747a96f61abc8b546496eddb2dcdedd0e Mon Sep 17 00:00:00 2001 From: teridax Date: Wed, 24 May 2023 12:11:48 +0200 Subject: [PATCH] replaced `threads` memeber from `Threadpool` mutex with atomic primitive --- src/multithreading/mod.rs | 55 ++++++++++++++++++++++++++++++--------- 1 file changed, 42 insertions(+), 13 deletions(-) diff --git a/src/multithreading/mod.rs b/src/multithreading/mod.rs index 0501061..4cf7e28 100644 --- a/src/multithreading/mod.rs +++ b/src/multithreading/mod.rs @@ -1,9 +1,11 @@ use std::{ any::Any, collections::VecDeque, - num::{NonZeroU32, NonZeroUsize}, - ops::{AddAssign, SubAssign}, - sync::{Arc, Mutex}, + num::NonZeroUsize, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, + }, thread::{self, JoinHandle}, }; @@ -15,7 +17,7 @@ pub const FALLBACK_THREADS: usize = 1; /// Returns the number of threads to be used by the thread pool by default. /// This function tries to fetch a recommended number by calling [`thread::available_parallelism`]. /// In case this fails [`FALLBACK_THREADS`] will be returned -fn get_default_thread_count() -> u32 { +fn get_default_thread_count() -> usize { // number of threads to fallback to let fallback_threads = NonZeroUsize::new(FALLBACK_THREADS).expect("fallback_threads must be nonzero"); @@ -23,7 +25,7 @@ fn get_default_thread_count() -> u32 { // most of the time this is gonna be the number of cpus thread::available_parallelism() .unwrap_or(fallback_threads) - .get() as u32 + .get() } /// This struct manages a pool of threads with a fixed maximum number. @@ -34,7 +36,7 @@ fn get_default_thread_count() -> u32 { /// The pool will also keep track of every `JoinHandle` created by running every closure on /// its on thread. The closures can be obtained by either calling `join_all` or `get_finished`. /// # Example -/// ```rust +/// ```rust ignore /// let mut pool = ThreadPool::new(); /// /// // launch some work in parallel @@ -46,6 +48,14 @@ fn get_default_thread_count() -> u32 { /// // wait for threads to finish /// pool.join_all(); /// ``` +/// # Portability +/// This implementation is not fully platform independent. This is due to the usage of [`std::sync::atomic::AtomicUsize`]. +/// This type is used to remove some locks from otherwise used [`std::sync::Mutex`] wrapping a [`usize`]. +/// Note that atomic primitives are not available on all platforms but "can generally be relied upon existing" +/// (see: https://doc.rust-lang.org/std/sync/atomic/index.html). +/// Additionally this implementation relies on using the `load` and `store` operations +/// instead of using more comfortable one like `fetch_add` in order to avoid unnecessary calls +/// to `unwrap` or `expected` from [`std::sync::MutexGuard`]. #[allow(dead_code)] #[derive(Debug)] pub struct ThreadPool @@ -53,13 +63,18 @@ where F: Send + FnOnce() -> T + Send + 'static, { /// maximum number of threads to launch at once - max_thread_count: u32, + max_thread_count: usize, /// handles for launched threads handles: Arc>>>, /// function to be executed when threads are ready queue: Arc>>, /// number of currently running threads - threads: Arc>, + /// new implementation relies on atomic primitives to avoid locking and possible + /// guard errors. Note that atomic primitives are not available on all platforms "can generally be relied upon existing" + /// (see: https://doc.rust-lang.org/std/sync/atomic/index.html). + /// Also this implementation relies on using the `load` and `store` operations + /// instead of using more comfortable one like `fetch_add` + threads: Arc, } impl Default for ThreadPool @@ -99,7 +114,7 @@ where /// # Overusage /// supplying a number of threads to great may negatively impact performance as the system may not /// be able to full fill the required needs - pub fn with_threads(max_thread_count: NonZeroU32) -> Self { + pub fn with_threads(max_thread_count: NonZeroUsize) -> Self { Self { max_thread_count: max_thread_count.get(), ..Default::default() @@ -113,10 +128,13 @@ where /// If `join_all` is called and the closure hasn't been executed yet, `join_all` will wait for all stalled /// closures be executed. pub fn enqueue(&mut self, closure: F) { + // read used thread counter and apply all store operations with Ordering::Release + let used_threads = self.threads.load(Ordering::Acquire); // test if we can launch a new thread - if self.threads.lock().unwrap().to_owned() < self.max_thread_count { + if used_threads < self.max_thread_count { // we can create a new thread, increment the thread count - self.threads.lock().unwrap().add_assign(1); + self.threads + .store(used_threads.saturating_add(1), Ordering::Release); // run new thread execute( self.queue.clone(), @@ -131,6 +149,14 @@ where } } + /// Removes all closures stalled for execution. + /// All closures still waiting to be executed will be dropped by the pool and + /// won't get executed. Useful if an old set of closures hasn't run yet but are outdated + /// and resources are required immediately for updated closures. + pub fn discard_stalled(&mut self) { + self.queue.lock().unwrap().clear(); + } + /// Waits for all currently running threads and all stalled closures to be executed. /// If any closure hasn't been executed yet, `join_all` will wait until the queue holding all /// unexecuted closures is empty. It returns the result every `join` of all threads yields as a vector. @@ -188,7 +214,7 @@ where fn execute( queue: Arc>>, handles: Arc>>>, - threads: Arc>, + threads: Arc, closure: F, ) where T: Send + 'static, @@ -208,7 +234,10 @@ fn execute( } else { // nothing to execute this thread will run out without any work to do // decrement the amount of used threads - threads.lock().unwrap().sub_assign(1); + threads.store( + threads.load(Ordering::Acquire).saturating_sub(1), + Ordering::Release, + ) } result