Skip to content

Commit

Permalink
🚧 openai summarizer added
Browse files Browse the repository at this point in the history
  • Loading branch information
rathijitpapon committed Jun 20, 2024
1 parent a3b1aa3 commit b4a4189
Show file tree
Hide file tree
Showing 11 changed files with 152 additions and 9 deletions.
6 changes: 6 additions & 0 deletions server/.env.template
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
DATABASE_URL=
CACHE__URL=
AGENCY_API=
BRAVE__SUBSCRIPTION_KEY=
BRAVE__GOGGLES_ID=
LLM__TOXICITY_AUTH_TOKEN=
QUERY_REPHRASER__API_KEY=
OPENAI__API_KEY=
SENTRY_DSN=
2 changes: 1 addition & 1 deletion server/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ debug/
target/

# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
# More in!ion here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
# Cargo.lock

# These are backup files generated by rustfmt
Expand Down
5 changes: 3 additions & 2 deletions server/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ dashmap = { version = "5.5.3", features = ["inline", "serde"] }
tokio-stream = { version = "0.1.15", features = ["full"] }
sentry = { version = "0.34.0", features = ["tracing"] }
sentry-tower = { version = "0.34.0", features = ["http"] }
regex = "1.10.5"

[dependencies.openssl-sys]
version = "0.9.102"
Expand Down
4 changes: 4 additions & 0 deletions server/config/default.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ max_new_tokens = 1024
temperature = 1.0
top_p = 0.7

[openai]
api_url = "https://api.openai.com/v1/chat/completions"
model = "gpt-4o"

[query_rephraser]
model = "mistralai/Mistral-7B-Instruct-v0.2"
max_tokens = 100
Expand Down
7 changes: 7 additions & 0 deletions server/src/llms/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,10 @@ pub struct LLMSettings {
pub toxicity_threshold: f64,
pub toxicity_auth_token: Secret<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAISettings {
pub api_url: String,
pub model: String,
pub api_key: Secret<String>,
}
113 changes: 112 additions & 1 deletion server/src/llms/summarizer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::llms::OpenAISettings;
use crate::search::api_models;
use color_eyre::eyre::eyre;
use futures::StreamExt;
use regex::Regex;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::Sender;
Expand Down Expand Up @@ -55,7 +58,7 @@ fn prepare_context_string(
The solution draft follows the format \"Thought, Action, Action Input, Observation\", where the 'Thought' statements describe a reasoning sequence. The rest of the text is information obtained to complement the reasoning sequence, and it is 100% accurate OR you can use a single \"Final Answer\" format.
Your task is to write an answer to the question based on the solution draft, and the following guidelines:
The text should have an educative and assistant-like tone, be accurate, follow the same reasoning sequence than the solution draft and explain how any conclusion is reached.
Question: {}\n\nSolution draft: {}\n\nAnswer:", summarizer_input.retrieved_result, summarizer_input.query),
Question: {}\n\nSolution draft: {}\n\nAnswer:", summarizer_input.query, summarizer_input.retrieved_result),
parameters: SummarizerParams {
model: Some(settings.model.clone()),
max_new_tokens: Some(settings.max_new_tokens.clone()),
Expand Down Expand Up @@ -114,3 +117,111 @@ pub async fn generate_text_stream(

Ok(())
}

fn prepare_openai_input(
settings: &OpenAISettings,
summarizer_input: SummarizerInput,
) -> serde_json::Value {
let system_role = "You are a summarizer AI. In this exercise you will assume the role of a scientific medical assistant. Your task is to answer the provided question as best as you can, based on the provided solution draft.
The solution draft follows the format \"Thought, Action, Action Input, Observation\", where the 'Thought' statements describe a reasoning sequence. The rest of the text is information obtained to complement the reasoning sequence, and it is 100% accurate OR you can use a single \"Final Answer\" format.
Your task is to write an answer to the question based on the solution draft, and the following guidelines:
The text should have an educative and assistant-like tone, be accurate, follow the same reasoning sequence than the solution draft and explain how any conclusion is reached.
Question: {}\n\nSolution draft: {}\n\nAnswer: ";

let user_input = format!(
"Question: {}\n\nSolution draft: {}",
summarizer_input.query, summarizer_input.retrieved_result
);

serde_json::json!({
"model": settings.model,
"stream": true,
"messages": [
{
"role": "system",
"content": system_role
},
{
"role": "user",
"content": user_input
}
]
})
}

#[tracing::instrument(level = "debug", ret, err)]
pub async fn generate_text_with_openai(
settings: OpenAISettings,
summarizer_input: SummarizerInput,
update_processor: api_models::UpdateResultProcessor,
stream_regex: Regex,
tx: Sender<api_models::SearchByIdResponse>,
) -> crate::Result<()> {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_bytes(b"Authorization")
.map_err(|e| eyre!("Failed to create header: {e}"))?,
HeaderValue::from_str(settings.api_key.expose())
.map_err(|e| eyre!("Failed to create header: {e}"))?,
);

let client = Client::new();
let summarizer_input = prepare_openai_input(&settings, summarizer_input);

let response = client
.post(settings.api_url.as_str())
.json(&summarizer_input)
.headers(headers)
.send()
.await
.map_err(|e| eyre!("Request to summarizer failed: {e}"))?;

// stream the response
if !response.status().is_success() {
return Err(eyre!("Request failed with status: {:?}", response.status()).into());
}

let mut stream = response.bytes_stream();
let mut stream_data = String::new();

while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| eyre!("Failed to read chunk: {e}"))?;
stream_data.push_str(&String::from_utf8_lossy(&chunk));

let parsed_chunk = stream_regex
.captures_iter(&stream_data)
.map(|c| c[1].to_string())
.collect::<Vec<String>>()
.join("");

let last_index = stream_regex
.captures_iter(&stream_data)
.last()
.map(|c| {
if let Some(m) = c.get(0) {
return Some(m.end());
}
None
})
.unwrap_or(None);

if let Some(last_index) = last_index {
stream_data = stream_data.split_off(last_index);
}

let mut search = update_processor
.process(parsed_chunk.clone())
.await
.map_err(|e| eyre!("Failed to update result: {e}"))?;
search.result = parsed_chunk;

tx.send(api_models::SearchByIdResponse {
search,
sources: vec![],
})
.await
.map_err(|e| eyre!("Failed to send response: {e}"))?;
}

Ok(())
}
11 changes: 7 additions & 4 deletions server/src/rag/post_process.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::llms::summarizer;
use crate::llms::SummarizerSettings;
use crate::proto::Embeddings;
use crate::rag::utils;
use crate::search::api_models;
use crate::settings::Settings;
use regex::Regex;
use std::cmp::Ordering;
use tokio::sync::mpsc::Sender;

Expand Down Expand Up @@ -36,19 +37,21 @@ pub async fn rerank_search_results(

#[tracing::instrument(level = "debug", ret, err)]
pub async fn summarize_search_results(
settings: SummarizerSettings,
settings: Settings,
search_query_request: api_models::SearchQueryRequest,
search_response: String,
update_processor: api_models::UpdateResultProcessor,
stream_regex: Regex,
tx: Sender<api_models::SearchByIdResponse>,
) -> crate::Result<()> {
summarizer::generate_text_stream(
settings,
summarizer::generate_text_with_openai(
settings.openai,
summarizer::SummarizerInput {
query: search_query_request.query,
retrieved_result: search_response,
},
update_processor,
stream_regex,
tx,
)
.await?;
Expand Down
5 changes: 4 additions & 1 deletion server/src/search/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use axum::{Json, Router};
use color_eyre::eyre::eyre;
use futures::stream::StreamExt;
use futures::Stream;
use regex::Regex;
use sqlx::PgPool;
use std::convert::Infallible;
use std::sync::Arc;
Expand All @@ -28,6 +29,7 @@ async fn get_search_query_handler(
State(pool): State<PgPool>,
State(cache): State<CachePool>,
State(mut agency_service): State<AgencyServiceClient<Channel>>,
State(openai_stream_regex): State<Regex>,
user: User,
Query(search_query_request): Query<api_models::SearchQueryRequest>,
) -> crate::Result<Sse<impl Stream<Item = Result<Event, Infallible>>>> {
Expand Down Expand Up @@ -82,10 +84,11 @@ async fn get_search_query_handler(
}));

tokio::spawn(post_process::summarize_search_results(
settings.summarizer.clone(),
settings.clone(),
search_query_request,
search_response.result,
update_processor,
openai_stream_regex,
tx,
));

Expand Down
1 change: 1 addition & 0 deletions server/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ pub struct Settings {
pub summarizer: llms::SummarizerSettings,
pub search: rag::SearchSettings,
pub query_rephraser: llms::QueryRephraserSettings,
pub openai: llms::OpenAISettings,
}

impl Settings {
Expand Down
6 changes: 6 additions & 0 deletions server/src/startup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::Result;
use axum::{extract::FromRef, routing::IntoMakeService, serve::Serve, Router};
use color_eyre::eyre::eyre;
use log::info;
use regex::Regex;
use sentry::{self, ClientInitGuard, ClientOptions};
use sqlx::postgres::PgPoolOptions;
use sqlx::PgPool;
Expand Down Expand Up @@ -59,6 +60,7 @@ pub struct AppState {
pub oauth2_clients: Vec<OAuth2Client>,
pub settings: Settings,
pub brave_config: brave_search::BraveAPIConfig,
pub openai_stream_regex: regex::Regex,
}

impl AppState {
Expand All @@ -69,6 +71,7 @@ impl AppState {
oauth2_clients: Vec<OAuth2Client>,
settings: Settings,
brave_config: brave_search::BraveAPIConfig,
openai_stream_regex: regex::Regex,
) -> Result<Self> {
Ok(Self {
db,
Expand All @@ -77,6 +80,7 @@ impl AppState {
oauth2_clients,
settings,
brave_config,
openai_stream_regex,
})
}

Expand All @@ -88,6 +92,8 @@ impl AppState {
oauth2_clients: settings.oauth2_clients.clone(),
brave_config: brave_search::prepare_brave_api_config(&settings.brave),
settings,
openai_stream_regex: Regex::new(r#"\"content\":\"(.*?)\"}"#)
.map_err(|e| eyre!("Failed to compile OpenAI stream regex: {}", e))?,
})
}
}
Expand Down

0 comments on commit b4a4189

Please sign in to comment.