Skip to content

Commit

Permalink
Various improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Jun 25, 2024
1 parent a54e16c commit 3dc0fd6
Show file tree
Hide file tree
Showing 13 changed files with 120 additions and 105 deletions.
10 changes: 5 additions & 5 deletions server/src/auth/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,11 @@ pub struct AccessTokenCredentials {
// Note that we've supplied our concrete backend here.
pub type AuthSession = axum_login::AuthSession<PostgresBackend>;

pub struct WhitelistedEmail {
pub email: String,
pub approved: bool,
}

#[cfg(test)]
mod tests {
use crate::auth::utils::dummy_verify_password;
Expand All @@ -266,8 +271,3 @@ mod tests {
assert!(dummy_verify_password(Secret::new("password")).is_ok());
}
}

pub struct WhitelistedEmail {
pub email: String,
pub approved: bool,
}
2 changes: 1 addition & 1 deletion server/src/llms/query_rephraser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pub async fn rephrase_query(
headers.insert(
HeaderName::from_bytes(b"Authorization")
.map_err(|e| eyre!("Failed to create header: {e}"))?,
HeaderValue::from_str(&settings.api_key.expose())
HeaderValue::from_str(settings.api_key.expose())
.map_err(|e| eyre!("Failed to create header: {e}"))?,
);

Expand Down
8 changes: 4 additions & 4 deletions server/src/llms/summarizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ fn prepare_llm_context_string(
Question: {}\n\nSolution draft: {}\n\nAnswer:", summarizer_input.query, summarizer_input.retrieved_result),
parameters: SummarizerParams {
model: Some(settings.model.clone()),
max_new_tokens: Some(settings.max_new_tokens.clone()),
max_new_tokens: Some(settings.max_new_tokens),
temperature: Some(settings.temperature),
top_p: Some(settings.top_p),
},
Expand Down Expand Up @@ -97,7 +97,7 @@ pub async fn generate_text_with_llm(
let chunk = chunk.map_err(|e| eyre!("Failed to read chunk: {e}"))?;
let chunk = &chunk[5..chunk.len() - 2];

let summarizer_api_response = serde_json::from_slice::<SummarizerStreamOutput>(&chunk)
let summarizer_api_response = serde_json::from_slice::<SummarizerStreamOutput>(chunk)
.map_err(|e| eyre!("Failed to parse summarizer response: {e}"))?;

if !summarizer_api_response.token.special {
Expand All @@ -115,7 +115,7 @@ pub async fn generate_text_with_llm(
})
.await;

if let Ok(_) = tx_response {
if tx_response.is_ok() {
buffer.clear();
}
}
Expand Down Expand Up @@ -230,7 +230,7 @@ pub async fn generate_text_with_openai(
})
.await;

if let Ok(_) = tx_response {
if tx_response.is_ok() {
buffer.clear();
}
}
Expand Down
4 changes: 2 additions & 2 deletions server/src/llms/toxicity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub async fn predict_toxicity(
headers.insert(
HeaderName::from_bytes(b"Authorization")
.map_err(|e| eyre!("Failed to create header: {e}"))?,
HeaderValue::from_str(&llm_settings.toxicity_auth_token.expose())
HeaderValue::from_str(llm_settings.toxicity_auth_token.expose())
.map_err(|e| eyre!("Failed to create header: {e}"))?,
);
let client = Client::new();
Expand All @@ -47,7 +47,7 @@ pub async fn predict_toxicity(

let toxicity_score = toxicity_api_response
.into_iter()
.find(|x| x.label == String::from("toxic"))
.find(|x| x.label == "toxic")
.unwrap_or(ToxicityScore {
score: 0.0,
label: String::from(""),
Expand Down
114 changes: 55 additions & 59 deletions server/src/rag/brave_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,50 +46,52 @@ struct BraveAPIResponse {
pub web: BraveWebAPIResponse,
}

pub fn prepare_brave_api_config(brave_settings: &BraveSettings) -> BraveAPIConfig {
let queries = vec![
(String::from("count"), brave_settings.count.to_string()),
(
String::from("goggles_id"),
brave_settings.goggles_id.clone(),
),
(
String::from("result_filter"),
brave_settings.result_filter.clone(),
),
(
String::from("search_lang"),
brave_settings.search_lang.clone(),
),
(
String::from("extra_snippets"),
brave_settings.extra_snippets.to_string(),
),
(
String::from("safesearch"),
brave_settings.safesearch.clone(),
),
];

let headers = HeaderMap::from_iter(
vec![
("Accept", "application/json"),
("Accept-Encoding", "gzip"),
impl From<BraveSettings> for BraveAPIConfig {
fn from(brave_settings: BraveSettings) -> Self {
let queries = vec![
(String::from("count"), brave_settings.count.to_string()),
(
"X-Subscription-Token",
brave_settings.subscription_key.expose(),
String::from("goggles_id"),
brave_settings.goggles_id.clone(),
),
]
.into_iter()
.map(|(k, v)| {
(
HeaderName::from_bytes(k.as_bytes()).unwrap(),
HeaderValue::from_str(v).unwrap(),
)
}),
);

BraveAPIConfig { queries, headers }
String::from("result_filter"),
brave_settings.result_filter.clone(),
),
(
String::from("search_lang"),
brave_settings.search_lang.clone(),
),
(
String::from("extra_snippets"),
brave_settings.extra_snippets.to_string(),
),
(
String::from("safesearch"),
brave_settings.safesearch.clone(),
),
];

let headers = HeaderMap::from_iter(
vec![
("Accept", "application/json"),
("Accept-Encoding", "gzip"),
(
"X-Subscription-Token",
brave_settings.subscription_key.expose(),
),
]
.into_iter()
.map(|(k, v)| {
(
HeaderName::from_bytes(k.as_bytes()).unwrap(),
HeaderValue::from_str(v).unwrap(),
)
}),
);

BraveAPIConfig { queries, headers }
}
}

#[tracing::instrument(level = "debug", ret, err)]
Expand Down Expand Up @@ -133,10 +135,7 @@ pub async fn web_search(
}

fn convert_to_retrieved_result(result: BraveWebSearchResult) -> RetrievedResult {
let extra_snippets = match result.extra_snippets {
Some(snippets) => snippets,
None => vec![],
};
let extra_snippets = result.extra_snippets.unwrap_or_default();

RetrievedResult {
text: result.description.clone() + "\n\n" + extra_snippets.join("\n\n").as_str(),
Expand All @@ -145,20 +144,17 @@ fn convert_to_retrieved_result(result: BraveWebSearchResult) -> RetrievedResult
url: result.url,
description: result.description,
source_type: SourceType::Url,
metadata: HashMap::from_iter(
vec![
(
"page_age".to_string(),
result.page_age.unwrap_or("".to_string()),
),
("age".to_string(), result.age.unwrap_or("".to_string())),
(
"language".to_string(),
result.language.unwrap_or("".to_string()),
),
]
.into_iter(),
),
metadata: HashMap::from_iter(vec![
(
"page_age".to_string(),
result.page_age.unwrap_or("".to_string()),
),
("age".to_string(), result.age.unwrap_or("".to_string())),
(
"language".to_string(),
result.language.unwrap_or("".to_string()),
),
]),
},
}
}
2 changes: 1 addition & 1 deletion server/src/rag/pre_process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ pub async fn rephrase_query(
let last_n_searches = match search_query_request.thread_id {
Some(thread_id) => {
search_services::get_last_n_searches(
&pool,
pool,
settings.search.max_search_context,
&thread_id,
)
Expand Down
10 changes: 5 additions & 5 deletions server/src/rag/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ pub async fn search(
brave_api_config: &brave_search::BraveAPIConfig,
cache: &CachePool,
agency_service: &mut AgencyServiceClient<Channel>,
search_query: &String,
search_query: &str,
) -> crate::Result<rag::SearchResponse> {
if let Some(response) = cache.get(&search_query).await {
if let Some(response) = cache.get(search_query).await {
return Ok(response);
}

let (agency_results, fallback_results) = tokio::join!(
retrieve_result_from_agency(settings, agency_service, search_query),
brave_search::web_search(&settings.brave, brave_api_config, &search_query),
brave_search::web_search(&settings.brave, brave_api_config, search_query),
);

let mut retrieved_results = Vec::new();
Expand All @@ -47,7 +47,7 @@ pub async fn search(
let compressed_results = prompt_compression::compress(
&settings.llm,
prompt_compression::PromptCompressionInput {
query: search_query.clone(),
query: search_query.to_string(),
target_token: 300,
context_texts_list: retrieved_results.iter().map(|r| r.text.clone()).collect(),
},
Expand All @@ -67,7 +67,7 @@ pub async fn search(
async fn retrieve_result_from_agency(
settings: &Settings,
agency_service: &mut AgencyServiceClient<Channel>,
search_query: &String,
search_query: &str,
) -> crate::Result<Vec<rag::RetrievedResult>> {
let agency_service = Arc::new(agency_service.clone());
let query_embeddings =
Expand Down
4 changes: 2 additions & 2 deletions server/src/rag/utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub fn cosine_similarity(v1: &Vec<f64>, v2: &Vec<f64>) -> f64 {
pub fn cosine_similarity(v1: &[f64], v2: &[f64]) -> f64 {
if v1.len() != v2.len() {
return 0.0;
}
Expand All @@ -11,5 +11,5 @@ pub fn cosine_similarity(v1: &Vec<f64>, v2: &Vec<f64>) -> f64 {
return 0.0;
}

return dot_product / magnitude_product;
dot_product / magnitude_product
}
8 changes: 4 additions & 4 deletions server/src/search/services.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub async fn insert_new_search(
pool: &PgPool,
user_id: &Uuid,
search_query_request: &api_models::SearchQueryRequest,
rephrased_query: &String,
rephrased_query: &str,
) -> crate::Result<data_models::Search> {
let thread = match search_query_request.thread_id {
Some(thread_id) => {
Expand Down Expand Up @@ -41,7 +41,7 @@ pub async fn insert_new_search(
&thread.thread_id,
search_query_request.query,
rephrased_query,
&String::from(""),
"",
)
.fetch_one(pool)
.await?;
Expand Down Expand Up @@ -74,14 +74,14 @@ pub async fn add_search_sources(
search: &data_models::Search,
sources: &Vec<Source>,
) -> crate::Result<Vec<data_models::Source>> {
if sources.len() == 0 {
if sources.is_empty() {
return Err(eyre!("No sources to add").into());
}

// remove duplicates with same url
let mut hash_set: HashSet<&String> = sources.iter().map(|s| &s.url).collect();
let sources = sources
.into_iter()
.iter()
.filter(|s| match hash_set.contains(&s.url) {
true => {
hash_set.remove(&s.url);
Expand Down
6 changes: 3 additions & 3 deletions server/src/startup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use color_eyre::eyre::eyre;
use log::info;
use oauth2::reqwest::async_http_client;
use openidconnect::core::{CoreClient, CoreProviderMetadata};
use regex::Regex;
use regex;
use sentry::{self, ClientInitGuard, ClientOptions};
use sqlx::postgres::PgPoolOptions;
use sqlx::PgPool;
Expand Down Expand Up @@ -93,9 +93,9 @@ impl AppState {
cache: CachePool::new(&settings.cache).await?,
agency_service: agency_service_connect(settings.agency_api.expose()).await?,
oidc_clients: initialize_oidc_clients(settings.oidc.clone()).await?,
brave_config: brave_search::prepare_brave_api_config(&settings.brave),
brave_config: settings.brave.clone().into(),
settings,
openai_stream_regex: Regex::new(r#"\"content\":\"(.*?)\"}"#)
openai_stream_regex: regex::Regex::new(r#""content":"(.*?)"}"#)
.map_err(|e| eyre!("Failed to compile OpenAI stream regex: {}", e))?,
})
}
Expand Down
15 changes: 12 additions & 3 deletions server/tests/health_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,18 @@ async fn health_check_works(pool: PgPool) {
let agency_service = agency_service_connect(settings.agency_api.expose())
.await
.unwrap();
let state = AppState::new(pool, cache, agency_service, vec![], settings)
.await
.unwrap();
let brave_api_config = settings.brave.clone().into();
let state = AppState::new(
pool,
cache,
agency_service,
vec![],
settings,
brave_api_config,
regex::Regex::new("").unwrap(),
)
.await
.unwrap();
let router = router(state).unwrap();
let request = Request::builder()
.uri("/health")
Expand Down
Loading

0 comments on commit 3dc0fd6

Please sign in to comment.