replaced `threads` memeber from `Threadpool`

mutex with atomic primitive
This commit is contained in:
Sven Vogel 2023-05-24 12:11:48 +02:00
parent cbf42d6d57
commit 044a3f3747
1 changed files with 42 additions and 13 deletions

View File

@ -1,9 +1,11 @@
use std::{ use std::{
any::Any, any::Any,
collections::VecDeque, collections::VecDeque,
num::{NonZeroU32, NonZeroUsize}, num::NonZeroUsize,
ops::{AddAssign, SubAssign}, sync::{
sync::{Arc, Mutex}, atomic::{AtomicUsize, Ordering},
Arc, Mutex,
},
thread::{self, JoinHandle}, thread::{self, JoinHandle},
}; };
@ -15,7 +17,7 @@ pub const FALLBACK_THREADS: usize = 1;
/// Returns the number of threads to be used by the thread pool by default. /// Returns the number of threads to be used by the thread pool by default.
/// This function tries to fetch a recommended number by calling [`thread::available_parallelism`]. /// This function tries to fetch a recommended number by calling [`thread::available_parallelism`].
/// In case this fails [`FALLBACK_THREADS`] will be returned /// In case this fails [`FALLBACK_THREADS`] will be returned
fn get_default_thread_count() -> u32 { fn get_default_thread_count() -> usize {
// number of threads to fallback to // number of threads to fallback to
let fallback_threads = let fallback_threads =
NonZeroUsize::new(FALLBACK_THREADS).expect("fallback_threads must be nonzero"); NonZeroUsize::new(FALLBACK_THREADS).expect("fallback_threads must be nonzero");
@ -23,7 +25,7 @@ fn get_default_thread_count() -> u32 {
// most of the time this is gonna be the number of cpus // most of the time this is gonna be the number of cpus
thread::available_parallelism() thread::available_parallelism()
.unwrap_or(fallback_threads) .unwrap_or(fallback_threads)
.get() as u32 .get()
} }
/// This struct manages a pool of threads with a fixed maximum number. /// This struct manages a pool of threads with a fixed maximum number.
@ -34,7 +36,7 @@ fn get_default_thread_count() -> u32 {
/// The pool will also keep track of every `JoinHandle` created by running every closure on /// The pool will also keep track of every `JoinHandle` created by running every closure on
/// its on thread. The closures can be obtained by either calling `join_all` or `get_finished`. /// its on thread. The closures can be obtained by either calling `join_all` or `get_finished`.
/// # Example /// # Example
/// ```rust /// ```rust ignore
/// let mut pool = ThreadPool::new(); /// let mut pool = ThreadPool::new();
/// ///
/// // launch some work in parallel /// // launch some work in parallel
@ -46,6 +48,14 @@ fn get_default_thread_count() -> u32 {
/// // wait for threads to finish /// // wait for threads to finish
/// pool.join_all(); /// pool.join_all();
/// ``` /// ```
/// # Portability
/// This implementation is not fully platform independent. This is due to the usage of [`std::sync::atomic::AtomicUsize`].
/// This type is used to remove some locks from otherwise used [`std::sync::Mutex`] wrapping a [`usize`].
/// Note that atomic primitives are not available on all platforms but "can generally be relied upon existing"
/// (see: https://doc.rust-lang.org/std/sync/atomic/index.html).
/// Additionally this implementation relies on using the `load` and `store` operations
/// instead of using more comfortable one like `fetch_add` in order to avoid unnecessary calls
/// to `unwrap` or `expected` from [`std::sync::MutexGuard`].
#[allow(dead_code)] #[allow(dead_code)]
#[derive(Debug)] #[derive(Debug)]
pub struct ThreadPool<F, T> pub struct ThreadPool<F, T>
@ -53,13 +63,18 @@ where
F: Send + FnOnce() -> T + Send + 'static, F: Send + FnOnce() -> T + Send + 'static,
{ {
/// maximum number of threads to launch at once /// maximum number of threads to launch at once
max_thread_count: u32, max_thread_count: usize,
/// handles for launched threads /// handles for launched threads
handles: Arc<Mutex<Vec<JoinHandle<T>>>>, handles: Arc<Mutex<Vec<JoinHandle<T>>>>,
/// function to be executed when threads are ready /// function to be executed when threads are ready
queue: Arc<Mutex<VecDeque<F>>>, queue: Arc<Mutex<VecDeque<F>>>,
/// number of currently running threads /// number of currently running threads
threads: Arc<Mutex<u32>>, /// new implementation relies on atomic primitives to avoid locking and possible
/// guard errors. Note that atomic primitives are not available on all platforms "can generally be relied upon existing"
/// (see: https://doc.rust-lang.org/std/sync/atomic/index.html).
/// Also this implementation relies on using the `load` and `store` operations
/// instead of using more comfortable one like `fetch_add`
threads: Arc<AtomicUsize>,
} }
impl<F, T> Default for ThreadPool<F, T> impl<F, T> Default for ThreadPool<F, T>
@ -99,7 +114,7 @@ where
/// # Overusage /// # Overusage
/// supplying a number of threads to great may negatively impact performance as the system may not /// supplying a number of threads to great may negatively impact performance as the system may not
/// be able to full fill the required needs /// be able to full fill the required needs
pub fn with_threads(max_thread_count: NonZeroU32) -> Self { pub fn with_threads(max_thread_count: NonZeroUsize) -> Self {
Self { Self {
max_thread_count: max_thread_count.get(), max_thread_count: max_thread_count.get(),
..Default::default() ..Default::default()
@ -113,10 +128,13 @@ where
/// If `join_all` is called and the closure hasn't been executed yet, `join_all` will wait for all stalled /// If `join_all` is called and the closure hasn't been executed yet, `join_all` will wait for all stalled
/// closures be executed. /// closures be executed.
pub fn enqueue(&mut self, closure: F) { pub fn enqueue(&mut self, closure: F) {
// read used thread counter and apply all store operations with Ordering::Release
let used_threads = self.threads.load(Ordering::Acquire);
// test if we can launch a new thread // test if we can launch a new thread
if self.threads.lock().unwrap().to_owned() < self.max_thread_count { if used_threads < self.max_thread_count {
// we can create a new thread, increment the thread count // we can create a new thread, increment the thread count
self.threads.lock().unwrap().add_assign(1); self.threads
.store(used_threads.saturating_add(1), Ordering::Release);
// run new thread // run new thread
execute( execute(
self.queue.clone(), self.queue.clone(),
@ -131,6 +149,14 @@ where
} }
} }
/// Removes all closures stalled for execution.
/// All closures still waiting to be executed will be dropped by the pool and
/// won't get executed. Useful if an old set of closures hasn't run yet but are outdated
/// and resources are required immediately for updated closures.
pub fn discard_stalled(&mut self) {
self.queue.lock().unwrap().clear();
}
/// Waits for all currently running threads and all stalled closures to be executed. /// Waits for all currently running threads and all stalled closures to be executed.
/// If any closure hasn't been executed yet, `join_all` will wait until the queue holding all /// If any closure hasn't been executed yet, `join_all` will wait until the queue holding all
/// unexecuted closures is empty. It returns the result every `join` of all threads yields as a vector. /// unexecuted closures is empty. It returns the result every `join` of all threads yields as a vector.
@ -188,7 +214,7 @@ where
fn execute<F, T>( fn execute<F, T>(
queue: Arc<Mutex<VecDeque<F>>>, queue: Arc<Mutex<VecDeque<F>>>,
handles: Arc<Mutex<Vec<JoinHandle<T>>>>, handles: Arc<Mutex<Vec<JoinHandle<T>>>>,
threads: Arc<Mutex<u32>>, threads: Arc<AtomicUsize>,
closure: F, closure: F,
) where ) where
T: Send + 'static, T: Send + 'static,
@ -208,7 +234,10 @@ fn execute<F, T>(
} else { } else {
// nothing to execute this thread will run out without any work to do // nothing to execute this thread will run out without any work to do
// decrement the amount of used threads // decrement the amount of used threads
threads.lock().unwrap().sub_assign(1); threads.store(
threads.load(Ordering::Acquire).saturating_sub(1),
Ordering::Release,
)
} }
result result