diff --git a/src/multithreading/mod.rs b/src/multithreading/mod.rs index 647ea48..00ddc74 100644 --- a/src/multithreading/mod.rs +++ b/src/multithreading/mod.rs @@ -346,3 +346,54 @@ fn execute( result })); } + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::*; + + #[test] + fn test_thread_pool() { + // auto determine the amount of threads to use + let mut pool = ThreadPool::new(); + + // launch 4 jobs to run on our pool + for i in 0..4 { + pool.enqueue(move || (0..=i).sum::()); + } + + // wait for the threads to finish and sum their results + let sum = pool + .join_all() + .into_iter() + .map(|r| r.unwrap()) + .sum::(); + + assert_eq!(sum, 10); + } + + #[test] + fn test_drop_stalled() { + // auto determine the amount of threads to use + let mut pool = ThreadPool::with_threads(NonZeroUsize::new(1).unwrap()); + + // launch 2 jobs: 1 will immediately return, the other one will sleep for 20 seconds + for i in 0..1 { + pool.enqueue(move || { + thread::sleep(Duration::from_secs(i * 20)); + i + }); + } + + // wait 10 secs + thread::sleep(Duration::from_secs(2)); + // discard job that should still run + pool.discard_stalled(); + + // wait for the threads to finish and sum their results + let sum = pool.join_all().into_iter().map(|r| r.unwrap()).sum::(); + + assert_eq!(sum, 0); + } +}