diff --git a/Cargo.toml b/Cargo.toml index 290f709..7ee36e9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "crates/app/cli", + "crates/core/agent", "crates/llm/ollama", "crates/platform/config", "crates/platform/hooks", diff --git a/crates/app/cli/Cargo.toml b/crates/app/cli/Cargo.toml index a698672..5c08fee 100644 --- a/crates/app/cli/Cargo.toml +++ b/crates/app/cli/Cargo.toml @@ -11,6 +11,7 @@ tokio = { version = "1.39", features = ["macros", "rt-multi-thread"] } serde = { version = "1", features = ["derive"] } serde_json = "1" color-eyre = "0.6" +agent-core = { path = "../../core/agent" } llm-ollama = { path = "../../llm/ollama" } tools-fs = { path = "../../tools/fs" } tools-bash = { path = "../../tools/bash" } diff --git a/crates/app/cli/src/main.rs b/crates/app/cli/src/main.rs index d4ff1af..3136457 100644 --- a/crates/app/cli/src/main.rs +++ b/crates/app/cli/src/main.rs @@ -461,50 +461,20 @@ async fn main() -> Result<()> { stream: true, }; - let msgs = vec![ChatMessage { - role: "user".into(), - content: prompt.clone(), - }]; - let start_time = SystemTime::now(); // Handle different output formats match output_format { OutputFormat::Text => { - // Text format: stream to stdout as before - let mut stream = client.chat_stream(&msgs, &opts).await?; - while let Some(chunk) = stream.try_next().await? { - if let Some(m) = chunk.message { - if let Some(c) = m.content { - print!("{c}"); - io::stdout().flush()?; - } - } - if matches!(chunk.done, Some(true)) { - break; - } - } - println!(); // Newline after response + // Text format: Use agent orchestrator with tool calling + let response = agent_core::run_agent_loop(&client, &prompt, &opts, &perms).await?; + println!("{}", response); } OutputFormat::Json => { - // JSON format: collect all chunks, then output final JSON - let mut stream = client.chat_stream(&msgs, &opts).await?; - let mut response = String::new(); - - while let Some(chunk) = stream.try_next().await? { - if let Some(m) = chunk.message { - if let Some(c) = m.content { - response.push_str(&c); - } - } - if matches!(chunk.done, Some(true)) { - break; - } - } + // JSON format: Use agent loop and output as JSON + let response = agent_core::run_agent_loop(&client, &prompt, &opts, &perms).await?; let duration_ms = start_time.elapsed().unwrap().as_millis() as u64; - - // Rough token estimate (tokens ~= chars / 4) let estimated_tokens = ((prompt.len() + response.len()) / 4) as u64; let output = SessionOutput { @@ -526,7 +496,7 @@ async fn main() -> Result<()> { println!("{}", serde_json::to_string(&output)?); } OutputFormat::StreamJson => { - // Stream-JSON format: emit session_start, chunks, and session_end + // Stream-JSON format: emit session_start, response, and session_end let session_start = StreamEvent { event_type: "session_start".to_string(), session_id: Some(session_id.clone()), @@ -535,30 +505,17 @@ async fn main() -> Result<()> { }; println!("{}", serde_json::to_string(&session_start)?); - let mut stream = client.chat_stream(&msgs, &opts).await?; - let mut response = String::new(); + let response = agent_core::run_agent_loop(&client, &prompt, &opts, &perms).await?; - while let Some(chunk) = stream.try_next().await? { - if let Some(m) = chunk.message { - if let Some(c) = m.content { - response.push_str(&c); - let chunk_event = StreamEvent { - event_type: "chunk".to_string(), - session_id: None, - content: Some(c), - stats: None, - }; - println!("{}", serde_json::to_string(&chunk_event)?); - } - } - if matches!(chunk.done, Some(true)) { - break; - } - } + let chunk_event = StreamEvent { + event_type: "chunk".to_string(), + session_id: None, + content: Some(response.clone()), + stats: None, + }; + println!("{}", serde_json::to_string(&chunk_event)?); let duration_ms = start_time.elapsed().unwrap().as_millis() as u64; - - // Rough token estimate let estimated_tokens = ((prompt.len() + response.len()) / 4) as u64; let session_end = StreamEvent { diff --git a/crates/core/agent/Cargo.toml b/crates/core/agent/Cargo.toml new file mode 100644 index 0000000..0e13766 --- /dev/null +++ b/crates/core/agent/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "agent-core" +version = "0.1.0" +edition.workspace = true +license.workspace = true +rust-version.workspace = true + +[dependencies] +serde = { version = "1", features = ["derive"] } +serde_json = "1" +color-eyre = "0.6" +tokio = { version = "1", features = ["full"] } +futures-util = "0.3" + +# Internal dependencies +llm-ollama = { path = "../../llm/ollama" } +permissions = { path = "../../platform/permissions" } +tools-fs = { path = "../../tools/fs" } +tools-bash = { path = "../../tools/bash" } + +[dev-dependencies] diff --git a/crates/core/agent/src/lib.rs b/crates/core/agent/src/lib.rs new file mode 100644 index 0000000..43ebb64 --- /dev/null +++ b/crates/core/agent/src/lib.rs @@ -0,0 +1,372 @@ +use color_eyre::eyre::{Result, eyre}; +use futures_util::TryStreamExt; +use llm_ollama::{ChatMessage, OllamaClient, OllamaOptions, Tool, ToolFunction, ToolParameters}; +use permissions::{PermissionDecision, PermissionManager, Tool as PermTool}; +use serde_json::{json, Value}; + +/// Define all available tools for the LLM +pub fn get_tool_definitions() -> Vec { + vec![ + Tool { + tool_type: "function".to_string(), + function: ToolFunction { + name: "read".to_string(), + description: "Read the contents of a file".to_string(), + parameters: ToolParameters { + param_type: "object".to_string(), + properties: json!({ + "path": { + "type": "string", + "description": "The path to the file to read" + } + }), + required: vec!["path".to_string()], + }, + }, + }, + Tool { + tool_type: "function".to_string(), + function: ToolFunction { + name: "glob".to_string(), + description: "Find files matching a glob pattern (e.g., '**/*.rs' for all Rust files)".to_string(), + parameters: ToolParameters { + param_type: "object".to_string(), + properties: json!({ + "pattern": { + "type": "string", + "description": "Glob pattern to match files (e.g., '**/*.toml', '*.md')" + } + }), + required: vec!["pattern".to_string()], + }, + }, + }, + Tool { + tool_type: "function".to_string(), + function: ToolFunction { + name: "grep".to_string(), + description: "Search for a pattern in files within a directory".to_string(), + parameters: ToolParameters { + param_type: "object".to_string(), + properties: json!({ + "root": { + "type": "string", + "description": "Root directory to search in" + }, + "pattern": { + "type": "string", + "description": "Pattern to search for" + } + }), + required: vec!["root".to_string(), "pattern".to_string()], + }, + }, + }, + Tool { + tool_type: "function".to_string(), + function: ToolFunction { + name: "write".to_string(), + description: "Write content to a file".to_string(), + parameters: ToolParameters { + param_type: "object".to_string(), + properties: json!({ + "path": { + "type": "string", + "description": "Path where the file should be written" + }, + "content": { + "type": "string", + "description": "Content to write to the file" + } + }), + required: vec!["path".to_string(), "content".to_string()], + }, + }, + }, + Tool { + tool_type: "function".to_string(), + function: ToolFunction { + name: "edit".to_string(), + description: "Edit a file by replacing old text with new text".to_string(), + parameters: ToolParameters { + param_type: "object".to_string(), + properties: json!({ + "path": { + "type": "string", + "description": "Path to the file to edit" + }, + "old_string": { + "type": "string", + "description": "Text to find and replace" + }, + "new_string": { + "type": "string", + "description": "Text to replace with" + } + }), + required: vec!["path".to_string(), "old_string".to_string(), "new_string".to_string()], + }, + }, + }, + Tool { + tool_type: "function".to_string(), + function: ToolFunction { + name: "bash".to_string(), + description: "Execute a bash command. Use carefully and only when necessary.".to_string(), + parameters: ToolParameters { + param_type: "object".to_string(), + properties: json!({ + "command": { + "type": "string", + "description": "The bash command to execute" + } + }), + required: vec!["command".to_string()], + }, + }, + }, + ] +} + +/// Execute a tool call and return the result +pub async fn execute_tool( + tool_name: &str, + arguments: &Value, + perms: &PermissionManager, +) -> Result { + match tool_name { + "read" => { + let path = arguments["path"] + .as_str() + .ok_or_else(|| eyre!("Missing 'path' argument"))?; + + // Check permission + match perms.check(PermTool::Read, Some(path)) { + PermissionDecision::Allow => { + let content = tools_fs::read_file(path)?; + Ok(content) + } + PermissionDecision::Ask => { + Err(eyre!("Permission required: Read operation needs approval")) + } + PermissionDecision::Deny => { + Err(eyre!("Permission denied: Read operation is blocked")) + } + } + } + "glob" => { + let pattern = arguments["pattern"] + .as_str() + .ok_or_else(|| eyre!("Missing 'pattern' argument"))?; + + // Check permission + match perms.check(PermTool::Glob, None) { + PermissionDecision::Allow => { + let files = tools_fs::glob_list(pattern)?; + Ok(files.join("\n")) + } + PermissionDecision::Ask => { + Err(eyre!("Permission required: Glob operation needs approval")) + } + PermissionDecision::Deny => { + Err(eyre!("Permission denied: Glob operation is blocked")) + } + } + } + "grep" => { + let root = arguments["root"] + .as_str() + .ok_or_else(|| eyre!("Missing 'root' argument"))?; + let pattern = arguments["pattern"] + .as_str() + .ok_or_else(|| eyre!("Missing 'pattern' argument"))?; + + // Check permission + match perms.check(PermTool::Grep, None) { + PermissionDecision::Allow => { + let results = tools_fs::grep(root, pattern)?; + let lines: Vec = results + .into_iter() + .map(|(path, line_num, text)| format!("{}:{}:{}", path, line_num, text)) + .collect(); + Ok(lines.join("\n")) + } + PermissionDecision::Ask => { + Err(eyre!("Permission required: Grep operation needs approval")) + } + PermissionDecision::Deny => { + Err(eyre!("Permission denied: Grep operation is blocked")) + } + } + } + "write" => { + let path = arguments["path"] + .as_str() + .ok_or_else(|| eyre!("Missing 'path' argument"))?; + let content = arguments["content"] + .as_str() + .ok_or_else(|| eyre!("Missing 'content' argument"))?; + + // Check permission + match perms.check(PermTool::Write, Some(path)) { + PermissionDecision::Allow => { + tools_fs::write_file(path, content)?; + Ok(format!("File written successfully: {}", path)) + } + PermissionDecision::Ask => { + Err(eyre!("Permission required: Write operation needs approval")) + } + PermissionDecision::Deny => { + Err(eyre!("Permission denied: Write operation is blocked")) + } + } + } + "edit" => { + let path = arguments["path"] + .as_str() + .ok_or_else(|| eyre!("Missing 'path' argument"))?; + let old_string = arguments["old_string"] + .as_str() + .ok_or_else(|| eyre!("Missing 'old_string' argument"))?; + let new_string = arguments["new_string"] + .as_str() + .ok_or_else(|| eyre!("Missing 'new_string' argument"))?; + + // Check permission + match perms.check(PermTool::Edit, Some(path)) { + PermissionDecision::Allow => { + tools_fs::edit_file(path, old_string, new_string)?; + Ok(format!("File edited successfully: {}", path)) + } + PermissionDecision::Ask => { + Err(eyre!("Permission required: Edit operation needs approval")) + } + PermissionDecision::Deny => { + Err(eyre!("Permission denied: Edit operation is blocked")) + } + } + } + "bash" => { + let command = arguments["command"] + .as_str() + .ok_or_else(|| eyre!("Missing 'command' argument"))?; + + // Check permission + match perms.check(PermTool::Bash, Some(command)) { + PermissionDecision::Allow => { + let mut session = tools_bash::BashSession::new().await?; + let output = session.execute(command, None).await?; + let result = if !output.stdout.is_empty() { + output.stdout + } else if !output.stderr.is_empty() { + format!("stderr: {}", output.stderr) + } else { + "Command executed successfully with no output".to_string() + }; + Ok(result) + } + PermissionDecision::Ask => { + Err(eyre!("Permission required: Bash operation needs approval")) + } + PermissionDecision::Deny => { + Err(eyre!("Permission denied: Bash operation is blocked")) + } + } + } + _ => Err(eyre!("Unknown tool: {}", tool_name)), + } +} + +/// Run the agent loop with tool calling +pub async fn run_agent_loop( + client: &OllamaClient, + user_prompt: &str, + opts: &OllamaOptions, + perms: &PermissionManager, +) -> Result { + let tools = get_tool_definitions(); + let mut messages = vec![ChatMessage { + role: "user".to_string(), + content: Some(user_prompt.to_string()), + tool_calls: None, + }]; + + let max_iterations = 10; // Prevent infinite loops + let mut iteration = 0; + + loop { + iteration += 1; + if iteration > max_iterations { + return Err(eyre!("Max iterations reached")); + } + + // Call LLM with messages and tools + let mut stream = client.chat_stream(&messages, opts, Some(&tools)).await?; + let mut response_content = String::new(); + let mut tool_calls = None; + + // Collect the streamed response + while let Some(chunk) = stream.try_next().await? { + if let Some(msg) = chunk.message { + if let Some(content) = msg.content { + response_content.push_str(&content); + } + if let Some(calls) = msg.tool_calls { + tool_calls = Some(calls); + } + } + } + + // Drop the stream to release the borrow on messages + drop(stream); + + // Check if LLM wants to call tools + if let Some(calls) = tool_calls { + // Add assistant message with tool calls + messages.push(ChatMessage { + role: "assistant".to_string(), + content: if response_content.is_empty() { + None + } else { + Some(response_content.clone()) + }, + tool_calls: Some(calls.clone()), + }); + + // Execute each tool call + for call in calls { + let tool_name = &call.function.name; + let arguments = &call.function.arguments; + + println!("\nšŸ”§ Tool call: {} with args: {}", tool_name, arguments); + + match execute_tool(tool_name, arguments, perms).await { + Ok(result) => { + println!("āœ… Tool result: {}", result); + // Add tool result message + messages.push(ChatMessage { + role: "tool".to_string(), + content: Some(result), + tool_calls: None, + }); + } + Err(e) => { + println!("āŒ Tool error: {}", e); + // Add error message as tool result + messages.push(ChatMessage { + role: "tool".to_string(), + content: Some(format!("Error: {}", e)), + tool_calls: None, + }); + } + } + } + + // Continue loop to get next response + continue; + } + + // No tool calls, we're done + return Ok(response_content); + } +} diff --git a/crates/llm/ollama/src/client.rs b/crates/llm/ollama/src/client.rs index d009ece..8f87d98 100644 --- a/crates/llm/ollama/src/client.rs +++ b/crates/llm/ollama/src/client.rs @@ -1,4 +1,4 @@ -use crate::types::{ChatMessage, ChatResponseChunk}; +use crate::types::{ChatMessage, ChatResponseChunk, Tool}; use futures::{Stream, TryStreamExt}; use reqwest::Client; use serde::Serialize; @@ -50,15 +50,18 @@ impl OllamaClient { &self, messages: &[ChatMessage], opts: &OllamaOptions, + tools: Option<&[Tool]>, ) -> Result>, OllamaError> { #[derive(Serialize)] struct Body<'a> { model: &'a str, messages: &'a [ChatMessage], stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option<&'a [Tool]>, } let url = format!("{}/api/chat", self.base_url); - let body = Body {model: &opts.model, messages, stream: true}; + let body = Body {model: &opts.model, messages, stream: true, tools}; let mut req = self.http.post(url).json(&body); // Add Authorization header if API key is present diff --git a/crates/llm/ollama/src/lib.rs b/crates/llm/ollama/src/lib.rs index 1b4af68..0c18e30 100644 --- a/crates/llm/ollama/src/lib.rs +++ b/crates/llm/ollama/src/lib.rs @@ -2,4 +2,4 @@ pub mod client; pub mod types; pub use client::{OllamaClient, OllamaOptions}; -pub use types::{ChatMessage, ChatResponseChunk}; +pub use types::{ChatMessage, ChatResponseChunk, Tool, ToolCall, ToolFunction, ToolParameters, FunctionCall}; diff --git a/crates/llm/ollama/src/types.rs b/crates/llm/ollama/src/types.rs index 10ff880..0fb222c 100644 --- a/crates/llm/ollama/src/types.rs +++ b/crates/llm/ollama/src/types.rs @@ -1,9 +1,50 @@ use serde::{Deserialize, Serialize}; +use serde_json::Value; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatMessage { - pub role: String, // "user", | "assistant" | "system" - pub content: String, + pub role: String, // "user" | "assistant" | "system" | "tool" + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ToolCall { + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub call_type: Option, // "function" + pub function: FunctionCall, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct FunctionCall { + pub name: String, + pub arguments: Value, // JSON object with arguments +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Tool { + #[serde(rename = "type")] + pub tool_type: String, // "function" + pub function: ToolFunction, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolFunction { + pub name: String, + pub description: String, + pub parameters: ToolParameters, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolParameters { + #[serde(rename = "type")] + pub param_type: String, // "object" + pub properties: Value, + pub required: Vec, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] @@ -19,4 +60,6 @@ pub struct ChatResponseChunk { pub struct ChunkMessage { pub role: Option, pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, }