149 lines
4.8 KiB
Rust
149 lines
4.8 KiB
Rust
use std::sync::Arc;
|
|
use std::time::Instant;
|
|
|
|
use crate::Result;
|
|
use anyhow::{Context, anyhow};
|
|
use async_trait::async_trait;
|
|
use serde_json::{Value, json};
|
|
|
|
use super::{Tool, ToolResult};
|
|
use crate::sandbox::{SandboxConfig, SandboxedProcess};
|
|
|
|
pub struct CodeExecTool {
|
|
allowed_languages: Arc<Vec<String>>,
|
|
}
|
|
|
|
impl CodeExecTool {
|
|
pub fn new(allowed_languages: Vec<String>) -> Self {
|
|
Self {
|
|
allowed_languages: Arc::new(allowed_languages),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Tool for CodeExecTool {
|
|
fn name(&self) -> &'static str {
|
|
"code_exec"
|
|
}
|
|
|
|
fn description(&self) -> &'static str {
|
|
"Execute code snippets within a sandboxed environment"
|
|
}
|
|
|
|
fn schema(&self) -> Value {
|
|
json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"language": {
|
|
"type": "string",
|
|
"enum": self.allowed_languages.as_slice(),
|
|
"description": "Language of the code block"
|
|
},
|
|
"code": {
|
|
"type": "string",
|
|
"minLength": 1,
|
|
"maxLength": 10000,
|
|
"description": "Code to execute"
|
|
},
|
|
"timeout": {
|
|
"type": "integer",
|
|
"minimum": 1,
|
|
"maximum": 300,
|
|
"default": 30,
|
|
"description": "Execution timeout in seconds"
|
|
}
|
|
},
|
|
"required": ["language", "code"],
|
|
"additionalProperties": false
|
|
})
|
|
}
|
|
|
|
async fn execute(&self, args: Value) -> Result<ToolResult> {
|
|
let start = Instant::now();
|
|
|
|
let language = args
|
|
.get("language")
|
|
.and_then(Value::as_str)
|
|
.context("Missing language parameter")?;
|
|
let code = args
|
|
.get("code")
|
|
.and_then(Value::as_str)
|
|
.context("Missing code parameter")?;
|
|
let timeout = args.get("timeout").and_then(Value::as_u64).unwrap_or(30);
|
|
|
|
if !self.allowed_languages.iter().any(|lang| lang == language) {
|
|
return Err(anyhow!("Language '{}' not permitted", language).into());
|
|
}
|
|
|
|
let (command, command_args) = match language {
|
|
"python" => (
|
|
"python3".to_string(),
|
|
vec!["-c".to_string(), code.to_string()],
|
|
),
|
|
"javascript" => ("node".to_string(), vec!["-e".to_string(), code.to_string()]),
|
|
"bash" => ("bash".to_string(), vec!["-c".to_string(), code.to_string()]),
|
|
"rust" => {
|
|
let mut result =
|
|
ToolResult::error("Rust execution is not yet supported in the sandbox");
|
|
result.duration = start.elapsed();
|
|
return Ok(result);
|
|
}
|
|
other => return Err(anyhow!("Unsupported language: {}", other).into()),
|
|
};
|
|
|
|
let sandbox_config = SandboxConfig {
|
|
allow_network: false,
|
|
timeout_seconds: timeout,
|
|
..Default::default()
|
|
};
|
|
|
|
let sandbox_result = tokio::task::spawn_blocking(move || {
|
|
let sandbox = SandboxedProcess::new(sandbox_config)?;
|
|
let arg_refs: Vec<&str> = command_args.iter().map(|s| s.as_str()).collect();
|
|
sandbox.execute(&command, &arg_refs)
|
|
})
|
|
.await
|
|
.context("Sandbox execution task failed")??;
|
|
|
|
let mut result = if sandbox_result.exit_code == 0 {
|
|
ToolResult::success(json!({
|
|
"stdout": sandbox_result.stdout,
|
|
"stderr": sandbox_result.stderr,
|
|
"exit_code": sandbox_result.exit_code,
|
|
"timed_out": sandbox_result.was_timeout,
|
|
}))
|
|
} else {
|
|
let error_msg = if sandbox_result.was_timeout {
|
|
format!(
|
|
"Execution timed out after {} seconds (exit code {}): {}",
|
|
timeout, sandbox_result.exit_code, sandbox_result.stderr
|
|
)
|
|
} else {
|
|
format!(
|
|
"Execution failed with status {}: {}",
|
|
sandbox_result.exit_code, sandbox_result.stderr
|
|
)
|
|
};
|
|
let mut err_result = ToolResult::error(&error_msg);
|
|
err_result.output = json!({
|
|
"stdout": sandbox_result.stdout,
|
|
"stderr": sandbox_result.stderr,
|
|
"exit_code": sandbox_result.exit_code,
|
|
"timed_out": sandbox_result.was_timeout,
|
|
});
|
|
err_result
|
|
};
|
|
|
|
result.duration = start.elapsed();
|
|
result
|
|
.metadata
|
|
.insert("language".to_string(), language.to_string());
|
|
result
|
|
.metadata
|
|
.insert("timeout_seconds".to_string(), timeout.to_string());
|
|
|
|
Ok(result)
|
|
}
|
|
}
|