diff --git a/Cargo.lock b/Cargo.lock index dc7f2d1..0479eff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -427,7 +427,7 @@ dependencies = [ [[package]] name = "gathers" -version = "0.2.0" +version = "0.3.0" dependencies = [ "argh", "criterion", @@ -437,6 +437,7 @@ dependencies = [ "num-traits", "rand", "rand_distr", + "rayon", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 778b5ab..f114952 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "gathers" -version = "0.2.0" +version = "0.3.0" edition = "2021" authors = ["Keming "] license = "Apache-2.0" @@ -19,6 +19,7 @@ log = "0.4.22" num-traits = "0.2.19" rand = "0.8.5" rand_distr = "0.4.3" +rayon = "1.10.0" [profile.dev.package.faer] opt-level = 3 diff --git a/README.md b/README.md index a03e417..8c02017 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,8 @@ Clustering algorithm implementation: - [x] K-means -- [ ] Parallel -- [ ] Hierarchical K-means +- [x] PyO3 binding - [x] RaBitQ assignment +- [x] Parallel with Rayon - [ ] mini batch K-means -- [x] PyO3 binding +- [ ] Hierarchical K-means diff --git a/python/Cargo.lock b/python/Cargo.lock index 97c3a01..4054c7e 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -153,6 +153,31 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + [[package]] name = "crunchy" version = "0.2.2" @@ -175,6 +200,12 @@ dependencies = [ "reborrow", ] +[[package]] +name = "either" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" + [[package]] name = "enum-as-inner" version = "0.6.1" @@ -290,7 +321,7 @@ dependencies = [ [[package]] name = "gathers" -version = "0.2.0" +version = "0.3.0" dependencies = [ "argh", "env_logger", @@ -299,11 +330,12 @@ dependencies = [ "num-traits", "rand", "rand_distr", + "rayon", ] [[package]] name = "gatherspy" -version = "0.2.0" +version = "0.3.0" dependencies = [ "gathers", "numpy", @@ -845,6 +877,26 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "reborrow" version = "0.5.5" diff --git a/python/Cargo.toml b/python/Cargo.toml index af85f77..896fc0a 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "gatherspy" -version = "0.2.0" +version = "0.3.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/src/kmeans.rs b/src/kmeans.rs index d49509e..61e2d85 100644 --- a/src/kmeans.rs +++ b/src/kmeans.rs @@ -5,6 +5,7 @@ use std::time::Instant; use log::debug; use rand::Rng; +use rayon::prelude::*; use crate::distance::{argmin, neg_dot_product, squared_euclidean, Distance}; use crate::rabitq::RaBitQ; @@ -15,6 +16,7 @@ const EPS: f32 = 1.0 / 1024.0; const MIN_POINTS_PER_CENTROID: usize = 39; const MAX_POINTS_PER_CENTROID: usize = 256; const LARGE_CLUSTER_THRESHOLD: usize = 1 << 20; +const RAYON_BLOCK_SIZE: usize = 1024 * 32; /// Assign vectors to centroids. pub fn assign(vecs: &[f32], centroids: &[f32], dim: usize, distance: Distance, labels: &mut [u32]) { @@ -36,14 +38,22 @@ pub fn assign(vecs: &[f32], centroids: &[f32], dim: usize, distance: Distance, l // let squared_x: Vec = vecs.chunks(dim).map(l2_norm).collect(); // let squared_y: Vec = centroids.chunks(dim).map(l2_norm).collect(); - for (i, vec) in vecs.chunks(dim).enumerate() { - for (j, centroid) in centroids.chunks(dim).enumerate() { - distances[j] = - // squared_x[i] + squared_y[j] + 2.0 * neg_dot_product(vec, centroid); - squared_euclidean(vec, centroid); - } - labels[i] = argmin(&distances) as u32; - } + labels.copy_from_slice( + &vecs + .par_chunks(dim * RAYON_BLOCK_SIZE) + .flat_map(|vec| { + let mut par_labels = vec![0; vec.len() / dim]; + let mut par_distances = vec![f32::MAX; centroids.len() / dim]; + for (i, v) in vec.chunks(dim).enumerate() { + for (j, centroid) in centroids.chunks(dim).enumerate() { + par_distances[j] = squared_euclidean(v, centroid); + } + par_labels[i] = argmin(&par_distances) as u32; + } + par_labels + }) + .collect::>(), + ); } } } @@ -53,9 +63,18 @@ pub fn assign(vecs: &[f32], centroids: &[f32], dim: usize, distance: Distance, l /// TODO: support dot product distance pub fn rabitq_assign(vecs: &[f32], centroids: &[f32], dim: usize, labels: &mut [u32]) { let rabitq = RaBitQ::new(centroids, dim); - for (i, vec) in vecs.chunks(dim).enumerate() { - labels[i] = rabitq.retrieve_top_one(vec) as u32; - } + + labels.copy_from_slice( + &vecs + .par_chunks(dim * RAYON_BLOCK_SIZE) + .flat_map(|vec| { + vec.chunks(dim) + .map(|v| rabitq.retrieve_top_one(v) as u32) + .collect::>() + }) + .collect::>(), + ); + let (rough, precise) = rabitq.get_metrics(); debug!( "RaBitQ: rough {}, precise {}, ratio: {}", @@ -197,9 +216,7 @@ impl KMeans { if num > MAX_POINTS_PER_CENTROID * self.n_cluster as usize { let n_sample = MAX_POINTS_PER_CENTROID * self.n_cluster as usize; debug!("subsample to {} points", n_sample); - let subsampled = as_continuous_vec(&subsample(n_sample, &vecs, dim)); - vecs.shrink_to(subsampled.len()); - vecs.copy_from_slice(&subsampled); + vecs = as_continuous_vec(&subsample(n_sample, &vecs, dim)); } let mut centroids = as_continuous_vec(&subsample(self.n_cluster as usize, &vecs, dim));