added multithreading crate with thread pool
This commit is contained in:
parent
e24028ffc6
commit
e48176707a
|
@ -0,0 +1,216 @@
|
||||||
|
use std::{
|
||||||
|
any::Any,
|
||||||
|
collections::VecDeque,
|
||||||
|
num::{NonZeroU32, NonZeroUsize},
|
||||||
|
ops::{AddAssign, SubAssign},
|
||||||
|
sync::{Arc, Mutex},
|
||||||
|
thread::{self, JoinHandle},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Maximum number of thread to be used by the thread pool in case all methods
|
||||||
|
/// of determining a recommend number failed
|
||||||
|
#[allow(unused)]
|
||||||
|
pub const FALLBACK_THREADS: usize = 1;
|
||||||
|
|
||||||
|
/// Returns the number of threads to be used by the thread pool by default.
|
||||||
|
/// This function tries to fetch a recommended number by calling [`thread::available_parallelism`].
|
||||||
|
/// In case this fails [`FALLBACK_THREADS`] will be returned
|
||||||
|
fn get_default_thread_count() -> u32 {
|
||||||
|
// number of threads to fallback to
|
||||||
|
let fallback_threads =
|
||||||
|
NonZeroUsize::new(FALLBACK_THREADS).expect("fallback_threads must be nonzero");
|
||||||
|
// determine the maximum recommend number of threads to use
|
||||||
|
// most of the time this is gonna be the number of cpus
|
||||||
|
thread::available_parallelism()
|
||||||
|
.unwrap_or(fallback_threads)
|
||||||
|
.get() as u32
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This struct manages a pool of threads with a fixed maximum number.
|
||||||
|
/// Any time a closure is passed to `enqueue` the pool checks whether it can
|
||||||
|
/// directly launch a new thread to execute the closure. If the maximum number
|
||||||
|
/// of threads is reached the closure is staged and will get executed by next
|
||||||
|
/// thread to be available.
|
||||||
|
/// The pool will also keep track of every `JoinHandle` created by running every closure on
|
||||||
|
/// its on thread. The closures can be obtained by either calling `join_all` or `get_finished`.
|
||||||
|
/// # Example
|
||||||
|
/// ```rust
|
||||||
|
/// let mut pool = ThreadPool::new();
|
||||||
|
///
|
||||||
|
/// // launch some work in parallel
|
||||||
|
/// for i in 0..10 {
|
||||||
|
/// pool.enqueue(move || {
|
||||||
|
/// println!("I am multithreaded and have id: {i}");
|
||||||
|
/// });
|
||||||
|
/// }
|
||||||
|
/// // wait for threads to finish
|
||||||
|
/// pool.join_all();
|
||||||
|
/// ```
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ThreadPool<F, T>
|
||||||
|
where
|
||||||
|
F: Send + FnOnce() -> T + Send + 'static,
|
||||||
|
{
|
||||||
|
/// maximum number of threads to launch at once
|
||||||
|
max_thread_count: u32,
|
||||||
|
/// handles for launched threads
|
||||||
|
handles: Arc<Mutex<Vec<JoinHandle<T>>>>,
|
||||||
|
/// function to be executed when threads are ready
|
||||||
|
queue: Arc<Mutex<VecDeque<F>>>,
|
||||||
|
/// number of currently running threads
|
||||||
|
threads: Arc<Mutex<u32>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F, T> Default for ThreadPool<F, T>
|
||||||
|
where
|
||||||
|
F: Send + FnOnce() -> T + Send + 'static,
|
||||||
|
{
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
max_thread_count: get_default_thread_count(),
|
||||||
|
handles: Default::default(),
|
||||||
|
queue: Default::default(),
|
||||||
|
// will be initialized to 0
|
||||||
|
threads: Default::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
impl<F, T> ThreadPool<F, T>
|
||||||
|
where
|
||||||
|
F: Send + FnOnce() -> T,
|
||||||
|
T: Send + 'static,
|
||||||
|
{
|
||||||
|
/// Create a new empty thread pool with the maximum number of threads set be the recommended amount of threads
|
||||||
|
/// supplied by [`std::thread::available_parallelism`] or in case the function fails [`FALLBACK_THREADS`].
|
||||||
|
/// # Limitations
|
||||||
|
/// This function may assume the wrong number of threads due to the nature of [`std::thread::available_parallelism`].
|
||||||
|
/// That can happen if the program runs inside of a container or vm with poorly configured parallelism.
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
max_thread_count: get_default_thread_count(),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new empty thread pool with the maximum number of threads set be the specified number
|
||||||
|
/// # Overusage
|
||||||
|
/// supplying a number of threads to great may negatively impact performance as the system may not
|
||||||
|
/// be able to full fill the required needs
|
||||||
|
pub fn with_threads(max_thread_count: NonZeroU32) -> Self {
|
||||||
|
Self {
|
||||||
|
max_thread_count: max_thread_count.get(),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Pass a new closure to be executed as soon as a thread is available.
|
||||||
|
/// This function will execute the supplied closure immediately when the number of running threads
|
||||||
|
/// is lower than the maximum number of threads. Otherwise the closure will be executed at some undetermined time
|
||||||
|
/// in the future unless program doesn't die before.
|
||||||
|
/// If `join_all` is called and the closure hasn't been executed yet, `join_all` will wait for all stalled
|
||||||
|
/// closures be executed.
|
||||||
|
pub fn enqueue(&mut self, closure: F) {
|
||||||
|
// test if we can launch a new thread
|
||||||
|
if self.threads.lock().unwrap().to_owned() < self.max_thread_count {
|
||||||
|
// we can create a new thread, increment the thread count
|
||||||
|
self.threads.lock().unwrap().add_assign(1);
|
||||||
|
// run new thread
|
||||||
|
execute(
|
||||||
|
self.queue.clone(),
|
||||||
|
self.handles.clone(),
|
||||||
|
self.threads.clone(),
|
||||||
|
closure,
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
// all threads being used
|
||||||
|
// enqueue closure to be launched when a thread is ready
|
||||||
|
self.queue.lock().unwrap().push_back(closure);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Waits for all currently running threads and all stalled closures to be executed.
|
||||||
|
/// If any closure hasn't been executed yet, `join_all` will wait until the queue holding all
|
||||||
|
/// unexecuted closures is empty. It returns the result every `join` of all threads yields as a vector.
|
||||||
|
/// If the vector is of length zero, no threads were joined and the thread pool didn't do anything.
|
||||||
|
/// All handles of threads will be removed after this call.
|
||||||
|
pub fn join_all(&mut self) -> Vec<Result<T, Box<dyn Any + Send>>> {
|
||||||
|
let mut results = Vec::new();
|
||||||
|
loop {
|
||||||
|
// lock the handles, pop the last one off and unlock handles again
|
||||||
|
// to allow running threads to process
|
||||||
|
let handle = self.handles.lock().unwrap().pop();
|
||||||
|
|
||||||
|
// if we still have a handle join it else no handles are left we abort the loop
|
||||||
|
if let Some(handle) = handle {
|
||||||
|
results.push(handle.join());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
results
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the results of every thread that has already finished until now.
|
||||||
|
/// All other threads currently running won't be waited for nor for any closure stalled for execution in the future.
|
||||||
|
/// /// If the vector is of length zero, no threads were joined and the thread pool either doesn't do anything or is busy.
|
||||||
|
/// All handles of finished threads will be removed after this call.
|
||||||
|
pub fn get_finished(&mut self) -> Vec<Result<T, Box<dyn Any + Send>>> {
|
||||||
|
let mut results = Vec::new();
|
||||||
|
|
||||||
|
let mut handles = self.handles.lock().unwrap();
|
||||||
|
|
||||||
|
// loop through the handles and remove all finished handles
|
||||||
|
// join on the finished handles which will be quick as they are finished!
|
||||||
|
let mut idx = 0;
|
||||||
|
while idx < handles.len() {
|
||||||
|
if handles[idx].is_finished() {
|
||||||
|
// thread is finished, yield result
|
||||||
|
results.push(handles.remove(idx).join());
|
||||||
|
} else {
|
||||||
|
// thread isn't done, continue to the next one
|
||||||
|
idx += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute the supplied closure on a new thread
|
||||||
|
/// and store the threads handle into `handles`. When the thread
|
||||||
|
/// finished executing the closure it will look for any closures left in `queue`
|
||||||
|
/// recursively execute it on a new thread. This method updates threads` in order to
|
||||||
|
/// keep track of the number of active threads.
|
||||||
|
fn execute<F, T>(
|
||||||
|
queue: Arc<Mutex<VecDeque<F>>>,
|
||||||
|
handles: Arc<Mutex<Vec<JoinHandle<T>>>>,
|
||||||
|
threads: Arc<Mutex<u32>>,
|
||||||
|
closure: F,
|
||||||
|
) where
|
||||||
|
T: Send + 'static,
|
||||||
|
F: Send + FnOnce() -> T + Send + 'static,
|
||||||
|
{
|
||||||
|
let handles_copy = handles.clone();
|
||||||
|
|
||||||
|
handles.lock().unwrap().push(thread::spawn(move || {
|
||||||
|
// run closure (actual work)
|
||||||
|
let result = closure();
|
||||||
|
|
||||||
|
// take the next closure stalled for execution
|
||||||
|
let next = queue.lock().unwrap().pop_front();
|
||||||
|
if let Some(next_closure) = next {
|
||||||
|
// if we have sth. to execute, spawn a new thread
|
||||||
|
execute(queue, handles_copy, threads, next_closure);
|
||||||
|
} else {
|
||||||
|
// nothing to execute this thread will run out without any work to do
|
||||||
|
// decrement the amount of used threads
|
||||||
|
threads.lock().unwrap().sub_assign(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}));
|
||||||
|
}
|
Loading…
Reference in New Issue