Skip to content

Commit

Permalink
feat: use rayon for parallel computing (#10)
Browse files Browse the repository at this point in the history
* feat: use rayon for parallel computing

Signed-off-by: Keming <[email protected]>

* bump version

Signed-off-by: Keming <[email protected]>

---------

Signed-off-by: Keming <[email protected]>
  • Loading branch information
kemingy authored Dec 4, 2024
1 parent d56c1db commit 5b58f37
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 22 deletions.
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "gathers"
version = "0.2.0"
version = "0.3.0"
edition = "2021"
authors = ["Keming <[email protected]>"]
license = "Apache-2.0"
Expand All @@ -19,6 +19,7 @@ log = "0.4.22"
num-traits = "0.2.19"
rand = "0.8.5"
rand_distr = "0.4.3"
rayon = "1.10.0"

[profile.dev.package.faer]
opt-level = 3
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
Clustering algorithm implementation:

- [x] K-means
- [ ] Parallel
- [ ] Hierarchical K-means
- [x] PyO3 binding
- [x] RaBitQ assignment
- [x] Parallel with Rayon
- [ ] mini batch K-means
- [x] PyO3 binding
- [ ] Hierarchical K-means
56 changes: 54 additions & 2 deletions python/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion python/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "gatherspy"
version = "0.2.0"
version = "0.3.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
45 changes: 31 additions & 14 deletions src/kmeans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::time::Instant;

use log::debug;
use rand::Rng;
use rayon::prelude::*;

use crate::distance::{argmin, neg_dot_product, squared_euclidean, Distance};
use crate::rabitq::RaBitQ;
Expand All @@ -15,6 +16,7 @@ const EPS: f32 = 1.0 / 1024.0;
const MIN_POINTS_PER_CENTROID: usize = 39;
const MAX_POINTS_PER_CENTROID: usize = 256;
const LARGE_CLUSTER_THRESHOLD: usize = 1 << 20;
const RAYON_BLOCK_SIZE: usize = 1024 * 32;

/// Assign vectors to centroids.
pub fn assign(vecs: &[f32], centroids: &[f32], dim: usize, distance: Distance, labels: &mut [u32]) {
Expand All @@ -36,14 +38,22 @@ pub fn assign(vecs: &[f32], centroids: &[f32], dim: usize, distance: Distance, l
// let squared_x: Vec<f32> = vecs.chunks(dim).map(l2_norm).collect();
// let squared_y: Vec<f32> = centroids.chunks(dim).map(l2_norm).collect();

for (i, vec) in vecs.chunks(dim).enumerate() {
for (j, centroid) in centroids.chunks(dim).enumerate() {
distances[j] =
// squared_x[i] + squared_y[j] + 2.0 * neg_dot_product(vec, centroid);
squared_euclidean(vec, centroid);
}
labels[i] = argmin(&distances) as u32;
}
labels.copy_from_slice(
&vecs
.par_chunks(dim * RAYON_BLOCK_SIZE)
.flat_map(|vec| {
let mut par_labels = vec![0; vec.len() / dim];
let mut par_distances = vec![f32::MAX; centroids.len() / dim];
for (i, v) in vec.chunks(dim).enumerate() {
for (j, centroid) in centroids.chunks(dim).enumerate() {
par_distances[j] = squared_euclidean(v, centroid);
}
par_labels[i] = argmin(&par_distances) as u32;
}
par_labels
})
.collect::<Vec<_>>(),
);
}
}
}
Expand All @@ -53,9 +63,18 @@ pub fn assign(vecs: &[f32], centroids: &[f32], dim: usize, distance: Distance, l
/// TODO: support dot product distance
pub fn rabitq_assign(vecs: &[f32], centroids: &[f32], dim: usize, labels: &mut [u32]) {
let rabitq = RaBitQ::new(centroids, dim);
for (i, vec) in vecs.chunks(dim).enumerate() {
labels[i] = rabitq.retrieve_top_one(vec) as u32;
}

labels.copy_from_slice(
&vecs
.par_chunks(dim * RAYON_BLOCK_SIZE)
.flat_map(|vec| {
vec.chunks(dim)
.map(|v| rabitq.retrieve_top_one(v) as u32)
.collect::<Vec<_>>()
})
.collect::<Vec<_>>(),
);

let (rough, precise) = rabitq.get_metrics();
debug!(
"RaBitQ: rough {}, precise {}, ratio: {}",
Expand Down Expand Up @@ -197,9 +216,7 @@ impl KMeans {
if num > MAX_POINTS_PER_CENTROID * self.n_cluster as usize {
let n_sample = MAX_POINTS_PER_CENTROID * self.n_cluster as usize;
debug!("subsample to {} points", n_sample);
let subsampled = as_continuous_vec(&subsample(n_sample, &vecs, dim));
vecs.shrink_to(subsampled.len());
vecs.copy_from_slice(&subsampled);
vecs = as_continuous_vec(&subsample(n_sample, &vecs, dim));
}

let mut centroids = as_continuous_vec(&subsample(self.n_cluster as usize, &vecs, dim));
Expand Down

0 comments on commit 5b58f37

Please sign in to comment.