Skip to content

Commit

Permalink
Merge pull request #61 from curieo-org/eng-204
Browse files Browse the repository at this point in the history
Context Search Integration
  • Loading branch information
rathijitpapon authored Jun 19, 2024
2 parents 5a0d98e + e04cb57 commit a3b1aa3
Show file tree
Hide file tree
Showing 16 changed files with 303 additions and 91 deletions.
11 changes: 9 additions & 2 deletions server/config/default.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
host = "0.0.0.0"
port = 3030
max_search_query_length = 255

[search]
max_query_length = 300
max_sources = 10
max_search_context = 5

[summarizer]
model = ""
max_new_tokens = 1024
temperature = 1.0
top_p = 0.7

[query_rephraser]
model = "mistralai/Mistral-7B-Instruct-v0.2"
max_tokens = 100

[llm]
top_k_sources = 10
toxicity_threshold = 0.75

[pubmed]
Expand Down
5 changes: 4 additions & 1 deletion server/config/dev.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ agency_api = "http://127.0.0.1:50051"
oauth2_clients = []

[llm]
llm_lingua_url = "http://localhost:8000/compress"
prompt_compression_url = "http://localhost:8000/compress"
toxicity_url = "http://localhost:8082/predict"

[summarizer]
api_url = "http://localhost:8001/generate_stream"

[query_rephraser]
api_url = "https://api.together.xyz/inference"

[cache]
url = "redis://127.0.0.1/"
max_sorted_size = 100
Expand Down
17 changes: 9 additions & 8 deletions server/migrations/20240604111752_searches.sql
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
-- Creating a table for searches
CREATE TABLE searches
(
search_id uuid primary key default uuid_generate_v1mc(),
thread_id uuid not null references threads (thread_id),
query varchar(255) not null,
result text not null,
media_urls text[],
reaction boolean,
created_at timestamptz not null default now(),
updated_at timestamptz not null default now()
search_id uuid primary key default uuid_generate_v1mc(),
thread_id uuid not null references threads (thread_id),
query varchar(400) not null,
rephrased_query varchar(400) not null,
result text not null,
media_urls text[],
reaction boolean,
created_at timestamptz not null default now(),
updated_at timestamptz not null default now()
);

-- And applying our `updated_at` trigger is as easy as this.
Expand Down
41 changes: 0 additions & 41 deletions server/src/llms/llm_lingua.rs

This file was deleted.

14 changes: 8 additions & 6 deletions server/src/llms/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
pub use summarizer::*;
pub use llm_lingua::*;
pub use models::*;
pub use toxicity_llm::*;
pub use prompt_compression::*;
pub use query_rephraser::*;
pub use summarizer::*;
pub use toxicity::*;

pub mod summarizer;
pub mod llm_lingua;
pub mod models;
pub mod toxicity_llm;
pub mod prompt_compression;
pub mod query_rephraser;
pub mod summarizer;
pub mod toxicity;
3 changes: 1 addition & 2 deletions server/src/llms/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMSettings {
pub llm_lingua_url: String,
pub top_k_sources: u16,
pub prompt_compression_url: String,
pub toxicity_url: String,
pub toxicity_threshold: f64,
pub toxicity_auth_token: Secret<String>,
Expand Down
41 changes: 41 additions & 0 deletions server/src/llms/prompt_compression.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use crate::llms::LLMSettings;
use color_eyre::eyre::eyre;
use reqwest::Client;
use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize, Deserialize)]
pub struct PromptCompressionInput {
pub query: String,
pub target_token: u16,
pub context_texts_list: Vec<String>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct PromptCompressionOutput {
pub compressed_prompt: String,
}

#[derive(Debug, Serialize, Deserialize)]
struct PromptCompressionAPIResponse {
pub response: PromptCompressionOutput,
}

pub async fn compress(
llm_settings: &LLMSettings,
prompt_compression_input: PromptCompressionInput,
) -> crate::Result<PromptCompressionOutput> {
let client = Client::new();
let response = client
.post(llm_settings.prompt_compression_url.as_str())
.json(&prompt_compression_input)
.send()
.await
.map_err(|e| eyre!("Request to prompt compression failed: {e}"))?;

let prompt_compression_response = response
.json::<PromptCompressionAPIResponse>()
.await
.map_err(|e| eyre!("Failed to parse prompt compression response: {e}"))?;

Ok(prompt_compression_response.response)
}
100 changes: 100 additions & 0 deletions server/src/llms/query_rephraser.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
use crate::secrets::Secret;
use color_eyre::eyre::eyre;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use reqwest::Client;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryRephraserSettings {
pub api_key: Secret<String>,
pub api_url: String,
pub max_tokens: u16,
pub model: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct QueryResult {
pub query: String,
pub result: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct QueryRephraserInput {
pub query: String,
pub previous_context: Vec<QueryResult>,
}

#[derive(Debug, Serialize, Deserialize)]
struct Choice {
pub text: String,
}

#[derive(Debug, Serialize, Deserialize)]
struct Output {
pub choices: Vec<Choice>,
}

#[derive(Debug, Serialize, Deserialize)]
struct QueryRephraserAPIResponse {
pub output: Output,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct QueryRephraserOutput {
pub rephrased_query: String,
}

fn prepare_prompt(query_rephraser_input: &QueryRephraserInput) -> String {
"Rephrase the input text based on the context and the final sentence. So that it can be understood without the context.\n\n---\n\nFollow the following format.\n\nContext: contains the chat history\n\nQuestion: ${question}\n\nReasoning: Let's think step by step in order to ${produce the answer}. We ...\n\nAnswer: Given a chat history and the latest user question, which might reference the context from the chat history, formulate a standalone question that can be understood from the history without needing the chat history. DO NOT ANSWER THE QUESTION - just reformulate it\n\n---\n\nContext: ".to_string()
+ query_rephraser_input.previous_context.iter().map(|x| format!("{}: {}", x.query, x.result)).collect::<Vec<String>>().join("\n").as_str()
+ "\n\nQuestion: "
+ query_rephraser_input.query.as_str()
+ "\n\nReasoning: Let's think step by step in order to...\n\nAnswer: "
}

#[tracing::instrument(level = "debug", ret, err)]
pub async fn rephrase_query(
settings: &QueryRephraserSettings,
query_rephraser_input: &QueryRephraserInput,
) -> crate::Result<QueryRephraserOutput> {
let client = Client::new();
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_bytes(b"Authorization")
.map_err(|e| eyre!("Failed to create header: {e}"))?,
HeaderValue::from_str(&settings.api_key.expose())
.map_err(|e| eyre!("Failed to create header: {e}"))?,
);

let prompt = prepare_prompt(query_rephraser_input);

let response = client
.post(&settings.api_url)
.json(&serde_json::json!({
"model": settings.model,
"prompt": prompt,
"max_tokens": settings.max_tokens,
// "temperature": 1.0,
// "top_p": 1.0,
// "top_k": 50,
// "frequency_penalty": 0.0,
// "presence_penalty": 0.0,
// "repetition_penalty": 1.0,
}))
.headers(headers)
.send()
.await
.map_err(|e| eyre!("Request to query rephraser failed: {e}"))?;

let response_body = serde_json::from_slice::<QueryRephraserAPIResponse>(
&response
.bytes()
.await
.map_err(|e| eyre!("Failed to read response: {e}"))?,
)
.map_err(|e| eyre!("Failed to parse response: {e}"))?;

Ok(QueryRephraserOutput {
rephrased_query: response_body.output.choices[0].text.trim().to_string(),
})
}
File renamed without changes.
7 changes: 7 additions & 0 deletions server/src/rag/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,10 @@ pub struct SearchResponse {
pub result: String,
pub sources: Vec<Source>,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct SearchSettings {
pub max_query_length: u16,
pub max_sources: u8,
pub max_search_context: u8,
}
55 changes: 51 additions & 4 deletions server/src/rag/pre_process.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use crate::llms::toxicity_llm;
use crate::llms::query_rephraser;
use crate::llms::toxicity;
use crate::proto::agency_service_client::AgencyServiceClient;
use crate::proto::{Embeddings, EmbeddingsOutput, SearchInput};
use crate::search::api_models;
use crate::search::services as search_services;
use crate::settings::Settings;
use color_eyre::eyre::eyre;
use sqlx::PgPool;
use std::sync::Arc;
use tonic::transport::Channel;

Expand Down Expand Up @@ -38,17 +41,61 @@ pub async fn check_query_validity(
settings: &Settings,
search_query_request: &api_models::SearchQueryRequest,
) -> crate::Result<bool> {
if search_query_request.query.len() > settings.max_search_query_length as usize {
if search_query_request.query.len() > settings.search.max_query_length as usize {
return Ok(false);
}

let toxicity_prediction = toxicity_llm::predict_toxicity(
let toxicity_prediction = toxicity::predict_toxicity(
&settings.llm,
toxicity_llm::ToxicityInput {
toxicity::ToxicityInput {
inputs: search_query_request.query.to_string(),
},
)
.await?;

Ok(!toxicity_prediction)
}

#[tracing::instrument(level = "debug", ret, err)]
pub async fn rephrase_query(
pool: &PgPool,
settings: &Settings,
search_query_request: &api_models::SearchQueryRequest,
) -> crate::Result<String> {
let last_n_searches = match search_query_request.thread_id {
Some(thread_id) => {
search_services::get_last_n_searches(
&pool,
settings.search.max_search_context,
&thread_id,
)
.await?
}
None => vec![],
};

if last_n_searches.is_empty() {
return Ok(search_query_request.query.clone());
}

let rephraser_response = query_rephraser::rephrase_query(
&settings.query_rephraser,
&query_rephraser::QueryRephraserInput {
query: search_query_request.query.clone(),
previous_context: last_n_searches
.into_iter()
.map(|s| query_rephraser::QueryResult {
query: s.rephrased_query,
result: s.result,
})
.collect(),
},
)
.await?;

Ok(rephraser_response
.rephrased_query
.chars()
.take(settings.search.max_query_length as usize)
.collect())
}
Loading

0 comments on commit a3b1aa3

Please sign in to comment.