Skip to content

Commit

Permalink
Interrupt running shell tool commands
Browse files Browse the repository at this point in the history
  • Loading branch information
jsibbison-square committed Nov 28, 2024
1 parent a726cb3 commit a840942
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 12 deletions.
49 changes: 38 additions & 11 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ use std::{
convert::Infallible,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use tokio::sync::mpsc;
use tokio::time::timeout;
use tokio_stream::wrappers::ReceiverStream;

// Types matching the incoming JSON structure
Expand Down Expand Up @@ -139,7 +141,6 @@ fn convert_messages(incoming: Vec<IncomingMessage>) -> Vec<Message> {
struct ProtocolFormatter;

impl ProtocolFormatter {

fn format_text(text: &str) -> String {
let encoded_text = serde_json::to_string(text).unwrap_or_else(|_| String::new());
format!("0:{}\n", encoded_text)
Expand Down Expand Up @@ -175,6 +176,10 @@ impl ProtocolFormatter {
});
format!("d:{}\n", finish)
}

fn heartbeat() -> String {
"2:[]\n".to_string()
}
}

async fn stream_message(
Expand Down Expand Up @@ -294,18 +299,40 @@ async fn handler(
}
};

while let Some(response) = stream.next().await {
match response {
Ok(message) => {
if let Err(e) = stream_message(message, &tx).await {
tracing::error!("Error sending message through channel: {}", e);
break;
loop {
tokio::select! {
response = timeout(Duration::from_millis(500), stream.next()) => {
match response {
Ok(Some(Ok(message))) => {
if let Err(e) = stream_message(message, &tx).await {
tracing::error!("Error sending message through channel: {}", e);
break;
}
}
Ok(Some(Err(e))) => {
tracing::error!("Error processing message: {}", e);
break;
}
Ok(None) => {
break;
}
Err(_) => { // Heartbeat, used to detect disconnected clients and then end running tools.
if let Err(e) = tx.try_send(ProtocolFormatter::heartbeat()) {
match e {
mpsc::error::TrySendError::Closed(_) => {
// Client has disconnected, end the stream and close running tools (works by ending this process).
break;
}
mpsc::error::TrySendError::Full(_) => {
tracing::warn!("Error sending heartbeat message through channel: {}", e);
continue;
}
}
}
continue;
}
}
}
Err(e) => {
tracing::error!("Error processing message: {}", e);
break;
}
}
}

Expand Down
4 changes: 3 additions & 1 deletion crates/goose/src/developer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use serde_json::{json, Value};
use std::collections::{HashMap, HashSet};
use std::io::Cursor;
use std::path::{Path, PathBuf};
use std::process::Command;
use std::sync::Mutex;
use tokio::process::Command;
use xcap::Monitor;

use crate::errors::{AgentError, AgentResult};
Expand Down Expand Up @@ -192,9 +192,11 @@ impl DeveloperSystem {

// Execute the command
let output = Command::new("bash")
.kill_on_drop(true) // Critical so that the command is killed when the agent.reply stream is interrupted.
.arg("-c")
.arg(cmd_with_redirect)
.output()
.await
.map_err(|e| AgentError::ExecutionError(e.to_string()))?;

let output_str = format!(
Expand Down

0 comments on commit a840942

Please sign in to comment.