diff --git a/src/multithreading/mod.rs b/src/multithreading/mod.rs new file mode 100644 index 0000000..0501061 --- /dev/null +++ b/src/multithreading/mod.rs @@ -0,0 +1,216 @@ +use std::{ + any::Any, + collections::VecDeque, + num::{NonZeroU32, NonZeroUsize}, + ops::{AddAssign, SubAssign}, + sync::{Arc, Mutex}, + thread::{self, JoinHandle}, +}; + +/// Maximum number of thread to be used by the thread pool in case all methods +/// of determining a recommend number failed +#[allow(unused)] +pub const FALLBACK_THREADS: usize = 1; + +/// Returns the number of threads to be used by the thread pool by default. +/// This function tries to fetch a recommended number by calling [`thread::available_parallelism`]. +/// In case this fails [`FALLBACK_THREADS`] will be returned +fn get_default_thread_count() -> u32 { + // number of threads to fallback to + let fallback_threads = + NonZeroUsize::new(FALLBACK_THREADS).expect("fallback_threads must be nonzero"); + // determine the maximum recommend number of threads to use + // most of the time this is gonna be the number of cpus + thread::available_parallelism() + .unwrap_or(fallback_threads) + .get() as u32 +} + +/// This struct manages a pool of threads with a fixed maximum number. +/// Any time a closure is passed to `enqueue` the pool checks whether it can +/// directly launch a new thread to execute the closure. If the maximum number +/// of threads is reached the closure is staged and will get executed by next +/// thread to be available. +/// The pool will also keep track of every `JoinHandle` created by running every closure on +/// its on thread. The closures can be obtained by either calling `join_all` or `get_finished`. +/// # Example +/// ```rust +/// let mut pool = ThreadPool::new(); +/// +/// // launch some work in parallel +/// for i in 0..10 { +/// pool.enqueue(move || { +/// println!("I am multithreaded and have id: {i}"); +/// }); +/// } +/// // wait for threads to finish +/// pool.join_all(); +/// ``` +#[allow(dead_code)] +#[derive(Debug)] +pub struct ThreadPool +where + F: Send + FnOnce() -> T + Send + 'static, +{ + /// maximum number of threads to launch at once + max_thread_count: u32, + /// handles for launched threads + handles: Arc>>>, + /// function to be executed when threads are ready + queue: Arc>>, + /// number of currently running threads + threads: Arc>, +} + +impl Default for ThreadPool +where + F: Send + FnOnce() -> T + Send + 'static, +{ + fn default() -> Self { + Self { + max_thread_count: get_default_thread_count(), + handles: Default::default(), + queue: Default::default(), + // will be initialized to 0 + threads: Default::default(), + } + } +} + +#[allow(dead_code)] +impl ThreadPool +where + F: Send + FnOnce() -> T, + T: Send + 'static, +{ + /// Create a new empty thread pool with the maximum number of threads set be the recommended amount of threads + /// supplied by [`std::thread::available_parallelism`] or in case the function fails [`FALLBACK_THREADS`]. + /// # Limitations + /// This function may assume the wrong number of threads due to the nature of [`std::thread::available_parallelism`]. + /// That can happen if the program runs inside of a container or vm with poorly configured parallelism. + pub fn new() -> Self { + Self { + max_thread_count: get_default_thread_count(), + ..Default::default() + } + } + + /// Create a new empty thread pool with the maximum number of threads set be the specified number + /// # Overusage + /// supplying a number of threads to great may negatively impact performance as the system may not + /// be able to full fill the required needs + pub fn with_threads(max_thread_count: NonZeroU32) -> Self { + Self { + max_thread_count: max_thread_count.get(), + ..Default::default() + } + } + + /// Pass a new closure to be executed as soon as a thread is available. + /// This function will execute the supplied closure immediately when the number of running threads + /// is lower than the maximum number of threads. Otherwise the closure will be executed at some undetermined time + /// in the future unless program doesn't die before. + /// If `join_all` is called and the closure hasn't been executed yet, `join_all` will wait for all stalled + /// closures be executed. + pub fn enqueue(&mut self, closure: F) { + // test if we can launch a new thread + if self.threads.lock().unwrap().to_owned() < self.max_thread_count { + // we can create a new thread, increment the thread count + self.threads.lock().unwrap().add_assign(1); + // run new thread + execute( + self.queue.clone(), + self.handles.clone(), + self.threads.clone(), + closure, + ); + } else { + // all threads being used + // enqueue closure to be launched when a thread is ready + self.queue.lock().unwrap().push_back(closure); + } + } + + /// Waits for all currently running threads and all stalled closures to be executed. + /// If any closure hasn't been executed yet, `join_all` will wait until the queue holding all + /// unexecuted closures is empty. It returns the result every `join` of all threads yields as a vector. + /// If the vector is of length zero, no threads were joined and the thread pool didn't do anything. + /// All handles of threads will be removed after this call. + pub fn join_all(&mut self) -> Vec>> { + let mut results = Vec::new(); + loop { + // lock the handles, pop the last one off and unlock handles again + // to allow running threads to process + let handle = self.handles.lock().unwrap().pop(); + + // if we still have a handle join it else no handles are left we abort the loop + if let Some(handle) = handle { + results.push(handle.join()); + continue; + } + break; + } + + results + } + + /// Returns the results of every thread that has already finished until now. + /// All other threads currently running won't be waited for nor for any closure stalled for execution in the future. + /// /// If the vector is of length zero, no threads were joined and the thread pool either doesn't do anything or is busy. + /// All handles of finished threads will be removed after this call. + pub fn get_finished(&mut self) -> Vec>> { + let mut results = Vec::new(); + + let mut handles = self.handles.lock().unwrap(); + + // loop through the handles and remove all finished handles + // join on the finished handles which will be quick as they are finished! + let mut idx = 0; + while idx < handles.len() { + if handles[idx].is_finished() { + // thread is finished, yield result + results.push(handles.remove(idx).join()); + } else { + // thread isn't done, continue to the next one + idx += 1; + } + } + + results + } +} + +/// Execute the supplied closure on a new thread +/// and store the threads handle into `handles`. When the thread +/// finished executing the closure it will look for any closures left in `queue` +/// recursively execute it on a new thread. This method updates threads` in order to +/// keep track of the number of active threads. +fn execute( + queue: Arc>>, + handles: Arc>>>, + threads: Arc>, + closure: F, +) where + T: Send + 'static, + F: Send + FnOnce() -> T + Send + 'static, +{ + let handles_copy = handles.clone(); + + handles.lock().unwrap().push(thread::spawn(move || { + // run closure (actual work) + let result = closure(); + + // take the next closure stalled for execution + let next = queue.lock().unwrap().pop_front(); + if let Some(next_closure) = next { + // if we have sth. to execute, spawn a new thread + execute(queue, handles_copy, threads, next_closure); + } else { + // nothing to execute this thread will run out without any work to do + // decrement the amount of used threads + threads.lock().unwrap().sub_assign(1); + } + + result + })); +}