diff --git a/benches/multithreading.rs b/benches/multithreading.rs index c013266..4f9c642 100644 --- a/benches/multithreading.rs +++ b/benches/multithreading.rs @@ -4,7 +4,7 @@ //! Each thread will calculate a partial dot product of two different vectors composed of 1,000,000 64-bit //! double precision floating point values. -use std::{sync::Arc}; +use std::sync::Arc; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use imsearch::multithreading::ThreadPool; @@ -37,8 +37,7 @@ fn dot(a: &[f64], b: &[f64]) -> f64 { /// sized slices which then get passed ot their own thread to compute the partial dot product. After all threads have /// finished the partial dot products will be summed to create the final result. fn dot_parallel(a: Arc>, b: Arc>, threads: usize) { - let mut pool = - ThreadPool::with_limit(threads); + let mut pool = ThreadPool::with_limit(threads); // number of elements in each vector for each thread let steps = a.len() / threads; @@ -56,7 +55,7 @@ fn dot_parallel(a: Arc>, b: Arc>, threads: usize) { dot(a, b) }); } - + pool.join_all(); black_box(pool.get_results().iter().sum::()); diff --git a/src/multithreading/mod.rs b/src/multithreading/mod.rs index e7cf0fe..d20347a 100644 --- a/src/multithreading/mod.rs +++ b/src/multithreading/mod.rs @@ -1,35 +1,114 @@ -use std::{thread::{JoinHandle, self}, sync::{mpsc::{Receiver, channel, Sender}, Mutex, Arc}, num::NonZeroUsize, collections::VecDeque}; +//! This module provides the functionality to create thread pool to execute tasks in parallel. +//! The amount of threads to be used at maximum can be regulated by using `ThreadPool::with_limit`. +//! This implementation is aimed to be of low runtime cost with minimal sychronisation due to blocking. +//! Note that no threads will be spawned until jobs are supplied to be executed. For every supplied job +//! a new thread will be launched until the maximum number is reached. By then every launched thread will +//! be reused to process the remaining elements of the queue. If no jobs are left to be executed +//! all threads will finish and die. This means that if nothing is done, no threads will run in idle in the background. +//! # Example +//! ```rust +//! # use imsearch::multithreading::ThreadPool; +//! let mut pool = ThreadPool::with_limit(2); +//! +//! for i in 0..10 { +//! pool.enqueue(move || i); +//! } +//! +//! pool.join_all(); +//! assert_eq!(pool.get_results().iter().sum::(), 45); +//! ``` -const DEFAULT_THREAD_POOL_SIZE: usize = 1; +use std::{ + collections::VecDeque, + num::NonZeroUsize, + sync::{ + mpsc::{channel, Receiver, Sender}, + Arc, Mutex, + }, + thread::{self, JoinHandle}, +}; +/// Default number if threads to be used in case [`std::thread::available_parallelism`] fails. +pub const DEFAULT_THREAD_POOL_SIZE: usize = 1; + +/// Indicates the priority level of functions or closures which get supplied to the pool. +/// Use [`Priority::High`] to ensure the closue to be executed before all closures that are already supplied +/// Use [`Priority::Low`] to ensure the closue to be executed after all closures that are already supplied +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub enum Priority { + /// Indicate that the closure or function supplied to the thread + /// has higher priority than any other given to the pool until now. + /// The item will get enqueued at the start of the waiting-queue. + High, + /// Indicate that the closure or function supplied to the thread pool + /// has lower priority than the already supplied ones in this pool. + /// The item will get enqueued at the end of the waiting-queue. + Low, +} + +/// Jobs are functions which are executed by the thread pool. They can be stalled when no threads are +/// free to execute them directly. They are meant to be executed only once and be done. pub trait Job: Send + 'static + FnOnce() -> T -where T: Send +where + T: Send, { } impl Job for U - where U: Send + 'static + FnOnce() -> T, T: Send + 'static +where + U: Send + 'static + FnOnce() -> T, + T: Send + 'static, { } +/// Thread pool which can be used to execute functions or closures in parallel. +/// The amount of threads to be used at maximum can be regulated by using `ThreadPool::with_limit`. +/// This implementation is aimed to be of low runtime cost with minimal sychronisation due to blocking. +/// Note that no threads will be spawned until jobs are supplied to be executed. For every supplied job +/// a new thread will be launched until the maximum number is reached. By then every launched thread will +/// be reused to process the remaining elements of the queue. If no jobs are left to be executed +/// all threads will finish and die. This means that if nothing is done, no threads will run in idle in the background. +/// # Example +/// ```rust +/// # use imsearch::multithreading::ThreadPool; +/// let mut pool = ThreadPool::with_limit(2); +/// +/// for i in 0..10 { +/// pool.enqueue(move || i); +/// } +/// +/// pool.join_all(); +/// assert_eq!(pool.get_results().iter().sum::(), 45); +/// ``` #[derive(Debug)] -pub struct ThreadPool - where T: Send, F: Job +pub struct ThreadPool +where + T: Send, + F: Job, { + /// queue for storing the jobs to be executed queue: Arc>>, + /// handles for all threads currently running and processing jobs handles: Vec>, + /// reciver end for channel based communication between threads receiver: Receiver, + /// sender end for channel based communication between threads sender: Sender, + /// maximum amount of threads to be used in parallel limit: NonZeroUsize, } impl Default for ThreadPool -where T: Send + 'static, F: Job +where + T: Send + 'static, + F: Job, { fn default() -> Self { let (sender, receiver) = channel::(); - let default = NonZeroUsize::new(DEFAULT_THREAD_POOL_SIZE).expect("Thread limit must be non-zero"); + // determine default thread count to use based on the system + let default = + NonZeroUsize::new(DEFAULT_THREAD_POOL_SIZE).expect("Thread limit must be non-zero"); let limit = thread::available_parallelism().unwrap_or(default); Self { @@ -43,12 +122,22 @@ where T: Send + 'static, F: Job } impl ThreadPool -where T: Send + 'static, F: Job +where + T: Send + 'static, + F: Job, { + /// Creates a new thread pool with default thread count determined by either + /// [`std::thread::available_parallelism`] or [`DEFAULT_THREAD_POOL_SIZE`] in case it fails. + /// No threads will be lauched until jobs are enqueued. pub fn new() -> Self { Default::default() } + /// Creates a new thread pool with the given thread count. The pool will continue to launch new threads even if + /// the system does not allow for that count of parallelism. + /// No threads will be lauched until jobs are enqueued. + /// # Panic + /// This function will fails if `max_threads` is zero. pub fn with_limit(max_threads: usize) -> Self { Self { limit: NonZeroUsize::new(max_threads).expect("Thread limit must be non-zero"), @@ -56,7 +145,22 @@ where T: Send + 'static, F: Job } } - pub fn enqueue(&mut self, func: F) { + /// Put a new job into the queue to be executed by a thread in the future. + /// The priority of the job will determine if the job will be put at the start or end of the queue. + /// See [`crate::multithreading::Priority`]. + /// This function will create a new thread if the maximum number of threads in not reached. + /// In case the maximum number of threads is already used, the job is stalled and will get executed + /// when a thread is ready and its at the start of the queue. + pub fn enqueue_priorize(&mut self, func: F, priority: Priority) { + // put job into queue + let mut queue = self.queue.lock().unwrap(); + + // insert new job into queue depending on its priority + match priority { + Priority::High => queue.push_front(func), + Priority::Low => queue.push_back(func), + } + if self.handles.len() < self.limit.get() { // we can still launch threads to run in parallel @@ -69,21 +173,93 @@ where T: Send + 'static, F: Job tx.send(job()).expect("cannot send result"); } })); - - } else { - self.queue.lock().unwrap().push_back(func); } self.handles.retain(|h| !h.is_finished()); } + /// Put a new job into the queue to be executed by a thread in the future. + /// The priority of the job is automatically set to [`crate::multithreading::Priority::Low`]. + /// This function will create a new thread if the maximum number of threads in not reached. + /// In case the maximum number of threads is already used, the job is stalled and will get executed + /// when a thread is ready and its at the start of the queue. + pub fn enqueue(&mut self, func: F) { + self.enqueue_priorize(func, Priority::Low); + } + + /// Wait for all threads to finish executing. This means that by the time all threads have finished + /// every task will have been executed too. In other words the threads finsish when the queue of jobs is empty. + /// This function will block the caller thread. pub fn join_all(&mut self) { while let Some(handle) = self.handles.pop() { handle.join().unwrap(); } } - pub fn get_results(&mut self) -> Vec { + /// Returns all results that have been returned by the threads until now + /// and haven't been consumed yet. + /// All results retrieved from this call won't be returned on a second call. + /// This function is non blocking. + pub fn try_get_results(&mut self) -> Vec { self.receiver.try_iter().collect() } -} \ No newline at end of file + + /// Returns all results that have been returned by the threads until now + /// and haven't been consumed yet. The function will also wait for all threads to finish executing (empty the queue). + /// All results retrieved from this call won't be returned on a second call. + /// This function will block the caller thread. + pub fn get_results(&mut self) -> Vec { + self.join_all(); + self.try_get_results() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_default() { + let mut pool = ThreadPool::default(); + + for i in 0..10 { + pool.enqueue_priorize(move || i, Priority::High); + } + + pool.join_all(); + + assert_eq!(pool.try_get_results().iter().sum::(), 45); + } + + #[test] + fn test_limit() { + let mut pool = ThreadPool::with_limit(2); + + for i in 0..10 { + pool.enqueue(move || i); + } + + assert_eq!(pool.handles.len(), 2); + assert_eq!(pool.limit.get(), 2); + + pool.join_all(); + + assert_eq!(pool.get_results().iter().sum::(), 45); + } + + #[test] + fn test_multiple() { + let mut pool = ThreadPool::with_limit(2); + + for i in 0..10 { + pool.enqueue(move || i); + } + + assert_eq!(pool.handles.len(), 2); + assert_eq!(pool.limit.get(), 2); + + pool.join_all(); + + assert_eq!(pool.get_results().iter().sum::(), 45); + } +}