implemented proper linear loop for dot product

This commit is contained in:
Sven Vogel 2023-05-03 08:46:13 +02:00
parent 9b7d91ad5b
commit 1f5a530c60
2 changed files with 45 additions and 17 deletions

View File

@ -11,3 +11,6 @@ futures = "0.3.28"
jemalloc-ctl = "0.5.0" jemalloc-ctl = "0.5.0"
jemallocator = "0.5.0" jemallocator = "0.5.0"
bytesize = "1.2.0" bytesize = "1.2.0"
[features]
binary_search = []

View File

@ -1,7 +1,7 @@
use std::time::Instant;
use bytesize::ByteSize; use bytesize::ByteSize;
use jemalloc_ctl::{epoch, stats};
use rand::Rng; use rand::Rng;
use jemalloc_ctl::{stats, epoch}; use std::time::Instant;
#[global_allocator] #[global_allocator]
static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc; static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
@ -13,13 +13,35 @@ pub struct SparseVec {
} }
impl SparseVec { impl SparseVec {
pub fn dot(&self, other: &SparseVec) -> f64 { pub fn dot(&self, other: &SparseVec) -> f64 {
let mut sum = 0.0; let mut sum = 0.0;
for index in 0..other.indices.len() { #[cfg(not(feature="binary_search"))]
// exponential search for an element in the second vector to have the same index {
sum += binary_search(self.indices[index], &other.indices, &other.values) * self.values[index]; let mut x = 0;
let mut y = 0;
while x < self.indices.len() && y < other.indices.len() {
if self.indices[x] == other.indices[y] {
sum += self.values[x] * other.values[y];
x += 1;
y += 1;
} else if self.indices[x] > other.indices[y] {
y += 1;
} else {
x += 1;
}
}
}
#[cfg(feature="binary_search")]
{
for index in 0..other.indices.len() {
// binary search for an element in the second vector to have the same index
sum += binary_search(self.indices[index], &other.indices, &other.values)
* self.values[index];
}
} }
sum sum
@ -36,14 +58,12 @@ impl SparseVec {
for i in 0..non_zero_elements { for i in 0..non_zero_elements {
values.push(0.5); values.push(0.5);
let idx = i as f32 / non_zero_elements as f32 * (elements as f32 - 4.0) + rng.gen_range(0.0..3.0); let idx = i as f32 / non_zero_elements as f32 * (elements as f32 - 4.0)
+ rng.gen_range(0.0..3.0);
indices.push(idx as usize); indices.push(idx as usize);
} }
Self { Self { values, indices }
values,
indices
}
} }
} }
@ -76,11 +96,10 @@ macro_rules! time {
let start = Instant::now(); let start = Instant::now();
$block; $block;
println!("{} took {}s", $name, start.elapsed().as_secs_f64()); println!("{} took {}s", $name, start.elapsed().as_secs_f64());
}} }};
} }
fn main() { fn main() {
/// Theoretical size of the vector in elements /// Theoretical size of the vector in elements
/// This would mean the we would require 10 GBs of memory to store a single vector /// This would mean the we would require 10 GBs of memory to store a single vector
const VECTOR_SIZE: usize = 10_000_000_000; const VECTOR_SIZE: usize = 10_000_000_000;
@ -92,7 +111,10 @@ fn main() {
let non_zero_elements = (VECTOR_SIZE as f64 * NULL_NON_NULL_RATIO) as usize; let non_zero_elements = (VECTOR_SIZE as f64 * NULL_NON_NULL_RATIO) as usize;
let heap_element_size = std::mem::size_of::<f64>() + std::mem::size_of::<usize>(); let heap_element_size = std::mem::size_of::<f64>() + std::mem::size_of::<usize>();
println!("Estimated size on heap: {}", ByteSize::b((non_zero_elements * heap_element_size) as u64)); println!(
"Estimated size on heap: {}",
ByteSize::b((non_zero_elements * heap_element_size) as u64)
);
println!("Size on stack: {} B", std::mem::size_of::<SparseVec>()); println!("Size on stack: {} B", std::mem::size_of::<SparseVec>());
let vec: SparseVec; let vec: SparseVec;
@ -104,7 +126,10 @@ fn main() {
// many statistics are cached and only updated when the epoch is advanced. // many statistics are cached and only updated when the epoch is advanced.
epoch::advance().unwrap(); epoch::advance().unwrap();
println!("Heap allocated bytes (total): {}", ByteSize::b(stats::allocated::read().unwrap() as u64)); println!(
"Heap allocated bytes (total): {}",
ByteSize::b(stats::allocated::read().unwrap() as u64)
);
time!("Sparse vector dot product", { time!("Sparse vector dot product", {
vec.dot(&vec); vec.dot(&vec);