Skip to content

Commit

Permalink
feat: Read bytes, not paths
Browse files Browse the repository at this point in the history
  • Loading branch information
emmyoh committed Dec 12, 2024
1 parent b5a3e35 commit 8b21275
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 88 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ viuer = { version = "0.9.1", features = ["print-file"], optional = true }
sonogram = "0.7.1"
image = "0.25.5"
rodio = { version = "0.20.1", optional = true }
rayon = "1.10.0"
bytes = { version = "1.9.0", features = ["serde"] }
symphonia = "0.5.4"

[features]
default = []
Expand Down
48 changes: 20 additions & 28 deletions src/db.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use crate::distance::DistanceUnit;
use crate::model::DatabaseEmbeddingModel;
use crate::EF;
use bytes::Bytes;
use fastembed::Embedding;
use hnsw::Params;
use hnsw::{Hnsw, Searcher};
use pcg_rand::Pcg64;
use serde::{Deserialize, Serialize};
use space::{Metric, Neighbor};
use space::Metric;
use std::collections::HashMap;
use std::fs::OpenOptions;
use std::io::{self, BufReader, BufWriter};
Expand Down Expand Up @@ -135,20 +135,20 @@ where
/// # Returns
///
/// A tuple containing the number of embeddings inserted and the dimension of the embeddings.
pub fn insert_documents<S: AsRef<str> + Send + Sync + Clone, Mod: DatabaseEmbeddingModel>(
pub fn insert_documents<Mod: DatabaseEmbeddingModel>(
&mut self,
model: &Mod,
documents: &[S],
documents: &[Bytes],
) -> Result<(usize, usize), Box<dyn Error>> {
let new_embeddings: Vec<Embedding> = model.embed_documents(documents.to_vec())?;
let length_and_dimension = (new_embeddings.len(), new_embeddings[0].len());
let mut searcher: Searcher<DistanceUnit> = Searcher::default();
let mut document_map = HashMap::new();
for (document, embedding) in documents.iter().zip(new_embeddings.iter()) {
let embedding_index = self.hnsw.insert(embedding.clone(), &mut searcher);
let mut document_map = HashMap::new();
document_map.insert(embedding_index, document.clone());
self.save_documents_to_disk(&mut document_map)?;
}
self.save_documents_to_disk(&mut document_map)?;
self.save_database()?;
Ok(length_and_dimension)
}
Expand All @@ -161,39 +161,31 @@ where
///
/// * `documents` - A vector of documents to be queried.
///
/// * `number_of_results` - An optional positive integer less than or equal to `EF` specifying the number of query results to return.
/// * `number_of_results` - The candidate list size for the HNSW graph. Higher values result in more accurate search results at the expense of slower retrieval speeds.
///
/// # Returns
///
/// A vector of documents that are most similar to the queried documents.
pub fn query_documents<S: AsRef<str> + Send + Sync, Mod: DatabaseEmbeddingModel>(
pub fn query_documents<Mod: DatabaseEmbeddingModel>(
&mut self,
model: &Mod,
documents: Vec<S>,
number_of_results: Option<usize>,
documents: Vec<Bytes>,
number_of_results: usize,
) -> Result<Vec<Vec<u8>>, Box<dyn Error>> {
if self.hnsw.is_empty() {
return Ok(Vec::new());
}
let number_of_results = match number_of_results {
None => 1,
Some(number_of_results) => std::cmp::min(number_of_results, EF),
};
let mut searcher: Searcher<DistanceUnit> = Searcher::default();
let mut results = Vec::new();
// let model = TextEmbedding::try_new(InitOptions {
// model_name: EmbeddingModel::BGESmallENV15,
// show_download_progress: false,
// ..Default::default()
// })?;
let query_embeddings = model.embed_documents(documents)?;
for query_embedding in query_embeddings.iter() {
let mut neighbours = [Neighbor {
index: !0,
distance: !0,
}; EF];
self.hnsw
.nearest(query_embedding, EF, &mut searcher, &mut neighbours);
let mut neighbours = Vec::new();
self.hnsw.nearest(
query_embedding,
number_of_results,
&mut searcher,
&mut neighbours,
);
if neighbours.is_empty() {
return Ok(Vec::new());
}
Expand All @@ -216,14 +208,14 @@ where
/// # Arguments
///
/// * `documents` - A map of document indices and their corresponding documents.
pub fn save_documents_to_disk<S: AsRef<str> + Send + Sync>(
pub fn save_documents_to_disk(
&self,
documents: &mut HashMap<usize, S>,
documents: &mut HashMap<usize, Bytes>,
) -> Result<(), Box<dyn Error>> {
let document_subdirectory = self.document_type.subdirectory_name();
std::fs::create_dir_all(document_subdirectory)?;
for document in documents {
let mut reader = BufReader::new(document.1.as_ref().as_bytes());
let mut reader = BufReader::new(document.1.as_ref());
let file = OpenOptions::new()
.read(true)
.write(true)
Expand Down
36 changes: 36 additions & 0 deletions src/image.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
use crate::db::{Database, DocumentType};
use crate::distance::{CosineDistance, DefaultImageMetric};
use bytes::Bytes;
use candle_core::Tensor;
use candle_examples::imagenet::{IMAGENET_MEAN, IMAGENET_STD};
use image::ImageReader;
use std::error::Error;
use std::io::Cursor;

/// A parameter regarding insertion into the HNSW graph. Higher values result in more accurate search results at the expense of slower retrieval speeds. Cannot be changed after database creation.
pub const IMAGE_EF_CONSTRUCTION: usize = 400;
Expand All @@ -21,3 +27,33 @@ pub type ImageDatabase = Database<DefaultImageMetric, IMAGE_EF_CONSTRUCTION, IMA
pub fn create_or_load_database() -> Result<ImageDatabase, Box<dyn std::error::Error>> {
ImageDatabase::create_or_load_database(CosineDistance, DocumentType::Image)
}

/// Loads an image from raw bytes with ImageNet normalisation applied, returning a tensor with the shape [3 224 224].
///
/// # Arguments
///
/// * `bytes` - The raw bytes of an image.
///
/// # Returns
///
/// A tensor with the shape [3 224 224]; ImageNet normalisation is applied.
pub fn load_image224(bytes: Bytes) -> Result<Tensor, Box<dyn Error>> {
let res = 224_usize;
let img = ImageReader::new(Cursor::new(bytes))
.with_guessed_format()?
.decode()?
.resize_to_fill(
res as u32,
res as u32,
image::imageops::FilterType::Triangle,
)
.to_rgb8();
let data = img.into_raw();
let data =
Tensor::from_vec(data, (res, res, 3), &candle_core::Device::Cpu)?.permute((2, 0, 1))?;
let mean = Tensor::new(&IMAGENET_MEAN, &candle_core::Device::Cpu)?.reshape((3, 1, 1))?;
let std = Tensor::new(&IMAGENET_STD, &candle_core::Device::Cpu)?.reshape((3, 1, 1))?;
Ok((data.to_dtype(candle_core::DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)?)
}
3 changes: 0 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,3 @@ pub mod image;
pub mod model;
/// A module for text database operations.
pub mod text;

/// The candidate list size for the HNSW graph. Higher values result in more accurate search results at the expense of slower retrieval speeds. Can be changed after database creation.
pub const EF: usize = 24;
48 changes: 26 additions & 22 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
use bytes::Bytes;
use clap::{command, Parser, Subcommand};
use fastembed::Embedding;
use fastembed::TextEmbedding;
use indicatif::HumanCount;
use indicatif::ProgressStyle;
use indicatif::{ProgressBar, ProgressDrawTarget};
use pretty_duration::pretty_duration;
use rayon::iter::IntoParallelIterator;
use rayon::iter::IntoParallelRefIterator;
use rayon::iter::ParallelIterator;
use rodio::{Decoder, OutputStream, Sink};
use space::Metric;
use std::error::Error;
Expand Down Expand Up @@ -71,7 +75,8 @@ enum TextCommands {
#[command(about = "Query texts from the database.", arg_required_else_help(true))]
Query {
texts: Vec<String>,
number_of_results: Option<usize>,
#[arg(default_value_t = 1)]
number_of_results: usize,
},
#[command(about = "Clear the database.")]
Clear,
Expand All @@ -90,7 +95,8 @@ enum ImageCommands {
)]
Query {
image_path: PathBuf,
number_of_results: Option<usize>,
#[arg(default_value_t = 1)]
number_of_results: usize,
},
#[command(about = "Clear the database.")]
Clear,
Expand All @@ -109,7 +115,8 @@ enum AudioCommands {
)]
Query {
audio_path: PathBuf,
number_of_results: Option<usize>,
#[arg(default_value_t = 1)]
number_of_results: usize,
},
#[command(about = "Clear the database.")]
Clear,
Expand All @@ -119,13 +126,14 @@ fn main() -> Result<(), Box<dyn Error>> {
let cli = Cli::parse();
match cli.commands {
Commands::Text(text) => match text.text_commands {
TextCommands::Insert { mut texts } => {
TextCommands::Insert { texts } => {
let mut sw = Stopwatch::start_new();
let mut db = zebra::text::create_or_load_database()?;
let mut buffer = BufWriter::new(stdout().lock());
let model: TextEmbedding = DatabaseEmbeddingModel::new()?;
writeln!(buffer, "Inserting {} text(s).", texts.len())?;
let insertion_results = db.insert_documents(&model, &mut texts)?;
let texts_bytes: Vec<_> = texts.into_par_iter().map(|x| Bytes::from(x)).collect();
let insertion_results = db.insert_documents(&model, &texts_bytes)?;
sw.stop();
writeln!(
buffer,
Expand All @@ -150,10 +158,11 @@ fn main() -> Result<(), Box<dyn Error>> {
let num_texts = texts.len();
let model: TextEmbedding = DatabaseEmbeddingModel::new()?;
writeln!(buffer, "Querying {} text(s).", num_texts)?;
let query_results = db.query_documents(&model, texts, number_of_results)?;
let result_texts: Vec<String> = query_results
let texts_bytes: Vec<_> = texts.into_par_iter().map(|x| Bytes::from(x)).collect();
let query_results = db.query_documents(&model, texts_bytes, number_of_results)?;
let result_texts: Vec<_> = query_results
.iter()
.map(|x| String::from_utf8(x.to_vec()).unwrap())
.map(|x| String::from_utf8_lossy(x))
.collect();
sw.stop();
writeln!(
Expand Down Expand Up @@ -201,11 +210,9 @@ fn main() -> Result<(), Box<dyn Error>> {
};
let model: ImageEmbeddingModel = DatabaseEmbeddingModel::new()?;
writeln!(buffer, "Querying image.")?;
let query_results = db.query_documents(
&model,
vec![image_path.to_str().unwrap()],
number_of_results,
)?;
let image_bytes = std::fs::read(image_path).unwrap_or_default().into();
let query_results =
db.query_documents(&model, vec![image_bytes], number_of_results)?;
sw.stop();
writeln!(
buffer,
Expand Down Expand Up @@ -239,11 +246,9 @@ fn main() -> Result<(), Box<dyn Error>> {
let mut buffer = BufWriter::new(stdout().lock());
let model: AudioEmbeddingModel = DatabaseEmbeddingModel::new()?;
writeln!(buffer, "Querying sound.")?;
let query_results = db.query_documents(
&model,
vec![audio_path.to_str().unwrap()],
number_of_results,
)?;
let audio_bytes = std::fs::read(audio_path).unwrap_or_default().into();
let query_results =
db.query_documents(&model, vec![audio_bytes], number_of_results)?;
sw.stop();
writeln!(
buffer,
Expand All @@ -264,7 +269,6 @@ fn main() -> Result<(), Box<dyn Error>> {
clear_database(DocumentType::Audio)?;
}
},
// _ => unreachable!(),
}
Ok(())
}
Expand Down Expand Up @@ -315,9 +319,9 @@ where
ProgressDrawTarget::hidden(),
);
progress_bar.set_style(progress_bar_style()?);
let documents: Vec<String> = file_paths
.into_iter()
.map(|x| x.to_str().unwrap().to_string())
let documents: Vec<_> = file_paths
.par_iter()
.filter_map(|x| std::fs::read(x).ok().map(|y| y.into()))
.collect();
// Insert documents in batches of INSERT_BATCH_SIZE.
for document_batch in documents.chunks(INSERT_BATCH_SIZE) {
Expand Down
Loading

0 comments on commit 8b21275

Please sign in to comment.