diff --git a/benches/multithreading.rs b/benches/multithreading.rs index 4f9c642..6e2eac9 100644 --- a/benches/multithreading.rs +++ b/benches/multithreading.rs @@ -7,7 +7,7 @@ use std::sync::Arc; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; -use imsearch::multithreading::ThreadPool; +use imsearch::multithreading::{Task, ThreadPool}; /// Amount of elements per vector used to calculate the dot product const VEC_ELEM_COUNT: usize = 1_000_000; @@ -45,15 +45,15 @@ fn dot_parallel(a: Arc>, b: Arc>, threads: usize) { for i in 0..threads { // offset of the first element for the thread local vec let chunk = i * steps; - // create a new strong reference to the vector - let aa = a.clone(); - let bb = b.clone(); // launch a new thread - pool.enqueue(move || { - let a = &aa[chunk..(chunk + steps)]; - let b = &bb[chunk..(chunk + steps)]; - dot(a, b) - }); + pool.enqueue(Task::new( + (chunk, steps, a.clone(), b.clone()), + |(block, inc, a, b)| { + let a = &a[block..(block + inc)]; + let b = &b[block..(block + inc)]; + dot(a, b) + }, + )); } pool.join_all(); @@ -115,15 +115,15 @@ fn pool_overusage(a: Arc>, b: Arc>, threads: usize) { for i in 0..threads { // offset of the first element for the thread local vec let chunk = i * steps; - // create a new strong reference to the vector - let aa = a.clone(); - let bb = b.clone(); // launch a new thread - pool.enqueue(move || { - let a = &aa[chunk..(chunk + steps)]; - let b = &bb[chunk..(chunk + steps)]; - dot(a, b) - }); + pool.enqueue(Task::new( + (chunk, steps, a.clone(), b.clone()), + |(block, inc, a, b)| { + let a = &a[block..(block + inc)]; + let b = &b[block..(block + inc)]; + dot(a, b) + }, + )); } pool.join_all(); diff --git a/src/multithreading/mod.rs b/src/multithreading/mod.rs index d20347a..f099cf5 100644 --- a/src/multithreading/mod.rs +++ b/src/multithreading/mod.rs @@ -8,10 +8,12 @@ //! # Example //! ```rust //! # use imsearch::multithreading::ThreadPool; +//! # use imsearch::multithreading::Task; //! let mut pool = ThreadPool::with_limit(2); //! //! for i in 0..10 { -//! pool.enqueue(move || i); +//! pool.enqueue(Task::new(i, |i| i)); +//! // ^^^^^^ closure or static function //! } //! //! pool.join_all(); @@ -19,6 +21,7 @@ //! ``` use std::{ + any::Any, collections::VecDeque, num::NonZeroUsize, sync::{ @@ -46,19 +49,31 @@ pub enum Priority { Low, } -/// Jobs are functions which are executed by the thread pool. They can be stalled when no threads are -/// free to execute them directly. They are meant to be executed only once and be done. -pub trait Job: Send + 'static + FnOnce() -> T +/// Traits a return value has to implement when being given back by a function or closure. +pub trait Sendable: Any + Send + 'static + std::panic::UnwindSafe {} + +impl Sendable for T where T: Any + Send + 'static + std::panic::UnwindSafe {} + +/// A task that will be executed at some point in the future by the thread pool +/// At the heart of this struct is the function to be executed. This may be a closure. +#[derive(Debug, Copy, Clone)] +pub struct Task where - T: Send, + I: Sendable, + T: Sendable, { + job: fn(I) -> T, + param: I, } -impl Job for U +impl Task where - U: Send + 'static + FnOnce() -> T, - T: Send + 'static, + I: Sendable, + T: Sendable, { + pub fn new(param: I, job: fn(I) -> T) -> Self { + Self { job, param } + } } /// Thread pool which can be used to execute functions or closures in parallel. @@ -71,23 +86,33 @@ where /// # Example /// ```rust /// # use imsearch::multithreading::ThreadPool; +/// # use imsearch::multithreading::Task; /// let mut pool = ThreadPool::with_limit(2); /// /// for i in 0..10 { -/// pool.enqueue(move || i); +/// pool.enqueue(Task::new(i, |i| i)); /// } /// /// pool.join_all(); /// assert_eq!(pool.get_results().iter().sum::(), 45); /// ``` +/// # Drop +/// This struct implements the `Drop` trait. Upon being dropped the pool will wait for all threads +/// to finsish. This may take up an arbitrary amount of time. +/// # Panics in the thread +/// When a function or closure panics, the executing thread will detect the unwind performed by `panic` causing the +/// thread to print a message on stderr. The thread itself captures panics and won't terminate execution but continue with +/// the next task in the queue. +/// Its not recommend to use this pool with custom panic hooks or special functions which abort the process. +/// Also panicking code from external program written in C++ or others in undefinied behavior according to [`std::panic::catch_unwind`] #[derive(Debug)] -pub struct ThreadPool +pub struct ThreadPool where - T: Send, - F: Job, + I: Sendable, + T: Sendable, { /// queue for storing the jobs to be executed - queue: Arc>>, + queue: Arc>>>, /// handles for all threads currently running and processing jobs handles: Vec>, /// reciver end for channel based communication between threads @@ -98,10 +123,10 @@ where limit: NonZeroUsize, } -impl Default for ThreadPool +impl Default for ThreadPool where - T: Send + 'static, - F: Job, + I: Sendable, + T: Sendable, { fn default() -> Self { let (sender, receiver) = channel::(); @@ -121,10 +146,10 @@ where } } -impl ThreadPool +impl ThreadPool where - T: Send + 'static, - F: Job, + I: Sendable, + T: Sendable, { /// Creates a new thread pool with default thread count determined by either /// [`std::thread::available_parallelism`] or [`DEFAULT_THREAD_POOL_SIZE`] in case it fails. @@ -139,9 +164,13 @@ where /// # Panic /// This function will fails if `max_threads` is zero. pub fn with_limit(max_threads: usize) -> Self { + let (sender, receiver) = channel::(); Self { limit: NonZeroUsize::new(max_threads).expect("Thread limit must be non-zero"), - ..Default::default() + queue: Arc::new(Mutex::new(VecDeque::new())), + handles: Vec::new(), + sender, + receiver, } } @@ -151,7 +180,7 @@ where /// This function will create a new thread if the maximum number of threads in not reached. /// In case the maximum number of threads is already used, the job is stalled and will get executed /// when a thread is ready and its at the start of the queue. - pub fn enqueue_priorize(&mut self, func: F, priority: Priority) { + pub fn enqueue_priorize(&mut self, func: Task, priority: Priority) { // put job into queue let mut queue = self.queue.lock().unwrap(); @@ -169,8 +198,14 @@ where let queue = self.queue.clone(); self.handles.push(thread::spawn(move || { - while let Some(job) = queue.lock().unwrap().pop_front() { - tx.send(job()).expect("cannot send result"); + while let Some(task) = queue.lock().unwrap().pop_front() { + // basically try catch + if let Err(e) = std::panic::catch_unwind(|| { + tx.send((task.job)(task.param)) + .expect("unable to send result over channel"); + }) { + eprintln!("thread paniced: {:?}", e); + } } })); } @@ -183,7 +218,7 @@ where /// This function will create a new thread if the maximum number of threads in not reached. /// In case the maximum number of threads is already used, the job is stalled and will get executed /// when a thread is ready and its at the start of the queue. - pub fn enqueue(&mut self, func: F) { + pub fn enqueue(&mut self, func: Task) { self.enqueue_priorize(func, Priority::Low); } @@ -196,15 +231,15 @@ where } } - /// Returns all results that have been returned by the threads until now + /// Sendables all results that have been Sendableed by the threads until now /// and haven't been consumed yet. - /// All results retrieved from this call won't be returned on a second call. + /// All results retrieved from this call won't be Sendableed on a second call. /// This function is non blocking. pub fn try_get_results(&mut self) -> Vec { self.receiver.try_iter().collect() } - /// Returns all results that have been returned by the threads until now + /// Sendables all results that have been returned by the threads until now /// and haven't been consumed yet. The function will also wait for all threads to finish executing (empty the queue). /// All results retrieved from this call won't be returned on a second call. /// This function will block the caller thread. @@ -214,8 +249,20 @@ where } } +impl Drop for ThreadPool +where + I: Sendable, + T: Sendable, +{ + fn drop(&mut self) { + self.join_all(); + } +} + #[cfg(test)] mod test { + use std::panic::UnwindSafe; + use super::*; #[test] @@ -223,7 +270,7 @@ mod test { let mut pool = ThreadPool::default(); for i in 0..10 { - pool.enqueue_priorize(move || i, Priority::High); + pool.enqueue_priorize(Task::new(i, |i| i), Priority::High); } pool.join_all(); @@ -236,7 +283,7 @@ mod test { let mut pool = ThreadPool::with_limit(2); for i in 0..10 { - pool.enqueue(move || i); + pool.enqueue(Task::new(i, |i| i)); } assert_eq!(pool.handles.len(), 2); @@ -247,19 +294,62 @@ mod test { assert_eq!(pool.get_results().iter().sum::(), 45); } + trait Object: Send + UnwindSafe { + fn get(&mut self) -> i32; + } + + #[derive(Default)] + struct Test1 { + _int: i32, + } + + impl Object for Test1 { + fn get(&mut self) -> i32 { + 0 + } + } + + #[derive(Default)] + struct Test2 { + _c: char, + } + + impl Object for Test2 { + fn get(&mut self) -> i32 { + 0 + } + } + + #[derive(Default)] + struct Test3 { + _s: String, + } + + impl Object for Test3 { + fn get(&mut self) -> i32 { + 0 + } + } + #[test] - fn test_multiple() { + fn test_1() { let mut pool = ThreadPool::with_limit(2); - for i in 0..10 { - pool.enqueue(move || i); - } + let feats: Vec> = vec![ + Box::new(Test1::default()), + Box::new(Test2::default()), + Box::new(Test3::default()), + ]; - assert_eq!(pool.handles.len(), 2); - assert_eq!(pool.limit.get(), 2); + for feat in feats { + pool.enqueue(Task::new(feat, |mut i| { + let _ = i.get(); + i + })); + } pool.join_all(); - assert_eq!(pool.get_results().iter().sum::(), 45); + let _feats: Vec> = pool.get_results(); } }