Skip to content

Commit

Permalink
feat: Database specified by model
Browse files Browse the repository at this point in the history
  • Loading branch information
emmyoh committed Dec 12, 2024
1 parent 084538e commit 5321e5e
Show file tree
Hide file tree
Showing 19 changed files with 372 additions and 299 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,5 @@ texts
image.db
images
audio.db
audio
audio
*_old.rs
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ symphonia = "0.5.4"

[features]
default = []
default_db = []
accelerate = ["candle-core/accelerate", "candle-examples/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle-core/cuda", "candle-examples/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
mkl = ["candle-core/mkl", "candle-examples/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
metal = ["candle-core/metal", "candle-examples/metal", "candle-nn/metal", "candle-transformers/metal"]
sixel = ["viuer/sixel"]
cli = ["dep:clap", "dep:ticky", "dep:pretty-duration", "dep:indicatif", "dep:viuer", "dep:rodio"]
cli = ["default_db", "dep:clap", "dep:ticky", "dep:pretty-duration", "dep:indicatif", "dep:viuer", "dep:rodio"]
23 changes: 0 additions & 23 deletions src/audio.rs

This file was deleted.

39 changes: 19 additions & 20 deletions src/db.rs → src/database/core.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::distance::DistanceUnit;
use crate::model::DatabaseEmbeddingModel;
use crate::model::core::DatabaseEmbeddingModel;
use bytes::Bytes;
use fastembed::Embedding;
use hnsw::Params;
Expand Down Expand Up @@ -67,24 +67,28 @@ impl DocumentType {
/// * `M0` - The number of bi-directional links created for each node in the HNSW graph in the first layer. Cannot be changed after database creation.
pub struct Database<
Met: Metric<Embedding, Unit = DistanceUnit> + Serialize,
Model: DatabaseEmbeddingModel + Serialize,
const EF_CONSTRUCTION: usize,
const M: usize,
const M0: usize,
> {
/// The Hierarchical Navigable Small World (HNSW) graph containing the embeddings.
pub hnsw: Hnsw<Met, Embedding, Pcg64, M, M0>,
/// The type of documents stored in the database.
pub document_type: DocumentType,
// pub document_type: DocumentType,
pub model: Model,
}

impl<
Met: Metric<Embedding, Unit = DistanceUnit> + Serialize,
Model: DatabaseEmbeddingModel + Serialize,
const EF_CONSTRUCTION: usize,
const M: usize,
const M0: usize,
> Database<Met, EF_CONSTRUCTION, M, M0>
> Database<Met, Model, EF_CONSTRUCTION, M, M0>
where
for<'de> Met: Deserialize<'de>,
for<'de> Model: Deserialize<'de>,
{
/// Load the database from disk, or create it if it does not already exist.
///
Expand All @@ -97,10 +101,8 @@ where
/// # Returns
///
/// A database containing a HNSW graph and the inserted documents.
pub fn create_or_load_database(
metric: Met,
document_type: DocumentType,
) -> Result<Self, Box<dyn Error>> {
pub fn create_or_load_database(metric: Met, model: Model) -> Result<Self, Box<dyn Error>> {
let document_type = model.document_type();
let db_bytes = fs::read(document_type.database_name());
match db_bytes {
Ok(bytes) => {
Expand All @@ -109,10 +111,7 @@ where
}
Err(_) => {
let hnsw = Hnsw::new_params(metric, Params::new().ef_construction(EF_CONSTRUCTION));
let db = Database {
hnsw,
document_type,
};
let db = Database { hnsw, model };
let db_bytes = bincode::serialize(&db)?;
fs::write(document_type.database_name(), db_bytes)?;
Ok(db)
Expand All @@ -123,7 +122,7 @@ where
/// Save the database to disk.
pub fn save_database(&self) -> Result<(), Box<dyn Error>> {
let db_bytes = bincode::serialize(&self)?;
fs::write(self.document_type.database_name(), db_bytes)?;
fs::write(self.model.document_type().database_name(), db_bytes)?;
Ok(())
}

Expand All @@ -138,12 +137,11 @@ where
/// # Returns
///
/// A tuple containing the number of embeddings inserted and the dimension of the embeddings.
pub fn insert_documents<Mod: DatabaseEmbeddingModel>(
pub fn insert_documents(
&mut self,
model: &Mod,
documents: Vec<Bytes>,
) -> Result<(usize, usize), Box<dyn Error>> {
let new_embeddings: Vec<Embedding> = model.embed_documents(documents.to_vec())?;
let new_embeddings: Vec<Embedding> = self.model.embed_documents(documents.to_vec())?;
let length_and_dimension = (new_embeddings.len(), new_embeddings[0].len());
let records: Vec<_> = new_embeddings
.into_par_iter()
Expand Down Expand Up @@ -186,9 +184,8 @@ where
/// # Returns
///
/// A vector of documents that are most similar to the queried documents.
pub fn query_documents<Mod: DatabaseEmbeddingModel>(
pub fn query_documents(
&mut self,
model: &Mod,
documents: Vec<Bytes>,
number_of_results: usize,
) -> Result<Vec<Vec<u8>>, Box<dyn Error>> {
Expand All @@ -197,7 +194,7 @@ where
}
let mut searcher: Searcher<DistanceUnit> = Searcher::default();
let mut results = Vec::new();
let query_embeddings = model.embed_documents(documents)?;
let query_embeddings = self.model.embed_documents(documents)?;
for query_embedding in query_embeddings.iter() {
let mut neighbours = Vec::new();
self.hnsw.nearest(
Expand Down Expand Up @@ -232,7 +229,8 @@ where
&self,
documents: &mut HashMap<usize, Bytes>,
) -> Result<(), Box<dyn Error>> {
let document_subdirectory = self.document_type.subdirectory_name();
let document_type = self.model.document_type();
let document_subdirectory = document_type.subdirectory_name();
std::fs::create_dir_all(document_subdirectory)?;
for document in documents {
let mut reader = BufReader::new(document.1.as_ref());
Expand Down Expand Up @@ -262,7 +260,8 @@ where
&self,
documents: &mut Vec<usize>,
) -> Result<HashMap<usize, Vec<u8>>, Box<dyn Error>> {
let document_subdirectory = self.document_type.subdirectory_name();
let document_type = self.model.document_type();
let document_subdirectory = document_type.subdirectory_name();
let mut results = HashMap::new();
for document_index in documents {
let file = OpenOptions::new()
Expand Down
36 changes: 36 additions & 0 deletions src/database/default/audio.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use crate::database::core::Database;
use crate::distance::CosineDistance;
use crate::model::audio::VitBasePatch16_224;

/// The default distance metric for audio embeddings.
pub type DefaultAudioMetric = CosineDistance;

/// The default embedding model for audio embeddings.
pub type DefaultAudioModel = VitBasePatch16_224;

/// A parameter regarding insertion into the HNSW graph. Higher values result in more accurate search results at the expense of slower retrieval speeds. Cannot be changed after database creation.
pub const DEFAULT_AUDIO_EF_CONSTRUCTION: usize = 400;

/// The number of bi-directional links created for each node in the HNSW graph. Cannot be changed after database creation. Increases memory usage and decreases retrieval speed with higher values.
pub const DEFAULT_AUDIO_M: usize = 12;

/// The number of bi-directional links created for each node in the HNSW graph in the first layer. Cannot be changed after database creation.
pub const DEFAULT_AUDIO_M0: usize = 24;

/// A database containing sounds and their embeddings.
pub type DefaultAudioDatabase = Database<
DefaultAudioMetric,
DefaultAudioModel,
DEFAULT_AUDIO_EF_CONSTRUCTION,
DEFAULT_AUDIO_M,
DEFAULT_AUDIO_M0,
>;

/// Load the audio database from disk, or create it if it does not already exist.
///
/// # Returns
///
/// A vector database for audio.
pub fn create_or_load_database() -> Result<DefaultAudioDatabase, Box<dyn std::error::Error>> {
DefaultAudioDatabase::create_or_load_database(DefaultAudioMetric {}, DefaultAudioModel {})
}
36 changes: 36 additions & 0 deletions src/database/default/image.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use crate::database::core::Database;
use crate::distance::CosineDistance;
use crate::model::image::VitBasePatch16_224;

/// The default distance metric for image embeddings.
pub type DefaultImageMetric = CosineDistance;

/// The default embedding model for image embeddings.
pub type DefaultImageModel = VitBasePatch16_224;

/// A parameter regarding insertion into the HNSW graph. Higher values result in more accurate search results at the expense of slower retrieval speeds. Cannot be changed after database creation.
pub const DEFAULT_IMAGE_EF_CONSTRUCTION: usize = 400;

/// The number of bi-directional links created for each node in the HNSW graph. Cannot be changed after database creation. Increases memory usage and decreases retrieval speed with higher values.
pub const DEFAULT_IMAGE_M: usize = 12;

/// The number of bi-directional links created for each node in the HNSW graph in the first layer. Cannot be changed after database creation.
pub const DEFAULT_IMAGE_M0: usize = 24;

/// A database containing images and their embeddings.
pub type DefaultImageDatabase = Database<
DefaultImageMetric,
DefaultImageModel,
DEFAULT_IMAGE_EF_CONSTRUCTION,
DEFAULT_IMAGE_M,
DEFAULT_IMAGE_M0,
>;

/// Load the image database from disk, or create it if it does not already exist.
///
/// # Returns
///
/// A vector database for images.
pub fn create_or_load_database() -> Result<DefaultImageDatabase, Box<dyn std::error::Error>> {
DefaultImageDatabase::create_or_load_database(DefaultImageMetric {}, DefaultImageModel {})
}
6 changes: 6 additions & 0 deletions src/database/default/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
/// Default configuration for an audio database.
pub mod audio;
/// Default configuration for an image database.
pub mod image;
/// Default configuration for a text database.
pub mod text;
36 changes: 36 additions & 0 deletions src/database/default/text.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use crate::database::core::Database;
use crate::distance::L2SquaredDistance;
use crate::model::text::BGESmallEn1_5;

/// The default distance metric for text embeddings.
pub type DefaultTextMetric = L2SquaredDistance;

/// The default embedding model for text embeddings.
pub type DefaultTextModel = BGESmallEn1_5;

/// A parameter regarding insertion into the HNSW graph. Higher values result in more accurate search results at the expense of slower retrieval speeds. Cannot be changed after database creation.
pub const DEFAULT_TEXT_EF_CONSTRUCTION: usize = 400;

/// The number of bi-directional links created for each node in the HNSW graph. Cannot be changed after database creation. Increases memory usage and decreases retrieval speed with higher values.
pub const DEFAULT_TEXT_M: usize = 12;

/// The number of bi-directional links created for each node in the HNSW graph in the first layer. Cannot be changed after database creation.
pub const DEFAULT_TEXT_M0: usize = 24;

/// A database containing texts and their embeddings.
pub type DefaultTextDatabase = Database<
DefaultTextMetric,
DefaultTextModel,
DEFAULT_TEXT_EF_CONSTRUCTION,
DEFAULT_TEXT_M,
DEFAULT_TEXT_M0,
>;

/// Load the text database from disk, or create it if it does not already exist.
///
/// # Returns
///
/// A vector database for text.
pub fn create_or_load_database() -> Result<DefaultTextDatabase, Box<dyn std::error::Error>> {
DefaultTextDatabase::create_or_load_database(DefaultTextMetric {}, DefaultTextModel {})
}
5 changes: 5 additions & 0 deletions src/database/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
/// Core implementation of a database.
pub mod core;
#[cfg(feature = "default_db")]
/// Default configurations of databases.
pub mod default;
9 changes: 0 additions & 9 deletions src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,6 @@ use space::Metric;
/// The data type representing the distance between two embeddings.
pub type DistanceUnit = u64;

/// The data type representing the distance metric for text embeddings.
pub type DefaultTextMetric = L2SquaredDistance;

/// The data type representing the distance metric for image embeddings.
pub type DefaultImageMetric = CosineDistance;

/// The data type representing the distance metric for audio embeddings.
pub type DefaultAudioMetric = CosineDistance;

#[derive(Debug, Clone, Serialize, Deserialize)]
/// The cosine distance metric.
pub struct CosineDistance;
Expand Down
59 changes: 0 additions & 59 deletions src/image.rs

This file was deleted.

8 changes: 1 addition & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,9 @@
#![feature(doc_auto_cfg)]
#![warn(missing_docs)]

/// A module for audio database operations.
pub mod audio;
/// A module for database operations regardless of data type.
pub mod db;
pub mod database;
/// A module for distance metrics.
pub mod distance;
/// A module for image database operations.
pub mod image;
/// A module for embedding models.
pub mod model;
/// A module for text database operations.
pub mod text;
Loading

0 comments on commit 5321e5e

Please sign in to comment.