replaced `threads` memeber from `Threadpool`
mutex with atomic primitive
This commit is contained in:
parent
cbf42d6d57
commit
044a3f3747
|
@ -1,9 +1,11 @@
|
||||||
use std::{
|
use std::{
|
||||||
any::Any,
|
any::Any,
|
||||||
collections::VecDeque,
|
collections::VecDeque,
|
||||||
num::{NonZeroU32, NonZeroUsize},
|
num::NonZeroUsize,
|
||||||
ops::{AddAssign, SubAssign},
|
sync::{
|
||||||
sync::{Arc, Mutex},
|
atomic::{AtomicUsize, Ordering},
|
||||||
|
Arc, Mutex,
|
||||||
|
},
|
||||||
thread::{self, JoinHandle},
|
thread::{self, JoinHandle},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -15,7 +17,7 @@ pub const FALLBACK_THREADS: usize = 1;
|
||||||
/// Returns the number of threads to be used by the thread pool by default.
|
/// Returns the number of threads to be used by the thread pool by default.
|
||||||
/// This function tries to fetch a recommended number by calling [`thread::available_parallelism`].
|
/// This function tries to fetch a recommended number by calling [`thread::available_parallelism`].
|
||||||
/// In case this fails [`FALLBACK_THREADS`] will be returned
|
/// In case this fails [`FALLBACK_THREADS`] will be returned
|
||||||
fn get_default_thread_count() -> u32 {
|
fn get_default_thread_count() -> usize {
|
||||||
// number of threads to fallback to
|
// number of threads to fallback to
|
||||||
let fallback_threads =
|
let fallback_threads =
|
||||||
NonZeroUsize::new(FALLBACK_THREADS).expect("fallback_threads must be nonzero");
|
NonZeroUsize::new(FALLBACK_THREADS).expect("fallback_threads must be nonzero");
|
||||||
|
@ -23,7 +25,7 @@ fn get_default_thread_count() -> u32 {
|
||||||
// most of the time this is gonna be the number of cpus
|
// most of the time this is gonna be the number of cpus
|
||||||
thread::available_parallelism()
|
thread::available_parallelism()
|
||||||
.unwrap_or(fallback_threads)
|
.unwrap_or(fallback_threads)
|
||||||
.get() as u32
|
.get()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This struct manages a pool of threads with a fixed maximum number.
|
/// This struct manages a pool of threads with a fixed maximum number.
|
||||||
|
@ -34,7 +36,7 @@ fn get_default_thread_count() -> u32 {
|
||||||
/// The pool will also keep track of every `JoinHandle` created by running every closure on
|
/// The pool will also keep track of every `JoinHandle` created by running every closure on
|
||||||
/// its on thread. The closures can be obtained by either calling `join_all` or `get_finished`.
|
/// its on thread. The closures can be obtained by either calling `join_all` or `get_finished`.
|
||||||
/// # Example
|
/// # Example
|
||||||
/// ```rust
|
/// ```rust ignore
|
||||||
/// let mut pool = ThreadPool::new();
|
/// let mut pool = ThreadPool::new();
|
||||||
///
|
///
|
||||||
/// // launch some work in parallel
|
/// // launch some work in parallel
|
||||||
|
@ -46,6 +48,14 @@ fn get_default_thread_count() -> u32 {
|
||||||
/// // wait for threads to finish
|
/// // wait for threads to finish
|
||||||
/// pool.join_all();
|
/// pool.join_all();
|
||||||
/// ```
|
/// ```
|
||||||
|
/// # Portability
|
||||||
|
/// This implementation is not fully platform independent. This is due to the usage of [`std::sync::atomic::AtomicUsize`].
|
||||||
|
/// This type is used to remove some locks from otherwise used [`std::sync::Mutex`] wrapping a [`usize`].
|
||||||
|
/// Note that atomic primitives are not available on all platforms but "can generally be relied upon existing"
|
||||||
|
/// (see: https://doc.rust-lang.org/std/sync/atomic/index.html).
|
||||||
|
/// Additionally this implementation relies on using the `load` and `store` operations
|
||||||
|
/// instead of using more comfortable one like `fetch_add` in order to avoid unnecessary calls
|
||||||
|
/// to `unwrap` or `expected` from [`std::sync::MutexGuard`].
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct ThreadPool<F, T>
|
pub struct ThreadPool<F, T>
|
||||||
|
@ -53,13 +63,18 @@ where
|
||||||
F: Send + FnOnce() -> T + Send + 'static,
|
F: Send + FnOnce() -> T + Send + 'static,
|
||||||
{
|
{
|
||||||
/// maximum number of threads to launch at once
|
/// maximum number of threads to launch at once
|
||||||
max_thread_count: u32,
|
max_thread_count: usize,
|
||||||
/// handles for launched threads
|
/// handles for launched threads
|
||||||
handles: Arc<Mutex<Vec<JoinHandle<T>>>>,
|
handles: Arc<Mutex<Vec<JoinHandle<T>>>>,
|
||||||
/// function to be executed when threads are ready
|
/// function to be executed when threads are ready
|
||||||
queue: Arc<Mutex<VecDeque<F>>>,
|
queue: Arc<Mutex<VecDeque<F>>>,
|
||||||
/// number of currently running threads
|
/// number of currently running threads
|
||||||
threads: Arc<Mutex<u32>>,
|
/// new implementation relies on atomic primitives to avoid locking and possible
|
||||||
|
/// guard errors. Note that atomic primitives are not available on all platforms "can generally be relied upon existing"
|
||||||
|
/// (see: https://doc.rust-lang.org/std/sync/atomic/index.html).
|
||||||
|
/// Also this implementation relies on using the `load` and `store` operations
|
||||||
|
/// instead of using more comfortable one like `fetch_add`
|
||||||
|
threads: Arc<AtomicUsize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<F, T> Default for ThreadPool<F, T>
|
impl<F, T> Default for ThreadPool<F, T>
|
||||||
|
@ -99,7 +114,7 @@ where
|
||||||
/// # Overusage
|
/// # Overusage
|
||||||
/// supplying a number of threads to great may negatively impact performance as the system may not
|
/// supplying a number of threads to great may negatively impact performance as the system may not
|
||||||
/// be able to full fill the required needs
|
/// be able to full fill the required needs
|
||||||
pub fn with_threads(max_thread_count: NonZeroU32) -> Self {
|
pub fn with_threads(max_thread_count: NonZeroUsize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
max_thread_count: max_thread_count.get(),
|
max_thread_count: max_thread_count.get(),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
|
@ -113,10 +128,13 @@ where
|
||||||
/// If `join_all` is called and the closure hasn't been executed yet, `join_all` will wait for all stalled
|
/// If `join_all` is called and the closure hasn't been executed yet, `join_all` will wait for all stalled
|
||||||
/// closures be executed.
|
/// closures be executed.
|
||||||
pub fn enqueue(&mut self, closure: F) {
|
pub fn enqueue(&mut self, closure: F) {
|
||||||
|
// read used thread counter and apply all store operations with Ordering::Release
|
||||||
|
let used_threads = self.threads.load(Ordering::Acquire);
|
||||||
// test if we can launch a new thread
|
// test if we can launch a new thread
|
||||||
if self.threads.lock().unwrap().to_owned() < self.max_thread_count {
|
if used_threads < self.max_thread_count {
|
||||||
// we can create a new thread, increment the thread count
|
// we can create a new thread, increment the thread count
|
||||||
self.threads.lock().unwrap().add_assign(1);
|
self.threads
|
||||||
|
.store(used_threads.saturating_add(1), Ordering::Release);
|
||||||
// run new thread
|
// run new thread
|
||||||
execute(
|
execute(
|
||||||
self.queue.clone(),
|
self.queue.clone(),
|
||||||
|
@ -131,6 +149,14 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Removes all closures stalled for execution.
|
||||||
|
/// All closures still waiting to be executed will be dropped by the pool and
|
||||||
|
/// won't get executed. Useful if an old set of closures hasn't run yet but are outdated
|
||||||
|
/// and resources are required immediately for updated closures.
|
||||||
|
pub fn discard_stalled(&mut self) {
|
||||||
|
self.queue.lock().unwrap().clear();
|
||||||
|
}
|
||||||
|
|
||||||
/// Waits for all currently running threads and all stalled closures to be executed.
|
/// Waits for all currently running threads and all stalled closures to be executed.
|
||||||
/// If any closure hasn't been executed yet, `join_all` will wait until the queue holding all
|
/// If any closure hasn't been executed yet, `join_all` will wait until the queue holding all
|
||||||
/// unexecuted closures is empty. It returns the result every `join` of all threads yields as a vector.
|
/// unexecuted closures is empty. It returns the result every `join` of all threads yields as a vector.
|
||||||
|
@ -188,7 +214,7 @@ where
|
||||||
fn execute<F, T>(
|
fn execute<F, T>(
|
||||||
queue: Arc<Mutex<VecDeque<F>>>,
|
queue: Arc<Mutex<VecDeque<F>>>,
|
||||||
handles: Arc<Mutex<Vec<JoinHandle<T>>>>,
|
handles: Arc<Mutex<Vec<JoinHandle<T>>>>,
|
||||||
threads: Arc<Mutex<u32>>,
|
threads: Arc<AtomicUsize>,
|
||||||
closure: F,
|
closure: F,
|
||||||
) where
|
) where
|
||||||
T: Send + 'static,
|
T: Send + 'static,
|
||||||
|
@ -208,7 +234,10 @@ fn execute<F, T>(
|
||||||
} else {
|
} else {
|
||||||
// nothing to execute this thread will run out without any work to do
|
// nothing to execute this thread will run out without any work to do
|
||||||
// decrement the amount of used threads
|
// decrement the amount of used threads
|
||||||
threads.lock().unwrap().sub_assign(1);
|
threads.store(
|
||||||
|
threads.load(Ordering::Acquire).saturating_sub(1),
|
||||||
|
Ordering::Release,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
result
|
result
|
||||||
|
|
Loading…
Reference in New Issue