Skip to content

Commit

Permalink
docs: update rs docs, bump version (#11)
Browse files Browse the repository at this point in the history
Signed-off-by: Keming <[email protected]>
  • Loading branch information
kemingy authored Dec 4, 2024
1 parent 5b58f37 commit c599799
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
[package]
name = "gathers"
version = "0.3.0"
version = "0.3.1"
edition = "2021"
authors = ["Keming <[email protected]>"]
license = "Apache-2.0"
readme = "README.md"
repository = "https://github.com/kemingy/gathers"
description = "Clustering algorithms."
documentation = "https://docs.rs/gathers"
keywords = ["clustering"]
keywords = ["cluster", "kmeans", "rabitq", "machine-learning", "vector-search"]
categories = ["algorithms", "science"]

[dependencies]
Expand Down
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,25 @@
[![crates.io](https://img.shields.io/crates/v/gathers.svg)](https://crates.io/crates/gathers)
[![docs.rs](https://docs.rs/gathers/badge.svg)](https://docs.rs/gathers)

Clustering algorithm implementation:
Clustering algorithm implementation in Rust and binding to Python.

For Python users, check the [Python README](./python/README.md).

- [x] K-means
- [x] PyO3 binding
- [x] RaBitQ assignment
- [x] Parallel with Rayon
- [x] `x86` & `x86_64` SIMD acceleration
- [ ] mini batch K-means
- [ ] Hierarchical K-means
- [ ] `arm` & `aarch64` SIMD acceleration

## Installation

```sh
cargo add gathers
```

## Usage

Check the [docs](https://docs.rs/gathers) and [main.rs](./src/main.rs).
2 changes: 1 addition & 1 deletion python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import numpy as np

gathers = Gathers(verbose=True)
rng = np.random.default_rng()
data = rng.random((1000, 64), dtype=np.float32)
data = rng.random((1000, 64), dtype=np.float32) # only support float32
centroids = gathers.fit(data, 10)
labels = gathers.batch_assign(data, centroids)
print(labels)
Expand Down
30 changes: 24 additions & 6 deletions src/kmeans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ pub struct KMeans {
tolerance: f32,
distance: Distance,
use_residual: bool,
use_default_config: bool,
}

impl Default for KMeans {
Expand All @@ -164,12 +165,21 @@ impl Default for KMeans {
tolerance: 1e-4,
distance: Distance::default(),
use_residual: false,
use_default_config: true,
}
}
}

impl KMeans {
/// Create a new KMeans instance.
///
/// # Arguments
///
/// * `n_cluster` - number of clusters, recommend to be a number in [sqrt(n) * 4, sqrt(n) * 8]
/// * `max_iter` - max number of iterations
/// * `tolerance` - convergence tolerance, stop when the diff is less than this value
/// * `distance` - distance metric
/// * `use_residual` - use residual for more accurate L2 distance computations, only work for L2
pub fn new(
n_cluster: u32,
max_iter: u32,
Expand All @@ -192,17 +202,25 @@ impl KMeans {
tolerance,
distance,
use_residual,
use_default_config: false,
}
}

/// Fit the KMeans configurations to the given vectors and return the centroids.
pub fn fit(&self, mut vecs: Vec<f32>, dim: usize) -> Vec<f32> {
let num = vecs.len() / dim;
debug!("num of points: {}", num);
if num < self.n_cluster as usize {

// auto-config the `n_cluster` if it's initialized with `default()`
let n_cluster = match self.use_default_config {
true => (((num as f32).sqrt() as u32) * 4).min((num / MIN_POINTS_PER_CENTROID) as u32),
false => self.n_cluster,
};
debug!("num of points: {}, num of clusters: {}", num, n_cluster);

if num < n_cluster as usize {
panic!("number of samples must be greater than n_cluster");
}
if num < self.n_cluster as usize * MIN_POINTS_PER_CENTROID {
if num < n_cluster as usize * MIN_POINTS_PER_CENTROID {
panic!("too few samples for n_cluster");
}

Expand All @@ -213,13 +231,13 @@ impl KMeans {
}

// subsample
if num > MAX_POINTS_PER_CENTROID * self.n_cluster as usize {
let n_sample = MAX_POINTS_PER_CENTROID * self.n_cluster as usize;
if num > MAX_POINTS_PER_CENTROID * n_cluster as usize {
let n_sample = MAX_POINTS_PER_CENTROID * n_cluster as usize;
debug!("subsample to {} points", n_sample);
vecs = as_continuous_vec(&subsample(n_sample, &vecs, dim));
}

let mut centroids = as_continuous_vec(&subsample(self.n_cluster as usize, &vecs, dim));
let mut centroids = as_continuous_vec(&subsample(n_cluster as usize, &vecs, dim));
if self.distance == Distance::NegativeDotProduct {
centroids.chunks_mut(dim).for_each(normalize);
}
Expand Down
21 changes: 21 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,25 @@
//! Clustering algorithms for Rust.
//!
//! ## Examples
//!
//! ```
//! use gathers::kmeans::{KMeans, rabitq_assign};
//! use gathers::utils::as_continuous_vec;
//! # use rand::Rng;
//! # let mut rng = rand::thread_rng();
//! # let vecs = (0..1000).map(|_| (0..32).map(|_| rng.gen::<f32>()).collect::<Vec<f32>>()).collect::<Vec<Vec<f32>>>();
//!
//!
//! let kmeans = KMeans::default();
//! let num = vecs.len();
//! let dim = vecs[0].len();
//!
//! // fit
//! let centroids = kmeans.fit(as_continuous_vec(&vecs), dim);
//! // predict
//! let mut labels = vec![0; num];
//! rabitq_assign(&as_continuous_vec(&vecs), &centroids, dim, &mut labels);
//! ```
#![deny(missing_docs)]

Expand Down
5 changes: 3 additions & 2 deletions src/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
use crate::rabitq::THETA_LOG_DIM;

/// Compute the squared Euclidean distance between two vectors.
/// Code refer to https://github.com/nmslib/hnswlib/blob/master/hnswlib/space_l2.h
///
/// Code refer to <https://github.com/nmslib/hnswlib/blob/master/hnswlib/space_l2.h>
///
/// # Safety
///
Expand Down Expand Up @@ -425,7 +426,7 @@ pub unsafe fn vector_binarize_query(vec: &[u8], binary: &mut [u64]) {

/// Compute the binary dot product of two vectors.
///
/// Refer to: https://github.com/komrad36/popcount
/// Refer to: <https://github.com/komrad36/popcount>
///
/// # Safety
///
Expand Down
4 changes: 2 additions & 2 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ where
}
}

/// Convert a 2-D Vec<Vec<T>> to a 1-D continuous vector.
/// Convert a 2-D `Vec<Vec<T>>` to a 1-D continuous vector.
#[inline]
pub fn as_continuous_vec<T>(mat: &[Vec<T>]) -> Vec<T>
where
Expand All @@ -37,7 +37,7 @@ where
mat.iter().flat_map(|v| v.iter().cloned()).collect()
}

/// Convert a 1-D continuous vector to a 2-D Vec<Vec<T>>.
/// Convert a 1-D continuous vector to a 2-D `Vec<Vec<T>>`.
#[inline]
pub fn as_matrix<T>(vecs: &[T], dim: usize) -> Vec<Vec<T>>
where
Expand Down

0 comments on commit c599799

Please sign in to comment.