- implemented `Drop` for thread pool
 - thread pool uses a struct to store a function of specific signature
This commit is contained in:
Sven Vogel 2023-06-07 13:30:22 +02:00
parent 52cb8639ea
commit 57016c1083
2 changed files with 143 additions and 53 deletions

View File

@ -7,7 +7,7 @@
use std::sync::Arc; use std::sync::Arc;
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use imsearch::multithreading::ThreadPool; use imsearch::multithreading::{Task, ThreadPool};
/// Amount of elements per vector used to calculate the dot product /// Amount of elements per vector used to calculate the dot product
const VEC_ELEM_COUNT: usize = 1_000_000; const VEC_ELEM_COUNT: usize = 1_000_000;
@ -45,15 +45,15 @@ fn dot_parallel(a: Arc<Vec<f64>>, b: Arc<Vec<f64>>, threads: usize) {
for i in 0..threads { for i in 0..threads {
// offset of the first element for the thread local vec // offset of the first element for the thread local vec
let chunk = i * steps; let chunk = i * steps;
// create a new strong reference to the vector
let aa = a.clone();
let bb = b.clone();
// launch a new thread // launch a new thread
pool.enqueue(move || { pool.enqueue(Task::new(
let a = &aa[chunk..(chunk + steps)]; (chunk, steps, a.clone(), b.clone()),
let b = &bb[chunk..(chunk + steps)]; |(block, inc, a, b)| {
dot(a, b) let a = &a[block..(block + inc)];
}); let b = &b[block..(block + inc)];
dot(a, b)
},
));
} }
pool.join_all(); pool.join_all();
@ -115,15 +115,15 @@ fn pool_overusage(a: Arc<Vec<f64>>, b: Arc<Vec<f64>>, threads: usize) {
for i in 0..threads { for i in 0..threads {
// offset of the first element for the thread local vec // offset of the first element for the thread local vec
let chunk = i * steps; let chunk = i * steps;
// create a new strong reference to the vector
let aa = a.clone();
let bb = b.clone();
// launch a new thread // launch a new thread
pool.enqueue(move || { pool.enqueue(Task::new(
let a = &aa[chunk..(chunk + steps)]; (chunk, steps, a.clone(), b.clone()),
let b = &bb[chunk..(chunk + steps)]; |(block, inc, a, b)| {
dot(a, b) let a = &a[block..(block + inc)];
}); let b = &b[block..(block + inc)];
dot(a, b)
},
));
} }
pool.join_all(); pool.join_all();

View File

@ -8,10 +8,12 @@
//! # Example //! # Example
//! ```rust //! ```rust
//! # use imsearch::multithreading::ThreadPool; //! # use imsearch::multithreading::ThreadPool;
//! # use imsearch::multithreading::Task;
//! let mut pool = ThreadPool::with_limit(2); //! let mut pool = ThreadPool::with_limit(2);
//! //!
//! for i in 0..10 { //! for i in 0..10 {
//! pool.enqueue(move || i); //! pool.enqueue(Task::new(i, |i| i));
//! // ^^^^^^ closure or static function
//! } //! }
//! //!
//! pool.join_all(); //! pool.join_all();
@ -19,6 +21,7 @@
//! ``` //! ```
use std::{ use std::{
any::Any,
collections::VecDeque, collections::VecDeque,
num::NonZeroUsize, num::NonZeroUsize,
sync::{ sync::{
@ -46,19 +49,31 @@ pub enum Priority {
Low, Low,
} }
/// Jobs are functions which are executed by the thread pool. They can be stalled when no threads are /// Traits a return value has to implement when being given back by a function or closure.
/// free to execute them directly. They are meant to be executed only once and be done. pub trait Sendable: Any + Send + 'static + std::panic::UnwindSafe {}
pub trait Job<T>: Send + 'static + FnOnce() -> T
impl<T> Sendable for T where T: Any + Send + 'static + std::panic::UnwindSafe {}
/// A task that will be executed at some point in the future by the thread pool
/// At the heart of this struct is the function to be executed. This may be a closure.
#[derive(Debug, Copy, Clone)]
pub struct Task<I, T>
where where
T: Send, I: Sendable,
T: Sendable,
{ {
job: fn(I) -> T,
param: I,
} }
impl<U, T> Job<T> for U impl<I, T> Task<I, T>
where where
U: Send + 'static + FnOnce() -> T, I: Sendable,
T: Send + 'static, T: Sendable,
{ {
pub fn new(param: I, job: fn(I) -> T) -> Self {
Self { job, param }
}
} }
/// Thread pool which can be used to execute functions or closures in parallel. /// Thread pool which can be used to execute functions or closures in parallel.
@ -71,23 +86,33 @@ where
/// # Example /// # Example
/// ```rust /// ```rust
/// # use imsearch::multithreading::ThreadPool; /// # use imsearch::multithreading::ThreadPool;
/// # use imsearch::multithreading::Task;
/// let mut pool = ThreadPool::with_limit(2); /// let mut pool = ThreadPool::with_limit(2);
/// ///
/// for i in 0..10 { /// for i in 0..10 {
/// pool.enqueue(move || i); /// pool.enqueue(Task::new(i, |i| i));
/// } /// }
/// ///
/// pool.join_all(); /// pool.join_all();
/// assert_eq!(pool.get_results().iter().sum::<i32>(), 45); /// assert_eq!(pool.get_results().iter().sum::<i32>(), 45);
/// ``` /// ```
/// # Drop
/// This struct implements the `Drop` trait. Upon being dropped the pool will wait for all threads
/// to finsish. This may take up an arbitrary amount of time.
/// # Panics in the thread
/// When a function or closure panics, the executing thread will detect the unwind performed by `panic` causing the
/// thread to print a message on stderr. The thread itself captures panics and won't terminate execution but continue with
/// the next task in the queue.
/// Its not recommend to use this pool with custom panic hooks or special functions which abort the process.
/// Also panicking code from external program written in C++ or others in undefinied behavior according to [`std::panic::catch_unwind`]
#[derive(Debug)] #[derive(Debug)]
pub struct ThreadPool<T, F> pub struct ThreadPool<I, T>
where where
T: Send, I: Sendable,
F: Job<T>, T: Sendable,
{ {
/// queue for storing the jobs to be executed /// queue for storing the jobs to be executed
queue: Arc<Mutex<VecDeque<F>>>, queue: Arc<Mutex<VecDeque<Task<I, T>>>>,
/// handles for all threads currently running and processing jobs /// handles for all threads currently running and processing jobs
handles: Vec<JoinHandle<()>>, handles: Vec<JoinHandle<()>>,
/// reciver end for channel based communication between threads /// reciver end for channel based communication between threads
@ -98,10 +123,10 @@ where
limit: NonZeroUsize, limit: NonZeroUsize,
} }
impl<T, F> Default for ThreadPool<T, F> impl<I, T> Default for ThreadPool<I, T>
where where
T: Send + 'static, I: Sendable,
F: Job<T>, T: Sendable,
{ {
fn default() -> Self { fn default() -> Self {
let (sender, receiver) = channel::<T>(); let (sender, receiver) = channel::<T>();
@ -121,10 +146,10 @@ where
} }
} }
impl<T, F> ThreadPool<T, F> impl<I, T> ThreadPool<I, T>
where where
T: Send + 'static, I: Sendable,
F: Job<T>, T: Sendable,
{ {
/// Creates a new thread pool with default thread count determined by either /// Creates a new thread pool with default thread count determined by either
/// [`std::thread::available_parallelism`] or [`DEFAULT_THREAD_POOL_SIZE`] in case it fails. /// [`std::thread::available_parallelism`] or [`DEFAULT_THREAD_POOL_SIZE`] in case it fails.
@ -139,9 +164,13 @@ where
/// # Panic /// # Panic
/// This function will fails if `max_threads` is zero. /// This function will fails if `max_threads` is zero.
pub fn with_limit(max_threads: usize) -> Self { pub fn with_limit(max_threads: usize) -> Self {
let (sender, receiver) = channel::<T>();
Self { Self {
limit: NonZeroUsize::new(max_threads).expect("Thread limit must be non-zero"), limit: NonZeroUsize::new(max_threads).expect("Thread limit must be non-zero"),
..Default::default() queue: Arc::new(Mutex::new(VecDeque::new())),
handles: Vec::new(),
sender,
receiver,
} }
} }
@ -151,7 +180,7 @@ where
/// This function will create a new thread if the maximum number of threads in not reached. /// This function will create a new thread if the maximum number of threads in not reached.
/// In case the maximum number of threads is already used, the job is stalled and will get executed /// In case the maximum number of threads is already used, the job is stalled and will get executed
/// when a thread is ready and its at the start of the queue. /// when a thread is ready and its at the start of the queue.
pub fn enqueue_priorize(&mut self, func: F, priority: Priority) { pub fn enqueue_priorize(&mut self, func: Task<I, T>, priority: Priority) {
// put job into queue // put job into queue
let mut queue = self.queue.lock().unwrap(); let mut queue = self.queue.lock().unwrap();
@ -169,8 +198,14 @@ where
let queue = self.queue.clone(); let queue = self.queue.clone();
self.handles.push(thread::spawn(move || { self.handles.push(thread::spawn(move || {
while let Some(job) = queue.lock().unwrap().pop_front() { while let Some(task) = queue.lock().unwrap().pop_front() {
tx.send(job()).expect("cannot send result"); // basically try catch
if let Err(e) = std::panic::catch_unwind(|| {
tx.send((task.job)(task.param))
.expect("unable to send result over channel");
}) {
eprintln!("thread paniced: {:?}", e);
}
} }
})); }));
} }
@ -183,7 +218,7 @@ where
/// This function will create a new thread if the maximum number of threads in not reached. /// This function will create a new thread if the maximum number of threads in not reached.
/// In case the maximum number of threads is already used, the job is stalled and will get executed /// In case the maximum number of threads is already used, the job is stalled and will get executed
/// when a thread is ready and its at the start of the queue. /// when a thread is ready and its at the start of the queue.
pub fn enqueue(&mut self, func: F) { pub fn enqueue(&mut self, func: Task<I, T>) {
self.enqueue_priorize(func, Priority::Low); self.enqueue_priorize(func, Priority::Low);
} }
@ -196,15 +231,15 @@ where
} }
} }
/// Returns all results that have been returned by the threads until now /// Sendables all results that have been Sendableed by the threads until now
/// and haven't been consumed yet. /// and haven't been consumed yet.
/// All results retrieved from this call won't be returned on a second call. /// All results retrieved from this call won't be Sendableed on a second call.
/// This function is non blocking. /// This function is non blocking.
pub fn try_get_results(&mut self) -> Vec<T> { pub fn try_get_results(&mut self) -> Vec<T> {
self.receiver.try_iter().collect() self.receiver.try_iter().collect()
} }
/// Returns all results that have been returned by the threads until now /// Sendables all results that have been returned by the threads until now
/// and haven't been consumed yet. The function will also wait for all threads to finish executing (empty the queue). /// and haven't been consumed yet. The function will also wait for all threads to finish executing (empty the queue).
/// All results retrieved from this call won't be returned on a second call. /// All results retrieved from this call won't be returned on a second call.
/// This function will block the caller thread. /// This function will block the caller thread.
@ -214,8 +249,20 @@ where
} }
} }
impl<I, T> Drop for ThreadPool<I, T>
where
I: Sendable,
T: Sendable,
{
fn drop(&mut self) {
self.join_all();
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use std::panic::UnwindSafe;
use super::*; use super::*;
#[test] #[test]
@ -223,7 +270,7 @@ mod test {
let mut pool = ThreadPool::default(); let mut pool = ThreadPool::default();
for i in 0..10 { for i in 0..10 {
pool.enqueue_priorize(move || i, Priority::High); pool.enqueue_priorize(Task::new(i, |i| i), Priority::High);
} }
pool.join_all(); pool.join_all();
@ -236,7 +283,7 @@ mod test {
let mut pool = ThreadPool::with_limit(2); let mut pool = ThreadPool::with_limit(2);
for i in 0..10 { for i in 0..10 {
pool.enqueue(move || i); pool.enqueue(Task::new(i, |i| i));
} }
assert_eq!(pool.handles.len(), 2); assert_eq!(pool.handles.len(), 2);
@ -247,19 +294,62 @@ mod test {
assert_eq!(pool.get_results().iter().sum::<i32>(), 45); assert_eq!(pool.get_results().iter().sum::<i32>(), 45);
} }
trait Object: Send + UnwindSafe {
fn get(&mut self) -> i32;
}
#[derive(Default)]
struct Test1 {
_int: i32,
}
impl Object for Test1 {
fn get(&mut self) -> i32 {
0
}
}
#[derive(Default)]
struct Test2 {
_c: char,
}
impl Object for Test2 {
fn get(&mut self) -> i32 {
0
}
}
#[derive(Default)]
struct Test3 {
_s: String,
}
impl Object for Test3 {
fn get(&mut self) -> i32 {
0
}
}
#[test] #[test]
fn test_multiple() { fn test_1() {
let mut pool = ThreadPool::with_limit(2); let mut pool = ThreadPool::with_limit(2);
for i in 0..10 { let feats: Vec<Box<dyn Object>> = vec![
pool.enqueue(move || i); Box::new(Test1::default()),
} Box::new(Test2::default()),
Box::new(Test3::default()),
];
assert_eq!(pool.handles.len(), 2); for feat in feats {
assert_eq!(pool.limit.get(), 2); pool.enqueue(Task::new(feat, |mut i| {
let _ = i.get();
i
}));
}
pool.join_all(); pool.join_all();
assert_eq!(pool.get_results().iter().sum::<i32>(), 45); let _feats: Vec<Box<dyn Object>> = pool.get_results();
} }
} }