//! This module provides the functionality to create thread pool to execute tasks in parallel. //! The amount of threads to be used at maximum can be regulated by using `ThreadPool::with_limit`. //! This implementation is aimed to be of low runtime cost with minimal sychronisation due to blocking. //! Note that no threads will be spawned until jobs are supplied to be executed. For every supplied job //! a new thread will be launched until the maximum number is reached. By then every launched thread will //! be reused to process the remaining elements of the queue. If no jobs are left to be executed //! all threads will finish and die. This means that if nothing is done, no threads will run in idle in the background. //! # Example //! ```rust //! # use imsearch::multithreading::ThreadPool; //! # use imsearch::multithreading::Task; //! let mut pool = ThreadPool::with_limit(2); //! //! for i in 0..10 { //! pool.enqueue(Task::new(i, |i| i)); //! // ^^^^^^ closure or static function //! } //! //! pool.join_all(); //! assert_eq!(pool.get_results().iter().sum::(), 45); //! ``` use std::{ any::Any, collections::VecDeque, num::NonZeroUsize, sync::{ mpsc::{channel, Receiver, Sender}, Arc, Mutex, }, thread::{self, JoinHandle}, }; /// Default number if threads to be used in case [`std::thread::available_parallelism`] fails. pub const DEFAULT_THREAD_POOL_SIZE: usize = 1; /// Indicates the priority level of functions or closures which get supplied to the pool. /// Use [`Priority::High`] to ensure the closue to be executed before all closures that are already supplied /// Use [`Priority::Low`] to ensure the closue to be executed after all closures that are already supplied #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] pub enum Priority { /// Indicate that the closure or function supplied to the thread /// has higher priority than any other given to the pool until now. /// The item will get enqueued at the start of the waiting-queue. High, /// Indicate that the closure or function supplied to the thread pool /// has lower priority than the already supplied ones in this pool. /// The item will get enqueued at the end of the waiting-queue. Low, } /// Traits a return value has to implement when being given back by a function or closure. pub trait Sendable: Any + Send + 'static {} impl Sendable for T where T: Any + Send + 'static {} /// A task that will be executed at some point in the future by the thread pool /// At the heart of this struct is the function to be executed. This may be a closure. #[derive(Debug, Copy, Clone)] pub struct Task where I: Sendable, T: Sendable, { job: fn(I) -> T, param: I, } impl Task where I: Sendable, T: Sendable, { pub fn new(param: I, job: fn(I) -> T) -> Self { Self { job, param } } } /// Thread pool which can be used to execute functions or closures in parallel. /// The amount of threads to be used at maximum can be regulated by using `ThreadPool::with_limit`. /// This implementation is aimed to be of low runtime cost with minimal sychronisation due to blocking. /// Note that no threads will be spawned until jobs are supplied to be executed. For every supplied job /// a new thread will be launched until the maximum number is reached. By then every launched thread will /// be reused to process the remaining elements of the queue. If no jobs are left to be executed /// all threads will finish and die. This means that if nothing is done, no threads will run in idle in the background. /// # Example /// ```rust /// # use imsearch::multithreading::ThreadPool; /// # use imsearch::multithreading::Task; /// let mut pool = ThreadPool::with_limit(2); /// /// for i in 0..10 { /// pool.enqueue(Task::new(i, |i| i)); /// } /// /// pool.join_all(); /// assert_eq!(pool.get_results().iter().sum::(), 45); /// ``` /// # Drop /// This struct implements the `Drop` trait. Upon being dropped the pool will wait for all threads /// to finsish. This may take up an arbitrary amount of time. /// # Panics in the thread /// When a function or closure panics, the executing thread will detect the unwind performed by `panic` causing the /// thread to print a message on stderr. The thread itself captures panics and won't terminate execution but continue with /// the next task in the queue. /// Its not recommend to use this pool with custom panic hooks or special functions which abort the process. /// Also panicking code from external program written in C++ or others in undefinied behavior according to [`std::panic::catch_unwind`] #[derive(Debug)] pub struct ThreadPool where I: Sendable, T: Sendable, { /// queue for storing the jobs to be executed queue: Arc>>>, /// handles for all threads currently running and processing jobs handles: Vec>, /// reciver end for channel based communication between threads receiver: Receiver, /// sender end for channel based communication between threads sender: Sender, /// maximum amount of threads to be used in parallel limit: NonZeroUsize, } impl Default for ThreadPool where I: Sendable, T: Sendable, { fn default() -> Self { let (sender, receiver) = channel::(); // determine default thread count to use based on the system let default = NonZeroUsize::new(DEFAULT_THREAD_POOL_SIZE).expect("Thread limit must be non-zero"); let limit = thread::available_parallelism().unwrap_or(default); Self { queue: Arc::new(Mutex::new(VecDeque::new())), handles: Vec::new(), receiver, sender, limit, } } } impl ThreadPool where I: Sendable, T: Sendable, { /// Creates a new thread pool with default thread count determined by either /// [`std::thread::available_parallelism`] or [`DEFAULT_THREAD_POOL_SIZE`] in case it fails. /// No threads will be lauched until jobs are enqueued. pub fn new() -> Self { Default::default() } /// Creates a new thread pool with the given thread count. The pool will continue to launch new threads even if /// the system does not allow for that count of parallelism. /// No threads will be lauched until jobs are enqueued. /// # Panic /// This function will fails if `max_threads` is zero. pub fn with_limit(max_threads: usize) -> Self { let (sender, receiver) = channel::(); Self { limit: NonZeroUsize::new(max_threads).expect("Thread limit must be non-zero"), queue: Arc::new(Mutex::new(VecDeque::new())), handles: Vec::new(), sender, receiver, } } /// Put a new job into the queue to be executed by a thread in the future. /// The priority of the job will determine if the job will be put at the start or end of the queue. /// See [`crate::multithreading::Priority`]. /// This function will create a new thread if the maximum number of threads in not reached. /// In case the maximum number of threads is already used, the job is stalled and will get executed /// when a thread is ready and its at the start of the queue. pub fn enqueue_priorize(&mut self, func: Task, priority: Priority) { // put job into queue let mut queue = self.queue.lock().unwrap(); // insert new job into queue depending on its priority match priority { Priority::High => queue.push_front(func), Priority::Low => queue.push_back(func), } if self.handles.len() < self.limit.get() { // we can still launch threads to run in parallel // clone the sender let tx = self.sender.clone(); let queue = self.queue.clone(); self.handles.push(thread::spawn(move || { while let Some(task) = queue.lock().unwrap().pop_front() { tx.send((task.job)(task.param)) .expect("unable to send result over channel"); } })); } self.handles.retain(|h| !h.is_finished()); } /// Put a new job into the queue to be executed by a thread in the future. /// The priority of the job is automatically set to [`crate::multithreading::Priority::Low`]. /// This function will create a new thread if the maximum number of threads in not reached. /// In case the maximum number of threads is already used, the job is stalled and will get executed /// when a thread is ready and its at the start of the queue. pub fn enqueue(&mut self, func: Task) { self.enqueue_priorize(func, Priority::Low); } /// Wait for all threads to finish executing. This means that by the time all threads have finished /// every task will have been executed too. In other words the threads finsish when the queue of jobs is empty. /// This function will block the caller thread. pub fn join_all(&mut self) { while let Some(handle) = self.handles.pop() { handle.join().unwrap(); } } /// Sendables all results that have been Sendableed by the threads until now /// and haven't been consumed yet. /// All results retrieved from this call won't be Sendableed on a second call. /// This function is non blocking. pub fn try_get_results(&mut self) -> Vec { self.receiver.try_iter().collect() } /// Sendables all results that have been returned by the threads until now /// and haven't been consumed yet. The function will also wait for all threads to finish executing (empty the queue). /// All results retrieved from this call won't be returned on a second call. /// This function will block the caller thread. pub fn get_results(&mut self) -> Vec { self.join_all(); self.try_get_results() } } impl Drop for ThreadPool where I: Sendable, T: Sendable, { fn drop(&mut self) { self.join_all(); } } #[cfg(test)] mod test { use std::panic::UnwindSafe; use super::*; #[test] fn test_default() { let mut pool = ThreadPool::default(); for i in 0..10 { pool.enqueue_priorize(Task::new(i, |i| i), Priority::High); } pool.join_all(); assert_eq!(pool.try_get_results().iter().sum::(), 45); } #[test] fn test_limit() { let mut pool = ThreadPool::with_limit(2); for i in 0..10 { pool.enqueue(Task::new(i, |i| i)); } assert_eq!(pool.handles.len(), 2); assert_eq!(pool.limit.get(), 2); pool.join_all(); assert_eq!(pool.get_results().iter().sum::(), 45); } trait Object: Send + UnwindSafe { fn get(&mut self) -> i32; } #[derive(Default)] struct Test1 { _int: i32, } impl Object for Test1 { fn get(&mut self) -> i32 { 0 } } #[derive(Default)] struct Test2 { _c: char, } impl Object for Test2 { fn get(&mut self) -> i32 { 0 } } #[derive(Default)] struct Test3 { _s: String, } impl Object for Test3 { fn get(&mut self) -> i32 { 0 } } #[test] fn test_1() { let mut pool = ThreadPool::with_limit(2); let feats: Vec> = vec![ Box::new(Test1::default()), Box::new(Test2::default()), Box::new(Test3::default()), ]; for feat in feats { pool.enqueue(Task::new(feat, |mut i| { let _ = i.get(); i })); } pool.join_all(); let _feats: Vec> = pool.get_results(); } }