diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index c6bd4896..7c6a7f2f 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -9,6 +9,7 @@ use axum::{ use bytes::Bytes; use futures::{stream::StreamExt, Stream}; use goose::message::{Message, MessageContent}; +use goose::providers::base::{Moderation, ModerationResult}; use mcp_core::{content::Content, role::Role}; use serde::Deserialize; use serde_json::{json, Value}; @@ -406,7 +407,7 @@ mod tests { #[async_trait::async_trait] impl Provider for MockProvider { - async fn complete( + async fn complete_internal( &self, _system_prompt: &str, _messages: &[Message], @@ -427,6 +428,16 @@ mod tests { } } + #[async_trait::async_trait] + impl Moderation for MockProvider { + async fn moderate_content( + &self, + _content: &str, + ) -> Result { + Ok(ModerationResult::new(false, None, None)) + } + } + #[test] fn test_convert_messages_user_only() { let incoming = vec![IncomingMessage { diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 61f1094b..496d0f95 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -6,8 +6,7 @@ use serde_json::{json, Value}; use std::collections::HashSet; use std::time::Duration; -use super::base::ProviderUsage; -use super::base::{Provider, Usage}; +use super::base::{Moderation, ModerationResult, Provider, ProviderUsage, Usage}; use super::configs::{AnthropicProviderConfig, ModelConfig, ProviderModelConfig}; use super::model_pricing::cost; use super::model_pricing::model_pricing_for; @@ -193,7 +192,7 @@ impl Provider for AnthropicProvider { self.config.model_config() } - async fn complete( + async fn complete_internal( &self, system: &str, messages: &[Message], @@ -273,6 +272,13 @@ impl Provider for AnthropicProvider { } } +#[async_trait] +impl Moderation for AnthropicProvider { + async fn moderate_content(&self, content: &str) -> Result { + Ok(ModerationResult::new(false, None, None)) + } +} + #[cfg(test)] mod tests { use crate::providers::configs::ModelConfig; diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index fa52442c..a70ef970 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -1,9 +1,12 @@ use anyhow::Result; use rust_decimal::Decimal; use serde::{Deserialize, Serialize}; +use tokio::select; use super::configs::ModelConfig; -use crate::message::Message; +use crate::message::{Message, MessageContent}; +use mcp_core::content::TextContent; +use mcp_core::role::Role; use mcp_core::tool::Tool; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -47,12 +50,51 @@ impl Usage { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModerationResult { + /// Whether the content was flagged as inappropriate + pub flagged: bool, + /// Optional categories that were flagged (provider specific) + pub categories: Option>, + /// Optional scores for each category (provider specific) + pub category_scores: Option, +} + +impl ModerationResult { + pub fn new( + flagged: bool, + categories: Option>, + category_scores: Option, + ) -> Self { + Self { + flagged, + categories, + category_scores, + } + } +} + use async_trait::async_trait; use serde_json::Value; +/// Trait for handling content moderation +#[async_trait] +pub trait Moderation: Send + Sync { + /// Moderate the given content + /// + /// # Arguments + /// * `content` - The text content to moderate + /// + /// # Returns + /// A ModerationResult containing the moderation decision and details + async fn moderate_content(&self, content: &str) -> Result { + Ok(ModerationResult::new(false, None, None)) + } +} + /// Base trait for AI providers (OpenAI, Anthropic, etc) #[async_trait] -pub trait Provider: Send + Sync { +pub trait Provider: Send + Sync + Moderation { /// Get the model configuration fn get_model_config(&self) -> &ModelConfig; @@ -70,6 +112,74 @@ pub trait Provider: Send + Sync { system: &str, messages: &[Message], tools: &[Tool], + ) -> Result<(Message, ProviderUsage)> { + // Get the latest user message + let latest_user_msg = messages + .iter() + .rev() + .find(|msg| { + msg.role == Role::User + && msg + .content + .iter() + .any(|content| matches!(content, MessageContent::Text(_))) + }) + .ok_or_else(|| anyhow::anyhow!("No user message with text content found in history"))?; + + // Get the content to moderate + let content = latest_user_msg.content.first().unwrap().as_text().unwrap(); + + // Create futures for both operations + let moderation_fut = self.moderate_content(content); + let completion_fut = self.complete_internal(system, messages, tools); + + // Pin the futures + tokio::pin!(moderation_fut); + tokio::pin!(completion_fut); + + // Use select! to run both concurrently + let result = select! { + moderation = &mut moderation_fut => { + // If moderation completes first, check the result + let moderation_result = moderation?; + if moderation_result.flagged { + let categories = moderation_result.categories + .unwrap_or_else(|| vec!["unknown".to_string()]) + .join(", "); + return Err(anyhow::anyhow!( + "Content was flagged for moderation in categories: {}", + categories + )); + } + // If moderation passes, wait for completion + Ok(completion_fut.await?) + } + completion = &mut completion_fut => { + // If completion finishes first, still check moderation + let completion_result = completion?; + let moderation_result = moderation_fut.await?; + if moderation_result.flagged { + let categories = moderation_result.categories + .unwrap_or_else(|| vec!["unknown".to_string()]) + .join(", "); + return Err(anyhow::anyhow!( + "Content was flagged for moderation in categories: {}", + categories + )); + } + Ok(completion_result) + } + }; + + result + } + + /// Internal completion method to be implemented by providers + async fn complete_internal( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], ) -> Result<(Message, ProviderUsage)>; fn get_usage(&self, data: &Value) -> Result; @@ -79,6 +189,8 @@ pub trait Provider: Send + Sync { mod tests { use super::*; use serde_json::json; + use std::time::Duration; + use tokio::time::sleep; #[test] fn test_usage_creation() { @@ -106,4 +218,270 @@ mod tests { Ok(()) } + + #[test] + fn test_moderation_result_creation() { + let categories = vec!["hate".to_string(), "violence".to_string()]; + let scores = json!({ + "hate": 0.9, + "violence": 0.8 + }); + let result = ModerationResult::new(true, Some(categories.clone()), Some(scores.clone())); + + assert!(result.flagged); + assert_eq!(result.categories.unwrap(), categories); + assert_eq!(result.category_scores.unwrap(), scores); + } + + #[tokio::test] + async fn test_moderation_blocks_completion() { + #[derive(Clone)] + struct TestProvider; + + #[async_trait] + impl Moderation for TestProvider { + async fn moderate_content(&self, _content: &str) -> Result { + // Return quickly with flagged content + Ok(ModerationResult::new( + true, + Some(vec!["test".to_string()]), + None, + )) + } + } + + #[async_trait] + impl Provider for TestProvider { + fn get_usage(&self, _data: &Value) -> Result { + Ok(Usage::new(Some(1), Some(1), Some(2))) + } + + fn get_model_config(&self) -> &ModelConfig { + panic!("Should not be called"); + } + + async fn complete_internal( + &self, + _system: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> Result<(Message, ProviderUsage)> { + // Simulate a slow completion + sleep(Duration::from_secs(1)).await; + panic!("complete_internal should not finish when moderation fails"); + } + } + + let provider = TestProvider; + let test_message = Message { + role: Role::User, + created: chrono::Utc::now().timestamp(), + content: vec![MessageContent::Text(TextContent { + text: "test".to_string(), + annotations: None, + })], + }; + + let result = provider.complete("system", &[test_message], &[]).await; + + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Content was flagged")); + } + + #[tokio::test] + async fn test_moderation_blocks_completion_delayed() { + #[derive(Clone)] + struct TestProvider; + + #[async_trait] + impl Moderation for TestProvider { + async fn moderate_content(&self, _content: &str) -> Result { + sleep(Duration::from_secs(1)).await; + // Return quickly with flagged content + Ok(ModerationResult::new( + true, + Some(vec!["test".to_string()]), + None, + )) + } + } + + #[async_trait] + impl Provider for TestProvider { + fn get_usage(&self, _data: &Value) -> Result { + Ok(Usage::new(Some(1), Some(1), Some(2))) + } + + fn get_model_config(&self) -> &ModelConfig { + panic!("Should not be called"); + } + + async fn complete_internal( + &self, + _system: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> Result<(Message, ProviderUsage)> { + // Simulate a fast completion= + Ok(( + Message { + role: Role::Assistant, + created: chrono::Utc::now().timestamp(), + content: vec![MessageContent::text("test response")], + }, + ProviderUsage::new( + "test-model".to_string(), + Usage::new(Some(1), Some(1), Some(2)), + None, + ), + )) + } + } + + let provider = TestProvider; + let test_message = Message { + role: Role::User, + created: chrono::Utc::now().timestamp(), + content: vec![MessageContent::Text(TextContent { + text: "test".to_string(), + annotations: None, + })], + }; + + let result = provider.complete("system", &[test_message], &[]).await; + + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Content was flagged")); + } + + #[tokio::test] + async fn test_moderation_pass_completion_pass() { + #[derive(Clone)] + struct TestProvider; + + #[async_trait] + impl Moderation for TestProvider { + async fn moderate_content(&self, _content: &str) -> Result { + // Return quickly with flagged content + Ok(ModerationResult::new(false, None, None)) + } + } + + #[async_trait] + impl Provider for TestProvider { + fn get_usage(&self, _data: &Value) -> Result { + Ok(Usage::new(Some(1), Some(1), Some(2))) + } + + fn get_model_config(&self) -> &ModelConfig { + panic!("Should not be called"); + } + + async fn complete_internal( + &self, + _system: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> Result<(Message, ProviderUsage)> { + Ok(( + Message { + role: Role::Assistant, + created: chrono::Utc::now().timestamp(), + content: vec![MessageContent::text("test response")], + }, + ProviderUsage::new( + "test-model".to_string(), + Usage::new(Some(1), Some(1), Some(2)), + None, + ), + )) + } + } + + let provider = TestProvider; + let test_message = Message { + role: Role::User, + created: chrono::Utc::now().timestamp(), + content: vec![MessageContent::Text(TextContent { + text: "test".to_string(), + annotations: None, + })], + }; + + let result = provider.complete("system", &[test_message], &[]).await; + + assert!(result.is_ok()); + let (message, usage) = result.unwrap(); + assert_eq!(message.content[0].as_text().unwrap(), "test response"); + assert_eq!(usage.model, "test-model"); + } + + #[tokio::test] + async fn test_completion_succeeds_when_moderation_passes() { + #[derive(Clone)] + struct TestProvider; + + #[async_trait] + impl Moderation for TestProvider { + async fn moderate_content(&self, _content: &str) -> Result { + // Simulate some processing time + sleep(Duration::from_millis(100)).await; + Ok(ModerationResult::new(false, None, None)) + } + } + + #[async_trait] + impl Provider for TestProvider { + fn get_usage(&self, _data: &Value) -> Result { + Ok(Usage::new(Some(1), Some(1), Some(2))) + } + + fn get_model_config(&self) -> &ModelConfig { + panic!("Should not be called"); + } + + async fn complete_internal( + &self, + _system: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> Result<(Message, ProviderUsage)> { + Ok(( + Message { + role: Role::Assistant, + created: chrono::Utc::now().timestamp(), + content: vec![MessageContent::text("test response")], + }, + ProviderUsage::new( + "test-model".to_string(), + Usage::new(Some(1), Some(1), Some(2)), + None, + ), + )) + } + } + + let provider = TestProvider; + let test_message = Message { + role: Role::User, + created: chrono::Utc::now().timestamp(), + content: vec![MessageContent::Text(TextContent { + text: "test".to_string(), + annotations: None, + })], + }; + + let result = provider.complete("system", &[test_message], &[]).await; + + assert!(result.is_ok()); + let (message, usage) = result.unwrap(); + assert_eq!(message.content[0].as_text().unwrap(), "test response"); + assert_eq!(usage.model, "test-model"); + } } diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index ab7434bd..f8413960 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -4,7 +4,7 @@ use reqwest::Client; use serde_json::{json, Value}; use std::time::Duration; -use super::base::{Provider, ProviderUsage, Usage}; +use super::base::{Moderation, ModerationResult, Provider, ProviderUsage, Usage}; use super::configs::{DatabricksAuth, DatabricksProviderConfig, ModelConfig, ProviderModelConfig}; use super::model_pricing::{cost, model_pricing_for}; use super::oauth; @@ -74,7 +74,7 @@ impl Provider for DatabricksProvider { self.config.model_config() } - async fn complete( + async fn complete_internal( &self, system: &str, messages: &[Message], @@ -148,6 +148,13 @@ impl Provider for DatabricksProvider { } } +#[async_trait] +impl Moderation for DatabricksProvider { + async fn moderate_content(&self, content: &str) -> Result { + Ok(ModerationResult::new(false, None, None)) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index 97d9a92c..1ef40dd5 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -1,10 +1,11 @@ use crate::errors::AgentError; use crate::message::{Message, MessageContent}; -use crate::providers::base::{Provider, ProviderUsage, Usage}; +use crate::providers::base::{Moderation, ModerationResult, Provider, ProviderUsage, Usage}; use crate::providers::configs::{GoogleProviderConfig, ModelConfig, ProviderModelConfig}; use crate::providers::utils::{ handle_response, is_valid_function_name, sanitize_function_name, unescape_json_values, }; +use anyhow::Result; use async_trait::async_trait; use mcp_core::{Content, Role, Tool, ToolCall}; use reqwest::Client; @@ -278,7 +279,7 @@ impl Provider for GoogleProvider { self.config.model_config() } - async fn complete( + async fn complete_internal( &self, system: &str, messages: &[Message], @@ -345,6 +346,13 @@ impl Provider for GoogleProvider { } } +#[async_trait] +impl Moderation for GoogleProvider { + async fn moderate_content(&self, content: &str) -> Result { + Ok(ModerationResult::new(false, None, None)) + } +} + #[cfg(test)] // Only compiles this module when running tests mod tests { use super::*; diff --git a/crates/goose/src/providers/groq.rs b/crates/goose/src/providers/groq.rs index 52605836..30f87bb4 100644 --- a/crates/goose/src/providers/groq.rs +++ b/crates/goose/src/providers/groq.rs @@ -1,11 +1,12 @@ use crate::message::Message; -use crate::providers::base::{Provider, ProviderUsage, Usage}; +use crate::providers::base::{Moderation, ModerationResult, Provider, ProviderUsage, Usage}; use crate::providers::configs::{GroqProviderConfig, ModelConfig, ProviderModelConfig}; use crate::providers::openai_utils::{ create_openai_request_payload_with_concat_response_content, get_openai_usage, openai_response_to_message, }; use crate::providers::utils::{get_model, handle_response}; +use anyhow::Result; use async_trait::async_trait; use mcp_core::Tool; use reqwest::Client; @@ -52,7 +53,7 @@ impl Provider for GroqProvider { self.config.model_config() } - async fn complete( + async fn complete_internal( &self, system: &str, messages: &[Message], @@ -79,6 +80,13 @@ impl Provider for GroqProvider { } } +#[async_trait] +impl Moderation for GroqProvider { + async fn moderate_content(&self, content: &str) -> Result { + Ok(ModerationResult::new(false, None, None)) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/goose/src/providers/mock.rs b/crates/goose/src/providers/mock.rs index fa84a63a..7d42c099 100644 --- a/crates/goose/src/providers/mock.rs +++ b/crates/goose/src/providers/mock.rs @@ -1,4 +1,4 @@ -use super::base::ProviderUsage; +use super::base::{Moderation, ModerationResult, ProviderUsage}; use crate::message::Message; use crate::providers::base::{Provider, Usage}; use crate::providers::configs::ModelConfig; @@ -40,7 +40,7 @@ impl Provider for MockProvider { &self.model_config } - async fn complete( + async fn complete_internal( &self, _system_prompt: &str, _messages: &[Message], @@ -66,3 +66,10 @@ impl Provider for MockProvider { Ok(Usage::new(None, None, None)) } } + +#[async_trait] +impl Moderation for MockProvider { + async fn moderate_content(&self, content: &str) -> Result { + Ok(ModerationResult::new(false, None, None)) + } +} diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 8b07333a..1e1266aa 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -1,4 +1,4 @@ -use super::base::{Provider, ProviderUsage, Usage}; +use super::base::{Moderation, ModerationResult, Provider, ProviderUsage, Usage}; use super::configs::{ModelConfig, OllamaProviderConfig, ProviderModelConfig}; use super::utils::{get_model, handle_response}; use crate::message::Message; @@ -47,7 +47,7 @@ impl Provider for OllamaProvider { self.config.model_config() } - async fn complete( + async fn complete_internal( &self, system: &str, messages: &[Message], @@ -71,6 +71,13 @@ impl Provider for OllamaProvider { } } +#[async_trait] +impl Moderation for OllamaProvider { + async fn moderate_content(&self, content: &str) -> Result { + Ok(ModerationResult::new(false, None, None)) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 6f310958..678e999d 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -4,7 +4,7 @@ use reqwest::Client; use serde_json::Value; use std::time::Duration; -use super::base::ProviderUsage; +use super::base::{Moderation, ModerationResult, ProviderUsage}; use super::base::{Provider, Usage}; use super::configs::OpenAiProviderConfig; use super::configs::{ModelConfig, ProviderModelConfig}; @@ -17,14 +17,28 @@ use crate::providers::openai_utils::{ openai_response_to_message, }; use mcp_core::tool::Tool; +use serde::Serialize; pub const OPEN_AI_DEFAULT_MODEL: &str = "gpt-4o"; +pub const OPEN_AI_MODERATION_MODEL: &str = "omni-moderation-latest"; pub struct OpenAiProvider { client: Client, config: OpenAiProviderConfig, } +#[derive(Serialize)] +struct OpenAiModerationRequest { + input: String, + model: String, +} + +impl OpenAiModerationRequest { + pub fn new(input: String, model: String) -> Self { + Self { input, model } + } +} + impl OpenAiProvider { pub fn new(config: OpenAiProviderConfig) -> Result { let client = Client::builder() @@ -58,7 +72,7 @@ impl Provider for OpenAiProvider { self.config.model_config() } - async fn complete( + async fn complete_internal( &self, system: &str, messages: &[Message], @@ -92,6 +106,49 @@ impl Provider for OpenAiProvider { } } +#[async_trait] +impl Moderation for OpenAiProvider { + async fn moderate_content(&self, content: &str) -> Result { + let url = format!("{}/v1/moderations", self.config.host.trim_end_matches('/')); + + let request = + OpenAiModerationRequest::new(content.to_string(), OPEN_AI_MODERATION_MODEL.to_string()); + + let response = self + .client + .post(&url) + .header("Authorization", format!("Bearer {}", self.config.api_key)) + .json(&request) + .send() + .await?; + + let response_json: serde_json::Value = response.json().await?; + + let flagged = response_json["results"][0]["flagged"] + .as_bool() + .unwrap_or(false); + if flagged { + let categories = response_json["results"][0]["categories"] + .as_object() + .unwrap(); + let category_scores = response_json["results"][0]["category_scores"].clone(); + return Ok(ModerationResult::new( + flagged, + Some( + categories + .iter() + .filter(|(_, value)| value.as_bool().unwrap_or(false)) + .map(|(key, _)| key.to_string()) + .collect(), + ), + Some(category_scores), + )); + } else { + return Ok(ModerationResult::new(flagged, None, None)); + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -133,7 +190,7 @@ mod tests { // Call the complete method let (message, usage) = provider - .complete("You are a helpful assistant.", &messages, &[]) + .complete_internal("You are a helpful assistant.", &messages, &[]) .await?; // Assert the response @@ -165,7 +222,7 @@ mod tests { // Call the complete method let (message, usage) = provider - .complete( + .complete_internal( "You are a helpful assistant.", &messages, &[create_test_tool()],