diff --git a/.gitignore b/.gitignore index 233fd09..e09e0a5 100644 --- a/.gitignore +++ b/.gitignore @@ -105,4 +105,5 @@ texts image.db images audio.db -audio \ No newline at end of file +audio +*_old.rs \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index cacb045..6664780 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,9 +51,10 @@ symphonia = "0.5.4" [features] default = [] +default_db = [] accelerate = ["candle-core/accelerate", "candle-examples/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] cuda = ["candle-core/cuda", "candle-examples/cuda", "candle-nn/cuda", "candle-transformers/cuda"] mkl = ["candle-core/mkl", "candle-examples/mkl", "candle-nn/mkl", "candle-transformers/mkl"] metal = ["candle-core/metal", "candle-examples/metal", "candle-nn/metal", "candle-transformers/metal"] sixel = ["viuer/sixel"] -cli = ["dep:clap", "dep:ticky", "dep:pretty-duration", "dep:indicatif", "dep:viuer", "dep:rodio"] +cli = ["default_db", "dep:clap", "dep:ticky", "dep:pretty-duration", "dep:indicatif", "dep:viuer", "dep:rodio"] diff --git a/src/audio.rs b/src/audio.rs deleted file mode 100644 index e205cbb..0000000 --- a/src/audio.rs +++ /dev/null @@ -1,23 +0,0 @@ -use crate::db::{Database, DocumentType}; -use crate::distance::{CosineDistance, DefaultAudioMetric}; - -/// A parameter regarding insertion into the HNSW graph. Higher values result in more accurate search results at the expense of slower retrieval speeds. Cannot be changed after database creation. -pub const AUDIO_EF_CONSTRUCTION: usize = 400; - -/// The number of bi-directional links created for each node in the HNSW graph. Cannot be changed after database creation. Increases memory usage and decreases retrieval speed with higher values. -pub const AUDIO_M: usize = 12; - -/// The number of bi-directional links created for each node in the HNSW graph in the first layer. Cannot be changed after database creation. -pub const AUDIO_M0: usize = 24; - -/// A database containing sounds and their embeddings. -pub type AudioDatabase = Database; - -/// Load the audio database from disk, or create it if it does not already exist. -/// -/// # Returns -/// -/// A vector database for audio. -pub fn create_or_load_database() -> Result> { - AudioDatabase::create_or_load_database(CosineDistance, DocumentType::Audio) -} diff --git a/src/db.rs b/src/database/core.rs similarity index 89% rename from src/db.rs rename to src/database/core.rs index 4c2630a..551544e 100644 --- a/src/db.rs +++ b/src/database/core.rs @@ -1,5 +1,5 @@ use crate::distance::DistanceUnit; -use crate::model::DatabaseEmbeddingModel; +use crate::model::core::DatabaseEmbeddingModel; use bytes::Bytes; use fastembed::Embedding; use hnsw::Params; @@ -67,6 +67,7 @@ impl DocumentType { /// * `M0` - The number of bi-directional links created for each node in the HNSW graph in the first layer. Cannot be changed after database creation. pub struct Database< Met: Metric + Serialize, + Model: DatabaseEmbeddingModel + Serialize, const EF_CONSTRUCTION: usize, const M: usize, const M0: usize, @@ -74,17 +75,20 @@ pub struct Database< /// The Hierarchical Navigable Small World (HNSW) graph containing the embeddings. pub hnsw: Hnsw, /// The type of documents stored in the database. - pub document_type: DocumentType, + // pub document_type: DocumentType, + pub model: Model, } impl< Met: Metric + Serialize, + Model: DatabaseEmbeddingModel + Serialize, const EF_CONSTRUCTION: usize, const M: usize, const M0: usize, - > Database + > Database where for<'de> Met: Deserialize<'de>, + for<'de> Model: Deserialize<'de>, { /// Load the database from disk, or create it if it does not already exist. /// @@ -97,10 +101,8 @@ where /// # Returns /// /// A database containing a HNSW graph and the inserted documents. - pub fn create_or_load_database( - metric: Met, - document_type: DocumentType, - ) -> Result> { + pub fn create_or_load_database(metric: Met, model: Model) -> Result> { + let document_type = model.document_type(); let db_bytes = fs::read(document_type.database_name()); match db_bytes { Ok(bytes) => { @@ -109,10 +111,7 @@ where } Err(_) => { let hnsw = Hnsw::new_params(metric, Params::new().ef_construction(EF_CONSTRUCTION)); - let db = Database { - hnsw, - document_type, - }; + let db = Database { hnsw, model }; let db_bytes = bincode::serialize(&db)?; fs::write(document_type.database_name(), db_bytes)?; Ok(db) @@ -123,7 +122,7 @@ where /// Save the database to disk. pub fn save_database(&self) -> Result<(), Box> { let db_bytes = bincode::serialize(&self)?; - fs::write(self.document_type.database_name(), db_bytes)?; + fs::write(self.model.document_type().database_name(), db_bytes)?; Ok(()) } @@ -138,12 +137,11 @@ where /// # Returns /// /// A tuple containing the number of embeddings inserted and the dimension of the embeddings. - pub fn insert_documents( + pub fn insert_documents( &mut self, - model: &Mod, documents: Vec, ) -> Result<(usize, usize), Box> { - let new_embeddings: Vec = model.embed_documents(documents.to_vec())?; + let new_embeddings: Vec = self.model.embed_documents(documents.to_vec())?; let length_and_dimension = (new_embeddings.len(), new_embeddings[0].len()); let records: Vec<_> = new_embeddings .into_par_iter() @@ -186,9 +184,8 @@ where /// # Returns /// /// A vector of documents that are most similar to the queried documents. - pub fn query_documents( + pub fn query_documents( &mut self, - model: &Mod, documents: Vec, number_of_results: usize, ) -> Result>, Box> { @@ -197,7 +194,7 @@ where } let mut searcher: Searcher = Searcher::default(); let mut results = Vec::new(); - let query_embeddings = model.embed_documents(documents)?; + let query_embeddings = self.model.embed_documents(documents)?; for query_embedding in query_embeddings.iter() { let mut neighbours = Vec::new(); self.hnsw.nearest( @@ -232,7 +229,8 @@ where &self, documents: &mut HashMap, ) -> Result<(), Box> { - let document_subdirectory = self.document_type.subdirectory_name(); + let document_type = self.model.document_type(); + let document_subdirectory = document_type.subdirectory_name(); std::fs::create_dir_all(document_subdirectory)?; for document in documents { let mut reader = BufReader::new(document.1.as_ref()); @@ -262,7 +260,8 @@ where &self, documents: &mut Vec, ) -> Result>, Box> { - let document_subdirectory = self.document_type.subdirectory_name(); + let document_type = self.model.document_type(); + let document_subdirectory = document_type.subdirectory_name(); let mut results = HashMap::new(); for document_index in documents { let file = OpenOptions::new() diff --git a/src/database/default/audio.rs b/src/database/default/audio.rs new file mode 100644 index 0000000..b757f77 --- /dev/null +++ b/src/database/default/audio.rs @@ -0,0 +1,36 @@ +use crate::database::core::Database; +use crate::distance::CosineDistance; +use crate::model::audio::VitBasePatch16_224; + +/// The default distance metric for audio embeddings. +pub type DefaultAudioMetric = CosineDistance; + +/// The default embedding model for audio embeddings. +pub type DefaultAudioModel = VitBasePatch16_224; + +/// A parameter regarding insertion into the HNSW graph. Higher values result in more accurate search results at the expense of slower retrieval speeds. Cannot be changed after database creation. +pub const DEFAULT_AUDIO_EF_CONSTRUCTION: usize = 400; + +/// The number of bi-directional links created for each node in the HNSW graph. Cannot be changed after database creation. Increases memory usage and decreases retrieval speed with higher values. +pub const DEFAULT_AUDIO_M: usize = 12; + +/// The number of bi-directional links created for each node in the HNSW graph in the first layer. Cannot be changed after database creation. +pub const DEFAULT_AUDIO_M0: usize = 24; + +/// A database containing sounds and their embeddings. +pub type DefaultAudioDatabase = Database< + DefaultAudioMetric, + DefaultAudioModel, + DEFAULT_AUDIO_EF_CONSTRUCTION, + DEFAULT_AUDIO_M, + DEFAULT_AUDIO_M0, +>; + +/// Load the audio database from disk, or create it if it does not already exist. +/// +/// # Returns +/// +/// A vector database for audio. +pub fn create_or_load_database() -> Result> { + DefaultAudioDatabase::create_or_load_database(DefaultAudioMetric {}, DefaultAudioModel {}) +} diff --git a/src/database/default/image.rs b/src/database/default/image.rs new file mode 100644 index 0000000..5a7b311 --- /dev/null +++ b/src/database/default/image.rs @@ -0,0 +1,36 @@ +use crate::database::core::Database; +use crate::distance::CosineDistance; +use crate::model::image::VitBasePatch16_224; + +/// The default distance metric for image embeddings. +pub type DefaultImageMetric = CosineDistance; + +/// The default embedding model for image embeddings. +pub type DefaultImageModel = VitBasePatch16_224; + +/// A parameter regarding insertion into the HNSW graph. Higher values result in more accurate search results at the expense of slower retrieval speeds. Cannot be changed after database creation. +pub const DEFAULT_IMAGE_EF_CONSTRUCTION: usize = 400; + +/// The number of bi-directional links created for each node in the HNSW graph. Cannot be changed after database creation. Increases memory usage and decreases retrieval speed with higher values. +pub const DEFAULT_IMAGE_M: usize = 12; + +/// The number of bi-directional links created for each node in the HNSW graph in the first layer. Cannot be changed after database creation. +pub const DEFAULT_IMAGE_M0: usize = 24; + +/// A database containing images and their embeddings. +pub type DefaultImageDatabase = Database< + DefaultImageMetric, + DefaultImageModel, + DEFAULT_IMAGE_EF_CONSTRUCTION, + DEFAULT_IMAGE_M, + DEFAULT_IMAGE_M0, +>; + +/// Load the image database from disk, or create it if it does not already exist. +/// +/// # Returns +/// +/// A vector database for images. +pub fn create_or_load_database() -> Result> { + DefaultImageDatabase::create_or_load_database(DefaultImageMetric {}, DefaultImageModel {}) +} diff --git a/src/database/default/mod.rs b/src/database/default/mod.rs new file mode 100644 index 0000000..6427f52 --- /dev/null +++ b/src/database/default/mod.rs @@ -0,0 +1,6 @@ +/// Default configuration for an audio database. +pub mod audio; +/// Default configuration for an image database. +pub mod image; +/// Default configuration for a text database. +pub mod text; diff --git a/src/database/default/text.rs b/src/database/default/text.rs new file mode 100644 index 0000000..168ecbe --- /dev/null +++ b/src/database/default/text.rs @@ -0,0 +1,36 @@ +use crate::database::core::Database; +use crate::distance::L2SquaredDistance; +use crate::model::text::BGESmallEn1_5; + +/// The default distance metric for text embeddings. +pub type DefaultTextMetric = L2SquaredDistance; + +/// The default embedding model for text embeddings. +pub type DefaultTextModel = BGESmallEn1_5; + +/// A parameter regarding insertion into the HNSW graph. Higher values result in more accurate search results at the expense of slower retrieval speeds. Cannot be changed after database creation. +pub const DEFAULT_TEXT_EF_CONSTRUCTION: usize = 400; + +/// The number of bi-directional links created for each node in the HNSW graph. Cannot be changed after database creation. Increases memory usage and decreases retrieval speed with higher values. +pub const DEFAULT_TEXT_M: usize = 12; + +/// The number of bi-directional links created for each node in the HNSW graph in the first layer. Cannot be changed after database creation. +pub const DEFAULT_TEXT_M0: usize = 24; + +/// A database containing texts and their embeddings. +pub type DefaultTextDatabase = Database< + DefaultTextMetric, + DefaultTextModel, + DEFAULT_TEXT_EF_CONSTRUCTION, + DEFAULT_TEXT_M, + DEFAULT_TEXT_M0, +>; + +/// Load the text database from disk, or create it if it does not already exist. +/// +/// # Returns +/// +/// A vector database for text. +pub fn create_or_load_database() -> Result> { + DefaultTextDatabase::create_or_load_database(DefaultTextMetric {}, DefaultTextModel {}) +} diff --git a/src/database/mod.rs b/src/database/mod.rs new file mode 100644 index 0000000..41f155e --- /dev/null +++ b/src/database/mod.rs @@ -0,0 +1,5 @@ +/// Core implementation of a database. +pub mod core; +#[cfg(feature = "default_db")] +/// Default configurations of databases. +pub mod default; diff --git a/src/distance.rs b/src/distance.rs index a3a518e..83f4d27 100644 --- a/src/distance.rs +++ b/src/distance.rs @@ -10,15 +10,6 @@ use space::Metric; /// The data type representing the distance between two embeddings. pub type DistanceUnit = u64; -/// The data type representing the distance metric for text embeddings. -pub type DefaultTextMetric = L2SquaredDistance; - -/// The data type representing the distance metric for image embeddings. -pub type DefaultImageMetric = CosineDistance; - -/// The data type representing the distance metric for audio embeddings. -pub type DefaultAudioMetric = CosineDistance; - #[derive(Debug, Clone, Serialize, Deserialize)] /// The cosine distance metric. pub struct CosineDistance; diff --git a/src/image.rs b/src/image.rs deleted file mode 100644 index 87ce7eb..0000000 --- a/src/image.rs +++ /dev/null @@ -1,59 +0,0 @@ -use crate::db::{Database, DocumentType}; -use crate::distance::{CosineDistance, DefaultImageMetric}; -use bytes::Bytes; -use candle_core::Tensor; -use candle_examples::imagenet::{IMAGENET_MEAN, IMAGENET_STD}; -use image::ImageReader; -use std::error::Error; -use std::io::Cursor; - -/// A parameter regarding insertion into the HNSW graph. Higher values result in more accurate search results at the expense of slower retrieval speeds. Cannot be changed after database creation. -pub const IMAGE_EF_CONSTRUCTION: usize = 400; - -/// The number of bi-directional links created for each node in the HNSW graph. Cannot be changed after database creation. Increases memory usage and decreases retrieval speed with higher values. -pub const IMAGE_M: usize = 12; - -/// The number of bi-directional links created for each node in the HNSW graph in the first layer. Cannot be changed after database creation. -pub const IMAGE_M0: usize = 24; - -/// A database containing images and their embeddings. -pub type ImageDatabase = Database; - -/// Load the image database from disk, or create it if it does not already exist. -/// -/// # Returns -/// -/// A vector database for images. -pub fn create_or_load_database() -> Result> { - ImageDatabase::create_or_load_database(CosineDistance, DocumentType::Image) -} - -/// Loads an image from raw bytes with ImageNet normalisation applied, returning a tensor with the shape [3 224 224]. -/// -/// # Arguments -/// -/// * `bytes` - The raw bytes of an image. -/// -/// # Returns -/// -/// A tensor with the shape [3 224 224]; ImageNet normalisation is applied. -pub fn load_image224(bytes: Bytes) -> Result> { - let res = 224_usize; - let img = ImageReader::new(Cursor::new(bytes)) - .with_guessed_format()? - .decode()? - .resize_to_fill( - res as u32, - res as u32, - image::imageops::FilterType::Triangle, - ) - .to_rgb8(); - let data = img.into_raw(); - let data = - Tensor::from_vec(data, (res, res, 3), &candle_core::Device::Cpu)?.permute((2, 0, 1))?; - let mean = Tensor::new(&IMAGENET_MEAN, &candle_core::Device::Cpu)?.reshape((3, 1, 1))?; - let std = Tensor::new(&IMAGENET_STD, &candle_core::Device::Cpu)?.reshape((3, 1, 1))?; - Ok((data.to_dtype(candle_core::DType::F32)? / 255.)? - .broadcast_sub(&mean)? - .broadcast_div(&std)?) -} diff --git a/src/lib.rs b/src/lib.rs index 5c278c5..13211ea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,15 +2,9 @@ #![feature(doc_auto_cfg)] #![warn(missing_docs)] -/// A module for audio database operations. -pub mod audio; /// A module for database operations regardless of data type. -pub mod db; +pub mod database; /// A module for distance metrics. pub mod distance; -/// A module for image database operations. -pub mod image; /// A module for embedding models. pub mod model; -/// A module for text database operations. -pub mod text; diff --git a/src/main.rs b/src/main.rs index 4579eee..2a86860 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,6 @@ use bytes::Bytes; use clap::{command, Parser, Subcommand}; use fastembed::Embedding; -use fastembed::TextEmbedding; use indicatif::HumanCount; use indicatif::ProgressStyle; use indicatif::{ProgressBar, ProgressDrawTarget}; @@ -18,10 +17,10 @@ use std::io::Write; use std::io::{stdout, BufWriter}; use std::path::PathBuf; use ticky::Stopwatch; -use zebra::db::Database; -use zebra::db::DocumentType; +use zebra::database::core::Database; +use zebra::database::core::DocumentType; use zebra::distance::DistanceUnit; -use zebra::model::{AudioEmbeddingModel, DatabaseEmbeddingModel, ImageEmbeddingModel}; +use zebra::model::core::DatabaseEmbeddingModel; #[derive(Parser)] #[command(version, about, long_about = None, arg_required_else_help(true))] @@ -138,12 +137,11 @@ fn main() -> Result<(), Box> { Commands::Text(text) => match text.text_commands { TextCommands::Insert { texts } => { let mut sw = Stopwatch::start_new(); - let mut db = zebra::text::create_or_load_database()?; + let mut db = zebra::database::default::text::create_or_load_database()?; let mut buffer = BufWriter::new(stdout().lock()); - let model: TextEmbedding = DatabaseEmbeddingModel::new()?; writeln!(buffer, "Inserting {} text(s).", texts.len())?; let texts_bytes: Vec<_> = texts.into_par_iter().map(|x| Bytes::from(x)).collect(); - let insertion_results = db.insert_documents(&model, texts_bytes)?; + let insertion_results = db.insert_documents(texts_bytes)?; sw.stop(); writeln!( buffer, @@ -157,22 +155,20 @@ fn main() -> Result<(), Box> { file_paths, batch_size, } => { - let mut db = zebra::text::create_or_load_database()?; - let model: TextEmbedding = DatabaseEmbeddingModel::new()?; - insert_from_files(&mut db, model, file_paths, batch_size)?; + let mut db = zebra::database::default::text::create_or_load_database()?; + insert_from_files(&mut db, file_paths, batch_size)?; } TextCommands::Query { texts, number_of_results, } => { let mut sw = Stopwatch::start_new(); - let mut db = zebra::text::create_or_load_database()?; + let mut db = zebra::database::default::text::create_or_load_database()?; let mut buffer = BufWriter::new(stdout().lock()); let num_texts = texts.len(); - let model: TextEmbedding = DatabaseEmbeddingModel::new()?; writeln!(buffer, "Querying {} text(s).", num_texts)?; let texts_bytes: Vec<_> = texts.into_par_iter().map(|x| Bytes::from(x)).collect(); - let query_results = db.query_documents(&model, texts_bytes, number_of_results)?; + let query_results = db.query_documents(texts_bytes, number_of_results)?; let result_texts: Vec<_> = query_results .iter() .map(|x| String::from_utf8_lossy(x)) @@ -198,16 +194,15 @@ fn main() -> Result<(), Box> { file_paths, batch_size, } => { - let mut db = zebra::image::create_or_load_database()?; - let model: ImageEmbeddingModel = DatabaseEmbeddingModel::new()?; - insert_from_files(&mut db, model, file_paths, batch_size)?; + let mut db = zebra::database::default::image::create_or_load_database()?; + insert_from_files(&mut db, file_paths, batch_size)?; } ImageCommands::Query { image_path, number_of_results, } => { let mut sw = Stopwatch::start_new(); - let mut db = zebra::image::create_or_load_database()?; + let mut db = zebra::database::default::image::create_or_load_database()?; let mut buffer = BufWriter::new(stdout().lock()); let image_print_config = viuer::Config { transparent: true, @@ -224,11 +219,9 @@ fn main() -> Result<(), Box> { #[cfg(feature = "sixel")] use_sixel: true, }; - let model: ImageEmbeddingModel = DatabaseEmbeddingModel::new()?; writeln!(buffer, "Querying image.")?; let image_bytes = std::fs::read(image_path).unwrap_or_default().into(); - let query_results = - db.query_documents(&model, vec![image_bytes], number_of_results)?; + let query_results = db.query_documents(vec![image_bytes], number_of_results)?; sw.stop(); writeln!( buffer, @@ -250,24 +243,21 @@ fn main() -> Result<(), Box> { file_paths, batch_size, } => { - let mut db = zebra::audio::create_or_load_database()?; - let model: AudioEmbeddingModel = DatabaseEmbeddingModel::new()?; - insert_from_files(&mut db, model, file_paths, batch_size)?; + let mut db = zebra::database::default::audio::create_or_load_database()?; + insert_from_files(&mut db, file_paths, batch_size)?; } AudioCommands::Query { audio_path, number_of_results, } => { let mut sw = Stopwatch::start_new(); - let mut db = zebra::audio::create_or_load_database()?; + let mut db = zebra::database::default::audio::create_or_load_database()?; let (_stream, stream_handle) = OutputStream::try_default()?; let sink = Sink::try_new(&stream_handle)?; let mut buffer = BufWriter::new(stdout().lock()); - let model: AudioEmbeddingModel = DatabaseEmbeddingModel::new()?; writeln!(buffer, "Querying sound.")?; let audio_bytes = std::fs::read(audio_path).unwrap_or_default().into(); - let query_results = - db.query_documents(&model, vec![audio_bytes], number_of_results)?; + let query_results = db.query_documents(vec![audio_bytes], number_of_results)?; sw.stop(); writeln!( buffer, @@ -316,17 +306,18 @@ fn clear_database(document_type: DocumentType) -> Result<(), Box> { fn insert_from_files< Met: Metric + serde::ser::Serialize, + Model: DatabaseEmbeddingModel + serde::ser::Serialize, const EF_CONSTRUCTION: usize, const M: usize, const M0: usize, >( - db: &mut Database, - model: impl DatabaseEmbeddingModel, + db: &mut Database, file_paths: Vec, batch_size: usize, ) -> Result<(), Box> where for<'de> Met: serde::Deserialize<'de>, + for<'de> Model: serde::Deserialize<'de>, { let mut sw = Stopwatch::start_new(); let num_documents = file_paths.len(); @@ -346,7 +337,7 @@ where // Insert documents in batches. for document_batch in documents.chunks(batch_size) { let mut batch_sw = Stopwatch::start_new(); - let insertion_results = db.insert_documents(&model, document_batch.to_vec())?; + let insertion_results = db.insert_documents(document_batch.to_vec())?; batch_sw.stop(); progress_bar.println(format!( "{} embeddings of {} dimensions inserted into the database in {}.", diff --git a/src/model.rs b/src/model/audio.rs similarity index 50% rename from src/model.rs rename to src/model/audio.rs index e178e74..9226c5d 100644 --- a/src/model.rs +++ b/src/model/audio.rs @@ -1,13 +1,16 @@ -use crate::image::load_image224; +use super::core::DatabaseEmbeddingModel; +use super::image::ImageEmbeddingModel; +use crate::database::core::DocumentType; use bytes::Bytes; use candle_core::DType; -use candle_core::Device; use candle_core::Tensor; use candle_nn::VarBuilder; use candle_transformers::models::vit; -use fastembed::{Embedding, EmbeddingModel, InitOptions, TextEmbedding}; +use fastembed::Embedding; use rayon::iter::IntoParallelIterator; use rayon::iter::ParallelIterator; +use serde::Deserialize; +use serde::Serialize; use sonogram::ColourGradient; use sonogram::FrequencyScale; use sonogram::SpecOptionsBuilder; @@ -21,111 +24,8 @@ use symphonia::core::io::MediaSourceStream; use symphonia::core::meta::MetadataOptions; use symphonia::core::probe::Hint; -/// A trait for embedding models that can be used with the database. -pub trait DatabaseEmbeddingModel { - /// Create a new instance of the embedding model. - fn new() -> Result> - where - Self: Sized; - - /// Embed a vector of documents. - /// - /// # Arguments - /// - /// * `documents` - A vector of documents to be embedded. - /// - /// # Returns - /// - /// A vector of embeddings. - fn embed_documents(&self, documents: Vec) -> Result, Box>; - - /// Embed a single document. - /// - /// # Arguments - /// - /// * `document` – A single document to be embedded. - /// - /// # Returns - /// - /// An embedding vector. - fn embed(&self, document: Bytes) -> Result>; -} - -impl DatabaseEmbeddingModel for TextEmbedding { - fn new() -> Result> { - Ok(TextEmbedding::try_new( - InitOptions::new(EmbeddingModel::BGESmallENV15).with_show_download_progress(false), - )?) - } - fn embed_documents(&self, documents: Vec) -> Result, Box> { - Ok(self.embed( - documents - .into_par_iter() - .map(|x| x.to_vec()) - .filter_map(|x| String::from_utf8(x).ok()) - .collect(), - None, - )?) - } - - fn embed(&self, document: Bytes) -> Result> { - let vec_with_document = vec![document] - .into_par_iter() - .map(|x| x.to_vec()) - .filter_map(|x| String::from_utf8(x).ok()) - .collect(); - let vector_of_embeddings = self.embed(vec_with_document, None)?; - Ok(vector_of_embeddings.first().unwrap().to_vec()) - } -} - -/// A model for embedding images. -pub struct ImageEmbeddingModel; - -impl DatabaseEmbeddingModel for ImageEmbeddingModel { - fn new() -> Result> { - Ok(Self) - } - fn embed_documents(&self, documents: Vec) -> Result, Box> { - let mut result = Vec::new(); - let device = candle_examples::device(false)?; - let api = hf_hub::api::sync::Api::new()?; - let api = api.model("google/vit-base-patch16-224".into()); - let model_file = api.get("model.safetensors")?; - let varbuilder = - unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; - let model = vit::Embeddings::new( - &vit::Config::vit_base_patch16_224(), - false, - varbuilder.pp("vit").pp("embeddings"), - )?; - for document in documents { - let image = load_image224(document)?.to_device(&device)?; - let embedding_tensors = model.forward(&image.unsqueeze(0)?, None, false)?; - let embedding_vector = embedding_tensors.flatten_all()?.to_vec1::()?; - result.push(embedding_vector); - } - Ok(result) - } - fn embed(&self, document: Bytes) -> Result> { - let device = candle_examples::device(false)?; - let image = load_image224(document)?.to_device(&device)?; - let api = hf_hub::api::sync::Api::new()?; - let api = api.model("google/vit-base-patch16-224".into()); - let model_file = api.get("model.safetensors")?; - let varbuilder = - unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; - let model = vit::Embeddings::new(&vit::Config::vit_base_patch16_224(), false, varbuilder)?; - let embedding_tensors = model.forward(&image.unsqueeze(0)?, None, false)?; - let embedding_vector = embedding_tensors.to_vec1::()?; - Ok(embedding_vector) - } -} - -/// A model for embedding audio. -pub struct AudioEmbeddingModel; - -impl AudioEmbeddingModel { +/// A trait for audio embedding models; these models are a subset of image embedding models. +pub trait AudioEmbeddingModel: ImageEmbeddingModel { /// Decodes the samples of an audio files. /// /// # Arguments @@ -135,7 +35,7 @@ impl AudioEmbeddingModel { /// # Returns /// /// An `i16` vector of decoded samples, and the sample rate of the audio. - pub fn audio_to_data(audio: Bytes) -> Result<(Vec, u32), Box> { + fn audio_to_data(audio: Bytes) -> Result<(Vec, u32), Box> { let mss = MediaSourceStream::new(Box::new(Cursor::new(audio)), Default::default()); let meta_opts: MetadataOptions = Default::default(); let fmt_opts: FormatOptions = Default::default(); @@ -191,7 +91,7 @@ impl AudioEmbeddingModel { /// # Returns /// /// A spectrogram of the audio as an ImageNet-normalised tensor with shape [3 224 224]. - pub fn audio_to_image_tensor(audio: Bytes) -> Result> { + fn audio_to_image_tensor224(&self, audio: Bytes) -> Result> { let (data, sample_rate) = Self::audio_to_data(audio)?; let mut spectrograph = SpecOptionsBuilder::new(512) .load_data_from_memory(data, sample_rate) @@ -202,23 +102,19 @@ impl AudioEmbeddingModel { let mut gradient = ColourGradient::rainbow_theme(); let png_bytes = spectrogram.to_png_in_memory(FrequencyScale::Log, &mut gradient, 224, 224)?; - let img = image::load_from_memory_with_format(&png_bytes, image::ImageFormat::Png) - .map_err(candle_core::Error::wrap)? - .resize_to_fill(224, 224, image::imageops::FilterType::Triangle); - let img = img.to_rgb8(); - let data = img.into_raw(); - let data = Tensor::from_vec(data, (224, 224, 3), &Device::Cpu)?.permute((2, 0, 1))?; - let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?; - let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?; - Ok((data.to_dtype(DType::F32)? / 255.)? - .broadcast_sub(&mean)? - .broadcast_div(&std)?) + self.load_image224(png_bytes.into()) } } -impl DatabaseEmbeddingModel for AudioEmbeddingModel { - fn new() -> Result> { - Ok(Self) +/// A model for embedding audio. +#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct VitBasePatch16_224; +impl ImageEmbeddingModel for VitBasePatch16_224 {} +impl AudioEmbeddingModel for VitBasePatch16_224 {} + +impl DatabaseEmbeddingModel for VitBasePatch16_224 { + fn document_type(&self) -> DocumentType { + DocumentType::Audio } fn embed_documents(&self, documents: Vec) -> Result, Box> { let mut result = Vec::new(); @@ -234,7 +130,9 @@ impl DatabaseEmbeddingModel for AudioEmbeddingModel { varbuilder.pp("vit").pp("embeddings"), )?; for document in documents { - let image = AudioEmbeddingModel::audio_to_image_tensor(document)?.to_device(&device)?; + let image = self + .audio_to_image_tensor224(document)? + .to_device(&device)?; let embedding_tensors = model.forward(&image.unsqueeze(0)?, None, false)?; let embedding_vector = embedding_tensors.flatten_all()?.to_vec1::()?; result.push(embedding_vector); @@ -243,7 +141,9 @@ impl DatabaseEmbeddingModel for AudioEmbeddingModel { } fn embed(&self, document: Bytes) -> Result> { let device = candle_examples::device(false)?; - let image = AudioEmbeddingModel::audio_to_image_tensor(document)?.to_device(&device)?; + let image = self + .audio_to_image_tensor224(document)? + .to_device(&device)?; let api = hf_hub::api::sync::Api::new()?; let api = api.model("google/vit-base-patch16-224".into()); let model_file = api.get("model.safetensors")?; diff --git a/src/model/core.rs b/src/model/core.rs new file mode 100644 index 0000000..f6d435a --- /dev/null +++ b/src/model/core.rs @@ -0,0 +1,36 @@ +use crate::database::core::DocumentType; +use bytes::Bytes; +use fastembed::Embedding; +use std::error::Error; + +/// A trait for embedding models that can be used with the database. +pub trait DatabaseEmbeddingModel { + /// The type of document that can be embedded by this model. + /// + /// # Returns + /// + /// The document type supported by this database. + fn document_type(&self) -> DocumentType; + + /// Embed a vector of documents. + /// + /// # Arguments + /// + /// * `documents` - A vector of documents to be embedded. + /// + /// # Returns + /// + /// A vector of embeddings. + fn embed_documents(&self, documents: Vec) -> Result, Box>; + + /// Embed a single document. + /// + /// # Arguments + /// + /// * `document` – A single document to be embedded. + /// + /// # Returns + /// + /// An embedding vector. + fn embed(&self, document: Bytes) -> Result>; +} diff --git a/src/model/image.rs b/src/model/image.rs new file mode 100644 index 0000000..31820e7 --- /dev/null +++ b/src/model/image.rs @@ -0,0 +1,93 @@ +use super::core::DatabaseEmbeddingModel; +use crate::database::core::DocumentType; +use bytes::Bytes; +use candle_core::DType; +use candle_core::Tensor; +use candle_examples::imagenet::IMAGENET_MEAN; +use candle_examples::imagenet::IMAGENET_STD; +use candle_nn::VarBuilder; +use candle_transformers::models::vit; +use fastembed::Embedding; +use image::ImageReader; +use serde::Deserialize; +use serde::Serialize; +use std::error::Error; +use std::io::Cursor; + +/// A trait for image embedding models. +pub trait ImageEmbeddingModel { + /// Loads an image from raw bytes with ImageNet normalisation applied, returning a tensor with the shape [3 224 224]. + /// + /// # Arguments + /// + /// * `bytes` - The raw bytes of an image. + /// + /// # Returns + /// + /// A tensor with the shape [3 224 224]; ImageNet normalisation is applied. + fn load_image224(&self, bytes: Bytes) -> Result> { + let res = 224_usize; + let img = ImageReader::new(Cursor::new(bytes)) + .with_guessed_format()? + .decode()? + .resize_to_fill( + res as u32, + res as u32, + image::imageops::FilterType::Triangle, + ) + .to_rgb8(); + let data = img.into_raw(); + let data = + Tensor::from_vec(data, (res, res, 3), &candle_core::Device::Cpu)?.permute((2, 0, 1))?; + let mean = Tensor::new(&IMAGENET_MEAN, &candle_core::Device::Cpu)?.reshape((3, 1, 1))?; + let std = Tensor::new(&IMAGENET_STD, &candle_core::Device::Cpu)?.reshape((3, 1, 1))?; + Ok((data.to_dtype(candle_core::DType::F32)? / 255.)? + .broadcast_sub(&mean)? + .broadcast_div(&std)?) + } +} + +/// A model for embedding images. +#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct VitBasePatch16_224; +impl ImageEmbeddingModel for VitBasePatch16_224 {} + +impl DatabaseEmbeddingModel for VitBasePatch16_224 { + fn document_type(&self) -> DocumentType { + DocumentType::Image + } + fn embed_documents(&self, documents: Vec) -> Result, Box> { + let mut result = Vec::new(); + let device = candle_examples::device(false)?; + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("google/vit-base-patch16-224".into()); + let model_file = api.get("model.safetensors")?; + let varbuilder = + unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let model = vit::Embeddings::new( + &vit::Config::vit_base_patch16_224(), + false, + varbuilder.pp("vit").pp("embeddings"), + )?; + for document in documents { + let image = self.load_image224(document)?.to_device(&device)?; + let embedding_tensors = model.forward(&image.unsqueeze(0)?, None, false)?; + let embedding_vector = embedding_tensors.flatten_all()?.to_vec1::()?; + result.push(embedding_vector); + } + Ok(result) + } + fn embed(&self, document: Bytes) -> Result> { + let device = candle_examples::device(false)?; + let image = self.load_image224(document)?.to_device(&device)?; + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("google/vit-base-patch16-224".into()); + let model_file = api.get("model.safetensors")?; + let varbuilder = + unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let model = vit::Embeddings::new(&vit::Config::vit_base_patch16_224(), false, varbuilder)?; + let embedding_tensors = model.forward(&image.unsqueeze(0)?, None, false)?; + let embedding_vector = embedding_tensors.to_vec1::()?; + Ok(embedding_vector) + } +} diff --git a/src/model/mod.rs b/src/model/mod.rs new file mode 100644 index 0000000..566138c --- /dev/null +++ b/src/model/mod.rs @@ -0,0 +1,8 @@ +/// Audio embedding models. +pub mod audio; +/// Core embedding implementation. +pub mod core; +/// Image embedding models. +pub mod image; +/// Text embedding models. +pub mod text; diff --git a/src/model/text.rs b/src/model/text.rs new file mode 100644 index 0000000..403db84 --- /dev/null +++ b/src/model/text.rs @@ -0,0 +1,45 @@ +use super::core::DatabaseEmbeddingModel; +use crate::database::core::DocumentType; +use bytes::Bytes; +use fastembed::{Embedding, EmbeddingModel, InitOptions, TextEmbedding}; +use rayon::iter::IntoParallelIterator; +use rayon::iter::ParallelIterator; +use serde::Deserialize; +use serde::Serialize; +use std::error::Error; + +/// A model for embedding images. +#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct BGESmallEn1_5; + +impl DatabaseEmbeddingModel for BGESmallEn1_5 { + fn document_type(&self) -> DocumentType { + DocumentType::Text + } + fn embed_documents(&self, documents: Vec) -> Result, Box> { + let model = TextEmbedding::try_new( + InitOptions::new(EmbeddingModel::BGESmallENV15).with_show_download_progress(false), + )?; + Ok(model.embed( + documents + .into_par_iter() + .map(|x| x.to_vec()) + .filter_map(|x| String::from_utf8(x).ok()) + .collect(), + None, + )?) + } + + fn embed(&self, document: Bytes) -> Result> { + let model = TextEmbedding::try_new( + InitOptions::new(EmbeddingModel::BGESmallENV15).with_show_download_progress(false), + )?; + let vec_with_document = vec![document] + .into_par_iter() + .map(|x| x.to_vec()) + .filter_map(|x| String::from_utf8(x).ok()) + .collect(); + let vector_of_embeddings = model.embed(vec_with_document, None)?; + Ok(vector_of_embeddings.first().unwrap().to_vec()) + } +} diff --git a/src/text.rs b/src/text.rs deleted file mode 100644 index 49e6df6..0000000 --- a/src/text.rs +++ /dev/null @@ -1,23 +0,0 @@ -use crate::db::{Database, DocumentType}; -use crate::distance::{DefaultTextMetric, L2SquaredDistance}; - -/// A parameter regarding insertion into the HNSW graph. Higher values result in more accurate search results at the expense of slower retrieval speeds. Cannot be changed after database creation. -pub const TEXT_EF_CONSTRUCTION: usize = 400; - -/// The number of bi-directional links created for each node in the HNSW graph. Cannot be changed after database creation. Increases memory usage and decreases retrieval speed with higher values. -pub const TEXT_M: usize = 12; - -/// The number of bi-directional links created for each node in the HNSW graph in the first layer. Cannot be changed after database creation. -pub const TEXT_M0: usize = 24; - -/// A database containing texts and their embeddings. -pub type TextDatabase = Database; - -/// Load the text database from disk, or create it if it does not already exist. -/// -/// # Returns -/// -/// A vector database for text. -pub fn create_or_load_database() -> Result> { - TextDatabase::create_or_load_database(L2SquaredDistance, DocumentType::Text) -}