feat(v2): complete multi-LLM providers, TUI redesign, and advanced agent features
Multi-LLM Provider Support: - Add llm-core crate with LlmProvider trait abstraction - Implement Anthropic Claude API client with streaming - Implement OpenAI API client with streaming - Add token counting with SimpleTokenCounter and ClaudeTokenCounter - Add retry logic with exponential backoff and jitter Borderless TUI Redesign: - Rewrite theme system with terminal capability detection (Full/Unicode256/Basic) - Add provider tabs component with keybind switching [1]/[2]/[3] - Implement vim-modal input (Normal/Insert/Visual/Command modes) - Redesign chat panel with timestamps and streaming indicators - Add multi-provider status bar with cost tracking - Add Nerd Font icons with graceful ASCII fallbacks - Add syntax highlighting (syntect) and markdown rendering (pulldown-cmark) Advanced Agent Features: - Add system prompt builder with configurable components - Enhance subagent orchestration with parallel execution - Add git integration module for safe command detection - Add streaming tool results via channels - Expand tool set: AskUserQuestion, TodoWrite, LS, MultiEdit, BashOutput, KillShell - Add WebSearch with provider abstraction Plugin System Enhancement: - Add full agent definition parsing from YAML frontmatter - Add skill system with progressive disclosure - Wire plugin hooks into HookManager 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -11,12 +11,17 @@ serde_json = "1"
|
||||
color-eyre = "0.6"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
futures-util = "0.3"
|
||||
tracing = "0.1"
|
||||
async-trait = "0.1"
|
||||
|
||||
# Internal dependencies
|
||||
llm-ollama = { path = "../../llm/ollama" }
|
||||
llm-core = { path = "../../llm/core" }
|
||||
permissions = { path = "../../platform/permissions" }
|
||||
tools-fs = { path = "../../tools/fs" }
|
||||
tools-bash = { path = "../../tools/bash" }
|
||||
tools-ask = { path = "../../tools/ask" }
|
||||
tools-todo = { path = "../../tools/todo" }
|
||||
tools-web = { path = "../../tools/web" }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.13"
|
||||
|
||||
74
crates/core/agent/examples/git_demo.rs
Normal file
74
crates/core/agent/examples/git_demo.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
//! Example demonstrating the git integration module
|
||||
//!
|
||||
//! Run with: cargo run -p agent-core --example git_demo
|
||||
|
||||
use agent_core::{detect_git_state, format_git_status, is_safe_git_command, is_destructive_git_command};
|
||||
use std::env;
|
||||
|
||||
fn main() -> color_eyre::Result<()> {
|
||||
color_eyre::install()?;
|
||||
|
||||
// Get current working directory
|
||||
let cwd = env::current_dir()?;
|
||||
println!("Detecting git state in: {}\n", cwd.display());
|
||||
|
||||
// Detect git state
|
||||
let state = detect_git_state(&cwd)?;
|
||||
|
||||
// Display formatted status
|
||||
println!("{}\n", format_git_status(&state));
|
||||
|
||||
// Show detailed file status if there are changes
|
||||
if !state.status.is_empty() {
|
||||
println!("Detailed file status:");
|
||||
for status in &state.status {
|
||||
match status {
|
||||
agent_core::GitFileStatus::Modified { path } => {
|
||||
println!(" M {}", path);
|
||||
}
|
||||
agent_core::GitFileStatus::Added { path } => {
|
||||
println!(" A {}", path);
|
||||
}
|
||||
agent_core::GitFileStatus::Deleted { path } => {
|
||||
println!(" D {}", path);
|
||||
}
|
||||
agent_core::GitFileStatus::Renamed { from, to } => {
|
||||
println!(" R {} -> {}", from, to);
|
||||
}
|
||||
agent_core::GitFileStatus::Untracked { path } => {
|
||||
println!(" ? {}", path);
|
||||
}
|
||||
}
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
// Test command safety checking
|
||||
println!("Command safety checks:");
|
||||
let test_commands = vec![
|
||||
"git status",
|
||||
"git log --oneline",
|
||||
"git diff HEAD",
|
||||
"git commit -m 'test'",
|
||||
"git push --force origin main",
|
||||
"git reset --hard HEAD~1",
|
||||
"git rebase main",
|
||||
"git branch -D feature",
|
||||
];
|
||||
|
||||
for cmd in test_commands {
|
||||
let is_safe = is_safe_git_command(cmd);
|
||||
let (is_destructive, warning) = is_destructive_git_command(cmd);
|
||||
|
||||
print!(" {} - ", cmd);
|
||||
if is_safe {
|
||||
println!("SAFE (read-only)");
|
||||
} else if is_destructive {
|
||||
println!("DESTRUCTIVE: {}", warning);
|
||||
} else {
|
||||
println!("UNSAFE (modifies state)");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
92
crates/core/agent/examples/streaming_agent.rs
Normal file
92
crates/core/agent/examples/streaming_agent.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
/// Example demonstrating the streaming agent loop API
|
||||
///
|
||||
/// This example shows how to use `run_agent_loop_streaming` to receive
|
||||
/// real-time events during agent execution, including:
|
||||
/// - Text deltas as the LLM generates text
|
||||
/// - Tool execution start/end events
|
||||
/// - Tool output events
|
||||
/// - Final completion events
|
||||
///
|
||||
/// Run with: cargo run --example streaming_agent -p agent-core
|
||||
|
||||
use agent_core::{create_event_channel, run_agent_loop_streaming, AgentEvent, ToolContext};
|
||||
use llm_core::ChatOptions;
|
||||
use permissions::{Mode, PermissionManager};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> color_eyre::Result<()> {
|
||||
color_eyre::install()?;
|
||||
|
||||
// Note: This is a minimal example. In a real application, you would:
|
||||
// 1. Initialize a real LLM provider (e.g., OllamaClient)
|
||||
// 2. Configure the ChatOptions with your preferred model
|
||||
// 3. Set up appropriate permissions and tool context
|
||||
|
||||
println!("=== Streaming Agent Example ===\n");
|
||||
println!("This example demonstrates how to use the streaming agent loop API.");
|
||||
println!("To run with a real LLM provider, modify this example to:");
|
||||
println!(" 1. Create an LLM provider instance");
|
||||
println!(" 2. Set up permissions and tool context");
|
||||
println!(" 3. Call run_agent_loop_streaming with your prompt\n");
|
||||
|
||||
// Example code structure:
|
||||
println!("Example code:");
|
||||
println!("```rust");
|
||||
println!("// Create LLM provider");
|
||||
println!("let provider = OllamaClient::new(\"http://localhost:11434\");");
|
||||
println!();
|
||||
println!("// Set up permissions and context");
|
||||
println!("let perms = PermissionManager::new(Mode::Plan);");
|
||||
println!("let ctx = ToolContext::default();");
|
||||
println!();
|
||||
println!("// Create event channel");
|
||||
println!("let (tx, mut rx) = create_event_channel();");
|
||||
println!();
|
||||
println!("// Spawn agent loop");
|
||||
println!("let handle = tokio::spawn(async move {{");
|
||||
println!(" run_agent_loop_streaming(");
|
||||
println!(" &provider,");
|
||||
println!(" \"Your prompt here\",");
|
||||
println!(" &ChatOptions::default(),");
|
||||
println!(" &perms,");
|
||||
println!(" &ctx,");
|
||||
println!(" tx,");
|
||||
println!(" ).await");
|
||||
println!("}});");
|
||||
println!();
|
||||
println!("// Process events");
|
||||
println!("while let Some(event) = rx.recv().await {{");
|
||||
println!(" match event {{");
|
||||
println!(" AgentEvent::TextDelta(text) => {{");
|
||||
println!(" print!(\"{{text}}\");");
|
||||
println!(" }}");
|
||||
println!(" AgentEvent::ToolStart {{ tool_name, .. }} => {{");
|
||||
println!(" println!(\"\\n[Executing tool: {{tool_name}}]\");");
|
||||
println!(" }}");
|
||||
println!(" AgentEvent::ToolOutput {{ content, is_error, .. }} => {{");
|
||||
println!(" if is_error {{");
|
||||
println!(" eprintln!(\"Error: {{content}}\");");
|
||||
println!(" }} else {{");
|
||||
println!(" println!(\"Output: {{content}}\");");
|
||||
println!(" }}");
|
||||
println!(" }}");
|
||||
println!(" AgentEvent::ToolEnd {{ success, .. }} => {{");
|
||||
println!(" println!(\"[Tool finished: {{}}]\", if success {{ \"success\" }} else {{ \"failed\" }});");
|
||||
println!(" }}");
|
||||
println!(" AgentEvent::Done {{ final_response }} => {{");
|
||||
println!(" println!(\"\\n\\nFinal response: {{final_response}}\");");
|
||||
println!(" break;");
|
||||
println!(" }}");
|
||||
println!(" AgentEvent::Error(e) => {{");
|
||||
println!(" eprintln!(\"Error: {{e}}\");");
|
||||
println!(" break;");
|
||||
println!(" }}");
|
||||
println!(" }}");
|
||||
println!("}}");
|
||||
println!();
|
||||
println!("// Wait for completion");
|
||||
println!("let result = handle.await??;");
|
||||
println!("```");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
557
crates/core/agent/src/git.rs
Normal file
557
crates/core/agent/src/git.rs
Normal file
@@ -0,0 +1,557 @@
|
||||
//! Git integration module for detecting repository state and validating git commands.
|
||||
//!
|
||||
//! This module provides functionality to:
|
||||
//! - Detect if the current directory is a git repository
|
||||
//! - Capture git repository state (branch, status, uncommitted changes)
|
||||
//! - Validate git commands for safety (read-only vs destructive operations)
|
||||
|
||||
use color_eyre::eyre::Result;
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
|
||||
/// Status of a file in the git working tree
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum GitFileStatus {
|
||||
/// File has been modified
|
||||
Modified { path: String },
|
||||
/// File has been added (staged)
|
||||
Added { path: String },
|
||||
/// File has been deleted
|
||||
Deleted { path: String },
|
||||
/// File has been renamed
|
||||
Renamed { from: String, to: String },
|
||||
/// File is untracked
|
||||
Untracked { path: String },
|
||||
}
|
||||
|
||||
impl GitFileStatus {
|
||||
/// Get the primary path associated with this status
|
||||
pub fn path(&self) -> &str {
|
||||
match self {
|
||||
Self::Modified { path } => path,
|
||||
Self::Added { path } => path,
|
||||
Self::Deleted { path } => path,
|
||||
Self::Renamed { to, .. } => to,
|
||||
Self::Untracked { path } => path,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Complete state of a git repository
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GitState {
|
||||
/// Whether the current directory is in a git repository
|
||||
pub is_git_repo: bool,
|
||||
/// Current branch name (None if not in a repo or detached HEAD)
|
||||
pub current_branch: Option<String>,
|
||||
/// Main branch name (main/master, None if not detected)
|
||||
pub main_branch: Option<String>,
|
||||
/// Status of files in the working tree
|
||||
pub status: Vec<GitFileStatus>,
|
||||
/// Whether there are any uncommitted changes
|
||||
pub has_uncommitted_changes: bool,
|
||||
/// Remote URL for the repository (None if no remote configured)
|
||||
pub remote_url: Option<String>,
|
||||
}
|
||||
|
||||
impl GitState {
|
||||
/// Create a default GitState for non-git directories
|
||||
pub fn not_a_repo() -> Self {
|
||||
Self {
|
||||
is_git_repo: false,
|
||||
current_branch: None,
|
||||
main_branch: None,
|
||||
status: Vec::new(),
|
||||
has_uncommitted_changes: false,
|
||||
remote_url: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect the current git repository state
|
||||
///
|
||||
/// This function runs various git commands to gather information about the repository.
|
||||
/// If git is not available or the directory is not a git repo, returns a default state.
|
||||
pub fn detect_git_state(working_dir: &Path) -> Result<GitState> {
|
||||
// Check if this is a git repository
|
||||
let is_repo = Command::new("git")
|
||||
.arg("rev-parse")
|
||||
.arg("--git-dir")
|
||||
.current_dir(working_dir)
|
||||
.output()
|
||||
.map(|output| output.status.success())
|
||||
.unwrap_or(false);
|
||||
|
||||
if !is_repo {
|
||||
return Ok(GitState::not_a_repo());
|
||||
}
|
||||
|
||||
// Get current branch
|
||||
let current_branch = get_current_branch(working_dir)?;
|
||||
|
||||
// Detect main branch (try main first, then master)
|
||||
let main_branch = detect_main_branch(working_dir)?;
|
||||
|
||||
// Get file status
|
||||
let status = get_git_status(working_dir)?;
|
||||
|
||||
// Check if there are uncommitted changes
|
||||
let has_uncommitted_changes = !status.is_empty();
|
||||
|
||||
// Get remote URL
|
||||
let remote_url = get_remote_url(working_dir)?;
|
||||
|
||||
Ok(GitState {
|
||||
is_git_repo: true,
|
||||
current_branch,
|
||||
main_branch,
|
||||
status,
|
||||
has_uncommitted_changes,
|
||||
remote_url,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the current branch name
|
||||
fn get_current_branch(working_dir: &Path) -> Result<Option<String>> {
|
||||
let output = Command::new("git")
|
||||
.arg("rev-parse")
|
||||
.arg("--abbrev-ref")
|
||||
.arg("HEAD")
|
||||
.current_dir(working_dir)
|
||||
.output()?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let branch = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
|
||||
// "HEAD" means detached HEAD state
|
||||
if branch == "HEAD" {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(branch))
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect the main branch (main or master)
|
||||
fn detect_main_branch(working_dir: &Path) -> Result<Option<String>> {
|
||||
// Try to get all branches
|
||||
let output = Command::new("git")
|
||||
.arg("branch")
|
||||
.arg("-a")
|
||||
.current_dir(working_dir)
|
||||
.output()?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let branches = String::from_utf8_lossy(&output.stdout);
|
||||
|
||||
// Check for main branch first (modern convention)
|
||||
if branches.lines().any(|line| {
|
||||
let trimmed = line.trim_start_matches('*').trim();
|
||||
trimmed == "main" || trimmed.ends_with("/main")
|
||||
}) {
|
||||
return Ok(Some("main".to_string()));
|
||||
}
|
||||
|
||||
// Fall back to master
|
||||
if branches.lines().any(|line| {
|
||||
let trimmed = line.trim_start_matches('*').trim();
|
||||
trimmed == "master" || trimmed.ends_with("/master")
|
||||
}) {
|
||||
return Ok(Some("master".to_string()));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Get the git status for all files
|
||||
fn get_git_status(working_dir: &Path) -> Result<Vec<GitFileStatus>> {
|
||||
let output = Command::new("git")
|
||||
.arg("status")
|
||||
.arg("--porcelain")
|
||||
.arg("-z") // Null-terminated for better parsing
|
||||
.current_dir(working_dir)
|
||||
.output()?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let status_text = String::from_utf8_lossy(&output.stdout);
|
||||
let mut statuses = Vec::new();
|
||||
|
||||
// Parse porcelain format with null termination
|
||||
// Format: XY filename\0 (where X is staged status, Y is unstaged status)
|
||||
for entry in status_text.split('\0').filter(|s| !s.is_empty()) {
|
||||
if entry.len() < 3 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let status_code = &entry[0..2];
|
||||
let path = entry[3..].to_string();
|
||||
|
||||
// Parse status codes
|
||||
match status_code {
|
||||
"M " | " M" | "MM" => {
|
||||
statuses.push(GitFileStatus::Modified { path });
|
||||
}
|
||||
"A " | " A" | "AM" => {
|
||||
statuses.push(GitFileStatus::Added { path });
|
||||
}
|
||||
"D " | " D" | "AD" => {
|
||||
statuses.push(GitFileStatus::Deleted { path });
|
||||
}
|
||||
"??" => {
|
||||
statuses.push(GitFileStatus::Untracked { path });
|
||||
}
|
||||
s if s.starts_with('R') => {
|
||||
// Renamed files have format "R old_name -> new_name"
|
||||
if let Some((from, to)) = path.split_once(" -> ") {
|
||||
statuses.push(GitFileStatus::Renamed {
|
||||
from: from.to_string(),
|
||||
to: to.to_string(),
|
||||
});
|
||||
} else {
|
||||
// Fallback if parsing fails
|
||||
statuses.push(GitFileStatus::Modified { path });
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Unknown status code, treat as modified
|
||||
statuses.push(GitFileStatus::Modified { path });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(statuses)
|
||||
}
|
||||
|
||||
/// Get the remote URL for the repository
|
||||
fn get_remote_url(working_dir: &Path) -> Result<Option<String>> {
|
||||
let output = Command::new("git")
|
||||
.arg("remote")
|
||||
.arg("get-url")
|
||||
.arg("origin")
|
||||
.current_dir(working_dir)
|
||||
.output()?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let url = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
|
||||
if url.is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(url))
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a git command is safe (read-only)
|
||||
///
|
||||
/// Safe commands include:
|
||||
/// - status, log, show, diff, branch (without -D)
|
||||
/// - remote (without add/remove)
|
||||
/// - config --get
|
||||
/// - rev-parse, ls-files, ls-tree
|
||||
pub fn is_safe_git_command(command: &str) -> bool {
|
||||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||||
|
||||
if parts.is_empty() || parts[0] != "git" {
|
||||
return false;
|
||||
}
|
||||
|
||||
if parts.len() < 2 {
|
||||
return false;
|
||||
}
|
||||
|
||||
let subcommand = parts[1];
|
||||
|
||||
// List of read-only git commands
|
||||
match subcommand {
|
||||
"status" | "log" | "show" | "diff" | "blame" | "reflog" => true,
|
||||
"ls-files" | "ls-tree" | "ls-remote" => true,
|
||||
"rev-parse" | "rev-list" => true,
|
||||
"describe" | "tag" if !command.contains("-d") && !command.contains("--delete") => true,
|
||||
"branch" if !command.contains("-D") && !command.contains("-d") && !command.contains("-m") => true,
|
||||
"remote" if command.contains("get-url") || command.contains("-v") || command.contains("show") => true,
|
||||
"config" if command.contains("--get") || command.contains("--list") => true,
|
||||
"grep" | "shortlog" | "whatchanged" => true,
|
||||
"fetch" if !command.contains("--prune") => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a git command is destructive
|
||||
///
|
||||
/// Returns (is_destructive, warning_message) tuple.
|
||||
/// Destructive commands include:
|
||||
/// - push --force, reset --hard, clean -fd
|
||||
/// - rebase, amend, filter-branch
|
||||
/// - branch -D, tag -d
|
||||
pub fn is_destructive_git_command(command: &str) -> (bool, &'static str) {
|
||||
let cmd_lower = command.to_lowercase();
|
||||
|
||||
// Check for force push
|
||||
if cmd_lower.contains("push") && (cmd_lower.contains("--force") || cmd_lower.contains("-f")) {
|
||||
return (true, "Force push can overwrite remote history and affect other collaborators");
|
||||
}
|
||||
|
||||
// Check for hard reset
|
||||
if cmd_lower.contains("reset") && cmd_lower.contains("--hard") {
|
||||
return (true, "Hard reset will discard uncommitted changes permanently");
|
||||
}
|
||||
|
||||
// Check for git clean
|
||||
if cmd_lower.contains("clean") && (cmd_lower.contains("-f") || cmd_lower.contains("-d")) {
|
||||
return (true, "Git clean will permanently delete untracked files");
|
||||
}
|
||||
|
||||
// Check for rebase
|
||||
if cmd_lower.contains("rebase") {
|
||||
return (true, "Rebase rewrites commit history and can cause conflicts");
|
||||
}
|
||||
|
||||
// Check for amend
|
||||
if cmd_lower.contains("commit") && cmd_lower.contains("--amend") {
|
||||
return (true, "Amending rewrites the last commit and changes its hash");
|
||||
}
|
||||
|
||||
// Check for filter-branch or filter-repo
|
||||
if cmd_lower.contains("filter-branch") || cmd_lower.contains("filter-repo") {
|
||||
return (true, "Filter operations rewrite repository history");
|
||||
}
|
||||
|
||||
// Check for branch/tag deletion
|
||||
if (cmd_lower.contains("branch") && (cmd_lower.contains("-D") || cmd_lower.contains("-d")))
|
||||
|| (cmd_lower.contains("tag") && (cmd_lower.contains("-d") || cmd_lower.contains("--delete")))
|
||||
{
|
||||
return (true, "This will delete a branch or tag");
|
||||
}
|
||||
|
||||
// Check for reflog expire
|
||||
if cmd_lower.contains("reflog") && cmd_lower.contains("expire") {
|
||||
return (true, "Expiring reflog removes recovery points for lost commits");
|
||||
}
|
||||
|
||||
// Check for gc with aggressive or prune
|
||||
if cmd_lower.contains("gc") && (cmd_lower.contains("--aggressive") || cmd_lower.contains("--prune")) {
|
||||
return (true, "Aggressive garbage collection can make recovery difficult");
|
||||
}
|
||||
|
||||
(false, "")
|
||||
}
|
||||
|
||||
/// Format git state for human-readable display
|
||||
///
|
||||
/// Example output:
|
||||
/// ```text
|
||||
/// Git Repository: yes
|
||||
/// Current branch: feature-branch
|
||||
/// Main branch: main
|
||||
/// Status: 3 modified, 1 untracked
|
||||
/// Remote: https://github.com/user/repo.git
|
||||
/// ```
|
||||
pub fn format_git_status(state: &GitState) -> String {
|
||||
if !state.is_git_repo {
|
||||
return "Not a git repository".to_string();
|
||||
}
|
||||
|
||||
let mut lines = Vec::new();
|
||||
|
||||
lines.push("Git Repository: yes".to_string());
|
||||
|
||||
if let Some(branch) = &state.current_branch {
|
||||
lines.push(format!("Current branch: {}", branch));
|
||||
} else {
|
||||
lines.push("Current branch: (detached HEAD)".to_string());
|
||||
}
|
||||
|
||||
if let Some(main) = &state.main_branch {
|
||||
lines.push(format!("Main branch: {}", main));
|
||||
}
|
||||
|
||||
// Summarize status
|
||||
if state.status.is_empty() {
|
||||
lines.push("Status: clean working tree".to_string());
|
||||
} else {
|
||||
let mut modified = 0;
|
||||
let mut added = 0;
|
||||
let mut deleted = 0;
|
||||
let mut renamed = 0;
|
||||
let mut untracked = 0;
|
||||
|
||||
for status in &state.status {
|
||||
match status {
|
||||
GitFileStatus::Modified { .. } => modified += 1,
|
||||
GitFileStatus::Added { .. } => added += 1,
|
||||
GitFileStatus::Deleted { .. } => deleted += 1,
|
||||
GitFileStatus::Renamed { .. } => renamed += 1,
|
||||
GitFileStatus::Untracked { .. } => untracked += 1,
|
||||
}
|
||||
}
|
||||
|
||||
let mut status_parts = Vec::new();
|
||||
if modified > 0 {
|
||||
status_parts.push(format!("{} modified", modified));
|
||||
}
|
||||
if added > 0 {
|
||||
status_parts.push(format!("{} added", added));
|
||||
}
|
||||
if deleted > 0 {
|
||||
status_parts.push(format!("{} deleted", deleted));
|
||||
}
|
||||
if renamed > 0 {
|
||||
status_parts.push(format!("{} renamed", renamed));
|
||||
}
|
||||
if untracked > 0 {
|
||||
status_parts.push(format!("{} untracked", untracked));
|
||||
}
|
||||
|
||||
lines.push(format!("Status: {}", status_parts.join(", ")));
|
||||
}
|
||||
|
||||
if let Some(url) = &state.remote_url {
|
||||
lines.push(format!("Remote: {}", url));
|
||||
} else {
|
||||
lines.push("Remote: (none)".to_string());
|
||||
}
|
||||
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_is_safe_git_command() {
|
||||
// Safe commands
|
||||
assert!(is_safe_git_command("git status"));
|
||||
assert!(is_safe_git_command("git log --oneline"));
|
||||
assert!(is_safe_git_command("git diff HEAD"));
|
||||
assert!(is_safe_git_command("git branch -v"));
|
||||
assert!(is_safe_git_command("git remote -v"));
|
||||
assert!(is_safe_git_command("git config --get user.name"));
|
||||
|
||||
// Unsafe commands
|
||||
assert!(!is_safe_git_command("git commit -m test"));
|
||||
assert!(!is_safe_git_command("git push origin main"));
|
||||
assert!(!is_safe_git_command("git branch -D feature"));
|
||||
assert!(!is_safe_git_command("git remote add origin url"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_destructive_git_command() {
|
||||
// Destructive commands
|
||||
let (is_dest, msg) = is_destructive_git_command("git push --force origin main");
|
||||
assert!(is_dest);
|
||||
assert!(msg.contains("Force push"));
|
||||
|
||||
let (is_dest, msg) = is_destructive_git_command("git reset --hard HEAD~1");
|
||||
assert!(is_dest);
|
||||
assert!(msg.contains("Hard reset"));
|
||||
|
||||
let (is_dest, msg) = is_destructive_git_command("git clean -fd");
|
||||
assert!(is_dest);
|
||||
assert!(msg.contains("clean"));
|
||||
|
||||
let (is_dest, msg) = is_destructive_git_command("git rebase main");
|
||||
assert!(is_dest);
|
||||
assert!(msg.contains("Rebase"));
|
||||
|
||||
let (is_dest, msg) = is_destructive_git_command("git commit --amend");
|
||||
assert!(is_dest);
|
||||
assert!(msg.contains("Amending"));
|
||||
|
||||
// Non-destructive commands
|
||||
let (is_dest, _) = is_destructive_git_command("git status");
|
||||
assert!(!is_dest);
|
||||
|
||||
let (is_dest, _) = is_destructive_git_command("git log");
|
||||
assert!(!is_dest);
|
||||
|
||||
let (is_dest, _) = is_destructive_git_command("git diff");
|
||||
assert!(!is_dest);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_git_state_not_a_repo() {
|
||||
let state = GitState::not_a_repo();
|
||||
assert!(!state.is_git_repo);
|
||||
assert!(state.current_branch.is_none());
|
||||
assert!(state.main_branch.is_none());
|
||||
assert!(state.status.is_empty());
|
||||
assert!(!state.has_uncommitted_changes);
|
||||
assert!(state.remote_url.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_git_file_status_path() {
|
||||
let status = GitFileStatus::Modified {
|
||||
path: "test.rs".to_string(),
|
||||
};
|
||||
assert_eq!(status.path(), "test.rs");
|
||||
|
||||
let status = GitFileStatus::Renamed {
|
||||
from: "old.rs".to_string(),
|
||||
to: "new.rs".to_string(),
|
||||
};
|
||||
assert_eq!(status.path(), "new.rs");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_git_status_not_repo() {
|
||||
let state = GitState::not_a_repo();
|
||||
let formatted = format_git_status(&state);
|
||||
assert_eq!(formatted, "Not a git repository");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_git_status_clean() {
|
||||
let state = GitState {
|
||||
is_git_repo: true,
|
||||
current_branch: Some("main".to_string()),
|
||||
main_branch: Some("main".to_string()),
|
||||
status: Vec::new(),
|
||||
has_uncommitted_changes: false,
|
||||
remote_url: Some("https://github.com/user/repo.git".to_string()),
|
||||
};
|
||||
|
||||
let formatted = format_git_status(&state);
|
||||
assert!(formatted.contains("Git Repository: yes"));
|
||||
assert!(formatted.contains("Current branch: main"));
|
||||
assert!(formatted.contains("clean working tree"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_git_status_with_changes() {
|
||||
let state = GitState {
|
||||
is_git_repo: true,
|
||||
current_branch: Some("feature".to_string()),
|
||||
main_branch: Some("main".to_string()),
|
||||
status: vec![
|
||||
GitFileStatus::Modified {
|
||||
path: "file1.rs".to_string(),
|
||||
},
|
||||
GitFileStatus::Modified {
|
||||
path: "file2.rs".to_string(),
|
||||
},
|
||||
GitFileStatus::Untracked {
|
||||
path: "new.rs".to_string(),
|
||||
},
|
||||
],
|
||||
has_uncommitted_changes: true,
|
||||
remote_url: None,
|
||||
};
|
||||
|
||||
let formatted = format_git_status(&state);
|
||||
assert!(formatted.contains("2 modified"));
|
||||
assert!(formatted.contains("1 untracked"));
|
||||
}
|
||||
}
|
||||
@@ -1,145 +1,409 @@
|
||||
pub mod session;
|
||||
pub mod system_prompt;
|
||||
pub mod git;
|
||||
|
||||
use color_eyre::eyre::{Result, eyre};
|
||||
use futures_util::TryStreamExt;
|
||||
use llm_ollama::{ChatMessage, OllamaClient, OllamaOptions, Tool, ToolFunction, ToolParameters};
|
||||
use futures_util::StreamExt;
|
||||
use llm_core::{ChatMessage, ChatOptions, LlmProvider, Tool, ToolParameters};
|
||||
use permissions::{PermissionDecision, PermissionManager, Tool as PermTool};
|
||||
use serde_json::{json, Value};
|
||||
use tokio::sync::mpsc;
|
||||
use tools_ask::AskSender;
|
||||
use tools_bash::ShellManager;
|
||||
use tools_todo::TodoList;
|
||||
|
||||
pub use session::{
|
||||
SessionStats, SessionHistory, ToolCallRecord,
|
||||
Checkpoint, CheckpointManager, FileDiff,
|
||||
};
|
||||
pub use system_prompt::{
|
||||
SystemPromptBuilder, default_base_prompt, generate_tool_instructions,
|
||||
};
|
||||
pub use git::{
|
||||
GitState, GitFileStatus,
|
||||
detect_git_state, is_safe_git_command, is_destructive_git_command,
|
||||
format_git_status,
|
||||
};
|
||||
|
||||
/// Events emitted during agent loop execution
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AgentEvent {
|
||||
/// LLM is generating text
|
||||
TextDelta(String),
|
||||
|
||||
/// Tool execution starting
|
||||
ToolStart {
|
||||
tool_name: String,
|
||||
tool_id: String,
|
||||
},
|
||||
|
||||
/// Tool produced output (may be partial)
|
||||
ToolOutput {
|
||||
tool_id: String,
|
||||
content: String,
|
||||
is_error: bool,
|
||||
},
|
||||
|
||||
/// Tool execution completed
|
||||
ToolEnd {
|
||||
tool_id: String,
|
||||
success: bool,
|
||||
},
|
||||
|
||||
/// Agent loop completed
|
||||
Done {
|
||||
final_response: String,
|
||||
},
|
||||
|
||||
/// Error occurred
|
||||
Error(String),
|
||||
}
|
||||
|
||||
pub type AgentEventSender = mpsc::Sender<AgentEvent>;
|
||||
pub type AgentEventReceiver = mpsc::Receiver<AgentEvent>;
|
||||
|
||||
/// Create channel for agent events
|
||||
pub fn create_event_channel() -> (AgentEventSender, AgentEventReceiver) {
|
||||
mpsc::channel(100)
|
||||
}
|
||||
|
||||
/// Optional context for tools that need external dependencies
|
||||
#[derive(Clone, Default)]
|
||||
pub struct ToolContext {
|
||||
/// Todo list for TodoWrite tool
|
||||
pub todo_list: Option<TodoList>,
|
||||
|
||||
/// Channel for asking user questions
|
||||
pub ask_sender: Option<AskSender>,
|
||||
|
||||
/// Shell manager for background shells
|
||||
pub shell_manager: Option<ShellManager>,
|
||||
}
|
||||
|
||||
impl ToolContext {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn with_todo_list(mut self, list: TodoList) -> Self {
|
||||
self.todo_list = Some(list);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_ask_sender(mut self, sender: AskSender) -> Self {
|
||||
self.ask_sender = Some(sender);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_shell_manager(mut self, manager: ShellManager) -> Self {
|
||||
self.shell_manager = Some(manager);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Define all available tools for the LLM
|
||||
pub fn get_tool_definitions() -> Vec<Tool> {
|
||||
vec![
|
||||
Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: ToolFunction {
|
||||
name: "read".to_string(),
|
||||
description: "Read the contents of a file".to_string(),
|
||||
parameters: ToolParameters {
|
||||
param_type: "object".to_string(),
|
||||
properties: json!({
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The path to the file to read"
|
||||
}
|
||||
}),
|
||||
required: vec!["path".to_string()],
|
||||
},
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: ToolFunction {
|
||||
name: "glob".to_string(),
|
||||
description: "Find files matching a glob pattern (e.g., '**/*.rs' for all Rust files)".to_string(),
|
||||
parameters: ToolParameters {
|
||||
param_type: "object".to_string(),
|
||||
properties: json!({
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob pattern to match files (e.g., '**/*.toml', '*.md')"
|
||||
}
|
||||
}),
|
||||
required: vec!["pattern".to_string()],
|
||||
},
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: ToolFunction {
|
||||
name: "grep".to_string(),
|
||||
description: "Search for a pattern in files within a directory".to_string(),
|
||||
parameters: ToolParameters {
|
||||
param_type: "object".to_string(),
|
||||
properties: json!({
|
||||
"root": {
|
||||
"type": "string",
|
||||
"description": "Root directory to search in"
|
||||
Tool::function(
|
||||
"read",
|
||||
"Read the contents of a file",
|
||||
ToolParameters::object(
|
||||
json!({
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The path to the file to read"
|
||||
}
|
||||
}),
|
||||
vec!["path".to_string()],
|
||||
),
|
||||
),
|
||||
Tool::function(
|
||||
"glob",
|
||||
"Find files matching a glob pattern (e.g., '**/*.rs' for all Rust files)",
|
||||
ToolParameters::object(
|
||||
json!({
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob pattern to match files (e.g., '**/*.toml', '*.md')"
|
||||
}
|
||||
}),
|
||||
vec!["pattern".to_string()],
|
||||
),
|
||||
),
|
||||
Tool::function(
|
||||
"grep",
|
||||
"Search for a pattern in files within a directory",
|
||||
ToolParameters::object(
|
||||
json!({
|
||||
"root": {
|
||||
"type": "string",
|
||||
"description": "Root directory to search in"
|
||||
},
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Pattern to search for"
|
||||
}
|
||||
}),
|
||||
vec!["root".to_string(), "pattern".to_string()],
|
||||
),
|
||||
),
|
||||
Tool::function(
|
||||
"write",
|
||||
"Write content to a file",
|
||||
ToolParameters::object(
|
||||
json!({
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path where the file should be written"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Content to write to the file"
|
||||
}
|
||||
}),
|
||||
vec!["path".to_string(), "content".to_string()],
|
||||
),
|
||||
),
|
||||
Tool::function(
|
||||
"edit",
|
||||
"Edit a file by replacing old text with new text",
|
||||
ToolParameters::object(
|
||||
json!({
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to edit"
|
||||
},
|
||||
"old_string": {
|
||||
"type": "string",
|
||||
"description": "Text to find and replace"
|
||||
},
|
||||
"new_string": {
|
||||
"type": "string",
|
||||
"description": "Text to replace with"
|
||||
}
|
||||
}),
|
||||
vec!["path".to_string(), "old_string".to_string(), "new_string".to_string()],
|
||||
),
|
||||
),
|
||||
Tool::function(
|
||||
"bash",
|
||||
"Execute a bash command. Use carefully and only when necessary.",
|
||||
ToolParameters::object(
|
||||
json!({
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The bash command to execute"
|
||||
}
|
||||
}),
|
||||
vec!["command".to_string()],
|
||||
),
|
||||
),
|
||||
Tool::function(
|
||||
"multi_edit",
|
||||
"Apply multiple edits to a file atomically",
|
||||
ToolParameters::object(
|
||||
json!({
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to edit"
|
||||
},
|
||||
"edits": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"old_string": { "type": "string" },
|
||||
"new_string": { "type": "string" }
|
||||
},
|
||||
"required": ["old_string", "new_string"]
|
||||
},
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Pattern to search for"
|
||||
"description": "List of edit operations"
|
||||
}
|
||||
}),
|
||||
vec!["path".to_string(), "edits".to_string()],
|
||||
),
|
||||
),
|
||||
Tool::function(
|
||||
"ls",
|
||||
"List contents of a directory",
|
||||
ToolParameters::object(
|
||||
json!({
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory path to list"
|
||||
},
|
||||
"show_hidden": {
|
||||
"type": "boolean",
|
||||
"description": "Show hidden files",
|
||||
"default": false
|
||||
}
|
||||
}),
|
||||
vec!["path".to_string()],
|
||||
),
|
||||
),
|
||||
Tool::function(
|
||||
"web_search",
|
||||
"Search the web for information",
|
||||
ToolParameters::object(
|
||||
json!({
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum results",
|
||||
"default": 10
|
||||
}
|
||||
}),
|
||||
vec!["query".to_string()],
|
||||
),
|
||||
),
|
||||
Tool::function(
|
||||
"todo_write",
|
||||
"Update the task list",
|
||||
ToolParameters::object(
|
||||
json!({
|
||||
"todos": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": { "type": "string" },
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["pending", "in_progress", "completed"]
|
||||
},
|
||||
"active_form": { "type": "string" }
|
||||
},
|
||||
"required": ["content", "status", "active_form"]
|
||||
}
|
||||
}),
|
||||
required: vec!["root".to_string(), "pattern".to_string()],
|
||||
},
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: ToolFunction {
|
||||
name: "write".to_string(),
|
||||
description: "Write content to a file".to_string(),
|
||||
parameters: ToolParameters {
|
||||
param_type: "object".to_string(),
|
||||
properties: json!({
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path where the file should be written"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Content to write to the file"
|
||||
}
|
||||
}),
|
||||
vec!["todos".to_string()],
|
||||
),
|
||||
),
|
||||
Tool::function(
|
||||
"ask_user",
|
||||
"Ask the user a question with options",
|
||||
ToolParameters::object(
|
||||
json!({
|
||||
"questions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question": { "type": "string" },
|
||||
"header": { "type": "string" },
|
||||
"options": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"label": { "type": "string" },
|
||||
"description": { "type": "string" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"multi_select": { "type": "boolean" }
|
||||
}
|
||||
}
|
||||
}),
|
||||
required: vec!["path".to_string(), "content".to_string()],
|
||||
},
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: ToolFunction {
|
||||
name: "edit".to_string(),
|
||||
description: "Edit a file by replacing old text with new text".to_string(),
|
||||
parameters: ToolParameters {
|
||||
param_type: "object".to_string(),
|
||||
properties: json!({
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to edit"
|
||||
},
|
||||
"old_string": {
|
||||
"type": "string",
|
||||
"description": "Text to find and replace"
|
||||
},
|
||||
"new_string": {
|
||||
"type": "string",
|
||||
"description": "Text to replace with"
|
||||
}
|
||||
}),
|
||||
required: vec!["path".to_string(), "old_string".to_string(), "new_string".to_string()],
|
||||
},
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: ToolFunction {
|
||||
name: "bash".to_string(),
|
||||
description: "Execute a bash command. Use carefully and only when necessary.".to_string(),
|
||||
parameters: ToolParameters {
|
||||
param_type: "object".to_string(),
|
||||
properties: json!({
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The bash command to execute"
|
||||
}
|
||||
}),
|
||||
required: vec!["command".to_string()],
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}),
|
||||
vec!["questions".to_string()],
|
||||
),
|
||||
),
|
||||
Tool::function(
|
||||
"bash_output",
|
||||
"Get output from a background shell",
|
||||
ToolParameters::object(
|
||||
json!({
|
||||
"shell_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the background shell"
|
||||
}
|
||||
}),
|
||||
vec!["shell_id".to_string()],
|
||||
),
|
||||
),
|
||||
Tool::function(
|
||||
"kill_shell",
|
||||
"Terminate a background shell",
|
||||
ToolParameters::object(
|
||||
json!({
|
||||
"shell_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the shell to kill"
|
||||
}
|
||||
}),
|
||||
vec!["shell_id".to_string()],
|
||||
),
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
/// Helper to accumulate streaming tool call deltas
|
||||
#[derive(Default)]
|
||||
struct ToolCallsBuilder {
|
||||
calls: Vec<PartialToolCall>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct PartialToolCall {
|
||||
id: Option<String>,
|
||||
name: Option<String>,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
impl ToolCallsBuilder {
|
||||
fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
fn add_deltas(&mut self, deltas: &[llm_core::ToolCallDelta]) {
|
||||
for delta in deltas {
|
||||
while self.calls.len() <= delta.index {
|
||||
self.calls.push(PartialToolCall::default());
|
||||
}
|
||||
let call = &mut self.calls[delta.index];
|
||||
if let Some(id) = &delta.id {
|
||||
call.id = Some(id.clone());
|
||||
}
|
||||
if let Some(name) = &delta.function_name {
|
||||
call.name = Some(name.clone());
|
||||
}
|
||||
if let Some(args) = &delta.arguments_delta {
|
||||
call.arguments.push_str(args);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build(self) -> Vec<llm_core::ToolCall> {
|
||||
self.calls
|
||||
.into_iter()
|
||||
.filter_map(|p| {
|
||||
let id = p.id?;
|
||||
let name = p.name?;
|
||||
let args: Value = serde_json::from_str(&p.arguments).ok()?;
|
||||
Some(llm_core::ToolCall {
|
||||
id,
|
||||
call_type: "function".to_string(),
|
||||
function: llm_core::FunctionCall {
|
||||
name,
|
||||
arguments: args,
|
||||
},
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a tool call and return the result
|
||||
pub async fn execute_tool(
|
||||
tool_name: &str,
|
||||
arguments: &Value,
|
||||
perms: &PermissionManager,
|
||||
ctx: &ToolContext,
|
||||
) -> Result<String> {
|
||||
match tool_name {
|
||||
"read" => {
|
||||
@@ -280,23 +544,182 @@ pub async fn execute_tool(
|
||||
}
|
||||
}
|
||||
}
|
||||
"multi_edit" => {
|
||||
let path = arguments["path"]
|
||||
.as_str()
|
||||
.ok_or_else(|| eyre!("Missing 'path' argument"))?;
|
||||
let edits_value = arguments["edits"]
|
||||
.as_array()
|
||||
.ok_or_else(|| eyre!("Missing or invalid 'edits' argument"))?;
|
||||
|
||||
// Parse edits
|
||||
let edits: Vec<tools_fs::EditOperation> = serde_json::from_value(json!(edits_value))?;
|
||||
|
||||
// Check permission
|
||||
match perms.check(PermTool::MultiEdit, Some(path)) {
|
||||
PermissionDecision::Allow => {
|
||||
let result = tools_fs::multi_edit_file(path, edits)?;
|
||||
Ok(result)
|
||||
}
|
||||
PermissionDecision::Ask => {
|
||||
Err(eyre!("Permission required: MultiEdit operation needs approval"))
|
||||
}
|
||||
PermissionDecision::Deny => {
|
||||
Err(eyre!("Permission denied: MultiEdit operation is blocked"))
|
||||
}
|
||||
}
|
||||
}
|
||||
"ls" => {
|
||||
let path = arguments["path"]
|
||||
.as_str()
|
||||
.ok_or_else(|| eyre!("Missing 'path' argument"))?;
|
||||
let show_hidden = arguments.get("show_hidden")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
// Check permission
|
||||
match perms.check(PermTool::LS, Some(path)) {
|
||||
PermissionDecision::Allow => {
|
||||
let entries = tools_fs::list_directory(path, show_hidden)?;
|
||||
let output = entries
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
let type_marker = if e.is_dir { "/" } else { "" };
|
||||
let size = e.size.map(|s| format!(" ({}B)", s)).unwrap_or_default();
|
||||
format!("{}{}{}", e.name, type_marker, size)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
Ok(output)
|
||||
}
|
||||
PermissionDecision::Ask => {
|
||||
Err(eyre!("Permission required: LS operation needs approval"))
|
||||
}
|
||||
PermissionDecision::Deny => {
|
||||
Err(eyre!("Permission denied: LS operation is blocked"))
|
||||
}
|
||||
}
|
||||
}
|
||||
"web_search" => {
|
||||
let query = arguments["query"]
|
||||
.as_str()
|
||||
.ok_or_else(|| eyre!("Missing 'query' argument"))?;
|
||||
let max_results = arguments.get("max_results")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(10) as usize;
|
||||
|
||||
// Check permission
|
||||
match perms.check(PermTool::WebSearch, None) {
|
||||
PermissionDecision::Allow => {
|
||||
// Use DuckDuckGo search provider
|
||||
let provider = Box::new(tools_web::DuckDuckGoSearchProvider::with_max_results(max_results));
|
||||
let client = tools_web::WebSearchClient::new(provider);
|
||||
let results = client.search(query).await?;
|
||||
let formatted = tools_web::format_search_results(&results);
|
||||
Ok(formatted)
|
||||
}
|
||||
PermissionDecision::Ask => {
|
||||
Err(eyre!("Permission required: WebSearch operation needs approval"))
|
||||
}
|
||||
PermissionDecision::Deny => {
|
||||
Err(eyre!("Permission denied: WebSearch operation is blocked"))
|
||||
}
|
||||
}
|
||||
}
|
||||
"todo_write" => {
|
||||
let todo_list = ctx.todo_list.as_ref()
|
||||
.ok_or_else(|| eyre!("TodoList not available in this context"))?;
|
||||
|
||||
// Check permission
|
||||
match perms.check(PermTool::TodoWrite, None) {
|
||||
PermissionDecision::Allow => {
|
||||
let todos = tools_todo::parse_todos(arguments)?;
|
||||
todo_list.write(todos);
|
||||
Ok(todo_list.format_display())
|
||||
}
|
||||
PermissionDecision::Ask => {
|
||||
Err(eyre!("Permission required: TodoWrite operation needs approval"))
|
||||
}
|
||||
PermissionDecision::Deny => {
|
||||
Err(eyre!("Permission denied: TodoWrite operation is blocked"))
|
||||
}
|
||||
}
|
||||
}
|
||||
"ask_user" => {
|
||||
let sender = ctx.ask_sender.as_ref()
|
||||
.ok_or_else(|| eyre!("AskUser not available in this context"))?;
|
||||
|
||||
// Check permission
|
||||
match perms.check(PermTool::AskUserQuestion, None) {
|
||||
PermissionDecision::Allow => {
|
||||
let questions = tools_ask::parse_questions(arguments)?;
|
||||
let answers = tools_ask::ask_user(sender, questions).await?;
|
||||
Ok(serde_json::to_string_pretty(&answers)?)
|
||||
}
|
||||
PermissionDecision::Ask => {
|
||||
Err(eyre!("Permission required: AskUser operation needs approval"))
|
||||
}
|
||||
PermissionDecision::Deny => {
|
||||
Err(eyre!("Permission denied: AskUser operation is blocked"))
|
||||
}
|
||||
}
|
||||
}
|
||||
"bash_output" => {
|
||||
let manager = ctx.shell_manager.as_ref()
|
||||
.ok_or_else(|| eyre!("ShellManager not available in this context"))?;
|
||||
|
||||
let shell_id = arguments["shell_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| eyre!("Missing 'shell_id' argument"))?;
|
||||
|
||||
// Check permission
|
||||
match perms.check(PermTool::BashOutput, Some(shell_id)) {
|
||||
PermissionDecision::Allow => {
|
||||
tools_bash::bash_output(manager, shell_id)
|
||||
}
|
||||
PermissionDecision::Ask => {
|
||||
Err(eyre!("Permission required: BashOutput operation needs approval"))
|
||||
}
|
||||
PermissionDecision::Deny => {
|
||||
Err(eyre!("Permission denied: BashOutput operation is blocked"))
|
||||
}
|
||||
}
|
||||
}
|
||||
"kill_shell" => {
|
||||
let manager = ctx.shell_manager.as_ref()
|
||||
.ok_or_else(|| eyre!("ShellManager not available in this context"))?;
|
||||
|
||||
let shell_id = arguments["shell_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| eyre!("Missing 'shell_id' argument"))?;
|
||||
|
||||
// Check permission
|
||||
match perms.check(PermTool::KillShell, Some(shell_id)) {
|
||||
PermissionDecision::Allow => {
|
||||
tools_bash::kill_shell(manager, shell_id)
|
||||
}
|
||||
PermissionDecision::Ask => {
|
||||
Err(eyre!("Permission required: KillShell operation needs approval"))
|
||||
}
|
||||
PermissionDecision::Deny => {
|
||||
Err(eyre!("Permission denied: KillShell operation is blocked"))
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => Err(eyre!("Unknown tool: {}", tool_name)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the agent loop with tool calling
|
||||
pub async fn run_agent_loop(
|
||||
client: &OllamaClient,
|
||||
pub async fn run_agent_loop<P: LlmProvider>(
|
||||
provider: &P,
|
||||
user_prompt: &str,
|
||||
opts: &OllamaOptions,
|
||||
options: &ChatOptions,
|
||||
perms: &PermissionManager,
|
||||
ctx: &ToolContext,
|
||||
) -> Result<String> {
|
||||
let tools = get_tool_definitions();
|
||||
let mut messages = vec![ChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some(user_prompt.to_string()),
|
||||
tool_calls: None,
|
||||
}];
|
||||
let mut messages = vec![ChatMessage::user(user_prompt)];
|
||||
|
||||
let max_iterations = 10; // Prevent infinite loops
|
||||
let mut iteration = 0;
|
||||
@@ -308,18 +731,57 @@ pub async fn run_agent_loop(
|
||||
}
|
||||
|
||||
// Call LLM with messages and tools
|
||||
let mut stream = client.chat_stream(&messages, opts, Some(&tools)).await?;
|
||||
let mut stream = provider
|
||||
.chat_stream(&messages, options, Some(&tools))
|
||||
.await
|
||||
.map_err(|e| eyre!("LLM provider error: {}", e))?;
|
||||
|
||||
let mut response_content = String::new();
|
||||
let mut tool_calls = None;
|
||||
let mut accumulated_tool_calls: Vec<llm_core::ToolCall> = Vec::new();
|
||||
|
||||
// Collect the streamed response
|
||||
while let Some(chunk) = stream.try_next().await? {
|
||||
if let Some(msg) = chunk.message {
|
||||
if let Some(content) = msg.content {
|
||||
response_content.push_str(&content);
|
||||
}
|
||||
if let Some(calls) = msg.tool_calls {
|
||||
tool_calls = Some(calls);
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk.map_err(|e| eyre!("Stream error: {}", e))?;
|
||||
|
||||
if let Some(content) = chunk.content {
|
||||
response_content.push_str(&content);
|
||||
}
|
||||
|
||||
// Accumulate tool calls from deltas
|
||||
if let Some(deltas) = chunk.tool_calls {
|
||||
for delta in deltas {
|
||||
// Ensure the accumulated_tool_calls vec is large enough
|
||||
while accumulated_tool_calls.len() <= delta.index {
|
||||
accumulated_tool_calls.push(llm_core::ToolCall {
|
||||
id: String::new(),
|
||||
call_type: "function".to_string(),
|
||||
function: llm_core::FunctionCall {
|
||||
name: String::new(),
|
||||
arguments: Value::Null,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
let tool_call = &mut accumulated_tool_calls[delta.index];
|
||||
|
||||
if let Some(id) = delta.id {
|
||||
tool_call.id = id;
|
||||
}
|
||||
if let Some(name) = delta.function_name {
|
||||
tool_call.function.name = name;
|
||||
}
|
||||
if let Some(args_delta) = delta.arguments_delta {
|
||||
// Accumulate the arguments string
|
||||
let current_args = if tool_call.function.arguments.is_null() {
|
||||
String::new()
|
||||
} else {
|
||||
tool_call.function.arguments.to_string()
|
||||
};
|
||||
let new_args = current_args + &args_delta;
|
||||
// Try to parse as JSON, but keep as string if incomplete
|
||||
tool_call.function.arguments = serde_json::from_str(&new_args)
|
||||
.unwrap_or_else(|_| Value::String(new_args));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -327,44 +789,44 @@ pub async fn run_agent_loop(
|
||||
// Drop the stream to release the borrow on messages
|
||||
drop(stream);
|
||||
|
||||
// Filter out incomplete tool calls and check if we have valid ones
|
||||
let valid_tool_calls: Vec<_> = accumulated_tool_calls
|
||||
.into_iter()
|
||||
.filter(|tc| !tc.id.is_empty() && !tc.function.name.is_empty())
|
||||
.collect();
|
||||
|
||||
// Check if LLM wants to call tools
|
||||
if let Some(calls) = tool_calls {
|
||||
if !valid_tool_calls.is_empty() {
|
||||
// Add assistant message with tool calls
|
||||
messages.push(ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
role: llm_core::Role::Assistant,
|
||||
content: if response_content.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(response_content.clone())
|
||||
},
|
||||
tool_calls: Some(calls.clone()),
|
||||
tool_calls: Some(valid_tool_calls.clone()),
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
});
|
||||
|
||||
// Execute each tool call
|
||||
for call in calls {
|
||||
for call in valid_tool_calls {
|
||||
let tool_name = &call.function.name;
|
||||
let arguments = &call.function.arguments;
|
||||
|
||||
println!("\n🔧 Tool call: {} with args: {}", tool_name, arguments);
|
||||
tracing::debug!(tool = %tool_name, args = %arguments, "executing tool call");
|
||||
|
||||
match execute_tool(tool_name, arguments, perms).await {
|
||||
match execute_tool(tool_name, arguments, perms, ctx).await {
|
||||
Ok(result) => {
|
||||
println!("✅ Tool result: {}", result);
|
||||
tracing::debug!(tool = %tool_name, result = %result, "tool call succeeded");
|
||||
// Add tool result message
|
||||
messages.push(ChatMessage {
|
||||
role: "tool".to_string(),
|
||||
content: Some(result),
|
||||
tool_calls: None,
|
||||
});
|
||||
messages.push(ChatMessage::tool_result(&call.id, result));
|
||||
}
|
||||
Err(e) => {
|
||||
println!("❌ Tool error: {}", e);
|
||||
tracing::warn!(tool = %tool_name, error = %e, "tool call failed");
|
||||
// Add error message as tool result
|
||||
messages.push(ChatMessage {
|
||||
role: "tool".to_string(),
|
||||
content: Some(format!("Error: {}", e)),
|
||||
tool_calls: None,
|
||||
});
|
||||
messages.push(ChatMessage::tool_result(&call.id, format!("Error: {}", e)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -377,3 +839,144 @@ pub async fn run_agent_loop(
|
||||
return Ok(response_content);
|
||||
}
|
||||
}
|
||||
|
||||
/// Run agent loop with event streaming
|
||||
pub async fn run_agent_loop_streaming<P: LlmProvider>(
|
||||
provider: &P,
|
||||
user_prompt: &str,
|
||||
options: &ChatOptions,
|
||||
perms: &PermissionManager,
|
||||
ctx: &ToolContext,
|
||||
events: AgentEventSender,
|
||||
) -> Result<String> {
|
||||
let tools = get_tool_definitions();
|
||||
let mut messages = vec![ChatMessage::user(user_prompt)];
|
||||
|
||||
let max_iterations = 10;
|
||||
let mut iteration = 0;
|
||||
|
||||
loop {
|
||||
iteration += 1;
|
||||
if iteration > max_iterations {
|
||||
let _ = events.send(AgentEvent::Error("Max iterations reached".into())).await;
|
||||
return Err(eyre!("Max iterations reached"));
|
||||
}
|
||||
|
||||
// Stream LLM response
|
||||
let mut stream = provider
|
||||
.chat_stream(&messages, options, Some(&tools))
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_msg = format!("LLM provider error: {}", e);
|
||||
let _ = events.try_send(AgentEvent::Error(err_msg.clone()));
|
||||
eyre!(err_msg)
|
||||
})?;
|
||||
|
||||
let mut response_content = String::new();
|
||||
let mut tool_calls_builder = ToolCallsBuilder::new();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk.map_err(|e| {
|
||||
let err_msg = format!("Stream error: {}", e);
|
||||
let _ = events.try_send(AgentEvent::Error(err_msg.clone()));
|
||||
eyre!(err_msg)
|
||||
})?;
|
||||
|
||||
// Send text deltas
|
||||
if let Some(text) = &chunk.content {
|
||||
let _ = events.send(AgentEvent::TextDelta(text.clone())).await;
|
||||
response_content.push_str(text);
|
||||
}
|
||||
|
||||
// Accumulate tool calls
|
||||
if let Some(deltas) = &chunk.tool_calls {
|
||||
tool_calls_builder.add_deltas(deltas);
|
||||
}
|
||||
}
|
||||
|
||||
// Drop stream to release borrow
|
||||
drop(stream);
|
||||
|
||||
let tool_calls = tool_calls_builder.build();
|
||||
|
||||
if tool_calls.is_empty() {
|
||||
let _ = events
|
||||
.send(AgentEvent::Done {
|
||||
final_response: response_content.clone(),
|
||||
})
|
||||
.await;
|
||||
return Ok(response_content);
|
||||
}
|
||||
|
||||
// Add assistant message
|
||||
messages.push(ChatMessage {
|
||||
role: llm_core::Role::Assistant,
|
||||
content: if response_content.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(response_content.clone())
|
||||
},
|
||||
tool_calls: Some(tool_calls.clone()),
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
});
|
||||
|
||||
// Execute tools with events
|
||||
for call in &tool_calls {
|
||||
let tool_id = call.id.clone();
|
||||
let tool_name = call.function.name.clone();
|
||||
|
||||
let _ = events
|
||||
.send(AgentEvent::ToolStart {
|
||||
tool_name: tool_name.clone(),
|
||||
tool_id: tool_id.clone(),
|
||||
})
|
||||
.await;
|
||||
|
||||
tracing::debug!(tool = %tool_name, args = %call.function.arguments, "executing tool call");
|
||||
|
||||
let result = execute_tool(&tool_name, &call.function.arguments, perms, ctx).await;
|
||||
|
||||
match &result {
|
||||
Ok(output) => {
|
||||
tracing::debug!(tool = %tool_name, result = %output, "tool call succeeded");
|
||||
let _ = events
|
||||
.send(AgentEvent::ToolOutput {
|
||||
tool_id: tool_id.clone(),
|
||||
content: output.clone(),
|
||||
is_error: false,
|
||||
})
|
||||
.await;
|
||||
let _ = events
|
||||
.send(AgentEvent::ToolEnd {
|
||||
tool_id: tool_id.clone(),
|
||||
success: true,
|
||||
})
|
||||
.await;
|
||||
messages.push(ChatMessage::tool_result(&tool_id, output));
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(tool = %tool_name, error = %e, "tool call failed");
|
||||
let error_msg = e.to_string();
|
||||
let _ = events
|
||||
.send(AgentEvent::ToolOutput {
|
||||
tool_id: tool_id.clone(),
|
||||
content: error_msg.clone(),
|
||||
is_error: true,
|
||||
})
|
||||
.await;
|
||||
let _ = events
|
||||
.send(AgentEvent::ToolEnd {
|
||||
tool_id: tool_id.clone(),
|
||||
success: false,
|
||||
})
|
||||
.await;
|
||||
messages.push(ChatMessage::tool_result(
|
||||
&tool_id,
|
||||
format!("Error: {}", error_msg),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
266
crates/core/agent/src/system_prompt.rs
Normal file
266
crates/core/agent/src/system_prompt.rs
Normal file
@@ -0,0 +1,266 @@
|
||||
//! System Prompt Management
|
||||
//!
|
||||
//! Composes system prompts from multiple sources for agent sessions.
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
/// Builder for composing system prompts
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct SystemPromptBuilder {
|
||||
sections: Vec<PromptSection>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct PromptSection {
|
||||
name: String,
|
||||
content: String,
|
||||
priority: i32, // Lower = earlier in prompt
|
||||
}
|
||||
|
||||
impl SystemPromptBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Add the base agent prompt
|
||||
pub fn with_base_prompt(mut self, content: impl Into<String>) -> Self {
|
||||
self.sections.push(PromptSection {
|
||||
name: "base".to_string(),
|
||||
content: content.into(),
|
||||
priority: 0,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Add tool usage instructions
|
||||
pub fn with_tool_instructions(mut self, content: impl Into<String>) -> Self {
|
||||
self.sections.push(PromptSection {
|
||||
name: "tools".to_string(),
|
||||
content: content.into(),
|
||||
priority: 10,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Load and add project instructions from CLAUDE.md or .owlen.md
|
||||
pub fn with_project_instructions(mut self, project_root: &Path) -> Self {
|
||||
// Try CLAUDE.md first (Claude Code compatibility)
|
||||
let claude_md = project_root.join("CLAUDE.md");
|
||||
if claude_md.exists() {
|
||||
if let Ok(content) = std::fs::read_to_string(&claude_md) {
|
||||
self.sections.push(PromptSection {
|
||||
name: "project".to_string(),
|
||||
content: format!("# Project Instructions\n\n{}", content),
|
||||
priority: 20,
|
||||
});
|
||||
return self;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to .owlen.md
|
||||
let owlen_md = project_root.join(".owlen.md");
|
||||
if owlen_md.exists() {
|
||||
if let Ok(content) = std::fs::read_to_string(&owlen_md) {
|
||||
self.sections.push(PromptSection {
|
||||
name: "project".to_string(),
|
||||
content: format!("# Project Instructions\n\n{}", content),
|
||||
priority: 20,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Add skill content
|
||||
pub fn with_skill(mut self, skill_name: &str, content: impl Into<String>) -> Self {
|
||||
self.sections.push(PromptSection {
|
||||
name: format!("skill:{}", skill_name),
|
||||
content: content.into(),
|
||||
priority: 30,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Add hook-injected content (from SessionStart hooks)
|
||||
pub fn with_hook_injection(mut self, content: impl Into<String>) -> Self {
|
||||
self.sections.push(PromptSection {
|
||||
name: "hook".to_string(),
|
||||
content: content.into(),
|
||||
priority: 40,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Add custom section
|
||||
pub fn with_section(mut self, name: impl Into<String>, content: impl Into<String>, priority: i32) -> Self {
|
||||
self.sections.push(PromptSection {
|
||||
name: name.into(),
|
||||
content: content.into(),
|
||||
priority,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the final system prompt
|
||||
pub fn build(mut self) -> String {
|
||||
// Sort by priority
|
||||
self.sections.sort_by_key(|s| s.priority);
|
||||
|
||||
// Join sections with separators
|
||||
self.sections
|
||||
.iter()
|
||||
.map(|s| s.content.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n\n---\n\n")
|
||||
}
|
||||
|
||||
/// Check if any content has been added
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.sections.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Default base prompt for Owlen agent
|
||||
pub fn default_base_prompt() -> &'static str {
|
||||
r#"You are Owlen, an AI assistant that helps with software engineering tasks.
|
||||
|
||||
You have access to tools for reading files, writing code, running commands, and searching the web.
|
||||
|
||||
## Guidelines
|
||||
|
||||
1. Be direct and concise in your responses
|
||||
2. Use tools to gather information before making changes
|
||||
3. Explain your reasoning when making decisions
|
||||
4. Ask for clarification when requirements are unclear
|
||||
5. Prefer editing existing files over creating new ones
|
||||
|
||||
## Tool Usage
|
||||
|
||||
- Use `read` to examine file contents before editing
|
||||
- Use `glob` and `grep` to find relevant files
|
||||
- Use `edit` for precise changes, `write` for new files
|
||||
- Use `bash` for running tests and commands
|
||||
- Use `web_search` for current information"#
|
||||
}
|
||||
|
||||
/// Generate tool instructions based on available tools
|
||||
pub fn generate_tool_instructions(tool_names: &[&str]) -> String {
|
||||
let mut instructions = String::from("## Available Tools\n\n");
|
||||
|
||||
for name in tool_names {
|
||||
let desc = match *name {
|
||||
"read" => "Read file contents",
|
||||
"write" => "Create or overwrite a file",
|
||||
"edit" => "Edit a file by replacing text",
|
||||
"multi_edit" => "Apply multiple edits atomically",
|
||||
"glob" => "Find files by pattern",
|
||||
"grep" => "Search file contents",
|
||||
"ls" => "List directory contents",
|
||||
"bash" => "Execute shell commands",
|
||||
"web_search" => "Search the web",
|
||||
"web_fetch" => "Fetch a URL",
|
||||
"todo_write" => "Update task list",
|
||||
"ask_user" => "Ask user a question",
|
||||
_ => continue,
|
||||
};
|
||||
instructions.push_str(&format!("- `{}`: {}\n", name, desc));
|
||||
}
|
||||
|
||||
instructions
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_builder() {
|
||||
let prompt = SystemPromptBuilder::new()
|
||||
.with_base_prompt("You are helpful")
|
||||
.with_tool_instructions("Use tools wisely")
|
||||
.build();
|
||||
|
||||
assert!(prompt.contains("You are helpful"));
|
||||
assert!(prompt.contains("Use tools wisely"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_priority_ordering() {
|
||||
let prompt = SystemPromptBuilder::new()
|
||||
.with_section("last", "Third", 100)
|
||||
.with_section("first", "First", 0)
|
||||
.with_section("middle", "Second", 50)
|
||||
.build();
|
||||
|
||||
let first_pos = prompt.find("First").unwrap();
|
||||
let second_pos = prompt.find("Second").unwrap();
|
||||
let third_pos = prompt.find("Third").unwrap();
|
||||
|
||||
assert!(first_pos < second_pos);
|
||||
assert!(second_pos < third_pos);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_base_prompt() {
|
||||
let prompt = default_base_prompt();
|
||||
assert!(prompt.contains("Owlen"));
|
||||
assert!(prompt.contains("Guidelines"));
|
||||
assert!(prompt.contains("Tool Usage"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_tool_instructions() {
|
||||
let tools = vec!["read", "write", "edit", "bash"];
|
||||
let instructions = generate_tool_instructions(&tools);
|
||||
|
||||
assert!(instructions.contains("Available Tools"));
|
||||
assert!(instructions.contains("read"));
|
||||
assert!(instructions.contains("write"));
|
||||
assert!(instructions.contains("edit"));
|
||||
assert!(instructions.contains("bash"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder_empty() {
|
||||
let builder = SystemPromptBuilder::new();
|
||||
assert!(builder.is_empty());
|
||||
|
||||
let builder = builder.with_base_prompt("test");
|
||||
assert!(!builder.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skill_section() {
|
||||
let prompt = SystemPromptBuilder::new()
|
||||
.with_base_prompt("Base")
|
||||
.with_skill("rust", "Rust expertise")
|
||||
.build();
|
||||
|
||||
assert!(prompt.contains("Base"));
|
||||
assert!(prompt.contains("Rust expertise"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hook_injection() {
|
||||
let prompt = SystemPromptBuilder::new()
|
||||
.with_base_prompt("Base")
|
||||
.with_hook_injection("Additional context from hook")
|
||||
.build();
|
||||
|
||||
assert!(prompt.contains("Base"));
|
||||
assert!(prompt.contains("Additional context from hook"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_separator_between_sections() {
|
||||
let prompt = SystemPromptBuilder::new()
|
||||
.with_section("first", "First section", 0)
|
||||
.with_section("second", "Second section", 10)
|
||||
.build();
|
||||
|
||||
assert!(prompt.contains("---"));
|
||||
assert!(prompt.contains("First section"));
|
||||
assert!(prompt.contains("Second section"));
|
||||
}
|
||||
}
|
||||
276
crates/core/agent/tests/streaming.rs
Normal file
276
crates/core/agent/tests/streaming.rs
Normal file
@@ -0,0 +1,276 @@
|
||||
use agent_core::{create_event_channel, run_agent_loop_streaming, AgentEvent, ToolContext};
|
||||
use async_trait::async_trait;
|
||||
use futures_util::stream;
|
||||
use llm_core::{
|
||||
ChatMessage, ChatOptions, LlmError, StreamChunk, LlmProvider, Tool, ToolCallDelta,
|
||||
};
|
||||
use permissions::{Mode, PermissionManager};
|
||||
use std::pin::Pin;
|
||||
|
||||
/// Mock LLM provider for testing streaming
|
||||
struct MockStreamingProvider {
|
||||
responses: Vec<MockResponse>,
|
||||
}
|
||||
|
||||
enum MockResponse {
|
||||
/// Text-only response (no tool calls)
|
||||
Text(Vec<String>), // Chunks of text
|
||||
/// Tool call response
|
||||
ToolCall {
|
||||
text_chunks: Vec<String>,
|
||||
tool_id: String,
|
||||
tool_name: String,
|
||||
tool_args: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for MockStreamingProvider {
|
||||
fn name(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
"mock-model"
|
||||
}
|
||||
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
_options: &ChatOptions,
|
||||
_tools: Option<&[Tool]>,
|
||||
) -> Result<Pin<Box<dyn futures_util::Stream<Item = Result<StreamChunk, LlmError>> + Send>>, LlmError> {
|
||||
// Determine which response to use based on message count
|
||||
let response_idx = (messages.len() / 2).min(self.responses.len() - 1);
|
||||
let response = &self.responses[response_idx];
|
||||
|
||||
let chunks: Vec<Result<StreamChunk, LlmError>> = match response {
|
||||
MockResponse::Text(text_chunks) => text_chunks
|
||||
.iter()
|
||||
.map(|text| {
|
||||
Ok(StreamChunk {
|
||||
content: Some(text.clone()),
|
||||
tool_calls: None,
|
||||
done: false,
|
||||
usage: None,
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
MockResponse::ToolCall {
|
||||
text_chunks,
|
||||
tool_id,
|
||||
tool_name,
|
||||
tool_args,
|
||||
} => {
|
||||
let mut result = vec![];
|
||||
|
||||
// First emit text chunks
|
||||
for text in text_chunks {
|
||||
result.push(Ok(StreamChunk {
|
||||
content: Some(text.clone()),
|
||||
tool_calls: None,
|
||||
done: false,
|
||||
usage: None,
|
||||
}));
|
||||
}
|
||||
|
||||
// Then emit tool call in chunks
|
||||
result.push(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: 0,
|
||||
id: Some(tool_id.clone()),
|
||||
function_name: Some(tool_name.clone()),
|
||||
arguments_delta: None,
|
||||
}]),
|
||||
done: false,
|
||||
usage: None,
|
||||
}));
|
||||
|
||||
// Emit args in chunks
|
||||
for chunk in tool_args.chars().collect::<Vec<_>>().chunks(5) {
|
||||
result.push(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: 0,
|
||||
id: None,
|
||||
function_name: None,
|
||||
arguments_delta: Some(chunk.iter().collect()),
|
||||
}]),
|
||||
done: false,
|
||||
usage: None,
|
||||
}));
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Box::pin(stream::iter(chunks)))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_text_only() {
|
||||
let provider = MockStreamingProvider {
|
||||
responses: vec![MockResponse::Text(vec![
|
||||
"Hello".to_string(),
|
||||
" ".to_string(),
|
||||
"world".to_string(),
|
||||
"!".to_string(),
|
||||
])],
|
||||
};
|
||||
|
||||
let perms = PermissionManager::new(Mode::Plan);
|
||||
let ctx = ToolContext::default();
|
||||
let (tx, mut rx) = create_event_channel();
|
||||
|
||||
// Spawn the agent loop
|
||||
let handle = tokio::spawn(async move {
|
||||
run_agent_loop_streaming(
|
||||
&provider,
|
||||
"Say hello",
|
||||
&ChatOptions::default(),
|
||||
&perms,
|
||||
&ctx,
|
||||
tx,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
// Collect events
|
||||
let mut text_deltas = vec![];
|
||||
let mut done_response = None;
|
||||
|
||||
while let Some(event) = rx.recv().await {
|
||||
match event {
|
||||
AgentEvent::TextDelta(text) => {
|
||||
text_deltas.push(text);
|
||||
}
|
||||
AgentEvent::Done { final_response } => {
|
||||
done_response = Some(final_response);
|
||||
break;
|
||||
}
|
||||
AgentEvent::Error(e) => {
|
||||
panic!("Unexpected error: {}", e);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for agent loop to complete
|
||||
let result = handle.await.unwrap();
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify events
|
||||
assert_eq!(text_deltas, vec!["Hello", " ", "world", "!"]);
|
||||
assert_eq!(done_response, Some("Hello world!".to_string()));
|
||||
assert_eq!(result.unwrap(), "Hello world!");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_with_tool_call() {
|
||||
let provider = MockStreamingProvider {
|
||||
responses: vec![
|
||||
MockResponse::ToolCall {
|
||||
text_chunks: vec!["Let me ".to_string(), "check...".to_string()],
|
||||
tool_id: "call_123".to_string(),
|
||||
tool_name: "glob".to_string(),
|
||||
tool_args: r#"{"pattern":"*.rs"}"#.to_string(),
|
||||
},
|
||||
MockResponse::Text(vec!["Found ".to_string(), "the files!".to_string()]),
|
||||
],
|
||||
};
|
||||
|
||||
let perms = PermissionManager::new(Mode::Plan);
|
||||
let ctx = ToolContext::default();
|
||||
let (tx, mut rx) = create_event_channel();
|
||||
|
||||
// Spawn the agent loop
|
||||
let handle = tokio::spawn(async move {
|
||||
run_agent_loop_streaming(
|
||||
&provider,
|
||||
"Find Rust files",
|
||||
&ChatOptions::default(),
|
||||
&perms,
|
||||
&ctx,
|
||||
tx,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
// Collect events
|
||||
let mut text_deltas = vec![];
|
||||
let mut tool_starts = vec![];
|
||||
let mut tool_outputs = vec![];
|
||||
let mut tool_ends = vec![];
|
||||
|
||||
while let Some(event) = rx.recv().await {
|
||||
match event {
|
||||
AgentEvent::TextDelta(text) => {
|
||||
text_deltas.push(text);
|
||||
}
|
||||
AgentEvent::ToolStart {
|
||||
tool_name,
|
||||
tool_id,
|
||||
} => {
|
||||
tool_starts.push((tool_name, tool_id));
|
||||
}
|
||||
AgentEvent::ToolOutput {
|
||||
tool_id,
|
||||
content,
|
||||
is_error,
|
||||
} => {
|
||||
tool_outputs.push((tool_id, content, is_error));
|
||||
}
|
||||
AgentEvent::ToolEnd { tool_id, success } => {
|
||||
tool_ends.push((tool_id, success));
|
||||
}
|
||||
AgentEvent::Done { .. } => {
|
||||
break;
|
||||
}
|
||||
AgentEvent::Error(e) => {
|
||||
panic!("Unexpected error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for agent loop to complete
|
||||
let result = handle.await.unwrap();
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify we got text deltas from both responses
|
||||
assert!(text_deltas.contains(&"Let me ".to_string()));
|
||||
assert!(text_deltas.contains(&"check...".to_string()));
|
||||
assert!(text_deltas.contains(&"Found ".to_string()));
|
||||
assert!(text_deltas.contains(&"the files!".to_string()));
|
||||
|
||||
// Verify tool events
|
||||
assert_eq!(tool_starts.len(), 1);
|
||||
assert_eq!(tool_starts[0].0, "glob");
|
||||
assert_eq!(tool_starts[0].1, "call_123");
|
||||
|
||||
assert_eq!(tool_outputs.len(), 1);
|
||||
assert_eq!(tool_outputs[0].0, "call_123");
|
||||
assert!(!tool_outputs[0].2); // not an error
|
||||
|
||||
assert_eq!(tool_ends.len(), 1);
|
||||
assert_eq!(tool_ends[0].0, "call_123");
|
||||
assert!(tool_ends[0].1); // success
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_channel_creation() {
|
||||
let (tx, mut rx) = create_event_channel();
|
||||
|
||||
// Test that channel works
|
||||
tx.send(AgentEvent::TextDelta("test".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let event = rx.recv().await.unwrap();
|
||||
match event {
|
||||
AgentEvent::TextDelta(text) => assert_eq!(text, "test"),
|
||||
_ => panic!("Wrong event type"),
|
||||
}
|
||||
}
|
||||
114
crates/core/agent/tests/tool_context.rs
Normal file
114
crates/core/agent/tests/tool_context.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
// Test that ToolContext properly wires up the placeholder tools
|
||||
use agent_core::{ToolContext, execute_tool};
|
||||
use permissions::{Mode, PermissionManager};
|
||||
use tools_todo::{TodoList, TodoStatus};
|
||||
use tools_bash::ShellManager;
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_todo_write_with_context() {
|
||||
let todo_list = TodoList::new();
|
||||
let ctx = ToolContext::new().with_todo_list(todo_list.clone());
|
||||
let perms = PermissionManager::new(Mode::Code); // Allow all tools
|
||||
|
||||
let arguments = json!({
|
||||
"todos": [
|
||||
{
|
||||
"content": "First task",
|
||||
"status": "pending",
|
||||
"active_form": "Working on first task"
|
||||
},
|
||||
{
|
||||
"content": "Second task",
|
||||
"status": "in_progress",
|
||||
"active_form": "Working on second task"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let result = execute_tool("todo_write", &arguments, &perms, &ctx).await;
|
||||
assert!(result.is_ok(), "TodoWrite should succeed: {:?}", result);
|
||||
|
||||
// Verify the todos were written
|
||||
let todos = todo_list.read();
|
||||
assert_eq!(todos.len(), 2);
|
||||
assert_eq!(todos[0].content, "First task");
|
||||
assert_eq!(todos[1].status, TodoStatus::InProgress);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_todo_write_without_context() {
|
||||
let ctx = ToolContext::new(); // No todo_list
|
||||
let perms = PermissionManager::new(Mode::Code);
|
||||
|
||||
let arguments = json!({
|
||||
"todos": []
|
||||
});
|
||||
|
||||
let result = execute_tool("todo_write", &arguments, &perms, &ctx).await;
|
||||
assert!(result.is_err(), "TodoWrite should fail without TodoList");
|
||||
assert!(result.unwrap_err().to_string().contains("not available"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bash_output_with_context() {
|
||||
let manager = ShellManager::new();
|
||||
let ctx = ToolContext::new().with_shell_manager(manager.clone());
|
||||
let perms = PermissionManager::new(Mode::Code);
|
||||
|
||||
// Start a shell and run a command
|
||||
let shell_id = manager.start_shell().await.unwrap();
|
||||
let _ = manager.execute(&shell_id, "echo test", None).await.unwrap();
|
||||
|
||||
let arguments = json!({
|
||||
"shell_id": shell_id
|
||||
});
|
||||
|
||||
let result = execute_tool("bash_output", &arguments, &perms, &ctx).await;
|
||||
assert!(result.is_ok(), "BashOutput should succeed: {:?}", result);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bash_output_without_context() {
|
||||
let ctx = ToolContext::new(); // No shell_manager
|
||||
let perms = PermissionManager::new(Mode::Code);
|
||||
|
||||
let arguments = json!({
|
||||
"shell_id": "fake-id"
|
||||
});
|
||||
|
||||
let result = execute_tool("bash_output", &arguments, &perms, &ctx).await;
|
||||
assert!(result.is_err(), "BashOutput should fail without ShellManager");
|
||||
assert!(result.unwrap_err().to_string().contains("not available"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_kill_shell_with_context() {
|
||||
let manager = ShellManager::new();
|
||||
let ctx = ToolContext::new().with_shell_manager(manager.clone());
|
||||
let perms = PermissionManager::new(Mode::Code);
|
||||
|
||||
// Start a shell
|
||||
let shell_id = manager.start_shell().await.unwrap();
|
||||
|
||||
let arguments = json!({
|
||||
"shell_id": shell_id
|
||||
});
|
||||
|
||||
let result = execute_tool("kill_shell", &arguments, &perms, &ctx).await;
|
||||
assert!(result.is_ok(), "KillShell should succeed: {:?}", result);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ask_user_without_context() {
|
||||
let ctx = ToolContext::new(); // No ask_sender
|
||||
let perms = PermissionManager::new(Mode::Code);
|
||||
|
||||
let arguments = json!({
|
||||
"questions": []
|
||||
});
|
||||
|
||||
let result = execute_tool("ask_user", &arguments, &perms, &ctx).await;
|
||||
assert!(result.is_err(), "AskUser should fail without AskSender");
|
||||
assert!(result.unwrap_err().to_string().contains("not available"));
|
||||
}
|
||||
Reference in New Issue
Block a user