Skip to content

Commit

Permalink
🚧 added query rephraser llm integration
Browse files Browse the repository at this point in the history
  • Loading branch information
rathijitpapon committed Jun 19, 2024
1 parent e24b9a7 commit 7dbd00c
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 5 deletions.
4 changes: 4 additions & 0 deletions server/config/default.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ max_new_tokens = 1024
temperature = 1.0
top_p = 0.7

[query_rephraser]
model = "mistralai/Mistral-7B-Instruct-v0.2"
max_tokens = 100

[llm]
toxicity_threshold = 0.75

Expand Down
4 changes: 4 additions & 0 deletions server/config/dev.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ toxicity_url = "http://localhost:8082/predict"
[summarizer]
api_url = "http://localhost:8001/generate_stream"

[query_rephraser]
api_url = "https://api.together.xyz/inference"
api_key = "Bearer 92b2f1a8c6f71fdaa546324a453fcd356c112a27dac140ecb1ef090be7a326ae"

[cache]
url = "redis://127.0.0.1/"
max_sorted_size = 100
Expand Down
78 changes: 75 additions & 3 deletions server/src/llms/query_rephraser.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
use crate::llms::LLMSettings;
use crate::secrets::Secret;
use color_eyre::eyre::eyre;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use reqwest::Client;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryRephraserSettings {
pub api_key: Secret<String>,
pub api_url: String,
pub max_tokens: u16,
pub model: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct QueryResult {
pub query: String,
Expand All @@ -13,17 +24,78 @@ pub struct QueryRephraserInput {
pub previous_context: Vec<QueryResult>,
}

#[derive(Debug, Serialize, Deserialize)]
struct Choice {
pub text: String,
}

#[derive(Debug, Serialize, Deserialize)]
struct Output {
pub choices: Vec<Choice>,
}

#[derive(Debug, Serialize, Deserialize)]
struct QueryRephraserAPIResponse {
pub output: Output,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct QueryRephraserOutput {
pub rephrased_query: String,
}

fn prepare_prompt(query_rephraser_input: &QueryRephraserInput) -> String {
"
[INST]Answer questions in such a way that history is not needed.\n\n---\n\nFollow the following format.\n\nContext: contains the chat history\n\nQuestion: ${question}\n\nReasoning: Let's think step by step in order to ${produce the answer}. We ...\n\nAnswer: Given a chat history and the latest user question, which might reference the context from the chat history, formulate a standalone question that can be understood from the history without needing the chat history. DO NOT ANSWER THE QUESTION - just reformulate it\n\n---\n\nContext: ".to_string()
+ query_rephraser_input.previous_context.iter().map(|x| format!("{}: {}", x.query, x.result)).collect::<Vec<String>>().join("\n").as_str()
+ "\n\nQuestion: "
+ query_rephraser_input.query.as_str()
+ "\n\nReasoning: Let's think step by step in order to...\n\nAnswer: [/INST]"
}

#[tracing::instrument(level = "debug", ret, err)]
pub async fn rephrase_query(
llm_settings: &LLMSettings,
settings: &QueryRephraserSettings,
query_rephraser_input: &QueryRephraserInput,
) -> crate::Result<QueryRephraserOutput> {
let client = Client::new();
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_bytes(b"Authorization")
.map_err(|e| eyre!("Failed to create header: {e}"))?,
HeaderValue::from_str(&settings.api_key.expose())
.map_err(|e| eyre!("Failed to create header: {e}"))?,
);

let prompt = prepare_prompt(query_rephraser_input);

let response = client
.post(&settings.api_url)
.json(&serde_json::json!({
"model": settings.model,
"prompt": prompt,
"max_tokens": settings.max_tokens,
// "temperature": 1.0,
// "top_p": 1.0,
// "top_k": 50,
// "frequency_penalty": 0.0,
// "presence_penalty": 0.0,
// "repetition_penalty": 1.0,
}))
.headers(headers)
.send()
.await
.map_err(|e| eyre!("Request to query rephraser failed: {e}"))?;

let response_body = serde_json::from_slice::<QueryRephraserAPIResponse>(
&response
.bytes()
.await
.map_err(|e| eyre!("Failed to read response: {e}"))?,
)
.map_err(|e| eyre!("Failed to parse response: {e}"))?;

Ok(QueryRephraserOutput {
rephrased_query: query_rephraser_input.query.clone(),
rephrased_query: response_body.output.choices[0].text.trim().to_string(),
})
}
8 changes: 6 additions & 2 deletions server/src/rag/pre_process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ pub async fn rephrase_query(
}

let rephraser_response = query_rephraser::rephrase_query(
&settings.llm,
&settings.query_rephraser,
&query_rephraser::QueryRephraserInput {
query: search_query_request.query.clone(),
previous_context: last_n_searches
Expand All @@ -93,5 +93,9 @@ pub async fn rephrase_query(
)
.await?;

Ok(rephraser_response.rephrased_query)
Ok(rephraser_response
.rephrased_query
.chars()
.take(settings.search.max_query_length as usize)
.collect())
}
1 change: 1 addition & 0 deletions server/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ pub struct Settings {
pub llm: llms::LLMSettings,
pub summarizer: llms::SummarizerSettings,
pub search: rag::SearchSettings,
pub query_rephraser: llms::QueryRephraserSettings,
}

impl Settings {
Expand Down

0 comments on commit 7dbd00c

Please sign in to comment.