finished documetation for thread pool

This commit is contained in:
Sven Vogel 2023-06-06 17:56:34 +02:00
parent 9be7bc18c7
commit 7a6dc389b9
2 changed files with 194 additions and 19 deletions

View File

@ -4,7 +4,7 @@
//! Each thread will calculate a partial dot product of two different vectors composed of 1,000,000 64-bit
//! double precision floating point values.
use std::{sync::Arc};
use std::sync::Arc;
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use imsearch::multithreading::ThreadPool;
@ -37,8 +37,7 @@ fn dot(a: &[f64], b: &[f64]) -> f64 {
/// sized slices which then get passed ot their own thread to compute the partial dot product. After all threads have
/// finished the partial dot products will be summed to create the final result.
fn dot_parallel(a: Arc<Vec<f64>>, b: Arc<Vec<f64>>, threads: usize) {
let mut pool =
ThreadPool::with_limit(threads);
let mut pool = ThreadPool::with_limit(threads);
// number of elements in each vector for each thread
let steps = a.len() / threads;
@ -56,7 +55,7 @@ fn dot_parallel(a: Arc<Vec<f64>>, b: Arc<Vec<f64>>, threads: usize) {
dot(a, b)
});
}
pool.join_all();
black_box(pool.get_results().iter().sum::<f64>());

View File

@ -1,35 +1,114 @@
use std::{thread::{JoinHandle, self}, sync::{mpsc::{Receiver, channel, Sender}, Mutex, Arc}, num::NonZeroUsize, collections::VecDeque};
//! This module provides the functionality to create thread pool to execute tasks in parallel.
//! The amount of threads to be used at maximum can be regulated by using `ThreadPool::with_limit`.
//! This implementation is aimed to be of low runtime cost with minimal sychronisation due to blocking.
//! Note that no threads will be spawned until jobs are supplied to be executed. For every supplied job
//! a new thread will be launched until the maximum number is reached. By then every launched thread will
//! be reused to process the remaining elements of the queue. If no jobs are left to be executed
//! all threads will finish and die. This means that if nothing is done, no threads will run in idle in the background.
//! # Example
//! ```rust
//! # use imsearch::multithreading::ThreadPool;
//! let mut pool = ThreadPool::with_limit(2);
//!
//! for i in 0..10 {
//! pool.enqueue(move || i);
//! }
//!
//! pool.join_all();
//! assert_eq!(pool.get_results().iter().sum::<i32>(), 45);
//! ```
const DEFAULT_THREAD_POOL_SIZE: usize = 1;
use std::{
collections::VecDeque,
num::NonZeroUsize,
sync::{
mpsc::{channel, Receiver, Sender},
Arc, Mutex,
},
thread::{self, JoinHandle},
};
/// Default number if threads to be used in case [`std::thread::available_parallelism`] fails.
pub const DEFAULT_THREAD_POOL_SIZE: usize = 1;
/// Indicates the priority level of functions or closures which get supplied to the pool.
/// Use [`Priority::High`] to ensure the closue to be executed before all closures that are already supplied
/// Use [`Priority::Low`] to ensure the closue to be executed after all closures that are already supplied
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub enum Priority {
/// Indicate that the closure or function supplied to the thread
/// has higher priority than any other given to the pool until now.
/// The item will get enqueued at the start of the waiting-queue.
High,
/// Indicate that the closure or function supplied to the thread pool
/// has lower priority than the already supplied ones in this pool.
/// The item will get enqueued at the end of the waiting-queue.
Low,
}
/// Jobs are functions which are executed by the thread pool. They can be stalled when no threads are
/// free to execute them directly. They are meant to be executed only once and be done.
pub trait Job<T>: Send + 'static + FnOnce() -> T
where T: Send
where
T: Send,
{
}
impl<U, T> Job<T> for U
where U: Send + 'static + FnOnce() -> T, T: Send + 'static
where
U: Send + 'static + FnOnce() -> T,
T: Send + 'static,
{
}
/// Thread pool which can be used to execute functions or closures in parallel.
/// The amount of threads to be used at maximum can be regulated by using `ThreadPool::with_limit`.
/// This implementation is aimed to be of low runtime cost with minimal sychronisation due to blocking.
/// Note that no threads will be spawned until jobs are supplied to be executed. For every supplied job
/// a new thread will be launched until the maximum number is reached. By then every launched thread will
/// be reused to process the remaining elements of the queue. If no jobs are left to be executed
/// all threads will finish and die. This means that if nothing is done, no threads will run in idle in the background.
/// # Example
/// ```rust
/// # use imsearch::multithreading::ThreadPool;
/// let mut pool = ThreadPool::with_limit(2);
///
/// for i in 0..10 {
/// pool.enqueue(move || i);
/// }
///
/// pool.join_all();
/// assert_eq!(pool.get_results().iter().sum::<i32>(), 45);
/// ```
#[derive(Debug)]
pub struct ThreadPool<T, F>
where T: Send, F: Job<T>
pub struct ThreadPool<T, F>
where
T: Send,
F: Job<T>,
{
/// queue for storing the jobs to be executed
queue: Arc<Mutex<VecDeque<F>>>,
/// handles for all threads currently running and processing jobs
handles: Vec<JoinHandle<()>>,
/// reciver end for channel based communication between threads
receiver: Receiver<T>,
/// sender end for channel based communication between threads
sender: Sender<T>,
/// maximum amount of threads to be used in parallel
limit: NonZeroUsize,
}
impl<T, F> Default for ThreadPool<T, F>
where T: Send + 'static, F: Job<T>
where
T: Send + 'static,
F: Job<T>,
{
fn default() -> Self {
let (sender, receiver) = channel::<T>();
let default = NonZeroUsize::new(DEFAULT_THREAD_POOL_SIZE).expect("Thread limit must be non-zero");
// determine default thread count to use based on the system
let default =
NonZeroUsize::new(DEFAULT_THREAD_POOL_SIZE).expect("Thread limit must be non-zero");
let limit = thread::available_parallelism().unwrap_or(default);
Self {
@ -43,12 +122,22 @@ where T: Send + 'static, F: Job<T>
}
impl<T, F> ThreadPool<T, F>
where T: Send + 'static, F: Job<T>
where
T: Send + 'static,
F: Job<T>,
{
/// Creates a new thread pool with default thread count determined by either
/// [`std::thread::available_parallelism`] or [`DEFAULT_THREAD_POOL_SIZE`] in case it fails.
/// No threads will be lauched until jobs are enqueued.
pub fn new() -> Self {
Default::default()
}
/// Creates a new thread pool with the given thread count. The pool will continue to launch new threads even if
/// the system does not allow for that count of parallelism.
/// No threads will be lauched until jobs are enqueued.
/// # Panic
/// This function will fails if `max_threads` is zero.
pub fn with_limit(max_threads: usize) -> Self {
Self {
limit: NonZeroUsize::new(max_threads).expect("Thread limit must be non-zero"),
@ -56,7 +145,22 @@ where T: Send + 'static, F: Job<T>
}
}
pub fn enqueue(&mut self, func: F) {
/// Put a new job into the queue to be executed by a thread in the future.
/// The priority of the job will determine if the job will be put at the start or end of the queue.
/// See [`crate::multithreading::Priority`].
/// This function will create a new thread if the maximum number of threads in not reached.
/// In case the maximum number of threads is already used, the job is stalled and will get executed
/// when a thread is ready and its at the start of the queue.
pub fn enqueue_priorize(&mut self, func: F, priority: Priority) {
// put job into queue
let mut queue = self.queue.lock().unwrap();
// insert new job into queue depending on its priority
match priority {
Priority::High => queue.push_front(func),
Priority::Low => queue.push_back(func),
}
if self.handles.len() < self.limit.get() {
// we can still launch threads to run in parallel
@ -69,21 +173,93 @@ where T: Send + 'static, F: Job<T>
tx.send(job()).expect("cannot send result");
}
}));
} else {
self.queue.lock().unwrap().push_back(func);
}
self.handles.retain(|h| !h.is_finished());
}
/// Put a new job into the queue to be executed by a thread in the future.
/// The priority of the job is automatically set to [`crate::multithreading::Priority::Low`].
/// This function will create a new thread if the maximum number of threads in not reached.
/// In case the maximum number of threads is already used, the job is stalled and will get executed
/// when a thread is ready and its at the start of the queue.
pub fn enqueue(&mut self, func: F) {
self.enqueue_priorize(func, Priority::Low);
}
/// Wait for all threads to finish executing. This means that by the time all threads have finished
/// every task will have been executed too. In other words the threads finsish when the queue of jobs is empty.
/// This function will block the caller thread.
pub fn join_all(&mut self) {
while let Some(handle) = self.handles.pop() {
handle.join().unwrap();
}
}
pub fn get_results(&mut self) -> Vec<T> {
/// Returns all results that have been returned by the threads until now
/// and haven't been consumed yet.
/// All results retrieved from this call won't be returned on a second call.
/// This function is non blocking.
pub fn try_get_results(&mut self) -> Vec<T> {
self.receiver.try_iter().collect()
}
}
/// Returns all results that have been returned by the threads until now
/// and haven't been consumed yet. The function will also wait for all threads to finish executing (empty the queue).
/// All results retrieved from this call won't be returned on a second call.
/// This function will block the caller thread.
pub fn get_results(&mut self) -> Vec<T> {
self.join_all();
self.try_get_results()
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_default() {
let mut pool = ThreadPool::default();
for i in 0..10 {
pool.enqueue_priorize(move || i, Priority::High);
}
pool.join_all();
assert_eq!(pool.try_get_results().iter().sum::<i32>(), 45);
}
#[test]
fn test_limit() {
let mut pool = ThreadPool::with_limit(2);
for i in 0..10 {
pool.enqueue(move || i);
}
assert_eq!(pool.handles.len(), 2);
assert_eq!(pool.limit.get(), 2);
pool.join_all();
assert_eq!(pool.get_results().iter().sum::<i32>(), 45);
}
#[test]
fn test_multiple() {
let mut pool = ThreadPool::with_limit(2);
for i in 0..10 {
pool.enqueue(move || i);
}
assert_eq!(pool.handles.len(), 2);
assert_eq!(pool.limit.get(), 2);
pool.join_all();
assert_eq!(pool.get_results().iter().sum::<i32>(), 45);
}
}