replaced `threads` memeber from `Threadpool`

mutex with atomic primitive
This commit is contained in:
Sven Vogel 2023-05-24 12:11:48 +02:00
parent cbf42d6d57
commit 044a3f3747
1 changed files with 42 additions and 13 deletions

View File

@ -1,9 +1,11 @@
use std::{
any::Any,
collections::VecDeque,
num::{NonZeroU32, NonZeroUsize},
ops::{AddAssign, SubAssign},
sync::{Arc, Mutex},
num::NonZeroUsize,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
},
thread::{self, JoinHandle},
};
@ -15,7 +17,7 @@ pub const FALLBACK_THREADS: usize = 1;
/// Returns the number of threads to be used by the thread pool by default.
/// This function tries to fetch a recommended number by calling [`thread::available_parallelism`].
/// In case this fails [`FALLBACK_THREADS`] will be returned
fn get_default_thread_count() -> u32 {
fn get_default_thread_count() -> usize {
// number of threads to fallback to
let fallback_threads =
NonZeroUsize::new(FALLBACK_THREADS).expect("fallback_threads must be nonzero");
@ -23,7 +25,7 @@ fn get_default_thread_count() -> u32 {
// most of the time this is gonna be the number of cpus
thread::available_parallelism()
.unwrap_or(fallback_threads)
.get() as u32
.get()
}
/// This struct manages a pool of threads with a fixed maximum number.
@ -34,7 +36,7 @@ fn get_default_thread_count() -> u32 {
/// The pool will also keep track of every `JoinHandle` created by running every closure on
/// its on thread. The closures can be obtained by either calling `join_all` or `get_finished`.
/// # Example
/// ```rust
/// ```rust ignore
/// let mut pool = ThreadPool::new();
///
/// // launch some work in parallel
@ -46,6 +48,14 @@ fn get_default_thread_count() -> u32 {
/// // wait for threads to finish
/// pool.join_all();
/// ```
/// # Portability
/// This implementation is not fully platform independent. This is due to the usage of [`std::sync::atomic::AtomicUsize`].
/// This type is used to remove some locks from otherwise used [`std::sync::Mutex`] wrapping a [`usize`].
/// Note that atomic primitives are not available on all platforms but "can generally be relied upon existing"
/// (see: https://doc.rust-lang.org/std/sync/atomic/index.html).
/// Additionally this implementation relies on using the `load` and `store` operations
/// instead of using more comfortable one like `fetch_add` in order to avoid unnecessary calls
/// to `unwrap` or `expected` from [`std::sync::MutexGuard`].
#[allow(dead_code)]
#[derive(Debug)]
pub struct ThreadPool<F, T>
@ -53,13 +63,18 @@ where
F: Send + FnOnce() -> T + Send + 'static,
{
/// maximum number of threads to launch at once
max_thread_count: u32,
max_thread_count: usize,
/// handles for launched threads
handles: Arc<Mutex<Vec<JoinHandle<T>>>>,
/// function to be executed when threads are ready
queue: Arc<Mutex<VecDeque<F>>>,
/// number of currently running threads
threads: Arc<Mutex<u32>>,
/// new implementation relies on atomic primitives to avoid locking and possible
/// guard errors. Note that atomic primitives are not available on all platforms "can generally be relied upon existing"
/// (see: https://doc.rust-lang.org/std/sync/atomic/index.html).
/// Also this implementation relies on using the `load` and `store` operations
/// instead of using more comfortable one like `fetch_add`
threads: Arc<AtomicUsize>,
}
impl<F, T> Default for ThreadPool<F, T>
@ -99,7 +114,7 @@ where
/// # Overusage
/// supplying a number of threads to great may negatively impact performance as the system may not
/// be able to full fill the required needs
pub fn with_threads(max_thread_count: NonZeroU32) -> Self {
pub fn with_threads(max_thread_count: NonZeroUsize) -> Self {
Self {
max_thread_count: max_thread_count.get(),
..Default::default()
@ -113,10 +128,13 @@ where
/// If `join_all` is called and the closure hasn't been executed yet, `join_all` will wait for all stalled
/// closures be executed.
pub fn enqueue(&mut self, closure: F) {
// read used thread counter and apply all store operations with Ordering::Release
let used_threads = self.threads.load(Ordering::Acquire);
// test if we can launch a new thread
if self.threads.lock().unwrap().to_owned() < self.max_thread_count {
if used_threads < self.max_thread_count {
// we can create a new thread, increment the thread count
self.threads.lock().unwrap().add_assign(1);
self.threads
.store(used_threads.saturating_add(1), Ordering::Release);
// run new thread
execute(
self.queue.clone(),
@ -131,6 +149,14 @@ where
}
}
/// Removes all closures stalled for execution.
/// All closures still waiting to be executed will be dropped by the pool and
/// won't get executed. Useful if an old set of closures hasn't run yet but are outdated
/// and resources are required immediately for updated closures.
pub fn discard_stalled(&mut self) {
self.queue.lock().unwrap().clear();
}
/// Waits for all currently running threads and all stalled closures to be executed.
/// If any closure hasn't been executed yet, `join_all` will wait until the queue holding all
/// unexecuted closures is empty. It returns the result every `join` of all threads yields as a vector.
@ -188,7 +214,7 @@ where
fn execute<F, T>(
queue: Arc<Mutex<VecDeque<F>>>,
handles: Arc<Mutex<Vec<JoinHandle<T>>>>,
threads: Arc<Mutex<u32>>,
threads: Arc<AtomicUsize>,
closure: F,
) where
T: Send + 'static,
@ -208,7 +234,10 @@ fn execute<F, T>(
} else {
// nothing to execute this thread will run out without any work to do
// decrement the amount of used threads
threads.lock().unwrap().sub_assign(1);
threads.store(
threads.load(Ordering::Acquire).saturating_sub(1),
Ordering::Release,
)
}
result