10 Commits

Author SHA1 Message Date
4a07b97eab feat(ui): add autocomplete, command help, and streaming improvements
TUI Enhancements:
- Add autocomplete dropdown with fuzzy filtering for slash commands
- Fix autocomplete: Tab confirms selection, Enter submits message
- Add command help overlay with scroll support (j/k, arrows, Page Up/Down)
- Brighten Tokyo Night theme colors for better readability
- Add todo panel component for task display
- Add rich command output formatting (tables, trees, lists)

Streaming Fixes:
- Refactor to non-blocking background streaming with channel events
- Add StreamStart/StreamEnd/StreamError events
- Fix LlmChunk to append instead of creating new messages
- Display user message immediately before LLM call

New Components:
- completions.rs: Command completion engine with fuzzy matching
- autocomplete.rs: Inline autocomplete dropdown
- command_help.rs: Modal help overlay with scrolling
- todo_panel.rs: Todo list display panel
- output.rs: Rich formatted output (tables, trees, code blocks)
- commands.rs: Built-in command implementations

Planning Mode Groundwork:
- Add EnterPlanMode/ExitPlanMode tools scaffolding
- Add Skill tool for plugin skill invocation
- Extend permissions with planning mode support
- Add compact.rs stub for context compaction

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-02 19:03:33 +01:00
10c8e2baae feat(v2): complete multi-LLM providers, TUI redesign, and advanced agent features
Multi-LLM Provider Support:
- Add llm-core crate with LlmProvider trait abstraction
- Implement Anthropic Claude API client with streaming
- Implement OpenAI API client with streaming
- Add token counting with SimpleTokenCounter and ClaudeTokenCounter
- Add retry logic with exponential backoff and jitter

Borderless TUI Redesign:
- Rewrite theme system with terminal capability detection (Full/Unicode256/Basic)
- Add provider tabs component with keybind switching [1]/[2]/[3]
- Implement vim-modal input (Normal/Insert/Visual/Command modes)
- Redesign chat panel with timestamps and streaming indicators
- Add multi-provider status bar with cost tracking
- Add Nerd Font icons with graceful ASCII fallbacks
- Add syntax highlighting (syntect) and markdown rendering (pulldown-cmark)

Advanced Agent Features:
- Add system prompt builder with configurable components
- Enhance subagent orchestration with parallel execution
- Add git integration module for safe command detection
- Add streaming tool results via channels
- Expand tool set: AskUserQuestion, TodoWrite, LS, MultiEdit, BashOutput, KillShell
- Add WebSearch with provider abstraction

Plugin System Enhancement:
- Add full agent definition parsing from YAML frontmatter
- Add skill system with progressive disclosure
- Wire plugin hooks into HookManager

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-02 17:24:14 +01:00
09c8c9d83e feat(ui): add TUI with streaming agent integration and theming
Add a new terminal UI crate (crates/app/ui) built with ratatui providing an
interactive chat interface with real-time LLM streaming and tool visualization.

Features:
- Chat panel with horizontal padding for improved readability
- Input box with cursor navigation and command history
- Status bar with session statistics and uniform background styling
- 7 theme presets: Tokyo Night (default), Dracula, Catppuccin, Nord,
  Synthwave, Rose Pine, and Midnight Ocean
- Theme switching via /theme <name> and /themes commands
- Streaming LLM responses that accumulate into single messages
- Real-time tool call visualization with success/error states
- Session tracking (messages, tokens, tool calls, duration)
- REPL commands: /help, /status, /cost, /checkpoint, /rewind, /clear, /exit

Integration:
- CLI automatically launches TUI mode when running interactively (no prompt)
- Falls back to legacy text REPL with --no-tui flag
- Uses existing agent loop with streaming support
- Supports all existing tools (read, write, edit, glob, grep, bash)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-01 22:57:25 +01:00
5caf502009 feat(M12): complete milestone with plugins, checkpointing, and rewind
Implements the remaining M12 features from AGENTS.md:

**Plugin System (crates/platform/plugins)**
- Plugin manifest schema with plugin.json support
- Plugin loader for commands, agents, skills, hooks, and MCP servers
- Discovers plugins from ~/.config/owlen/plugins and .owlen/plugins
- Includes comprehensive tests (4 passing)

**Session Checkpointing (crates/core/agent)**
- Checkpoint struct capturing session state and file diffs
- CheckpointManager with snapshot, diff, save, load, and rewind capabilities
- File diff tracking with before/after content
- Checkpoint persistence to .owlen/checkpoints/
- Includes comprehensive tests (6 passing)

**REPL Commands (crates/app/cli)**
- /checkpoint - Save current session with file diffs
- /checkpoints - List all saved checkpoints
- /rewind <id> - Restore session and files from checkpoint
- Updated /help documentation

M12 milestone now fully complete:
 /permissions, /status, /cost (previously implemented)
 Checkpointing and /rewind
 Plugin loader with manifest schema

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-01 21:59:08 +01:00
04a7085007 feat(repl): implement M12 REPL commands and session tracking
Add comprehensive REPL commands for session management and introspection:

**Session Tracking** (`crates/core/agent/src/session.rs`):
- SessionStats: Track messages, tool calls, tokens, timing
- SessionHistory: Store conversation history and tool call records
- Auto-formatting for durations (seconds, minutes, hours)

**REPL Commands** (in interactive mode):
- `/help`        - List all available commands
- `/status`      - Show session stats (messages, tools, uptime)
- `/permissions` - Display permission mode and tool access
- `/cost`        - Show token usage and timing (free with Ollama!)
- `/history`     - View conversation history
- `/clear`       - Reset session state
- `/exit`        - Exit interactive mode gracefully

**Stats Tracking**:
- Automatic message counting
- Token estimation (chars / 4)
- Duration tracking per message
- Tool call counting (foundation for future)
- Session uptime from start

**Permission Display**:
- Shows current mode (Plan/AcceptEdits/Code)
- Lists tools by category (read-only, write, system)
- Indicates which tools are allowed/ask/deny

**UX Improvements**:
- Welcome message shows model and mode
- Clean command output with emoji indicators
- Helpful error messages for unknown commands
- Session stats persist across messages

**Example Session**:
```
🤖 Owlen Interactive Mode
Model: qwen3:8b
Mode: Plan

> /help
📖 Available Commands: [list]

> Find all Cargo.toml files
🔧 Tool call: glob...
 Tool result: 14 files

> /status
📊 Session Status:
  Messages: 1
  Tools: 1 calls
  Uptime: 15s

> /cost
💰 Token Usage: ~234 tokens

> /exit
👋 Goodbye!
```

Implements core M12 requirements for REPL commands and session management.
Future: Checkpointing/rewind functionality can build on this foundation.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-01 21:05:29 +01:00
6022aeb2b0 feat(cli): add interactive REPL mode with agent loop
Add proper interactive mode when no prompt is provided:

**Interactive REPL Features**:
- Starts when running `cargo run` with no arguments
- Shows welcome message with model name
- Prompts with `> ` for user input
- Each input runs through the full agent loop with tools
- Continues until Ctrl+C or EOF
- Displays tool calls and results in real-time

**Changes**:
- Detect empty prompt and enter interactive loop
- Use stdin.lines() for reading user input
- Call agent_core::run_agent_loop for each message
- Handle errors gracefully and continue
- Clean up unused imports

**Usage**:
```bash
# Interactive mode
cargo run

# Single prompt mode
cargo run -- --print "Find all Cargo.toml files"

# Tool subcommands
cargo run -- glob "**/*.rs"
```

Example session:
```
🤖 Owlen Interactive Mode
Model: qwen3:8b

> Find all markdown files
🔧 Tool call: glob with args: {"pattern":"**/*.md"}
 Tool result: ./README.md ./CLAUDE.md ./AGENTS.md
...

> exit
```

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-01 21:00:56 +01:00
e77e33ce2f feat(agent): implement Agent Orchestrator with LLM tool calling
Add complete agent orchestration system that enables LLM to call tools:

**Core Agent System** (`crates/core/agent`):
- Agent execution loop with tool call/result cycle
- Tool definitions in Ollama-compatible format (6 tools)
- Tool execution with permission checking
- Multi-iteration support with max iteration safety

**Tool Definitions**:
- read: Read file contents
- glob: Find files by pattern
- grep: Search for patterns in files
- write: Write content to files
- edit: Edit files with find/replace
- bash: Execute bash commands

**Ollama Integration Updates**:
- Extended ChatMessage to support tool_calls
- Added Tool, ToolCall, ToolFunction types
- Updated chat_stream to accept tools parameter
- Made tool call fields optional for Ollama compatibility

**CLI Integration**:
- Wired agent loop into all output formats (Text, JSON, StreamJSON)
- Tool calls displayed with 🔧 icon, results with 
- Replaced simple chat with agent orchestrator

**Permission Integration**:
- All tool executions check permissions before running
- Respects plan/acceptEdits/code modes
- Returns clear error messages for denied operations

**Example**:
User: "Find all Cargo.toml files in the workspace"
LLM: Calls glob("**/Cargo.toml")
Agent: Executes and returns 14 files
LLM: Formats human-readable response

This transforms owlen from a passive chatbot into an active agent that
can autonomously use tools to accomplish user goals.

Tested with: qwen3:8b successfully calling glob tool

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-01 20:56:56 +01:00
f87e5d2796 feat(tools): implement M11 subagent system with task routing
Add tools-task crate with subagent registry and tool whitelist system:

Core Features:
- Subagent struct with name, description, keywords, and allowed tools
- SubagentRegistry for managing and selecting subagents
- Tool whitelist validation per subagent
- Keyword-based task matching and agent selection

Built-in Subagents:
- code-reviewer: Read-only code analysis (Read, Grep, Glob)
- test-writer: Test file creation (Read, Write, Edit, Grep, Glob)
- doc-writer: Documentation management (Read, Write, Edit, Grep, Glob)
- refactorer: Code restructuring (Read, Write, Edit, Grep, Glob)

Test Coverage:
- Subagent tool whitelist enforcement
- Keyword matching for task descriptions
- Registry selection based on task description
- Tool validation for specific agents
- Error handling for nonexistent agents

Implements M11 from AGENTS.md for specialized agents with limited tool access.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-01 20:37:37 +01:00
3c436fda54 feat(tools): implement M10 Jupyter notebook support
Add tools-notebook crate with full Jupyter notebook (.ipynb) support:

- Core data structures: Notebook, Cell, NotebookMetadata, Output
- Read/write operations with metadata preservation
- Edit operations: EditCell, AddCell, DeleteCell
- Helper functions: new_code_cell, new_markdown_cell, cell_source_as_string
- Comprehensive test suite: 9 tests covering round-trip, editing, and error handling
- Permission integration: NotebookRead (plan mode), NotebookEdit (acceptedits mode)

Implements M10 from AGENTS.md for LLM-driven notebook editing.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-01 20:33:28 +01:00
173403379f feat(M9): implement WebFetch and WebSearch with domain filtering and pluggable providers
Milestone M9 implementation adds web access tools with security controls.

New crate: crates/tools/web

WebFetch Features:
- HTTP client using reqwest
- Domain allowlist/blocklist filtering
  * Empty allowlist = allow all domains (except blocked)
  * Non-empty allowlist = only allow specified domains
  * Blocklist always takes precedence
- Redirect detection and blocking
  * Redirects to unapproved domains are blocked
  * Manual redirect policy (no automatic following)
  * Returns error message with redirect URL
- Response capture with metadata
  * Status code, content, content-type
  * Original URL preserved

WebSearch Features:
- Pluggable provider trait using async-trait
- SearchProvider trait for implementing search APIs
- StubSearchProvider for testing
- SearchResult structure with title, URL, snippet
- Provider name identification

Security Features:
- Case-insensitive domain matching
- Host extraction from URLs
- Relative redirect URL resolution
- Domain validation before requests
- Explicit approval required for cross-domain redirects

Tests added (9 new tests):
Unit tests:
1. domain_filtering_allowlist - Verifies allowlist-only mode
2. domain_filtering_blocklist - Verifies blocklist takes precedence
3. domain_filtering_case_insensitive - Verifies case handling

Integration tests with wiremock:
4. webfetch_domain_whitelist_only - Tests allowlist enforcement
5. webfetch_redirect_to_unapproved_domain - Blocks bad redirects
6. webfetch_redirect_to_approved_domain - Detects good redirects
7. webfetch_blocklist_overrides_allowlist - Blocklist priority
8. websearch_pluggable_provider - Provider pattern works
9. webfetch_successful_request - Basic fetch operation

All 84 tests passing (up from 75).

Note: CLI integration deferred - infrastructure is complete and tested.
Future work will add CLI commands for web-fetch and web-search with
domain configuration.

Dependencies: reqwest 0.12, async-trait 0.1, wiremock 0.6 (test)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-01 20:23:29 +01:00
84 changed files with 18926 additions and 125 deletions

View File

@@ -1,13 +1,26 @@
[workspace]
members = [
"crates/app/cli",
"crates/app/ui",
"crates/core/agent",
"crates/llm/core",
"crates/llm/anthropic",
"crates/llm/ollama",
"crates/llm/openai",
"crates/platform/config",
"crates/platform/hooks",
"crates/platform/permissions",
"crates/platform/plugins",
"crates/tools/ask",
"crates/tools/bash",
"crates/tools/fs",
"crates/tools/notebook",
"crates/tools/plan",
"crates/tools/skill",
"crates/tools/slash",
"crates/tools/task",
"crates/tools/todo",
"crates/tools/web",
"crates/integration/mcp-client",
]
resolver = "2"

View File

@@ -11,6 +11,8 @@ tokio = { version = "1.39", features = ["macros", "rt-multi-thread"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
color-eyre = "0.6"
agent-core = { path = "../../core/agent" }
llm-core = { path = "../../llm/core" }
llm-ollama = { path = "../../llm/ollama" }
tools-fs = { path = "../../tools/fs" }
tools-bash = { path = "../../tools/bash" }
@@ -18,6 +20,9 @@ tools-slash = { path = "../../tools/slash" }
config-agent = { package = "config-agent", path = "../../platform/config" }
permissions = { path = "../../platform/permissions" }
hooks = { path = "../../platform/hooks" }
plugins = { path = "../../platform/plugins" }
ui = { path = "../ui" }
atty = "0.2"
futures-util = "0.3.31"
[dev-dependencies]

View File

@@ -0,0 +1,382 @@
//! Built-in commands for CLI and TUI
//!
//! Provides handlers for /help, /mcp, /hooks, /clear, and other built-in commands.
use ui::{CommandInfo, CommandOutput, OutputFormat, TreeNode, ListItem};
use permissions::PermissionManager;
use hooks::HookManager;
use plugins::PluginManager;
use agent_core::SessionStats;
/// Result of executing a built-in command
pub enum CommandResult {
/// Command produced output to display
Output(CommandOutput),
/// Command was handled but produced no output (e.g., /clear)
Handled,
/// Command was not recognized
NotFound,
/// Command needs to exit the session
Exit,
}
/// Built-in command handler
pub struct BuiltinCommands<'a> {
plugin_manager: Option<&'a PluginManager>,
hook_manager: Option<&'a HookManager>,
permission_manager: Option<&'a PermissionManager>,
stats: Option<&'a SessionStats>,
}
impl<'a> BuiltinCommands<'a> {
pub fn new() -> Self {
Self {
plugin_manager: None,
hook_manager: None,
permission_manager: None,
stats: None,
}
}
pub fn with_plugins(mut self, pm: &'a PluginManager) -> Self {
self.plugin_manager = Some(pm);
self
}
pub fn with_hooks(mut self, hm: &'a HookManager) -> Self {
self.hook_manager = Some(hm);
self
}
pub fn with_permissions(mut self, perms: &'a PermissionManager) -> Self {
self.permission_manager = Some(perms);
self
}
pub fn with_stats(mut self, stats: &'a SessionStats) -> Self {
self.stats = Some(stats);
self
}
/// Execute a built-in command
pub fn execute(&self, command: &str) -> CommandResult {
let parts: Vec<&str> = command.split_whitespace().collect();
let cmd = parts.first().map(|s| s.trim_start_matches('/'));
match cmd {
Some("help") | Some("?") => CommandResult::Output(self.help()),
Some("mcp") => CommandResult::Output(self.mcp()),
Some("hooks") => CommandResult::Output(self.hooks()),
Some("plugins") => CommandResult::Output(self.plugins()),
Some("status") => CommandResult::Output(self.status()),
Some("permissions") | Some("perms") => CommandResult::Output(self.permissions()),
Some("clear") => CommandResult::Handled,
Some("exit") | Some("quit") | Some("q") => CommandResult::Exit,
_ => CommandResult::NotFound,
}
}
/// Generate help output
fn help(&self) -> CommandOutput {
let mut commands = vec![
// Built-in commands
CommandInfo::new("help", "Show available commands", "builtin"),
CommandInfo::new("clear", "Clear the screen", "builtin"),
CommandInfo::new("status", "Show session status", "builtin"),
CommandInfo::new("permissions", "Show permission settings", "builtin"),
CommandInfo::new("mcp", "List MCP servers and tools", "builtin"),
CommandInfo::new("hooks", "Show loaded hooks", "builtin"),
CommandInfo::new("plugins", "Show loaded plugins", "builtin"),
CommandInfo::new("checkpoint", "Save session state", "builtin"),
CommandInfo::new("checkpoints", "List saved checkpoints", "builtin"),
CommandInfo::new("rewind", "Restore from checkpoint", "builtin"),
CommandInfo::new("compact", "Compact conversation context", "builtin"),
CommandInfo::new("exit", "Exit the session", "builtin"),
];
// Add plugin commands
if let Some(pm) = self.plugin_manager {
for plugin in pm.plugins() {
for cmd_name in plugin.all_command_names() {
commands.push(CommandInfo::new(
&cmd_name,
&format!("Plugin command from {}", plugin.manifest.name),
&format!("plugin:{}", plugin.manifest.name),
));
}
}
}
CommandOutput::help_table(&commands)
}
/// Generate MCP servers output
fn mcp(&self) -> CommandOutput {
let mut servers: Vec<(String, Vec<String>)> = vec![];
// Get MCP servers from plugins
if let Some(pm) = self.plugin_manager {
for plugin in pm.plugins() {
// Check for .mcp.json in plugin directory
let mcp_path = plugin.base_path.join(".mcp.json");
if mcp_path.exists() {
if let Ok(content) = std::fs::read_to_string(&mcp_path) {
if let Ok(config) = serde_json::from_str::<serde_json::Value>(&content) {
if let Some(mcpservers) = config.get("mcpServers").and_then(|v| v.as_object()) {
for (name, _) in mcpservers {
servers.push((
format!("{} ({})", name, plugin.manifest.name),
vec!["(connect to discover tools)".to_string()],
));
}
}
}
}
}
}
}
if servers.is_empty() {
CommandOutput::new(OutputFormat::Text {
content: "No MCP servers configured.\n\nAdd MCP servers in plugin .mcp.json files.".to_string(),
})
} else {
CommandOutput::mcp_tree(&servers)
}
}
/// Generate hooks output
fn hooks(&self) -> CommandOutput {
let mut hooks_list: Vec<(String, String, bool)> = vec![];
// Check for file-based hooks in .owlen/hooks/
let hook_events = ["PreToolUse", "PostToolUse", "SessionStart", "SessionEnd",
"UserPromptSubmit", "PreCompact", "Stop", "SubagentStop"];
for event in hook_events {
let path = format!(".owlen/hooks/{}", event);
let exists = std::path::Path::new(&path).exists();
if exists {
hooks_list.push((event.to_string(), path, true));
}
}
// Get hooks from plugins
if let Some(pm) = self.plugin_manager {
for plugin in pm.plugins() {
if let Some(hooks_config) = plugin.load_hooks_config().ok().flatten() {
// hooks_config.hooks is HashMap<String, Vec<HookMatcher>>
for (event_name, matchers) in &hooks_config.hooks {
for matcher in matchers {
for hook_def in &matcher.hooks {
let cmd = hook_def.command.as_deref()
.or(hook_def.prompt.as_deref())
.unwrap_or("(no command)");
hooks_list.push((
event_name.clone(),
format!("{}: {}", plugin.manifest.name, cmd),
true,
));
}
}
}
}
}
}
if hooks_list.is_empty() {
CommandOutput::new(OutputFormat::Text {
content: "No hooks configured.\n\nAdd hooks in .owlen/hooks/ or plugin hooks.json files.".to_string(),
})
} else {
CommandOutput::hooks_list(&hooks_list)
}
}
/// Generate plugins output
fn plugins(&self) -> CommandOutput {
if let Some(pm) = self.plugin_manager {
let plugins = pm.plugins();
if plugins.is_empty() {
return CommandOutput::new(OutputFormat::Text {
content: "No plugins loaded.\n\nPlace plugins in:\n - ~/.config/owlen/plugins (user)\n - .owlen/plugins (project)".to_string(),
});
}
// Build tree of plugins and their components
let children: Vec<TreeNode> = plugins.iter().map(|p| {
let mut plugin_children = vec![];
let commands = p.all_command_names();
if !commands.is_empty() {
plugin_children.push(TreeNode::new("Commands").with_children(
commands.iter().map(|c| TreeNode::new(format!("/{}", c))).collect()
));
}
let agents = p.all_agent_names();
if !agents.is_empty() {
plugin_children.push(TreeNode::new("Agents").with_children(
agents.iter().map(|a| TreeNode::new(a)).collect()
));
}
let skills = p.all_skill_names();
if !skills.is_empty() {
plugin_children.push(TreeNode::new("Skills").with_children(
skills.iter().map(|s| TreeNode::new(s)).collect()
));
}
TreeNode::new(format!("{} v{}", p.manifest.name, p.manifest.version))
.with_children(plugin_children)
}).collect();
CommandOutput::new(OutputFormat::Tree {
root: TreeNode::new("Loaded Plugins").with_children(children),
})
} else {
CommandOutput::new(OutputFormat::Text {
content: "Plugin manager not available.".to_string(),
})
}
}
/// Generate status output
fn status(&self) -> CommandOutput {
let mut items = vec![];
if let Some(stats) = self.stats {
items.push(ListItem {
text: format!("Messages: {}", stats.total_messages),
marker: Some("📊".to_string()),
style: None,
});
items.push(ListItem {
text: format!("Tool Calls: {}", stats.total_tool_calls),
marker: Some("🔧".to_string()),
style: None,
});
items.push(ListItem {
text: format!("Est. Tokens: ~{}", stats.estimated_tokens),
marker: Some("📝".to_string()),
style: None,
});
let uptime = stats.start_time.elapsed().unwrap_or_default();
items.push(ListItem {
text: format!("Uptime: {}", SessionStats::format_duration(uptime)),
marker: Some("⏱️".to_string()),
style: None,
});
}
if let Some(perms) = self.permission_manager {
items.push(ListItem {
text: format!("Mode: {:?}", perms.mode()),
marker: Some("🔒".to_string()),
style: None,
});
}
if items.is_empty() {
CommandOutput::new(OutputFormat::Text {
content: "Session status not available.".to_string(),
})
} else {
CommandOutput::new(OutputFormat::List { items })
}
}
/// Generate permissions output
fn permissions(&self) -> CommandOutput {
if let Some(perms) = self.permission_manager {
let mode = perms.mode();
let mode_str = format!("{:?}", mode);
let mut items = vec![
ListItem {
text: format!("Current Mode: {}", mode_str),
marker: Some("🔒".to_string()),
style: None,
},
];
// Add tool permissions summary
let (read_status, write_status, bash_status) = match mode {
permissions::Mode::Plan => ("✅ Allowed", "❓ Ask", "❓ Ask"),
permissions::Mode::AcceptEdits => ("✅ Allowed", "✅ Allowed", "❓ Ask"),
permissions::Mode::Code => ("✅ Allowed", "✅ Allowed", "✅ Allowed"),
};
items.push(ListItem {
text: format!("Read/Grep/Glob: {}", read_status),
marker: None,
style: None,
});
items.push(ListItem {
text: format!("Write/Edit: {}", write_status),
marker: None,
style: None,
});
items.push(ListItem {
text: format!("Bash: {}", bash_status),
marker: None,
style: None,
});
CommandOutput::new(OutputFormat::List { items })
} else {
CommandOutput::new(OutputFormat::Text {
content: "Permission manager not available.".to_string(),
})
}
}
}
impl Default for BuiltinCommands<'_> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_help_command() {
let handler = BuiltinCommands::new();
match handler.execute("/help") {
CommandResult::Output(output) => {
match output.format {
OutputFormat::Table { headers, rows } => {
assert!(!headers.is_empty());
assert!(!rows.is_empty());
}
_ => panic!("Expected Table format"),
}
}
_ => panic!("Expected Output result"),
}
}
#[test]
fn test_exit_command() {
let handler = BuiltinCommands::new();
assert!(matches!(handler.execute("/exit"), CommandResult::Exit));
assert!(matches!(handler.execute("/quit"), CommandResult::Exit));
assert!(matches!(handler.execute("/q"), CommandResult::Exit));
}
#[test]
fn test_clear_command() {
let handler = BuiltinCommands::new();
assert!(matches!(handler.execute("/clear"), CommandResult::Handled));
}
#[test]
fn test_unknown_command() {
let handler = BuiltinCommands::new();
assert!(matches!(handler.execute("/unknown"), CommandResult::NotFound));
}
}

View File

@@ -1,14 +1,19 @@
mod commands;
use clap::{Parser, ValueEnum};
use color_eyre::eyre::{Result, eyre};
use config_agent::load_settings;
use futures_util::TryStreamExt;
use hooks::{HookEvent, HookManager, HookResult};
use llm_ollama::{OllamaClient, OllamaOptions, types::ChatMessage};
use llm_core::ChatOptions;
use llm_ollama::OllamaClient;
use permissions::{PermissionDecision, Tool};
use plugins::PluginManager;
use serde::Serialize;
use std::io::{self, Write};
use std::io::Write;
use std::time::{SystemTime, UNIX_EPOCH};
pub use commands::{BuiltinCommands, CommandResult};
#[derive(Debug, Clone, Copy, ValueEnum)]
enum OutputFormat {
Text,
@@ -49,6 +54,51 @@ struct StreamEvent {
stats: Option<Stats>,
}
/// Application context shared across the session
pub struct AppContext {
pub plugin_manager: PluginManager,
pub config: config_agent::Settings,
}
impl AppContext {
pub fn new() -> Result<Self> {
let config = load_settings(None).unwrap_or_default();
let mut plugin_manager = PluginManager::new();
// Non-fatal: just log warnings, don't fail startup
if let Err(e) = plugin_manager.load_all() {
eprintln!("Warning: Failed to load some plugins: {}", e);
}
Ok(Self {
plugin_manager,
config,
})
}
/// Print loaded plugins and available commands
pub fn print_plugin_info(&self) {
let plugins = self.plugin_manager.plugins();
if !plugins.is_empty() {
println!("\nLoaded {} plugin(s):", plugins.len());
for plugin in plugins {
println!(" - {} v{}", plugin.manifest.name, plugin.manifest.version);
if let Some(desc) = &plugin.manifest.description {
println!(" {}", desc);
}
}
}
let commands = self.plugin_manager.all_commands();
if !commands.is_empty() {
println!("\nAvailable plugin commands:");
for (name, _path) in &commands {
println!(" /{}", name);
}
}
}
}
fn generate_session_id() -> String {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
@@ -150,6 +200,9 @@ struct Args {
/// Output format (text, json, stream-json)
#[arg(long, value_enum, default_value = "text")]
output_format: OutputFormat,
/// Disable TUI and use legacy text-based REPL
#[arg(long)]
no_tui: bool,
#[arg()]
prompt: Vec<String>,
#[command(subcommand)]
@@ -160,7 +213,10 @@ struct Args {
async fn main() -> Result<()> {
color_eyre::install()?;
let args = Args::parse();
let mut settings = load_settings(None).unwrap_or_default();
// Initialize application context with plugins
let app_context = AppContext::new()?;
let mut settings = app_context.config.clone();
// Override mode if specified via CLI
if let Some(mode) = args.mode {
@@ -171,7 +227,16 @@ async fn main() -> Result<()> {
let perms = settings.create_permission_manager();
// Create hook manager
let hook_mgr = HookManager::new(".");
let mut hook_mgr = HookManager::new(".");
// Register plugin hooks
for plugin in app_context.plugin_manager.plugins() {
if let Ok(Some(hooks_config)) = plugin.load_hooks_config() {
for (event, command, pattern, timeout) in plugin.register_hooks_with_manager(&hooks_config) {
hook_mgr.register_hook(event, command, pattern, timeout);
}
}
}
// Generate session ID
let session_id = generate_session_id();
@@ -395,19 +460,20 @@ async fn main() -> Result<()> {
HookResult::Allow => {}
}
// Look for command file in .owlen/commands/
let command_path = format!(".owlen/commands/{}.md", command_name);
// Look for command file in .owlen/commands/ first
let local_command_path = format!(".owlen/commands/{}.md", command_name);
// Read the command file
let content = match tools_fs::read_file(&command_path) {
Ok(c) => c,
Err(_) => {
return Err(eyre!(
"Slash command '{}' not found at {}",
command_name,
command_path
));
}
// Try local commands first, then plugin commands
let content = if let Ok(c) = tools_fs::read_file(&local_command_path) {
c
} else if let Some(plugin_path) = app_context.plugin_manager.all_commands().get(&command_name) {
// Found in plugins
tools_fs::read_file(&plugin_path.to_string_lossy())?
} else {
return Err(eyre!(
"Slash command '{}' not found in .owlen/commands/ or plugins",
command_name
));
};
// Parse with arguments
@@ -435,76 +501,316 @@ async fn main() -> Result<()> {
}
}
let prompt = if args.prompt.is_empty() {
"Say hello".to_string()
} else {
args.prompt.join(" ")
};
let model = args.model.unwrap_or(settings.model);
let api_key = args.api_key.or(settings.api_key);
let model = args.model.unwrap_or(settings.model.clone());
let api_key = args.api_key.or(settings.api_key.clone());
// Use Ollama Cloud when model has "-cloud" suffix AND API key is set
let use_cloud = model.ends_with("-cloud") && api_key.is_some();
let client = if use_cloud {
OllamaClient::with_cloud().with_api_key(api_key.unwrap())
} else {
let base_url = args.ollama_url.unwrap_or(settings.ollama_url);
let base_url = args.ollama_url.unwrap_or(settings.ollama_url.clone());
let mut client = OllamaClient::new(base_url);
if let Some(key) = api_key {
client = client.with_api_key(key);
}
client
};
let opts = OllamaOptions {
model,
stream: true,
};
let opts = ChatOptions::new(model);
let msgs = vec![ChatMessage {
role: "user".into(),
content: prompt.clone(),
}];
// Check if interactive mode (no prompt provided)
if args.prompt.is_empty() {
// Use TUI mode unless --no-tui flag is set or not a TTY
if !args.no_tui && atty::is(atty::Stream::Stdout) {
// Launch TUI
// Note: For now, TUI doesn't use plugin manager directly
// In the future, we'll integrate plugin commands into TUI
return ui::run(client, opts, perms, settings).await;
}
// Legacy text-based REPL
println!("🤖 Owlen Interactive Mode");
println!("Model: {}", opts.model);
println!("Mode: {:?}", settings.mode);
// Show loaded plugins
let plugins = app_context.plugin_manager.plugins();
if !plugins.is_empty() {
println!("Plugins: {} loaded", plugins.len());
}
println!("Type your message or /help for commands. Press Ctrl+C to exit.\n");
use std::io::{stdin, BufRead};
let stdin = stdin();
let mut lines = stdin.lock().lines();
let mut stats = agent_core::SessionStats::new();
let mut history = agent_core::SessionHistory::new();
let mut checkpoint_mgr = agent_core::CheckpointManager::new(
std::path::PathBuf::from(".owlen/checkpoints")
);
loop {
print!("> ");
std::io::stdout().flush().ok();
if let Some(Ok(line)) = lines.next() {
let input = line.trim();
if input.is_empty() {
continue;
}
// Handle slash commands
if input.starts_with('/') {
match input {
"/help" => {
println!("\n📖 Available Commands:");
println!(" /help - Show this help message");
println!(" /status - Show session status");
println!(" /permissions - Show permission settings");
println!(" /cost - Show token usage and timing");
println!(" /history - Show conversation history");
println!(" /checkpoint - Save current session state");
println!(" /checkpoints - List all saved checkpoints");
println!(" /rewind <id> - Restore session from checkpoint");
println!(" /clear - Clear conversation history");
println!(" /plugins - Show loaded plugins and commands");
println!(" /exit - Exit interactive mode");
// Show plugin commands if any are loaded
let plugin_commands = app_context.plugin_manager.all_commands();
if !plugin_commands.is_empty() {
println!("\n📦 Plugin Commands:");
for (name, _path) in &plugin_commands {
println!(" /{}", name);
}
}
}
"/status" => {
println!("\n📊 Session Status:");
println!(" Model: {}", opts.model);
println!(" Mode: {:?}", settings.mode);
println!(" Messages: {}", stats.total_messages);
println!(" Tools: {} calls", stats.total_tool_calls);
let elapsed = stats.start_time.elapsed().unwrap_or_default();
println!(" Uptime: {}", agent_core::SessionStats::format_duration(elapsed));
}
"/permissions" => {
println!("\n🔒 Permission Settings:");
println!(" Mode: {:?}", perms.mode());
println!("\n Read-only tools: Read, Grep, Glob, NotebookRead");
match perms.mode() {
permissions::Mode::Plan => {
println!(" ✅ Allowed (plan mode)");
println!("\n Write tools: Write, Edit, NotebookEdit");
println!(" ❓ Ask permission");
println!("\n System tools: Bash");
println!(" ❓ Ask permission");
}
permissions::Mode::AcceptEdits => {
println!(" ✅ Allowed");
println!("\n Write tools: Write, Edit, NotebookEdit");
println!(" ✅ Allowed (acceptEdits mode)");
println!("\n System tools: Bash");
println!(" ❓ Ask permission");
}
permissions::Mode::Code => {
println!(" ✅ Allowed");
println!("\n Write tools: Write, Edit, NotebookEdit");
println!(" ✅ Allowed (code mode)");
println!("\n System tools: Bash");
println!(" ✅ Allowed (code mode)");
}
}
}
"/cost" => {
println!("\n💰 Token Usage & Timing:");
println!(" Est. Tokens: ~{}", stats.estimated_tokens);
println!(" Total Time: {}", agent_core::SessionStats::format_duration(stats.total_duration));
if stats.total_messages > 0 {
let avg_time = stats.total_duration / stats.total_messages as u32;
println!(" Avg/Message: {}", agent_core::SessionStats::format_duration(avg_time));
}
println!("\n Note: Ollama is free - no cost incurred!");
}
"/history" => {
println!("\n📜 Conversation History:");
if history.user_prompts.is_empty() {
println!(" (No messages yet)");
} else {
for (i, (user, assistant)) in history.user_prompts.iter()
.zip(history.assistant_responses.iter()).enumerate() {
println!("\n [{}] User: {}", i + 1, user);
println!(" Assistant: {}...",
assistant.chars().take(100).collect::<String>());
}
}
if !history.tool_calls.is_empty() {
println!("\n Tool Calls: {}", history.tool_calls.len());
}
}
"/checkpoint" => {
let checkpoint_id = format!("checkpoint-{}",
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
);
match checkpoint_mgr.save_checkpoint(
checkpoint_id.clone(),
stats.clone(),
&history,
) {
Ok(checkpoint) => {
println!("\n💾 Checkpoint saved: {}", checkpoint_id);
if !checkpoint.file_diffs.is_empty() {
println!(" Files tracked: {}", checkpoint.file_diffs.len());
}
}
Err(e) => {
eprintln!("\n❌ Failed to save checkpoint: {}", e);
}
}
}
"/checkpoints" => {
match checkpoint_mgr.list_checkpoints() {
Ok(checkpoints) => {
if checkpoints.is_empty() {
println!("\n📋 No checkpoints saved yet");
} else {
println!("\n📋 Saved Checkpoints:");
for (i, cp_id) in checkpoints.iter().enumerate() {
println!(" [{}] {}", i + 1, cp_id);
}
println!("\n Use /rewind <id> to restore");
}
}
Err(e) => {
eprintln!("\n❌ Failed to list checkpoints: {}", e);
}
}
}
"/clear" => {
history.clear();
stats = agent_core::SessionStats::new();
println!("\n🗑️ Session history cleared!");
}
"/plugins" => {
let plugins = app_context.plugin_manager.plugins();
if plugins.is_empty() {
println!("\n📦 No plugins loaded");
println!(" Place plugins in:");
println!(" - ~/.config/owlen/plugins (user plugins)");
println!(" - .owlen/plugins (project plugins)");
} else {
println!("\n📦 Loaded Plugins:");
for plugin in plugins {
println!("\n {} v{}", plugin.manifest.name, plugin.manifest.version);
if let Some(desc) = &plugin.manifest.description {
println!(" {}", desc);
}
if let Some(author) = &plugin.manifest.author {
println!(" Author: {}", author);
}
let commands = plugin.all_command_names();
if !commands.is_empty() {
println!(" Commands: {}", commands.join(", "));
}
let agents = plugin.all_agent_names();
if !agents.is_empty() {
println!(" Agents: {}", agents.join(", "));
}
let skills = plugin.all_skill_names();
if !skills.is_empty() {
println!(" Skills: {}", skills.join(", "));
}
}
}
}
"/exit" => {
println!("\n👋 Goodbye!");
break;
}
cmd if cmd.starts_with("/rewind ") => {
let checkpoint_id = cmd.strip_prefix("/rewind ").unwrap().trim();
match checkpoint_mgr.rewind_to(checkpoint_id) {
Ok(restored_files) => {
println!("\n⏪ Rewound to checkpoint: {}", checkpoint_id);
if !restored_files.is_empty() {
println!(" Restored files:");
for file in restored_files {
println!(" - {}", file.display());
}
}
// Load the checkpoint to restore history and stats
if let Ok(checkpoint) = checkpoint_mgr.load_checkpoint(checkpoint_id) {
stats = checkpoint.stats;
history.user_prompts = checkpoint.user_prompts;
history.assistant_responses = checkpoint.assistant_responses;
history.tool_calls = checkpoint.tool_calls;
println!(" Session state restored");
}
}
Err(e) => {
eprintln!("\n❌ Failed to rewind: {}", e);
}
}
}
_ => {
println!("\n❌ Unknown command: {}", input);
println!(" Type /help for available commands");
}
}
continue;
}
// Regular message - run through agent loop
history.add_user_message(input.to_string());
let start = SystemTime::now();
let ctx = agent_core::ToolContext::new();
match agent_core::run_agent_loop(&client, input, &opts, &perms, &ctx).await {
Ok(response) => {
println!("\n{}", response);
history.add_assistant_message(response.clone());
// Update stats
let duration = start.elapsed().unwrap_or_default();
let tokens = (input.len() + response.len()) / 4; // Rough estimate
stats.record_message(tokens, duration);
}
Err(e) => {
eprintln!("\n❌ Error: {}", e);
}
}
} else {
break;
}
}
return Ok(());
}
// Non-interactive mode - process single prompt
let prompt = args.prompt.join(" ");
let start_time = SystemTime::now();
// Handle different output formats
let ctx = agent_core::ToolContext::new();
match output_format {
OutputFormat::Text => {
// Text format: stream to stdout as before
let mut stream = client.chat_stream(&msgs, &opts).await?;
while let Some(chunk) = stream.try_next().await? {
if let Some(m) = chunk.message {
if let Some(c) = m.content {
print!("{c}");
io::stdout().flush()?;
}
}
if matches!(chunk.done, Some(true)) {
break;
}
}
println!(); // Newline after response
// Text format: Use agent orchestrator with tool calling
let response = agent_core::run_agent_loop(&client, &prompt, &opts, &perms, &ctx).await?;
println!("{}", response);
}
OutputFormat::Json => {
// JSON format: collect all chunks, then output final JSON
let mut stream = client.chat_stream(&msgs, &opts).await?;
let mut response = String::new();
while let Some(chunk) = stream.try_next().await? {
if let Some(m) = chunk.message {
if let Some(c) = m.content {
response.push_str(&c);
}
}
if matches!(chunk.done, Some(true)) {
break;
}
}
// JSON format: Use agent loop and output as JSON
let response = agent_core::run_agent_loop(&client, &prompt, &opts, &perms, &ctx).await?;
let duration_ms = start_time.elapsed().unwrap().as_millis() as u64;
// Rough token estimate (tokens ~= chars / 4)
let estimated_tokens = ((prompt.len() + response.len()) / 4) as u64;
let output = SessionOutput {
@@ -526,7 +832,7 @@ async fn main() -> Result<()> {
println!("{}", serde_json::to_string(&output)?);
}
OutputFormat::StreamJson => {
// Stream-JSON format: emit session_start, chunks, and session_end
// Stream-JSON format: emit session_start, response, and session_end
let session_start = StreamEvent {
event_type: "session_start".to_string(),
session_id: Some(session_id.clone()),
@@ -535,30 +841,17 @@ async fn main() -> Result<()> {
};
println!("{}", serde_json::to_string(&session_start)?);
let mut stream = client.chat_stream(&msgs, &opts).await?;
let mut response = String::new();
let response = agent_core::run_agent_loop(&client, &prompt, &opts, &perms, &ctx).await?;
while let Some(chunk) = stream.try_next().await? {
if let Some(m) = chunk.message {
if let Some(c) = m.content {
response.push_str(&c);
let chunk_event = StreamEvent {
event_type: "chunk".to_string(),
session_id: None,
content: Some(c),
stats: None,
};
println!("{}", serde_json::to_string(&chunk_event)?);
}
}
if matches!(chunk.done, Some(true)) {
break;
}
}
let chunk_event = StreamEvent {
event_type: "chunk".to_string(),
session_id: None,
content: Some(response.clone()),
stats: None,
};
println!("{}", serde_json::to_string(&chunk_event)?);
let duration_ms = start_time.elapsed().unwrap().as_millis() as u64;
// Rough token estimate
let estimated_tokens = ((prompt.len() + response.len()) / 4) as u64;
let session_end = StreamEvent {

View File

@@ -5,12 +5,6 @@ use predicates::prelude::PredicateBooleanExt;
#[tokio::test]
async fn headless_streams_ndjson() {
let server = MockServer::start_async().await;
// Mock /api/chat with NDJSON lines
let body = serde_json::json!({
"model": "qwen2.5",
"messages": [{"role": "user", "content": "hello"}],
"stream": true
});
let response = concat!(
r#"{"message":{"role":"assistant","content":"Hel"}}"#,"\n",
@@ -18,10 +12,11 @@ async fn headless_streams_ndjson() {
r#"{"done":true}"#,"\n",
);
// The CLI includes tools in the request, so we need to match any request to /api/chat
// instead of matching exact body (which includes tool definitions)
let _m = server.mock(|when, then| {
when.method(POST)
.path("/api/chat")
.json_body(body.clone());
.path("/api/chat");
then.status(200)
.header("content-type", "application/x-ndjson")
.body(response);

27
crates/app/ui/Cargo.toml Normal file
View File

@@ -0,0 +1,27 @@
[package]
name = "ui"
version = "0.1.0"
edition.workspace = true
license.workspace = true
rust-version.workspace = true
[dependencies]
color-eyre = "0.6"
crossterm = { version = "0.28", features = ["event-stream"] }
ratatui = "0.28"
tokio = { version = "1", features = ["full"] }
futures = "0.3"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
unicode-width = "0.2"
textwrap = "0.16"
syntect = { version = "5.0", default-features = false, features = ["default-syntaxes", "default-themes", "regex-onig"] }
pulldown-cmark = "0.11"
# Internal dependencies
agent-core = { path = "../../core/agent" }
permissions = { path = "../../platform/permissions" }
llm-core = { path = "../../llm/core" }
llm-ollama = { path = "../../llm/ollama" }
config-agent = { path = "../../platform/config" }
tools-todo = { path = "../../tools/todo" }

1101
crates/app/ui/src/app.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,226 @@
//! Command completion engine for the TUI
//!
//! Provides Tab-completion for slash commands, file paths, and tool names.
use std::path::Path;
/// A single completion suggestion
#[derive(Debug, Clone)]
pub struct Completion {
/// The text to insert
pub text: String,
/// Description of what this completion does
pub description: String,
/// Source of the completion (e.g., "builtin", "plugin:name")
pub source: String,
}
/// Information about a command for completion purposes
#[derive(Debug, Clone)]
pub struct CommandInfo {
/// Command name (without leading /)
pub name: String,
/// Command description
pub description: String,
/// Source of the command
pub source: String,
}
impl CommandInfo {
pub fn new(name: &str, description: &str, source: &str) -> Self {
Self {
name: name.to_string(),
description: description.to_string(),
source: source.to_string(),
}
}
}
/// Completion engine for the TUI
pub struct CompletionEngine {
/// Available commands
commands: Vec<CommandInfo>,
}
impl Default for CompletionEngine {
fn default() -> Self {
Self::new()
}
}
impl CompletionEngine {
pub fn new() -> Self {
Self {
commands: Self::builtin_commands(),
}
}
/// Get built-in commands
fn builtin_commands() -> Vec<CommandInfo> {
vec![
CommandInfo::new("help", "Show available commands and help", "builtin"),
CommandInfo::new("clear", "Clear the screen", "builtin"),
CommandInfo::new("mcp", "List MCP servers and their tools", "builtin"),
CommandInfo::new("hooks", "Show loaded hooks", "builtin"),
CommandInfo::new("compact", "Compact conversation context", "builtin"),
CommandInfo::new("mode", "Switch permission mode (plan/edit/code)", "builtin"),
CommandInfo::new("provider", "Switch LLM provider", "builtin"),
CommandInfo::new("model", "Switch LLM model", "builtin"),
CommandInfo::new("checkpoint", "Create a checkpoint", "builtin"),
CommandInfo::new("rewind", "Rewind to a checkpoint", "builtin"),
]
}
/// Add commands from plugins
pub fn add_plugin_commands(&mut self, plugin_name: &str, commands: Vec<CommandInfo>) {
for mut cmd in commands {
cmd.source = format!("plugin:{}", plugin_name);
self.commands.push(cmd);
}
}
/// Add a single command
pub fn add_command(&mut self, command: CommandInfo) {
self.commands.push(command);
}
/// Get completions for the given input
pub fn complete(&self, input: &str) -> Vec<Completion> {
if input.starts_with('/') {
self.complete_command(&input[1..])
} else if input.starts_with('@') {
self.complete_file_path(&input[1..])
} else {
vec![]
}
}
/// Complete a slash command
fn complete_command(&self, partial: &str) -> Vec<Completion> {
let partial_lower = partial.to_lowercase();
self.commands
.iter()
.filter(|cmd| {
// Match if name starts with partial, or contains partial (fuzzy)
cmd.name.to_lowercase().starts_with(&partial_lower)
|| (partial.len() >= 2 && cmd.name.to_lowercase().contains(&partial_lower))
})
.map(|cmd| Completion {
text: format!("/{}", cmd.name),
description: cmd.description.clone(),
source: cmd.source.clone(),
})
.collect()
}
/// Complete a file path
fn complete_file_path(&self, partial: &str) -> Vec<Completion> {
let path = Path::new(partial);
// Get the directory to search and the prefix to match
let (dir, prefix) = if partial.ends_with('/') || partial.is_empty() {
(partial, "")
} else {
let parent = path.parent().map(|p| p.to_str().unwrap_or("")).unwrap_or("");
let file_name = path.file_name().and_then(|f| f.to_str()).unwrap_or("");
(parent, file_name)
};
// Search directory
let search_dir = if dir.is_empty() { "." } else { dir };
match std::fs::read_dir(search_dir) {
Ok(entries) => {
entries
.filter_map(|entry| entry.ok())
.filter(|entry| {
let name = entry.file_name();
let name_str = name.to_string_lossy();
// Skip hidden files unless user started typing with .
if !prefix.starts_with('.') && name_str.starts_with('.') {
return false;
}
name_str.to_lowercase().starts_with(&prefix.to_lowercase())
})
.map(|entry| {
let name = entry.file_name();
let name_str = name.to_string_lossy();
let is_dir = entry.file_type().map(|t| t.is_dir()).unwrap_or(false);
let full_path = if dir.is_empty() {
name_str.to_string()
} else if dir.ends_with('/') {
format!("{}{}", dir, name_str)
} else {
format!("{}/{}", dir, name_str)
};
Completion {
text: format!("@{}{}", full_path, if is_dir { "/" } else { "" }),
description: if is_dir { "Directory".to_string() } else { "File".to_string() },
source: "filesystem".to_string(),
}
})
.collect()
}
Err(_) => vec![],
}
}
/// Get all commands (for /help display)
pub fn all_commands(&self) -> &[CommandInfo] {
&self.commands
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_command_completion_exact() {
let engine = CompletionEngine::new();
let completions = engine.complete("/help");
assert!(!completions.is_empty());
assert!(completions.iter().any(|c| c.text == "/help"));
}
#[test]
fn test_command_completion_partial() {
let engine = CompletionEngine::new();
let completions = engine.complete("/hel");
assert!(!completions.is_empty());
assert!(completions.iter().any(|c| c.text == "/help"));
}
#[test]
fn test_command_completion_fuzzy() {
let engine = CompletionEngine::new();
// "cle" should match "clear"
let completions = engine.complete("/cle");
assert!(!completions.is_empty());
assert!(completions.iter().any(|c| c.text == "/clear"));
}
#[test]
fn test_command_info() {
let info = CommandInfo::new("test", "A test command", "builtin");
assert_eq!(info.name, "test");
assert_eq!(info.description, "A test command");
assert_eq!(info.source, "builtin");
}
#[test]
fn test_add_plugin_commands() {
let mut engine = CompletionEngine::new();
let plugin_cmds = vec![
CommandInfo::new("custom", "A custom command", ""),
];
engine.add_plugin_commands("my-plugin", plugin_cmds);
let completions = engine.complete("/custom");
assert!(!completions.is_empty());
assert!(completions.iter().any(|c| c.source == "plugin:my-plugin"));
}
}

View File

@@ -0,0 +1,377 @@
//! Command autocomplete dropdown component
//!
//! Displays inline autocomplete suggestions when user types `/`.
//! Supports fuzzy filtering as user types.
use crate::theme::Theme;
use crossterm::event::{KeyCode, KeyEvent};
use ratatui::{
layout::Rect,
style::Style,
text::{Line, Span},
widgets::{Block, Borders, Clear, Paragraph},
Frame,
};
/// An autocomplete option
#[derive(Debug, Clone)]
pub struct AutocompleteOption {
/// The trigger text (command name without /)
pub trigger: String,
/// Display text (e.g., "/model [name]")
pub display: String,
/// Short description
pub description: String,
/// Has submenu/subcommands
pub has_submenu: bool,
}
impl AutocompleteOption {
pub fn new(trigger: &str, description: &str) -> Self {
Self {
trigger: trigger.to_string(),
display: format!("/{}", trigger),
description: description.to_string(),
has_submenu: false,
}
}
pub fn with_args(trigger: &str, args: &str, description: &str) -> Self {
Self {
trigger: trigger.to_string(),
display: format!("/{} {}", trigger, args),
description: description.to_string(),
has_submenu: false,
}
}
pub fn with_submenu(trigger: &str, description: &str) -> Self {
Self {
trigger: trigger.to_string(),
display: format!("/{}", trigger),
description: description.to_string(),
has_submenu: true,
}
}
}
/// Default command options
fn default_options() -> Vec<AutocompleteOption> {
vec![
AutocompleteOption::new("help", "Show help"),
AutocompleteOption::new("status", "Session info"),
AutocompleteOption::with_args("model", "[name]", "Switch model"),
AutocompleteOption::with_args("provider", "[name]", "Switch provider"),
AutocompleteOption::new("history", "View history"),
AutocompleteOption::new("checkpoint", "Save state"),
AutocompleteOption::new("checkpoints", "List checkpoints"),
AutocompleteOption::with_args("rewind", "[id]", "Restore"),
AutocompleteOption::new("cost", "Token usage"),
AutocompleteOption::new("clear", "Clear chat"),
AutocompleteOption::new("compact", "Compact context"),
AutocompleteOption::new("permissions", "Permission mode"),
AutocompleteOption::new("themes", "List themes"),
AutocompleteOption::with_args("theme", "[name]", "Switch theme"),
AutocompleteOption::new("exit", "Exit"),
]
}
/// Autocomplete dropdown component
pub struct Autocomplete {
options: Vec<AutocompleteOption>,
filtered: Vec<usize>, // indices into options
selected: usize,
visible: bool,
theme: Theme,
}
impl Autocomplete {
pub fn new(theme: Theme) -> Self {
let options = default_options();
let filtered: Vec<usize> = (0..options.len()).collect();
Self {
options,
filtered,
selected: 0,
visible: false,
theme,
}
}
/// Show autocomplete and reset filter
pub fn show(&mut self) {
self.visible = true;
self.filtered = (0..self.options.len()).collect();
self.selected = 0;
}
/// Hide autocomplete
pub fn hide(&mut self) {
self.visible = false;
}
/// Check if visible
pub fn is_visible(&self) -> bool {
self.visible
}
/// Update filter based on current input (text after /)
pub fn update_filter(&mut self, query: &str) {
if query.is_empty() {
self.filtered = (0..self.options.len()).collect();
} else {
let query_lower = query.to_lowercase();
self.filtered = self.options
.iter()
.enumerate()
.filter(|(_, opt)| {
// Fuzzy match: check if query chars appear in order
fuzzy_match(&opt.trigger.to_lowercase(), &query_lower)
})
.map(|(i, _)| i)
.collect();
}
// Reset selection if it's out of bounds
if self.selected >= self.filtered.len() {
self.selected = 0;
}
}
/// Select next option
pub fn select_next(&mut self) {
if !self.filtered.is_empty() {
self.selected = (self.selected + 1) % self.filtered.len();
}
}
/// Select previous option
pub fn select_prev(&mut self) {
if !self.filtered.is_empty() {
self.selected = if self.selected == 0 {
self.filtered.len() - 1
} else {
self.selected - 1
};
}
}
/// Get the currently selected option's trigger
pub fn confirm(&self) -> Option<String> {
if self.filtered.is_empty() {
return None;
}
let idx = self.filtered[self.selected];
Some(format!("/{}", self.options[idx].trigger))
}
/// Handle key input, returns Some(command) if confirmed
///
/// Key behavior:
/// - Tab: Confirm selection and insert into input
/// - Down/Up: Navigate options
/// - Enter: Pass through to submit (NotHandled)
/// - Esc: Cancel autocomplete
pub fn handle_key(&mut self, key: KeyEvent) -> AutocompleteResult {
if !self.visible {
return AutocompleteResult::NotHandled;
}
match key.code {
KeyCode::Tab => {
// Tab confirms and inserts the selected command
if let Some(cmd) = self.confirm() {
self.hide();
AutocompleteResult::Confirmed(cmd)
} else {
AutocompleteResult::Handled
}
}
KeyCode::Down => {
self.select_next();
AutocompleteResult::Handled
}
KeyCode::BackTab | KeyCode::Up => {
self.select_prev();
AutocompleteResult::Handled
}
KeyCode::Enter => {
// Enter should submit the message, not confirm autocomplete
// Hide autocomplete and let Enter pass through
self.hide();
AutocompleteResult::NotHandled
}
KeyCode::Esc => {
self.hide();
AutocompleteResult::Cancelled
}
_ => AutocompleteResult::NotHandled,
}
}
/// Update theme
pub fn set_theme(&mut self, theme: Theme) {
self.theme = theme;
}
/// Add custom options (from plugins)
pub fn add_options(&mut self, options: Vec<AutocompleteOption>) {
self.options.extend(options);
// Re-filter with all options
self.filtered = (0..self.options.len()).collect();
}
/// Render the autocomplete dropdown above the input line
pub fn render(&self, frame: &mut Frame, input_area: Rect) {
if !self.visible || self.filtered.is_empty() {
return;
}
// Calculate dropdown dimensions
let max_visible = 8.min(self.filtered.len());
let width = 40.min(input_area.width.saturating_sub(4));
let height = (max_visible + 2) as u16; // +2 for borders
// Position above input, left-aligned with some padding
let x = input_area.x + 2;
let y = input_area.y.saturating_sub(height);
let dropdown_area = Rect::new(x, y, width, height);
// Clear area behind dropdown
frame.render_widget(Clear, dropdown_area);
// Build option lines
let mut lines: Vec<Line> = Vec::new();
for (display_idx, &opt_idx) in self.filtered.iter().take(max_visible).enumerate() {
let opt = &self.options[opt_idx];
let is_selected = display_idx == self.selected;
let style = if is_selected {
self.theme.selected
} else {
Style::default()
};
let mut spans = vec![
Span::styled(" ", style),
Span::styled("/", if is_selected { style } else { self.theme.cmd_slash }),
Span::styled(&opt.trigger, if is_selected { style } else { self.theme.cmd_name }),
];
// Submenu indicator
if opt.has_submenu {
spans.push(Span::styled(" >", if is_selected { style } else { self.theme.cmd_desc }));
}
// Pad to fixed width for consistent selection highlighting
let current_len: usize = spans.iter().map(|s| s.content.len()).sum();
let padding = (width as usize).saturating_sub(current_len + 1);
spans.push(Span::styled(" ".repeat(padding), style));
lines.push(Line::from(spans));
}
// Show overflow indicator if needed
if self.filtered.len() > max_visible {
lines.push(Line::from(Span::styled(
format!(" ... +{} more", self.filtered.len() - max_visible),
self.theme.cmd_desc,
)));
}
let block = Block::default()
.borders(Borders::ALL)
.border_style(Style::default().fg(self.theme.palette.border))
.style(self.theme.overlay_bg);
let paragraph = Paragraph::new(lines).block(block);
frame.render_widget(paragraph, dropdown_area);
}
}
/// Result of handling autocomplete key
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AutocompleteResult {
/// Key was not handled by autocomplete
NotHandled,
/// Key was handled, no action needed
Handled,
/// User confirmed selection, returns command string
Confirmed(String),
/// User cancelled autocomplete
Cancelled,
}
/// Simple fuzzy match: check if query chars appear in order in text
fn fuzzy_match(text: &str, query: &str) -> bool {
let mut text_chars = text.chars().peekable();
for query_char in query.chars() {
loop {
match text_chars.next() {
Some(c) if c == query_char => break,
Some(_) => continue,
None => return false,
}
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fuzzy_match() {
assert!(fuzzy_match("help", "h"));
assert!(fuzzy_match("help", "he"));
assert!(fuzzy_match("help", "hel"));
assert!(fuzzy_match("help", "help"));
assert!(fuzzy_match("help", "hp")); // fuzzy: h...p
assert!(!fuzzy_match("help", "x"));
assert!(!fuzzy_match("help", "helping")); // query longer than text
}
#[test]
fn test_autocomplete_filter() {
let theme = Theme::default();
let mut ac = Autocomplete::new(theme);
ac.update_filter("he");
assert!(ac.filtered.len() < ac.options.len());
// Should match "help"
assert!(ac.filtered.iter().any(|&i| ac.options[i].trigger == "help"));
}
#[test]
fn test_autocomplete_navigation() {
let theme = Theme::default();
let mut ac = Autocomplete::new(theme);
ac.show();
assert_eq!(ac.selected, 0);
ac.select_next();
assert_eq!(ac.selected, 1);
ac.select_prev();
assert_eq!(ac.selected, 0);
}
#[test]
fn test_autocomplete_confirm() {
let theme = Theme::default();
let mut ac = Autocomplete::new(theme);
ac.show();
let cmd = ac.confirm();
assert!(cmd.is_some());
assert!(cmd.unwrap().starts_with("/"));
}
}

View File

@@ -0,0 +1,468 @@
//! Borderless chat panel component
//!
//! Displays chat messages with proper indentation, timestamps,
//! and streaming indicators. Uses whitespace instead of borders.
use crate::theme::Theme;
use ratatui::{
layout::Rect,
style::{Modifier, Style},
text::{Line, Span, Text},
widgets::{Paragraph, Scrollbar, ScrollbarOrientation, ScrollbarState},
Frame,
};
use std::time::SystemTime;
/// Chat message types
#[derive(Debug, Clone)]
pub enum ChatMessage {
User(String),
Assistant(String),
ToolCall { name: String, args: String },
ToolResult { success: bool, output: String },
System(String),
}
impl ChatMessage {
/// Get a timestamp for when the message was created (for display)
pub fn timestamp_display() -> String {
let now = SystemTime::now();
let secs = now
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let hours = (secs / 3600) % 24;
let mins = (secs / 60) % 60;
format!("{:02}:{:02}", hours, mins)
}
}
/// Message with metadata for display
#[derive(Debug, Clone)]
pub struct DisplayMessage {
pub message: ChatMessage,
pub timestamp: String,
pub focused: bool,
}
impl DisplayMessage {
pub fn new(message: ChatMessage) -> Self {
Self {
message,
timestamp: ChatMessage::timestamp_display(),
focused: false,
}
}
}
/// Borderless chat panel
pub struct ChatPanel {
messages: Vec<DisplayMessage>,
scroll_offset: usize,
auto_scroll: bool,
total_lines: usize,
focused_index: Option<usize>,
is_streaming: bool,
theme: Theme,
}
impl ChatPanel {
/// Create new borderless chat panel
pub fn new(theme: Theme) -> Self {
Self {
messages: Vec::new(),
scroll_offset: 0,
auto_scroll: true,
total_lines: 0,
focused_index: None,
is_streaming: false,
theme,
}
}
/// Add a new message
pub fn add_message(&mut self, message: ChatMessage) {
self.messages.push(DisplayMessage::new(message));
self.auto_scroll = true;
self.is_streaming = false;
}
/// Append content to the last assistant message, or create a new one
pub fn append_to_assistant(&mut self, content: &str) {
if let Some(DisplayMessage {
message: ChatMessage::Assistant(last_content),
..
}) = self.messages.last_mut()
{
last_content.push_str(content);
} else {
self.messages.push(DisplayMessage::new(ChatMessage::Assistant(
content.to_string(),
)));
}
self.auto_scroll = true;
self.is_streaming = true;
}
/// Set streaming state
pub fn set_streaming(&mut self, streaming: bool) {
self.is_streaming = streaming;
}
/// Scroll up
pub fn scroll_up(&mut self, amount: usize) {
self.scroll_offset = self.scroll_offset.saturating_sub(amount);
self.auto_scroll = false;
}
/// Scroll down
pub fn scroll_down(&mut self, amount: usize) {
self.scroll_offset = self.scroll_offset.saturating_add(amount);
let near_bottom_threshold = 5;
if self.total_lines > 0 {
let max_scroll = self.total_lines.saturating_sub(1);
if self.scroll_offset.saturating_add(near_bottom_threshold) >= max_scroll {
self.auto_scroll = true;
}
}
}
/// Scroll to bottom
pub fn scroll_to_bottom(&mut self) {
self.scroll_offset = self.total_lines.saturating_sub(1);
self.auto_scroll = true;
}
/// Page up
pub fn page_up(&mut self, page_size: usize) {
self.scroll_up(page_size.saturating_sub(2));
}
/// Page down
pub fn page_down(&mut self, page_size: usize) {
self.scroll_down(page_size.saturating_sub(2));
}
/// Focus next message
pub fn focus_next(&mut self) {
if self.messages.is_empty() {
return;
}
self.focused_index = Some(match self.focused_index {
Some(i) if i + 1 < self.messages.len() => i + 1,
Some(_) => 0,
None => 0,
});
}
/// Focus previous message
pub fn focus_previous(&mut self) {
if self.messages.is_empty() {
return;
}
self.focused_index = Some(match self.focused_index {
Some(0) => self.messages.len() - 1,
Some(i) => i - 1,
None => self.messages.len() - 1,
});
}
/// Clear focus
pub fn clear_focus(&mut self) {
self.focused_index = None;
}
/// Get focused message index
pub fn focused_index(&self) -> Option<usize> {
self.focused_index
}
/// Get focused message
pub fn focused_message(&self) -> Option<&ChatMessage> {
self.focused_index
.and_then(|i| self.messages.get(i))
.map(|m| &m.message)
}
/// Update scroll position before rendering
pub fn update_scroll(&mut self, area: Rect) {
self.total_lines = self.count_total_lines(area);
if self.auto_scroll {
let visible_height = area.height as usize;
let max_scroll = self.total_lines.saturating_sub(visible_height);
self.scroll_offset = max_scroll;
} else {
let visible_height = area.height as usize;
let max_scroll = self.total_lines.saturating_sub(visible_height);
self.scroll_offset = self.scroll_offset.min(max_scroll);
}
}
/// Count total lines for scroll calculation
fn count_total_lines(&self, area: Rect) -> usize {
let mut line_count = 0;
let wrap_width = area.width.saturating_sub(4) as usize;
for msg in &self.messages {
line_count += match &msg.message {
ChatMessage::User(content) => {
let wrapped = textwrap::wrap(content, wrap_width);
wrapped.len() + 1 // +1 for spacing
}
ChatMessage::Assistant(content) => {
let wrapped = textwrap::wrap(content, wrap_width);
wrapped.len() + 1
}
ChatMessage::ToolCall { .. } => 2,
ChatMessage::ToolResult { .. } => 2,
ChatMessage::System(_) => 1,
};
}
line_count
}
/// Render the borderless chat panel
///
/// Message display format (no symbols, clean typography):
/// - Role: bold, appropriate color
/// - Timestamp: dim, same line as role
/// - Content: 2-space indent, normal weight
/// - Blank line between messages
pub fn render(&self, frame: &mut Frame, area: Rect) {
let mut text_lines = Vec::new();
let wrap_width = area.width.saturating_sub(4) as usize;
for (idx, display_msg) in self.messages.iter().enumerate() {
let is_focused = self.focused_index == Some(idx);
let is_last = idx == self.messages.len() - 1;
match &display_msg.message {
ChatMessage::User(content) => {
// Role line: "You" bold + timestamp dim
text_lines.push(Line::from(vec![
Span::styled(" ", Style::default()),
Span::styled("You", self.theme.user_message),
Span::styled(
format!(" {}", display_msg.timestamp),
self.theme.timestamp,
),
]));
// Message content with 2-space indent
let wrapped = textwrap::wrap(content, wrap_width);
for line in wrapped {
let style = if is_focused {
self.theme.user_message.add_modifier(Modifier::REVERSED)
} else {
self.theme.user_message.remove_modifier(Modifier::BOLD)
};
text_lines.push(Line::from(Span::styled(
format!(" {}", line),
style,
)));
}
// Focus hints
if is_focused {
text_lines.push(Line::from(Span::styled(
" [y]copy [e]edit [r]retry",
self.theme.status_dim,
)));
}
text_lines.push(Line::from(""));
}
ChatMessage::Assistant(content) => {
// Role line: streaming indicator (if active) + "Assistant" bold + timestamp
let mut role_spans = vec![Span::styled(" ", Style::default())];
// Streaming indicator (subtle, no symbol)
if is_last && self.is_streaming {
role_spans.push(Span::styled(
"... ",
Style::default().fg(self.theme.palette.success),
));
}
role_spans.push(Span::styled(
"Assistant",
self.theme.assistant_message.add_modifier(Modifier::BOLD),
));
role_spans.push(Span::styled(
format!(" {}", display_msg.timestamp),
self.theme.timestamp,
));
text_lines.push(Line::from(role_spans));
// Content
let wrapped = textwrap::wrap(content, wrap_width);
for line in wrapped {
let style = if is_focused {
self.theme.assistant_message.add_modifier(Modifier::REVERSED)
} else {
self.theme.assistant_message
};
text_lines.push(Line::from(Span::styled(
format!(" {}", line),
style,
)));
}
// Focus hints
if is_focused {
text_lines.push(Line::from(Span::styled(
" [y]copy [r]retry",
self.theme.status_dim,
)));
}
text_lines.push(Line::from(""));
}
ChatMessage::ToolCall { name, args } => {
// Tool calls: name in tool color, args dimmed
text_lines.push(Line::from(vec![
Span::styled(" ", Style::default()),
Span::styled(format!("{} ", name), self.theme.tool_call),
Span::styled(
truncate_str(args, 60),
self.theme.tool_call.add_modifier(Modifier::DIM),
),
]));
text_lines.push(Line::from(""));
}
ChatMessage::ToolResult { success, output } => {
// Tool results: status prefix + output
let (prefix, style) = if *success {
("ok ", self.theme.tool_result_success)
} else {
("err ", self.theme.tool_result_error)
};
text_lines.push(Line::from(vec![
Span::styled(" ", Style::default()),
Span::styled(prefix, style),
Span::styled(
truncate_str(output, 100),
style.remove_modifier(Modifier::BOLD),
),
]));
text_lines.push(Line::from(""));
}
ChatMessage::System(content) => {
// System messages: just dim text, no prefix
text_lines.push(Line::from(vec![
Span::styled(" ", Style::default()),
Span::styled(content.to_string(), self.theme.system_message),
]));
}
}
}
let text = Text::from(text_lines);
let paragraph = Paragraph::new(text).scroll((self.scroll_offset as u16, 0));
frame.render_widget(paragraph, area);
// Render scrollbar if needed
if self.total_lines > area.height as usize {
let scrollbar = Scrollbar::default()
.orientation(ScrollbarOrientation::VerticalRight)
.begin_symbol(None)
.end_symbol(None)
.track_symbol(Some(" "))
.thumb_symbol("")
.style(self.theme.status_dim);
let mut scrollbar_state = ScrollbarState::default()
.content_length(self.total_lines)
.position(self.scroll_offset);
frame.render_stateful_widget(scrollbar, area, &mut scrollbar_state);
}
}
/// Get messages
pub fn messages(&self) -> &[DisplayMessage] {
&self.messages
}
/// Clear all messages
pub fn clear(&mut self) {
self.messages.clear();
self.scroll_offset = 0;
self.focused_index = None;
}
/// Update theme
pub fn set_theme(&mut self, theme: Theme) {
self.theme = theme;
}
}
/// Truncate a string to max length with ellipsis
fn truncate_str(s: &str, max_len: usize) -> String {
if s.len() <= max_len {
s.to_string()
} else {
format!("{}...", &s[..max_len.saturating_sub(3)])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chat_panel_add_message() {
let theme = Theme::default();
let mut panel = ChatPanel::new(theme);
panel.add_message(ChatMessage::User("Hello".to_string()));
panel.add_message(ChatMessage::Assistant("Hi there!".to_string()));
assert_eq!(panel.messages().len(), 2);
}
#[test]
fn test_append_to_assistant() {
let theme = Theme::default();
let mut panel = ChatPanel::new(theme);
panel.append_to_assistant("Hello");
panel.append_to_assistant(" world");
assert_eq!(panel.messages().len(), 1);
if let ChatMessage::Assistant(content) = &panel.messages()[0].message {
assert_eq!(content, "Hello world");
}
}
#[test]
fn test_focus_navigation() {
let theme = Theme::default();
let mut panel = ChatPanel::new(theme);
panel.add_message(ChatMessage::User("1".to_string()));
panel.add_message(ChatMessage::User("2".to_string()));
panel.add_message(ChatMessage::User("3".to_string()));
assert_eq!(panel.focused_index(), None);
panel.focus_next();
assert_eq!(panel.focused_index(), Some(0));
panel.focus_next();
assert_eq!(panel.focused_index(), Some(1));
panel.focus_previous();
assert_eq!(panel.focused_index(), Some(0));
}
}

View File

@@ -0,0 +1,322 @@
//! Command help overlay component
//!
//! Modal overlay that displays available commands in a structured format.
//! Shown when user types `/help` or `?`. Supports scrolling with j/k or arrows.
use crate::theme::Theme;
use crossterm::event::{KeyCode, KeyEvent};
use ratatui::{
layout::Rect,
style::Style,
text::{Line, Span},
widgets::{Block, Borders, Clear, Paragraph, Scrollbar, ScrollbarOrientation, ScrollbarState},
Frame,
};
/// A single command definition
#[derive(Debug, Clone)]
pub struct Command {
pub name: &'static str,
pub args: Option<&'static str>,
pub description: &'static str,
}
impl Command {
pub const fn new(name: &'static str, description: &'static str) -> Self {
Self {
name,
args: None,
description,
}
}
pub const fn with_args(name: &'static str, args: &'static str, description: &'static str) -> Self {
Self {
name,
args: Some(args),
description,
}
}
}
/// Built-in commands
pub fn builtin_commands() -> Vec<Command> {
vec![
Command::new("help", "Show this help"),
Command::new("status", "Current session info"),
Command::with_args("model", "[name]", "Switch model"),
Command::with_args("provider", "[name]", "Switch provider (ollama, anthropic, openai)"),
Command::new("history", "Browse conversation history"),
Command::new("checkpoint", "Save conversation state"),
Command::new("checkpoints", "List saved checkpoints"),
Command::with_args("rewind", "[id]", "Restore checkpoint"),
Command::new("cost", "Show token usage"),
Command::new("clear", "Clear conversation"),
Command::new("compact", "Compact conversation context"),
Command::new("permissions", "Show permission mode"),
Command::new("themes", "List available themes"),
Command::with_args("theme", "[name]", "Switch theme"),
Command::new("exit", "Exit OWLEN"),
]
}
/// Command help overlay
pub struct CommandHelp {
commands: Vec<Command>,
visible: bool,
scroll_offset: usize,
theme: Theme,
}
impl CommandHelp {
pub fn new(theme: Theme) -> Self {
Self {
commands: builtin_commands(),
visible: false,
scroll_offset: 0,
theme,
}
}
/// Show the help overlay
pub fn show(&mut self) {
self.visible = true;
self.scroll_offset = 0; // Reset scroll when showing
}
/// Hide the help overlay
pub fn hide(&mut self) {
self.visible = false;
}
/// Check if visible
pub fn is_visible(&self) -> bool {
self.visible
}
/// Toggle visibility
pub fn toggle(&mut self) {
self.visible = !self.visible;
if self.visible {
self.scroll_offset = 0;
}
}
/// Scroll up by amount
fn scroll_up(&mut self, amount: usize) {
self.scroll_offset = self.scroll_offset.saturating_sub(amount);
}
/// Scroll down by amount, respecting max
fn scroll_down(&mut self, amount: usize, max_scroll: usize) {
self.scroll_offset = (self.scroll_offset + amount).min(max_scroll);
}
/// Handle key input, returns true if overlay handled the key
pub fn handle_key(&mut self, key: KeyEvent) -> bool {
if !self.visible {
return false;
}
// Calculate max scroll (commands + padding lines - visible area)
let total_lines = self.commands.len() + 3; // +3 for padding and footer
let max_scroll = total_lines.saturating_sub(10); // Assume ~10 visible lines
match key.code {
KeyCode::Esc | KeyCode::Char('q') | KeyCode::Char('?') => {
self.hide();
true
}
// Scroll navigation
KeyCode::Up | KeyCode::Char('k') => {
self.scroll_up(1);
true
}
KeyCode::Down | KeyCode::Char('j') => {
self.scroll_down(1, max_scroll);
true
}
KeyCode::PageUp | KeyCode::Char('u') => {
self.scroll_up(5);
true
}
KeyCode::PageDown | KeyCode::Char('d') => {
self.scroll_down(5, max_scroll);
true
}
KeyCode::Home | KeyCode::Char('g') => {
self.scroll_offset = 0;
true
}
KeyCode::End | KeyCode::Char('G') => {
self.scroll_offset = max_scroll;
true
}
_ => true, // Consume all other keys while visible
}
}
/// Update theme
pub fn set_theme(&mut self, theme: Theme) {
self.theme = theme;
}
/// Add plugin commands
pub fn add_commands(&mut self, commands: Vec<Command>) {
self.commands.extend(commands);
}
/// Render the help overlay
pub fn render(&self, frame: &mut Frame, area: Rect) {
if !self.visible {
return;
}
// Calculate overlay dimensions
let width = (area.width as f32 * 0.7).min(65.0) as u16;
let max_height = area.height.saturating_sub(4);
let content_height = self.commands.len() as u16 + 4; // +4 for padding and footer
let height = content_height.min(max_height).max(8);
// Center the overlay
let x = (area.width.saturating_sub(width)) / 2;
let y = (area.height.saturating_sub(height)) / 2;
let overlay_area = Rect::new(x, y, width, height);
// Clear the area behind the overlay
frame.render_widget(Clear, overlay_area);
// Build content lines
let mut lines: Vec<Line> = Vec::new();
// Empty line for padding
lines.push(Line::from(""));
// Command list
for cmd in &self.commands {
let name_with_args = if let Some(args) = cmd.args {
format!("/{} {}", cmd.name, args)
} else {
format!("/{}", cmd.name)
};
// Calculate padding for alignment
let name_width: usize = 22;
let padding = name_width.saturating_sub(name_with_args.len());
lines.push(Line::from(vec![
Span::styled(" ", Style::default()),
Span::styled("/", self.theme.cmd_slash),
Span::styled(
if let Some(args) = cmd.args {
format!("{} {}", cmd.name, args)
} else {
cmd.name.to_string()
},
self.theme.cmd_name,
),
Span::raw(" ".repeat(padding)),
Span::styled(cmd.description, self.theme.cmd_desc),
]));
}
// Empty line for padding
lines.push(Line::from(""));
// Footer hint with scroll info
let scroll_hint = if self.commands.len() > (height as usize - 4) {
format!(" (scroll: j/k or ↑/↓)")
} else {
String::new()
};
lines.push(Line::from(vec![
Span::styled(" Press ", self.theme.cmd_desc),
Span::styled("Esc", self.theme.cmd_name),
Span::styled(" to close", self.theme.cmd_desc),
Span::styled(scroll_hint, self.theme.cmd_desc),
]));
// Create the block with border
let block = Block::default()
.title(" Commands ")
.title_style(self.theme.popup_title)
.borders(Borders::ALL)
.border_style(self.theme.popup_border)
.style(self.theme.overlay_bg);
let paragraph = Paragraph::new(lines)
.block(block)
.scroll((self.scroll_offset as u16, 0));
frame.render_widget(paragraph, overlay_area);
// Render scrollbar if content exceeds visible area
let visible_height = height.saturating_sub(2) as usize; // -2 for borders
let total_lines = self.commands.len() + 3;
if total_lines > visible_height {
let scrollbar = Scrollbar::default()
.orientation(ScrollbarOrientation::VerticalRight)
.begin_symbol(None)
.end_symbol(None)
.track_symbol(Some(" "))
.thumb_symbol("")
.style(self.theme.status_dim);
let mut scrollbar_state = ScrollbarState::default()
.content_length(total_lines)
.position(self.scroll_offset);
// Adjust scrollbar area to be inside the border
let scrollbar_area = Rect::new(
overlay_area.x + overlay_area.width - 2,
overlay_area.y + 1,
1,
overlay_area.height.saturating_sub(2),
);
frame.render_stateful_widget(scrollbar, scrollbar_area, &mut scrollbar_state);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_command_help_visibility() {
let theme = Theme::default();
let mut help = CommandHelp::new(theme);
assert!(!help.is_visible());
help.show();
assert!(help.is_visible());
help.hide();
assert!(!help.is_visible());
}
#[test]
fn test_builtin_commands() {
let commands = builtin_commands();
assert!(!commands.is_empty());
assert!(commands.iter().any(|c| c.name == "help"));
assert!(commands.iter().any(|c| c.name == "provider"));
}
#[test]
fn test_scroll_navigation() {
let theme = Theme::default();
let mut help = CommandHelp::new(theme);
help.show();
assert_eq!(help.scroll_offset, 0);
help.scroll_down(3, 10);
assert_eq!(help.scroll_offset, 3);
help.scroll_up(1);
assert_eq!(help.scroll_offset, 2);
help.scroll_up(10); // Should clamp to 0
assert_eq!(help.scroll_offset, 0);
}
}

View File

@@ -0,0 +1,507 @@
//! Vim-modal input component
//!
//! Borderless input with vim-like modes (Normal, Insert, Command).
//! Uses mode prefix instead of borders for visual indication.
use crate::theme::{Theme, VimMode};
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
use ratatui::{
layout::Rect,
style::Style,
text::{Line, Span},
widgets::Paragraph,
Frame,
};
/// Input event from the input box
#[derive(Debug, Clone)]
pub enum InputEvent {
/// User submitted a message
Message(String),
/// User submitted a command (without / prefix)
Command(String),
/// Mode changed
ModeChange(VimMode),
/// Request to cancel current operation
Cancel,
/// Request to expand input (multiline)
Expand,
}
/// Vim-modal input box
pub struct InputBox {
input: String,
cursor_position: usize,
history: Vec<String>,
history_index: usize,
mode: VimMode,
theme: Theme,
}
impl InputBox {
pub fn new(theme: Theme) -> Self {
Self {
input: String::new(),
cursor_position: 0,
history: Vec::new(),
history_index: 0,
mode: VimMode::Insert, // Start in insert mode for familiarity
theme,
}
}
/// Get current vim mode
pub fn mode(&self) -> VimMode {
self.mode
}
/// Set vim mode
pub fn set_mode(&mut self, mode: VimMode) {
self.mode = mode;
}
/// Handle key event, returns input event if action is needed
pub fn handle_key(&mut self, key: KeyEvent) -> Option<InputEvent> {
match self.mode {
VimMode::Normal => self.handle_normal_mode(key),
VimMode::Insert => self.handle_insert_mode(key),
VimMode::Command => self.handle_command_mode(key),
VimMode::Visual => self.handle_visual_mode(key),
}
}
/// Handle keys in normal mode
fn handle_normal_mode(&mut self, key: KeyEvent) -> Option<InputEvent> {
match key.code {
// Enter insert mode
KeyCode::Char('i') => {
self.mode = VimMode::Insert;
Some(InputEvent::ModeChange(VimMode::Insert))
}
KeyCode::Char('a') => {
self.mode = VimMode::Insert;
if self.cursor_position < self.input.len() {
self.cursor_position += 1;
}
Some(InputEvent::ModeChange(VimMode::Insert))
}
KeyCode::Char('I') => {
self.mode = VimMode::Insert;
self.cursor_position = 0;
Some(InputEvent::ModeChange(VimMode::Insert))
}
KeyCode::Char('A') => {
self.mode = VimMode::Insert;
self.cursor_position = self.input.len();
Some(InputEvent::ModeChange(VimMode::Insert))
}
// Enter command mode
KeyCode::Char(':') => {
self.mode = VimMode::Command;
self.input.clear();
self.cursor_position = 0;
Some(InputEvent::ModeChange(VimMode::Command))
}
// Navigation
KeyCode::Char('h') | KeyCode::Left => {
self.cursor_position = self.cursor_position.saturating_sub(1);
None
}
KeyCode::Char('l') | KeyCode::Right => {
if self.cursor_position < self.input.len() {
self.cursor_position += 1;
}
None
}
KeyCode::Char('0') | KeyCode::Home => {
self.cursor_position = 0;
None
}
KeyCode::Char('$') | KeyCode::End => {
self.cursor_position = self.input.len();
None
}
KeyCode::Char('w') => {
// Jump to next word
self.cursor_position = self.next_word_position();
None
}
KeyCode::Char('b') => {
// Jump to previous word
self.cursor_position = self.prev_word_position();
None
}
// Editing
KeyCode::Char('x') => {
if self.cursor_position < self.input.len() {
self.input.remove(self.cursor_position);
}
None
}
KeyCode::Char('d') => {
// Delete line (dd would require tracking, simplify to clear)
self.input.clear();
self.cursor_position = 0;
None
}
// History
KeyCode::Char('k') | KeyCode::Up => {
self.history_prev();
None
}
KeyCode::Char('j') | KeyCode::Down => {
self.history_next();
None
}
_ => None,
}
}
/// Handle keys in insert mode
fn handle_insert_mode(&mut self, key: KeyEvent) -> Option<InputEvent> {
match key.code {
KeyCode::Esc => {
self.mode = VimMode::Normal;
// Move cursor back when exiting insert mode (vim behavior)
if self.cursor_position > 0 {
self.cursor_position -= 1;
}
Some(InputEvent::ModeChange(VimMode::Normal))
}
KeyCode::Enter => {
let message = self.input.clone();
if !message.trim().is_empty() {
self.history.push(message.clone());
self.history_index = self.history.len();
self.input.clear();
self.cursor_position = 0;
return Some(InputEvent::Message(message));
}
None
}
KeyCode::Char('e') if key.modifiers.contains(KeyModifiers::CONTROL) => {
Some(InputEvent::Expand)
}
KeyCode::Char('c') if key.modifiers.contains(KeyModifiers::CONTROL) => {
Some(InputEvent::Cancel)
}
KeyCode::Char(c) => {
self.input.insert(self.cursor_position, c);
self.cursor_position += 1;
None
}
KeyCode::Backspace => {
if self.cursor_position > 0 {
self.input.remove(self.cursor_position - 1);
self.cursor_position -= 1;
}
None
}
KeyCode::Delete => {
if self.cursor_position < self.input.len() {
self.input.remove(self.cursor_position);
}
None
}
KeyCode::Left => {
self.cursor_position = self.cursor_position.saturating_sub(1);
None
}
KeyCode::Right => {
if self.cursor_position < self.input.len() {
self.cursor_position += 1;
}
None
}
KeyCode::Home => {
self.cursor_position = 0;
None
}
KeyCode::End => {
self.cursor_position = self.input.len();
None
}
KeyCode::Up => {
self.history_prev();
None
}
KeyCode::Down => {
self.history_next();
None
}
_ => None,
}
}
/// Handle keys in command mode
fn handle_command_mode(&mut self, key: KeyEvent) -> Option<InputEvent> {
match key.code {
KeyCode::Esc => {
self.mode = VimMode::Normal;
self.input.clear();
self.cursor_position = 0;
Some(InputEvent::ModeChange(VimMode::Normal))
}
KeyCode::Enter => {
let command = self.input.clone();
self.mode = VimMode::Normal;
self.input.clear();
self.cursor_position = 0;
if !command.trim().is_empty() {
return Some(InputEvent::Command(command));
}
Some(InputEvent::ModeChange(VimMode::Normal))
}
KeyCode::Char(c) => {
self.input.insert(self.cursor_position, c);
self.cursor_position += 1;
None
}
KeyCode::Backspace => {
if self.cursor_position > 0 {
self.input.remove(self.cursor_position - 1);
self.cursor_position -= 1;
} else {
// Empty command, exit to normal mode
self.mode = VimMode::Normal;
return Some(InputEvent::ModeChange(VimMode::Normal));
}
None
}
KeyCode::Left => {
self.cursor_position = self.cursor_position.saturating_sub(1);
None
}
KeyCode::Right => {
if self.cursor_position < self.input.len() {
self.cursor_position += 1;
}
None
}
_ => None,
}
}
/// Handle keys in visual mode (simplified)
fn handle_visual_mode(&mut self, key: KeyEvent) -> Option<InputEvent> {
match key.code {
KeyCode::Esc => {
self.mode = VimMode::Normal;
Some(InputEvent::ModeChange(VimMode::Normal))
}
_ => None,
}
}
/// History navigation - previous
fn history_prev(&mut self) {
if !self.history.is_empty() && self.history_index > 0 {
self.history_index -= 1;
self.input = self.history[self.history_index].clone();
self.cursor_position = self.input.len();
}
}
/// History navigation - next
fn history_next(&mut self) {
if self.history_index < self.history.len().saturating_sub(1) {
self.history_index += 1;
self.input = self.history[self.history_index].clone();
self.cursor_position = self.input.len();
} else if self.history_index < self.history.len() {
self.history_index = self.history.len();
self.input.clear();
self.cursor_position = 0;
}
}
/// Find next word position
fn next_word_position(&self) -> usize {
let bytes = self.input.as_bytes();
let mut pos = self.cursor_position;
// Skip current word
while pos < bytes.len() && !bytes[pos].is_ascii_whitespace() {
pos += 1;
}
// Skip whitespace
while pos < bytes.len() && bytes[pos].is_ascii_whitespace() {
pos += 1;
}
pos
}
/// Find previous word position
fn prev_word_position(&self) -> usize {
let bytes = self.input.as_bytes();
let mut pos = self.cursor_position.saturating_sub(1);
// Skip whitespace
while pos > 0 && bytes[pos].is_ascii_whitespace() {
pos -= 1;
}
// Skip to start of word
while pos > 0 && !bytes[pos - 1].is_ascii_whitespace() {
pos -= 1;
}
pos
}
/// Render the borderless input (single line)
pub fn render(&self, frame: &mut Frame, area: Rect) {
let is_empty = self.input.is_empty();
let symbols = &self.theme.symbols;
// Mode-specific prefix
let prefix = match self.mode {
VimMode::Normal => Span::styled(
format!("{} ", symbols.mode_normal),
self.theme.status_dim,
),
VimMode::Insert => Span::styled(
format!("{} ", symbols.user_prefix),
self.theme.input_prefix,
),
VimMode::Command => Span::styled(
": ",
self.theme.input_prefix,
),
VimMode::Visual => Span::styled(
format!("{} ", symbols.mode_visual),
self.theme.status_accent,
),
};
// Cursor position handling
let (text_before, cursor_char, text_after) = if self.cursor_position < self.input.len() {
let before = &self.input[..self.cursor_position];
let cursor = &self.input[self.cursor_position..self.cursor_position + 1];
let after = &self.input[self.cursor_position + 1..];
(before, cursor, after)
} else {
(&self.input[..], " ", "")
};
let line = if is_empty && self.mode == VimMode::Insert {
Line::from(vec![
Span::raw(" "),
prefix,
Span::styled("", self.theme.input_prefix),
Span::styled(" Type message...", self.theme.input_placeholder),
])
} else if is_empty && self.mode == VimMode::Command {
Line::from(vec![
Span::raw(" "),
prefix,
Span::styled("", self.theme.input_prefix),
])
} else {
// Build cursor span with appropriate styling
let cursor_style = if self.mode == VimMode::Normal {
Style::default()
.bg(self.theme.palette.fg)
.fg(self.theme.palette.bg)
} else {
self.theme.input_prefix
};
let cursor_span = if self.mode == VimMode::Normal && !is_empty {
Span::styled(cursor_char.to_string(), cursor_style)
} else {
Span::styled("", self.theme.input_prefix)
};
Line::from(vec![
Span::raw(" "),
prefix,
Span::styled(text_before.to_string(), self.theme.input_text),
cursor_span,
Span::styled(text_after.to_string(), self.theme.input_text),
])
};
let paragraph = Paragraph::new(line);
frame.render_widget(paragraph, area);
}
/// Clear input
pub fn clear(&mut self) {
self.input.clear();
self.cursor_position = 0;
}
/// Get current input text
pub fn text(&self) -> &str {
&self.input
}
/// Set input text
pub fn set_text(&mut self, text: String) {
self.input = text;
self.cursor_position = self.input.len();
}
/// Update theme
pub fn set_theme(&mut self, theme: Theme) {
self.theme = theme;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mode_transitions() {
let theme = Theme::default();
let mut input = InputBox::new(theme);
// Start in insert mode
assert_eq!(input.mode(), VimMode::Insert);
// Escape to normal mode
let event = input.handle_key(KeyEvent::from(KeyCode::Esc));
assert!(matches!(event, Some(InputEvent::ModeChange(VimMode::Normal))));
assert_eq!(input.mode(), VimMode::Normal);
// 'i' to insert mode
let event = input.handle_key(KeyEvent::from(KeyCode::Char('i')));
assert!(matches!(event, Some(InputEvent::ModeChange(VimMode::Insert))));
assert_eq!(input.mode(), VimMode::Insert);
}
#[test]
fn test_insert_text() {
let theme = Theme::default();
let mut input = InputBox::new(theme);
input.handle_key(KeyEvent::from(KeyCode::Char('h')));
input.handle_key(KeyEvent::from(KeyCode::Char('i')));
assert_eq!(input.text(), "hi");
}
#[test]
fn test_command_mode() {
let theme = Theme::default();
let mut input = InputBox::new(theme);
// Escape to normal, then : to command
input.handle_key(KeyEvent::from(KeyCode::Esc));
input.handle_key(KeyEvent::from(KeyCode::Char(':')));
assert_eq!(input.mode(), VimMode::Command);
// Type command
input.handle_key(KeyEvent::from(KeyCode::Char('q')));
input.handle_key(KeyEvent::from(KeyCode::Char('u')));
input.handle_key(KeyEvent::from(KeyCode::Char('i')));
input.handle_key(KeyEvent::from(KeyCode::Char('t')));
assert_eq!(input.text(), "quit");
// Submit command
let event = input.handle_key(KeyEvent::from(KeyCode::Enter));
assert!(matches!(event, Some(InputEvent::Command(cmd)) if cmd == "quit"));
}
}

View File

@@ -0,0 +1,19 @@
//! TUI components for the borderless multi-provider design
mod autocomplete;
mod chat_panel;
mod command_help;
mod input_box;
mod permission_popup;
mod provider_tabs;
mod status_bar;
mod todo_panel;
pub use autocomplete::{Autocomplete, AutocompleteOption, AutocompleteResult};
pub use chat_panel::{ChatMessage, ChatPanel, DisplayMessage};
pub use command_help::{Command, CommandHelp};
pub use input_box::{InputBox, InputEvent};
pub use permission_popup::{PermissionOption, PermissionPopup};
pub use provider_tabs::ProviderTabs;
pub use status_bar::{AppState, StatusBar};
pub use todo_panel::TodoPanel;

View File

@@ -0,0 +1,196 @@
use crate::theme::Theme;
use crossterm::event::{KeyCode, KeyEvent};
use permissions::PermissionDecision;
use ratatui::{
layout::{Constraint, Direction, Layout, Rect},
style::{Modifier, Style},
text::{Line, Span},
widgets::{Block, Borders, Clear, Paragraph},
Frame,
};
#[derive(Debug, Clone)]
pub enum PermissionOption {
AllowOnce,
AlwaysAllow,
Deny,
Explain,
}
pub struct PermissionPopup {
tool: String,
context: Option<String>,
selected: usize,
theme: Theme,
}
impl PermissionPopup {
pub fn new(tool: String, context: Option<String>, theme: Theme) -> Self {
Self {
tool,
context,
selected: 0,
theme,
}
}
pub fn handle_key(&mut self, key: KeyEvent) -> Option<PermissionOption> {
match key.code {
KeyCode::Char('a') => Some(PermissionOption::AllowOnce),
KeyCode::Char('A') => Some(PermissionOption::AlwaysAllow),
KeyCode::Char('d') => Some(PermissionOption::Deny),
KeyCode::Char('?') => Some(PermissionOption::Explain),
KeyCode::Up => {
self.selected = self.selected.saturating_sub(1);
None
}
KeyCode::Down => {
if self.selected < 3 {
self.selected += 1;
}
None
}
KeyCode::Enter => match self.selected {
0 => Some(PermissionOption::AllowOnce),
1 => Some(PermissionOption::AlwaysAllow),
2 => Some(PermissionOption::Deny),
3 => Some(PermissionOption::Explain),
_ => None,
},
KeyCode::Esc => Some(PermissionOption::Deny),
_ => None,
}
}
pub fn render(&self, frame: &mut Frame, area: Rect) {
// Center the popup
let popup_area = crate::layout::AppLayout::center_popup(area, 64, 14);
// Clear the area behind the popup
frame.render_widget(Clear, popup_area);
// Render popup with styled border
let block = Block::default()
.borders(Borders::ALL)
.border_style(self.theme.popup_border)
.style(self.theme.popup_bg)
.title(Line::from(vec![
Span::raw(" "),
Span::styled("🔒", self.theme.popup_title),
Span::raw(" "),
Span::styled("Permission Required", self.theme.popup_title),
Span::raw(" "),
]));
frame.render_widget(block, popup_area);
// Split popup into sections
let inner = popup_area.inner(ratatui::layout::Margin {
vertical: 1,
horizontal: 2,
});
let sections = Layout::default()
.direction(Direction::Vertical)
.constraints([
Constraint::Length(2), // Tool name with box
Constraint::Length(3), // Context (if any)
Constraint::Length(1), // Separator
Constraint::Length(1), // Option 1
Constraint::Length(1), // Option 2
Constraint::Length(1), // Option 3
Constraint::Length(1), // Option 4
Constraint::Length(1), // Help text
])
.split(inner);
// Tool name with highlight
let tool_line = Line::from(vec![
Span::styled("⚡ Tool: ", Style::default().fg(self.theme.palette.warning)),
Span::styled(&self.tool, self.theme.popup_title),
]);
frame.render_widget(Paragraph::new(tool_line), sections[0]);
// Context with wrapping
if let Some(ctx) = &self.context {
let context_text = if ctx.len() > 100 {
format!("{}...", &ctx[..100])
} else {
ctx.clone()
};
let context_lines = textwrap::wrap(&context_text, (sections[1].width - 2) as usize);
let mut lines = vec![
Line::from(vec![
Span::styled("📝 Context: ", Style::default().fg(self.theme.palette.info)),
])
];
for line in context_lines.iter().take(2) {
lines.push(Line::from(vec![
Span::raw(" "),
Span::styled(line.to_string(), Style::default().fg(self.theme.palette.fg_dim)),
]));
}
frame.render_widget(Paragraph::new(lines), sections[1]);
}
// Separator
let separator = Line::styled(
"".repeat(sections[2].width as usize),
Style::default().fg(self.theme.palette.divider_fg),
);
frame.render_widget(Paragraph::new(separator), sections[2]);
// Options with icons and colors
let options = [
("", " [a] Allow once", self.theme.palette.success, 0),
("✓✓", " [A] Always allow", self.theme.palette.primary, 1),
("", " [d] Deny", self.theme.palette.error, 2),
("?", " [?] Explain", self.theme.palette.info, 3),
];
for (icon, text, color, idx) in options.iter() {
let (style, prefix) = if self.selected == *idx {
(
self.theme.selected,
""
)
} else {
(
Style::default().fg(*color),
" "
)
};
let line = Line::from(vec![
Span::styled(prefix, style),
Span::styled(*icon, style),
Span::styled(*text, style),
]);
frame.render_widget(Paragraph::new(line), sections[3 + idx]);
}
// Help text at bottom
let help_line = Line::from(vec![
Span::styled(
"↑↓ Navigate Enter to select Esc to deny",
Style::default().fg(self.theme.palette.fg_dim).add_modifier(Modifier::ITALIC),
),
]);
frame.render_widget(Paragraph::new(help_line), sections[7]);
}
}
impl PermissionOption {
pub fn to_decision(&self) -> Option<PermissionDecision> {
match self {
PermissionOption::AllowOnce => Some(PermissionDecision::Allow),
PermissionOption::AlwaysAllow => Some(PermissionDecision::Allow),
PermissionOption::Deny => Some(PermissionDecision::Deny),
PermissionOption::Explain => None, // Special handling needed
}
}
pub fn should_persist(&self) -> bool {
matches!(self, PermissionOption::AlwaysAllow)
}
}

View File

@@ -0,0 +1,189 @@
//! Provider tabs component for multi-LLM support
//!
//! Displays horizontal tabs for switching between providers (Claude, Ollama, OpenAI)
//! with icons and keybind hints.
use crate::theme::{Provider, Theme};
use ratatui::{
layout::Rect,
style::Style,
text::{Line, Span},
widgets::Paragraph,
Frame,
};
/// Provider tab state and rendering
pub struct ProviderTabs {
active: Provider,
theme: Theme,
}
impl ProviderTabs {
/// Create new provider tabs with default provider
pub fn new(theme: Theme) -> Self {
Self {
active: Provider::Ollama, // Default to Ollama (local)
theme,
}
}
/// Create with specific active provider
pub fn with_provider(provider: Provider, theme: Theme) -> Self {
Self {
active: provider,
theme,
}
}
/// Get the currently active provider
pub fn active(&self) -> Provider {
self.active
}
/// Set the active provider
pub fn set_active(&mut self, provider: Provider) {
self.active = provider;
}
/// Cycle to the next provider
pub fn next(&mut self) {
self.active = match self.active {
Provider::Claude => Provider::Ollama,
Provider::Ollama => Provider::OpenAI,
Provider::OpenAI => Provider::Claude,
};
}
/// Cycle to the previous provider
pub fn previous(&mut self) {
self.active = match self.active {
Provider::Claude => Provider::OpenAI,
Provider::Ollama => Provider::Claude,
Provider::OpenAI => Provider::Ollama,
};
}
/// Select provider by number (1, 2, 3)
pub fn select_by_number(&mut self, num: u8) {
self.active = match num {
1 => Provider::Claude,
2 => Provider::Ollama,
3 => Provider::OpenAI,
_ => self.active,
};
}
/// Update the theme
pub fn set_theme(&mut self, theme: Theme) {
self.theme = theme;
}
/// Render the provider tabs (borderless)
pub fn render(&self, frame: &mut Frame, area: Rect) {
let mut spans = Vec::new();
// Add spacing at start
spans.push(Span::raw(" "));
for (i, provider) in Provider::all().iter().enumerate() {
let is_active = *provider == self.active;
let icon = self.theme.provider_icon(*provider);
let name = provider.name();
let number = (i + 1).to_string();
// Keybind hint
spans.push(Span::styled(
format!("[{}] ", number),
self.theme.status_dim,
));
// Icon and name
let style = if is_active {
Style::default()
.fg(self.theme.provider_color(*provider))
.add_modifier(ratatui::style::Modifier::BOLD)
} else {
self.theme.tab_inactive
};
spans.push(Span::styled(format!("{} ", icon), style));
spans.push(Span::styled(name.to_string(), style));
// Separator between tabs (not after last)
if i < Provider::all().len() - 1 {
spans.push(Span::styled(
format!(" {} ", self.theme.symbols.vertical_separator),
self.theme.status_dim,
));
}
}
// Tab cycling hint on the right
spans.push(Span::raw(" "));
spans.push(Span::styled("[Tab] cycle", self.theme.status_dim));
let line = Line::from(spans);
let paragraph = Paragraph::new(line);
frame.render_widget(paragraph, area);
}
/// Render a compact version (just active provider)
pub fn render_compact(&self, frame: &mut Frame, area: Rect) {
let icon = self.theme.provider_icon(self.active);
let name = self.active.name();
let line = Line::from(vec![
Span::raw(" "),
Span::styled(
format!("{} {}", icon, name),
Style::default()
.fg(self.theme.provider_color(self.active))
.add_modifier(ratatui::style::Modifier::BOLD),
),
]);
let paragraph = Paragraph::new(line);
frame.render_widget(paragraph, area);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_cycling() {
let theme = Theme::default();
let mut tabs = ProviderTabs::new(theme);
assert_eq!(tabs.active(), Provider::Ollama);
tabs.next();
assert_eq!(tabs.active(), Provider::OpenAI);
tabs.next();
assert_eq!(tabs.active(), Provider::Claude);
tabs.next();
assert_eq!(tabs.active(), Provider::Ollama);
}
#[test]
fn test_select_by_number() {
let theme = Theme::default();
let mut tabs = ProviderTabs::new(theme);
tabs.select_by_number(1);
assert_eq!(tabs.active(), Provider::Claude);
tabs.select_by_number(2);
assert_eq!(tabs.active(), Provider::Ollama);
tabs.select_by_number(3);
assert_eq!(tabs.active(), Provider::OpenAI);
// Invalid number should not change
tabs.select_by_number(4);
assert_eq!(tabs.active(), Provider::OpenAI);
}
}

View File

@@ -0,0 +1,188 @@
//! Minimal status bar component
//!
//! Clean, readable status bar with essential info only.
//! Format: ` Mode │ N msgs │ ~Nk tok │ state`
use crate::theme::{Provider, Theme, VimMode};
use agent_core::SessionStats;
use permissions::Mode;
use ratatui::{
layout::Rect,
style::Style,
text::{Line, Span},
widgets::Paragraph,
Frame,
};
/// Application state for status display
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AppState {
Idle,
Streaming,
WaitingPermission,
Error,
}
impl AppState {
pub fn label(&self) -> &'static str {
match self {
AppState::Idle => "idle",
AppState::Streaming => "streaming...",
AppState::WaitingPermission => "waiting",
AppState::Error => "error",
}
}
}
pub struct StatusBar {
provider: Provider,
model: String,
mode: Mode,
vim_mode: VimMode,
stats: SessionStats,
last_tool: Option<String>,
state: AppState,
estimated_cost: f64,
planning_mode: bool,
theme: Theme,
}
impl StatusBar {
pub fn new(model: String, mode: Mode, theme: Theme) -> Self {
Self {
provider: Provider::Ollama, // Default provider
model,
mode,
vim_mode: VimMode::Insert,
stats: SessionStats::new(),
last_tool: None,
state: AppState::Idle,
estimated_cost: 0.0,
planning_mode: false,
theme,
}
}
/// Set the active provider
pub fn set_provider(&mut self, provider: Provider) {
self.provider = provider;
}
/// Set the current model
pub fn set_model(&mut self, model: String) {
self.model = model;
}
/// Update session stats
pub fn update_stats(&mut self, stats: SessionStats) {
self.stats = stats;
}
/// Set the last used tool
pub fn set_last_tool(&mut self, tool: String) {
self.last_tool = Some(tool);
}
/// Set application state
pub fn set_state(&mut self, state: AppState) {
self.state = state;
}
/// Set vim mode for display
pub fn set_vim_mode(&mut self, mode: VimMode) {
self.vim_mode = mode;
}
/// Add to estimated cost
pub fn add_cost(&mut self, cost: f64) {
self.estimated_cost += cost;
}
/// Reset cost
pub fn reset_cost(&mut self) {
self.estimated_cost = 0.0;
}
/// Update theme
pub fn set_theme(&mut self, theme: Theme) {
self.theme = theme;
}
/// Set planning mode status
pub fn set_planning_mode(&mut self, active: bool) {
self.planning_mode = active;
}
/// Render the minimal status bar
///
/// Format: ` Mode │ N msgs │ ~Nk tok │ state`
pub fn render(&self, frame: &mut Frame, area: Rect) {
let sep = self.theme.symbols.vertical_separator;
let sep_style = Style::default().fg(self.theme.palette.border);
// Permission mode
let mode_str = if self.planning_mode {
"PLAN"
} else {
match self.mode {
Mode::Plan => "Plan",
Mode::AcceptEdits => "Edit",
Mode::Code => "Code",
}
};
// Format token count
let tokens_str = if self.stats.estimated_tokens >= 1000 {
format!("~{}k tok", self.stats.estimated_tokens / 1000)
} else {
format!("~{} tok", self.stats.estimated_tokens)
};
// State style - only highlight non-idle states
let state_style = match self.state {
AppState::Idle => self.theme.status_dim,
AppState::Streaming => Style::default().fg(self.theme.palette.success),
AppState::WaitingPermission => Style::default().fg(self.theme.palette.warning),
AppState::Error => Style::default().fg(self.theme.palette.error),
};
// Build minimal status line
let spans = vec![
Span::styled(" ", self.theme.status_dim),
// Mode
Span::styled(mode_str, self.theme.status_dim),
Span::styled(format!(" {} ", sep), sep_style),
// Message count
Span::styled(format!("{} msgs", self.stats.total_messages), self.theme.status_dim),
Span::styled(format!(" {} ", sep), sep_style),
// Token count
Span::styled(&tokens_str, self.theme.status_dim),
Span::styled(format!(" {} ", sep), sep_style),
// State
Span::styled(self.state.label(), state_style),
];
let line = Line::from(spans);
let paragraph = Paragraph::new(line);
frame.render_widget(paragraph, area);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_status_bar_creation() {
let theme = Theme::default();
let status_bar = StatusBar::new("gpt-4".to_string(), Mode::Plan, theme);
assert_eq!(status_bar.model, "gpt-4");
}
#[test]
fn test_app_state_display() {
assert_eq!(AppState::Idle.label(), "idle");
assert_eq!(AppState::Streaming.label(), "streaming...");
assert_eq!(AppState::Error.label(), "error");
}
}

View File

@@ -0,0 +1,200 @@
//! Todo panel component for displaying task list
//!
//! Shows the current todo list with status indicators and progress.
use ratatui::{
layout::Rect,
style::{Color, Modifier, Style},
text::{Line, Span},
widgets::{Block, Borders, Paragraph},
Frame,
};
use tools_todo::{Todo, TodoList, TodoStatus};
use crate::theme::Theme;
/// Todo panel component
pub struct TodoPanel {
theme: Theme,
collapsed: bool,
}
impl TodoPanel {
pub fn new(theme: Theme) -> Self {
Self {
theme,
collapsed: false,
}
}
/// Toggle collapsed state
pub fn toggle(&mut self) {
self.collapsed = !self.collapsed;
}
/// Check if collapsed
pub fn is_collapsed(&self) -> bool {
self.collapsed
}
/// Update theme
pub fn set_theme(&mut self, theme: Theme) {
self.theme = theme;
}
/// Get the minimum height needed for the panel
pub fn min_height(&self) -> u16 {
if self.collapsed {
1
} else {
5
}
}
/// Render the todo panel
pub fn render(&self, frame: &mut Frame, area: Rect, todos: &TodoList) {
if self.collapsed {
self.render_collapsed(frame, area, todos);
} else {
self.render_expanded(frame, area, todos);
}
}
/// Render collapsed view (single line summary)
fn render_collapsed(&self, frame: &mut Frame, area: Rect, todos: &TodoList) {
let items = todos.read();
let completed = items.iter().filter(|t| t.status == TodoStatus::Completed).count();
let in_progress = items.iter().filter(|t| t.status == TodoStatus::InProgress).count();
let pending = items.iter().filter(|t| t.status == TodoStatus::Pending).count();
let summary = if items.is_empty() {
"No tasks".to_string()
} else {
format!(
"{} {} / {} {} / {} {}",
self.theme.symbols.check, completed,
self.theme.symbols.streaming, in_progress,
self.theme.symbols.bullet, pending
)
};
let line = Line::from(vec![
Span::styled("Tasks: ", self.theme.status_bar),
Span::styled(summary, self.theme.status_dim),
Span::styled(" [t to expand]", self.theme.status_dim),
]);
let paragraph = Paragraph::new(line);
frame.render_widget(paragraph, area);
}
/// Render expanded view with task list
fn render_expanded(&self, frame: &mut Frame, area: Rect, todos: &TodoList) {
let items = todos.read();
let mut lines: Vec<Line> = Vec::new();
// Header
lines.push(Line::from(vec![
Span::styled("Tasks", Style::default().add_modifier(Modifier::BOLD)),
Span::styled(" [t to collapse]", self.theme.status_dim),
]));
if items.is_empty() {
lines.push(Line::from(Span::styled(
" No active tasks",
self.theme.status_dim,
)));
} else {
// Show tasks (limit to available space)
let max_items = (area.height as usize).saturating_sub(2);
let display_items: Vec<&Todo> = items.iter().take(max_items).collect();
for item in display_items {
let (icon, style) = match item.status {
TodoStatus::Completed => (
self.theme.symbols.check,
Style::default().fg(Color::Green),
),
TodoStatus::InProgress => (
self.theme.symbols.streaming,
Style::default().fg(Color::Yellow),
),
TodoStatus::Pending => (
self.theme.symbols.bullet,
self.theme.status_dim,
),
};
// Use active form for in-progress, content for others
let text = if item.status == TodoStatus::InProgress {
&item.active_form
} else {
&item.content
};
// Truncate if too long
let max_width = area.width.saturating_sub(6) as usize;
let display_text = if text.len() > max_width {
format!("{}...", &text[..max_width.saturating_sub(3)])
} else {
text.clone()
};
lines.push(Line::from(vec![
Span::styled(format!(" {} ", icon), style),
Span::styled(display_text, style),
]));
}
// Show overflow indicator if needed
if items.len() > max_items {
lines.push(Line::from(Span::styled(
format!(" ... and {} more", items.len() - max_items),
self.theme.status_dim,
)));
}
}
let block = Block::default()
.borders(Borders::TOP)
.border_style(self.theme.status_dim);
let paragraph = Paragraph::new(lines).block(block);
frame.render_widget(paragraph, area);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_todo_panel_creation() {
let theme = Theme::default();
let panel = TodoPanel::new(theme);
assert!(!panel.is_collapsed());
}
#[test]
fn test_todo_panel_toggle() {
let theme = Theme::default();
let mut panel = TodoPanel::new(theme);
assert!(!panel.is_collapsed());
panel.toggle();
assert!(panel.is_collapsed());
panel.toggle();
assert!(!panel.is_collapsed());
}
#[test]
fn test_min_height() {
let theme = Theme::default();
let mut panel = TodoPanel::new(theme);
assert_eq!(panel.min_height(), 5);
panel.toggle();
assert_eq!(panel.min_height(), 1);
}
}

View File

@@ -0,0 +1,53 @@
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
use serde_json::Value;
/// Application events that drive the TUI
#[derive(Debug, Clone)]
pub enum AppEvent {
/// User input from keyboard
Input(KeyEvent),
/// User submitted a message
UserMessage(String),
/// LLM streaming started
StreamStart,
/// LLM response chunk (streaming)
LlmChunk(String),
/// LLM streaming completed
StreamEnd { response: String },
/// LLM streaming error
StreamError(String),
/// Tool call started
ToolCall { name: String, args: Value },
/// Tool execution result
ToolResult { success: bool, output: String },
/// Permission request from agent
PermissionRequest {
tool: String,
context: Option<String>,
},
/// Session statistics updated
StatusUpdate(agent_core::SessionStats),
/// Terminal was resized
Resize { width: u16, height: u16 },
/// Mouse scroll up
ScrollUp,
/// Mouse scroll down
ScrollDown,
/// Toggle the todo panel
ToggleTodo,
/// Application should quit
Quit,
}
/// Process keyboard input into app events
pub fn handle_key_event(key: KeyEvent) -> Option<AppEvent> {
match key.code {
KeyCode::Char('c') if key.modifiers.contains(KeyModifiers::CONTROL) => {
Some(AppEvent::Quit)
}
KeyCode::Char('t') if key.modifiers.contains(KeyModifiers::CONTROL) => {
Some(AppEvent::ToggleTodo)
}
_ => Some(AppEvent::Input(key)),
}
}

View File

@@ -0,0 +1,532 @@
//! Output formatting with markdown parsing and syntax highlighting
//!
//! This module provides rich text rendering for the TUI, converting markdown
//! content into styled ratatui spans with proper syntax highlighting for code blocks.
use pulldown_cmark::{CodeBlockKind, Event, Parser, Tag, TagEnd};
use ratatui::style::{Color, Modifier, Style};
use ratatui::text::{Line, Span};
use syntect::easy::HighlightLines;
use syntect::highlighting::{Theme, ThemeSet};
use syntect::parsing::SyntaxSet;
use syntect::util::LinesWithEndings;
/// Highlighter for syntax highlighting code blocks
pub struct SyntaxHighlighter {
syntax_set: SyntaxSet,
theme: Theme,
}
impl SyntaxHighlighter {
/// Create a new syntax highlighter with default theme
pub fn new() -> Self {
let syntax_set = SyntaxSet::load_defaults_newlines();
let theme_set = ThemeSet::load_defaults();
// Use a dark theme that works well in terminals
let theme = theme_set.themes["base16-ocean.dark"].clone();
Self { syntax_set, theme }
}
/// Create highlighter with a specific theme name
pub fn with_theme(theme_name: &str) -> Self {
let syntax_set = SyntaxSet::load_defaults_newlines();
let theme_set = ThemeSet::load_defaults();
let theme = theme_set
.themes
.get(theme_name)
.cloned()
.unwrap_or_else(|| theme_set.themes["base16-ocean.dark"].clone());
Self { syntax_set, theme }
}
/// Get available theme names
pub fn available_themes() -> Vec<&'static str> {
vec![
"base16-ocean.dark",
"base16-eighties.dark",
"base16-mocha.dark",
"base16-ocean.light",
"InspiredGitHub",
"Solarized (dark)",
"Solarized (light)",
]
}
/// Highlight a code block and return styled lines
pub fn highlight_code(&self, code: &str, language: &str) -> Vec<Line<'static>> {
// Find syntax for the language
let syntax = self
.syntax_set
.find_syntax_by_token(language)
.or_else(|| self.syntax_set.find_syntax_by_extension(language))
.unwrap_or_else(|| self.syntax_set.find_syntax_plain_text());
let mut highlighter = HighlightLines::new(syntax, &self.theme);
let mut lines = Vec::new();
for line in LinesWithEndings::from(code) {
let Ok(ranges) = highlighter.highlight_line(line, &self.syntax_set) else {
// Fallback to plain text if highlighting fails
lines.push(Line::from(Span::raw(line.trim_end().to_string())));
continue;
};
let spans: Vec<Span<'static>> = ranges
.into_iter()
.map(|(style, text)| {
let fg = syntect_to_ratatui_color(style.foreground);
let ratatui_style = Style::default().fg(fg);
Span::styled(text.trim_end_matches('\n').to_string(), ratatui_style)
})
.collect();
lines.push(Line::from(spans));
}
lines
}
}
impl Default for SyntaxHighlighter {
fn default() -> Self {
Self::new()
}
}
/// Convert syntect color to ratatui color
fn syntect_to_ratatui_color(color: syntect::highlighting::Color) -> Color {
Color::Rgb(color.r, color.g, color.b)
}
/// Parsed markdown content ready for rendering
#[derive(Debug, Clone)]
pub struct FormattedContent {
pub lines: Vec<Line<'static>>,
}
impl FormattedContent {
/// Create empty formatted content
pub fn empty() -> Self {
Self { lines: Vec::new() }
}
/// Get the number of lines
pub fn len(&self) -> usize {
self.lines.len()
}
/// Check if content is empty
pub fn is_empty(&self) -> bool {
self.lines.is_empty()
}
}
/// Markdown parser that converts markdown to styled ratatui lines
pub struct MarkdownRenderer {
highlighter: SyntaxHighlighter,
}
impl MarkdownRenderer {
/// Create a new markdown renderer
pub fn new() -> Self {
Self {
highlighter: SyntaxHighlighter::new(),
}
}
/// Create renderer with custom highlighter
pub fn with_highlighter(highlighter: SyntaxHighlighter) -> Self {
Self { highlighter }
}
/// Render markdown text to formatted content
pub fn render(&self, markdown: &str) -> FormattedContent {
let parser = Parser::new(markdown);
let mut lines: Vec<Line<'static>> = Vec::new();
let mut current_line_spans: Vec<Span<'static>> = Vec::new();
// State tracking
let mut in_code_block = false;
let mut code_block_lang = String::new();
let mut code_block_content = String::new();
let mut current_style = Style::default();
let mut list_depth: usize = 0;
let mut ordered_list_index: Option<u64> = None;
for event in parser {
match event {
Event::Start(tag) => match tag {
Tag::Heading { level, .. } => {
// Flush current line
if !current_line_spans.is_empty() {
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
}
// Style for headings
current_style = match level {
pulldown_cmark::HeadingLevel::H1 => Style::default()
.fg(Color::Cyan)
.add_modifier(Modifier::BOLD),
pulldown_cmark::HeadingLevel::H2 => Style::default()
.fg(Color::Blue)
.add_modifier(Modifier::BOLD),
pulldown_cmark::HeadingLevel::H3 => Style::default()
.fg(Color::Green)
.add_modifier(Modifier::BOLD),
_ => Style::default().add_modifier(Modifier::BOLD),
};
// Add heading prefix
let prefix = "#".repeat(level as usize);
current_line_spans.push(Span::styled(
format!("{} ", prefix),
Style::default().fg(Color::DarkGray),
));
}
Tag::Paragraph => {
// Start a new paragraph
if !current_line_spans.is_empty() {
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
}
}
Tag::CodeBlock(kind) => {
in_code_block = true;
code_block_content.clear();
code_block_lang = match kind {
CodeBlockKind::Fenced(lang) => lang.to_string(),
CodeBlockKind::Indented => String::new(),
};
// Flush current line and add code block header
if !current_line_spans.is_empty() {
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
}
// Add code fence line
let fence_line = if code_block_lang.is_empty() {
"```".to_string()
} else {
format!("```{}", code_block_lang)
};
lines.push(Line::from(Span::styled(
fence_line,
Style::default().fg(Color::DarkGray),
)));
}
Tag::List(start) => {
list_depth += 1;
ordered_list_index = start;
}
Tag::Item => {
// Flush current line
if !current_line_spans.is_empty() {
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
}
// Add list marker
let indent = " ".repeat(list_depth.saturating_sub(1));
let marker = if let Some(idx) = ordered_list_index {
ordered_list_index = Some(idx + 1);
format!("{}{}. ", indent, idx)
} else {
format!("{}- ", indent)
};
current_line_spans.push(Span::styled(
marker,
Style::default().fg(Color::Yellow),
));
}
Tag::Emphasis => {
current_style = current_style.add_modifier(Modifier::ITALIC);
}
Tag::Strong => {
current_style = current_style.add_modifier(Modifier::BOLD);
}
Tag::Strikethrough => {
current_style = current_style.add_modifier(Modifier::CROSSED_OUT);
}
Tag::Link { dest_url, .. } => {
current_style = Style::default()
.fg(Color::Blue)
.add_modifier(Modifier::UNDERLINED);
// Store URL for later
current_line_spans.push(Span::styled(
"[",
Style::default().fg(Color::DarkGray),
));
// URL will be shown after link text
code_block_content = dest_url.to_string();
}
Tag::BlockQuote(_) => {
if !current_line_spans.is_empty() {
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
}
current_line_spans.push(Span::styled(
"",
Style::default().fg(Color::DarkGray),
));
current_style = Style::default().fg(Color::Gray).add_modifier(Modifier::ITALIC);
}
_ => {}
},
Event::End(tag_end) => match tag_end {
TagEnd::Heading(_) => {
current_style = Style::default();
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
}
TagEnd::Paragraph => {
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
lines.push(Line::from("")); // Empty line after paragraph
}
TagEnd::CodeBlock => {
in_code_block = false;
// Highlight and add code content
let highlighted =
self.highlighter.highlight_code(&code_block_content, &code_block_lang);
lines.extend(highlighted);
// Add closing fence
lines.push(Line::from(Span::styled(
"```",
Style::default().fg(Color::DarkGray),
)));
code_block_content.clear();
code_block_lang.clear();
}
TagEnd::List(_) => {
list_depth = list_depth.saturating_sub(1);
if list_depth == 0 {
ordered_list_index = None;
}
}
TagEnd::Item => {
if !current_line_spans.is_empty() {
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
}
}
TagEnd::Emphasis | TagEnd::Strong | TagEnd::Strikethrough => {
current_style = Style::default();
}
TagEnd::Link => {
current_line_spans.push(Span::styled(
"]",
Style::default().fg(Color::DarkGray),
));
current_line_spans.push(Span::styled(
format!("({})", code_block_content),
Style::default().fg(Color::DarkGray),
));
code_block_content.clear();
current_style = Style::default();
}
TagEnd::BlockQuote => {
current_style = Style::default();
if !current_line_spans.is_empty() {
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
}
}
_ => {}
},
Event::Text(text) => {
if in_code_block {
code_block_content.push_str(&text);
} else {
current_line_spans.push(Span::styled(text.to_string(), current_style));
}
}
Event::Code(code) => {
// Inline code
current_line_spans.push(Span::styled(
format!("`{}`", code),
Style::default().fg(Color::Magenta),
));
}
Event::SoftBreak => {
current_line_spans.push(Span::raw(" "));
}
Event::HardBreak => {
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
}
Event::Rule => {
if !current_line_spans.is_empty() {
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
}
lines.push(Line::from(Span::styled(
"".repeat(40),
Style::default().fg(Color::DarkGray),
)));
}
_ => {}
}
}
// Flush any remaining content
if !current_line_spans.is_empty() {
lines.push(Line::from(current_line_spans));
}
FormattedContent { lines }
}
/// Render plain text (no markdown parsing)
pub fn render_plain(&self, text: &str) -> FormattedContent {
let lines = text
.lines()
.map(|line| Line::from(Span::raw(line.to_string())))
.collect();
FormattedContent { lines }
}
/// Render a diff with +/- highlighting
pub fn render_diff(&self, diff: &str) -> FormattedContent {
let lines = diff
.lines()
.map(|line| {
let style = if line.starts_with('+') && !line.starts_with("+++") {
Style::default().fg(Color::Green)
} else if line.starts_with('-') && !line.starts_with("---") {
Style::default().fg(Color::Red)
} else if line.starts_with("@@") {
Style::default().fg(Color::Cyan)
} else if line.starts_with("diff ") || line.starts_with("index ") {
Style::default().fg(Color::Yellow)
} else {
Style::default()
};
Line::from(Span::styled(line.to_string(), style))
})
.collect();
FormattedContent { lines }
}
}
impl Default for MarkdownRenderer {
fn default() -> Self {
Self::new()
}
}
/// Format a file path with syntax highlighting based on extension
pub fn format_file_path(path: &str) -> Span<'static> {
let color = if path.ends_with(".rs") {
Color::Rgb(222, 165, 132) // Rust orange
} else if path.ends_with(".toml") {
Color::Rgb(156, 220, 254) // Light blue
} else if path.ends_with(".md") {
Color::Rgb(86, 156, 214) // Blue
} else if path.ends_with(".json") {
Color::Rgb(206, 145, 120) // Brown
} else if path.ends_with(".ts") || path.ends_with(".tsx") {
Color::Rgb(49, 120, 198) // TypeScript blue
} else if path.ends_with(".js") || path.ends_with(".jsx") {
Color::Rgb(241, 224, 90) // JavaScript yellow
} else if path.ends_with(".py") {
Color::Rgb(55, 118, 171) // Python blue
} else if path.ends_with(".go") {
Color::Rgb(0, 173, 216) // Go cyan
} else if path.ends_with(".sh") || path.ends_with(".bash") {
Color::Rgb(137, 224, 81) // Shell green
} else {
Color::White
};
Span::styled(path.to_string(), Style::default().fg(color))
}
/// Format a tool name with appropriate styling
pub fn format_tool_name(name: &str) -> Span<'static> {
let style = Style::default()
.fg(Color::Yellow)
.add_modifier(Modifier::BOLD);
Span::styled(name.to_string(), style)
}
/// Format an error message
pub fn format_error(message: &str) -> Line<'static> {
Line::from(vec![
Span::styled("Error: ", Style::default().fg(Color::Red).add_modifier(Modifier::BOLD)),
Span::styled(message.to_string(), Style::default().fg(Color::Red)),
])
}
/// Format a success message
pub fn format_success(message: &str) -> Line<'static> {
Line::from(vec![
Span::styled("", Style::default().fg(Color::Green)),
Span::styled(message.to_string(), Style::default().fg(Color::Green)),
])
}
/// Format a warning message
pub fn format_warning(message: &str) -> Line<'static> {
Line::from(vec![
Span::styled("", Style::default().fg(Color::Yellow)),
Span::styled(message.to_string(), Style::default().fg(Color::Yellow)),
])
}
/// Format an info message
pub fn format_info(message: &str) -> Line<'static> {
Line::from(vec![
Span::styled(" ", Style::default().fg(Color::Blue)),
Span::styled(message.to_string(), Style::default().fg(Color::Blue)),
])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_syntax_highlighter_creation() {
let highlighter = SyntaxHighlighter::new();
let lines = highlighter.highlight_code("fn main() {}", "rust");
assert!(!lines.is_empty());
}
#[test]
fn test_markdown_render_heading() {
let renderer = MarkdownRenderer::new();
let content = renderer.render("# Hello World");
assert!(!content.is_empty());
}
#[test]
fn test_markdown_render_code_block() {
let renderer = MarkdownRenderer::new();
let content = renderer.render("```rust\nfn main() {}\n```");
assert!(content.len() >= 3); // Opening fence, code, closing fence
}
#[test]
fn test_markdown_render_list() {
let renderer = MarkdownRenderer::new();
let content = renderer.render("- Item 1\n- Item 2\n- Item 3");
assert!(content.len() >= 3);
}
#[test]
fn test_diff_rendering() {
let renderer = MarkdownRenderer::new();
let diff = "+added line\n-removed line\n unchanged";
let content = renderer.render_diff(diff);
assert_eq!(content.len(), 3);
}
#[test]
fn test_format_file_path() {
let span = format_file_path("src/main.rs");
assert!(span.content.contains("main.rs"));
}
#[test]
fn test_format_messages() {
let error = format_error("Something went wrong");
assert!(!error.spans.is_empty());
let success = format_success("Operation completed");
assert!(!success.spans.is_empty());
let warning = format_warning("Be careful");
assert!(!warning.spans.is_empty());
let info = format_info("FYI");
assert!(!info.spans.is_empty());
}
}

218
crates/app/ui/src/layout.rs Normal file
View File

@@ -0,0 +1,218 @@
//! Layout calculation for the borderless TUI
//!
//! Uses vertical layout with whitespace for visual hierarchy instead of borders:
//! - Header row (app name, mode, model, help)
//! - Provider tabs
//! - Horizontal divider
//! - Chat area (scrollable)
//! - Horizontal divider
//! - Input area
//! - Status bar
use ratatui::layout::{Constraint, Direction, Layout, Rect};
/// Calculated layout areas for the borderless TUI
#[derive(Debug, Clone, Copy)]
pub struct AppLayout {
/// Header row: app name, mode indicator, model, help hint
pub header_area: Rect,
/// Provider tabs row
pub tabs_area: Rect,
/// Top divider (horizontal rule)
pub top_divider: Rect,
/// Main chat/message area
pub chat_area: Rect,
/// Todo panel area (optional, between chat and input)
pub todo_area: Rect,
/// Bottom divider (horizontal rule)
pub bottom_divider: Rect,
/// Input area for user text
pub input_area: Rect,
/// Status bar at the bottom
pub status_area: Rect,
}
impl AppLayout {
/// Calculate layout for the given terminal size
pub fn calculate(area: Rect) -> Self {
Self::calculate_with_todo(area, 0)
}
/// Calculate layout with todo panel of specified height
///
/// Simplified layout without provider tabs:
/// - Header (1 line)
/// - Top divider (1 line)
/// - Chat area (flexible)
/// - Todo panel (optional)
/// - Bottom divider (1 line)
/// - Input (1 line)
/// - Status bar (1 line)
pub fn calculate_with_todo(area: Rect, todo_height: u16) -> Self {
let chunks = if todo_height > 0 {
Layout::default()
.direction(Direction::Vertical)
.constraints([
Constraint::Length(1), // Header
Constraint::Length(1), // Top divider
Constraint::Min(5), // Chat area (flexible)
Constraint::Length(todo_height), // Todo panel
Constraint::Length(1), // Bottom divider
Constraint::Length(1), // Input
Constraint::Length(1), // Status bar
])
.split(area)
} else {
Layout::default()
.direction(Direction::Vertical)
.constraints([
Constraint::Length(1), // Header
Constraint::Length(1), // Top divider
Constraint::Min(5), // Chat area (flexible)
Constraint::Length(0), // No todo panel
Constraint::Length(1), // Bottom divider
Constraint::Length(1), // Input
Constraint::Length(1), // Status bar
])
.split(area)
};
Self {
header_area: chunks[0],
tabs_area: Rect::default(), // Not used in simplified layout
top_divider: chunks[1],
chat_area: chunks[2],
todo_area: chunks[3],
bottom_divider: chunks[4],
input_area: chunks[5],
status_area: chunks[6],
}
}
/// Calculate layout with expanded input (multiline)
pub fn calculate_expanded_input(area: Rect, input_lines: u16) -> Self {
let input_height = input_lines.min(10).max(1); // Cap at 10 lines
let chunks = Layout::default()
.direction(Direction::Vertical)
.constraints([
Constraint::Length(1), // Header
Constraint::Length(1), // Top divider
Constraint::Min(5), // Chat area (flexible)
Constraint::Length(0), // No todo panel
Constraint::Length(1), // Bottom divider
Constraint::Length(input_height), // Expanded input
Constraint::Length(1), // Status bar
])
.split(area);
Self {
header_area: chunks[0],
tabs_area: Rect::default(),
top_divider: chunks[1],
chat_area: chunks[2],
todo_area: chunks[3],
bottom_divider: chunks[4],
input_area: chunks[5],
status_area: chunks[6],
}
}
/// Calculate layout without tabs (compact mode)
pub fn calculate_compact(area: Rect) -> Self {
let chunks = Layout::default()
.direction(Direction::Vertical)
.constraints([
Constraint::Length(1), // Header (includes compact provider indicator)
Constraint::Length(1), // Top divider
Constraint::Min(5), // Chat area (flexible)
Constraint::Length(0), // No todo panel
Constraint::Length(1), // Bottom divider
Constraint::Length(1), // Input
Constraint::Length(1), // Status bar
])
.split(area);
Self {
header_area: chunks[0],
tabs_area: Rect::default(), // No tabs area in compact mode
top_divider: chunks[1],
chat_area: chunks[2],
todo_area: chunks[3],
bottom_divider: chunks[4],
input_area: chunks[5],
status_area: chunks[6],
}
}
/// Center a popup in the given area
pub fn center_popup(area: Rect, width: u16, height: u16) -> Rect {
let popup_layout = Layout::default()
.direction(Direction::Vertical)
.constraints([
Constraint::Length((area.height.saturating_sub(height)) / 2),
Constraint::Length(height),
Constraint::Length((area.height.saturating_sub(height)) / 2),
])
.split(area);
Layout::default()
.direction(Direction::Horizontal)
.constraints([
Constraint::Length((area.width.saturating_sub(width)) / 2),
Constraint::Length(width),
Constraint::Length((area.width.saturating_sub(width)) / 2),
])
.split(popup_layout[1])[1]
}
}
/// Layout mode based on terminal width
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LayoutMode {
/// Full layout with provider tabs (>= 80 cols)
Full,
/// Compact layout without tabs (< 80 cols)
Compact,
}
impl LayoutMode {
/// Determine layout mode based on terminal width
pub fn for_width(width: u16) -> Self {
if width >= 80 {
LayoutMode::Full
} else {
LayoutMode::Compact
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_layout_calculation() {
let area = Rect::new(0, 0, 120, 40);
let layout = AppLayout::calculate(area);
// Header should be at top
assert_eq!(layout.header_area.y, 0);
assert_eq!(layout.header_area.height, 1);
// Status should be at bottom
assert_eq!(layout.status_area.y, 39);
assert_eq!(layout.status_area.height, 1);
// Chat area should have most of the space
assert!(layout.chat_area.height > 20);
}
#[test]
fn test_layout_mode() {
assert_eq!(LayoutMode::for_width(80), LayoutMode::Full);
assert_eq!(LayoutMode::for_width(120), LayoutMode::Full);
assert_eq!(LayoutMode::for_width(79), LayoutMode::Compact);
assert_eq!(LayoutMode::for_width(60), LayoutMode::Compact);
}
}

30
crates/app/ui/src/lib.rs Normal file
View File

@@ -0,0 +1,30 @@
pub mod app;
pub mod completions;
pub mod components;
pub mod events;
pub mod formatting;
pub mod layout;
pub mod output;
pub mod theme;
pub use app::TuiApp;
pub use completions::{CompletionEngine, Completion, CommandInfo};
pub use events::AppEvent;
pub use output::{CommandOutput, OutputFormat, TreeNode, ListItem};
pub use formatting::{
FormattedContent, MarkdownRenderer, SyntaxHighlighter,
format_file_path, format_tool_name, format_error, format_success, format_warning, format_info,
};
use color_eyre::eyre::Result;
/// Run the TUI application
pub async fn run(
client: llm_ollama::OllamaClient,
opts: llm_core::ChatOptions,
perms: permissions::PermissionManager,
settings: config_agent::Settings,
) -> Result<()> {
let mut app = TuiApp::new(client, opts, perms, settings)?;
app.run().await
}

388
crates/app/ui/src/output.rs Normal file
View File

@@ -0,0 +1,388 @@
//! Rich command output formatting
//!
//! Provides formatted output for commands like /help, /mcp, /hooks
//! with tables, trees, and syntax highlighting.
use ratatui::text::{Line, Span};
use ratatui::style::{Color, Modifier, Style};
use crate::completions::CommandInfo;
use crate::theme::Theme;
/// A tree node for hierarchical display
#[derive(Debug, Clone)]
pub struct TreeNode {
pub label: String,
pub children: Vec<TreeNode>,
}
impl TreeNode {
pub fn new(label: impl Into<String>) -> Self {
Self {
label: label.into(),
children: vec![],
}
}
pub fn with_children(mut self, children: Vec<TreeNode>) -> Self {
self.children = children;
self
}
}
/// A list item with optional icon/marker
#[derive(Debug, Clone)]
pub struct ListItem {
pub text: String,
pub marker: Option<String>,
pub style: Option<Style>,
}
/// Different output formats
#[derive(Debug, Clone)]
pub enum OutputFormat {
/// Formatted table with headers and rows
Table {
headers: Vec<String>,
rows: Vec<Vec<String>>,
},
/// Hierarchical tree view
Tree {
root: TreeNode,
},
/// Syntax-highlighted code block
Code {
language: String,
content: String,
},
/// Side-by-side diff view
Diff {
old: String,
new: String,
},
/// Simple list with markers
List {
items: Vec<ListItem>,
},
/// Plain text
Text {
content: String,
},
}
/// Rich command output renderer
pub struct CommandOutput {
pub format: OutputFormat,
}
impl CommandOutput {
pub fn new(format: OutputFormat) -> Self {
Self { format }
}
/// Create a help table output
pub fn help_table(commands: &[CommandInfo]) -> Self {
let headers = vec![
"Command".to_string(),
"Description".to_string(),
"Source".to_string(),
];
let rows: Vec<Vec<String>> = commands
.iter()
.map(|c| vec![
format!("/{}", c.name),
c.description.clone(),
c.source.clone(),
])
.collect();
Self {
format: OutputFormat::Table { headers, rows },
}
}
/// Create an MCP servers tree view
pub fn mcp_tree(servers: &[(String, Vec<String>)]) -> Self {
let children: Vec<TreeNode> = servers
.iter()
.map(|(name, tools)| {
TreeNode {
label: name.clone(),
children: tools.iter().map(|t| TreeNode::new(t)).collect(),
}
})
.collect();
Self {
format: OutputFormat::Tree {
root: TreeNode {
label: "MCP Servers".to_string(),
children,
},
},
}
}
/// Create a hooks list output
pub fn hooks_list(hooks: &[(String, String, bool)]) -> Self {
let items: Vec<ListItem> = hooks
.iter()
.map(|(event, path, enabled)| {
let marker = if *enabled { "" } else { "" };
let style = if *enabled {
Some(Style::default().fg(Color::Green))
} else {
Some(Style::default().fg(Color::Red))
};
ListItem {
text: format!("{}: {}", event, path),
marker: Some(marker.to_string()),
style,
}
})
.collect();
Self {
format: OutputFormat::List { items },
}
}
/// Render to TUI Lines
pub fn render(&self, theme: &Theme) -> Vec<Line<'static>> {
match &self.format {
OutputFormat::Table { headers, rows } => {
self.render_table(headers, rows, theme)
}
OutputFormat::Tree { root } => {
self.render_tree(root, 0, theme)
}
OutputFormat::List { items } => {
self.render_list(items, theme)
}
OutputFormat::Code { content, .. } => {
content.lines()
.map(|line| Line::from(Span::styled(line.to_string(), theme.tool_call)))
.collect()
}
OutputFormat::Diff { old, new } => {
self.render_diff(old, new, theme)
}
OutputFormat::Text { content } => {
content.lines()
.map(|line| Line::from(line.to_string()))
.collect()
}
}
}
fn render_table(&self, headers: &[String], rows: &[Vec<String>], theme: &Theme) -> Vec<Line<'static>> {
let mut lines = Vec::new();
// Calculate column widths
let mut widths: Vec<usize> = headers.iter().map(|h| h.len()).collect();
for row in rows {
for (i, cell) in row.iter().enumerate() {
if i < widths.len() {
widths[i] = widths[i].max(cell.len());
}
}
}
// Header line
let header_spans: Vec<Span> = headers
.iter()
.enumerate()
.flat_map(|(i, h)| {
let padded = format!("{:width$}", h, width = widths.get(i).copied().unwrap_or(h.len()));
vec![
Span::styled(padded, Style::default().add_modifier(Modifier::BOLD)),
Span::raw(" "),
]
})
.collect();
lines.push(Line::from(header_spans));
// Separator
let sep: String = widths.iter().map(|w| "".repeat(*w)).collect::<Vec<_>>().join("──");
lines.push(Line::from(Span::styled(sep, theme.status_dim)));
// Rows
for row in rows {
let row_spans: Vec<Span> = row
.iter()
.enumerate()
.flat_map(|(i, cell)| {
let padded = format!("{:width$}", cell, width = widths.get(i).copied().unwrap_or(cell.len()));
let style = if i == 0 {
theme.status_accent // Command names in accent color
} else {
theme.status_bar
};
vec![
Span::styled(padded, style),
Span::raw(" "),
]
})
.collect();
lines.push(Line::from(row_spans));
}
lines
}
fn render_tree(&self, node: &TreeNode, depth: usize, theme: &Theme) -> Vec<Line<'static>> {
let mut lines = Vec::new();
// Render current node
let prefix = if depth == 0 {
"".to_string()
} else {
format!("{}├─ ", "".repeat(depth - 1))
};
let style = if depth == 0 {
Style::default().add_modifier(Modifier::BOLD)
} else if node.children.is_empty() {
theme.status_bar
} else {
theme.status_accent
};
lines.push(Line::from(vec![
Span::styled(prefix, theme.status_dim),
Span::styled(node.label.clone(), style),
]));
// Render children
for child in &node.children {
lines.extend(self.render_tree(child, depth + 1, theme));
}
lines
}
fn render_list(&self, items: &[ListItem], theme: &Theme) -> Vec<Line<'static>> {
items
.iter()
.map(|item| {
let marker_span = if let Some(marker) = &item.marker {
Span::styled(
format!("{} ", marker),
item.style.unwrap_or(theme.status_bar),
)
} else {
Span::raw("")
};
Line::from(vec![
marker_span,
Span::styled(
item.text.clone(),
item.style.unwrap_or(theme.status_bar),
),
])
})
.collect()
}
fn render_diff(&self, old: &str, new: &str, _theme: &Theme) -> Vec<Line<'static>> {
let mut lines = Vec::new();
// Simple line-by-line diff
let old_lines: Vec<&str> = old.lines().collect();
let new_lines: Vec<&str> = new.lines().collect();
let max_len = old_lines.len().max(new_lines.len());
for i in 0..max_len {
let old_line = old_lines.get(i).copied().unwrap_or("");
let new_line = new_lines.get(i).copied().unwrap_or("");
if old_line != new_line {
if !old_line.is_empty() {
lines.push(Line::from(Span::styled(
format!("- {}", old_line),
Style::default().fg(Color::Red),
)));
}
if !new_line.is_empty() {
lines.push(Line::from(Span::styled(
format!("+ {}", new_line),
Style::default().fg(Color::Green),
)));
}
} else {
lines.push(Line::from(format!(" {}", old_line)));
}
}
lines
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_help_table() {
let commands = vec![
CommandInfo::new("help", "Show help", "builtin"),
CommandInfo::new("clear", "Clear screen", "builtin"),
];
let output = CommandOutput::help_table(&commands);
match output.format {
OutputFormat::Table { headers, rows } => {
assert_eq!(headers.len(), 3);
assert_eq!(rows.len(), 2);
}
_ => panic!("Expected Table format"),
}
}
#[test]
fn test_mcp_tree() {
let servers = vec![
("filesystem".to_string(), vec!["read".to_string(), "write".to_string()]),
("database".to_string(), vec!["query".to_string()]),
];
let output = CommandOutput::mcp_tree(&servers);
match output.format {
OutputFormat::Tree { root } => {
assert_eq!(root.label, "MCP Servers");
assert_eq!(root.children.len(), 2);
}
_ => panic!("Expected Tree format"),
}
}
#[test]
fn test_hooks_list() {
let hooks = vec![
("PreToolUse".to_string(), "./hooks/pre".to_string(), true),
("PostToolUse".to_string(), "./hooks/post".to_string(), false),
];
let output = CommandOutput::hooks_list(&hooks);
match output.format {
OutputFormat::List { items } => {
assert_eq!(items.len(), 2);
}
_ => panic!("Expected List format"),
}
}
#[test]
fn test_tree_node() {
let node = TreeNode::new("root")
.with_children(vec![
TreeNode::new("child1"),
TreeNode::new("child2"),
]);
assert_eq!(node.label, "root");
assert_eq!(node.children.len(), 2);
}
}

707
crates/app/ui/src/theme.rs Normal file
View File

@@ -0,0 +1,707 @@
//! Theme system for the borderless TUI design
//!
//! Provides color palettes, semantic styling, and terminal capability detection
//! for graceful degradation across different terminal emulators.
use ratatui::style::{Color, Modifier, Style};
/// Terminal capability detection for graceful degradation
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TerminalCapability {
/// Full Unicode support with true color
Full,
/// Basic Unicode with 256 colors
Unicode256,
/// ASCII only with 16 colors
Basic,
}
impl TerminalCapability {
/// Detect terminal capabilities from environment
pub fn detect() -> Self {
// Check for true color support
let colorterm = std::env::var("COLORTERM").unwrap_or_default();
let term = std::env::var("TERM").unwrap_or_default();
if colorterm == "truecolor" || colorterm == "24bit" {
return Self::Full;
}
if term.contains("256color") || term.contains("kitty") || term.contains("alacritty") {
return Self::Unicode256;
}
// Check if we're in a linux VT or basic terminal
if term == "linux" || term == "vt100" || term == "dumb" {
return Self::Basic;
}
// Default to unicode with 256 colors
Self::Unicode256
}
/// Check if Unicode box drawing is supported
pub fn supports_unicode(&self) -> bool {
matches!(self, Self::Full | Self::Unicode256)
}
/// Check if true color (RGB) is supported
pub fn supports_truecolor(&self) -> bool {
matches!(self, Self::Full)
}
}
/// Symbols with fallbacks for different terminal capabilities
#[derive(Debug, Clone)]
pub struct Symbols {
pub horizontal_rule: &'static str,
pub vertical_separator: &'static str,
pub bullet: &'static str,
pub arrow: &'static str,
pub check: &'static str,
pub cross: &'static str,
pub warning: &'static str,
pub info: &'static str,
pub streaming: &'static str,
pub user_prefix: &'static str,
pub assistant_prefix: &'static str,
pub tool_prefix: &'static str,
pub system_prefix: &'static str,
// Provider icons
pub claude_icon: &'static str,
pub ollama_icon: &'static str,
pub openai_icon: &'static str,
// Vim mode indicators
pub mode_normal: &'static str,
pub mode_insert: &'static str,
pub mode_visual: &'static str,
pub mode_command: &'static str,
}
impl Symbols {
/// Unicode symbols for capable terminals
pub fn unicode() -> Self {
Self {
horizontal_rule: "",
vertical_separator: "",
bullet: "",
arrow: "",
check: "",
cross: "",
warning: "",
info: "",
streaming: "",
user_prefix: "",
assistant_prefix: "",
tool_prefix: "",
system_prefix: "",
claude_icon: "󰚩",
ollama_icon: "󰫢",
openai_icon: "󰊤",
mode_normal: "[N]",
mode_insert: "[I]",
mode_visual: "[V]",
mode_command: "[:]",
}
}
/// ASCII fallback symbols
pub fn ascii() -> Self {
Self {
horizontal_rule: "-",
vertical_separator: "|",
bullet: "*",
arrow: "->",
check: "+",
cross: "x",
warning: "!",
info: "i",
streaming: "*",
user_prefix: ">",
assistant_prefix: "-",
tool_prefix: "#",
system_prefix: "-",
claude_icon: "C",
ollama_icon: "O",
openai_icon: "G",
mode_normal: "[N]",
mode_insert: "[I]",
mode_visual: "[V]",
mode_command: "[:]",
}
}
/// Select symbols based on terminal capability
pub fn for_capability(cap: TerminalCapability) -> Self {
match cap {
TerminalCapability::Full | TerminalCapability::Unicode256 => Self::unicode(),
TerminalCapability::Basic => Self::ascii(),
}
}
}
/// Modern color palette inspired by contemporary design systems
///
/// Color assignment principles:
/// - fg (#c0caf5): PRIMARY text - user messages, command names
/// - assistant (#9aa5ce): Soft gray-blue for AI responses (distinct from user)
/// - accent (#7aa2f7): Interactive elements ONLY (mode, prompt symbol)
/// - cmd_slash (#bb9af7): Purple for / prefix (signals "command")
/// - fg_dim (#565f89): Timestamps, hints, inactive elements
/// - selection (#283457): Highlighted row background
#[derive(Debug, Clone)]
pub struct ColorPalette {
pub primary: Color,
pub secondary: Color,
pub accent: Color,
pub success: Color,
pub warning: Color,
pub error: Color,
pub info: Color,
pub bg: Color,
pub fg: Color,
pub fg_dim: Color,
pub fg_muted: Color,
pub highlight: Color,
pub border: Color, // For horizontal rules (subtle)
pub selection: Color, // Highlighted row background
// Provider-specific colors
pub claude: Color,
pub ollama: Color,
pub openai: Color,
// Semantic colors for messages
pub user_fg: Color, // User message text (bright, fg)
pub assistant_fg: Color, // Assistant message text (soft gray-blue)
pub tool_fg: Color,
pub timestamp_fg: Color,
pub divider_fg: Color,
// Command colors
pub cmd_slash: Color, // Purple for / prefix
pub cmd_name: Color, // Command name (same as fg)
pub cmd_desc: Color, // Command description (dim)
// Overlay/modal colors
pub overlay_bg: Color, // Slightly lighter than main bg
}
impl ColorPalette {
/// Tokyo Night inspired palette - high contrast, readable
///
/// Key principles:
/// - fg (#c0caf5) for user messages and command names
/// - assistant (#a9b1d6) brighter gray-blue for AI responses (readable)
/// - accent (#7aa2f7) only for interactive elements (mode indicator, prompt symbol)
/// - cmd_slash (#bb9af7) purple for / prefix (signals "command")
/// - fg_dim (#737aa2) for timestamps, hints, descriptions (brighter than before)
/// - border (#3b4261) for horizontal rules
pub fn tokyo_night() -> Self {
Self {
primary: Color::Rgb(122, 162, 247), // #7aa2f7 - Blue accent
secondary: Color::Rgb(187, 154, 247), // #bb9af7 - Purple
accent: Color::Rgb(122, 162, 247), // #7aa2f7 - Interactive elements ONLY
success: Color::Rgb(158, 206, 106), // #9ece6a - Green
warning: Color::Rgb(224, 175, 104), // #e0af68 - Yellow
error: Color::Rgb(247, 118, 142), // #f7768e - Pink/Red
info: Color::Rgb(125, 207, 255), // Cyan (rarely used)
bg: Color::Rgb(26, 27, 38), // #1a1b26 - Dark bg
fg: Color::Rgb(192, 202, 245), // #c0caf5 - Primary text (HIGH CONTRAST)
fg_dim: Color::Rgb(115, 122, 162), // #737aa2 - Secondary text (BRIGHTER)
fg_muted: Color::Rgb(86, 95, 137), // #565f89 - Very dim
highlight: Color::Rgb(56, 62, 90), // Selection bg (legacy)
border: Color::Rgb(73, 82, 115), // #495273 - Horizontal rules (BRIGHTER)
selection: Color::Rgb(40, 52, 87), // #283457 - Highlighted row bg
// Provider colors
claude: Color::Rgb(217, 119, 87), // Claude orange
ollama: Color::Rgb(122, 162, 247), // Blue
openai: Color::Rgb(16, 163, 127), // OpenAI green
// Message colors - user bright, assistant readable
user_fg: Color::Rgb(192, 202, 245), // #c0caf5 - Same as fg (bright)
assistant_fg: Color::Rgb(169, 177, 214), // #a9b1d6 - Brighter gray-blue (READABLE)
tool_fg: Color::Rgb(224, 175, 104), // #e0af68 - Yellow for tools
timestamp_fg: Color::Rgb(115, 122, 162), // #737aa2 - Brighter dim
divider_fg: Color::Rgb(73, 82, 115), // #495273 - Border color (BRIGHTER)
// Command colors
cmd_slash: Color::Rgb(187, 154, 247), // #bb9af7 - Purple for / prefix
cmd_name: Color::Rgb(192, 202, 245), // #c0caf5 - White for command name
cmd_desc: Color::Rgb(115, 122, 162), // #737aa2 - Brighter description
// Overlay colors
overlay_bg: Color::Rgb(36, 40, 59), // #24283b - Slightly lighter than bg
}
}
/// Dracula inspired palette - classic and elegant
pub fn dracula() -> Self {
Self {
primary: Color::Rgb(139, 233, 253), // Cyan
secondary: Color::Rgb(189, 147, 249), // Purple
accent: Color::Rgb(255, 121, 198), // Pink
success: Color::Rgb(80, 250, 123), // Green
warning: Color::Rgb(241, 250, 140), // Yellow
error: Color::Rgb(255, 85, 85), // Red
info: Color::Rgb(139, 233, 253), // Cyan
bg: Color::Rgb(40, 42, 54), // Dark bg
fg: Color::Rgb(248, 248, 242), // Light text
fg_dim: Color::Rgb(98, 114, 164), // Comment
fg_muted: Color::Rgb(68, 71, 90), // Very dim
highlight: Color::Rgb(68, 71, 90), // Selection
border: Color::Rgb(68, 71, 90),
selection: Color::Rgb(68, 71, 90),
claude: Color::Rgb(255, 121, 198),
ollama: Color::Rgb(139, 233, 253),
openai: Color::Rgb(80, 250, 123),
user_fg: Color::Rgb(248, 248, 242),
assistant_fg: Color::Rgb(189, 186, 220), // Softer purple-gray
tool_fg: Color::Rgb(241, 250, 140),
timestamp_fg: Color::Rgb(68, 71, 90),
divider_fg: Color::Rgb(68, 71, 90),
cmd_slash: Color::Rgb(189, 147, 249), // Purple
cmd_name: Color::Rgb(248, 248, 242),
cmd_desc: Color::Rgb(98, 114, 164),
overlay_bg: Color::Rgb(50, 52, 64),
}
}
/// Catppuccin Mocha - warm and cozy
pub fn catppuccin() -> Self {
Self {
primary: Color::Rgb(137, 180, 250), // Blue
secondary: Color::Rgb(203, 166, 247), // Mauve
accent: Color::Rgb(245, 194, 231), // Pink
success: Color::Rgb(166, 227, 161), // Green
warning: Color::Rgb(249, 226, 175), // Yellow
error: Color::Rgb(243, 139, 168), // Red
info: Color::Rgb(148, 226, 213), // Teal
bg: Color::Rgb(30, 30, 46), // Base
fg: Color::Rgb(205, 214, 244), // Text
fg_dim: Color::Rgb(108, 112, 134), // Overlay
fg_muted: Color::Rgb(69, 71, 90), // Surface
highlight: Color::Rgb(49, 50, 68), // Surface
border: Color::Rgb(69, 71, 90),
selection: Color::Rgb(49, 50, 68),
claude: Color::Rgb(245, 194, 231),
ollama: Color::Rgb(137, 180, 250),
openai: Color::Rgb(166, 227, 161),
user_fg: Color::Rgb(205, 214, 244),
assistant_fg: Color::Rgb(166, 187, 213), // Softer blue-gray
tool_fg: Color::Rgb(249, 226, 175),
timestamp_fg: Color::Rgb(69, 71, 90),
divider_fg: Color::Rgb(69, 71, 90),
cmd_slash: Color::Rgb(203, 166, 247), // Mauve
cmd_name: Color::Rgb(205, 214, 244),
cmd_desc: Color::Rgb(108, 112, 134),
overlay_bg: Color::Rgb(40, 40, 56),
}
}
/// Nord - minimal and clean
pub fn nord() -> Self {
Self {
primary: Color::Rgb(136, 192, 208), // Frost cyan
secondary: Color::Rgb(129, 161, 193), // Frost blue
accent: Color::Rgb(180, 142, 173), // Aurora purple
success: Color::Rgb(163, 190, 140), // Aurora green
warning: Color::Rgb(235, 203, 139), // Aurora yellow
error: Color::Rgb(191, 97, 106), // Aurora red
info: Color::Rgb(136, 192, 208), // Frost cyan
bg: Color::Rgb(46, 52, 64), // Polar night
fg: Color::Rgb(236, 239, 244), // Snow storm
fg_dim: Color::Rgb(76, 86, 106), // Polar night light
fg_muted: Color::Rgb(59, 66, 82),
highlight: Color::Rgb(59, 66, 82), // Selection
border: Color::Rgb(59, 66, 82),
selection: Color::Rgb(59, 66, 82),
claude: Color::Rgb(180, 142, 173),
ollama: Color::Rgb(136, 192, 208),
openai: Color::Rgb(163, 190, 140),
user_fg: Color::Rgb(236, 239, 244),
assistant_fg: Color::Rgb(180, 195, 210), // Softer blue-gray
tool_fg: Color::Rgb(235, 203, 139),
timestamp_fg: Color::Rgb(59, 66, 82),
divider_fg: Color::Rgb(59, 66, 82),
cmd_slash: Color::Rgb(180, 142, 173), // Aurora purple
cmd_name: Color::Rgb(236, 239, 244),
cmd_desc: Color::Rgb(76, 86, 106),
overlay_bg: Color::Rgb(56, 62, 74),
}
}
/// Synthwave - vibrant and retro
pub fn synthwave() -> Self {
Self {
primary: Color::Rgb(255, 0, 128), // Hot pink
secondary: Color::Rgb(0, 229, 255), // Cyan
accent: Color::Rgb(255, 128, 0), // Orange
success: Color::Rgb(0, 255, 157), // Neon green
warning: Color::Rgb(255, 215, 0), // Gold
error: Color::Rgb(255, 64, 64), // Neon red
info: Color::Rgb(0, 229, 255), // Cyan
bg: Color::Rgb(20, 16, 32), // Dark purple
fg: Color::Rgb(242, 233, 255), // Light purple
fg_dim: Color::Rgb(127, 90, 180), // Mid purple
fg_muted: Color::Rgb(72, 12, 168),
highlight: Color::Rgb(72, 12, 168), // Deep purple
border: Color::Rgb(72, 12, 168),
selection: Color::Rgb(72, 12, 168),
claude: Color::Rgb(255, 128, 0),
ollama: Color::Rgb(0, 229, 255),
openai: Color::Rgb(0, 255, 157),
user_fg: Color::Rgb(242, 233, 255),
assistant_fg: Color::Rgb(180, 170, 220), // Softer purple
tool_fg: Color::Rgb(255, 215, 0),
timestamp_fg: Color::Rgb(72, 12, 168),
divider_fg: Color::Rgb(72, 12, 168),
cmd_slash: Color::Rgb(255, 0, 128), // Hot pink
cmd_name: Color::Rgb(242, 233, 255),
cmd_desc: Color::Rgb(127, 90, 180),
overlay_bg: Color::Rgb(30, 26, 42),
}
}
/// Rose Pine - elegant and muted
pub fn rose_pine() -> Self {
Self {
primary: Color::Rgb(156, 207, 216), // Foam
secondary: Color::Rgb(235, 188, 186), // Rose
accent: Color::Rgb(234, 154, 151), // Love
success: Color::Rgb(49, 116, 143), // Pine
warning: Color::Rgb(246, 193, 119), // Gold
error: Color::Rgb(235, 111, 146), // Love (darker)
info: Color::Rgb(156, 207, 216), // Foam
bg: Color::Rgb(25, 23, 36), // Base
fg: Color::Rgb(224, 222, 244), // Text
fg_dim: Color::Rgb(110, 106, 134), // Muted
fg_muted: Color::Rgb(42, 39, 63),
highlight: Color::Rgb(42, 39, 63), // Highlight
border: Color::Rgb(42, 39, 63),
selection: Color::Rgb(42, 39, 63),
claude: Color::Rgb(234, 154, 151),
ollama: Color::Rgb(156, 207, 216),
openai: Color::Rgb(49, 116, 143),
user_fg: Color::Rgb(224, 222, 244),
assistant_fg: Color::Rgb(180, 185, 210), // Softer lavender-gray
tool_fg: Color::Rgb(246, 193, 119),
timestamp_fg: Color::Rgb(42, 39, 63),
divider_fg: Color::Rgb(42, 39, 63),
cmd_slash: Color::Rgb(235, 188, 186), // Rose
cmd_name: Color::Rgb(224, 222, 244),
cmd_desc: Color::Rgb(110, 106, 134),
overlay_bg: Color::Rgb(35, 33, 46),
}
}
/// Midnight Ocean - deep and serene
pub fn midnight_ocean() -> Self {
Self {
primary: Color::Rgb(102, 217, 239), // Bright cyan
secondary: Color::Rgb(130, 170, 255), // Periwinkle
accent: Color::Rgb(199, 146, 234), // Purple
success: Color::Rgb(163, 190, 140), // Sea green
warning: Color::Rgb(229, 200, 144), // Sandy yellow
error: Color::Rgb(236, 95, 103), // Coral red
info: Color::Rgb(102, 217, 239), // Bright cyan
bg: Color::Rgb(1, 22, 39), // Deep ocean
fg: Color::Rgb(201, 211, 235), // Light blue-white
fg_dim: Color::Rgb(71, 103, 145), // Muted blue
fg_muted: Color::Rgb(13, 43, 69),
highlight: Color::Rgb(13, 43, 69), // Deep blue
border: Color::Rgb(13, 43, 69),
selection: Color::Rgb(13, 43, 69),
claude: Color::Rgb(199, 146, 234),
ollama: Color::Rgb(102, 217, 239),
openai: Color::Rgb(163, 190, 140),
user_fg: Color::Rgb(201, 211, 235),
assistant_fg: Color::Rgb(150, 175, 200), // Softer blue-gray
tool_fg: Color::Rgb(229, 200, 144),
timestamp_fg: Color::Rgb(13, 43, 69),
divider_fg: Color::Rgb(13, 43, 69),
cmd_slash: Color::Rgb(199, 146, 234), // Purple
cmd_name: Color::Rgb(201, 211, 235),
cmd_desc: Color::Rgb(71, 103, 145),
overlay_bg: Color::Rgb(11, 32, 49),
}
}
}
/// LLM Provider enum
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Provider {
Claude,
Ollama,
OpenAI,
}
impl Provider {
pub fn name(&self) -> &'static str {
match self {
Provider::Claude => "Claude",
Provider::Ollama => "Ollama",
Provider::OpenAI => "OpenAI",
}
}
pub fn all() -> &'static [Provider] {
&[Provider::Claude, Provider::Ollama, Provider::OpenAI]
}
}
/// Vim-like editing mode
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum VimMode {
#[default]
Normal,
Insert,
Visual,
Command,
}
impl VimMode {
pub fn indicator(&self, symbols: &Symbols) -> &'static str {
match self {
VimMode::Normal => symbols.mode_normal,
VimMode::Insert => symbols.mode_insert,
VimMode::Visual => symbols.mode_visual,
VimMode::Command => symbols.mode_command,
}
}
}
/// Theme configuration for the borderless TUI
#[derive(Debug, Clone)]
pub struct Theme {
pub palette: ColorPalette,
pub symbols: Symbols,
pub capability: TerminalCapability,
// Message styles
pub user_message: Style,
pub assistant_message: Style,
pub tool_call: Style,
pub tool_result_success: Style,
pub tool_result_error: Style,
pub system_message: Style,
pub timestamp: Style,
// UI element styles
pub divider: Style,
pub header: Style,
pub header_accent: Style,
pub tab_active: Style,
pub tab_inactive: Style,
pub input_prefix: Style,
pub input_text: Style,
pub input_placeholder: Style,
pub status_bar: Style,
pub status_accent: Style,
pub status_dim: Style,
// Command styles
pub cmd_slash: Style, // Purple for / prefix
pub cmd_name: Style, // White for command name
pub cmd_desc: Style, // Dim for description
// Overlay/modal styles
pub overlay_bg: Style, // Modal background
pub selection_bg: Style, // Selected row background
// Popup styles (for permission dialogs)
pub popup_border: Style,
pub popup_bg: Style,
pub popup_title: Style,
pub selected: Style,
// Legacy compatibility
pub border: Style,
pub border_active: Style,
pub status_bar_highlight: Style,
pub input_box: Style,
pub input_box_active: Style,
}
impl Theme {
/// Create theme from color palette with automatic capability detection
pub fn from_palette(palette: ColorPalette) -> Self {
let capability = TerminalCapability::detect();
Self::from_palette_with_capability(palette, capability)
}
/// Create theme with specific terminal capability
pub fn from_palette_with_capability(palette: ColorPalette, capability: TerminalCapability) -> Self {
let symbols = Symbols::for_capability(capability);
Self {
// Message styles
user_message: Style::default()
.fg(palette.user_fg)
.add_modifier(Modifier::BOLD),
assistant_message: Style::default().fg(palette.assistant_fg),
tool_call: Style::default()
.fg(palette.tool_fg)
.add_modifier(Modifier::ITALIC),
tool_result_success: Style::default()
.fg(palette.success)
.add_modifier(Modifier::BOLD),
tool_result_error: Style::default()
.fg(palette.error)
.add_modifier(Modifier::BOLD),
system_message: Style::default().fg(palette.fg_dim),
timestamp: Style::default().fg(palette.timestamp_fg),
// UI elements
divider: Style::default().fg(palette.divider_fg),
header: Style::default()
.fg(palette.fg)
.add_modifier(Modifier::BOLD),
header_accent: Style::default()
.fg(palette.accent)
.add_modifier(Modifier::BOLD),
tab_active: Style::default()
.fg(palette.primary)
.add_modifier(Modifier::BOLD | Modifier::UNDERLINED),
tab_inactive: Style::default().fg(palette.fg_dim),
input_prefix: Style::default()
.fg(palette.accent)
.add_modifier(Modifier::BOLD),
input_text: Style::default().fg(palette.fg),
input_placeholder: Style::default().fg(palette.fg_muted),
status_bar: Style::default().fg(palette.fg_dim),
status_accent: Style::default().fg(palette.accent),
status_dim: Style::default().fg(palette.fg_muted),
// Command styles
cmd_slash: Style::default().fg(palette.cmd_slash),
cmd_name: Style::default().fg(palette.cmd_name),
cmd_desc: Style::default().fg(palette.cmd_desc),
// Overlay/modal styles
overlay_bg: Style::default().bg(palette.overlay_bg),
selection_bg: Style::default().bg(palette.selection),
// Popup styles
popup_border: Style::default()
.fg(palette.border)
.add_modifier(Modifier::BOLD),
popup_bg: Style::default().bg(palette.overlay_bg),
popup_title: Style::default()
.fg(palette.fg)
.add_modifier(Modifier::BOLD),
selected: Style::default()
.fg(palette.fg)
.bg(palette.selection)
.add_modifier(Modifier::BOLD),
// Legacy compatibility
border: Style::default().fg(palette.fg_dim),
border_active: Style::default()
.fg(palette.primary)
.add_modifier(Modifier::BOLD),
status_bar_highlight: Style::default()
.fg(palette.bg)
.bg(palette.accent)
.add_modifier(Modifier::BOLD),
input_box: Style::default().fg(palette.fg),
input_box_active: Style::default()
.fg(palette.accent)
.add_modifier(Modifier::BOLD),
symbols,
capability,
palette,
}
}
/// Get provider-specific color
pub fn provider_color(&self, provider: Provider) -> Color {
match provider {
Provider::Claude => self.palette.claude,
Provider::Ollama => self.palette.ollama,
Provider::OpenAI => self.palette.openai,
}
}
/// Get provider icon
pub fn provider_icon(&self, provider: Provider) -> &str {
match provider {
Provider::Claude => self.symbols.claude_icon,
Provider::Ollama => self.symbols.ollama_icon,
Provider::OpenAI => self.symbols.openai_icon,
}
}
/// Create a horizontal rule string of given width
pub fn horizontal_rule(&self, width: usize) -> String {
self.symbols.horizontal_rule.repeat(width)
}
/// Tokyo Night theme (default) - modern and vibrant
pub fn tokyo_night() -> Self {
Self::from_palette(ColorPalette::tokyo_night())
}
/// Dracula theme - classic dark theme
pub fn dracula() -> Self {
Self::from_palette(ColorPalette::dracula())
}
/// Catppuccin Mocha - warm and cozy
pub fn catppuccin() -> Self {
Self::from_palette(ColorPalette::catppuccin())
}
/// Nord theme - minimal and clean
pub fn nord() -> Self {
Self::from_palette(ColorPalette::nord())
}
/// Synthwave theme - vibrant retro
pub fn synthwave() -> Self {
Self::from_palette(ColorPalette::synthwave())
}
/// Rose Pine theme - elegant and muted
pub fn rose_pine() -> Self {
Self::from_palette(ColorPalette::rose_pine())
}
/// Midnight Ocean theme - deep and serene
pub fn midnight_ocean() -> Self {
Self::from_palette(ColorPalette::midnight_ocean())
}
}
impl Default for Theme {
fn default() -> Self {
Self::tokyo_night()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_terminal_capability_detection() {
let cap = TerminalCapability::detect();
// Should return some valid capability
assert!(matches!(
cap,
TerminalCapability::Full | TerminalCapability::Unicode256 | TerminalCapability::Basic
));
}
#[test]
fn test_symbols_for_capability() {
let unicode = Symbols::for_capability(TerminalCapability::Full);
assert_eq!(unicode.horizontal_rule, "");
let ascii = Symbols::for_capability(TerminalCapability::Basic);
assert_eq!(ascii.horizontal_rule, "-");
}
#[test]
fn test_theme_from_palette() {
let theme = Theme::tokyo_night();
assert!(theme.capability.supports_unicode() || !theme.capability.supports_unicode());
}
#[test]
fn test_provider_colors() {
let theme = Theme::tokyo_night();
let claude_color = theme.provider_color(Provider::Claude);
let ollama_color = theme.provider_color(Provider::Ollama);
assert_ne!(claude_color, ollama_color);
}
#[test]
fn test_vim_mode_indicator() {
let symbols = Symbols::unicode();
assert_eq!(VimMode::Normal.indicator(&symbols), "[N]");
assert_eq!(VimMode::Insert.indicator(&symbols), "[I]");
}
}

View File

@@ -0,0 +1,29 @@
[package]
name = "agent-core"
version = "0.1.0"
edition.workspace = true
license.workspace = true
rust-version.workspace = true
[dependencies]
serde = { version = "1", features = ["derive"] }
serde_json = "1"
color-eyre = "0.6"
tokio = { version = "1", features = ["full"] }
futures-util = "0.3"
tracing = "0.1"
async-trait = "0.1"
chrono = "0.4"
# Internal dependencies
llm-core = { path = "../../llm/core" }
permissions = { path = "../../platform/permissions" }
tools-fs = { path = "../../tools/fs" }
tools-bash = { path = "../../tools/bash" }
tools-ask = { path = "../../tools/ask" }
tools-todo = { path = "../../tools/todo" }
tools-web = { path = "../../tools/web" }
tools-plan = { path = "../../tools/plan" }
[dev-dependencies]
tempfile = "3.13"

View File

@@ -0,0 +1,74 @@
//! Example demonstrating the git integration module
//!
//! Run with: cargo run -p agent-core --example git_demo
use agent_core::{detect_git_state, format_git_status, is_safe_git_command, is_destructive_git_command};
use std::env;
fn main() -> color_eyre::Result<()> {
color_eyre::install()?;
// Get current working directory
let cwd = env::current_dir()?;
println!("Detecting git state in: {}\n", cwd.display());
// Detect git state
let state = detect_git_state(&cwd)?;
// Display formatted status
println!("{}\n", format_git_status(&state));
// Show detailed file status if there are changes
if !state.status.is_empty() {
println!("Detailed file status:");
for status in &state.status {
match status {
agent_core::GitFileStatus::Modified { path } => {
println!(" M {}", path);
}
agent_core::GitFileStatus::Added { path } => {
println!(" A {}", path);
}
agent_core::GitFileStatus::Deleted { path } => {
println!(" D {}", path);
}
agent_core::GitFileStatus::Renamed { from, to } => {
println!(" R {} -> {}", from, to);
}
agent_core::GitFileStatus::Untracked { path } => {
println!(" ? {}", path);
}
}
}
println!();
}
// Test command safety checking
println!("Command safety checks:");
let test_commands = vec![
"git status",
"git log --oneline",
"git diff HEAD",
"git commit -m 'test'",
"git push --force origin main",
"git reset --hard HEAD~1",
"git rebase main",
"git branch -D feature",
];
for cmd in test_commands {
let is_safe = is_safe_git_command(cmd);
let (is_destructive, warning) = is_destructive_git_command(cmd);
print!(" {} - ", cmd);
if is_safe {
println!("SAFE (read-only)");
} else if is_destructive {
println!("DESTRUCTIVE: {}", warning);
} else {
println!("UNSAFE (modifies state)");
}
}
Ok(())
}

View File

@@ -0,0 +1,92 @@
/// Example demonstrating the streaming agent loop API
///
/// This example shows how to use `run_agent_loop_streaming` to receive
/// real-time events during agent execution, including:
/// - Text deltas as the LLM generates text
/// - Tool execution start/end events
/// - Tool output events
/// - Final completion events
///
/// Run with: cargo run --example streaming_agent -p agent-core
use agent_core::{create_event_channel, run_agent_loop_streaming, AgentEvent, ToolContext};
use llm_core::ChatOptions;
use permissions::{Mode, PermissionManager};
#[tokio::main]
async fn main() -> color_eyre::Result<()> {
color_eyre::install()?;
// Note: This is a minimal example. In a real application, you would:
// 1. Initialize a real LLM provider (e.g., OllamaClient)
// 2. Configure the ChatOptions with your preferred model
// 3. Set up appropriate permissions and tool context
println!("=== Streaming Agent Example ===\n");
println!("This example demonstrates how to use the streaming agent loop API.");
println!("To run with a real LLM provider, modify this example to:");
println!(" 1. Create an LLM provider instance");
println!(" 2. Set up permissions and tool context");
println!(" 3. Call run_agent_loop_streaming with your prompt\n");
// Example code structure:
println!("Example code:");
println!("```rust");
println!("// Create LLM provider");
println!("let provider = OllamaClient::new(\"http://localhost:11434\");");
println!();
println!("// Set up permissions and context");
println!("let perms = PermissionManager::new(Mode::Plan);");
println!("let ctx = ToolContext::default();");
println!();
println!("// Create event channel");
println!("let (tx, mut rx) = create_event_channel();");
println!();
println!("// Spawn agent loop");
println!("let handle = tokio::spawn(async move {{");
println!(" run_agent_loop_streaming(");
println!(" &provider,");
println!(" \"Your prompt here\",");
println!(" &ChatOptions::default(),");
println!(" &perms,");
println!(" &ctx,");
println!(" tx,");
println!(" ).await");
println!("}});");
println!();
println!("// Process events");
println!("while let Some(event) = rx.recv().await {{");
println!(" match event {{");
println!(" AgentEvent::TextDelta(text) => {{");
println!(" print!(\"{{text}}\");");
println!(" }}");
println!(" AgentEvent::ToolStart {{ tool_name, .. }} => {{");
println!(" println!(\"\\n[Executing tool: {{tool_name}}]\");");
println!(" }}");
println!(" AgentEvent::ToolOutput {{ content, is_error, .. }} => {{");
println!(" if is_error {{");
println!(" eprintln!(\"Error: {{content}}\");");
println!(" }} else {{");
println!(" println!(\"Output: {{content}}\");");
println!(" }}");
println!(" }}");
println!(" AgentEvent::ToolEnd {{ success, .. }} => {{");
println!(" println!(\"[Tool finished: {{}}]\", if success {{ \"success\" }} else {{ \"failed\" }});");
println!(" }}");
println!(" AgentEvent::Done {{ final_response }} => {{");
println!(" println!(\"\\n\\nFinal response: {{final_response}}\");");
println!(" break;");
println!(" }}");
println!(" AgentEvent::Error(e) => {{");
println!(" eprintln!(\"Error: {{e}}\");");
println!(" break;");
println!(" }}");
println!(" }}");
println!("}}");
println!();
println!("// Wait for completion");
println!("let result = handle.await??;");
println!("```");
Ok(())
}

View File

@@ -0,0 +1,218 @@
//! Context compaction for long conversations
//!
//! When the conversation context grows too large, this module compacts
//! earlier messages into a summary while preserving recent context.
use color_eyre::eyre::Result;
use llm_core::{ChatMessage, ChatOptions, LlmProvider};
/// Token limit threshold for triggering compaction
const CONTEXT_LIMIT: usize = 180_000;
/// Threshold ratio at which to trigger compaction (90% of limit)
const COMPACTION_THRESHOLD: f64 = 0.9;
/// Number of recent messages to preserve during compaction
const PRESERVE_RECENT: usize = 10;
/// Token counter for estimating context size
pub struct TokenCounter {
chars_per_token: f64,
}
impl Default for TokenCounter {
fn default() -> Self {
Self::new()
}
}
impl TokenCounter {
pub fn new() -> Self {
// Rough estimate: ~4 chars per token for English text
Self { chars_per_token: 4.0 }
}
/// Estimate token count for a message
pub fn count_message(&self, message: &ChatMessage) -> usize {
let content_len = message.content.as_ref().map(|c| c.len()).unwrap_or(0);
// Add overhead for role, metadata
let overhead = 10;
((content_len as f64 / self.chars_per_token) as usize) + overhead
}
/// Estimate total token count for all messages
pub fn count_messages(&self, messages: &[ChatMessage]) -> usize {
messages.iter().map(|m| self.count_message(m)).sum()
}
/// Check if context should be compacted
pub fn should_compact(&self, messages: &[ChatMessage]) -> bool {
let count = self.count_messages(messages);
count > (CONTEXT_LIMIT as f64 * COMPACTION_THRESHOLD) as usize
}
}
/// Context compactor that summarizes conversation history
pub struct Compactor {
token_counter: TokenCounter,
}
impl Default for Compactor {
fn default() -> Self {
Self::new()
}
}
impl Compactor {
pub fn new() -> Self {
Self {
token_counter: TokenCounter::new(),
}
}
/// Check if messages need compaction
pub fn needs_compaction(&self, messages: &[ChatMessage]) -> bool {
self.token_counter.should_compact(messages)
}
/// Compact messages by summarizing earlier conversation
///
/// Returns compacted messages with:
/// - A system message containing the summary of earlier context
/// - The most recent N messages preserved in full
pub async fn compact<P: LlmProvider>(
&self,
provider: &P,
messages: &[ChatMessage],
options: &ChatOptions,
) -> Result<Vec<ChatMessage>> {
// If not enough messages to compact, return as-is
if messages.len() <= PRESERVE_RECENT + 1 {
return Ok(messages.to_vec());
}
// Split into messages to summarize and messages to preserve
let split_point = messages.len().saturating_sub(PRESERVE_RECENT);
let to_summarize = &messages[..split_point];
let to_preserve = &messages[split_point..];
// Generate summary of earlier messages
let summary = self.summarize_messages(provider, to_summarize, options).await?;
// Build compacted message list
let mut compacted = Vec::with_capacity(PRESERVE_RECENT + 1);
// Add system message with summary
compacted.push(ChatMessage::system(format!(
"## Earlier Conversation Summary\n\n{}\n\n---\n\n\
The above summarizes the earlier part of this conversation. \
Continue from the recent messages below.",
summary
)));
// Add preserved recent messages
compacted.extend(to_preserve.iter().cloned());
Ok(compacted)
}
/// Generate a summary of messages using the LLM
async fn summarize_messages<P: LlmProvider>(
&self,
provider: &P,
messages: &[ChatMessage],
options: &ChatOptions,
) -> Result<String> {
// Format messages for summarization
let mut context = String::new();
for msg in messages {
let role = &msg.role;
let content = msg.content.as_deref().unwrap_or("");
context.push_str(&format!("[{:?}]: {}\n\n", role, content));
}
// Create summarization prompt
let summary_prompt = format!(
"Please provide a concise summary of the following conversation. \
Focus on:\n\
1. Key decisions made\n\
2. Important files or code mentioned\n\
3. Tasks completed and their outcomes\n\
4. Any pending items or next steps discussed\n\n\
Keep the summary informative but brief (under 500 words).\n\n\
Conversation:\n{}\n\n\
Summary:",
context
);
// Call LLM to generate summary
let summary_options = ChatOptions {
model: options.model.clone(),
max_tokens: Some(1000),
temperature: Some(0.3), // Lower temperature for more focused summary
..Default::default()
};
let summary_messages = vec![ChatMessage::user(&summary_prompt)];
let mut stream = provider.chat_stream(&summary_messages, &summary_options, None).await?;
let mut summary = String::new();
use futures_util::StreamExt;
while let Some(chunk_result) = stream.next().await {
if let Ok(chunk) = chunk_result {
if let Some(content) = &chunk.content {
summary.push_str(content);
}
}
}
Ok(summary.trim().to_string())
}
/// Get token counter for external use
pub fn token_counter(&self) -> &TokenCounter {
&self.token_counter
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_counter_estimate() {
let counter = TokenCounter::new();
let msg = ChatMessage::user("Hello, world!");
let count = counter.count_message(&msg);
// Should be approximately 13/4 + 10 overhead = 13
assert!(count > 10);
assert!(count < 20);
}
#[test]
fn test_should_compact() {
let counter = TokenCounter::new();
// Small message list shouldn't compact
let small_messages: Vec<ChatMessage> = (0..10)
.map(|i| ChatMessage::user(&format!("Message {}", i)))
.collect();
assert!(!counter.should_compact(&small_messages));
// Large message list should compact
// Need ~162,000 tokens = ~648,000 chars (at 4 chars per token)
let large_content = "x".repeat(700_000);
let large_messages = vec![ChatMessage::user(&large_content)];
assert!(counter.should_compact(&large_messages));
}
#[test]
fn test_compactor_needs_compaction() {
let compactor = Compactor::new();
let small: Vec<ChatMessage> = (0..5)
.map(|i| ChatMessage::user(&format!("Short message {}", i)))
.collect();
assert!(!compactor.needs_compaction(&small));
}
}

View File

@@ -0,0 +1,557 @@
//! Git integration module for detecting repository state and validating git commands.
//!
//! This module provides functionality to:
//! - Detect if the current directory is a git repository
//! - Capture git repository state (branch, status, uncommitted changes)
//! - Validate git commands for safety (read-only vs destructive operations)
use color_eyre::eyre::Result;
use std::path::Path;
use std::process::Command;
/// Status of a file in the git working tree
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum GitFileStatus {
/// File has been modified
Modified { path: String },
/// File has been added (staged)
Added { path: String },
/// File has been deleted
Deleted { path: String },
/// File has been renamed
Renamed { from: String, to: String },
/// File is untracked
Untracked { path: String },
}
impl GitFileStatus {
/// Get the primary path associated with this status
pub fn path(&self) -> &str {
match self {
Self::Modified { path } => path,
Self::Added { path } => path,
Self::Deleted { path } => path,
Self::Renamed { to, .. } => to,
Self::Untracked { path } => path,
}
}
}
/// Complete state of a git repository
#[derive(Debug, Clone)]
pub struct GitState {
/// Whether the current directory is in a git repository
pub is_git_repo: bool,
/// Current branch name (None if not in a repo or detached HEAD)
pub current_branch: Option<String>,
/// Main branch name (main/master, None if not detected)
pub main_branch: Option<String>,
/// Status of files in the working tree
pub status: Vec<GitFileStatus>,
/// Whether there are any uncommitted changes
pub has_uncommitted_changes: bool,
/// Remote URL for the repository (None if no remote configured)
pub remote_url: Option<String>,
}
impl GitState {
/// Create a default GitState for non-git directories
pub fn not_a_repo() -> Self {
Self {
is_git_repo: false,
current_branch: None,
main_branch: None,
status: Vec::new(),
has_uncommitted_changes: false,
remote_url: None,
}
}
}
/// Detect the current git repository state
///
/// This function runs various git commands to gather information about the repository.
/// If git is not available or the directory is not a git repo, returns a default state.
pub fn detect_git_state(working_dir: &Path) -> Result<GitState> {
// Check if this is a git repository
let is_repo = Command::new("git")
.arg("rev-parse")
.arg("--git-dir")
.current_dir(working_dir)
.output()
.map(|output| output.status.success())
.unwrap_or(false);
if !is_repo {
return Ok(GitState::not_a_repo());
}
// Get current branch
let current_branch = get_current_branch(working_dir)?;
// Detect main branch (try main first, then master)
let main_branch = detect_main_branch(working_dir)?;
// Get file status
let status = get_git_status(working_dir)?;
// Check if there are uncommitted changes
let has_uncommitted_changes = !status.is_empty();
// Get remote URL
let remote_url = get_remote_url(working_dir)?;
Ok(GitState {
is_git_repo: true,
current_branch,
main_branch,
status,
has_uncommitted_changes,
remote_url,
})
}
/// Get the current branch name
fn get_current_branch(working_dir: &Path) -> Result<Option<String>> {
let output = Command::new("git")
.arg("rev-parse")
.arg("--abbrev-ref")
.arg("HEAD")
.current_dir(working_dir)
.output()?;
if !output.status.success() {
return Ok(None);
}
let branch = String::from_utf8_lossy(&output.stdout).trim().to_string();
// "HEAD" means detached HEAD state
if branch == "HEAD" {
Ok(None)
} else {
Ok(Some(branch))
}
}
/// Detect the main branch (main or master)
fn detect_main_branch(working_dir: &Path) -> Result<Option<String>> {
// Try to get all branches
let output = Command::new("git")
.arg("branch")
.arg("-a")
.current_dir(working_dir)
.output()?;
if !output.status.success() {
return Ok(None);
}
let branches = String::from_utf8_lossy(&output.stdout);
// Check for main branch first (modern convention)
if branches.lines().any(|line| {
let trimmed = line.trim_start_matches('*').trim();
trimmed == "main" || trimmed.ends_with("/main")
}) {
return Ok(Some("main".to_string()));
}
// Fall back to master
if branches.lines().any(|line| {
let trimmed = line.trim_start_matches('*').trim();
trimmed == "master" || trimmed.ends_with("/master")
}) {
return Ok(Some("master".to_string()));
}
Ok(None)
}
/// Get the git status for all files
fn get_git_status(working_dir: &Path) -> Result<Vec<GitFileStatus>> {
let output = Command::new("git")
.arg("status")
.arg("--porcelain")
.arg("-z") // Null-terminated for better parsing
.current_dir(working_dir)
.output()?;
if !output.status.success() {
return Ok(Vec::new());
}
let status_text = String::from_utf8_lossy(&output.stdout);
let mut statuses = Vec::new();
// Parse porcelain format with null termination
// Format: XY filename\0 (where X is staged status, Y is unstaged status)
for entry in status_text.split('\0').filter(|s| !s.is_empty()) {
if entry.len() < 3 {
continue;
}
let status_code = &entry[0..2];
let path = entry[3..].to_string();
// Parse status codes
match status_code {
"M " | " M" | "MM" => {
statuses.push(GitFileStatus::Modified { path });
}
"A " | " A" | "AM" => {
statuses.push(GitFileStatus::Added { path });
}
"D " | " D" | "AD" => {
statuses.push(GitFileStatus::Deleted { path });
}
"??" => {
statuses.push(GitFileStatus::Untracked { path });
}
s if s.starts_with('R') => {
// Renamed files have format "R old_name -> new_name"
if let Some((from, to)) = path.split_once(" -> ") {
statuses.push(GitFileStatus::Renamed {
from: from.to_string(),
to: to.to_string(),
});
} else {
// Fallback if parsing fails
statuses.push(GitFileStatus::Modified { path });
}
}
_ => {
// Unknown status code, treat as modified
statuses.push(GitFileStatus::Modified { path });
}
}
}
Ok(statuses)
}
/// Get the remote URL for the repository
fn get_remote_url(working_dir: &Path) -> Result<Option<String>> {
let output = Command::new("git")
.arg("remote")
.arg("get-url")
.arg("origin")
.current_dir(working_dir)
.output()?;
if !output.status.success() {
return Ok(None);
}
let url = String::from_utf8_lossy(&output.stdout).trim().to_string();
if url.is_empty() {
Ok(None)
} else {
Ok(Some(url))
}
}
/// Check if a git command is safe (read-only)
///
/// Safe commands include:
/// - status, log, show, diff, branch (without -D)
/// - remote (without add/remove)
/// - config --get
/// - rev-parse, ls-files, ls-tree
pub fn is_safe_git_command(command: &str) -> bool {
let parts: Vec<&str> = command.split_whitespace().collect();
if parts.is_empty() || parts[0] != "git" {
return false;
}
if parts.len() < 2 {
return false;
}
let subcommand = parts[1];
// List of read-only git commands
match subcommand {
"status" | "log" | "show" | "diff" | "blame" | "reflog" => true,
"ls-files" | "ls-tree" | "ls-remote" => true,
"rev-parse" | "rev-list" => true,
"describe" | "tag" if !command.contains("-d") && !command.contains("--delete") => true,
"branch" if !command.contains("-D") && !command.contains("-d") && !command.contains("-m") => true,
"remote" if command.contains("get-url") || command.contains("-v") || command.contains("show") => true,
"config" if command.contains("--get") || command.contains("--list") => true,
"grep" | "shortlog" | "whatchanged" => true,
"fetch" if !command.contains("--prune") => true,
_ => false,
}
}
/// Check if a git command is destructive
///
/// Returns (is_destructive, warning_message) tuple.
/// Destructive commands include:
/// - push --force, reset --hard, clean -fd
/// - rebase, amend, filter-branch
/// - branch -D, tag -d
pub fn is_destructive_git_command(command: &str) -> (bool, &'static str) {
let cmd_lower = command.to_lowercase();
// Check for force push
if cmd_lower.contains("push") && (cmd_lower.contains("--force") || cmd_lower.contains("-f")) {
return (true, "Force push can overwrite remote history and affect other collaborators");
}
// Check for hard reset
if cmd_lower.contains("reset") && cmd_lower.contains("--hard") {
return (true, "Hard reset will discard uncommitted changes permanently");
}
// Check for git clean
if cmd_lower.contains("clean") && (cmd_lower.contains("-f") || cmd_lower.contains("-d")) {
return (true, "Git clean will permanently delete untracked files");
}
// Check for rebase
if cmd_lower.contains("rebase") {
return (true, "Rebase rewrites commit history and can cause conflicts");
}
// Check for amend
if cmd_lower.contains("commit") && cmd_lower.contains("--amend") {
return (true, "Amending rewrites the last commit and changes its hash");
}
// Check for filter-branch or filter-repo
if cmd_lower.contains("filter-branch") || cmd_lower.contains("filter-repo") {
return (true, "Filter operations rewrite repository history");
}
// Check for branch/tag deletion
if (cmd_lower.contains("branch") && (cmd_lower.contains("-D") || cmd_lower.contains("-d")))
|| (cmd_lower.contains("tag") && (cmd_lower.contains("-d") || cmd_lower.contains("--delete")))
{
return (true, "This will delete a branch or tag");
}
// Check for reflog expire
if cmd_lower.contains("reflog") && cmd_lower.contains("expire") {
return (true, "Expiring reflog removes recovery points for lost commits");
}
// Check for gc with aggressive or prune
if cmd_lower.contains("gc") && (cmd_lower.contains("--aggressive") || cmd_lower.contains("--prune")) {
return (true, "Aggressive garbage collection can make recovery difficult");
}
(false, "")
}
/// Format git state for human-readable display
///
/// Example output:
/// ```text
/// Git Repository: yes
/// Current branch: feature-branch
/// Main branch: main
/// Status: 3 modified, 1 untracked
/// Remote: https://github.com/user/repo.git
/// ```
pub fn format_git_status(state: &GitState) -> String {
if !state.is_git_repo {
return "Not a git repository".to_string();
}
let mut lines = Vec::new();
lines.push("Git Repository: yes".to_string());
if let Some(branch) = &state.current_branch {
lines.push(format!("Current branch: {}", branch));
} else {
lines.push("Current branch: (detached HEAD)".to_string());
}
if let Some(main) = &state.main_branch {
lines.push(format!("Main branch: {}", main));
}
// Summarize status
if state.status.is_empty() {
lines.push("Status: clean working tree".to_string());
} else {
let mut modified = 0;
let mut added = 0;
let mut deleted = 0;
let mut renamed = 0;
let mut untracked = 0;
for status in &state.status {
match status {
GitFileStatus::Modified { .. } => modified += 1,
GitFileStatus::Added { .. } => added += 1,
GitFileStatus::Deleted { .. } => deleted += 1,
GitFileStatus::Renamed { .. } => renamed += 1,
GitFileStatus::Untracked { .. } => untracked += 1,
}
}
let mut status_parts = Vec::new();
if modified > 0 {
status_parts.push(format!("{} modified", modified));
}
if added > 0 {
status_parts.push(format!("{} added", added));
}
if deleted > 0 {
status_parts.push(format!("{} deleted", deleted));
}
if renamed > 0 {
status_parts.push(format!("{} renamed", renamed));
}
if untracked > 0 {
status_parts.push(format!("{} untracked", untracked));
}
lines.push(format!("Status: {}", status_parts.join(", ")));
}
if let Some(url) = &state.remote_url {
lines.push(format!("Remote: {}", url));
} else {
lines.push("Remote: (none)".to_string());
}
lines.join("\n")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_safe_git_command() {
// Safe commands
assert!(is_safe_git_command("git status"));
assert!(is_safe_git_command("git log --oneline"));
assert!(is_safe_git_command("git diff HEAD"));
assert!(is_safe_git_command("git branch -v"));
assert!(is_safe_git_command("git remote -v"));
assert!(is_safe_git_command("git config --get user.name"));
// Unsafe commands
assert!(!is_safe_git_command("git commit -m test"));
assert!(!is_safe_git_command("git push origin main"));
assert!(!is_safe_git_command("git branch -D feature"));
assert!(!is_safe_git_command("git remote add origin url"));
}
#[test]
fn test_is_destructive_git_command() {
// Destructive commands
let (is_dest, msg) = is_destructive_git_command("git push --force origin main");
assert!(is_dest);
assert!(msg.contains("Force push"));
let (is_dest, msg) = is_destructive_git_command("git reset --hard HEAD~1");
assert!(is_dest);
assert!(msg.contains("Hard reset"));
let (is_dest, msg) = is_destructive_git_command("git clean -fd");
assert!(is_dest);
assert!(msg.contains("clean"));
let (is_dest, msg) = is_destructive_git_command("git rebase main");
assert!(is_dest);
assert!(msg.contains("Rebase"));
let (is_dest, msg) = is_destructive_git_command("git commit --amend");
assert!(is_dest);
assert!(msg.contains("Amending"));
// Non-destructive commands
let (is_dest, _) = is_destructive_git_command("git status");
assert!(!is_dest);
let (is_dest, _) = is_destructive_git_command("git log");
assert!(!is_dest);
let (is_dest, _) = is_destructive_git_command("git diff");
assert!(!is_dest);
}
#[test]
fn test_git_state_not_a_repo() {
let state = GitState::not_a_repo();
assert!(!state.is_git_repo);
assert!(state.current_branch.is_none());
assert!(state.main_branch.is_none());
assert!(state.status.is_empty());
assert!(!state.has_uncommitted_changes);
assert!(state.remote_url.is_none());
}
#[test]
fn test_git_file_status_path() {
let status = GitFileStatus::Modified {
path: "test.rs".to_string(),
};
assert_eq!(status.path(), "test.rs");
let status = GitFileStatus::Renamed {
from: "old.rs".to_string(),
to: "new.rs".to_string(),
};
assert_eq!(status.path(), "new.rs");
}
#[test]
fn test_format_git_status_not_repo() {
let state = GitState::not_a_repo();
let formatted = format_git_status(&state);
assert_eq!(formatted, "Not a git repository");
}
#[test]
fn test_format_git_status_clean() {
let state = GitState {
is_git_repo: true,
current_branch: Some("main".to_string()),
main_branch: Some("main".to_string()),
status: Vec::new(),
has_uncommitted_changes: false,
remote_url: Some("https://github.com/user/repo.git".to_string()),
};
let formatted = format_git_status(&state);
assert!(formatted.contains("Git Repository: yes"));
assert!(formatted.contains("Current branch: main"));
assert!(formatted.contains("clean working tree"));
}
#[test]
fn test_format_git_status_with_changes() {
let state = GitState {
is_git_repo: true,
current_branch: Some("feature".to_string()),
main_branch: Some("main".to_string()),
status: vec![
GitFileStatus::Modified {
path: "file1.rs".to_string(),
},
GitFileStatus::Modified {
path: "file2.rs".to_string(),
},
GitFileStatus::Untracked {
path: "new.rs".to_string(),
},
],
has_uncommitted_changes: true,
remote_url: None,
};
let formatted = format_git_status(&state);
assert!(formatted.contains("2 modified"));
assert!(formatted.contains("1 untracked"));
}
}

1130
crates/core/agent/src/lib.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,295 @@
use color_eyre::eyre::{Result, eyre};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::time::{Duration, SystemTime};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionStats {
pub start_time: SystemTime,
pub total_messages: usize,
pub total_tool_calls: usize,
pub total_duration: Duration,
pub estimated_tokens: usize,
}
impl SessionStats {
pub fn new() -> Self {
Self {
start_time: SystemTime::now(),
total_messages: 0,
total_tool_calls: 0,
total_duration: Duration::ZERO,
estimated_tokens: 0,
}
}
pub fn record_message(&mut self, tokens: usize, duration: Duration) {
self.total_messages += 1;
self.estimated_tokens += tokens;
self.total_duration += duration;
}
pub fn record_tool_call(&mut self) {
self.total_tool_calls += 1;
}
pub fn format_duration(d: Duration) -> String {
let secs = d.as_secs();
if secs < 60 {
format!("{}s", secs)
} else if secs < 3600 {
format!("{}m {}s", secs / 60, secs % 60)
} else {
format!("{}h {}m", secs / 3600, (secs % 3600) / 60)
}
}
}
impl Default for SessionStats {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct SessionHistory {
pub user_prompts: Vec<String>,
pub assistant_responses: Vec<String>,
pub tool_calls: Vec<ToolCallRecord>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRecord {
pub tool_name: String,
pub arguments: String,
pub result: String,
pub success: bool,
}
impl SessionHistory {
pub fn new() -> Self {
Self {
user_prompts: Vec::new(),
assistant_responses: Vec::new(),
tool_calls: Vec::new(),
}
}
pub fn add_user_message(&mut self, message: String) {
self.user_prompts.push(message);
}
pub fn add_assistant_message(&mut self, message: String) {
self.assistant_responses.push(message);
}
pub fn add_tool_call(&mut self, record: ToolCallRecord) {
self.tool_calls.push(record);
}
pub fn clear(&mut self) {
self.user_prompts.clear();
self.assistant_responses.clear();
self.tool_calls.clear();
}
}
impl Default for SessionHistory {
fn default() -> Self {
Self::new()
}
}
/// Represents a file modification with before/after content
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FileDiff {
pub path: PathBuf,
pub before: String,
pub after: String,
pub timestamp: SystemTime,
}
impl FileDiff {
/// Create a new file diff
pub fn new(path: PathBuf, before: String, after: String) -> Self {
Self {
path,
before,
after,
timestamp: SystemTime::now(),
}
}
}
/// A checkpoint captures the state of a session at a point in time
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint {
pub id: String,
pub timestamp: SystemTime,
pub stats: SessionStats,
pub user_prompts: Vec<String>,
pub assistant_responses: Vec<String>,
pub tool_calls: Vec<ToolCallRecord>,
pub file_diffs: Vec<FileDiff>,
}
impl Checkpoint {
/// Create a new checkpoint from current session state
pub fn new(
id: String,
stats: SessionStats,
history: &SessionHistory,
file_diffs: Vec<FileDiff>,
) -> Self {
Self {
id,
timestamp: SystemTime::now(),
stats,
user_prompts: history.user_prompts.clone(),
assistant_responses: history.assistant_responses.clone(),
tool_calls: history.tool_calls.clone(),
file_diffs,
}
}
/// Save checkpoint to disk
pub fn save(&self, checkpoint_dir: &Path) -> Result<()> {
fs::create_dir_all(checkpoint_dir)?;
let path = checkpoint_dir.join(format!("{}.json", self.id));
let content = serde_json::to_string_pretty(self)?;
fs::write(path, content)?;
Ok(())
}
/// Load checkpoint from disk
pub fn load(checkpoint_dir: &Path, id: &str) -> Result<Self> {
let path = checkpoint_dir.join(format!("{}.json", id));
let content = fs::read_to_string(&path)
.map_err(|e| eyre!("Failed to read checkpoint: {}", e))?;
let checkpoint: Checkpoint = serde_json::from_str(&content)
.map_err(|e| eyre!("Failed to parse checkpoint: {}", e))?;
Ok(checkpoint)
}
/// List all available checkpoints in a directory
pub fn list(checkpoint_dir: &Path) -> Result<Vec<String>> {
if !checkpoint_dir.exists() {
return Ok(Vec::new());
}
let mut checkpoints = Vec::new();
for entry in fs::read_dir(checkpoint_dir)? {
let entry = entry?;
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("json") {
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
checkpoints.push(stem.to_string());
}
}
}
// Sort by checkpoint ID (which includes timestamp)
checkpoints.sort();
Ok(checkpoints)
}
}
/// Session checkpoint manager
pub struct CheckpointManager {
checkpoint_dir: PathBuf,
file_snapshots: HashMap<PathBuf, String>,
}
impl CheckpointManager {
/// Create a new checkpoint manager
pub fn new(checkpoint_dir: PathBuf) -> Self {
Self {
checkpoint_dir,
file_snapshots: HashMap::new(),
}
}
/// Snapshot a file's current content before modification
pub fn snapshot_file(&mut self, path: &Path) -> Result<()> {
if !self.file_snapshots.contains_key(path) {
let content = fs::read_to_string(path).unwrap_or_default();
self.file_snapshots.insert(path.to_path_buf(), content);
}
Ok(())
}
/// Create a file diff after modification
pub fn create_diff(&self, path: &Path) -> Result<Option<FileDiff>> {
if let Some(before) = self.file_snapshots.get(path) {
let after = fs::read_to_string(path).unwrap_or_default();
if before != &after {
Ok(Some(FileDiff::new(
path.to_path_buf(),
before.clone(),
after,
)))
} else {
Ok(None)
}
} else {
Ok(None)
}
}
/// Get all file diffs since last checkpoint
pub fn get_all_diffs(&self) -> Result<Vec<FileDiff>> {
let mut diffs = Vec::new();
for (path, before) in &self.file_snapshots {
let after = fs::read_to_string(path).unwrap_or_default();
if before != &after {
diffs.push(FileDiff::new(path.clone(), before.clone(), after));
}
}
Ok(diffs)
}
/// Clear file snapshots
pub fn clear_snapshots(&mut self) {
self.file_snapshots.clear();
}
/// Save a checkpoint
pub fn save_checkpoint(
&mut self,
id: String,
stats: SessionStats,
history: &SessionHistory,
) -> Result<Checkpoint> {
let file_diffs = self.get_all_diffs()?;
let checkpoint = Checkpoint::new(id, stats, history, file_diffs);
checkpoint.save(&self.checkpoint_dir)?;
self.clear_snapshots();
Ok(checkpoint)
}
/// Load a checkpoint
pub fn load_checkpoint(&self, id: &str) -> Result<Checkpoint> {
Checkpoint::load(&self.checkpoint_dir, id)
}
/// List all checkpoints
pub fn list_checkpoints(&self) -> Result<Vec<String>> {
Checkpoint::list(&self.checkpoint_dir)
}
/// Rewind to a checkpoint by restoring file contents
pub fn rewind_to(&self, checkpoint_id: &str) -> Result<Vec<PathBuf>> {
let checkpoint = self.load_checkpoint(checkpoint_id)?;
let mut restored_files = Vec::new();
// Restore files from diffs (revert to 'before' state)
for diff in &checkpoint.file_diffs {
fs::write(&diff.path, &diff.before)?;
restored_files.push(diff.path.clone());
}
Ok(restored_files)
}
}

View File

@@ -0,0 +1,266 @@
//! System Prompt Management
//!
//! Composes system prompts from multiple sources for agent sessions.
use std::path::Path;
/// Builder for composing system prompts
#[derive(Debug, Clone, Default)]
pub struct SystemPromptBuilder {
sections: Vec<PromptSection>,
}
#[derive(Debug, Clone)]
struct PromptSection {
name: String,
content: String,
priority: i32, // Lower = earlier in prompt
}
impl SystemPromptBuilder {
pub fn new() -> Self {
Self::default()
}
/// Add the base agent prompt
pub fn with_base_prompt(mut self, content: impl Into<String>) -> Self {
self.sections.push(PromptSection {
name: "base".to_string(),
content: content.into(),
priority: 0,
});
self
}
/// Add tool usage instructions
pub fn with_tool_instructions(mut self, content: impl Into<String>) -> Self {
self.sections.push(PromptSection {
name: "tools".to_string(),
content: content.into(),
priority: 10,
});
self
}
/// Load and add project instructions from CLAUDE.md or .owlen.md
pub fn with_project_instructions(mut self, project_root: &Path) -> Self {
// Try CLAUDE.md first (Claude Code compatibility)
let claude_md = project_root.join("CLAUDE.md");
if claude_md.exists() {
if let Ok(content) = std::fs::read_to_string(&claude_md) {
self.sections.push(PromptSection {
name: "project".to_string(),
content: format!("# Project Instructions\n\n{}", content),
priority: 20,
});
return self;
}
}
// Fallback to .owlen.md
let owlen_md = project_root.join(".owlen.md");
if owlen_md.exists() {
if let Ok(content) = std::fs::read_to_string(&owlen_md) {
self.sections.push(PromptSection {
name: "project".to_string(),
content: format!("# Project Instructions\n\n{}", content),
priority: 20,
});
}
}
self
}
/// Add skill content
pub fn with_skill(mut self, skill_name: &str, content: impl Into<String>) -> Self {
self.sections.push(PromptSection {
name: format!("skill:{}", skill_name),
content: content.into(),
priority: 30,
});
self
}
/// Add hook-injected content (from SessionStart hooks)
pub fn with_hook_injection(mut self, content: impl Into<String>) -> Self {
self.sections.push(PromptSection {
name: "hook".to_string(),
content: content.into(),
priority: 40,
});
self
}
/// Add custom section
pub fn with_section(mut self, name: impl Into<String>, content: impl Into<String>, priority: i32) -> Self {
self.sections.push(PromptSection {
name: name.into(),
content: content.into(),
priority,
});
self
}
/// Build the final system prompt
pub fn build(mut self) -> String {
// Sort by priority
self.sections.sort_by_key(|s| s.priority);
// Join sections with separators
self.sections
.iter()
.map(|s| s.content.as_str())
.collect::<Vec<_>>()
.join("\n\n---\n\n")
}
/// Check if any content has been added
pub fn is_empty(&self) -> bool {
self.sections.is_empty()
}
}
/// Default base prompt for Owlen agent
pub fn default_base_prompt() -> &'static str {
r#"You are Owlen, an AI assistant that helps with software engineering tasks.
You have access to tools for reading files, writing code, running commands, and searching the web.
## Guidelines
1. Be direct and concise in your responses
2. Use tools to gather information before making changes
3. Explain your reasoning when making decisions
4. Ask for clarification when requirements are unclear
5. Prefer editing existing files over creating new ones
## Tool Usage
- Use `read` to examine file contents before editing
- Use `glob` and `grep` to find relevant files
- Use `edit` for precise changes, `write` for new files
- Use `bash` for running tests and commands
- Use `web_search` for current information"#
}
/// Generate tool instructions based on available tools
pub fn generate_tool_instructions(tool_names: &[&str]) -> String {
let mut instructions = String::from("## Available Tools\n\n");
for name in tool_names {
let desc = match *name {
"read" => "Read file contents",
"write" => "Create or overwrite a file",
"edit" => "Edit a file by replacing text",
"multi_edit" => "Apply multiple edits atomically",
"glob" => "Find files by pattern",
"grep" => "Search file contents",
"ls" => "List directory contents",
"bash" => "Execute shell commands",
"web_search" => "Search the web",
"web_fetch" => "Fetch a URL",
"todo_write" => "Update task list",
"ask_user" => "Ask user a question",
_ => continue,
};
instructions.push_str(&format!("- `{}`: {}\n", name, desc));
}
instructions
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder() {
let prompt = SystemPromptBuilder::new()
.with_base_prompt("You are helpful")
.with_tool_instructions("Use tools wisely")
.build();
assert!(prompt.contains("You are helpful"));
assert!(prompt.contains("Use tools wisely"));
}
#[test]
fn test_priority_ordering() {
let prompt = SystemPromptBuilder::new()
.with_section("last", "Third", 100)
.with_section("first", "First", 0)
.with_section("middle", "Second", 50)
.build();
let first_pos = prompt.find("First").unwrap();
let second_pos = prompt.find("Second").unwrap();
let third_pos = prompt.find("Third").unwrap();
assert!(first_pos < second_pos);
assert!(second_pos < third_pos);
}
#[test]
fn test_default_base_prompt() {
let prompt = default_base_prompt();
assert!(prompt.contains("Owlen"));
assert!(prompt.contains("Guidelines"));
assert!(prompt.contains("Tool Usage"));
}
#[test]
fn test_generate_tool_instructions() {
let tools = vec!["read", "write", "edit", "bash"];
let instructions = generate_tool_instructions(&tools);
assert!(instructions.contains("Available Tools"));
assert!(instructions.contains("read"));
assert!(instructions.contains("write"));
assert!(instructions.contains("edit"));
assert!(instructions.contains("bash"));
}
#[test]
fn test_builder_empty() {
let builder = SystemPromptBuilder::new();
assert!(builder.is_empty());
let builder = builder.with_base_prompt("test");
assert!(!builder.is_empty());
}
#[test]
fn test_skill_section() {
let prompt = SystemPromptBuilder::new()
.with_base_prompt("Base")
.with_skill("rust", "Rust expertise")
.build();
assert!(prompt.contains("Base"));
assert!(prompt.contains("Rust expertise"));
}
#[test]
fn test_hook_injection() {
let prompt = SystemPromptBuilder::new()
.with_base_prompt("Base")
.with_hook_injection("Additional context from hook")
.build();
assert!(prompt.contains("Base"));
assert!(prompt.contains("Additional context from hook"));
}
#[test]
fn test_separator_between_sections() {
let prompt = SystemPromptBuilder::new()
.with_section("first", "First section", 0)
.with_section("second", "Second section", 10)
.build();
assert!(prompt.contains("---"));
assert!(prompt.contains("First section"));
assert!(prompt.contains("Second section"));
}
}

View File

@@ -0,0 +1,210 @@
use agent_core::{Checkpoint, CheckpointManager, FileDiff, SessionHistory, SessionStats};
use std::fs;
use std::path::PathBuf;
use tempfile::TempDir;
#[test]
fn test_checkpoint_save_and_load() {
let temp_dir = TempDir::new().unwrap();
let checkpoint_dir = temp_dir.path().to_path_buf();
let stats = SessionStats::new();
let mut history = SessionHistory::new();
history.add_user_message("Hello".to_string());
history.add_assistant_message("Hi there!".to_string());
let file_diffs = vec![FileDiff::new(
PathBuf::from("test.txt"),
"before".to_string(),
"after".to_string(),
)];
let checkpoint = Checkpoint::new(
"test-checkpoint".to_string(),
stats.clone(),
&history,
file_diffs,
);
// Save checkpoint
checkpoint.save(&checkpoint_dir).unwrap();
// Load checkpoint
let loaded = Checkpoint::load(&checkpoint_dir, "test-checkpoint").unwrap();
assert_eq!(loaded.id, "test-checkpoint");
assert_eq!(loaded.user_prompts, vec!["Hello"]);
assert_eq!(loaded.assistant_responses, vec!["Hi there!"]);
assert_eq!(loaded.file_diffs.len(), 1);
assert_eq!(loaded.file_diffs[0].path, PathBuf::from("test.txt"));
assert_eq!(loaded.file_diffs[0].before, "before");
assert_eq!(loaded.file_diffs[0].after, "after");
}
#[test]
fn test_checkpoint_list() {
let temp_dir = TempDir::new().unwrap();
let checkpoint_dir = temp_dir.path().to_path_buf();
// Create a few checkpoints
for i in 1..=3 {
let checkpoint = Checkpoint::new(
format!("checkpoint-{}", i),
SessionStats::new(),
&SessionHistory::new(),
vec![],
);
checkpoint.save(&checkpoint_dir).unwrap();
}
let checkpoints = Checkpoint::list(&checkpoint_dir).unwrap();
assert_eq!(checkpoints.len(), 3);
assert!(checkpoints.contains(&"checkpoint-1".to_string()));
assert!(checkpoints.contains(&"checkpoint-2".to_string()));
assert!(checkpoints.contains(&"checkpoint-3".to_string()));
}
#[test]
fn test_checkpoint_manager_snapshot_and_diff() {
let temp_dir = TempDir::new().unwrap();
let checkpoint_dir = temp_dir.path().join("checkpoints");
let test_file = temp_dir.path().join("test.txt");
// Create initial file content
fs::write(&test_file, "initial content").unwrap();
let mut manager = CheckpointManager::new(checkpoint_dir.clone());
// Snapshot the file
manager.snapshot_file(&test_file).unwrap();
// Modify the file
fs::write(&test_file, "modified content").unwrap();
// Create a diff
let diff = manager.create_diff(&test_file).unwrap();
assert!(diff.is_some());
let diff = diff.unwrap();
assert_eq!(diff.path, test_file);
assert_eq!(diff.before, "initial content");
assert_eq!(diff.after, "modified content");
}
#[test]
fn test_checkpoint_manager_save_and_restore() {
let temp_dir = TempDir::new().unwrap();
let checkpoint_dir = temp_dir.path().join("checkpoints");
let test_file = temp_dir.path().join("test.txt");
// Create initial file content
fs::write(&test_file, "initial content").unwrap();
let mut manager = CheckpointManager::new(checkpoint_dir.clone());
// Snapshot the file
manager.snapshot_file(&test_file).unwrap();
// Modify the file
fs::write(&test_file, "modified content").unwrap();
// Save checkpoint
let mut history = SessionHistory::new();
history.add_user_message("test".to_string());
let checkpoint = manager
.save_checkpoint("test-checkpoint".to_string(), SessionStats::new(), &history)
.unwrap();
assert_eq!(checkpoint.file_diffs.len(), 1);
assert_eq!(checkpoint.file_diffs[0].before, "initial content");
assert_eq!(checkpoint.file_diffs[0].after, "modified content");
// Modify file again
fs::write(&test_file, "final content").unwrap();
assert_eq!(fs::read_to_string(&test_file).unwrap(), "final content");
// Rewind to checkpoint
let restored_files = manager.rewind_to("test-checkpoint").unwrap();
assert_eq!(restored_files.len(), 1);
assert_eq!(restored_files[0], test_file);
// File should be reverted to initial content (before the checkpoint)
assert_eq!(fs::read_to_string(&test_file).unwrap(), "initial content");
}
#[test]
fn test_checkpoint_manager_multiple_files() {
let temp_dir = TempDir::new().unwrap();
let checkpoint_dir = temp_dir.path().join("checkpoints");
let test_file1 = temp_dir.path().join("file1.txt");
let test_file2 = temp_dir.path().join("file2.txt");
// Create initial files
fs::write(&test_file1, "file1 initial").unwrap();
fs::write(&test_file2, "file2 initial").unwrap();
let mut manager = CheckpointManager::new(checkpoint_dir.clone());
// Snapshot both files
manager.snapshot_file(&test_file1).unwrap();
manager.snapshot_file(&test_file2).unwrap();
// Modify both files
fs::write(&test_file1, "file1 modified").unwrap();
fs::write(&test_file2, "file2 modified").unwrap();
// Save checkpoint
let checkpoint = manager
.save_checkpoint(
"multi-file-checkpoint".to_string(),
SessionStats::new(),
&SessionHistory::new(),
)
.unwrap();
assert_eq!(checkpoint.file_diffs.len(), 2);
// Modify files again
fs::write(&test_file1, "file1 final").unwrap();
fs::write(&test_file2, "file2 final").unwrap();
// Rewind
let restored_files = manager.rewind_to("multi-file-checkpoint").unwrap();
assert_eq!(restored_files.len(), 2);
// Both files should be reverted
assert_eq!(fs::read_to_string(&test_file1).unwrap(), "file1 initial");
assert_eq!(fs::read_to_string(&test_file2).unwrap(), "file2 initial");
}
#[test]
fn test_checkpoint_no_changes() {
let temp_dir = TempDir::new().unwrap();
let checkpoint_dir = temp_dir.path().join("checkpoints");
let test_file = temp_dir.path().join("test.txt");
// Create file
fs::write(&test_file, "content").unwrap();
let mut manager = CheckpointManager::new(checkpoint_dir.clone());
// Snapshot the file
manager.snapshot_file(&test_file).unwrap();
// Don't modify the file
// Create diff - should be None because nothing changed
let diff = manager.create_diff(&test_file).unwrap();
assert!(diff.is_none());
// Save checkpoint - should have no diffs
let checkpoint = manager
.save_checkpoint(
"no-change-checkpoint".to_string(),
SessionStats::new(),
&SessionHistory::new(),
)
.unwrap();
assert_eq!(checkpoint.file_diffs.len(), 0);
}

View File

@@ -0,0 +1,276 @@
use agent_core::{create_event_channel, run_agent_loop_streaming, AgentEvent, ToolContext};
use async_trait::async_trait;
use futures_util::stream;
use llm_core::{
ChatMessage, ChatOptions, LlmError, StreamChunk, LlmProvider, Tool, ToolCallDelta,
};
use permissions::{Mode, PermissionManager};
use std::pin::Pin;
/// Mock LLM provider for testing streaming
struct MockStreamingProvider {
responses: Vec<MockResponse>,
}
enum MockResponse {
/// Text-only response (no tool calls)
Text(Vec<String>), // Chunks of text
/// Tool call response
ToolCall {
text_chunks: Vec<String>,
tool_id: String,
tool_name: String,
tool_args: String,
},
}
#[async_trait]
impl LlmProvider for MockStreamingProvider {
fn name(&self) -> &str {
"mock"
}
fn model(&self) -> &str {
"mock-model"
}
async fn chat_stream(
&self,
messages: &[ChatMessage],
_options: &ChatOptions,
_tools: Option<&[Tool]>,
) -> Result<Pin<Box<dyn futures_util::Stream<Item = Result<StreamChunk, LlmError>> + Send>>, LlmError> {
// Determine which response to use based on message count
let response_idx = (messages.len() / 2).min(self.responses.len() - 1);
let response = &self.responses[response_idx];
let chunks: Vec<Result<StreamChunk, LlmError>> = match response {
MockResponse::Text(text_chunks) => text_chunks
.iter()
.map(|text| {
Ok(StreamChunk {
content: Some(text.clone()),
tool_calls: None,
done: false,
usage: None,
})
})
.collect(),
MockResponse::ToolCall {
text_chunks,
tool_id,
tool_name,
tool_args,
} => {
let mut result = vec![];
// First emit text chunks
for text in text_chunks {
result.push(Ok(StreamChunk {
content: Some(text.clone()),
tool_calls: None,
done: false,
usage: None,
}));
}
// Then emit tool call in chunks
result.push(Ok(StreamChunk {
content: None,
tool_calls: Some(vec![ToolCallDelta {
index: 0,
id: Some(tool_id.clone()),
function_name: Some(tool_name.clone()),
arguments_delta: None,
}]),
done: false,
usage: None,
}));
// Emit args in chunks
for chunk in tool_args.chars().collect::<Vec<_>>().chunks(5) {
result.push(Ok(StreamChunk {
content: None,
tool_calls: Some(vec![ToolCallDelta {
index: 0,
id: None,
function_name: None,
arguments_delta: Some(chunk.iter().collect()),
}]),
done: false,
usage: None,
}));
}
result
}
};
Ok(Box::pin(stream::iter(chunks)))
}
}
#[tokio::test]
async fn test_streaming_text_only() {
let provider = MockStreamingProvider {
responses: vec![MockResponse::Text(vec![
"Hello".to_string(),
" ".to_string(),
"world".to_string(),
"!".to_string(),
])],
};
let perms = PermissionManager::new(Mode::Plan);
let ctx = ToolContext::default();
let (tx, mut rx) = create_event_channel();
// Spawn the agent loop
let handle = tokio::spawn(async move {
run_agent_loop_streaming(
&provider,
"Say hello",
&ChatOptions::default(),
&perms,
&ctx,
tx,
)
.await
});
// Collect events
let mut text_deltas = vec![];
let mut done_response = None;
while let Some(event) = rx.recv().await {
match event {
AgentEvent::TextDelta(text) => {
text_deltas.push(text);
}
AgentEvent::Done { final_response } => {
done_response = Some(final_response);
break;
}
AgentEvent::Error(e) => {
panic!("Unexpected error: {}", e);
}
_ => {}
}
}
// Wait for agent loop to complete
let result = handle.await.unwrap();
assert!(result.is_ok());
// Verify events
assert_eq!(text_deltas, vec!["Hello", " ", "world", "!"]);
assert_eq!(done_response, Some("Hello world!".to_string()));
assert_eq!(result.unwrap(), "Hello world!");
}
#[tokio::test]
async fn test_streaming_with_tool_call() {
let provider = MockStreamingProvider {
responses: vec![
MockResponse::ToolCall {
text_chunks: vec!["Let me ".to_string(), "check...".to_string()],
tool_id: "call_123".to_string(),
tool_name: "glob".to_string(),
tool_args: r#"{"pattern":"*.rs"}"#.to_string(),
},
MockResponse::Text(vec!["Found ".to_string(), "the files!".to_string()]),
],
};
let perms = PermissionManager::new(Mode::Plan);
let ctx = ToolContext::default();
let (tx, mut rx) = create_event_channel();
// Spawn the agent loop
let handle = tokio::spawn(async move {
run_agent_loop_streaming(
&provider,
"Find Rust files",
&ChatOptions::default(),
&perms,
&ctx,
tx,
)
.await
});
// Collect events
let mut text_deltas = vec![];
let mut tool_starts = vec![];
let mut tool_outputs = vec![];
let mut tool_ends = vec![];
while let Some(event) = rx.recv().await {
match event {
AgentEvent::TextDelta(text) => {
text_deltas.push(text);
}
AgentEvent::ToolStart {
tool_name,
tool_id,
} => {
tool_starts.push((tool_name, tool_id));
}
AgentEvent::ToolOutput {
tool_id,
content,
is_error,
} => {
tool_outputs.push((tool_id, content, is_error));
}
AgentEvent::ToolEnd { tool_id, success } => {
tool_ends.push((tool_id, success));
}
AgentEvent::Done { .. } => {
break;
}
AgentEvent::Error(e) => {
panic!("Unexpected error: {}", e);
}
}
}
// Wait for agent loop to complete
let result = handle.await.unwrap();
assert!(result.is_ok());
// Verify we got text deltas from both responses
assert!(text_deltas.contains(&"Let me ".to_string()));
assert!(text_deltas.contains(&"check...".to_string()));
assert!(text_deltas.contains(&"Found ".to_string()));
assert!(text_deltas.contains(&"the files!".to_string()));
// Verify tool events
assert_eq!(tool_starts.len(), 1);
assert_eq!(tool_starts[0].0, "glob");
assert_eq!(tool_starts[0].1, "call_123");
assert_eq!(tool_outputs.len(), 1);
assert_eq!(tool_outputs[0].0, "call_123");
assert!(!tool_outputs[0].2); // not an error
assert_eq!(tool_ends.len(), 1);
assert_eq!(tool_ends[0].0, "call_123");
assert!(tool_ends[0].1); // success
}
#[tokio::test]
async fn test_channel_creation() {
let (tx, mut rx) = create_event_channel();
// Test that channel works
tx.send(AgentEvent::TextDelta("test".to_string()))
.await
.unwrap();
let event = rx.recv().await.unwrap();
match event {
AgentEvent::TextDelta(text) => assert_eq!(text, "test"),
_ => panic!("Wrong event type"),
}
}

View File

@@ -0,0 +1,114 @@
// Test that ToolContext properly wires up the placeholder tools
use agent_core::{ToolContext, execute_tool};
use permissions::{Mode, PermissionManager};
use tools_todo::{TodoList, TodoStatus};
use tools_bash::ShellManager;
use serde_json::json;
#[tokio::test]
async fn test_todo_write_with_context() {
let todo_list = TodoList::new();
let ctx = ToolContext::new().with_todo_list(todo_list.clone());
let perms = PermissionManager::new(Mode::Code); // Allow all tools
let arguments = json!({
"todos": [
{
"content": "First task",
"status": "pending",
"active_form": "Working on first task"
},
{
"content": "Second task",
"status": "in_progress",
"active_form": "Working on second task"
}
]
});
let result = execute_tool("todo_write", &arguments, &perms, &ctx).await;
assert!(result.is_ok(), "TodoWrite should succeed: {:?}", result);
// Verify the todos were written
let todos = todo_list.read();
assert_eq!(todos.len(), 2);
assert_eq!(todos[0].content, "First task");
assert_eq!(todos[1].status, TodoStatus::InProgress);
}
#[tokio::test]
async fn test_todo_write_without_context() {
let ctx = ToolContext::new(); // No todo_list
let perms = PermissionManager::new(Mode::Code);
let arguments = json!({
"todos": []
});
let result = execute_tool("todo_write", &arguments, &perms, &ctx).await;
assert!(result.is_err(), "TodoWrite should fail without TodoList");
assert!(result.unwrap_err().to_string().contains("not available"));
}
#[tokio::test]
async fn test_bash_output_with_context() {
let manager = ShellManager::new();
let ctx = ToolContext::new().with_shell_manager(manager.clone());
let perms = PermissionManager::new(Mode::Code);
// Start a shell and run a command
let shell_id = manager.start_shell().await.unwrap();
let _ = manager.execute(&shell_id, "echo test", None).await.unwrap();
let arguments = json!({
"shell_id": shell_id
});
let result = execute_tool("bash_output", &arguments, &perms, &ctx).await;
assert!(result.is_ok(), "BashOutput should succeed: {:?}", result);
}
#[tokio::test]
async fn test_bash_output_without_context() {
let ctx = ToolContext::new(); // No shell_manager
let perms = PermissionManager::new(Mode::Code);
let arguments = json!({
"shell_id": "fake-id"
});
let result = execute_tool("bash_output", &arguments, &perms, &ctx).await;
assert!(result.is_err(), "BashOutput should fail without ShellManager");
assert!(result.unwrap_err().to_string().contains("not available"));
}
#[tokio::test]
async fn test_kill_shell_with_context() {
let manager = ShellManager::new();
let ctx = ToolContext::new().with_shell_manager(manager.clone());
let perms = PermissionManager::new(Mode::Code);
// Start a shell
let shell_id = manager.start_shell().await.unwrap();
let arguments = json!({
"shell_id": shell_id
});
let result = execute_tool("kill_shell", &arguments, &perms, &ctx).await;
assert!(result.is_ok(), "KillShell should succeed: {:?}", result);
}
#[tokio::test]
async fn test_ask_user_without_context() {
let ctx = ToolContext::new(); // No ask_sender
let perms = PermissionManager::new(Mode::Code);
let arguments = json!({
"questions": []
});
let result = execute_tool("ask_user", &arguments, &perms, &ctx).await;
assert!(result.is_err(), "AskUser should fail without AskSender");
assert!(result.unwrap_err().to_string().contains("not available"));
}

View File

@@ -0,0 +1,18 @@
[package]
name = "llm-anthropic"
version = "0.1.0"
edition.workspace = true
license.workspace = true
description = "Anthropic Claude API client for Owlen"
[dependencies]
llm-core = { path = "../core" }
async-trait = "0.1"
futures = "0.3"
reqwest = { version = "0.12", features = ["json", "stream"] }
reqwest-eventsource = "0.6"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1", features = ["sync", "time"] }
tracing = "0.1"
uuid = { version = "1.0", features = ["v4"] }

View File

@@ -0,0 +1,285 @@
//! Anthropic OAuth Authentication
//!
//! Implements device code flow for authenticating with Anthropic without API keys.
use llm_core::{AuthMethod, DeviceAuthResult, DeviceCodeResponse, LlmError, OAuthProvider};
use reqwest::Client;
use serde::{Deserialize, Serialize};
/// OAuth client for Anthropic device flow
pub struct AnthropicAuth {
http: Client,
client_id: String,
}
// Anthropic OAuth endpoints (these would be the real endpoints)
const AUTH_BASE_URL: &str = "https://console.anthropic.com";
const DEVICE_CODE_ENDPOINT: &str = "/oauth/device/code";
const TOKEN_ENDPOINT: &str = "/oauth/token";
// Default client ID for Owlen CLI
const DEFAULT_CLIENT_ID: &str = "owlen-cli";
impl AnthropicAuth {
/// Create a new OAuth client with the default CLI client ID
pub fn new() -> Self {
Self {
http: Client::new(),
client_id: DEFAULT_CLIENT_ID.to_string(),
}
}
/// Create with a custom client ID
pub fn with_client_id(client_id: impl Into<String>) -> Self {
Self {
http: Client::new(),
client_id: client_id.into(),
}
}
}
impl Default for AnthropicAuth {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Serialize)]
struct DeviceCodeRequest<'a> {
client_id: &'a str,
scope: &'a str,
}
#[derive(Debug, Deserialize)]
struct DeviceCodeApiResponse {
device_code: String,
user_code: String,
verification_uri: String,
verification_uri_complete: Option<String>,
expires_in: u64,
interval: u64,
}
#[derive(Debug, Serialize)]
struct TokenRequest<'a> {
client_id: &'a str,
device_code: &'a str,
grant_type: &'a str,
}
#[derive(Debug, Deserialize)]
struct TokenApiResponse {
access_token: String,
#[allow(dead_code)]
token_type: String,
expires_in: Option<u64>,
refresh_token: Option<String>,
}
#[derive(Debug, Deserialize)]
struct TokenErrorResponse {
error: String,
error_description: Option<String>,
}
#[async_trait::async_trait]
impl OAuthProvider for AnthropicAuth {
async fn start_device_auth(&self) -> Result<DeviceCodeResponse, LlmError> {
let url = format!("{}{}", AUTH_BASE_URL, DEVICE_CODE_ENDPOINT);
let request = DeviceCodeRequest {
client_id: &self.client_id,
scope: "api:read api:write", // Request API access
};
let response = self
.http
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::Http(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::Auth(format!(
"Device code request failed ({}): {}",
status, text
)));
}
let api_response: DeviceCodeApiResponse = response
.json()
.await
.map_err(|e| LlmError::Json(e.to_string()))?;
Ok(DeviceCodeResponse {
device_code: api_response.device_code,
user_code: api_response.user_code,
verification_uri: api_response.verification_uri,
verification_uri_complete: api_response.verification_uri_complete,
expires_in: api_response.expires_in,
interval: api_response.interval,
})
}
async fn poll_device_auth(&self, device_code: &str) -> Result<DeviceAuthResult, LlmError> {
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
let request = TokenRequest {
client_id: &self.client_id,
device_code,
grant_type: "urn:ietf:params:oauth:grant-type:device_code",
};
let response = self
.http
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::Http(e.to_string()))?;
if response.status().is_success() {
let token_response: TokenApiResponse = response
.json()
.await
.map_err(|e| LlmError::Json(e.to_string()))?;
return Ok(DeviceAuthResult::Success {
access_token: token_response.access_token,
refresh_token: token_response.refresh_token,
expires_in: token_response.expires_in,
});
}
// Parse error response
let error_response: TokenErrorResponse = response
.json()
.await
.map_err(|e| LlmError::Json(e.to_string()))?;
match error_response.error.as_str() {
"authorization_pending" => Ok(DeviceAuthResult::Pending),
"slow_down" => Ok(DeviceAuthResult::Pending), // Treat as pending, caller should slow down
"access_denied" => Ok(DeviceAuthResult::Denied),
"expired_token" => Ok(DeviceAuthResult::Expired),
_ => Err(LlmError::Auth(format!(
"Token request failed: {} - {}",
error_response.error,
error_response.error_description.unwrap_or_default()
))),
}
}
async fn refresh_token(&self, refresh_token: &str) -> Result<AuthMethod, LlmError> {
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
#[derive(Serialize)]
struct RefreshRequest<'a> {
client_id: &'a str,
refresh_token: &'a str,
grant_type: &'a str,
}
let request = RefreshRequest {
client_id: &self.client_id,
refresh_token,
grant_type: "refresh_token",
};
let response = self
.http
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::Http(e.to_string()))?;
if !response.status().is_success() {
let text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::Auth(format!("Token refresh failed: {}", text)));
}
let token_response: TokenApiResponse = response
.json()
.await
.map_err(|e| LlmError::Json(e.to_string()))?;
let expires_at = token_response.expires_in.map(|secs| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() + secs)
.unwrap_or(0)
});
Ok(AuthMethod::OAuth {
access_token: token_response.access_token,
refresh_token: token_response.refresh_token,
expires_at,
})
}
}
/// Helper to perform the full device auth flow with polling
pub async fn perform_device_auth<F>(
auth: &AnthropicAuth,
on_code: F,
) -> Result<AuthMethod, LlmError>
where
F: FnOnce(&DeviceCodeResponse),
{
// Start the device flow
let device_code = auth.start_device_auth().await?;
// Let caller display the code to user
on_code(&device_code);
// Poll for completion
let poll_interval = std::time::Duration::from_secs(device_code.interval);
let deadline =
std::time::Instant::now() + std::time::Duration::from_secs(device_code.expires_in);
loop {
if std::time::Instant::now() > deadline {
return Err(LlmError::Auth("Device code expired".to_string()));
}
tokio::time::sleep(poll_interval).await;
match auth.poll_device_auth(&device_code.device_code).await? {
DeviceAuthResult::Success {
access_token,
refresh_token,
expires_in,
} => {
let expires_at = expires_in.map(|secs| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() + secs)
.unwrap_or(0)
});
return Ok(AuthMethod::OAuth {
access_token,
refresh_token,
expires_at,
});
}
DeviceAuthResult::Pending => continue,
DeviceAuthResult::Denied => {
return Err(LlmError::Auth("Authorization denied by user".to_string()));
}
DeviceAuthResult::Expired => {
return Err(LlmError::Auth("Device code expired".to_string()));
}
}
}
}

View File

@@ -0,0 +1,577 @@
//! Anthropic Claude API Client
//!
//! Implements the Messages API with streaming support.
use crate::types::*;
use async_trait::async_trait;
use futures::StreamExt;
use llm_core::{
AccountInfo, AuthMethod, ChatMessage, ChatOptions, ChatResponse, ChunkStream, FunctionCall,
LlmError, LlmProvider, ModelInfo, ProviderInfo, ProviderStatus, Role, StreamChunk, Tool,
ToolCall, ToolCallDelta, Usage, UsageStats,
};
use reqwest::Client;
use reqwest_eventsource::{Event, EventSource};
use std::sync::Arc;
use tokio::sync::Mutex;
const API_BASE_URL: &str = "https://api.anthropic.com";
const MESSAGES_ENDPOINT: &str = "/v1/messages";
const API_VERSION: &str = "2023-06-01";
const DEFAULT_MAX_TOKENS: u32 = 8192;
/// Anthropic Claude API client
pub struct AnthropicClient {
http: Client,
auth: AuthMethod,
model: String,
}
impl AnthropicClient {
/// Create a new client with API key authentication
pub fn new(api_key: impl Into<String>) -> Self {
Self {
http: Client::new(),
auth: AuthMethod::api_key(api_key),
model: "claude-sonnet-4-20250514".to_string(),
}
}
/// Create a new client with OAuth token
pub fn with_oauth(access_token: impl Into<String>) -> Self {
Self {
http: Client::new(),
auth: AuthMethod::oauth(access_token),
model: "claude-sonnet-4-20250514".to_string(),
}
}
/// Create a new client with full AuthMethod
pub fn with_auth(auth: AuthMethod) -> Self {
Self {
http: Client::new(),
auth,
model: "claude-sonnet-4-20250514".to_string(),
}
}
/// Set the model to use
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
/// Get current auth method (for token refresh)
pub fn auth(&self) -> &AuthMethod {
&self.auth
}
/// Update the auth method (after refresh)
pub fn set_auth(&mut self, auth: AuthMethod) {
self.auth = auth;
}
/// Convert messages to Anthropic format, extracting system message
fn prepare_messages(messages: &[ChatMessage]) -> (Option<String>, Vec<AnthropicMessage>) {
let mut system_content = None;
let mut anthropic_messages = Vec::new();
for msg in messages {
if msg.role == Role::System {
// Collect system messages
if let Some(content) = &msg.content {
if let Some(existing) = &mut system_content {
*existing = format!("{}\n\n{}", existing, content);
} else {
system_content = Some(content.clone());
}
}
} else {
anthropic_messages.push(AnthropicMessage::from(msg));
}
}
(system_content, anthropic_messages)
}
/// Convert tools to Anthropic format
fn prepare_tools(tools: Option<&[Tool]>) -> Option<Vec<AnthropicTool>> {
tools.map(|t| t.iter().map(AnthropicTool::from).collect())
}
}
#[async_trait]
impl LlmProvider for AnthropicClient {
fn name(&self) -> &str {
"anthropic"
}
fn model(&self) -> &str {
&self.model
}
async fn chat_stream(
&self,
messages: &[ChatMessage],
options: &ChatOptions,
tools: Option<&[Tool]>,
) -> Result<ChunkStream, LlmError> {
let url = format!("{}{}", API_BASE_URL, MESSAGES_ENDPOINT);
let model = if options.model.is_empty() {
&self.model
} else {
&options.model
};
let (system, anthropic_messages) = Self::prepare_messages(messages);
let anthropic_tools = Self::prepare_tools(tools);
let request = MessagesRequest {
model,
messages: anthropic_messages,
max_tokens: options.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
system: system.as_deref(),
temperature: options.temperature,
top_p: options.top_p,
stop_sequences: options.stop.as_deref(),
tools: anthropic_tools,
stream: true,
};
let bearer = self
.auth
.bearer_token()
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
// Build the SSE request
let req = self
.http
.post(&url)
.header("x-api-key", bearer)
.header("anthropic-version", API_VERSION)
.header("content-type", "application/json")
.json(&request);
let es = EventSource::new(req).map_err(|e| LlmError::Http(e.to_string()))?;
// State for accumulating tool calls across deltas
let tool_state: Arc<Mutex<Vec<PartialToolCall>>> = Arc::new(Mutex::new(Vec::new()));
let stream = es.filter_map(move |event| {
let tool_state = Arc::clone(&tool_state);
async move {
match event {
Ok(Event::Open) => None,
Ok(Event::Message(msg)) => {
// Parse the SSE data as JSON
let event: StreamEvent = match serde_json::from_str(&msg.data) {
Ok(e) => e,
Err(e) => {
tracing::warn!("Failed to parse SSE event: {}", e);
return None;
}
};
convert_stream_event(event, &tool_state).await
}
Err(reqwest_eventsource::Error::StreamEnded) => None,
Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
}
}
});
Ok(Box::pin(stream))
}
async fn chat(
&self,
messages: &[ChatMessage],
options: &ChatOptions,
tools: Option<&[Tool]>,
) -> Result<ChatResponse, LlmError> {
let url = format!("{}{}", API_BASE_URL, MESSAGES_ENDPOINT);
let model = if options.model.is_empty() {
&self.model
} else {
&options.model
};
let (system, anthropic_messages) = Self::prepare_messages(messages);
let anthropic_tools = Self::prepare_tools(tools);
let request = MessagesRequest {
model,
messages: anthropic_messages,
max_tokens: options.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
system: system.as_deref(),
temperature: options.temperature,
top_p: options.top_p,
stop_sequences: options.stop.as_deref(),
tools: anthropic_tools,
stream: false,
};
let bearer = self
.auth
.bearer_token()
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
let response = self
.http
.post(&url)
.header("x-api-key", bearer)
.header("anthropic-version", API_VERSION)
.json(&request)
.send()
.await
.map_err(|e| LlmError::Http(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
// Check for rate limiting
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
return Err(LlmError::RateLimit {
retry_after_secs: None,
});
}
return Err(LlmError::Api {
message: text,
code: Some(status.to_string()),
});
}
let api_response: MessagesResponse = response
.json()
.await
.map_err(|e| LlmError::Json(e.to_string()))?;
// Convert response to common format
let mut content = String::new();
let mut tool_calls = Vec::new();
for block in api_response.content {
match block {
ResponseContentBlock::Text { text } => {
content.push_str(&text);
}
ResponseContentBlock::ToolUse { id, name, input } => {
tool_calls.push(ToolCall {
id,
call_type: "function".to_string(),
function: FunctionCall {
name,
arguments: input,
},
});
}
}
}
let usage = api_response.usage.map(|u| Usage {
prompt_tokens: u.input_tokens,
completion_tokens: u.output_tokens,
total_tokens: u.input_tokens + u.output_tokens,
});
Ok(ChatResponse {
content: if content.is_empty() {
None
} else {
Some(content)
},
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
usage,
})
}
}
/// Helper struct for accumulating streaming tool calls
#[derive(Default)]
struct PartialToolCall {
#[allow(dead_code)]
id: String,
#[allow(dead_code)]
name: String,
input_json: String,
}
/// Convert an Anthropic stream event to our common StreamChunk format
async fn convert_stream_event(
event: StreamEvent,
tool_state: &Arc<Mutex<Vec<PartialToolCall>>>,
) -> Option<Result<StreamChunk, LlmError>> {
match event {
StreamEvent::ContentBlockStart {
index,
content_block,
} => {
match content_block {
ContentBlockStartInfo::Text { text } => {
if text.is_empty() {
None
} else {
Some(Ok(StreamChunk {
content: Some(text),
tool_calls: None,
done: false,
usage: None,
}))
}
}
ContentBlockStartInfo::ToolUse { id, name } => {
// Store the tool call start
let mut state = tool_state.lock().await;
while state.len() <= index {
state.push(PartialToolCall::default());
}
state[index] = PartialToolCall {
id: id.clone(),
name: name.clone(),
input_json: String::new(),
};
Some(Ok(StreamChunk {
content: None,
tool_calls: Some(vec![ToolCallDelta {
index,
id: Some(id),
function_name: Some(name),
arguments_delta: None,
}]),
done: false,
usage: None,
}))
}
}
}
StreamEvent::ContentBlockDelta { index, delta } => match delta {
ContentDelta::TextDelta { text } => Some(Ok(StreamChunk {
content: Some(text),
tool_calls: None,
done: false,
usage: None,
})),
ContentDelta::InputJsonDelta { partial_json } => {
// Accumulate the JSON
let mut state = tool_state.lock().await;
if index < state.len() {
state[index].input_json.push_str(&partial_json);
}
Some(Ok(StreamChunk {
content: None,
tool_calls: Some(vec![ToolCallDelta {
index,
id: None,
function_name: None,
arguments_delta: Some(partial_json),
}]),
done: false,
usage: None,
}))
}
},
StreamEvent::MessageDelta { usage, .. } => {
let u = usage.map(|u| Usage {
prompt_tokens: u.input_tokens,
completion_tokens: u.output_tokens,
total_tokens: u.input_tokens + u.output_tokens,
});
Some(Ok(StreamChunk {
content: None,
tool_calls: None,
done: false,
usage: u,
}))
}
StreamEvent::MessageStop => Some(Ok(StreamChunk {
content: None,
tool_calls: None,
done: true,
usage: None,
})),
StreamEvent::Error { error } => Some(Err(LlmError::Api {
message: error.message,
code: Some(error.error_type),
})),
// Ignore other events
StreamEvent::MessageStart { .. }
| StreamEvent::ContentBlockStop { .. }
| StreamEvent::Ping => None,
}
}
// ============================================================================
// ProviderInfo Implementation
// ============================================================================
/// Known Claude models with their specifications
fn get_claude_models() -> Vec<ModelInfo> {
vec![
ModelInfo {
id: "claude-opus-4-20250514".to_string(),
display_name: Some("Claude Opus 4".to_string()),
description: Some("Most capable model for complex tasks".to_string()),
context_window: Some(200_000),
max_output_tokens: Some(32_000),
supports_tools: true,
supports_vision: true,
input_price_per_mtok: Some(15.0),
output_price_per_mtok: Some(75.0),
},
ModelInfo {
id: "claude-sonnet-4-20250514".to_string(),
display_name: Some("Claude Sonnet 4".to_string()),
description: Some("Best balance of performance and speed".to_string()),
context_window: Some(200_000),
max_output_tokens: Some(64_000),
supports_tools: true,
supports_vision: true,
input_price_per_mtok: Some(3.0),
output_price_per_mtok: Some(15.0),
},
ModelInfo {
id: "claude-haiku-3-5-20241022".to_string(),
display_name: Some("Claude 3.5 Haiku".to_string()),
description: Some("Fast and affordable for simple tasks".to_string()),
context_window: Some(200_000),
max_output_tokens: Some(8_192),
supports_tools: true,
supports_vision: true,
input_price_per_mtok: Some(0.80),
output_price_per_mtok: Some(4.0),
},
]
}
#[async_trait]
impl ProviderInfo for AnthropicClient {
async fn status(&self) -> Result<ProviderStatus, LlmError> {
let authenticated = self.auth.bearer_token().is_some();
// Try to reach the API with a simple request
let reachable = if authenticated {
// Test with a minimal message to verify auth works
let test_messages = vec![ChatMessage::user("Hi")];
let test_opts = ChatOptions::new(&self.model).with_max_tokens(1);
match self.chat(&test_messages, &test_opts, None).await {
Ok(_) => true,
Err(LlmError::Auth(_)) => false, // Auth failed
Err(_) => true, // Other errors mean API is reachable
}
} else {
false
};
let account = if authenticated && reachable {
self.account_info().await.ok().flatten()
} else {
None
};
let message = if !authenticated {
Some("Not authenticated - run 'owlen login anthropic' to authenticate".to_string())
} else if !reachable {
Some("Cannot reach Anthropic API".to_string())
} else {
Some("Connected".to_string())
};
Ok(ProviderStatus {
provider: "anthropic".to_string(),
authenticated,
account,
model: self.model.clone(),
endpoint: API_BASE_URL.to_string(),
reachable,
message,
})
}
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError> {
// Anthropic doesn't have a public account info endpoint
// Return None - account info would come from OAuth token claims
Ok(None)
}
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError> {
// Anthropic doesn't expose usage stats via API
// This would require the admin/billing API with different auth
Ok(None)
}
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
// Return known models - Anthropic doesn't have a models list endpoint
Ok(get_claude_models())
}
async fn model_info(&self, model_id: &str) -> Result<Option<ModelInfo>, LlmError> {
let models = get_claude_models();
Ok(models.into_iter().find(|m| m.id == model_id))
}
}
#[cfg(test)]
mod tests {
use super::*;
use llm_core::ToolParameters;
use serde_json::json;
#[test]
fn test_message_conversion() {
let messages = vec![
ChatMessage::system("You are helpful"),
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there!"),
];
let (system, anthropic_msgs) = AnthropicClient::prepare_messages(&messages);
assert_eq!(system, Some("You are helpful".to_string()));
assert_eq!(anthropic_msgs.len(), 2);
assert_eq!(anthropic_msgs[0].role, "user");
assert_eq!(anthropic_msgs[1].role, "assistant");
}
#[test]
fn test_tool_conversion() {
let tools = vec![Tool::function(
"read_file",
"Read a file's contents",
ToolParameters::object(
json!({
"path": {
"type": "string",
"description": "File path"
}
}),
vec!["path".to_string()],
),
)];
let anthropic_tools = AnthropicClient::prepare_tools(Some(&tools)).unwrap();
assert_eq!(anthropic_tools.len(), 1);
assert_eq!(anthropic_tools[0].name, "read_file");
assert_eq!(anthropic_tools[0].description, "Read a file's contents");
}
}

View File

@@ -0,0 +1,12 @@
//! Anthropic Claude API Client
//!
//! Implements the LlmProvider trait for Anthropic's Claude models.
//! Supports both API key authentication and OAuth device flow.
mod auth;
mod client;
mod types;
pub use auth::*;
pub use client::*;
pub use types::*;

View File

@@ -0,0 +1,276 @@
//! Anthropic API request/response types
use serde::{Deserialize, Serialize};
use serde_json::Value;
// ============================================================================
// Request Types
// ============================================================================
#[derive(Debug, Serialize)]
pub struct MessagesRequest<'a> {
pub model: &'a str,
pub messages: Vec<AnthropicMessage>,
pub max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<&'a [String]>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<AnthropicTool>>,
pub stream: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnthropicMessage {
pub role: String, // "user" or "assistant"
pub content: AnthropicContent,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum AnthropicContent {
Text(String),
Blocks(Vec<ContentBlock>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ContentBlock {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: Value,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
is_error: Option<bool>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnthropicTool {
pub name: String,
pub description: String,
pub input_schema: ToolInputSchema,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolInputSchema {
#[serde(rename = "type")]
pub schema_type: String,
pub properties: Value,
pub required: Vec<String>,
}
// ============================================================================
// Response Types
// ============================================================================
#[derive(Debug, Clone, Deserialize)]
pub struct MessagesResponse {
pub id: String,
#[serde(rename = "type")]
pub response_type: String,
pub role: String,
pub content: Vec<ResponseContentBlock>,
pub model: String,
pub stop_reason: Option<String>,
pub usage: Option<UsageInfo>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
pub enum ResponseContentBlock {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: Value,
},
}
#[derive(Debug, Clone, Deserialize)]
pub struct UsageInfo {
pub input_tokens: u32,
pub output_tokens: u32,
}
// ============================================================================
// Streaming Event Types
// ============================================================================
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
pub enum StreamEvent {
#[serde(rename = "message_start")]
MessageStart { message: MessageStartInfo },
#[serde(rename = "content_block_start")]
ContentBlockStart {
index: usize,
content_block: ContentBlockStartInfo,
},
#[serde(rename = "content_block_delta")]
ContentBlockDelta { index: usize, delta: ContentDelta },
#[serde(rename = "content_block_stop")]
ContentBlockStop { index: usize },
#[serde(rename = "message_delta")]
MessageDelta {
delta: MessageDeltaInfo,
usage: Option<UsageInfo>,
},
#[serde(rename = "message_stop")]
MessageStop,
#[serde(rename = "ping")]
Ping,
#[serde(rename = "error")]
Error { error: ApiError },
}
#[derive(Debug, Clone, Deserialize)]
pub struct MessageStartInfo {
pub id: String,
#[serde(rename = "type")]
pub message_type: String,
pub role: String,
pub model: String,
pub usage: Option<UsageInfo>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
pub enum ContentBlockStartInfo {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "tool_use")]
ToolUse { id: String, name: String },
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
pub enum ContentDelta {
#[serde(rename = "text_delta")]
TextDelta { text: String },
#[serde(rename = "input_json_delta")]
InputJsonDelta { partial_json: String },
}
#[derive(Debug, Clone, Deserialize)]
pub struct MessageDeltaInfo {
pub stop_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ApiError {
#[serde(rename = "type")]
pub error_type: String,
pub message: String,
}
// ============================================================================
// Conversions
// ============================================================================
impl From<&llm_core::Tool> for AnthropicTool {
fn from(tool: &llm_core::Tool) -> Self {
Self {
name: tool.function.name.clone(),
description: tool.function.description.clone(),
input_schema: ToolInputSchema {
schema_type: tool.function.parameters.param_type.clone(),
properties: tool.function.parameters.properties.clone(),
required: tool.function.parameters.required.clone(),
},
}
}
}
impl From<&llm_core::ChatMessage> for AnthropicMessage {
fn from(msg: &llm_core::ChatMessage) -> Self {
use llm_core::Role;
let role = match msg.role {
Role::User | Role::System => "user",
Role::Assistant => "assistant",
Role::Tool => "user", // Tool results come as user messages in Anthropic
};
// Handle tool results
if msg.role == Role::Tool {
if let (Some(tool_call_id), Some(content)) = (&msg.tool_call_id, &msg.content) {
return Self {
role: "user".to_string(),
content: AnthropicContent::Blocks(vec![ContentBlock::ToolResult {
tool_use_id: tool_call_id.clone(),
content: content.clone(),
is_error: None,
}]),
};
}
}
// Handle assistant messages with tool calls
if msg.role == Role::Assistant {
if let Some(tool_calls) = &msg.tool_calls {
let mut blocks: Vec<ContentBlock> = Vec::new();
// Add text content if present
if let Some(text) = &msg.content {
if !text.is_empty() {
blocks.push(ContentBlock::Text { text: text.clone() });
}
}
// Add tool use blocks
for call in tool_calls {
blocks.push(ContentBlock::ToolUse {
id: call.id.clone(),
name: call.function.name.clone(),
input: call.function.arguments.clone(),
});
}
return Self {
role: "assistant".to_string(),
content: AnthropicContent::Blocks(blocks),
};
}
}
// Simple text message
Self {
role: role.to_string(),
content: AnthropicContent::Text(msg.content.clone().unwrap_or_default()),
}
}
}

View File

@@ -0,0 +1,18 @@
[package]
name = "llm-core"
version = "0.1.0"
edition.workspace = true
license.workspace = true
description = "LLM provider abstraction layer for Owlen"
[dependencies]
async-trait = "0.1"
futures = "0.3"
rand = "0.8"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = "2.0"
tokio = { version = "1.0", features = ["time"] }
[dev-dependencies]
tokio = { version = "1.0", features = ["macros", "rt"] }

View File

@@ -0,0 +1,195 @@
//! Token counting example
//!
//! This example demonstrates how to use the token counting utilities
//! to manage LLM context windows.
//!
//! Run with: cargo run --example token_counting -p llm-core
use llm_core::{
ChatMessage, ClaudeTokenCounter, ContextWindow, SimpleTokenCounter, TokenCounter,
};
fn main() {
println!("=== Token Counting Example ===\n");
// Example 1: Basic token counting with SimpleTokenCounter
println!("1. Basic Token Counting");
println!("{}", "-".repeat(50));
let simple_counter = SimpleTokenCounter::new(8192);
let text = "The quick brown fox jumps over the lazy dog.";
let token_count = simple_counter.count(text);
println!("Text: \"{}\"", text);
println!("Estimated tokens: {}", token_count);
println!("Max context: {}\n", simple_counter.max_context());
// Example 2: Counting tokens in chat messages
println!("2. Counting Tokens in Chat Messages");
println!("{}", "-".repeat(50));
let messages = vec![
ChatMessage::system("You are a helpful assistant that provides concise answers."),
ChatMessage::user("What is the capital of France?"),
ChatMessage::assistant("The capital of France is Paris."),
ChatMessage::user("What is its population?"),
];
let total_tokens = simple_counter.count_messages(&messages);
println!("Number of messages: {}", messages.len());
println!("Total tokens (with overhead): {}\n", total_tokens);
// Example 3: Using ClaudeTokenCounter for Claude models
println!("3. Claude-Specific Token Counting");
println!("{}", "-".repeat(50));
let claude_counter = ClaudeTokenCounter::new();
let claude_total = claude_counter.count_messages(&messages);
println!("Claude counter max context: {}", claude_counter.max_context());
println!("Claude estimated tokens: {}\n", claude_total);
// Example 4: Context window management
println!("4. Context Window Management");
println!("{}", "-".repeat(50));
let mut context = ContextWindow::new(8192);
println!("Created context window with max: {} tokens", context.max());
// Simulate adding messages
let conversation = vec![
ChatMessage::user("Tell me about Rust programming."),
ChatMessage::assistant(
"Rust is a systems programming language focused on safety, \
speed, and concurrency. It prevents common bugs like null pointer \
dereferences and data races through its ownership system.",
),
ChatMessage::user("What are its main features?"),
ChatMessage::assistant(
"Rust's main features include: 1) Memory safety without garbage collection, \
2) Zero-cost abstractions, 3) Fearless concurrency, 4) Pattern matching, \
5) Type inference, and 6) A powerful macro system.",
),
];
for (i, msg) in conversation.iter().enumerate() {
let tokens = simple_counter.count_messages(&[msg.clone()]);
context.add_tokens(tokens);
let role = msg.role.as_str();
let preview = msg
.content
.as_ref()
.map(|c| {
if c.len() > 50 {
format!("{}...", &c[..50])
} else {
c.clone()
}
})
.unwrap_or_default();
println!(
"Message {}: [{}] \"{}\"",
i + 1,
role,
preview
);
println!(" Added {} tokens", tokens);
println!(" Total used: {} / {}", context.used(), context.max());
println!(" Usage: {:.1}%", context.usage_percent() * 100.0);
println!(" Progress: {}\n", context.progress_bar(30));
}
// Example 5: Checking context limits
println!("5. Checking Context Limits");
println!("{}", "-".repeat(50));
if context.is_near_limit(0.8) {
println!("Warning: Context is over 80% full!");
} else {
println!("Context usage is below 80%");
}
let remaining = context.remaining();
println!("Remaining tokens: {}", remaining);
let new_message_tokens = 500;
if context.has_room_for(new_message_tokens) {
println!(
"Can fit a message of {} tokens",
new_message_tokens
);
} else {
println!(
"Cannot fit a message of {} tokens - would need to compact or start new context",
new_message_tokens
);
}
// Example 6: Different counter variants
println!("\n6. Using Different Counter Variants");
println!("{}", "-".repeat(50));
let counter_8k = SimpleTokenCounter::default_8k();
let counter_32k = SimpleTokenCounter::with_32k();
let counter_128k = SimpleTokenCounter::with_128k();
println!("8k context counter: {} tokens", counter_8k.max_context());
println!("32k context counter: {} tokens", counter_32k.max_context());
println!("128k context counter: {} tokens", counter_128k.max_context());
let haiku = ClaudeTokenCounter::haiku();
let sonnet = ClaudeTokenCounter::sonnet();
let opus = ClaudeTokenCounter::opus();
println!("\nClaude Haiku: {} tokens", haiku.max_context());
println!("Claude Sonnet: {} tokens", sonnet.max_context());
println!("Claude Opus: {} tokens", opus.max_context());
// Example 7: Managing context for a long conversation
println!("\n7. Long Conversation Simulation");
println!("{}", "-".repeat(50));
let mut long_context = ContextWindow::new(4096); // Smaller context for demo
let counter = SimpleTokenCounter::new(4096);
let mut message_count = 0;
let mut compaction_count = 0;
// Simulate 20 exchanges
for i in 0..20 {
let user_msg = ChatMessage::user(format!(
"This is user message number {} asking a question.",
i + 1
));
let assistant_msg = ChatMessage::assistant(format!(
"This is assistant response number {} providing a detailed answer with multiple sentences to make it longer.",
i + 1
));
let tokens_needed = counter.count_messages(&[user_msg, assistant_msg]);
if !long_context.has_room_for(tokens_needed) {
println!(
"After {} messages, context is full ({}%). Compacting...",
message_count,
(long_context.usage_percent() * 100.0) as u32
);
// In a real scenario, we would compact the conversation
// For now, just reset
long_context.reset();
compaction_count += 1;
}
long_context.add_tokens(tokens_needed);
message_count += 2;
}
println!("Total messages: {}", message_count);
println!("Compactions needed: {}", compaction_count);
println!("Final context usage: {:.1}%", long_context.usage_percent() * 100.0);
println!("Final progress: {}", long_context.progress_bar(40));
println!("\n=== Example Complete ===");
}

796
crates/llm/core/src/lib.rs Normal file
View File

@@ -0,0 +1,796 @@
//! LLM Provider Abstraction Layer
//!
//! This crate defines the common types and traits for LLM provider integration.
//! Providers (Ollama, Anthropic Claude, OpenAI) implement the `LlmProvider` trait
//! to enable swapping providers at runtime.
use async_trait::async_trait;
use futures::Stream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::pin::Pin;
use thiserror::Error;
// ============================================================================
// Public Modules
// ============================================================================
pub mod retry;
pub mod tokens;
// Re-export token counting types for convenience
pub use tokens::{ClaudeTokenCounter, ContextWindow, SimpleTokenCounter, TokenCounter};
// Re-export retry types for convenience
pub use retry::{is_retryable_error, RetryConfig, RetryStrategy};
// ============================================================================
// Error Types
// ============================================================================
#[derive(Error, Debug)]
pub enum LlmError {
#[error("HTTP error: {0}")]
Http(String),
#[error("JSON parsing error: {0}")]
Json(String),
#[error("Authentication error: {0}")]
Auth(String),
#[error("Rate limit exceeded: retry after {retry_after_secs:?} seconds")]
RateLimit { retry_after_secs: Option<u64> },
#[error("API error: {message}")]
Api { message: String, code: Option<String> },
#[error("Provider error: {0}")]
Provider(String),
#[error("Stream error: {0}")]
Stream(String),
#[error("Request timeout: {0}")]
Timeout(String),
}
// ============================================================================
// Message Types
// ============================================================================
/// Role of a message in the conversation
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
impl Role {
pub fn as_str(&self) -> &'static str {
match self {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
Role::Tool => "tool",
}
}
}
impl From<&str> for Role {
fn from(s: &str) -> Self {
match s.to_lowercase().as_str() {
"system" => Role::System,
"user" => Role::User,
"assistant" => Role::Assistant,
"tool" => Role::Tool,
_ => Role::User, // Default fallback
}
}
}
/// A message in the conversation
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: Role,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
/// Tool calls made by the assistant
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
/// For tool role messages: the ID of the tool call this responds to
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
/// For tool role messages: the name of the tool
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
impl ChatMessage {
/// Create a system message
pub fn system(content: impl Into<String>) -> Self {
Self {
role: Role::System,
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
/// Create a user message
pub fn user(content: impl Into<String>) -> Self {
Self {
role: Role::User,
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
/// Create an assistant message
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: Role::Assistant,
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
/// Create an assistant message with tool calls (no text content)
pub fn assistant_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
Self {
role: Role::Assistant,
content: None,
tool_calls: Some(tool_calls),
tool_call_id: None,
name: None,
}
}
/// Create a tool result message
pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: Role::Tool,
content: Some(content.into()),
tool_calls: None,
tool_call_id: Some(tool_call_id.into()),
name: None,
}
}
}
// ============================================================================
// Tool Types
// ============================================================================
/// A tool call requested by the LLM
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCall {
/// Unique identifier for this tool call
pub id: String,
/// The type of tool call (always "function" for now)
#[serde(rename = "type", default = "default_function_type")]
pub call_type: String,
/// The function being called
pub function: FunctionCall,
}
fn default_function_type() -> String {
"function".to_string()
}
/// Details of a function call
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FunctionCall {
/// Name of the function to call
pub name: String,
/// Arguments as a JSON object
pub arguments: Value,
}
/// Definition of a tool available to the LLM
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: ToolFunction,
}
impl Tool {
/// Create a new function tool
pub fn function(
name: impl Into<String>,
description: impl Into<String>,
parameters: ToolParameters,
) -> Self {
Self {
tool_type: "function".to_string(),
function: ToolFunction {
name: name.into(),
description: description.into(),
parameters,
},
}
}
}
/// Function definition within a tool
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolFunction {
pub name: String,
pub description: String,
pub parameters: ToolParameters,
}
/// Parameters schema for a function
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolParameters {
#[serde(rename = "type")]
pub param_type: String,
/// JSON Schema properties object
pub properties: Value,
/// Required parameter names
pub required: Vec<String>,
}
impl ToolParameters {
/// Create an object parameter schema
pub fn object(properties: Value, required: Vec<String>) -> Self {
Self {
param_type: "object".to_string(),
properties,
required,
}
}
}
// ============================================================================
// Streaming Response Types
// ============================================================================
/// A chunk of a streaming response
#[derive(Debug, Clone)]
pub struct StreamChunk {
/// Incremental text content
pub content: Option<String>,
/// Tool calls (may be partial/streaming)
pub tool_calls: Option<Vec<ToolCallDelta>>,
/// Whether this is the final chunk
pub done: bool,
/// Usage statistics (typically only in final chunk)
pub usage: Option<Usage>,
}
/// Partial tool call for streaming
#[derive(Debug, Clone)]
pub struct ToolCallDelta {
/// Index of this tool call in the array
pub index: usize,
/// Tool call ID (may only be present in first delta)
pub id: Option<String>,
/// Function name (may only be present in first delta)
pub function_name: Option<String>,
/// Incremental arguments string
pub arguments_delta: Option<String>,
}
/// Token usage statistics
#[derive(Debug, Clone, Default)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
// ============================================================================
// Provider Configuration
// ============================================================================
/// Options for a chat request
#[derive(Debug, Clone, Default)]
pub struct ChatOptions {
/// Model to use
pub model: String,
/// Temperature (0.0 - 2.0)
pub temperature: Option<f32>,
/// Maximum tokens to generate
pub max_tokens: Option<u32>,
/// Top-p sampling
pub top_p: Option<f32>,
/// Stop sequences
pub stop: Option<Vec<String>>,
}
impl ChatOptions {
pub fn new(model: impl Into<String>) -> Self {
Self {
model: model.into(),
..Default::default()
}
}
pub fn with_temperature(mut self, temp: f32) -> Self {
self.temperature = Some(temp);
self
}
pub fn with_max_tokens(mut self, max: u32) -> Self {
self.max_tokens = Some(max);
self
}
}
// ============================================================================
// Provider Trait
// ============================================================================
/// A boxed stream of chunks
pub type ChunkStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, LlmError>> + Send>>;
/// The main trait that all LLM providers must implement
#[async_trait]
pub trait LlmProvider: Send + Sync {
/// Get the provider name (e.g., "ollama", "anthropic", "openai")
fn name(&self) -> &str;
/// Get the current model name
fn model(&self) -> &str;
/// Send a chat request and receive a streaming response
///
/// # Arguments
/// * `messages` - The conversation history
/// * `options` - Request options (model, temperature, etc.)
/// * `tools` - Optional list of tools the model can use
///
/// # Returns
/// A stream of response chunks
async fn chat_stream(
&self,
messages: &[ChatMessage],
options: &ChatOptions,
tools: Option<&[Tool]>,
) -> Result<ChunkStream, LlmError>;
/// Send a chat request and receive a complete response (non-streaming)
///
/// Default implementation collects the stream, but providers may override
/// for efficiency.
async fn chat(
&self,
messages: &[ChatMessage],
options: &ChatOptions,
tools: Option<&[Tool]>,
) -> Result<ChatResponse, LlmError> {
use futures::StreamExt;
let mut stream = self.chat_stream(messages, options, tools).await?;
let mut content = String::new();
let mut tool_calls: Vec<PartialToolCall> = Vec::new();
let mut usage = None;
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
if let Some(text) = chunk.content {
content.push_str(&text);
}
if let Some(deltas) = chunk.tool_calls {
for delta in deltas {
// Grow the tool_calls vec if needed
while tool_calls.len() <= delta.index {
tool_calls.push(PartialToolCall::default());
}
let partial = &mut tool_calls[delta.index];
if let Some(id) = delta.id {
partial.id = Some(id);
}
if let Some(name) = delta.function_name {
partial.function_name = Some(name);
}
if let Some(args) = delta.arguments_delta {
partial.arguments.push_str(&args);
}
}
}
if chunk.usage.is_some() {
usage = chunk.usage;
}
}
// Convert partial tool calls to complete tool calls
let final_tool_calls: Vec<ToolCall> = tool_calls
.into_iter()
.filter_map(|p| p.try_into_tool_call())
.collect();
Ok(ChatResponse {
content: if content.is_empty() {
None
} else {
Some(content)
},
tool_calls: if final_tool_calls.is_empty() {
None
} else {
Some(final_tool_calls)
},
usage,
})
}
}
/// A complete chat response (non-streaming)
#[derive(Debug, Clone)]
pub struct ChatResponse {
pub content: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
pub usage: Option<Usage>,
}
/// Helper for accumulating streaming tool calls
#[derive(Default)]
struct PartialToolCall {
id: Option<String>,
function_name: Option<String>,
arguments: String,
}
impl PartialToolCall {
fn try_into_tool_call(self) -> Option<ToolCall> {
let id = self.id?;
let name = self.function_name?;
let arguments: Value = serde_json::from_str(&self.arguments).ok()?;
Some(ToolCall {
id,
call_type: "function".to_string(),
function: FunctionCall { name, arguments },
})
}
}
// ============================================================================
// Authentication
// ============================================================================
/// Authentication method for LLM providers
#[derive(Debug, Clone)]
pub enum AuthMethod {
/// No authentication (for local providers like Ollama)
None,
/// API key authentication
ApiKey(String),
/// OAuth access token (from login flow)
OAuth {
access_token: String,
refresh_token: Option<String>,
expires_at: Option<u64>,
},
}
impl AuthMethod {
/// Create API key auth
pub fn api_key(key: impl Into<String>) -> Self {
Self::ApiKey(key.into())
}
/// Create OAuth auth from tokens
pub fn oauth(access_token: impl Into<String>) -> Self {
Self::OAuth {
access_token: access_token.into(),
refresh_token: None,
expires_at: None,
}
}
/// Create OAuth auth with refresh token
pub fn oauth_with_refresh(
access_token: impl Into<String>,
refresh_token: impl Into<String>,
expires_at: Option<u64>,
) -> Self {
Self::OAuth {
access_token: access_token.into(),
refresh_token: Some(refresh_token.into()),
expires_at,
}
}
/// Get the bearer token for Authorization header
pub fn bearer_token(&self) -> Option<&str> {
match self {
Self::None => None,
Self::ApiKey(key) => Some(key),
Self::OAuth { access_token, .. } => Some(access_token),
}
}
/// Check if token might need refresh
pub fn needs_refresh(&self) -> bool {
match self {
Self::OAuth {
expires_at: Some(exp),
refresh_token: Some(_),
..
} => {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
// Refresh if expiring within 5 minutes
*exp < now + 300
}
_ => false,
}
}
}
/// Device code response for OAuth device flow
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeviceCodeResponse {
/// Code the user enters on the verification page
pub user_code: String,
/// URL the user visits to authorize
pub verification_uri: String,
/// Full URL with code pre-filled (if supported)
pub verification_uri_complete: Option<String>,
/// Device code for polling (internal use)
pub device_code: String,
/// How often to poll (in seconds)
pub interval: u64,
/// When the codes expire (in seconds)
pub expires_in: u64,
}
/// Result of polling for device authorization
#[derive(Debug, Clone)]
pub enum DeviceAuthResult {
/// Still waiting for user to authorize
Pending,
/// User authorized, here are the tokens
Success {
access_token: String,
refresh_token: Option<String>,
expires_in: Option<u64>,
},
/// User denied authorization
Denied,
/// Code expired
Expired,
}
/// Trait for providers that support OAuth device flow
#[async_trait]
pub trait OAuthProvider {
/// Start the device authorization flow
async fn start_device_auth(&self) -> Result<DeviceCodeResponse, LlmError>;
/// Poll for the authorization result
async fn poll_device_auth(&self, device_code: &str) -> Result<DeviceAuthResult, LlmError>;
/// Refresh an access token using a refresh token
async fn refresh_token(&self, refresh_token: &str) -> Result<AuthMethod, LlmError>;
}
/// Stored credentials for a provider
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredCredentials {
pub provider: String,
pub access_token: String,
pub refresh_token: Option<String>,
pub expires_at: Option<u64>,
}
// ============================================================================
// Provider Status & Info
// ============================================================================
/// Status information for a provider connection
#[derive(Debug, Clone)]
pub struct ProviderStatus {
/// Provider name
pub provider: String,
/// Whether the connection is authenticated
pub authenticated: bool,
/// Current user/account info if authenticated
pub account: Option<AccountInfo>,
/// Current model being used
pub model: String,
/// API endpoint URL
pub endpoint: String,
/// Whether the provider is reachable
pub reachable: bool,
/// Any status message or error
pub message: Option<String>,
}
/// Account/user information from the provider
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccountInfo {
/// Account/user ID
pub id: Option<String>,
/// Display name or email
pub name: Option<String>,
/// Account email
pub email: Option<String>,
/// Account type (free, pro, team, enterprise)
pub account_type: Option<String>,
/// Organization name if applicable
pub organization: Option<String>,
}
/// Usage statistics from the provider
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageStats {
/// Total tokens used in current period
pub tokens_used: Option<u64>,
/// Token limit for current period (if applicable)
pub token_limit: Option<u64>,
/// Number of requests made
pub requests_made: Option<u64>,
/// Request limit (if applicable)
pub request_limit: Option<u64>,
/// Cost incurred (if available)
pub cost_usd: Option<f64>,
/// Period start timestamp
pub period_start: Option<u64>,
/// Period end timestamp
pub period_end: Option<u64>,
}
/// Available model information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
/// Model ID/name
pub id: String,
/// Human-readable display name
pub display_name: Option<String>,
/// Model description
pub description: Option<String>,
/// Context window size (tokens)
pub context_window: Option<u32>,
/// Max output tokens
pub max_output_tokens: Option<u32>,
/// Whether the model supports tool use
pub supports_tools: bool,
/// Whether the model supports vision/images
pub supports_vision: bool,
/// Input token price per 1M tokens (USD)
pub input_price_per_mtok: Option<f64>,
/// Output token price per 1M tokens (USD)
pub output_price_per_mtok: Option<f64>,
}
/// Trait for providers that support status/info queries
#[async_trait]
pub trait ProviderInfo {
/// Get the current connection status
async fn status(&self) -> Result<ProviderStatus, LlmError>;
/// Get account information (if authenticated)
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError>;
/// Get usage statistics (if available)
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError>;
/// List available models
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError>;
/// Check if a specific model is available
async fn model_info(&self, model_id: &str) -> Result<Option<ModelInfo>, LlmError> {
let models = self.list_models().await?;
Ok(models.into_iter().find(|m| m.id == model_id))
}
}
// ============================================================================
// Provider Factory
// ============================================================================
/// Supported LLM providers
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ProviderType {
Ollama,
Anthropic,
OpenAI,
}
impl ProviderType {
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"ollama" => Some(Self::Ollama),
"anthropic" | "claude" => Some(Self::Anthropic),
"openai" | "gpt" => Some(Self::OpenAI),
_ => None,
}
}
pub fn as_str(&self) -> &'static str {
match self {
Self::Ollama => "ollama",
Self::Anthropic => "anthropic",
Self::OpenAI => "openai",
}
}
/// Default model for this provider
pub fn default_model(&self) -> &'static str {
match self {
Self::Ollama => "qwen3:8b",
Self::Anthropic => "claude-sonnet-4-20250514",
Self::OpenAI => "gpt-4o",
}
}
}
impl std::fmt::Display for ProviderType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}

View File

@@ -0,0 +1,386 @@
//! Error recovery and retry logic for LLM operations
//!
//! This module provides configurable retry strategies with exponential backoff
//! for handling transient failures when communicating with LLM providers.
use crate::LlmError;
use rand::Rng;
use std::time::Duration;
/// Configuration for retry behavior
#[derive(Debug, Clone)]
pub struct RetryConfig {
/// Maximum number of retry attempts
pub max_retries: u32,
/// Initial delay before first retry (in milliseconds)
pub initial_delay_ms: u64,
/// Maximum delay between retries (in milliseconds)
pub max_delay_ms: u64,
/// Multiplier for exponential backoff
pub backoff_multiplier: f32,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_delay_ms: 1000,
max_delay_ms: 30000,
backoff_multiplier: 2.0,
}
}
}
impl RetryConfig {
/// Create a new retry configuration with custom values
pub fn new(
max_retries: u32,
initial_delay_ms: u64,
max_delay_ms: u64,
backoff_multiplier: f32,
) -> Self {
Self {
max_retries,
initial_delay_ms,
max_delay_ms,
backoff_multiplier,
}
}
/// Create a configuration with no retries
pub fn no_retry() -> Self {
Self {
max_retries: 0,
initial_delay_ms: 0,
max_delay_ms: 0,
backoff_multiplier: 1.0,
}
}
/// Create a configuration with aggressive retries for rate-limited scenarios
pub fn aggressive() -> Self {
Self {
max_retries: 5,
initial_delay_ms: 2000,
max_delay_ms: 60000,
backoff_multiplier: 2.5,
}
}
}
/// Determines whether an error is retryable
///
/// # Arguments
/// * `error` - The error to check
///
/// # Returns
/// `true` if the error is transient and the operation should be retried,
/// `false` if the error is permanent and retrying won't help
pub fn is_retryable_error(error: &LlmError) -> bool {
match error {
// Always retry rate limits
LlmError::RateLimit { .. } => true,
// Always retry timeouts
LlmError::Timeout(_) => true,
// Retry HTTP errors that are server-side (5xx)
LlmError::Http(msg) => {
// Check if the error message contains a 5xx status code
msg.contains("500")
|| msg.contains("502")
|| msg.contains("503")
|| msg.contains("504")
|| msg.contains("Internal Server Error")
|| msg.contains("Bad Gateway")
|| msg.contains("Service Unavailable")
|| msg.contains("Gateway Timeout")
}
// Don't retry authentication errors - they need user intervention
LlmError::Auth(_) => false,
// Don't retry JSON parsing errors - the data is malformed
LlmError::Json(_) => false,
// Don't retry API errors - these are typically client-side issues
LlmError::Api { .. } => false,
// Provider errors might be transient, but we conservatively don't retry
LlmError::Provider(_) => false,
// Stream errors are typically not retryable
LlmError::Stream(_) => false,
}
}
/// Strategy for retrying failed operations with exponential backoff
#[derive(Debug, Clone)]
pub struct RetryStrategy {
config: RetryConfig,
}
impl RetryStrategy {
/// Create a new retry strategy with the given configuration
pub fn new(config: RetryConfig) -> Self {
Self { config }
}
/// Create a retry strategy with default configuration
pub fn default_config() -> Self {
Self::new(RetryConfig::default())
}
/// Execute an async operation with retries
///
/// # Arguments
/// * `operation` - A function that returns a Future producing a Result
///
/// # Returns
/// The result of the operation, or the last error if all retries fail
///
/// # Example
/// ```ignore
/// let strategy = RetryStrategy::default_config();
/// let result = strategy.execute(|| async {
/// // Your LLM API call here
/// llm_client.chat(&messages, &options, None).await
/// }).await?;
/// ```
pub async fn execute<F, T, Fut>(&self, operation: F) -> Result<T, LlmError>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T, LlmError>>,
{
let mut attempt = 0;
loop {
// Try the operation
match operation().await {
Ok(result) => return Ok(result),
Err(err) => {
// Check if we should retry
if !is_retryable_error(&err) {
return Err(err);
}
attempt += 1;
// Check if we've exhausted retries
if attempt > self.config.max_retries {
return Err(err);
}
// Calculate delay with exponential backoff and jitter
let delay = self.delay_for_attempt(attempt);
// Log retry attempt (in a real implementation, you might use tracing)
eprintln!(
"Retry attempt {}/{} after {:?}",
attempt, self.config.max_retries, delay
);
// Sleep before next attempt
tokio::time::sleep(delay).await;
}
}
}
}
/// Calculate the delay for a given attempt number with jitter
///
/// Uses exponential backoff: delay = initial_delay * (backoff_multiplier ^ (attempt - 1))
/// Adds random jitter of ±10% to prevent thundering herd problems
///
/// # Arguments
/// * `attempt` - The attempt number (1-indexed)
///
/// # Returns
/// The delay duration to wait before the next retry
fn delay_for_attempt(&self, attempt: u32) -> Duration {
// Calculate base delay with exponential backoff
let base_delay_ms = self.config.initial_delay_ms as f64
* self.config.backoff_multiplier.powi((attempt - 1) as i32) as f64;
// Cap at max_delay_ms
let capped_delay_ms = base_delay_ms.min(self.config.max_delay_ms as f64);
// Add jitter: ±10%
let mut rng = rand::thread_rng();
let jitter_factor = rng.gen_range(0.9..=1.1);
let final_delay_ms = capped_delay_ms * jitter_factor;
Duration::from_millis(final_delay_ms as u64)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
#[test]
fn test_default_retry_config() {
let config = RetryConfig::default();
assert_eq!(config.max_retries, 3);
assert_eq!(config.initial_delay_ms, 1000);
assert_eq!(config.max_delay_ms, 30000);
assert_eq!(config.backoff_multiplier, 2.0);
}
#[test]
fn test_no_retry_config() {
let config = RetryConfig::no_retry();
assert_eq!(config.max_retries, 0);
}
#[test]
fn test_is_retryable_error() {
// Retryable errors
assert!(is_retryable_error(&LlmError::RateLimit {
retry_after_secs: Some(60)
}));
assert!(is_retryable_error(&LlmError::Timeout(
"Request timed out".to_string()
)));
assert!(is_retryable_error(&LlmError::Http(
"500 Internal Server Error".to_string()
)));
assert!(is_retryable_error(&LlmError::Http(
"503 Service Unavailable".to_string()
)));
// Non-retryable errors
assert!(!is_retryable_error(&LlmError::Auth(
"Invalid API key".to_string()
)));
assert!(!is_retryable_error(&LlmError::Json(
"Invalid JSON".to_string()
)));
assert!(!is_retryable_error(&LlmError::Api {
message: "Invalid request".to_string(),
code: Some("400".to_string())
}));
assert!(!is_retryable_error(&LlmError::Http(
"400 Bad Request".to_string()
)));
}
#[test]
fn test_delay_calculation() {
let config = RetryConfig::default();
let strategy = RetryStrategy::new(config);
// Test that delays increase exponentially
let delay1 = strategy.delay_for_attempt(1);
let delay2 = strategy.delay_for_attempt(2);
let delay3 = strategy.delay_for_attempt(3);
// Base delays should be around 1000ms, 2000ms, 4000ms (with jitter)
assert!(delay1.as_millis() >= 900 && delay1.as_millis() <= 1100);
assert!(delay2.as_millis() >= 1800 && delay2.as_millis() <= 2200);
assert!(delay3.as_millis() >= 3600 && delay3.as_millis() <= 4400);
}
#[test]
fn test_delay_max_cap() {
let config = RetryConfig {
max_retries: 10,
initial_delay_ms: 1000,
max_delay_ms: 5000,
backoff_multiplier: 2.0,
};
let strategy = RetryStrategy::new(config);
// Even with high attempt numbers, delay should be capped
let delay = strategy.delay_for_attempt(10);
assert!(delay.as_millis() <= 5500); // max + jitter
}
#[tokio::test]
async fn test_retry_success_on_first_attempt() {
let strategy = RetryStrategy::default_config();
let call_count = Arc::new(AtomicU32::new(0));
let count_clone = call_count.clone();
let result = strategy
.execute(|| {
let count = count_clone.clone();
async move {
count.fetch_add(1, Ordering::SeqCst);
Ok::<_, LlmError>(42)
}
})
.await;
assert_eq!(result.unwrap(), 42);
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_retry_success_after_retries() {
let config = RetryConfig::new(3, 10, 100, 2.0); // Fast retries for testing
let strategy = RetryStrategy::new(config);
let call_count = Arc::new(AtomicU32::new(0));
let count_clone = call_count.clone();
let result = strategy
.execute(|| {
let count = count_clone.clone();
async move {
let current = count.fetch_add(1, Ordering::SeqCst) + 1;
if current < 3 {
Err(LlmError::Timeout("Timeout".to_string()))
} else {
Ok(42)
}
}
})
.await;
assert_eq!(result.unwrap(), 42);
assert_eq!(call_count.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_retry_exhausted() {
let config = RetryConfig::new(2, 10, 100, 2.0); // Fast retries for testing
let strategy = RetryStrategy::new(config);
let call_count = Arc::new(AtomicU32::new(0));
let count_clone = call_count.clone();
let result = strategy
.execute(|| {
let count = count_clone.clone();
async move {
count.fetch_add(1, Ordering::SeqCst);
Err::<(), _>(LlmError::Timeout("Always fails".to_string()))
}
})
.await;
assert!(result.is_err());
assert_eq!(call_count.load(Ordering::SeqCst), 3); // Initial attempt + 2 retries
}
#[tokio::test]
async fn test_non_retryable_error() {
let strategy = RetryStrategy::default_config();
let call_count = Arc::new(AtomicU32::new(0));
let count_clone = call_count.clone();
let result = strategy
.execute(|| {
let count = count_clone.clone();
async move {
count.fetch_add(1, Ordering::SeqCst);
Err::<(), _>(LlmError::Auth("Invalid API key".to_string()))
}
})
.await;
assert!(result.is_err());
assert_eq!(call_count.load(Ordering::SeqCst), 1); // Should not retry
}
}

View File

@@ -0,0 +1,607 @@
//! Token counting utilities for LLM context management
//!
//! This module provides token counting abstractions and implementations for
//! managing LLM context windows. Token counters estimate token usage without
//! requiring external tokenization libraries, using heuristic-based approaches.
use crate::ChatMessage;
// ============================================================================
// TokenCounter Trait
// ============================================================================
/// Trait for counting tokens in text and chat messages
///
/// Implementations provide model-specific token counting logic to help
/// manage context windows and estimate API costs.
pub trait TokenCounter: Send + Sync {
/// Count tokens in a string
///
/// # Arguments
/// * `text` - The text to count tokens for
///
/// # Returns
/// Estimated number of tokens
fn count(&self, text: &str) -> usize;
/// Count tokens in chat messages
///
/// This accounts for both the message content and the overhead
/// from the chat message structure (roles, delimiters, etc.).
///
/// # Arguments
/// * `messages` - The messages to count tokens for
///
/// # Returns
/// Estimated total tokens including message structure overhead
fn count_messages(&self, messages: &[ChatMessage]) -> usize;
/// Get the model's max context window size
///
/// # Returns
/// Maximum number of tokens the model can handle
fn max_context(&self) -> usize;
}
// ============================================================================
// SimpleTokenCounter
// ============================================================================
/// A basic token counter using simple heuristics
///
/// This counter uses the rule of thumb that English text averages about
/// 4 characters per token. It adds overhead for message structure.
///
/// # Example
/// ```
/// use llm_core::tokens::{TokenCounter, SimpleTokenCounter};
/// use llm_core::ChatMessage;
///
/// let counter = SimpleTokenCounter::new(8192);
/// let text = "Hello, world!";
/// let tokens = counter.count(text);
/// assert!(tokens > 0);
///
/// let messages = vec![
/// ChatMessage::user("What is the weather?"),
/// ChatMessage::assistant("I don't have access to weather data."),
/// ];
/// let total = counter.count_messages(&messages);
/// assert!(total > 0);
/// ```
#[derive(Debug, Clone)]
pub struct SimpleTokenCounter {
max_context: usize,
}
impl SimpleTokenCounter {
/// Create a new simple token counter
///
/// # Arguments
/// * `max_context` - Maximum context window size for the model
pub fn new(max_context: usize) -> Self {
Self { max_context }
}
/// Create a token counter with a default 8192 token context
pub fn default_8k() -> Self {
Self::new(8192)
}
/// Create a token counter with a 32k token context
pub fn with_32k() -> Self {
Self::new(32768)
}
/// Create a token counter with a 128k token context
pub fn with_128k() -> Self {
Self::new(131072)
}
}
impl TokenCounter for SimpleTokenCounter {
fn count(&self, text: &str) -> usize {
// Estimate: approximately 4 characters per token for English
// Add 3 before dividing to round up
(text.len() + 3) / 4
}
fn count_messages(&self, messages: &[ChatMessage]) -> usize {
let mut total = 0;
// Base overhead for message formatting (estimated)
// Each message has role, delimiters, etc.
const MESSAGE_OVERHEAD: usize = 4;
for msg in messages {
// Count role
total += MESSAGE_OVERHEAD;
// Count content
if let Some(content) = &msg.content {
total += self.count(content);
}
// Count tool calls (more expensive due to JSON structure)
if let Some(tool_calls) = &msg.tool_calls {
for tc in tool_calls {
// ID overhead
total += self.count(&tc.id);
// Function name
total += self.count(&tc.function.name);
// Arguments (JSON serialized, add 20% overhead for JSON structure)
let args_str = tc.function.arguments.to_string();
total += (self.count(&args_str) * 12) / 10;
}
}
// Count tool call id for tool result messages
if let Some(tool_call_id) = &msg.tool_call_id {
total += self.count(tool_call_id);
}
// Count tool name for tool result messages
if let Some(name) = &msg.name {
total += self.count(name);
}
}
total
}
fn max_context(&self) -> usize {
self.max_context
}
}
// ============================================================================
// ClaudeTokenCounter
// ============================================================================
/// Token counter optimized for Anthropic Claude models
///
/// Claude models have specific tokenization characteristics and overhead.
/// This counter adjusts the estimates accordingly.
///
/// # Example
/// ```
/// use llm_core::tokens::{TokenCounter, ClaudeTokenCounter};
/// use llm_core::ChatMessage;
///
/// let counter = ClaudeTokenCounter::new();
/// let messages = vec![
/// ChatMessage::system("You are a helpful assistant."),
/// ChatMessage::user("Hello!"),
/// ];
/// let total = counter.count_messages(&messages);
/// ```
#[derive(Debug, Clone)]
pub struct ClaudeTokenCounter {
max_context: usize,
}
impl ClaudeTokenCounter {
/// Create a new Claude token counter with default 200k context
///
/// This is suitable for Claude 3.5 Sonnet, Claude 4 Sonnet, and Claude 4 Opus.
pub fn new() -> Self {
Self {
max_context: 200_000,
}
}
/// Create a Claude counter with a custom context window
///
/// # Arguments
/// * `max_context` - Maximum context window size
pub fn with_context(max_context: usize) -> Self {
Self { max_context }
}
/// Create a counter for Claude 3 Haiku (200k context)
pub fn haiku() -> Self {
Self::new()
}
/// Create a counter for Claude 3.5 Sonnet (200k context)
pub fn sonnet() -> Self {
Self::new()
}
/// Create a counter for Claude 4 Opus (200k context)
pub fn opus() -> Self {
Self::new()
}
}
impl Default for ClaudeTokenCounter {
fn default() -> Self {
Self::new()
}
}
impl TokenCounter for ClaudeTokenCounter {
fn count(&self, text: &str) -> usize {
// Claude's tokenization is similar to the 4 chars/token heuristic
// but tends to be slightly more efficient with structured content
(text.len() + 3) / 4
}
fn count_messages(&self, messages: &[ChatMessage]) -> usize {
let mut total = 0;
// Claude has specific message formatting overhead
const MESSAGE_OVERHEAD: usize = 5;
const SYSTEM_MESSAGE_OVERHEAD: usize = 3;
for msg in messages {
// Different overhead for system vs other messages
let overhead = if matches!(msg.role, crate::Role::System) {
SYSTEM_MESSAGE_OVERHEAD
} else {
MESSAGE_OVERHEAD
};
total += overhead;
// Count content
if let Some(content) = &msg.content {
total += self.count(content);
}
// Count tool calls
if let Some(tool_calls) = &msg.tool_calls {
// Claude's tool call format has additional overhead
const TOOL_CALL_OVERHEAD: usize = 10;
for tc in tool_calls {
total += TOOL_CALL_OVERHEAD;
total += self.count(&tc.id);
total += self.count(&tc.function.name);
// Arguments with JSON structure overhead
let args_str = tc.function.arguments.to_string();
total += (self.count(&args_str) * 12) / 10;
}
}
// Tool result overhead
if msg.tool_call_id.is_some() {
const TOOL_RESULT_OVERHEAD: usize = 8;
total += TOOL_RESULT_OVERHEAD;
if let Some(tool_call_id) = &msg.tool_call_id {
total += self.count(tool_call_id);
}
if let Some(name) = &msg.name {
total += self.count(name);
}
}
}
total
}
fn max_context(&self) -> usize {
self.max_context
}
}
// ============================================================================
// ContextWindow
// ============================================================================
/// Manages context window tracking for a conversation
///
/// Helps monitor token usage and determine when context limits are approaching.
///
/// # Example
/// ```
/// use llm_core::tokens::{ContextWindow, TokenCounter, SimpleTokenCounter};
/// use llm_core::ChatMessage;
///
/// let counter = SimpleTokenCounter::new(8192);
/// let mut window = ContextWindow::new(counter.max_context());
///
/// let messages = vec![
/// ChatMessage::user("Hello!"),
/// ChatMessage::assistant("Hi there!"),
/// ];
///
/// let tokens = counter.count_messages(&messages);
/// window.add_tokens(tokens);
///
/// println!("Used: {} tokens", window.used());
/// println!("Remaining: {} tokens", window.remaining());
/// println!("Usage: {:.1}%", window.usage_percent() * 100.0);
///
/// if window.is_near_limit(0.8) {
/// println!("Warning: Context is 80% full!");
/// }
/// ```
#[derive(Debug, Clone)]
pub struct ContextWindow {
/// Number of tokens currently used
used: usize,
/// Maximum number of tokens allowed
max: usize,
}
impl ContextWindow {
/// Create a new context window tracker
///
/// # Arguments
/// * `max` - Maximum context window size in tokens
pub fn new(max: usize) -> Self {
Self { used: 0, max }
}
/// Create a context window with initial usage
///
/// # Arguments
/// * `max` - Maximum context window size
/// * `used` - Initial number of tokens used
pub fn with_usage(max: usize, used: usize) -> Self {
Self { used, max }
}
/// Get the number of tokens currently used
pub fn used(&self) -> usize {
self.used
}
/// Get the maximum number of tokens
pub fn max(&self) -> usize {
self.max
}
/// Get the number of remaining tokens
pub fn remaining(&self) -> usize {
self.max.saturating_sub(self.used)
}
/// Get the usage as a percentage (0.0 to 1.0)
///
/// Returns the fraction of the context window that is currently used.
pub fn usage_percent(&self) -> f32 {
if self.max == 0 {
return 0.0;
}
self.used as f32 / self.max as f32
}
/// Check if usage is near the limit
///
/// # Arguments
/// * `threshold` - Threshold as a fraction (0.0 to 1.0). For example,
/// 0.8 means "is usage > 80%?"
///
/// # Returns
/// `true` if the current usage exceeds the threshold percentage
pub fn is_near_limit(&self, threshold: f32) -> bool {
self.usage_percent() > threshold
}
/// Add tokens to the usage count
///
/// # Arguments
/// * `tokens` - Number of tokens to add
pub fn add_tokens(&mut self, tokens: usize) {
self.used = self.used.saturating_add(tokens);
}
/// Set the current usage
///
/// # Arguments
/// * `used` - Number of tokens currently used
pub fn set_used(&mut self, used: usize) {
self.used = used;
}
/// Reset the usage counter to zero
pub fn reset(&mut self) {
self.used = 0;
}
/// Check if there's enough room for additional tokens
///
/// # Arguments
/// * `tokens` - Number of tokens needed
///
/// # Returns
/// `true` if adding these tokens would stay within the limit
pub fn has_room_for(&self, tokens: usize) -> bool {
self.used.saturating_add(tokens) <= self.max
}
/// Get a visual progress bar representation
///
/// # Arguments
/// * `width` - Width of the progress bar in characters
///
/// # Returns
/// A string with a simple text-based progress bar
pub fn progress_bar(&self, width: usize) -> String {
if width == 0 {
return String::new();
}
let percent = self.usage_percent();
let filled = ((percent * width as f32) as usize).min(width);
let empty = width - filled;
format!(
"[{}{}] {:.1}%",
"=".repeat(filled),
" ".repeat(empty),
percent * 100.0
)
}
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
use crate::{ChatMessage, FunctionCall, ToolCall};
use serde_json::json;
#[test]
fn test_simple_counter_basic() {
let counter = SimpleTokenCounter::new(8192);
// Empty string
assert_eq!(counter.count(""), 0);
// Short string (~4 chars/token)
let text = "Hello, world!"; // 13 chars -> ~4 tokens
let count = counter.count(text);
assert!(count >= 3 && count <= 5);
// Longer text
let text = "The quick brown fox jumps over the lazy dog"; // 44 chars -> ~11 tokens
let count = counter.count(text);
assert!(count >= 10 && count <= 13);
}
#[test]
fn test_simple_counter_messages() {
let counter = SimpleTokenCounter::new(8192);
let messages = vec![
ChatMessage::user("Hello!"),
ChatMessage::assistant("Hi there! How can I help you today?"),
];
let total = counter.count_messages(&messages);
// Should be more than just the text due to overhead
let text_only = counter.count("Hello!") + counter.count("Hi there! How can I help you today?");
assert!(total > text_only);
}
#[test]
fn test_simple_counter_with_tool_calls() {
let counter = SimpleTokenCounter::new(8192);
let tool_call = ToolCall {
id: "call_123".to_string(),
call_type: "function".to_string(),
function: FunctionCall {
name: "read_file".to_string(),
arguments: json!({"path": "/etc/hosts"}),
},
};
let messages = vec![ChatMessage::assistant_tool_calls(vec![tool_call])];
let total = counter.count_messages(&messages);
assert!(total > 0);
}
#[test]
fn test_claude_counter() {
let counter = ClaudeTokenCounter::new();
assert_eq!(counter.max_context(), 200_000);
let text = "Hello, Claude!";
let count = counter.count(text);
assert!(count > 0);
}
#[test]
fn test_claude_counter_system_message() {
let counter = ClaudeTokenCounter::new();
let messages = vec![
ChatMessage::system("You are a helpful assistant."),
ChatMessage::user("Hello!"),
];
let total = counter.count_messages(&messages);
assert!(total > 0);
}
#[test]
fn test_context_window() {
let mut window = ContextWindow::new(1000);
assert_eq!(window.used(), 0);
assert_eq!(window.max(), 1000);
assert_eq!(window.remaining(), 1000);
assert_eq!(window.usage_percent(), 0.0);
window.add_tokens(200);
assert_eq!(window.used(), 200);
assert_eq!(window.remaining(), 800);
assert_eq!(window.usage_percent(), 0.2);
window.add_tokens(600);
assert_eq!(window.used(), 800);
assert!(window.is_near_limit(0.7));
assert!(!window.is_near_limit(0.9));
assert!(window.has_room_for(200));
assert!(!window.has_room_for(300));
window.reset();
assert_eq!(window.used(), 0);
}
#[test]
fn test_context_window_progress_bar() {
let mut window = ContextWindow::new(100);
window.add_tokens(50);
let bar = window.progress_bar(10);
assert!(bar.contains("====="));
assert!(bar.contains("50.0%"));
window.add_tokens(40);
let bar = window.progress_bar(10);
assert!(bar.contains("========="));
assert!(bar.contains("90.0%"));
}
#[test]
fn test_context_window_saturation() {
let mut window = ContextWindow::new(100);
// Adding more tokens than max should saturate, not overflow
window.add_tokens(150);
assert_eq!(window.used(), 150);
assert_eq!(window.remaining(), 0);
}
#[test]
fn test_simple_counter_constructors() {
let counter1 = SimpleTokenCounter::default_8k();
assert_eq!(counter1.max_context(), 8192);
let counter2 = SimpleTokenCounter::with_32k();
assert_eq!(counter2.max_context(), 32768);
let counter3 = SimpleTokenCounter::with_128k();
assert_eq!(counter3.max_context(), 131072);
}
#[test]
fn test_claude_counter_variants() {
let haiku = ClaudeTokenCounter::haiku();
assert_eq!(haiku.max_context(), 200_000);
let sonnet = ClaudeTokenCounter::sonnet();
assert_eq!(sonnet.max_context(), 200_000);
let opus = ClaudeTokenCounter::opus();
assert_eq!(opus.max_context(), 200_000);
let custom = ClaudeTokenCounter::with_context(100_000);
assert_eq!(custom.max_context(), 100_000);
}
}

View File

@@ -6,11 +6,13 @@ license.workspace = true
rust-version.workspace = true
[dependencies]
llm-core = { path = "../core" }
reqwest = { version = "0.12", features = ["json", "stream"] }
tokio = { version = "1.39", features = ["rt-multi-thread"] }
tokio = { version = "1.39", features = ["rt-multi-thread", "macros"] }
futures = "0.3"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
thiserror = "1"
bytes = "1"
tokio-stream = "0.1.17"
async-trait = "0.1"

View File

@@ -1,14 +1,20 @@
use crate::types::{ChatMessage, ChatResponseChunk};
use futures::{Stream, TryStreamExt};
use crate::types::{ChatMessage, ChatResponseChunk, Tool};
use futures::{Stream, StreamExt, TryStreamExt};
use reqwest::Client;
use serde::Serialize;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use async_trait::async_trait;
use llm_core::{
LlmProvider, ProviderInfo, LlmError, ChatOptions, ChunkStream,
ProviderStatus, AccountInfo, UsageStats, ModelInfo,
};
#[derive(Debug, Clone)]
pub struct OllamaClient {
http: Client,
base_url: String, // e.g. "http://localhost:11434"
api_key: Option<String>, // For Ollama Cloud authentication
current_model: String, // Default model for this client
}
#[derive(Debug, Clone, Default)]
@@ -27,12 +33,24 @@ pub enum OllamaError {
Protocol(String),
}
// Convert OllamaError to LlmError
impl From<OllamaError> for LlmError {
fn from(err: OllamaError) -> Self {
match err {
OllamaError::Http(e) => LlmError::Http(e.to_string()),
OllamaError::Json(e) => LlmError::Json(e.to_string()),
OllamaError::Protocol(msg) => LlmError::Provider(msg),
}
}
}
impl OllamaClient {
pub fn new(base_url: impl Into<String>) -> Self {
Self {
http: Client::new(),
base_url: base_url.into().trim_end_matches('/').to_string(),
api_key: None,
current_model: "qwen3:8b".to_string(),
}
}
@@ -41,24 +59,32 @@ impl OllamaClient {
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.current_model = model.into();
self
}
pub fn with_cloud() -> Self {
// Same API, different base
Self::new("https://ollama.com")
}
pub async fn chat_stream(
pub async fn chat_stream_raw(
&self,
messages: &[ChatMessage],
opts: &OllamaOptions,
tools: Option<&[Tool]>,
) -> Result<impl Stream<Item = Result<ChatResponseChunk, OllamaError>>, OllamaError> {
#[derive(Serialize)]
struct Body<'a> {
model: &'a str,
messages: &'a [ChatMessage],
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<&'a [Tool]>,
}
let url = format!("{}/api/chat", self.base_url);
let body = Body {model: &opts.model, messages, stream: true};
let body = Body {model: &opts.model, messages, stream: true, tools};
let mut req = self.http.post(url).json(&body);
// Add Authorization header if API key is present
@@ -96,3 +122,208 @@ impl OllamaClient {
Ok(out)
}
}
// ============================================================================
// LlmProvider Trait Implementation
// ============================================================================
#[async_trait]
impl LlmProvider for OllamaClient {
fn name(&self) -> &str {
"ollama"
}
fn model(&self) -> &str {
&self.current_model
}
async fn chat_stream(
&self,
messages: &[llm_core::ChatMessage],
options: &ChatOptions,
tools: Option<&[llm_core::Tool]>,
) -> Result<ChunkStream, LlmError> {
// Convert llm_core messages to Ollama messages
let ollama_messages: Vec<ChatMessage> = messages.iter().map(|m| m.into()).collect();
// Convert llm_core tools to Ollama tools if present
let ollama_tools: Option<Vec<Tool>> = tools.map(|tools| {
tools.iter().map(|t| Tool {
tool_type: t.tool_type.clone(),
function: crate::types::ToolFunction {
name: t.function.name.clone(),
description: t.function.description.clone(),
parameters: crate::types::ToolParameters {
param_type: t.function.parameters.param_type.clone(),
properties: t.function.parameters.properties.clone(),
required: t.function.parameters.required.clone(),
},
},
}).collect()
});
let opts = OllamaOptions {
model: options.model.clone(),
stream: true,
};
// Make the request and build the body inline to avoid lifetime issues
#[derive(Serialize)]
struct Body<'a> {
model: &'a str,
messages: &'a [ChatMessage],
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<&'a [Tool]>,
}
let url = format!("{}/api/chat", self.base_url);
let body = Body {
model: &opts.model,
messages: &ollama_messages,
stream: true,
tools: ollama_tools.as_deref(),
};
let mut req = self.http.post(url).json(&body);
// Add Authorization header if API key is present
if let Some(ref key) = self.api_key {
req = req.header("Authorization", format!("Bearer {}", key));
}
let resp = req.send().await
.map_err(|e| LlmError::Http(e.to_string()))?;
let bytes_stream = resp.bytes_stream();
// NDJSON parser: split by '\n', parse each as JSON and stream the results
let converted_stream = bytes_stream
.map(|result| {
result.map_err(|e| LlmError::Http(e.to_string()))
})
.map_ok(|bytes| {
// Convert the chunk to a UTF-8 string and own it
let txt = String::from_utf8_lossy(&bytes).into_owned();
// Parse each non-empty line into a ChatResponseChunk
let results: Vec<Result<llm_core::StreamChunk, LlmError>> = txt
.lines()
.filter_map(|line| {
let trimmed = line.trim();
if trimmed.is_empty() {
None
} else {
Some(
serde_json::from_str::<ChatResponseChunk>(trimmed)
.map(|chunk| llm_core::StreamChunk::from(chunk))
.map_err(|e| LlmError::Json(e.to_string())),
)
}
})
.collect();
futures::stream::iter(results)
})
.try_flatten();
Ok(Box::pin(converted_stream))
}
}
// ============================================================================
// ProviderInfo Trait Implementation
// ============================================================================
#[derive(Debug, Clone, Deserialize)]
struct OllamaModelList {
models: Vec<OllamaModel>,
}
#[derive(Debug, Clone, Deserialize)]
struct OllamaModel {
name: String,
#[serde(default)]
modified_at: Option<String>,
#[serde(default)]
size: Option<u64>,
#[serde(default)]
digest: Option<String>,
#[serde(default)]
details: Option<OllamaModelDetails>,
}
#[derive(Debug, Clone, Deserialize)]
struct OllamaModelDetails {
#[serde(default)]
format: Option<String>,
#[serde(default)]
family: Option<String>,
#[serde(default)]
parameter_size: Option<String>,
}
#[async_trait]
impl ProviderInfo for OllamaClient {
async fn status(&self) -> Result<ProviderStatus, LlmError> {
// Try to ping the Ollama server
let url = format!("{}/api/tags", self.base_url);
let reachable = self.http.get(&url).send().await.is_ok();
Ok(ProviderStatus {
provider: "ollama".to_string(),
authenticated: self.api_key.is_some(),
account: None, // Ollama is local, no account info
model: self.current_model.clone(),
endpoint: self.base_url.clone(),
reachable,
message: if reachable {
Some("Connected to Ollama".to_string())
} else {
Some("Cannot reach Ollama server".to_string())
},
})
}
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError> {
// Ollama is a local service, no account info
Ok(None)
}
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError> {
// Ollama doesn't track usage statistics
Ok(None)
}
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
let url = format!("{}/api/tags", self.base_url);
let mut req = self.http.get(&url);
// Add Authorization header if API key is present
if let Some(ref key) = self.api_key {
req = req.header("Authorization", format!("Bearer {}", key));
}
let resp = req.send().await
.map_err(|e| LlmError::Http(e.to_string()))?;
let model_list: OllamaModelList = resp.json().await
.map_err(|e| LlmError::Json(e.to_string()))?;
// Convert Ollama models to ModelInfo
let models = model_list.models.into_iter().map(|m| {
ModelInfo {
id: m.name.clone(),
display_name: Some(m.name.clone()),
description: m.details.as_ref()
.and_then(|d| d.family.as_ref())
.map(|f| format!("{} model", f)),
context_window: None, // Ollama doesn't provide this in list
max_output_tokens: None,
supports_tools: true, // Most Ollama models support tools
supports_vision: false, // Would need to check model capabilities
input_price_per_mtok: None, // Local models are free
output_price_per_mtok: None,
}
}).collect();
Ok(models)
}
}

View File

@@ -1,5 +1,13 @@
pub mod client;
pub mod types;
pub use client::{OllamaClient, OllamaOptions};
pub use types::{ChatMessage, ChatResponseChunk};
pub use client::{OllamaClient, OllamaOptions, OllamaError};
pub use types::{ChatMessage, ChatResponseChunk, Tool, ToolCall, ToolFunction, ToolParameters, FunctionCall};
// Re-export llm-core traits and types for convenience
pub use llm_core::{
LlmProvider, ProviderInfo, LlmError,
ChatOptions, StreamChunk, ToolCallDelta, Usage,
ProviderStatus, AccountInfo, UsageStats, ModelInfo,
Role,
};

View File

@@ -1,9 +1,51 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
use llm_core::{StreamChunk, ToolCallDelta};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String, // "user", | "assistant" | "system"
pub content: String,
pub role: String, // "user" | "assistant" | "system" | "tool"
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCall {
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
pub call_type: Option<String>, // "function"
pub function: FunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FunctionCall {
pub name: String,
pub arguments: Value, // JSON object with arguments
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String, // "function"
pub function: ToolFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolFunction {
pub name: String,
pub description: String,
pub parameters: ToolParameters,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolParameters {
#[serde(rename = "type")]
pub param_type: String, // "object"
pub properties: Value,
pub required: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
@@ -19,4 +61,70 @@ pub struct ChatResponseChunk {
pub struct ChunkMessage {
pub role: Option<String>,
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
}
// ============================================================================
// Conversions to/from llm-core types
// ============================================================================
/// Convert from llm_core::ChatMessage to Ollama's ChatMessage
impl From<&llm_core::ChatMessage> for ChatMessage {
fn from(msg: &llm_core::ChatMessage) -> Self {
let role = msg.role.as_str().to_string();
// Convert tool_calls if present
let tool_calls = msg.tool_calls.as_ref().map(|calls| {
calls.iter().map(|tc| ToolCall {
id: Some(tc.id.clone()),
call_type: Some(tc.call_type.clone()),
function: FunctionCall {
name: tc.function.name.clone(),
arguments: tc.function.arguments.clone(),
},
}).collect()
});
ChatMessage {
role,
content: msg.content.clone(),
tool_calls,
}
}
}
/// Convert from Ollama's ChatResponseChunk to llm_core::StreamChunk
impl From<ChatResponseChunk> for StreamChunk {
fn from(chunk: ChatResponseChunk) -> Self {
let done = chunk.done.unwrap_or(false);
let content = chunk.message.as_ref().and_then(|m| m.content.clone());
// Convert tool calls to deltas
let tool_calls = chunk.message.as_ref().and_then(|m| {
m.tool_calls.as_ref().map(|calls| {
calls.iter().enumerate().map(|(index, tc)| {
// Serialize arguments back to JSON string for delta
let arguments_delta = serde_json::to_string(&tc.function.arguments).ok();
ToolCallDelta {
index,
id: tc.id.clone(),
function_name: Some(tc.function.name.clone()),
arguments_delta,
}
}).collect()
})
});
// Ollama doesn't provide per-chunk usage stats, only in final chunk
let usage = None;
StreamChunk {
content,
tool_calls,
done,
usage,
}
}
}

View File

@@ -0,0 +1,18 @@
[package]
name = "llm-openai"
version = "0.1.0"
edition.workspace = true
license.workspace = true
description = "OpenAI GPT API client for Owlen"
[dependencies]
llm-core = { path = "../core" }
async-trait = "0.1"
futures = "0.3"
reqwest = { version = "0.12", features = ["json", "stream"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1", features = ["sync", "time", "io-util"] }
tokio-stream = { version = "0.1", default-features = false, features = ["io-util"] }
tokio-util = { version = "0.7", features = ["codec", "io"] }
tracing = "0.1"

View File

@@ -0,0 +1,285 @@
//! OpenAI OAuth Authentication
//!
//! Implements device code flow for authenticating with OpenAI without API keys.
use llm_core::{AuthMethod, DeviceAuthResult, DeviceCodeResponse, LlmError, OAuthProvider};
use reqwest::Client;
use serde::{Deserialize, Serialize};
/// OAuth client for OpenAI device flow
pub struct OpenAIAuth {
http: Client,
client_id: String,
}
// OpenAI OAuth endpoints
const AUTH_BASE_URL: &str = "https://auth.openai.com";
const DEVICE_CODE_ENDPOINT: &str = "/oauth/device/code";
const TOKEN_ENDPOINT: &str = "/oauth/token";
// Default client ID for Owlen CLI
const DEFAULT_CLIENT_ID: &str = "owlen-cli";
impl OpenAIAuth {
/// Create a new OAuth client with the default CLI client ID
pub fn new() -> Self {
Self {
http: Client::new(),
client_id: DEFAULT_CLIENT_ID.to_string(),
}
}
/// Create with a custom client ID
pub fn with_client_id(client_id: impl Into<String>) -> Self {
Self {
http: Client::new(),
client_id: client_id.into(),
}
}
}
impl Default for OpenAIAuth {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Serialize)]
struct DeviceCodeRequest<'a> {
client_id: &'a str,
scope: &'a str,
}
#[derive(Debug, Deserialize)]
struct DeviceCodeApiResponse {
device_code: String,
user_code: String,
verification_uri: String,
verification_uri_complete: Option<String>,
expires_in: u64,
interval: u64,
}
#[derive(Debug, Serialize)]
struct TokenRequest<'a> {
client_id: &'a str,
device_code: &'a str,
grant_type: &'a str,
}
#[derive(Debug, Deserialize)]
struct TokenApiResponse {
access_token: String,
#[allow(dead_code)]
token_type: String,
expires_in: Option<u64>,
refresh_token: Option<String>,
}
#[derive(Debug, Deserialize)]
struct TokenErrorResponse {
error: String,
error_description: Option<String>,
}
#[async_trait::async_trait]
impl OAuthProvider for OpenAIAuth {
async fn start_device_auth(&self) -> Result<DeviceCodeResponse, LlmError> {
let url = format!("{}{}", AUTH_BASE_URL, DEVICE_CODE_ENDPOINT);
let request = DeviceCodeRequest {
client_id: &self.client_id,
scope: "api.read api.write",
};
let response = self
.http
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::Http(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::Auth(format!(
"Device code request failed ({}): {}",
status, text
)));
}
let api_response: DeviceCodeApiResponse = response
.json()
.await
.map_err(|e| LlmError::Json(e.to_string()))?;
Ok(DeviceCodeResponse {
device_code: api_response.device_code,
user_code: api_response.user_code,
verification_uri: api_response.verification_uri,
verification_uri_complete: api_response.verification_uri_complete,
expires_in: api_response.expires_in,
interval: api_response.interval,
})
}
async fn poll_device_auth(&self, device_code: &str) -> Result<DeviceAuthResult, LlmError> {
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
let request = TokenRequest {
client_id: &self.client_id,
device_code,
grant_type: "urn:ietf:params:oauth:grant-type:device_code",
};
let response = self
.http
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::Http(e.to_string()))?;
if response.status().is_success() {
let token_response: TokenApiResponse = response
.json()
.await
.map_err(|e| LlmError::Json(e.to_string()))?;
return Ok(DeviceAuthResult::Success {
access_token: token_response.access_token,
refresh_token: token_response.refresh_token,
expires_in: token_response.expires_in,
});
}
// Parse error response
let error_response: TokenErrorResponse = response
.json()
.await
.map_err(|e| LlmError::Json(e.to_string()))?;
match error_response.error.as_str() {
"authorization_pending" => Ok(DeviceAuthResult::Pending),
"slow_down" => Ok(DeviceAuthResult::Pending),
"access_denied" => Ok(DeviceAuthResult::Denied),
"expired_token" => Ok(DeviceAuthResult::Expired),
_ => Err(LlmError::Auth(format!(
"Token request failed: {} - {}",
error_response.error,
error_response.error_description.unwrap_or_default()
))),
}
}
async fn refresh_token(&self, refresh_token: &str) -> Result<AuthMethod, LlmError> {
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
#[derive(Serialize)]
struct RefreshRequest<'a> {
client_id: &'a str,
refresh_token: &'a str,
grant_type: &'a str,
}
let request = RefreshRequest {
client_id: &self.client_id,
refresh_token,
grant_type: "refresh_token",
};
let response = self
.http
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::Http(e.to_string()))?;
if !response.status().is_success() {
let text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::Auth(format!("Token refresh failed: {}", text)));
}
let token_response: TokenApiResponse = response
.json()
.await
.map_err(|e| LlmError::Json(e.to_string()))?;
let expires_at = token_response.expires_in.map(|secs| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() + secs)
.unwrap_or(0)
});
Ok(AuthMethod::OAuth {
access_token: token_response.access_token,
refresh_token: token_response.refresh_token,
expires_at,
})
}
}
/// Helper to perform the full device auth flow with polling
pub async fn perform_device_auth<F>(
auth: &OpenAIAuth,
on_code: F,
) -> Result<AuthMethod, LlmError>
where
F: FnOnce(&DeviceCodeResponse),
{
// Start the device flow
let device_code = auth.start_device_auth().await?;
// Let caller display the code to user
on_code(&device_code);
// Poll for completion
let poll_interval = std::time::Duration::from_secs(device_code.interval);
let deadline =
std::time::Instant::now() + std::time::Duration::from_secs(device_code.expires_in);
loop {
if std::time::Instant::now() > deadline {
return Err(LlmError::Auth("Device code expired".to_string()));
}
tokio::time::sleep(poll_interval).await;
match auth.poll_device_auth(&device_code.device_code).await? {
DeviceAuthResult::Success {
access_token,
refresh_token,
expires_in,
} => {
let expires_at = expires_in.map(|secs| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() + secs)
.unwrap_or(0)
});
return Ok(AuthMethod::OAuth {
access_token,
refresh_token,
expires_at,
});
}
DeviceAuthResult::Pending => continue,
DeviceAuthResult::Denied => {
return Err(LlmError::Auth("Authorization denied by user".to_string()));
}
DeviceAuthResult::Expired => {
return Err(LlmError::Auth("Device code expired".to_string()));
}
}
}
}

View File

@@ -0,0 +1,561 @@
//! OpenAI GPT API Client
//!
//! Implements the Chat Completions API with streaming support.
use crate::types::*;
use async_trait::async_trait;
use futures::StreamExt;
use llm_core::{
AccountInfo, AuthMethod, ChatMessage, ChatOptions, ChatResponse, ChunkStream, FunctionCall,
LlmError, LlmProvider, ModelInfo, ProviderInfo, ProviderStatus, StreamChunk, Tool, ToolCall,
ToolCallDelta, Usage, UsageStats,
};
use reqwest::Client;
use tokio::io::AsyncBufReadExt;
use tokio_stream::wrappers::LinesStream;
use tokio_util::io::StreamReader;
const API_BASE_URL: &str = "https://api.openai.com/v1";
const CHAT_ENDPOINT: &str = "/chat/completions";
const MODELS_ENDPOINT: &str = "/models";
/// OpenAI GPT API client
pub struct OpenAIClient {
http: Client,
auth: AuthMethod,
model: String,
}
impl OpenAIClient {
/// Create a new client with API key authentication
pub fn new(api_key: impl Into<String>) -> Self {
Self {
http: Client::new(),
auth: AuthMethod::api_key(api_key),
model: "gpt-4o".to_string(),
}
}
/// Create a new client with OAuth token
pub fn with_oauth(access_token: impl Into<String>) -> Self {
Self {
http: Client::new(),
auth: AuthMethod::oauth(access_token),
model: "gpt-4o".to_string(),
}
}
/// Create a new client with full AuthMethod
pub fn with_auth(auth: AuthMethod) -> Self {
Self {
http: Client::new(),
auth,
model: "gpt-4o".to_string(),
}
}
/// Set the model to use
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
/// Get current auth method (for token refresh)
pub fn auth(&self) -> &AuthMethod {
&self.auth
}
/// Update the auth method (after refresh)
pub fn set_auth(&mut self, auth: AuthMethod) {
self.auth = auth;
}
/// Convert messages to OpenAI format
fn prepare_messages(messages: &[ChatMessage]) -> Vec<OpenAIMessage> {
messages.iter().map(OpenAIMessage::from).collect()
}
/// Convert tools to OpenAI format
fn prepare_tools(tools: Option<&[Tool]>) -> Option<Vec<OpenAITool>> {
tools.map(|t| t.iter().map(OpenAITool::from).collect())
}
}
#[async_trait]
impl LlmProvider for OpenAIClient {
fn name(&self) -> &str {
"openai"
}
fn model(&self) -> &str {
&self.model
}
async fn chat_stream(
&self,
messages: &[ChatMessage],
options: &ChatOptions,
tools: Option<&[Tool]>,
) -> Result<ChunkStream, LlmError> {
let url = format!("{}{}", API_BASE_URL, CHAT_ENDPOINT);
let model = if options.model.is_empty() {
&self.model
} else {
&options.model
};
let openai_messages = Self::prepare_messages(messages);
let openai_tools = Self::prepare_tools(tools);
let request = ChatCompletionRequest {
model,
messages: openai_messages,
temperature: options.temperature,
max_tokens: options.max_tokens,
top_p: options.top_p,
stop: options.stop.as_deref(),
tools: openai_tools,
tool_choice: None,
stream: true,
};
let bearer = self
.auth
.bearer_token()
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
let response = self
.http
.post(&url)
.header("Authorization", format!("Bearer {}", bearer))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| LlmError::Http(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
return Err(LlmError::RateLimit {
retry_after_secs: None,
});
}
// Try to parse as error response
if let Ok(err_resp) = serde_json::from_str::<ErrorResponse>(&text) {
return Err(LlmError::Api {
message: err_resp.error.message,
code: err_resp.error.code,
});
}
return Err(LlmError::Api {
message: text,
code: Some(status.to_string()),
});
}
// Parse SSE stream
let byte_stream = response
.bytes_stream()
.map(|result| result.map_err(std::io::Error::other));
let reader = StreamReader::new(byte_stream);
let buf_reader = tokio::io::BufReader::new(reader);
let lines_stream = LinesStream::new(buf_reader.lines());
let chunk_stream = lines_stream.filter_map(|line_result| async move {
match line_result {
Ok(line) => parse_sse_line(&line),
Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
}
});
Ok(Box::pin(chunk_stream))
}
async fn chat(
&self,
messages: &[ChatMessage],
options: &ChatOptions,
tools: Option<&[Tool]>,
) -> Result<ChatResponse, LlmError> {
let url = format!("{}{}", API_BASE_URL, CHAT_ENDPOINT);
let model = if options.model.is_empty() {
&self.model
} else {
&options.model
};
let openai_messages = Self::prepare_messages(messages);
let openai_tools = Self::prepare_tools(tools);
let request = ChatCompletionRequest {
model,
messages: openai_messages,
temperature: options.temperature,
max_tokens: options.max_tokens,
top_p: options.top_p,
stop: options.stop.as_deref(),
tools: openai_tools,
tool_choice: None,
stream: false,
};
let bearer = self
.auth
.bearer_token()
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
let response = self
.http
.post(&url)
.header("Authorization", format!("Bearer {}", bearer))
.json(&request)
.send()
.await
.map_err(|e| LlmError::Http(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
return Err(LlmError::RateLimit {
retry_after_secs: None,
});
}
if let Ok(err_resp) = serde_json::from_str::<ErrorResponse>(&text) {
return Err(LlmError::Api {
message: err_resp.error.message,
code: err_resp.error.code,
});
}
return Err(LlmError::Api {
message: text,
code: Some(status.to_string()),
});
}
let api_response: ChatCompletionResponse = response
.json()
.await
.map_err(|e| LlmError::Json(e.to_string()))?;
// Extract the first choice
let choice = api_response
.choices
.first()
.ok_or_else(|| LlmError::Api {
message: "No choices in response".to_string(),
code: None,
})?;
let content = choice.message.content.clone();
let tool_calls = choice.message.tool_calls.as_ref().map(|calls| {
calls
.iter()
.map(|call| {
let arguments: serde_json::Value =
serde_json::from_str(&call.function.arguments).unwrap_or_default();
ToolCall {
id: call.id.clone(),
call_type: "function".to_string(),
function: FunctionCall {
name: call.function.name.clone(),
arguments,
},
}
})
.collect()
});
let usage = api_response.usage.map(|u| Usage {
prompt_tokens: u.prompt_tokens,
completion_tokens: u.completion_tokens,
total_tokens: u.total_tokens,
});
Ok(ChatResponse {
content,
tool_calls,
usage,
})
}
}
/// Parse a single SSE line into a StreamChunk
fn parse_sse_line(line: &str) -> Option<Result<StreamChunk, LlmError>> {
let line = line.trim();
// Skip empty lines and comments
if line.is_empty() || line.starts_with(':') {
return None;
}
// SSE format: "data: <json>"
if let Some(data) = line.strip_prefix("data: ") {
// OpenAI sends [DONE] to signal end
if data == "[DONE]" {
return Some(Ok(StreamChunk {
content: None,
tool_calls: None,
done: true,
usage: None,
}));
}
// Parse the JSON chunk
match serde_json::from_str::<ChatCompletionChunk>(data) {
Ok(chunk) => Some(convert_chunk_to_stream_chunk(chunk)),
Err(e) => {
tracing::warn!("Failed to parse SSE chunk: {}", e);
None
}
}
} else {
None
}
}
/// Convert OpenAI chunk to our common format
fn convert_chunk_to_stream_chunk(chunk: ChatCompletionChunk) -> Result<StreamChunk, LlmError> {
let choice = chunk.choices.first();
if let Some(choice) = choice {
let content = choice.delta.content.clone();
let tool_calls = choice.delta.tool_calls.as_ref().map(|deltas| {
deltas
.iter()
.map(|delta| ToolCallDelta {
index: delta.index,
id: delta.id.clone(),
function_name: delta.function.as_ref().and_then(|f| f.name.clone()),
arguments_delta: delta.function.as_ref().and_then(|f| f.arguments.clone()),
})
.collect()
});
let done = choice.finish_reason.is_some();
Ok(StreamChunk {
content,
tool_calls,
done,
usage: None,
})
} else {
// No choices, treat as done
Ok(StreamChunk {
content: None,
tool_calls: None,
done: true,
usage: None,
})
}
}
// ============================================================================
// ProviderInfo Implementation
// ============================================================================
/// Known GPT models with their specifications
fn get_gpt_models() -> Vec<ModelInfo> {
vec![
ModelInfo {
id: "gpt-4o".to_string(),
display_name: Some("GPT-4o".to_string()),
description: Some("Most advanced multimodal model with vision".to_string()),
context_window: Some(128_000),
max_output_tokens: Some(16_384),
supports_tools: true,
supports_vision: true,
input_price_per_mtok: Some(2.50),
output_price_per_mtok: Some(10.0),
},
ModelInfo {
id: "gpt-4o-mini".to_string(),
display_name: Some("GPT-4o mini".to_string()),
description: Some("Affordable and fast model for simple tasks".to_string()),
context_window: Some(128_000),
max_output_tokens: Some(16_384),
supports_tools: true,
supports_vision: true,
input_price_per_mtok: Some(0.15),
output_price_per_mtok: Some(0.60),
},
ModelInfo {
id: "gpt-4-turbo".to_string(),
display_name: Some("GPT-4 Turbo".to_string()),
description: Some("Previous generation high-performance model".to_string()),
context_window: Some(128_000),
max_output_tokens: Some(4_096),
supports_tools: true,
supports_vision: true,
input_price_per_mtok: Some(10.0),
output_price_per_mtok: Some(30.0),
},
ModelInfo {
id: "gpt-3.5-turbo".to_string(),
display_name: Some("GPT-3.5 Turbo".to_string()),
description: Some("Fast and affordable for simple tasks".to_string()),
context_window: Some(16_385),
max_output_tokens: Some(4_096),
supports_tools: true,
supports_vision: false,
input_price_per_mtok: Some(0.50),
output_price_per_mtok: Some(1.50),
},
ModelInfo {
id: "o1".to_string(),
display_name: Some("OpenAI o1".to_string()),
description: Some("Reasoning model optimized for complex problems".to_string()),
context_window: Some(200_000),
max_output_tokens: Some(100_000),
supports_tools: false,
supports_vision: true,
input_price_per_mtok: Some(15.0),
output_price_per_mtok: Some(60.0),
},
ModelInfo {
id: "o1-mini".to_string(),
display_name: Some("OpenAI o1-mini".to_string()),
description: Some("Faster reasoning model for STEM".to_string()),
context_window: Some(128_000),
max_output_tokens: Some(65_536),
supports_tools: false,
supports_vision: true,
input_price_per_mtok: Some(3.0),
output_price_per_mtok: Some(12.0),
},
]
}
#[async_trait]
impl ProviderInfo for OpenAIClient {
async fn status(&self) -> Result<ProviderStatus, LlmError> {
let authenticated = self.auth.bearer_token().is_some();
// Try to reach the API by listing models
let reachable = if authenticated {
let url = format!("{}{}", API_BASE_URL, MODELS_ENDPOINT);
let bearer = self.auth.bearer_token().unwrap();
match self
.http
.get(&url)
.header("Authorization", format!("Bearer {}", bearer))
.send()
.await
{
Ok(resp) => resp.status().is_success(),
Err(_) => false,
}
} else {
false
};
let message = if !authenticated {
Some("Not authenticated - set OPENAI_API_KEY or run 'owlen login openai'".to_string())
} else if !reachable {
Some("Cannot reach OpenAI API".to_string())
} else {
Some("Connected".to_string())
};
Ok(ProviderStatus {
provider: "openai".to_string(),
authenticated,
account: None, // OpenAI doesn't expose account info via API
model: self.model.clone(),
endpoint: API_BASE_URL.to_string(),
reachable,
message,
})
}
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError> {
// OpenAI doesn't have a public account info endpoint
Ok(None)
}
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError> {
// OpenAI doesn't expose usage stats via the standard API
Ok(None)
}
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
// We can optionally fetch from API, but return known models for now
Ok(get_gpt_models())
}
async fn model_info(&self, model_id: &str) -> Result<Option<ModelInfo>, LlmError> {
let models = get_gpt_models();
Ok(models.into_iter().find(|m| m.id == model_id))
}
}
#[cfg(test)]
mod tests {
use super::*;
use llm_core::ToolParameters;
use serde_json::json;
#[test]
fn test_message_conversion() {
let messages = vec![
ChatMessage::system("You are helpful"),
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there!"),
];
let openai_msgs = OpenAIClient::prepare_messages(&messages);
assert_eq!(openai_msgs.len(), 3);
assert_eq!(openai_msgs[0].role, "system");
assert_eq!(openai_msgs[1].role, "user");
assert_eq!(openai_msgs[2].role, "assistant");
}
#[test]
fn test_tool_conversion() {
let tools = vec![Tool::function(
"read_file",
"Read a file's contents",
ToolParameters::object(
json!({
"path": {
"type": "string",
"description": "File path"
}
}),
vec!["path".to_string()],
),
)];
let openai_tools = OpenAIClient::prepare_tools(Some(&tools)).unwrap();
assert_eq!(openai_tools.len(), 1);
assert_eq!(openai_tools[0].function.name, "read_file");
assert_eq!(
openai_tools[0].function.description,
"Read a file's contents"
);
}
}

View File

@@ -0,0 +1,12 @@
//! OpenAI GPT API Client
//!
//! Implements the LlmProvider trait for OpenAI's GPT models.
//! Supports both API key authentication and OAuth device flow.
mod auth;
mod client;
mod types;
pub use auth::*;
pub use client::*;
pub use types::*;

View File

@@ -0,0 +1,285 @@
//! OpenAI API request/response types
use serde::{Deserialize, Serialize};
use serde_json::Value;
// ============================================================================
// Request Types
// ============================================================================
#[derive(Debug, Serialize)]
pub struct ChatCompletionRequest<'a> {
pub model: &'a str,
pub messages: Vec<OpenAIMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<&'a [String]>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<OpenAITool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<&'a str>,
pub stream: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIMessage {
pub role: String, // "system", "user", "assistant", "tool"
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<OpenAIToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: OpenAIFunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIFunctionCall {
pub name: String,
pub arguments: String, // JSON string
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAITool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: OpenAIFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIFunction {
pub name: String,
pub description: String,
pub parameters: FunctionParameters,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionParameters {
#[serde(rename = "type")]
pub param_type: String,
pub properties: Value,
pub required: Vec<String>,
}
// ============================================================================
// Response Types
// ============================================================================
#[derive(Debug, Clone, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Option<UsageInfo>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Choice {
pub index: u32,
pub message: OpenAIMessage,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct UsageInfo {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
// ============================================================================
// Streaming Response Types
// ============================================================================
#[derive(Debug, Clone, Deserialize)]
pub struct ChatCompletionChunk {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChunkChoice>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ChunkChoice {
pub index: u32,
pub delta: Delta,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Delta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<DeltaToolCall>>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct DeltaToolCall {
pub index: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename = "type")]
pub call_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<DeltaFunction>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct DeltaFunction {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}
// ============================================================================
// Error Response Types
// ============================================================================
#[derive(Debug, Clone, Deserialize)]
pub struct ErrorResponse {
pub error: ApiError,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ApiError {
pub message: String,
#[serde(rename = "type")]
pub error_type: String,
pub code: Option<String>,
}
// ============================================================================
// Models List Response
// ============================================================================
#[derive(Debug, Clone, Deserialize)]
pub struct ModelsResponse {
pub object: String,
pub data: Vec<ModelData>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ModelData {
pub id: String,
pub object: String,
pub created: u64,
pub owned_by: String,
}
// ============================================================================
// Conversions
// ============================================================================
impl From<&llm_core::Tool> for OpenAITool {
fn from(tool: &llm_core::Tool) -> Self {
Self {
tool_type: "function".to_string(),
function: OpenAIFunction {
name: tool.function.name.clone(),
description: tool.function.description.clone(),
parameters: FunctionParameters {
param_type: tool.function.parameters.param_type.clone(),
properties: tool.function.parameters.properties.clone(),
required: tool.function.parameters.required.clone(),
},
},
}
}
}
impl From<&llm_core::ChatMessage> for OpenAIMessage {
fn from(msg: &llm_core::ChatMessage) -> Self {
use llm_core::Role;
let role = match msg.role {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
Role::Tool => "tool",
};
// Handle tool result messages
if msg.role == Role::Tool {
return Self {
role: "tool".to_string(),
content: msg.content.clone(),
tool_calls: None,
tool_call_id: msg.tool_call_id.clone(),
name: msg.name.clone(),
};
}
// Handle assistant messages with tool calls
if msg.role == Role::Assistant && msg.tool_calls.is_some() {
let tool_calls = msg.tool_calls.as_ref().map(|calls| {
calls
.iter()
.map(|call| OpenAIToolCall {
id: call.id.clone(),
call_type: "function".to_string(),
function: OpenAIFunctionCall {
name: call.function.name.clone(),
arguments: serde_json::to_string(&call.function.arguments)
.unwrap_or_else(|_| "{}".to_string()),
},
})
.collect()
});
return Self {
role: "assistant".to_string(),
content: msg.content.clone(),
tool_calls,
tool_call_id: None,
name: None,
};
}
// Simple text message
Self {
role: role.to_string(),
content: msg.content.clone(),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
}

View File

@@ -10,6 +10,7 @@ serde = { version = "1", features = ["derive"] }
directories = "5"
figment = { version = "0.10", features = ["toml", "env"] }
permissions = { path = "../permissions" }
llm-core = { path = "../../llm/core" }
[dev-dependencies]
tempfile = "3.23.0"

View File

@@ -5,26 +5,65 @@ use figment::{
};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::env;
use permissions::{Mode, PermissionManager};
use llm_core::ProviderType;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Settings {
#[serde(default = "default_ollama_url")]
pub ollama_url: String,
// Provider configuration
#[serde(default = "default_provider")]
pub provider: String, // "ollama" | "anthropic" | "openai"
#[serde(default = "default_model")]
pub model: String,
#[serde(default = "default_mode")]
pub mode: String, // "plan" (read-only) for now
// Ollama-specific
#[serde(default = "default_ollama_url")]
pub ollama_url: String,
// API keys for different providers
#[serde(default)]
pub api_key: Option<String>, // For Ollama Cloud or other API authentication
pub api_key: Option<String>, // For Ollama Cloud or backwards compatibility
#[serde(default)]
pub anthropic_api_key: Option<String>,
#[serde(default)]
pub openai_api_key: Option<String>,
// Permission mode
#[serde(default = "default_mode")]
pub mode: String, // "plan" | "acceptEdits" | "code"
// Tool permission lists
/// Tools that are always allowed without prompting
/// Format: "tool_name" or "tool_name:pattern"
/// Example: ["bash:npm test:*", "bash:cargo test:*", "mcp:filesystem__*"]
#[serde(default)]
pub allowed_tools: Vec<String>,
/// Tools that are always denied (blocked)
/// Format: "tool_name" or "tool_name:pattern"
/// Example: ["bash:rm -rf*", "bash:sudo*"]
#[serde(default)]
pub disallowed_tools: Vec<String>,
}
fn default_provider() -> String {
"ollama".into()
}
fn default_ollama_url() -> String {
"http://localhost:11434".into()
}
fn default_model() -> String {
// Default model depends on provider, but we use ollama's default here
// Users can override this per-provider or use get_effective_model()
"qwen3:8b".into()
}
fn default_mode() -> String {
"plan".into()
}
@@ -32,25 +71,71 @@ fn default_mode() -> String {
impl Default for Settings {
fn default() -> Self {
Self {
ollama_url: default_ollama_url(),
provider: default_provider(),
model: default_model(),
mode: default_mode(),
ollama_url: default_ollama_url(),
api_key: None,
anthropic_api_key: None,
openai_api_key: None,
mode: default_mode(),
allowed_tools: Vec::new(),
disallowed_tools: Vec::new(),
}
}
}
impl Settings {
/// Create a PermissionManager based on the configured mode
/// Create a PermissionManager based on the configured mode and tool lists
///
/// Tool lists are applied in order:
/// 1. Disallowed tools (highest priority - blocked first)
/// 2. Allowed tools
/// 3. Mode-based defaults
pub fn create_permission_manager(&self) -> PermissionManager {
let mode = Mode::from_str(&self.mode).unwrap_or(Mode::Plan);
PermissionManager::new(mode)
let mut pm = PermissionManager::new(mode);
// Add disallowed tools first (deny rules take precedence)
pm.add_disallowed_tools(&self.disallowed_tools);
// Then add allowed tools
pm.add_allowed_tools(&self.allowed_tools);
pm
}
/// Get the Mode enum from the mode string
pub fn get_mode(&self) -> Mode {
Mode::from_str(&self.mode).unwrap_or(Mode::Plan)
}
/// Get the ProviderType enum from the provider string
pub fn get_provider(&self) -> Option<ProviderType> {
ProviderType::from_str(&self.provider)
}
/// Get the effective model for the current provider
/// If no model is explicitly set, returns the provider's default
pub fn get_effective_model(&self) -> String {
// If model is explicitly set and not the default, use it
if self.model != default_model() {
return self.model.clone();
}
// Otherwise, use provider-specific default
self.get_provider()
.map(|p| p.default_model().to_string())
.unwrap_or_else(|| self.model.clone())
}
/// Get the API key for the current provider
pub fn get_provider_api_key(&self) -> Option<String> {
match self.get_provider()? {
ProviderType::Ollama => self.api_key.clone(),
ProviderType::Anthropic => self.anthropic_api_key.clone(),
ProviderType::OpenAI => self.openai_api_key.clone(),
}
}
}
pub fn load_settings(project_root: Option<&str>) -> Result<Settings, figment::Error> {
@@ -68,9 +153,31 @@ pub fn load_settings(project_root: Option<&str>) -> Result<Settings, figment::Er
}
// Environment variables have highest precedence
// OWLEN_* prefix (e.g., OWLEN_PROVIDER, OWLEN_MODEL, OWLEN_API_KEY, OWLEN_ANTHROPIC_API_KEY)
fig = fig.merge(Env::prefixed("OWLEN_").split("__"));
// Support OLLAMA_API_KEY, OLLAMA_MODEL, etc. (without nesting)
// Support OLLAMA_* prefix for backwards compatibility
fig = fig.merge(Env::prefixed("OLLAMA_"));
fig.extract()
// Support PROVIDER env var (without OWLEN_ prefix)
fig = fig.merge(Env::raw().only(&["PROVIDER"]));
// Extract the settings
let mut settings: Settings = fig.extract()?;
// Manually handle standard provider API key env vars (ANTHROPIC_API_KEY, OPENAI_API_KEY)
// These override config files but are overridden by OWLEN_* vars
if settings.anthropic_api_key.is_none() {
if let Ok(key) = env::var("ANTHROPIC_API_KEY") {
settings.anthropic_api_key = Some(key);
}
}
if settings.openai_api_key.is_none() {
if let Ok(key) = env::var("OPENAI_API_KEY") {
settings.openai_api_key = Some(key);
}
}
Ok(settings)
}

View File

@@ -1,5 +1,6 @@
use config_agent::{load_settings, Settings};
use permissions::{Mode, PermissionDecision, Tool};
use llm_core::ProviderType;
use std::{env, fs};
#[test]
@@ -46,3 +47,188 @@ fn settings_parse_mode_from_config() {
assert_eq!(mgr.check(Tool::Write, None), PermissionDecision::Allow);
assert_eq!(mgr.check(Tool::Bash, None), PermissionDecision::Allow);
}
#[test]
fn default_provider_is_ollama() {
let s = Settings::default();
assert_eq!(s.provider, "ollama");
assert_eq!(s.get_provider(), Some(ProviderType::Ollama));
}
#[test]
fn provider_from_config_file() {
let tmp = tempfile::tempdir().unwrap();
let project_file = tmp.path().join(".owlen.toml");
fs::write(&project_file, r#"provider="anthropic""#).unwrap();
let s = load_settings(Some(tmp.path().to_str().unwrap())).unwrap();
assert_eq!(s.provider, "anthropic");
assert_eq!(s.get_provider(), Some(ProviderType::Anthropic));
}
#[test]
#[ignore] // Ignore due to env var interaction in parallel tests
fn provider_from_env_var() {
let tmp = tempfile::tempdir().unwrap();
unsafe {
env::set_var("OWLEN_PROVIDER", "openai");
env::remove_var("PROVIDER");
env::remove_var("ANTHROPIC_API_KEY");
env::remove_var("OPENAI_API_KEY");
}
let s = load_settings(Some(tmp.path().to_str().unwrap())).unwrap();
assert_eq!(s.provider, "openai");
assert_eq!(s.get_provider(), Some(ProviderType::OpenAI));
unsafe { env::remove_var("OWLEN_PROVIDER"); }
}
#[test]
#[ignore] // Ignore due to env var interaction in parallel tests
fn provider_from_provider_env_var() {
let tmp = tempfile::tempdir().unwrap();
unsafe {
env::set_var("PROVIDER", "anthropic");
env::remove_var("OWLEN_PROVIDER");
env::remove_var("ANTHROPIC_API_KEY");
env::remove_var("OPENAI_API_KEY");
}
let s = load_settings(Some(tmp.path().to_str().unwrap())).unwrap();
assert_eq!(s.provider, "anthropic");
assert_eq!(s.get_provider(), Some(ProviderType::Anthropic));
unsafe { env::remove_var("PROVIDER"); }
}
#[test]
fn anthropic_api_key_from_owlen_env() {
let tmp = tempfile::tempdir().unwrap();
let project_file = tmp.path().join(".owlen.toml");
fs::write(&project_file, r#"provider="anthropic""#).unwrap();
unsafe { env::set_var("OWLEN_ANTHROPIC_API_KEY", "sk-ant-test123"); }
let s = load_settings(Some(tmp.path().to_str().unwrap())).unwrap();
assert_eq!(s.anthropic_api_key, Some("sk-ant-test123".to_string()));
assert_eq!(s.get_provider_api_key(), Some("sk-ant-test123".to_string()));
unsafe { env::remove_var("OWLEN_ANTHROPIC_API_KEY"); }
}
#[test]
fn openai_api_key_from_owlen_env() {
let tmp = tempfile::tempdir().unwrap();
let project_file = tmp.path().join(".owlen.toml");
fs::write(&project_file, r#"provider="openai""#).unwrap();
unsafe { env::set_var("OWLEN_OPENAI_API_KEY", "sk-test-456"); }
let s = load_settings(Some(tmp.path().to_str().unwrap())).unwrap();
assert_eq!(s.openai_api_key, Some("sk-test-456".to_string()));
assert_eq!(s.get_provider_api_key(), Some("sk-test-456".to_string()));
unsafe { env::remove_var("OWLEN_OPENAI_API_KEY"); }
}
#[test]
#[ignore] // Ignore due to env var interaction in parallel tests
fn api_keys_from_config_file() {
let tmp = tempfile::tempdir().unwrap();
let project_file = tmp.path().join(".owlen.toml");
fs::write(&project_file, r#"
provider = "anthropic"
anthropic_api_key = "sk-ant-from-file"
openai_api_key = "sk-openai-from-file"
"#).unwrap();
// Clear any env vars that might interfere
unsafe {
env::remove_var("ANTHROPIC_API_KEY");
env::remove_var("OPENAI_API_KEY");
env::remove_var("OWLEN_ANTHROPIC_API_KEY");
env::remove_var("OWLEN_OPENAI_API_KEY");
}
let s = load_settings(Some(tmp.path().to_str().unwrap())).unwrap();
assert_eq!(s.anthropic_api_key, Some("sk-ant-from-file".to_string()));
assert_eq!(s.openai_api_key, Some("sk-openai-from-file".to_string()));
assert_eq!(s.get_provider_api_key(), Some("sk-ant-from-file".to_string()));
}
#[test]
#[ignore] // Ignore due to env var interaction in parallel tests
fn anthropic_api_key_from_standard_env() {
let tmp = tempfile::tempdir().unwrap();
let project_file = tmp.path().join(".owlen.toml");
fs::write(&project_file, r#"provider="anthropic""#).unwrap();
unsafe {
env::set_var("ANTHROPIC_API_KEY", "sk-ant-std");
env::remove_var("OWLEN_ANTHROPIC_API_KEY");
env::remove_var("PROVIDER");
env::remove_var("OWLEN_PROVIDER");
}
let s = load_settings(Some(tmp.path().to_str().unwrap())).unwrap();
assert_eq!(s.anthropic_api_key, Some("sk-ant-std".to_string()));
assert_eq!(s.get_provider_api_key(), Some("sk-ant-std".to_string()));
unsafe { env::remove_var("ANTHROPIC_API_KEY"); }
}
#[test]
#[ignore] // Ignore due to env var interaction in parallel tests
fn openai_api_key_from_standard_env() {
let tmp = tempfile::tempdir().unwrap();
let project_file = tmp.path().join(".owlen.toml");
fs::write(&project_file, r#"provider="openai""#).unwrap();
unsafe {
env::set_var("OPENAI_API_KEY", "sk-openai-std");
env::remove_var("OWLEN_OPENAI_API_KEY");
env::remove_var("PROVIDER");
env::remove_var("OWLEN_PROVIDER");
}
let s = load_settings(Some(tmp.path().to_str().unwrap())).unwrap();
assert_eq!(s.openai_api_key, Some("sk-openai-std".to_string()));
assert_eq!(s.get_provider_api_key(), Some("sk-openai-std".to_string()));
unsafe { env::remove_var("OPENAI_API_KEY"); }
}
#[test]
#[ignore] // Ignore due to env var interaction in parallel tests
fn owlen_prefix_overrides_standard_env() {
let tmp = tempfile::tempdir().unwrap();
unsafe {
env::set_var("ANTHROPIC_API_KEY", "sk-ant-std");
env::set_var("OWLEN_ANTHROPIC_API_KEY", "sk-ant-owlen");
}
let s = load_settings(Some(tmp.path().to_str().unwrap())).unwrap();
// OWLEN_ prefix should take precedence
assert_eq!(s.anthropic_api_key, Some("sk-ant-owlen".to_string()));
unsafe {
env::remove_var("ANTHROPIC_API_KEY");
env::remove_var("OWLEN_ANTHROPIC_API_KEY");
}
}
#[test]
fn effective_model_uses_provider_default() {
// Test Anthropic provider default
let mut s = Settings::default();
s.provider = "anthropic".to_string();
assert_eq!(s.get_effective_model(), "claude-sonnet-4-20250514");
// Test OpenAI provider default
s.provider = "openai".to_string();
assert_eq!(s.get_effective_model(), "gpt-4o");
// Test Ollama provider default
s.provider = "ollama".to_string();
assert_eq!(s.get_effective_model(), "qwen3:8b");
}
#[test]
fn effective_model_respects_explicit_model() {
let mut s = Settings::default();
s.provider = "anthropic".to_string();
s.model = "claude-opus-4-20250514".to_string();
// Should use explicit model, not provider default
assert_eq!(s.get_effective_model(), "claude-opus-4-20250514");
}

View File

@@ -10,6 +10,7 @@ serde = { version = "1", features = ["derive"] }
serde_json = "1"
tokio = { version = "1.39", features = ["process", "time", "io-util"] }
color-eyre = "0.6"
regex = "1.10"
[dev-dependencies]
tempfile = "3.23.0"

View File

@@ -34,6 +34,34 @@ pub enum HookEvent {
prompt: String,
},
PreCompact,
/// Called before the agent stops - allows validation of completion
#[serde(rename_all = "camelCase")]
Stop {
/// Reason for stopping (e.g., "task_complete", "max_iterations", "user_interrupt")
reason: String,
/// Number of messages in conversation
num_messages: usize,
/// Number of tool calls made
num_tool_calls: usize,
},
/// Called before a subagent stops
#[serde(rename_all = "camelCase")]
SubagentStop {
/// Unique identifier for the subagent
agent_id: String,
/// Type of subagent (e.g., "explore", "code-reviewer")
agent_type: String,
/// Reason for stopping
reason: String,
},
/// Called when a notification is sent to the user
#[serde(rename_all = "camelCase")]
Notification {
/// Notification message
message: String,
/// Notification type (e.g., "info", "warning", "error")
notification_type: String,
},
}
impl HookEvent {
@@ -46,27 +74,137 @@ impl HookEvent {
HookEvent::SessionEnd { .. } => "SessionEnd",
HookEvent::UserPromptSubmit { .. } => "UserPromptSubmit",
HookEvent::PreCompact => "PreCompact",
HookEvent::Stop { .. } => "Stop",
HookEvent::SubagentStop { .. } => "SubagentStop",
HookEvent::Notification { .. } => "Notification",
}
}
}
/// Simple hook result for backwards compatibility
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HookResult {
Allow,
Deny,
}
/// Extended hook output with additional control options
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct HookOutput {
/// Whether to continue execution (default: true if exit code 0)
#[serde(default = "default_continue")]
pub continue_execution: bool,
/// Whether to suppress showing the result to the user
#[serde(default)]
pub suppress_output: bool,
/// System message to inject into the conversation
#[serde(default)]
pub system_message: Option<String>,
/// Permission decision override
#[serde(default)]
pub permission_decision: Option<HookPermission>,
/// Modified input/args for the tool (PreToolUse only)
#[serde(default)]
pub updated_input: Option<Value>,
}
impl Default for HookOutput {
fn default() -> Self {
Self {
continue_execution: true,
suppress_output: false,
system_message: None,
permission_decision: None,
updated_input: None,
}
}
}
fn default_continue() -> bool {
true
}
/// Permission decision from a hook
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum HookPermission {
Allow,
Deny,
Ask,
}
impl HookOutput {
pub fn new() -> Self {
Self::default()
}
pub fn allow() -> Self {
Self {
continue_execution: true,
..Default::default()
}
}
pub fn deny() -> Self {
Self {
continue_execution: false,
..Default::default()
}
}
pub fn with_system_message(mut self, message: impl Into<String>) -> Self {
self.system_message = Some(message.into());
self
}
pub fn with_permission(mut self, permission: HookPermission) -> Self {
self.permission_decision = Some(permission);
self
}
/// Convert to simple HookResult for backwards compatibility
pub fn to_result(&self) -> HookResult {
if self.continue_execution {
HookResult::Allow
} else {
HookResult::Deny
}
}
}
/// A registered hook that can be executed
#[derive(Debug, Clone)]
struct Hook {
event: String, // Event name like "PreToolUse", "PostToolUse", etc.
command: String,
pattern: Option<String>, // Optional regex pattern for matching tool names
timeout: Option<u64>,
}
pub struct HookManager {
project_root: PathBuf,
hooks: Vec<Hook>,
}
impl HookManager {
pub fn new(project_root: &str) -> Self {
Self {
project_root: PathBuf::from(project_root),
hooks: Vec::new(),
}
}
/// Register a single hook
pub fn register_hook(&mut self, event: String, command: String, pattern: Option<String>, timeout: Option<u64>) {
self.hooks.push(Hook {
event,
command,
pattern,
timeout,
});
}
/// Execute a hook for the given event
///
/// Returns:
@@ -74,18 +212,66 @@ impl HookManager {
/// - Ok(HookResult::Deny) if hook denies (exit code 2)
/// - Err if hook fails (other exit codes) or times out
pub async fn execute(&self, event: &HookEvent, timeout_ms: Option<u64>) -> Result<HookResult> {
// First check for legacy file-based hooks
let hook_path = self.get_hook_path(event);
let has_file_hook = hook_path.exists();
// If hook doesn't exist, allow by default
if !hook_path.exists() {
// Get registered hooks for this event
let event_name = event.hook_name();
let mut matching_hooks: Vec<&Hook> = self.hooks.iter()
.filter(|h| h.event == event_name)
.collect();
// If we need to filter by pattern (for PreToolUse events)
if let HookEvent::PreToolUse { tool, .. } = event {
matching_hooks.retain(|h| {
if let Some(pattern) = &h.pattern {
// Use regex to match tool name
if let Ok(re) = regex::Regex::new(pattern) {
re.is_match(tool)
} else {
false
}
} else {
true // No pattern means match all
}
});
}
// If no hooks at all, allow by default
if !has_file_hook && matching_hooks.is_empty() {
return Ok(HookResult::Allow);
}
// Execute file-based hook first (if exists)
if has_file_hook {
let result = self.execute_hook_command(&hook_path.to_string_lossy(), event, timeout_ms).await?;
if result == HookResult::Deny {
return Ok(HookResult::Deny);
}
}
// Execute registered hooks
for hook in matching_hooks {
let hook_timeout = hook.timeout.or(timeout_ms);
let result = self.execute_hook_command(&hook.command, event, hook_timeout).await?;
if result == HookResult::Deny {
return Ok(HookResult::Deny);
}
}
Ok(HookResult::Allow)
}
/// Execute a single hook command
async fn execute_hook_command(&self, command: &str, event: &HookEvent, timeout_ms: Option<u64>) -> Result<HookResult> {
// Serialize event to JSON
let input_json = serde_json::to_string(event)?;
// Spawn the hook process
let mut child = Command::new(&hook_path)
let mut child = Command::new("sh")
.arg("-c")
.arg(command)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
@@ -126,6 +312,131 @@ impl HookManager {
}
}
/// Execute a hook and return extended output
///
/// This method parses JSON output from stdout if the hook provides it,
/// otherwise falls back to exit code interpretation.
pub async fn execute_extended(&self, event: &HookEvent, timeout_ms: Option<u64>) -> Result<HookOutput> {
// First check for legacy file-based hooks
let hook_path = self.get_hook_path(event);
let has_file_hook = hook_path.exists();
// Get registered hooks for this event
let event_name = event.hook_name();
let mut matching_hooks: Vec<&Hook> = self.hooks.iter()
.filter(|h| h.event == event_name)
.collect();
// If we need to filter by pattern (for PreToolUse events)
if let HookEvent::PreToolUse { tool, .. } = event {
matching_hooks.retain(|h| {
if let Some(pattern) = &h.pattern {
if let Ok(re) = regex::Regex::new(pattern) {
re.is_match(tool)
} else {
false
}
} else {
true
}
});
}
// If no hooks at all, allow by default
if !has_file_hook && matching_hooks.is_empty() {
return Ok(HookOutput::allow());
}
let mut combined_output = HookOutput::allow();
// Execute file-based hook first (if exists)
if has_file_hook {
let output = self.execute_hook_extended(&hook_path.to_string_lossy(), event, timeout_ms).await?;
combined_output = Self::merge_outputs(combined_output, output);
if !combined_output.continue_execution {
return Ok(combined_output);
}
}
// Execute registered hooks
for hook in matching_hooks {
let hook_timeout = hook.timeout.or(timeout_ms);
let output = self.execute_hook_extended(&hook.command, event, hook_timeout).await?;
combined_output = Self::merge_outputs(combined_output, output);
if !combined_output.continue_execution {
return Ok(combined_output);
}
}
Ok(combined_output)
}
/// Execute a single hook command and return extended output
async fn execute_hook_extended(&self, command: &str, event: &HookEvent, timeout_ms: Option<u64>) -> Result<HookOutput> {
let input_json = serde_json::to_string(event)?;
let mut child = Command::new("sh")
.arg("-c")
.arg(command)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.current_dir(&self.project_root)
.spawn()?;
if let Some(mut stdin) = child.stdin.take() {
stdin.write_all(input_json.as_bytes()).await?;
stdin.flush().await?;
drop(stdin);
}
let result = if let Some(ms) = timeout_ms {
timeout(Duration::from_millis(ms), child.wait_with_output()).await
} else {
Ok(child.wait_with_output().await)
};
match result {
Ok(Ok(output)) => {
let exit_code = output.status.code();
let stdout = String::from_utf8_lossy(&output.stdout);
// Try to parse JSON output from stdout
if !stdout.trim().is_empty() {
if let Ok(hook_output) = serde_json::from_str::<HookOutput>(stdout.trim()) {
return Ok(hook_output);
}
}
// Fall back to exit code interpretation
match exit_code {
Some(0) => Ok(HookOutput::allow()),
Some(2) => Ok(HookOutput::deny()),
Some(code) => Err(eyre!(
"Hook {} failed with exit code {}: {}",
event.hook_name(),
code,
String::from_utf8_lossy(&output.stderr)
)),
None => Err(eyre!("Hook {} terminated by signal", event.hook_name())),
}
}
Ok(Err(e)) => Err(eyre!("Failed to execute hook {}: {}", event.hook_name(), e)),
Err(_) => Err(eyre!("Hook {} timed out", event.hook_name())),
}
}
/// Merge two hook outputs, with the second taking precedence
fn merge_outputs(base: HookOutput, new: HookOutput) -> HookOutput {
HookOutput {
continue_execution: base.continue_execution && new.continue_execution,
suppress_output: base.suppress_output || new.suppress_output,
system_message: new.system_message.or(base.system_message),
permission_decision: new.permission_decision.or(base.permission_decision),
updated_input: new.updated_input.or(base.updated_input),
}
}
fn get_hook_path(&self, event: &HookEvent) -> PathBuf {
self.project_root
.join(".owlen")
@@ -167,5 +478,76 @@ mod tests {
.hook_name(),
"SessionStart"
);
assert_eq!(
HookEvent::Stop {
reason: "task_complete".to_string(),
num_messages: 10,
num_tool_calls: 5,
}
.hook_name(),
"Stop"
);
assert_eq!(
HookEvent::SubagentStop {
agent_id: "abc123".to_string(),
agent_type: "explore".to_string(),
reason: "completed".to_string(),
}
.hook_name(),
"SubagentStop"
);
}
#[test]
fn stop_event_serializes_correctly() {
let event = HookEvent::Stop {
reason: "task_complete".to_string(),
num_messages: 10,
num_tool_calls: 5,
};
let json = serde_json::to_string(&event).unwrap();
assert!(json.contains("\"event\":\"stop\""));
assert!(json.contains("\"reason\":\"task_complete\""));
assert!(json.contains("\"numMessages\":10"));
assert!(json.contains("\"numToolCalls\":5"));
}
#[test]
fn hook_output_defaults() {
let output = HookOutput::default();
assert!(output.continue_execution);
assert!(!output.suppress_output);
assert!(output.system_message.is_none());
assert!(output.permission_decision.is_none());
}
#[test]
fn hook_output_builders() {
let output = HookOutput::allow()
.with_system_message("Test message")
.with_permission(HookPermission::Allow);
assert!(output.continue_execution);
assert_eq!(output.system_message, Some("Test message".to_string()));
assert_eq!(output.permission_decision, Some(HookPermission::Allow));
let deny = HookOutput::deny();
assert!(!deny.continue_execution);
}
#[test]
fn hook_output_deserializes() {
let json = r#"{"continueExecution": true, "suppressOutput": false, "systemMessage": "Hello"}"#;
let output: HookOutput = serde_json::from_str(json).unwrap();
assert!(output.continue_execution);
assert!(!output.suppress_output);
assert_eq!(output.system_message, Some("Hello".to_string()));
}
#[test]
fn hook_output_to_result() {
assert_eq!(HookOutput::allow().to_result(), HookResult::Allow);
assert_eq!(HookOutput::deny().to_result(), HookResult::Deny);
}
}

View File

@@ -0,0 +1,154 @@
// Integration test for plugin hooks with HookManager
use color_eyre::eyre::Result;
use hooks::{HookEvent, HookManager, HookResult};
use tempfile::TempDir;
#[tokio::test]
async fn test_register_and_execute_plugin_hooks() -> Result<()> {
// Create temporary directory to act as project root
let temp_dir = TempDir::new()?;
// Create hook manager
let mut hook_mgr = HookManager::new(temp_dir.path().to_str().unwrap());
// Register a hook that matches Edit|Write tools
hook_mgr.register_hook(
"PreToolUse".to_string(),
"echo 'Hook executed' && exit 0".to_string(),
Some("Edit|Write".to_string()),
Some(5000),
);
// Test that the hook executes for Edit tool
let event = HookEvent::PreToolUse {
tool: "Edit".to_string(),
args: serde_json::json!({"path": "/tmp/test.txt"}),
};
let result = hook_mgr.execute(&event, Some(5000)).await?;
assert_eq!(result, HookResult::Allow);
// Test that the hook executes for Write tool
let event = HookEvent::PreToolUse {
tool: "Write".to_string(),
args: serde_json::json!({"path": "/tmp/test.txt"}),
};
let result = hook_mgr.execute(&event, Some(5000)).await?;
assert_eq!(result, HookResult::Allow);
// Test that the hook does NOT execute for Read tool (doesn't match pattern)
let event = HookEvent::PreToolUse {
tool: "Read".to_string(),
args: serde_json::json!({"path": "/tmp/test.txt"}),
};
let result = hook_mgr.execute(&event, Some(5000)).await?;
assert_eq!(result, HookResult::Allow);
Ok(())
}
#[tokio::test]
async fn test_deny_hook() -> Result<()> {
// Create temporary directory to act as project root
let temp_dir = TempDir::new()?;
// Create hook manager
let mut hook_mgr = HookManager::new(temp_dir.path().to_str().unwrap());
// Register a hook that denies Write operations
hook_mgr.register_hook(
"PreToolUse".to_string(),
"exit 2".to_string(), // Exit code 2 means deny
Some("Write".to_string()),
Some(5000),
);
// Test that the hook denies Write tool
let event = HookEvent::PreToolUse {
tool: "Write".to_string(),
args: serde_json::json!({"path": "/tmp/test.txt"}),
};
let result = hook_mgr.execute(&event, Some(5000)).await?;
assert_eq!(result, HookResult::Deny);
Ok(())
}
#[tokio::test]
async fn test_multiple_hooks_same_event() -> Result<()> {
// Create temporary directory to act as project root
let temp_dir = TempDir::new()?;
// Create hook manager
let mut hook_mgr = HookManager::new(temp_dir.path().to_str().unwrap());
// Register multiple hooks for the same event
hook_mgr.register_hook(
"PreToolUse".to_string(),
"echo 'Hook 1' && exit 0".to_string(),
Some("Edit".to_string()),
Some(5000),
);
hook_mgr.register_hook(
"PreToolUse".to_string(),
"echo 'Hook 2' && exit 0".to_string(),
Some("Edit".to_string()),
Some(5000),
);
// Test that both hooks execute
let event = HookEvent::PreToolUse {
tool: "Edit".to_string(),
args: serde_json::json!({"path": "/tmp/test.txt"}),
};
let result = hook_mgr.execute(&event, Some(5000)).await?;
assert_eq!(result, HookResult::Allow);
Ok(())
}
#[tokio::test]
async fn test_hook_with_no_pattern_matches_all() -> Result<()> {
// Create temporary directory to act as project root
let temp_dir = TempDir::new()?;
// Create hook manager
let mut hook_mgr = HookManager::new(temp_dir.path().to_str().unwrap());
// Register a hook with no pattern (matches all tools)
hook_mgr.register_hook(
"PreToolUse".to_string(),
"echo 'Hook for all tools' && exit 0".to_string(),
None, // No pattern = match all
Some(5000),
);
// Test that the hook executes for any tool
let event = HookEvent::PreToolUse {
tool: "Read".to_string(),
args: serde_json::json!({"path": "/tmp/test.txt"}),
};
let result = hook_mgr.execute(&event, Some(5000)).await?;
assert_eq!(result, HookResult::Allow);
let event = HookEvent::PreToolUse {
tool: "Write".to_string(),
args: serde_json::json!({"path": "/tmp/test.txt"}),
};
let result = hook_mgr.execute(&event, Some(5000)).await?;
assert_eq!(result, HookResult::Allow);
let event = HookEvent::PreToolUse {
tool: "Bash".to_string(),
args: serde_json::json!({"command": "ls"}),
};
let result = hook_mgr.execute(&event, Some(5000)).await?;
assert_eq!(result, HookResult::Allow);
Ok(())
}

View File

@@ -16,6 +16,75 @@ pub enum Tool {
Task,
TodoWrite,
Mcp,
// New tools
MultiEdit,
LS,
AskUserQuestion,
BashOutput,
KillShell,
// Planning mode tools
EnterPlanMode,
ExitPlanMode,
Skill,
}
impl Tool {
/// Parse a tool name from string (case-insensitive)
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"read" => Some(Tool::Read),
"write" => Some(Tool::Write),
"edit" => Some(Tool::Edit),
"bash" => Some(Tool::Bash),
"grep" => Some(Tool::Grep),
"glob" => Some(Tool::Glob),
"webfetch" | "web_fetch" => Some(Tool::WebFetch),
"websearch" | "web_search" => Some(Tool::WebSearch),
"notebookread" | "notebook_read" => Some(Tool::NotebookRead),
"notebookedit" | "notebook_edit" => Some(Tool::NotebookEdit),
"slashcommand" | "slash_command" => Some(Tool::SlashCommand),
"task" => Some(Tool::Task),
"todowrite" | "todo_write" | "todo" => Some(Tool::TodoWrite),
"mcp" => Some(Tool::Mcp),
"multiedit" | "multi_edit" => Some(Tool::MultiEdit),
"ls" => Some(Tool::LS),
"askuserquestion" | "ask_user_question" | "ask" => Some(Tool::AskUserQuestion),
"bashoutput" | "bash_output" => Some(Tool::BashOutput),
"killshell" | "kill_shell" => Some(Tool::KillShell),
"enterplanmode" | "enter_plan_mode" => Some(Tool::EnterPlanMode),
"exitplanmode" | "exit_plan_mode" => Some(Tool::ExitPlanMode),
"skill" => Some(Tool::Skill),
_ => None,
}
}
/// Get the string name of this tool
pub fn name(&self) -> &'static str {
match self {
Tool::Read => "read",
Tool::Write => "write",
Tool::Edit => "edit",
Tool::Bash => "bash",
Tool::Grep => "grep",
Tool::Glob => "glob",
Tool::WebFetch => "web_fetch",
Tool::WebSearch => "web_search",
Tool::NotebookRead => "notebook_read",
Tool::NotebookEdit => "notebook_edit",
Tool::SlashCommand => "slash_command",
Tool::Task => "task",
Tool::TodoWrite => "todo_write",
Tool::Mcp => "mcp",
Tool::MultiEdit => "multi_edit",
Tool::LS => "ls",
Tool::AskUserQuestion => "ask_user_question",
Tool::BashOutput => "bash_output",
Tool::KillShell => "kill_shell",
Tool::EnterPlanMode => "enter_plan_mode",
Tool::ExitPlanMode => "exit_plan_mode",
Tool::Skill => "skill",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
@@ -123,23 +192,34 @@ impl PermissionManager {
match self.mode {
Mode::Plan => match tool {
// Read-only tools are allowed in plan mode
Tool::Read | Tool::Grep | Tool::Glob | Tool::NotebookRead => {
Tool::Read | Tool::Grep | Tool::Glob | Tool::NotebookRead | Tool::LS => {
PermissionDecision::Allow
}
// User interaction and session state tools allowed
Tool::AskUserQuestion | Tool::TodoWrite => PermissionDecision::Allow,
// Planning mode tools - EnterPlanMode asks, ExitPlanMode allowed
Tool::EnterPlanMode => PermissionDecision::Ask,
Tool::ExitPlanMode => PermissionDecision::Allow,
// Skill tool allowed (read-only skill injection)
Tool::Skill => PermissionDecision::Allow,
// Everything else requires asking
_ => PermissionDecision::Ask,
},
Mode::AcceptEdits => match tool {
// Read operations allowed
Tool::Read | Tool::Grep | Tool::Glob | Tool::NotebookRead => {
Tool::Read | Tool::Grep | Tool::Glob | Tool::NotebookRead | Tool::LS => {
PermissionDecision::Allow
}
// Edit/Write operations allowed
Tool::Edit | Tool::Write | Tool::NotebookEdit => PermissionDecision::Allow,
Tool::Edit | Tool::Write | Tool::NotebookEdit | Tool::MultiEdit => PermissionDecision::Allow,
// Bash and other dangerous operations still require asking
Tool::Bash | Tool::WebFetch | Tool::WebSearch | Tool::Mcp => PermissionDecision::Ask,
// Background shell operations same as Bash
Tool::BashOutput | Tool::KillShell => PermissionDecision::Ask,
// Utility tools allowed
Tool::TodoWrite | Tool::SlashCommand | Tool::Task => PermissionDecision::Allow,
Tool::TodoWrite | Tool::SlashCommand | Tool::Task | Tool::AskUserQuestion => PermissionDecision::Allow,
// Planning mode tools allowed
Tool::EnterPlanMode | Tool::ExitPlanMode | Tool::Skill => PermissionDecision::Allow,
},
Mode::Code => {
// Everything allowed in code mode
@@ -155,6 +235,41 @@ impl PermissionManager {
pub fn mode(&self) -> Mode {
self.mode
}
/// Add allowed tools from a list of tool names (with optional patterns)
///
/// Format: "tool_name" or "tool_name:pattern"
/// Example: "bash", "bash:npm test:*", "mcp:filesystem__*"
pub fn add_allowed_tools(&mut self, tools: &[String]) {
for spec in tools {
if let Some((tool, pattern)) = Self::parse_tool_spec(spec) {
self.add_rule(tool, pattern, Action::Allow);
}
}
}
/// Add disallowed tools from a list of tool names (with optional patterns)
///
/// Format: "tool_name" or "tool_name:pattern"
/// Example: "bash", "bash:rm -rf*"
pub fn add_disallowed_tools(&mut self, tools: &[String]) {
for spec in tools {
if let Some((tool, pattern)) = Self::parse_tool_spec(spec) {
self.add_rule(tool, pattern, Action::Deny);
}
}
}
/// Parse a tool specification into (Tool, Option<pattern>)
///
/// Format: "tool_name" or "tool_name:pattern"
fn parse_tool_spec(spec: &str) -> Option<(Tool, Option<String>)> {
let parts: Vec<&str> = spec.splitn(2, ':').collect();
let tool_name = parts[0].trim();
let pattern = parts.get(1).map(|s| s.trim().to_string());
Tool::from_str(tool_name).map(|tool| (tool, pattern))
}
}
#[cfg(test)]
@@ -237,4 +352,78 @@ mod tests {
assert!(rule.matches(Tool::Mcp, Some("filesystem__read_file")));
assert!(!rule.matches(Tool::Mcp, Some("filesystem__write_file")));
}
#[test]
fn tool_from_str() {
assert_eq!(Tool::from_str("bash"), Some(Tool::Bash));
assert_eq!(Tool::from_str("BASH"), Some(Tool::Bash));
assert_eq!(Tool::from_str("Bash"), Some(Tool::Bash));
assert_eq!(Tool::from_str("web_fetch"), Some(Tool::WebFetch));
assert_eq!(Tool::from_str("webfetch"), Some(Tool::WebFetch));
assert_eq!(Tool::from_str("unknown"), None);
}
#[test]
fn parse_tool_spec() {
let (tool, pattern) = PermissionManager::parse_tool_spec("bash").unwrap();
assert_eq!(tool, Tool::Bash);
assert_eq!(pattern, None);
let (tool, pattern) = PermissionManager::parse_tool_spec("bash:npm test*").unwrap();
assert_eq!(tool, Tool::Bash);
assert_eq!(pattern, Some("npm test*".to_string()));
let (tool, pattern) = PermissionManager::parse_tool_spec("mcp:filesystem__*").unwrap();
assert_eq!(tool, Tool::Mcp);
assert_eq!(pattern, Some("filesystem__*".to_string()));
assert!(PermissionManager::parse_tool_spec("invalid_tool").is_none());
}
#[test]
fn allowed_tools_list() {
let mut pm = PermissionManager::new(Mode::Plan);
pm.add_allowed_tools(&[
"bash:npm test:*".to_string(),
"bash:cargo test".to_string(),
]);
// Allowed by rule
assert_eq!(pm.check(Tool::Bash, Some("npm test:unit")), PermissionDecision::Allow);
assert_eq!(pm.check(Tool::Bash, Some("cargo test")), PermissionDecision::Allow);
// Not matched by any rule, falls back to mode default (Ask for bash in plan mode)
assert_eq!(pm.check(Tool::Bash, Some("rm -rf")), PermissionDecision::Ask);
}
#[test]
fn disallowed_tools_list() {
let mut pm = PermissionManager::new(Mode::Code);
pm.add_disallowed_tools(&[
"bash:rm -rf*".to_string(),
"bash:sudo*".to_string(),
]);
// Denied by rule
assert_eq!(pm.check(Tool::Bash, Some("rm -rf /")), PermissionDecision::Deny);
assert_eq!(pm.check(Tool::Bash, Some("sudo apt install")), PermissionDecision::Deny);
// Not matched by deny rule, allowed by Code mode
assert_eq!(pm.check(Tool::Bash, Some("npm test")), PermissionDecision::Allow);
}
#[test]
fn deny_takes_precedence() {
let mut pm = PermissionManager::new(Mode::Code);
// Add both allow and deny for similar patterns
pm.add_disallowed_tools(&["bash:rm*".to_string()]);
pm.add_allowed_tools(&["bash".to_string()]);
// Deny rule was added first, so it takes precedence when matched
assert_eq!(pm.check(Tool::Bash, Some("rm -rf")), PermissionDecision::Deny);
assert_eq!(pm.check(Tool::Bash, Some("ls -la")), PermissionDecision::Allow);
}
}

View File

@@ -0,0 +1,15 @@
[package]
name = "plugins"
version = "0.1.0"
edition = "2024"
[dependencies]
color-eyre = "0.6"
dirs = "5.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
serde_yaml = "0.9"
walkdir = "2.5"
[dev-dependencies]
tempfile = "3.13"

View File

@@ -0,0 +1,773 @@
use color_eyre::eyre::{Result, eyre};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use walkdir::WalkDir;
/// Plugin manifest schema (plugin.json)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PluginManifest {
/// Plugin name
pub name: String,
/// Plugin version
pub version: String,
/// Plugin description
pub description: Option<String>,
/// Plugin author
pub author: Option<String>,
/// Commands provided by this plugin
#[serde(default)]
pub commands: Vec<String>,
/// Agents provided by this plugin
#[serde(default)]
pub agents: Vec<String>,
/// Skills provided by this plugin
#[serde(default)]
pub skills: Vec<String>,
/// Hooks provided by this plugin
#[serde(default)]
pub hooks: HashMap<String, String>,
/// MCP servers provided by this plugin
#[serde(default)]
pub mcp_servers: Vec<McpServerConfig>,
}
/// MCP server configuration in plugin manifest
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpServerConfig {
pub name: String,
pub command: String,
#[serde(default)]
pub args: Vec<String>,
#[serde(default)]
pub env: HashMap<String, String>,
}
/// Plugin hook configuration from hooks/hooks.json
#[derive(Debug, Clone, Deserialize)]
pub struct PluginHooksConfig {
pub description: Option<String>,
pub hooks: HashMap<String, Vec<HookMatcher>>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct HookMatcher {
pub matcher: Option<String>, // Regex pattern for tool names
pub hooks: Vec<HookDefinition>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct HookDefinition {
#[serde(rename = "type")]
pub hook_type: String, // "command" or "prompt"
pub command: Option<String>,
pub prompt: Option<String>,
pub timeout: Option<u64>,
}
/// Parsed slash command from markdown file
#[derive(Debug, Clone)]
pub struct SlashCommand {
pub name: String,
pub description: Option<String>,
pub argument_hint: Option<String>,
pub allowed_tools: Option<Vec<String>>,
pub body: String, // Markdown content after frontmatter
pub source_path: PathBuf,
}
/// Parsed agent definition from markdown file
#[derive(Debug, Clone)]
pub struct AgentDefinition {
pub name: String,
pub description: String,
pub tools: Vec<String>, // Tool whitelist
pub model: Option<String>, // haiku, sonnet, opus
pub color: Option<String>,
pub system_prompt: String, // Markdown body
pub source_path: PathBuf,
}
/// Parsed skill definition
#[derive(Debug, Clone)]
pub struct Skill {
pub name: String,
pub description: String,
pub version: Option<String>,
pub content: String, // Core SKILL.md content
pub references: Vec<PathBuf>, // Reference files
pub examples: Vec<PathBuf>, // Example files
pub source_path: PathBuf,
}
/// YAML frontmatter for command files
#[derive(Deserialize)]
struct CommandFrontmatter {
description: Option<String>,
#[serde(rename = "argument-hint")]
argument_hint: Option<String>,
#[serde(rename = "allowed-tools")]
allowed_tools: Option<String>,
}
/// YAML frontmatter for agent files
#[derive(Deserialize)]
struct AgentFrontmatter {
name: String,
description: String,
#[serde(default)]
tools: Vec<String>,
model: Option<String>,
color: Option<String>,
}
/// YAML frontmatter for skill files
#[derive(Deserialize)]
struct SkillFrontmatter {
name: String,
description: String,
version: Option<String>,
}
/// Parse YAML frontmatter from markdown content
fn parse_frontmatter<T: serde::de::DeserializeOwned>(content: &str) -> Result<(T, String)> {
if !content.starts_with("---") {
return Err(eyre!("No frontmatter found"));
}
let parts: Vec<&str> = content.splitn(3, "---").collect();
if parts.len() < 3 {
return Err(eyre!("Invalid frontmatter format"));
}
let frontmatter: T = serde_yaml::from_str(parts[1].trim())?;
let body = parts[2].trim().to_string();
Ok((frontmatter, body))
}
/// A loaded plugin with its manifest and base path
#[derive(Debug, Clone)]
pub struct Plugin {
pub manifest: PluginManifest,
pub base_path: PathBuf,
}
impl Plugin {
/// Get the path to a command file
pub fn command_path(&self, command_name: &str) -> PathBuf {
self.base_path.join("commands").join(format!("{}.md", command_name))
}
/// Get the path to an agent file
pub fn agent_path(&self, agent_name: &str) -> PathBuf {
self.base_path.join("agents").join(format!("{}.md", agent_name))
}
/// Get the path to a skill file
pub fn skill_path(&self, skill_name: &str) -> PathBuf {
self.base_path.join("skills").join(skill_name).join("SKILL.md")
}
/// Get the path to a hook script
pub fn hook_path(&self, hook_name: &str) -> Option<PathBuf> {
self.manifest.hooks.get(hook_name).map(|path| {
self.base_path.join("hooks").join(path)
})
}
/// Parse a command file
pub fn parse_command(&self, name: &str) -> Result<SlashCommand> {
let path = self.command_path(name);
let content = fs::read_to_string(&path)?;
let (fm, body): (CommandFrontmatter, String) = parse_frontmatter(&content)?;
let allowed_tools = fm.allowed_tools.map(|s| {
s.split(',').map(|t| t.trim().to_string()).collect()
});
Ok(SlashCommand {
name: name.to_string(),
description: fm.description,
argument_hint: fm.argument_hint,
allowed_tools,
body,
source_path: path,
})
}
/// Parse an agent file
pub fn parse_agent(&self, name: &str) -> Result<AgentDefinition> {
let path = self.agent_path(name);
let content = fs::read_to_string(&path)?;
let (fm, body): (AgentFrontmatter, String) = parse_frontmatter(&content)?;
Ok(AgentDefinition {
name: fm.name,
description: fm.description,
tools: fm.tools,
model: fm.model,
color: fm.color,
system_prompt: body,
source_path: path,
})
}
/// Parse a skill file
pub fn parse_skill(&self, name: &str) -> Result<Skill> {
let path = self.skill_path(name);
let content = fs::read_to_string(&path)?;
let (fm, body): (SkillFrontmatter, String) = parse_frontmatter(&content)?;
// Discover reference and example files in the skill directory
let skill_dir = self.base_path.join("skills").join(name);
let references_dir = skill_dir.join("references");
let examples_dir = skill_dir.join("examples");
let references = if references_dir.exists() {
fs::read_dir(&references_dir)
.into_iter()
.flatten()
.filter_map(|e| e.ok())
.map(|e| e.path())
.collect()
} else {
Vec::new()
};
let examples = if examples_dir.exists() {
fs::read_dir(&examples_dir)
.into_iter()
.flatten()
.filter_map(|e| e.ok())
.map(|e| e.path())
.collect()
} else {
Vec::new()
};
Ok(Skill {
name: fm.name,
description: fm.description,
version: fm.version,
content: body,
references,
examples,
source_path: path,
})
}
/// Auto-discover commands in the commands/ directory
pub fn discover_commands(&self) -> Vec<String> {
let commands_dir = self.base_path.join("commands");
if !commands_dir.exists() {
return Vec::new();
}
std::fs::read_dir(&commands_dir)
.into_iter()
.flatten()
.filter_map(|e| e.ok())
.filter(|e| e.path().extension().map(|ext| ext == "md").unwrap_or(false))
.filter_map(|e| {
e.path().file_stem()
.map(|s| s.to_string_lossy().to_string())
})
.collect()
}
/// Auto-discover agents in the agents/ directory
pub fn discover_agents(&self) -> Vec<String> {
let agents_dir = self.base_path.join("agents");
if !agents_dir.exists() {
return Vec::new();
}
std::fs::read_dir(&agents_dir)
.into_iter()
.flatten()
.filter_map(|e| e.ok())
.filter(|e| e.path().extension().map(|ext| ext == "md").unwrap_or(false))
.filter_map(|e| {
e.path().file_stem()
.map(|s| s.to_string_lossy().to_string())
})
.collect()
}
/// Auto-discover skills in skills/*/SKILL.md
pub fn discover_skills(&self) -> Vec<String> {
let skills_dir = self.base_path.join("skills");
if !skills_dir.exists() {
return Vec::new();
}
std::fs::read_dir(&skills_dir)
.into_iter()
.flatten()
.filter_map(|e| e.ok())
.filter(|e| e.path().is_dir())
.filter(|e| e.path().join("SKILL.md").exists())
.filter_map(|e| {
e.path().file_name()
.map(|s| s.to_string_lossy().to_string())
})
.collect()
}
/// Get all commands (manifest + discovered)
pub fn all_command_names(&self) -> Vec<String> {
let mut names: std::collections::HashSet<String> =
self.manifest.commands.iter().cloned().collect();
names.extend(self.discover_commands());
names.into_iter().collect()
}
/// Get all agent names (manifest + discovered)
pub fn all_agent_names(&self) -> Vec<String> {
let mut names: std::collections::HashSet<String> =
self.manifest.agents.iter().cloned().collect();
names.extend(self.discover_agents());
names.into_iter().collect()
}
/// Get all skill names (manifest + discovered)
pub fn all_skill_names(&self) -> Vec<String> {
let mut names: std::collections::HashSet<String> =
self.manifest.skills.iter().cloned().collect();
names.extend(self.discover_skills());
names.into_iter().collect()
}
/// Load hooks configuration from hooks/hooks.json
pub fn load_hooks_config(&self) -> Result<Option<PluginHooksConfig>> {
let hooks_path = self.base_path.join("hooks").join("hooks.json");
if !hooks_path.exists() {
return Ok(None);
}
let content = fs::read_to_string(&hooks_path)?;
let config: PluginHooksConfig = serde_json::from_str(&content)?;
Ok(Some(config))
}
/// Register hooks from this plugin's config into a hook manager
/// This requires the hooks crate to be available where this is called
pub fn register_hooks_with_manager(&self, config: &PluginHooksConfig) -> Vec<(String, String, Option<String>, Option<u64>)> {
let mut hooks_to_register = Vec::new();
for (event_name, matchers) in &config.hooks {
for matcher in matchers {
for hook_def in &matcher.hooks {
if let Some(command) = &hook_def.command {
// Substitute ${CLAUDE_PLUGIN_ROOT}
let resolved = command.replace(
"${CLAUDE_PLUGIN_ROOT}",
&self.base_path.to_string_lossy()
);
hooks_to_register.push((
event_name.clone(),
resolved,
matcher.matcher.clone(),
hook_def.timeout,
));
}
}
}
}
hooks_to_register
}
}
/// Plugin loader and registry
pub struct PluginManager {
plugins: Vec<Plugin>,
plugin_dirs: Vec<PathBuf>,
}
impl PluginManager {
/// Create a new plugin manager with default plugin directories
pub fn new() -> Self {
let mut plugin_dirs = Vec::new();
// User plugins: ~/.config/owlen/plugins
if let Some(config_dir) = dirs::config_dir() {
plugin_dirs.push(config_dir.join("owlen").join("plugins"));
}
// Project plugins: .owlen/plugins
plugin_dirs.push(PathBuf::from(".owlen/plugins"));
Self {
plugins: Vec::new(),
plugin_dirs,
}
}
/// Create a plugin manager with custom plugin directories
pub fn with_dirs(plugin_dirs: Vec<PathBuf>) -> Self {
Self {
plugins: Vec::new(),
plugin_dirs,
}
}
/// Load all plugins from configured directories
pub fn load_all(&mut self) -> Result<()> {
let plugin_dirs = self.plugin_dirs.clone();
for dir in &plugin_dirs {
if !dir.exists() {
continue;
}
self.load_from_dir(dir)?;
}
Ok(())
}
/// Load plugins from a specific directory
fn load_from_dir(&mut self, dir: &Path) -> Result<()> {
// Walk directory looking for plugin.json files
for entry in WalkDir::new(dir)
.max_depth(2) // Don't recurse too deep
.into_iter()
.filter_map(|e| e.ok())
{
if entry.file_name() == "plugin.json" {
if let Some(plugin_dir) = entry.path().parent() {
match self.load_plugin(plugin_dir) {
Ok(plugin) => {
self.plugins.push(plugin);
}
Err(e) => {
eprintln!("Warning: Failed to load plugin from {:?}: {}", plugin_dir, e);
}
}
}
}
}
Ok(())
}
/// Load a single plugin from a directory
fn load_plugin(&self, plugin_dir: &Path) -> Result<Plugin> {
let manifest_path = plugin_dir.join("plugin.json");
let content = fs::read_to_string(&manifest_path)
.map_err(|e| eyre!("Failed to read plugin manifest: {}", e))?;
let manifest: PluginManifest = serde_json::from_str(&content)
.map_err(|e| eyre!("Failed to parse plugin manifest: {}", e))?;
Ok(Plugin {
manifest,
base_path: plugin_dir.to_path_buf(),
})
}
/// Get all loaded plugins
pub fn plugins(&self) -> &[Plugin] {
&self.plugins
}
/// Find a plugin by name
pub fn find_plugin(&self, name: &str) -> Option<&Plugin> {
self.plugins.iter().find(|p| p.manifest.name == name)
}
/// Get all available commands from all plugins
pub fn all_commands(&self) -> HashMap<String, PathBuf> {
let mut commands = HashMap::new();
for plugin in &self.plugins {
for cmd_name in &plugin.manifest.commands {
let path = plugin.command_path(cmd_name);
if path.exists() {
commands.insert(cmd_name.clone(), path);
}
}
}
commands
}
/// Get all available agents from all plugins
pub fn all_agents(&self) -> HashMap<String, PathBuf> {
let mut agents = HashMap::new();
for plugin in &self.plugins {
for agent_name in &plugin.manifest.agents {
let path = plugin.agent_path(agent_name);
if path.exists() {
agents.insert(agent_name.clone(), path);
}
}
}
agents
}
/// Get all available skills from all plugins
pub fn all_skills(&self) -> HashMap<String, PathBuf> {
let mut skills = HashMap::new();
for plugin in &self.plugins {
for skill_name in &plugin.manifest.skills {
let path = plugin.skill_path(skill_name);
if path.exists() {
skills.insert(skill_name.clone(), path);
}
}
}
skills
}
/// Get all MCP servers from all plugins
pub fn all_mcp_servers(&self) -> Vec<(String, &McpServerConfig)> {
let mut servers = Vec::new();
for plugin in &self.plugins {
for server in &plugin.manifest.mcp_servers {
servers.push((plugin.manifest.name.clone(), server));
}
}
servers
}
/// Get all parsed commands
pub fn load_all_commands(&self) -> Vec<SlashCommand> {
let mut commands = Vec::new();
for plugin in &self.plugins {
for cmd_name in &plugin.manifest.commands {
if let Ok(cmd) = plugin.parse_command(cmd_name) {
commands.push(cmd);
}
}
}
commands
}
/// Get all parsed agents
pub fn load_all_agents(&self) -> Vec<AgentDefinition> {
let mut agents = Vec::new();
for plugin in &self.plugins {
for agent_name in &plugin.manifest.agents {
if let Ok(agent) = plugin.parse_agent(agent_name) {
agents.push(agent);
}
}
}
agents
}
/// Get all parsed skills
pub fn load_all_skills(&self) -> Vec<Skill> {
let mut skills = Vec::new();
for plugin in &self.plugins {
for skill_name in &plugin.manifest.skills {
if let Ok(skill) = plugin.parse_skill(skill_name) {
skills.push(skill);
}
}
}
skills
}
}
impl Default for PluginManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
fn create_test_plugin(dir: &Path) -> Result<()> {
fs::create_dir_all(dir)?;
fs::create_dir_all(dir.join("commands"))?;
fs::create_dir_all(dir.join("agents"))?;
fs::create_dir_all(dir.join("hooks"))?;
let manifest = PluginManifest {
name: "test-plugin".to_string(),
version: "1.0.0".to_string(),
description: Some("A test plugin".to_string()),
author: Some("Test Author".to_string()),
commands: vec!["test-cmd".to_string()],
agents: vec!["test-agent".to_string()],
skills: vec![],
hooks: {
let mut h = HashMap::new();
h.insert("PreToolUse".to_string(), "pre_tool_use.sh".to_string());
h
},
mcp_servers: vec![],
};
fs::write(
dir.join("plugin.json"),
serde_json::to_string_pretty(&manifest)?,
)?;
fs::write(
dir.join("commands/test-cmd.md"),
"---\ndescription: A test command\nargument-hint: <file>\nallowed-tools: read,write\n---\n\nThis is a test command body.",
)?;
fs::write(
dir.join("agents/test-agent.md"),
"---\nname: test-agent\ndescription: A test agent\ntools:\n - read\n - write\nmodel: sonnet\ncolor: blue\n---\n\nYou are a helpful test agent.",
)?;
Ok(())
}
#[test]
fn test_load_plugin() -> Result<()> {
let temp_dir = tempfile::tempdir()?;
let plugin_dir = temp_dir.path().join("test-plugin");
create_test_plugin(&plugin_dir)?;
let manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
let plugin = manager.load_plugin(&plugin_dir)?;
assert_eq!(plugin.manifest.name, "test-plugin");
assert_eq!(plugin.manifest.version, "1.0.0");
assert_eq!(plugin.manifest.commands, vec!["test-cmd"]);
assert_eq!(plugin.manifest.agents, vec!["test-agent"]);
Ok(())
}
#[test]
fn test_load_all_plugins() -> Result<()> {
let temp_dir = tempfile::tempdir()?;
let plugin_dir = temp_dir.path().join("test-plugin");
create_test_plugin(&plugin_dir)?;
let mut manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
manager.load_all()?;
assert_eq!(manager.plugins().len(), 1);
assert_eq!(manager.plugins()[0].manifest.name, "test-plugin");
Ok(())
}
#[test]
fn test_find_plugin() -> Result<()> {
let temp_dir = tempfile::tempdir()?;
let plugin_dir = temp_dir.path().join("test-plugin");
create_test_plugin(&plugin_dir)?;
let mut manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
manager.load_all()?;
let plugin = manager.find_plugin("test-plugin");
assert!(plugin.is_some());
assert_eq!(plugin.unwrap().manifest.name, "test-plugin");
let not_found = manager.find_plugin("nonexistent");
assert!(not_found.is_none());
Ok(())
}
#[test]
fn test_all_commands() -> Result<()> {
let temp_dir = tempfile::tempdir()?;
let plugin_dir = temp_dir.path().join("test-plugin");
create_test_plugin(&plugin_dir)?;
let mut manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
manager.load_all()?;
let commands = manager.all_commands();
assert_eq!(commands.len(), 1);
assert!(commands.contains_key("test-cmd"));
Ok(())
}
#[test]
fn test_parse_command() -> Result<()> {
let temp_dir = tempfile::tempdir()?;
let plugin_dir = temp_dir.path().join("test-plugin");
create_test_plugin(&plugin_dir)?;
let manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
let plugin = manager.load_plugin(&plugin_dir)?;
let cmd = plugin.parse_command("test-cmd")?;
assert_eq!(cmd.name, "test-cmd");
assert_eq!(cmd.description, Some("A test command".to_string()));
assert_eq!(cmd.argument_hint, Some("<file>".to_string()));
assert_eq!(cmd.allowed_tools, Some(vec!["read".to_string(), "write".to_string()]));
assert_eq!(cmd.body, "This is a test command body.");
Ok(())
}
#[test]
fn test_parse_agent() -> Result<()> {
let temp_dir = tempfile::tempdir()?;
let plugin_dir = temp_dir.path().join("test-plugin");
create_test_plugin(&plugin_dir)?;
let manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
let plugin = manager.load_plugin(&plugin_dir)?;
let agent = plugin.parse_agent("test-agent")?;
assert_eq!(agent.name, "test-agent");
assert_eq!(agent.description, "A test agent");
assert_eq!(agent.tools, vec!["read", "write"]);
assert_eq!(agent.model, Some("sonnet".to_string()));
assert_eq!(agent.color, Some("blue".to_string()));
assert_eq!(agent.system_prompt, "You are a helpful test agent.");
Ok(())
}
#[test]
fn test_load_all_commands() -> Result<()> {
let temp_dir = tempfile::tempdir()?;
let plugin_dir = temp_dir.path().join("test-plugin");
create_test_plugin(&plugin_dir)?;
let mut manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
manager.load_all()?;
let commands = manager.load_all_commands();
assert_eq!(commands.len(), 1);
assert_eq!(commands[0].name, "test-cmd");
assert_eq!(commands[0].description, Some("A test command".to_string()));
Ok(())
}
#[test]
fn test_load_all_agents() -> Result<()> {
let temp_dir = tempfile::tempdir()?;
let plugin_dir = temp_dir.path().join("test-plugin");
create_test_plugin(&plugin_dir)?;
let mut manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
manager.load_all()?;
let agents = manager.load_all_agents();
assert_eq!(agents.len(), 1);
assert_eq!(agents[0].name, "test-agent");
assert_eq!(agents[0].description, "A test agent");
Ok(())
}
}

View File

@@ -0,0 +1,175 @@
// End-to-end integration test for plugin hooks
use color_eyre::eyre::Result;
use plugins::PluginManager;
use std::fs;
use tempfile::TempDir;
fn create_test_plugin_with_hooks(plugin_dir: &std::path::Path) -> Result<()> {
fs::create_dir_all(plugin_dir)?;
// Create plugin manifest
let manifest = serde_json::json!({
"name": "test-hook-plugin",
"version": "1.0.0",
"description": "Test plugin with hooks",
"commands": [],
"agents": [],
"skills": [],
"hooks": {},
"mcp_servers": []
});
fs::write(
plugin_dir.join("plugin.json"),
serde_json::to_string_pretty(&manifest)?,
)?;
// Create hooks directory and hooks.json
let hooks_dir = plugin_dir.join("hooks");
fs::create_dir_all(&hooks_dir)?;
let hooks_config = serde_json::json!({
"description": "Validate edit and write operations",
"hooks": {
"PreToolUse": [
{
"matcher": "Edit|Write",
"hooks": [
{
"type": "command",
"command": "python3 ${CLAUDE_PLUGIN_ROOT}/hooks/validate.py",
"timeout": 5000
}
]
},
{
"matcher": "Bash",
"hooks": [
{
"type": "command",
"command": "echo 'Bash hook' && exit 0"
}
]
}
],
"PostToolUse": [
{
"hooks": [
{
"type": "command",
"command": "echo 'Post-tool hook' && exit 0"
}
]
}
]
}
});
fs::write(
hooks_dir.join("hooks.json"),
serde_json::to_string_pretty(&hooks_config)?,
)?;
Ok(())
}
#[test]
fn test_load_plugin_hooks_config() -> Result<()> {
let temp_dir = TempDir::new()?;
let plugin_dir = temp_dir.path().join("test-plugin");
create_test_plugin_with_hooks(&plugin_dir)?;
// Load all plugins
let mut plugin_manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
plugin_manager.load_all()?;
assert_eq!(plugin_manager.plugins().len(), 1);
let plugin = &plugin_manager.plugins()[0];
// Load hooks config
let hooks_config = plugin.load_hooks_config()?;
assert!(hooks_config.is_some());
let config = hooks_config.unwrap();
assert_eq!(config.description, Some("Validate edit and write operations".to_string()));
assert!(config.hooks.contains_key("PreToolUse"));
assert!(config.hooks.contains_key("PostToolUse"));
// Check PreToolUse hooks
let pre_tool_hooks = &config.hooks["PreToolUse"];
assert_eq!(pre_tool_hooks.len(), 2);
// First matcher: Edit|Write
assert_eq!(pre_tool_hooks[0].matcher, Some("Edit|Write".to_string()));
assert_eq!(pre_tool_hooks[0].hooks.len(), 1);
assert_eq!(pre_tool_hooks[0].hooks[0].hook_type, "command");
assert!(pre_tool_hooks[0].hooks[0].command.as_ref().unwrap().contains("validate.py"));
// Second matcher: Bash
assert_eq!(pre_tool_hooks[1].matcher, Some("Bash".to_string()));
assert_eq!(pre_tool_hooks[1].hooks.len(), 1);
Ok(())
}
#[test]
fn test_plugin_hooks_substitution() -> Result<()> {
let temp_dir = TempDir::new()?;
let plugin_dir = temp_dir.path().join("test-plugin");
create_test_plugin_with_hooks(&plugin_dir)?;
// Load all plugins
let mut plugin_manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
plugin_manager.load_all()?;
assert_eq!(plugin_manager.plugins().len(), 1);
let plugin = &plugin_manager.plugins()[0];
// Load hooks config and register
let hooks_config = plugin.load_hooks_config()?.unwrap();
let hooks_to_register = plugin.register_hooks_with_manager(&hooks_config);
// Check that ${CLAUDE_PLUGIN_ROOT} was substituted
let edit_write_hook = hooks_to_register.iter()
.find(|(event, _, pattern, _)| {
event == "PreToolUse" && pattern.as_ref().map(|p| p.contains("Edit")).unwrap_or(false)
})
.unwrap();
// The command should have the plugin path substituted
assert!(edit_write_hook.1.contains(&plugin_dir.to_string_lossy().to_string()));
assert!(edit_write_hook.1.contains("validate.py"));
assert!(!edit_write_hook.1.contains("${CLAUDE_PLUGIN_ROOT}"));
Ok(())
}
#[test]
fn test_multiple_plugins_with_hooks() -> Result<()> {
let temp_dir = TempDir::new()?;
// Create two plugins with hooks
let plugin1_dir = temp_dir.path().join("plugin1");
create_test_plugin_with_hooks(&plugin1_dir)?;
let plugin2_dir = temp_dir.path().join("plugin2");
create_test_plugin_with_hooks(&plugin2_dir)?;
// Load all plugins
let mut plugin_manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
plugin_manager.load_all()?;
assert_eq!(plugin_manager.plugins().len(), 2);
// Collect all hooks from all plugins
let mut total_hooks = 0;
for plugin in plugin_manager.plugins() {
if let Ok(Some(hooks_config)) = plugin.load_hooks_config() {
let hooks = plugin.register_hooks_with_manager(&hooks_config);
total_hooks += hooks.len();
}
}
// Each plugin has 3 hooks (2 PreToolUse + 1 PostToolUse)
assert_eq!(total_hooks, 6);
Ok(())
}

View File

@@ -0,0 +1,11 @@
[package]
name = "tools-ask"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1", features = ["sync"] }
color-eyre = "0.6"

View File

@@ -0,0 +1,60 @@
//! AskUserQuestion tool for interactive user input
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tokio::sync::{mpsc, oneshot};
/// A question option
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuestionOption {
pub label: String,
pub description: String,
}
/// A question to ask the user
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Question {
pub question: String,
pub header: String,
pub options: Vec<QuestionOption>,
pub multi_select: bool,
}
/// Request sent to the UI to ask questions
#[derive(Debug)]
pub struct AskRequest {
pub questions: Vec<Question>,
pub response_tx: oneshot::Sender<HashMap<String, String>>,
}
/// Channel for sending ask requests to the UI
pub type AskSender = mpsc::Sender<AskRequest>;
pub type AskReceiver = mpsc::Receiver<AskRequest>;
/// Create a channel pair for ask requests
pub fn create_ask_channel() -> (AskSender, AskReceiver) {
mpsc::channel(1)
}
/// Ask the user questions (called by agent)
pub async fn ask_user(
sender: &AskSender,
questions: Vec<Question>,
) -> color_eyre::Result<HashMap<String, String>> {
let (response_tx, response_rx) = oneshot::channel();
sender.send(AskRequest { questions, response_tx }).await
.map_err(|_| color_eyre::eyre::eyre!("Failed to send ask request"))?;
response_rx.await
.map_err(|_| color_eyre::eyre::eyre!("Failed to receive ask response"))
}
/// Parse questions from JSON tool input
pub fn parse_questions(input: &serde_json::Value) -> color_eyre::Result<Vec<Question>> {
let questions = input.get("questions")
.ok_or_else(|| color_eyre::eyre::eyre!("Missing 'questions' field"))?;
serde_json::from_value(questions.clone())
.map_err(|e| color_eyre::eyre::eyre!("Invalid questions format: {}", e))
}

View File

@@ -6,9 +6,11 @@ license.workspace = true
rust-version.workspace = true
[dependencies]
tokio = { version = "1.39", features = ["process", "io-util", "time", "sync"] }
tokio = { version = "1.39", features = ["process", "io-util", "time", "sync", "rt"] }
color-eyre = "0.6"
tempfile = "3.23.0"
parking_lot = "0.12"
uuid = { version = "1.0", features = ["v4"] }
[dev-dependencies]
tokio = { version = "1.39", features = ["macros", "rt-multi-thread"] }

View File

@@ -1,5 +1,8 @@
use color_eyre::eyre::{Result, eyre};
use std::collections::HashMap;
use std::process::Stdio;
use std::sync::Arc;
use parking_lot::RwLock;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::Mutex;
@@ -19,6 +22,7 @@ pub struct CommandOutput {
pub struct BashSession {
child: Mutex<Child>,
last_output: Option<String>,
}
impl BashSession {
@@ -40,6 +44,7 @@ impl BashSession {
Ok(Self {
child: Mutex::new(child),
last_output: None,
})
}
@@ -54,7 +59,13 @@ impl BashSession {
let result = timeout(timeout_duration, self.execute_internal(command)).await;
match result {
Ok(output) => output,
Ok(output) => {
// Store the output for potential retrieval via BashOutput tool
let combined = format!("{}{}", output.as_ref().map(|o| o.stdout.as_str()).unwrap_or(""),
output.as_ref().map(|o| o.stderr.as_str()).unwrap_or(""));
self.last_output = Some(combined);
output
},
Err(_) => Err(eyre!("Command timed out after {}ms", timeout_duration.as_millis())),
}
}
@@ -158,6 +169,106 @@ impl BashSession {
}
}
/// Manages background bash shells by ID
#[derive(Clone, Default)]
pub struct ShellManager {
shells: Arc<RwLock<HashMap<String, BashSession>>>,
}
impl ShellManager {
pub fn new() -> Self {
Self::default()
}
/// Start a new background shell, returns shell ID
pub async fn start_shell(&self) -> Result<String> {
let id = uuid::Uuid::new_v4().to_string();
let session = BashSession::new().await?;
self.shells.write().insert(id.clone(), session);
Ok(id)
}
/// Execute command in background shell
pub async fn execute(&self, shell_id: &str, command: &str, timeout: Option<Duration>) -> Result<CommandOutput> {
// We need to handle this carefully to avoid holding the lock across await
// First check if the shell exists and clone what we need
let exists = self.shells.read().contains_key(shell_id);
if !exists {
return Err(eyre!("Shell not found: {}", shell_id));
}
// For now, we need to use a more complex approach since BashSession contains async operations
// We'll execute and then update in a separate critical section
let timeout_ms = timeout.map(|d| d.as_millis() as u64);
// Take temporary ownership for execution
let mut session = {
let mut shells = self.shells.write();
shells.remove(shell_id)
.ok_or_else(|| eyre!("Shell not found: {}", shell_id))?
};
// Execute without holding the lock
let result = session.execute(command, timeout_ms).await;
// Put the session back
self.shells.write().insert(shell_id.to_string(), session);
result
}
/// Get output from a shell (BashOutput tool)
pub fn get_output(&self, shell_id: &str) -> Result<Option<String>> {
let shells = self.shells.read();
let session = shells.get(shell_id)
.ok_or_else(|| eyre!("Shell not found: {}", shell_id))?;
// Return any buffered output
Ok(session.last_output.clone())
}
/// Kill a shell (KillShell tool)
pub fn kill_shell(&self, shell_id: &str) -> Result<()> {
let mut shells = self.shells.write();
if shells.remove(shell_id).is_some() {
Ok(())
} else {
Err(eyre!("Shell not found: {}", shell_id))
}
}
/// List active shells
pub fn list_shells(&self) -> Vec<String> {
self.shells.read().keys().cloned().collect()
}
}
/// Start a background bash command, returns shell ID
pub async fn run_background(manager: &ShellManager, command: &str) -> Result<String> {
let shell_id = manager.start_shell().await?;
// Execute in background (non-blocking)
tokio::spawn({
let manager = manager.clone();
let command = command.to_string();
let shell_id = shell_id.clone();
async move {
let _ = manager.execute(&shell_id, &command, None).await;
}
});
Ok(shell_id)
}
/// Get output from background shell (BashOutput tool)
pub fn bash_output(manager: &ShellManager, shell_id: &str) -> Result<String> {
manager.get_output(shell_id)?
.ok_or_else(|| eyre!("No output available"))
}
/// Kill a background shell (KillShell tool)
pub fn kill_shell(manager: &ShellManager, shell_id: &str) -> Result<String> {
manager.kill_shell(shell_id)?;
Ok(format!("Shell {} terminated", shell_id))
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -13,6 +13,8 @@ grep-regex = "0.1"
grep-searcher = "0.1"
color-eyre = "0.6"
similar = "2.7"
serde = { version = "1.0", features = ["derive"] }
humantime = "2.1"
[dev-dependencies]
tempfile = "3.23.0"

View File

@@ -3,6 +3,7 @@ use ignore::WalkBuilder;
use grep_regex::RegexMatcher;
use grep_searcher::{sinks::UTF8, SearcherBuilder};
use globset::Glob;
use serde::{Deserialize, Serialize};
use std::path::Path;
pub fn read_file(path: &str) -> Result<String> {
@@ -128,3 +129,80 @@ pub fn grep(root: &str, pattern: &str) -> Result<Vec<(String, usize, String)>> {
Ok(results)
}
/// Edit operation for MultiEdit
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EditOperation {
pub old_string: String,
pub new_string: String,
}
/// Perform multiple edits on a file atomically
pub fn multi_edit_file(path: &str, edits: Vec<EditOperation>) -> Result<String> {
let content = std::fs::read_to_string(path)?;
let mut new_content = content.clone();
// Apply edits in order
for edit in &edits {
if !new_content.contains(&edit.old_string) {
return Err(eyre!("String not found: '{}'", edit.old_string));
}
new_content = new_content.replacen(&edit.old_string, &edit.new_string, 1);
}
// Create backup and write
let backup_path = format!("{}.bak", path);
std::fs::copy(path, &backup_path)?;
std::fs::write(path, &new_content)?;
Ok(format!("Applied {} edits to {}", edits.len(), path))
}
/// Entry in directory listing
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DirEntry {
pub name: String,
pub is_dir: bool,
pub size: Option<u64>,
pub modified: Option<String>,
}
/// List contents of a directory
pub fn list_directory(path: &str, show_hidden: bool) -> Result<Vec<DirEntry>> {
let entries = std::fs::read_dir(path)?;
let mut result = Vec::new();
for entry in entries {
let entry = entry?;
let name = entry.file_name().to_string_lossy().to_string();
// Skip hidden files unless requested
if !show_hidden && name.starts_with('.') {
continue;
}
let metadata = entry.metadata()?;
let modified = metadata.modified().ok().map(|t| {
// Format as ISO 8601
humantime::format_rfc3339(t).to_string()
});
result.push(DirEntry {
name,
is_dir: metadata.is_dir(),
size: if metadata.is_file() { Some(metadata.len()) } else { None },
modified,
});
}
// Sort directories first, then alphabetically
result.sort_by(|a, b| {
match (a.is_dir, b.is_dir) {
(true, false) => std::cmp::Ordering::Less,
(false, true) => std::cmp::Ordering::Greater,
_ => a.name.cmp(&b.name),
}
});
Ok(result)
}

View File

@@ -1,4 +1,4 @@
use tools_fs::{read_file, glob_list, grep, write_file, edit_file};
use tools_fs::{read_file, glob_list, grep, write_file, edit_file, multi_edit_file, list_directory, EditOperation};
use std::fs;
use tempfile::tempdir;
@@ -102,3 +102,123 @@ fn edit_file_fails_on_no_match() {
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("not found") || err_msg.contains("String to replace"));
}
#[test]
fn multi_edit_file_applies_multiple_edits() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.txt");
let original = "line 1\nline 2\nline 3\n";
fs::write(&file_path, original).unwrap();
let edits = vec![
EditOperation {
old_string: "line 1".to_string(),
new_string: "modified 1".to_string(),
},
EditOperation {
old_string: "line 2".to_string(),
new_string: "modified 2".to_string(),
},
];
let result = multi_edit_file(file_path.to_str().unwrap(), edits).unwrap();
assert!(result.contains("Applied 2 edits"));
let content = read_file(file_path.to_str().unwrap()).unwrap();
assert_eq!(content, "modified 1\nmodified 2\nline 3\n");
// Backup file should exist
let backup_path = format!("{}.bak", file_path.display());
assert!(std::path::Path::new(&backup_path).exists());
}
#[test]
fn multi_edit_file_fails_on_missing_string() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.txt");
let original = "line 1\nline 2\n";
fs::write(&file_path, original).unwrap();
let edits = vec![
EditOperation {
old_string: "line 1".to_string(),
new_string: "modified 1".to_string(),
},
EditOperation {
old_string: "nonexistent".to_string(),
new_string: "modified".to_string(),
},
];
let result = multi_edit_file(file_path.to_str().unwrap(), edits);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("String not found"));
}
#[test]
fn list_directory_shows_files_and_dirs() {
let dir = tempdir().unwrap();
let root = dir.path();
// Create test structure
fs::write(root.join("file1.txt"), "content").unwrap();
fs::write(root.join("file2.txt"), "content").unwrap();
fs::create_dir(root.join("subdir")).unwrap();
fs::write(root.join(".hidden"), "hidden content").unwrap();
let entries = list_directory(root.to_str().unwrap(), false).unwrap();
// Should find 2 files and 1 directory (hidden file excluded)
assert_eq!(entries.len(), 3);
// Verify directory appears first (sorted)
assert_eq!(entries[0].name, "subdir");
assert!(entries[0].is_dir);
// Verify files
let file_names: Vec<_> = entries.iter().skip(1).map(|e| e.name.as_str()).collect();
assert!(file_names.contains(&"file1.txt"));
assert!(file_names.contains(&"file2.txt"));
// Hidden file should not be present
assert!(!entries.iter().any(|e| e.name == ".hidden"));
}
#[test]
fn list_directory_shows_hidden_when_requested() {
let dir = tempdir().unwrap();
let root = dir.path();
fs::write(root.join("visible.txt"), "content").unwrap();
fs::write(root.join(".hidden"), "hidden content").unwrap();
let entries = list_directory(root.to_str().unwrap(), true).unwrap();
// Should find both files
assert!(entries.iter().any(|e| e.name == "visible.txt"));
assert!(entries.iter().any(|e| e.name == ".hidden"));
}
#[test]
fn list_directory_includes_metadata() {
let dir = tempdir().unwrap();
let root = dir.path();
fs::write(root.join("test.txt"), "hello world").unwrap();
fs::create_dir(root.join("testdir")).unwrap();
let entries = list_directory(root.to_str().unwrap(), false).unwrap();
// Directory entry should have no size
let dir_entry = entries.iter().find(|e| e.name == "testdir").unwrap();
assert!(dir_entry.is_dir);
assert!(dir_entry.size.is_none());
assert!(dir_entry.modified.is_some());
// File entry should have size
let file_entry = entries.iter().find(|e| e.name == "test.txt").unwrap();
assert!(!file_entry.is_dir);
assert_eq!(file_entry.size, Some(11)); // "hello world" is 11 bytes
assert!(file_entry.modified.is_some());
}

View File

@@ -0,0 +1,14 @@
[package]
name = "tools-notebook"
version = "0.1.0"
edition.workspace = true
license.workspace = true
rust-version.workspace = true
[dependencies]
serde = { version = "1", features = ["derive"] }
serde_json = "1"
color-eyre = "0.6"
[dev-dependencies]
tempfile = "3.23.0"

View File

@@ -0,0 +1,175 @@
use color_eyre::eyre::{Result, eyre};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::fs;
use std::path::Path;
/// Jupyter notebook structure
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Notebook {
pub cells: Vec<Cell>,
pub metadata: NotebookMetadata,
pub nbformat: i32,
pub nbformat_minor: i32,
}
/// Notebook cell
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Cell {
pub cell_type: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub execution_count: Option<i32>,
pub metadata: HashMap<String, Value>,
pub source: Vec<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub outputs: Vec<Output>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
}
/// Cell output
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Output {
pub output_type: String,
#[serde(flatten)]
pub data: HashMap<String, Value>,
}
/// Notebook metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NotebookMetadata {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub kernelspec: Option<KernelSpec>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub language_info: Option<LanguageInfo>,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KernelSpec {
pub display_name: String,
pub language: String,
pub name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LanguageInfo {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub version: Option<String>,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
/// Read a Jupyter notebook from a file
pub fn read_notebook<P: AsRef<Path>>(path: P) -> Result<Notebook> {
let content = fs::read_to_string(path)?;
let notebook: Notebook = serde_json::from_str(&content)?;
Ok(notebook)
}
/// Write a Jupyter notebook to a file
pub fn write_notebook<P: AsRef<Path>>(path: P, notebook: &Notebook) -> Result<()> {
let content = serde_json::to_string_pretty(notebook)?;
fs::write(path, content)?;
Ok(())
}
/// Edit operations for notebooks
pub enum NotebookEdit {
/// Replace cell at index with new source
EditCell { index: usize, source: Vec<String> },
/// Add a new cell at index
AddCell { index: usize, cell: Cell },
/// Delete cell at index
DeleteCell { index: usize },
}
/// Apply an edit to a notebook
pub fn edit_notebook(notebook: &mut Notebook, edit: NotebookEdit) -> Result<()> {
match edit {
NotebookEdit::EditCell { index, source } => {
if index >= notebook.cells.len() {
return Err(eyre!("Cell index {} out of bounds (notebook has {} cells)", index, notebook.cells.len()));
}
notebook.cells[index].source = source;
}
NotebookEdit::AddCell { index, cell } => {
if index > notebook.cells.len() {
return Err(eyre!("Cell index {} out of bounds (notebook has {} cells)", index, notebook.cells.len()));
}
notebook.cells.insert(index, cell);
}
NotebookEdit::DeleteCell { index } => {
if index >= notebook.cells.len() {
return Err(eyre!("Cell index {} out of bounds (notebook has {} cells)", index, notebook.cells.len()));
}
notebook.cells.remove(index);
}
}
Ok(())
}
/// Create a new code cell
pub fn new_code_cell(source: Vec<String>) -> Cell {
Cell {
cell_type: "code".to_string(),
execution_count: None,
metadata: HashMap::new(),
source,
outputs: Vec::new(),
id: None,
}
}
/// Create a new markdown cell
pub fn new_markdown_cell(source: Vec<String>) -> Cell {
Cell {
cell_type: "markdown".to_string(),
execution_count: None,
metadata: HashMap::new(),
source,
outputs: Vec::new(),
id: None,
}
}
/// Get cell source as a single string
pub fn cell_source_as_string(cell: &Cell) -> String {
cell.source.join("")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cell_source_concatenation() {
let cell = Cell {
cell_type: "code".to_string(),
execution_count: None,
metadata: HashMap::new(),
source: vec!["import pandas as pd\n".to_string(), "df = pd.DataFrame()\n".to_string()],
outputs: Vec::new(),
id: None,
};
let source = cell_source_as_string(&cell);
assert_eq!(source, "import pandas as pd\ndf = pd.DataFrame()\n");
}
#[test]
fn new_code_cell_creation() {
let cell = new_code_cell(vec!["print('hello')\n".to_string()]);
assert_eq!(cell.cell_type, "code");
assert!(cell.outputs.is_empty());
}
#[test]
fn new_markdown_cell_creation() {
let cell = new_markdown_cell(vec!["# Title\n".to_string()]);
assert_eq!(cell.cell_type, "markdown");
}
}

View File

@@ -0,0 +1,280 @@
use tools_notebook::*;
use std::fs;
use tempfile::tempdir;
#[test]
fn notebook_round_trip_preserves_metadata() {
let dir = tempdir().unwrap();
let notebook_path = dir.path().join("test.ipynb");
// Create a sample notebook with metadata
let notebook_json = r##"{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"source": ["print('hello world')"],
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": ["# Test Notebook", "This is a test."]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.9.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}"##;
fs::write(&notebook_path, notebook_json).unwrap();
// Read the notebook
let notebook = read_notebook(&notebook_path).unwrap();
// Verify structure
assert_eq!(notebook.cells.len(), 2);
assert_eq!(notebook.nbformat, 4);
assert_eq!(notebook.nbformat_minor, 5);
// Verify metadata
assert!(notebook.metadata.kernelspec.is_some());
let kernelspec = notebook.metadata.kernelspec.as_ref().unwrap();
assert_eq!(kernelspec.language, "python");
assert_eq!(kernelspec.name, "python3");
assert!(notebook.metadata.language_info.is_some());
let lang_info = notebook.metadata.language_info.as_ref().unwrap();
assert_eq!(lang_info.name, "python");
assert_eq!(lang_info.version, Some("3.9.0".to_string()));
// Write it back
let output_path = dir.path().join("output.ipynb");
write_notebook(&output_path, &notebook).unwrap();
// Read it again
let notebook2 = read_notebook(&output_path).unwrap();
// Verify metadata is preserved
assert_eq!(notebook2.nbformat, 4);
assert_eq!(notebook2.nbformat_minor, 5);
assert!(notebook2.metadata.kernelspec.is_some());
assert_eq!(
notebook2.metadata.kernelspec.as_ref().unwrap().language,
"python"
);
}
#[test]
fn notebook_edit_cell_content() {
let dir = tempdir().unwrap();
let notebook_path = dir.path().join("test.ipynb");
let notebook_json = r##"{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"source": ["x = 1"],
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"source": ["y = 2"],
"outputs": []
}
],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}"##;
fs::write(&notebook_path, notebook_json).unwrap();
let mut notebook = read_notebook(&notebook_path).unwrap();
// Edit the first cell
edit_notebook(
&mut notebook,
NotebookEdit::EditCell {
index: 0,
source: vec!["x = 10\n".to_string(), "print(x)\n".to_string()],
},
)
.unwrap();
// Verify the edit
assert_eq!(notebook.cells[0].source.len(), 2);
assert_eq!(notebook.cells[0].source[0], "x = 10\n");
assert_eq!(notebook.cells[0].source[1], "print(x)\n");
// Second cell should be unchanged
assert_eq!(notebook.cells[1].source[0], "y = 2");
}
#[test]
fn notebook_add_delete_cells() {
let dir = tempdir().unwrap();
let notebook_path = dir.path().join("test.ipynb");
let notebook_json = r##"{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"source": ["x = 1"],
"outputs": []
}
],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}"##;
fs::write(&notebook_path, notebook_json).unwrap();
let mut notebook = read_notebook(&notebook_path).unwrap();
assert_eq!(notebook.cells.len(), 1);
// Add a cell at the end
let new_cell = new_code_cell(vec!["y = 2\n".to_string()]);
edit_notebook(
&mut notebook,
NotebookEdit::AddCell {
index: 1,
cell: new_cell,
},
)
.unwrap();
assert_eq!(notebook.cells.len(), 2);
assert_eq!(notebook.cells[1].source[0], "y = 2\n");
// Add a cell at the beginning
let first_cell = new_markdown_cell(vec!["# Header\n".to_string()]);
edit_notebook(
&mut notebook,
NotebookEdit::AddCell {
index: 0,
cell: first_cell,
},
)
.unwrap();
assert_eq!(notebook.cells.len(), 3);
assert_eq!(notebook.cells[0].cell_type, "markdown");
assert_eq!(notebook.cells[0].source[0], "# Header\n");
assert_eq!(notebook.cells[1].source[0], "x = 1"); // Original first cell is now second
// Delete the middle cell
edit_notebook(&mut notebook, NotebookEdit::DeleteCell { index: 1 }).unwrap();
assert_eq!(notebook.cells.len(), 2);
assert_eq!(notebook.cells[0].cell_type, "markdown");
assert_eq!(notebook.cells[1].source[0], "y = 2\n");
}
#[test]
fn notebook_edit_out_of_bounds() {
let dir = tempdir().unwrap();
let notebook_path = dir.path().join("test.ipynb");
let notebook_json = r##"{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"source": ["x = 1\n"],
"outputs": []
}
],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}"##;
fs::write(&notebook_path, notebook_json).unwrap();
let mut notebook = read_notebook(&notebook_path).unwrap();
// Try to edit non-existent cell
let result = edit_notebook(
&mut notebook,
NotebookEdit::EditCell {
index: 5,
source: vec!["bad\n".to_string()],
},
);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("out of bounds"));
}
#[test]
fn notebook_with_outputs_preserved() {
let dir = tempdir().unwrap();
let notebook_path = dir.path().join("test.ipynb");
let notebook_json = r##"{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"source": ["print('hello')\n"],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": ["hello\n"]
}
]
}
],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}"##;
fs::write(&notebook_path, notebook_json).unwrap();
let notebook = read_notebook(&notebook_path).unwrap();
assert_eq!(notebook.cells[0].outputs.len(), 1);
assert_eq!(notebook.cells[0].outputs[0].output_type, "stream");
// Write and read back
let output_path = dir.path().join("output.ipynb");
write_notebook(&output_path, &notebook).unwrap();
let notebook2 = read_notebook(&output_path).unwrap();
assert_eq!(notebook2.cells[0].outputs.len(), 1);
assert_eq!(notebook2.cells[0].outputs[0].output_type, "stream");
}
#[test]
fn cell_source_as_string_concatenates() {
let cell = new_code_cell(vec![
"import numpy as np\n".to_string(),
"arr = np.array([1, 2, 3])\n".to_string(),
]);
let source = cell_source_as_string(&cell);
assert_eq!(source, "import numpy as np\narr = np.array([1, 2, 3])\n");
}

View File

@@ -0,0 +1,18 @@
[package]
name = "tools-plan"
version = "0.1.0"
edition = "2024"
license = "AGPL-3.0"
description = "Planning mode tools for the Owlen agent"
[dependencies]
color-eyre = "0.6"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
chrono = { version = "0.4", features = ["serde"] }
uuid = { version = "1.0", features = ["v4"] }
tokio = { version = "1", features = ["fs"] }
[dev-dependencies]
tempfile = "3.13"
tokio = { version = "1", features = ["rt", "macros"] }

View File

@@ -0,0 +1,296 @@
//! Planning mode tools for the Owlen agent
//!
//! Provides EnterPlanMode and ExitPlanMode tools that allow the agent
//! to enter a planning phase where only read-only operations are allowed,
//! and then present a plan for user approval.
use color_eyre::eyre::Result;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use chrono::{DateTime, Utc};
use uuid::Uuid;
/// Agent mode - normal execution or planning
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum AgentMode {
/// Normal mode - all tools available per permission settings
Normal,
/// Planning mode - only read-only tools allowed
Planning {
/// Path to the plan file being written
plan_file: PathBuf,
/// When planning mode was entered
started_at: DateTime<Utc>,
},
}
impl Default for AgentMode {
fn default() -> Self {
Self::Normal
}
}
impl AgentMode {
/// Check if we're in planning mode
pub fn is_planning(&self) -> bool {
matches!(self, AgentMode::Planning { .. })
}
/// Get the plan file path if in planning mode
pub fn plan_file(&self) -> Option<&PathBuf> {
match self {
AgentMode::Planning { plan_file, .. } => Some(plan_file),
AgentMode::Normal => None,
}
}
}
/// Plan file metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlanMetadata {
pub id: String,
pub created_at: DateTime<Utc>,
pub status: PlanStatus,
pub title: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum PlanStatus {
/// Plan is being written
Draft,
/// Plan is awaiting user approval
PendingApproval,
/// Plan was approved by user
Approved,
/// Plan was rejected by user
Rejected,
}
/// Manager for plan files
pub struct PlanManager {
plans_dir: PathBuf,
}
impl PlanManager {
/// Create a new plan manager
pub fn new(project_root: PathBuf) -> Self {
let plans_dir = project_root.join(".owlen").join("plans");
Self { plans_dir }
}
/// Create a new plan manager with custom directory
pub fn with_dir(plans_dir: PathBuf) -> Self {
Self { plans_dir }
}
/// Get the plans directory
pub fn plans_dir(&self) -> &PathBuf {
&self.plans_dir
}
/// Ensure the plans directory exists
pub async fn ensure_dir(&self) -> Result<()> {
tokio::fs::create_dir_all(&self.plans_dir).await?;
Ok(())
}
/// Generate a unique plan file name
/// Uses a format like: <adjective>-<verb>-<noun>.md
pub fn generate_plan_name(&self) -> String {
// Simple word lists for readable names
let adjectives = ["cozy", "swift", "clever", "bright", "calm", "eager", "gentle", "happy"];
let verbs = ["dancing", "jumping", "running", "flying", "singing", "coding", "building", "thinking"];
let nouns = ["owl", "fox", "bear", "wolf", "hawk", "deer", "lion", "tiger"];
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let uuid = Uuid::new_v4();
let mut hasher = DefaultHasher::new();
uuid.hash(&mut hasher);
let hash = hasher.finish();
let adj = adjectives[(hash % adjectives.len() as u64) as usize];
let verb = verbs[((hash >> 8) % verbs.len() as u64) as usize];
let noun = nouns[((hash >> 16) % nouns.len() as u64) as usize];
format!("{}-{}-{}.md", adj, verb, noun)
}
/// Create a new plan file and return the path
pub async fn create_plan(&self) -> Result<PathBuf> {
self.ensure_dir().await?;
let filename = self.generate_plan_name();
let plan_path = self.plans_dir.join(&filename);
// Create initial plan file with metadata
let metadata = PlanMetadata {
id: Uuid::new_v4().to_string(),
created_at: Utc::now(),
status: PlanStatus::Draft,
title: None,
};
let initial_content = format!(
"<!-- plan-id: {} -->\n<!-- status: draft -->\n\n# Implementation Plan\n\n",
metadata.id
);
tokio::fs::write(&plan_path, initial_content).await?;
Ok(plan_path)
}
/// Write content to a plan file
pub async fn write_plan(&self, path: &PathBuf, content: &str) -> Result<()> {
// Preserve the metadata header if it exists
let existing = tokio::fs::read_to_string(path).await.unwrap_or_default();
// Extract metadata lines (lines starting with <!--)
let metadata_lines: Vec<&str> = existing
.lines()
.take_while(|line| line.starts_with("<!--"))
.collect();
// Update status to pending approval
let mut new_content = String::new();
for line in &metadata_lines {
if line.contains("status:") {
new_content.push_str("<!-- status: pending_approval -->\n");
} else {
new_content.push_str(line);
new_content.push('\n');
}
}
new_content.push('\n');
new_content.push_str(content);
tokio::fs::write(path, new_content).await?;
Ok(())
}
/// Read a plan file
pub async fn read_plan(&self, path: &PathBuf) -> Result<String> {
let content = tokio::fs::read_to_string(path).await?;
Ok(content)
}
/// Update plan status
pub async fn set_status(&self, path: &PathBuf, status: PlanStatus) -> Result<()> {
let content = tokio::fs::read_to_string(path).await?;
let status_str = match status {
PlanStatus::Draft => "draft",
PlanStatus::PendingApproval => "pending_approval",
PlanStatus::Approved => "approved",
PlanStatus::Rejected => "rejected",
};
// Replace status line
let updated: String = content
.lines()
.map(|line| {
if line.contains("<!-- status:") {
format!("<!-- status: {} -->", status_str)
} else {
line.to_string()
}
})
.collect::<Vec<_>>()
.join("\n");
tokio::fs::write(path, updated).await?;
Ok(())
}
/// List all plan files
pub async fn list_plans(&self) -> Result<Vec<PathBuf>> {
let mut plans = Vec::new();
if !self.plans_dir.exists() {
return Ok(plans);
}
let mut entries = tokio::fs::read_dir(&self.plans_dir).await?;
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if path.extension().map_or(false, |ext| ext == "md") {
plans.push(path);
}
}
plans.sort();
Ok(plans)
}
}
/// Enter planning mode
pub fn enter_plan_mode(plan_file: PathBuf) -> AgentMode {
AgentMode::Planning {
plan_file,
started_at: Utc::now(),
}
}
/// Exit planning mode and return to normal
pub fn exit_plan_mode() -> AgentMode {
AgentMode::Normal
}
/// Check if a tool is allowed in planning mode
/// Only read-only tools are allowed
pub fn is_tool_allowed_in_plan_mode(tool_name: &str) -> bool {
matches!(
tool_name,
"read" | "glob" | "grep" | "ls" | "web_fetch" | "web_search" |
"todo_write" | "ask_user" | "exit_plan_mode"
)
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_create_plan() {
let temp_dir = TempDir::new().unwrap();
let manager = PlanManager::new(temp_dir.path().to_path_buf());
let plan_path = manager.create_plan().await.unwrap();
assert!(plan_path.exists());
assert!(plan_path.extension().map_or(false, |ext| ext == "md"));
}
#[tokio::test]
async fn test_write_and_read_plan() {
let temp_dir = TempDir::new().unwrap();
let manager = PlanManager::new(temp_dir.path().to_path_buf());
let plan_path = manager.create_plan().await.unwrap();
manager.write_plan(&plan_path, "# My Plan\n\nStep 1: Do something").await.unwrap();
let content = manager.read_plan(&plan_path).await.unwrap();
assert!(content.contains("My Plan"));
assert!(content.contains("pending_approval"));
}
#[test]
fn test_plan_mode_check() {
assert!(is_tool_allowed_in_plan_mode("read"));
assert!(is_tool_allowed_in_plan_mode("glob"));
assert!(is_tool_allowed_in_plan_mode("grep"));
assert!(!is_tool_allowed_in_plan_mode("write"));
assert!(!is_tool_allowed_in_plan_mode("bash"));
assert!(!is_tool_allowed_in_plan_mode("edit"));
}
#[test]
fn test_agent_mode_default() {
let mode = AgentMode::default();
assert!(!mode.is_planning());
assert!(mode.plan_file().is_none());
}
}

View File

@@ -0,0 +1,16 @@
[package]
name = "tools-skill"
version = "0.1.0"
edition.workspace = true
license.workspace = true
rust-version.workspace = true
description = "Skill invocation tool for the Owlen agent"
[dependencies]
color-eyre = "0.6"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
plugins = { path = "../../platform/plugins" }
[dev-dependencies]
tempfile = "3.13"

View File

@@ -0,0 +1,275 @@
//! Skill invocation tool for the Owlen agent
//!
//! Provides the Skill tool that allows the agent to invoke skills
//! from plugins programmatically during a conversation.
use color_eyre::eyre::{Result, eyre};
use plugins::PluginManager;
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::PathBuf;
/// Parameters for the Skill tool
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SkillParams {
/// Name of the skill to invoke (e.g., "pdf", "xlsx", or "plugin:skill")
pub skill: String,
}
/// Result of skill invocation
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SkillResult {
/// The skill name that was invoked
pub skill_name: String,
/// The skill content (instructions)
pub content: String,
/// Source of the skill (plugin name)
pub source: String,
}
/// Skill registry for looking up and invoking skills
pub struct SkillRegistry {
/// Local skills directory (e.g., .owlen/skills/)
local_skills_dir: PathBuf,
/// Plugin manager for finding plugin skills
plugin_manager: Option<PluginManager>,
}
impl SkillRegistry {
/// Create a new skill registry
pub fn new() -> Self {
Self {
local_skills_dir: PathBuf::from(".owlen/skills"),
plugin_manager: None,
}
}
/// Create with custom local skills directory
pub fn with_local_dir(mut self, dir: PathBuf) -> Self {
self.local_skills_dir = dir;
self
}
/// Set the plugin manager for discovering plugin skills
pub fn with_plugin_manager(mut self, pm: PluginManager) -> Self {
self.plugin_manager = Some(pm);
self
}
/// Find and load a skill by name
///
/// Skill names can be:
/// - Simple name: "pdf" (searches local, then plugins)
/// - Fully qualified: "plugin-name:skill-name"
pub fn invoke(&self, skill_name: &str) -> Result<SkillResult> {
// Check for fully qualified name (plugin:skill)
if let Some((plugin_name, skill_id)) = skill_name.split_once(':') {
return self.load_plugin_skill(plugin_name, skill_id);
}
// Try local skills first
if let Ok(result) = self.load_local_skill(skill_name) {
return Ok(result);
}
// Try plugins
if let Some(pm) = &self.plugin_manager {
for plugin in pm.plugins() {
if let Ok(result) = self.load_skill_from_plugin(plugin, skill_name) {
return Ok(result);
}
}
}
Err(eyre!(
"Skill '{}' not found.\n\nAvailable skills:\n{}",
skill_name,
self.list_available_skills().join("\n")
))
}
/// Load a local skill from .owlen/skills/
fn load_local_skill(&self, skill_name: &str) -> Result<SkillResult> {
// Try with and without .md extension
let skill_file = self.local_skills_dir.join(format!("{}.md", skill_name));
let skill_dir = self.local_skills_dir.join(skill_name).join("SKILL.md");
let content = if skill_file.exists() {
fs::read_to_string(&skill_file)?
} else if skill_dir.exists() {
fs::read_to_string(&skill_dir)?
} else {
return Err(eyre!("Local skill '{}' not found", skill_name));
};
Ok(SkillResult {
skill_name: skill_name.to_string(),
content: parse_skill_content(&content),
source: "local".to_string(),
})
}
/// Load a skill from a specific plugin
fn load_plugin_skill(&self, plugin_name: &str, skill_name: &str) -> Result<SkillResult> {
let pm = self.plugin_manager.as_ref()
.ok_or_else(|| eyre!("Plugin manager not available"))?;
for plugin in pm.plugins() {
if plugin.manifest.name == plugin_name {
return self.load_skill_from_plugin(plugin, skill_name);
}
}
Err(eyre!("Plugin '{}' not found", plugin_name))
}
/// Load a skill from a plugin
fn load_skill_from_plugin(&self, plugin: &plugins::Plugin, skill_name: &str) -> Result<SkillResult> {
let skill_names = plugin.all_skill_names();
if !skill_names.contains(&skill_name.to_string()) {
return Err(eyre!("Skill '{}' not found in plugin '{}'", skill_name, plugin.manifest.name));
}
// Skills are in skills/<name>/SKILL.md
let skill_path = plugin.base_path.join("skills").join(skill_name).join("SKILL.md");
if !skill_path.exists() {
return Err(eyre!("Skill file not found: {:?}", skill_path));
}
let content = fs::read_to_string(&skill_path)?;
Ok(SkillResult {
skill_name: skill_name.to_string(),
content: parse_skill_content(&content),
source: format!("plugin:{}", plugin.manifest.name),
})
}
/// List all available skills
pub fn list_available_skills(&self) -> Vec<String> {
let mut skills = Vec::new();
// Local skills
if self.local_skills_dir.exists() {
if let Ok(entries) = fs::read_dir(&self.local_skills_dir) {
for entry in entries.filter_map(|e| e.ok()) {
let path = entry.path();
if path.is_file() && path.extension().map_or(false, |e| e == "md") {
if let Some(name) = path.file_stem().and_then(|s| s.to_str()) {
skills.push(format!(" - {} (local)", name));
}
} else if path.is_dir() && path.join("SKILL.md").exists() {
if let Some(name) = path.file_name().and_then(|s| s.to_str()) {
skills.push(format!(" - {} (local)", name));
}
}
}
}
}
// Plugin skills
if let Some(pm) = &self.plugin_manager {
for plugin in pm.plugins() {
for skill_name in plugin.all_skill_names() {
skills.push(format!(" - {} (plugin:{})", skill_name, plugin.manifest.name));
}
}
}
if skills.is_empty() {
skills.push(" (no skills available)".to_string());
}
skills
}
}
impl Default for SkillRegistry {
fn default() -> Self {
Self::new()
}
}
/// Parse skill content, extracting the body (stripping YAML frontmatter)
fn parse_skill_content(content: &str) -> String {
// Check for YAML frontmatter
if content.starts_with("---") {
// Find the end of frontmatter
if let Some(end_idx) = content[3..].find("---") {
let body_start = end_idx + 6; // Skip past the closing ---
if body_start < content.len() {
return content[body_start..].trim().to_string();
}
}
}
content.trim().to_string()
}
/// Execute the Skill tool
pub fn execute_skill(params: &SkillParams, registry: &SkillRegistry) -> Result<String> {
let result = registry.invoke(&params.skill)?;
// Format output for injection into conversation
Ok(format!(
"## Skill: {} ({})\n\n{}",
result.skill_name,
result.source,
result.content
))
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[test]
fn test_parse_skill_content_with_frontmatter() {
let content = r#"---
name: test-skill
description: A test skill
---
# Test Skill
This is the skill content."#;
let parsed = parse_skill_content(content);
assert!(parsed.starts_with("# Test Skill"));
assert!(!parsed.contains("name: test-skill"));
}
#[test]
fn test_parse_skill_content_without_frontmatter() {
let content = "# Just Content\n\nNo frontmatter here.";
let parsed = parse_skill_content(content);
assert_eq!(parsed, content.trim());
}
#[test]
fn test_skill_registry_local() {
let temp_dir = TempDir::new().unwrap();
let skills_dir = temp_dir.path().join(".owlen/skills");
fs::create_dir_all(&skills_dir).unwrap();
// Create a test skill
fs::write(skills_dir.join("test.md"), "# Test Skill\n\nTest content.").unwrap();
let registry = SkillRegistry::new().with_local_dir(skills_dir);
let result = registry.invoke("test").unwrap();
assert_eq!(result.skill_name, "test");
assert_eq!(result.source, "local");
assert!(result.content.contains("Test Skill"));
}
#[test]
fn test_skill_not_found() {
let registry = SkillRegistry::new();
assert!(registry.invoke("nonexistent").is_err());
}
}

View File

@@ -0,0 +1,16 @@
[package]
name = "tools-task"
version = "0.1.0"
edition.workspace = true
license.workspace = true
rust-version.workspace = true
[dependencies]
serde = { version = "1", features = ["derive"] }
serde_json = "1"
color-eyre = "0.6"
permissions = { path = "../../platform/permissions" }
plugins = { path = "../../platform/plugins" }
parking_lot = "0.12"
[dev-dependencies]

View File

@@ -0,0 +1,335 @@
// Note: Result and eyre will be used by spawn_subagent when implemented
#[allow(unused_imports)]
use color_eyre::eyre::{Result, eyre};
use parking_lot::RwLock;
use permissions::Tool;
use plugins::AgentDefinition;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
/// Configuration for spawning a subagent
#[derive(Debug, Clone)]
pub struct SubagentConfig {
/// Agent type/name (e.g., "code-reviewer", "explore")
pub agent_type: String,
/// Task prompt for the agent
pub prompt: String,
/// Optional model override
pub model: Option<String>,
/// Tool whitelist (if None, uses agent's default)
pub tools: Option<Vec<String>>,
/// Parsed agent definition (if from plugin)
pub definition: Option<AgentDefinition>,
}
impl SubagentConfig {
/// Create a new subagent config with just type and prompt
pub fn new(agent_type: String, prompt: String) -> Self {
Self {
agent_type,
prompt,
model: None,
tools: None,
definition: None,
}
}
/// Builder method to set model override
pub fn with_model(mut self, model: String) -> Self {
self.model = Some(model);
self
}
/// Builder method to set tool whitelist
pub fn with_tools(mut self, tools: Vec<String>) -> Self {
self.tools = Some(tools);
self
}
/// Builder method to set agent definition
pub fn with_definition(mut self, definition: AgentDefinition) -> Self {
self.definition = Some(definition);
self
}
}
/// Registry of available subagents
#[derive(Clone, Default)]
pub struct SubagentRegistry {
agents: Arc<RwLock<HashMap<String, AgentDefinition>>>,
}
impl SubagentRegistry {
/// Create a new empty registry
pub fn new() -> Self {
Self::default()
}
/// Register agents from plugin manager
pub fn register_from_plugins(&self, agents: Vec<AgentDefinition>) {
let mut map = self.agents.write();
for agent in agents {
map.insert(agent.name.clone(), agent);
}
}
/// Register built-in agents
pub fn register_builtin(&self) {
let mut map = self.agents.write();
// Explore agent - for codebase exploration
map.insert("explore".to_string(), AgentDefinition {
name: "explore".to_string(),
description: "Explores codebases to find files and understand structure".to_string(),
tools: vec!["read".to_string(), "glob".to_string(), "grep".to_string(), "ls".to_string()],
model: None,
color: Some("blue".to_string()),
system_prompt: "You are an exploration agent. Your purpose is to find relevant files and understand code structure. Use glob to find files by pattern, grep to search for content, ls to list directories, and read to examine files. Be thorough and systematic in your exploration.".to_string(),
source_path: PathBuf::new(),
});
// Plan agent - for designing implementations
map.insert("plan".to_string(), AgentDefinition {
name: "plan".to_string(),
description: "Designs implementation plans and architectures".to_string(),
tools: vec!["read".to_string(), "glob".to_string(), "grep".to_string()],
model: None,
color: Some("green".to_string()),
system_prompt: "You are a planning agent. Your purpose is to design clear implementation strategies and architectures. Read existing code, understand patterns, and create detailed plans. Focus on the 'why' and 'how' rather than the 'what'.".to_string(),
source_path: PathBuf::new(),
});
// Code reviewer - read-only analysis
map.insert("code-reviewer".to_string(), AgentDefinition {
name: "code-reviewer".to_string(),
description: "Reviews code for quality, bugs, and best practices".to_string(),
tools: vec!["read".to_string(), "grep".to_string(), "glob".to_string()],
model: None,
color: Some("yellow".to_string()),
system_prompt: "You are a code review agent. Analyze code for quality, potential bugs, performance issues, and adherence to best practices. Provide constructive feedback with specific examples.".to_string(),
source_path: PathBuf::new(),
});
// Test writer - can read and write test files
map.insert("test-writer".to_string(), AgentDefinition {
name: "test-writer".to_string(),
description: "Writes and updates test files".to_string(),
tools: vec!["read".to_string(), "write".to_string(), "edit".to_string(), "grep".to_string(), "glob".to_string()],
model: None,
color: Some("cyan".to_string()),
system_prompt: "You are a test writing agent. Write comprehensive, well-structured tests that cover edge cases and ensure code correctness. Follow testing best practices and patterns used in the codebase.".to_string(),
source_path: PathBuf::new(),
});
// Documentation agent - can read code and write docs
map.insert("doc-writer".to_string(), AgentDefinition {
name: "doc-writer".to_string(),
description: "Writes and maintains documentation".to_string(),
tools: vec!["read".to_string(), "write".to_string(), "edit".to_string(), "grep".to_string(), "glob".to_string()],
model: None,
color: Some("magenta".to_string()),
system_prompt: "You are a documentation agent. Write clear, comprehensive documentation that helps users understand the code. Include examples, explain concepts, and maintain consistency with existing documentation style.".to_string(),
source_path: PathBuf::new(),
});
// Refactoring agent - full file access but no bash
map.insert("refactorer".to_string(), AgentDefinition {
name: "refactorer".to_string(),
description: "Refactors code while preserving functionality".to_string(),
tools: vec!["read".to_string(), "write".to_string(), "edit".to_string(), "grep".to_string(), "glob".to_string()],
model: None,
color: Some("red".to_string()),
system_prompt: "You are a refactoring agent. Improve code structure, readability, and maintainability while preserving functionality. Follow SOLID principles and language idioms. Make small, incremental changes.".to_string(),
source_path: PathBuf::new(),
});
}
/// Get an agent by name
pub fn get(&self, name: &str) -> Option<AgentDefinition> {
self.agents.read().get(name).cloned()
}
/// List all available agents with their descriptions
pub fn list(&self) -> Vec<(String, String)> {
self.agents.read()
.iter()
.map(|(name, def)| (name.clone(), def.description.clone()))
.collect()
}
/// Check if an agent exists
pub fn contains(&self, name: &str) -> bool {
self.agents.read().contains_key(name)
}
/// Get all agent names
pub fn agent_names(&self) -> Vec<String> {
self.agents.read().keys().cloned().collect()
}
}
/// A specialized subagent with limited tool access (legacy API for backward compatibility)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Subagent {
/// Unique identifier for the subagent
pub name: String,
/// Description of subagent's capabilities and purpose
pub description: String,
/// Keywords that trigger this subagent's selection
pub keywords: Vec<String>,
/// Tools this subagent is allowed to use
pub allowed_tools: Vec<Tool>,
}
impl Subagent {
pub fn new(name: String, description: String, keywords: Vec<String>, allowed_tools: Vec<Tool>) -> Self {
Self {
name,
description,
keywords,
allowed_tools,
}
}
/// Check if this subagent can use the specified tool
pub fn can_use_tool(&self, tool: Tool) -> bool {
self.allowed_tools.contains(&tool)
}
/// Check if this subagent matches the task description
pub fn matches_task(&self, task_description: &str) -> bool {
let task_lower = task_description.to_lowercase();
self.keywords.iter().any(|keyword| {
task_lower.contains(&keyword.to_lowercase())
})
}
}
/// Task execution request
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskRequest {
/// Description of the task to execute
pub description: String,
/// Optional: specific subagent to use
pub agent_name: Option<String>,
}
/// Task execution result
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskResult {
/// The subagent that handled the task
pub agent_name: String,
/// Success or failure
pub success: bool,
/// Result message
pub message: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_subagent_registry_builtin() {
let registry = SubagentRegistry::new();
registry.register_builtin();
// Check that built-in agents are registered
assert!(registry.contains("explore"));
assert!(registry.contains("plan"));
assert!(registry.contains("code-reviewer"));
assert!(registry.contains("test-writer"));
assert!(registry.contains("doc-writer"));
assert!(registry.contains("refactorer"));
// Get an agent and verify its properties
let explore = registry.get("explore").unwrap();
assert_eq!(explore.name, "explore");
assert!(explore.tools.contains(&"read".to_string()));
assert!(explore.tools.contains(&"glob".to_string()));
assert!(explore.tools.contains(&"grep".to_string()));
}
#[test]
fn test_subagent_registry_list() {
let registry = SubagentRegistry::new();
registry.register_builtin();
let agents = registry.list();
assert!(agents.len() >= 6);
// Check that we have expected agents
let names: Vec<String> = agents.iter().map(|(name, _)| name.clone()).collect();
assert!(names.contains(&"explore".to_string()));
assert!(names.contains(&"plan".to_string()));
}
#[test]
fn test_subagent_config_builder() {
let config = SubagentConfig::new("explore".to_string(), "Find all Rust files".to_string())
.with_model("claude-3-opus".to_string())
.with_tools(vec!["read".to_string(), "glob".to_string()]);
assert_eq!(config.agent_type, "explore");
assert_eq!(config.prompt, "Find all Rust files");
assert_eq!(config.model, Some("claude-3-opus".to_string()));
assert_eq!(config.tools, Some(vec!["read".to_string(), "glob".to_string()]));
}
#[test]
fn test_register_from_plugins() {
let registry = SubagentRegistry::new();
let plugin_agent = AgentDefinition {
name: "custom-agent".to_string(),
description: "A custom agent from plugin".to_string(),
tools: vec!["read".to_string()],
model: Some("haiku".to_string()),
color: Some("purple".to_string()),
system_prompt: "Custom prompt".to_string(),
source_path: PathBuf::from("/path/to/plugin"),
};
registry.register_from_plugins(vec![plugin_agent]);
assert!(registry.contains("custom-agent"));
let agent = registry.get("custom-agent").unwrap();
assert_eq!(agent.model, Some("haiku".to_string()));
}
// Legacy API tests for backward compatibility
#[test]
fn subagent_tool_whitelist() {
let agent = Subagent::new(
"reader".to_string(),
"Read-only agent".to_string(),
vec!["read".to_string()],
vec![Tool::Read, Tool::Grep],
);
assert!(agent.can_use_tool(Tool::Read));
assert!(agent.can_use_tool(Tool::Grep));
assert!(!agent.can_use_tool(Tool::Write));
assert!(!agent.can_use_tool(Tool::Bash));
}
#[test]
fn subagent_keyword_matching() {
let agent = Subagent::new(
"tester".to_string(),
"Test agent".to_string(),
vec!["test".to_string(), "unit test".to_string()],
vec![Tool::Read, Tool::Write],
);
assert!(agent.matches_task("Write unit tests for the auth module"));
assert!(agent.matches_task("Add test coverage"));
assert!(!agent.matches_task("Refactor the database layer"));
}
}

View File

@@ -0,0 +1,12 @@
[package]
name = "tools-todo"
version = "0.1.0"
edition.workspace = true
license.workspace = true
rust-version.workspace = true
[dependencies]
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
parking_lot = "0.12"
color-eyre = "0.6"

View File

@@ -0,0 +1,113 @@
//! TodoWrite tool for task list management
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
/// Status of a todo item
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TodoStatus {
Pending,
InProgress,
Completed,
}
/// A todo item
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Todo {
pub content: String,
pub status: TodoStatus,
pub active_form: String, // Present continuous form for display
}
/// Shared todo list state
#[derive(Debug, Clone, Default)]
pub struct TodoList {
inner: Arc<RwLock<Vec<Todo>>>,
}
impl TodoList {
pub fn new() -> Self {
Self::default()
}
/// Replace all todos with new list
pub fn write(&self, todos: Vec<Todo>) {
*self.inner.write() = todos;
}
/// Get current todos
pub fn read(&self) -> Vec<Todo> {
self.inner.read().clone()
}
/// Get the current in-progress task (for status display)
pub fn current_task(&self) -> Option<String> {
self.inner.read()
.iter()
.find(|t| t.status == TodoStatus::InProgress)
.map(|t| t.active_form.clone())
}
/// Get summary stats
pub fn stats(&self) -> (usize, usize, usize) {
let todos = self.inner.read();
let pending = todos.iter().filter(|t| t.status == TodoStatus::Pending).count();
let in_progress = todos.iter().filter(|t| t.status == TodoStatus::InProgress).count();
let completed = todos.iter().filter(|t| t.status == TodoStatus::Completed).count();
(pending, in_progress, completed)
}
/// Format todos for display
pub fn format_display(&self) -> String {
let todos = self.inner.read();
if todos.is_empty() {
return "No tasks".to_string();
}
todos.iter().enumerate().map(|(i, t)| {
let status_icon = match t.status {
TodoStatus::Pending => "",
TodoStatus::InProgress => "",
TodoStatus::Completed => "",
};
format!("{}. {} {}", i + 1, status_icon, t.content)
}).collect::<Vec<_>>().join("\n")
}
}
/// Parse todos from JSON tool input
pub fn parse_todos(input: &serde_json::Value) -> color_eyre::Result<Vec<Todo>> {
let todos = input.get("todos")
.ok_or_else(|| color_eyre::eyre::eyre!("Missing 'todos' field"))?;
serde_json::from_value(todos.clone())
.map_err(|e| color_eyre::eyre::eyre!("Invalid todos format: {}", e))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_todo_list() {
let list = TodoList::new();
list.write(vec![
Todo {
content: "First task".to_string(),
status: TodoStatus::Completed,
active_form: "Completing first task".to_string(),
},
Todo {
content: "Second task".to_string(),
status: TodoStatus::InProgress,
active_form: "Working on second task".to_string(),
},
]);
assert_eq!(list.current_task(), Some("Working on second task".to_string()));
assert_eq!(list.stats(), (0, 1, 1));
}
}

View File

@@ -0,0 +1,21 @@
[package]
name = "tools-web"
version = "0.1.0"
edition.workspace = true
license.workspace = true
rust-version.workspace = true
[dependencies]
reqwest = { version = "0.12", features = ["json"] }
tokio = { version = "1.39", features = ["macros"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
color-eyre = "0.6"
url = "2.5"
async-trait = "0.1"
scraper = "0.18"
urlencoding = "2.1"
[dev-dependencies]
tokio = { version = "1.39", features = ["macros", "rt-multi-thread"] }
wiremock = "0.6"

325
crates/tools/web/src/lib.rs Normal file
View File

@@ -0,0 +1,325 @@
use color_eyre::eyre::{Result, eyre};
use reqwest::redirect::Policy;
use scraper::{Html, Selector};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use url::Url;
/// WebFetch response
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FetchResponse {
pub url: String,
pub status: u16,
pub content: String,
pub content_type: Option<String>,
}
/// WebFetch client with domain filtering
pub struct WebFetchClient {
allowed_domains: HashSet<String>,
blocked_domains: HashSet<String>,
client: reqwest::Client,
}
impl WebFetchClient {
/// Create a new WebFetch client
pub fn new() -> Self {
let client = reqwest::Client::builder()
.redirect(Policy::none()) // Don't follow redirects automatically
.build()
.unwrap();
Self {
allowed_domains: HashSet::new(),
blocked_domains: HashSet::new(),
client,
}
}
/// Add an allowed domain
pub fn allow_domain(&mut self, domain: &str) {
self.allowed_domains.insert(domain.to_lowercase());
}
/// Add a blocked domain
pub fn block_domain(&mut self, domain: &str) {
self.blocked_domains.insert(domain.to_lowercase());
}
/// Check if a domain is allowed
fn is_domain_allowed(&self, domain: &str) -> bool {
let domain_lower = domain.to_lowercase();
// If explicitly blocked, deny
if self.blocked_domains.contains(&domain_lower) {
return false;
}
// If allowlist is empty, allow all (except blocked)
if self.allowed_domains.is_empty() {
return true;
}
// Otherwise, must be in allowlist
self.allowed_domains.contains(&domain_lower)
}
/// Fetch a URL
pub async fn fetch(&self, url: &str) -> Result<FetchResponse> {
let parsed_url = Url::parse(url)?;
let domain = parsed_url
.host_str()
.ok_or_else(|| eyre!("No host in URL"))?;
// Check domain permission
if !self.is_domain_allowed(domain) {
return Err(eyre!("Domain not allowed: {}", domain));
}
// Make the request
let response = self.client.get(url).send().await?;
let status = response.status().as_u16();
// Handle redirects manually
if status >= 300 && status < 400 {
if let Some(location) = response.headers().get("location") {
let location_str = location.to_str()?;
// Parse the redirect URL (may be relative)
let redirect_url = if location_str.starts_with("http") {
Url::parse(location_str)?
} else {
parsed_url.join(location_str)?
};
let redirect_domain = redirect_url
.host_str()
.ok_or_else(|| eyre!("No host in redirect URL"))?;
// Check if redirect domain is allowed
if !self.is_domain_allowed(redirect_domain) {
return Err(eyre!(
"Redirect to unapproved domain: {} -> {}",
domain,
redirect_domain
));
}
return Err(eyre!(
"Redirect detected: {} -> {}. Use the redirect URL directly.",
url,
redirect_url
));
}
}
let content_type = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let content = response.text().await?;
Ok(FetchResponse {
url: url.to_string(),
status,
content,
content_type,
})
}
}
impl Default for WebFetchClient {
fn default() -> Self {
Self::new()
}
}
/// Search provider trait
#[async_trait::async_trait]
pub trait SearchProvider: Send + Sync {
fn name(&self) -> &str;
async fn search(&self, query: &str) -> Result<Vec<SearchResult>>;
}
/// Search result
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub title: String,
pub url: String,
pub snippet: String,
}
/// Stub search provider for testing
pub struct StubSearchProvider {
results: Vec<SearchResult>,
}
impl StubSearchProvider {
pub fn new(results: Vec<SearchResult>) -> Self {
Self { results }
}
}
#[async_trait::async_trait]
impl SearchProvider for StubSearchProvider {
fn name(&self) -> &str {
"stub"
}
async fn search(&self, _query: &str) -> Result<Vec<SearchResult>> {
Ok(self.results.clone())
}
}
/// DuckDuckGo HTML search provider
pub struct DuckDuckGoSearchProvider {
client: reqwest::Client,
max_results: usize,
}
impl DuckDuckGoSearchProvider {
/// Create a new DuckDuckGo search provider with default max results (10)
pub fn new() -> Self {
Self::with_max_results(10)
}
/// Create a new DuckDuckGo search provider with custom max results
pub fn with_max_results(max_results: usize) -> Self {
let client = reqwest::Client::builder()
.user_agent("Mozilla/5.0 (compatible; Owlen/1.0)")
.build()
.unwrap();
Self { client, max_results }
}
/// Parse DuckDuckGo HTML results
fn parse_results(html: &str, max_results: usize) -> Result<Vec<SearchResult>> {
let document = Html::parse_document(html);
// DuckDuckGo HTML selectors
let result_selector = Selector::parse(".result").map_err(|e| eyre!("Invalid selector: {:?}", e))?;
let title_selector = Selector::parse(".result__title a").map_err(|e| eyre!("Invalid selector: {:?}", e))?;
let snippet_selector = Selector::parse(".result__snippet").map_err(|e| eyre!("Invalid selector: {:?}", e))?;
let mut results = Vec::new();
for result in document.select(&result_selector).take(max_results) {
let title = result
.select(&title_selector)
.next()
.map(|e| e.text().collect::<String>().trim().to_string())
.unwrap_or_default();
let url = result
.select(&title_selector)
.next()
.and_then(|e| e.value().attr("href"))
.unwrap_or_default()
.to_string();
let snippet = result
.select(&snippet_selector)
.next()
.map(|e| e.text().collect::<String>().trim().to_string())
.unwrap_or_default();
if !title.is_empty() && !url.is_empty() {
results.push(SearchResult { title, url, snippet });
}
}
Ok(results)
}
}
impl Default for DuckDuckGoSearchProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl SearchProvider for DuckDuckGoSearchProvider {
fn name(&self) -> &str {
"duckduckgo"
}
async fn search(&self, query: &str) -> Result<Vec<SearchResult>> {
let encoded_query = urlencoding::encode(query);
let url = format!("https://html.duckduckgo.com/html/?q={}", encoded_query);
let response = self.client.get(&url).send().await?;
let html = response.text().await?;
Self::parse_results(&html, self.max_results)
}
}
/// WebSearch client with pluggable providers
pub struct WebSearchClient {
provider: Box<dyn SearchProvider>,
}
impl WebSearchClient {
pub fn new(provider: Box<dyn SearchProvider>) -> Self {
Self { provider }
}
pub fn provider_name(&self) -> &str {
self.provider.name()
}
pub async fn search(&self, query: &str) -> Result<Vec<SearchResult>> {
self.provider.search(query).await
}
}
/// Format search results for LLM consumption (markdown format)
pub fn format_search_results(results: &[SearchResult]) -> String {
if results.is_empty() {
return "No results found.".to_string();
}
results
.iter()
.enumerate()
.map(|(i, r)| format!("{}. [{}]({})\n {}", i + 1, r.title, r.url, r.snippet))
.collect::<Vec<_>>()
.join("\n\n")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn domain_filtering_allowlist() {
let mut client = WebFetchClient::new();
client.allow_domain("example.com");
assert!(client.is_domain_allowed("example.com"));
assert!(!client.is_domain_allowed("evil.com"));
}
#[test]
fn domain_filtering_blocklist() {
let mut client = WebFetchClient::new();
client.block_domain("evil.com");
assert!(client.is_domain_allowed("example.com")); // Empty allowlist = allow all
assert!(!client.is_domain_allowed("evil.com"));
}
#[test]
fn domain_filtering_case_insensitive() {
let mut client = WebFetchClient::new();
client.allow_domain("Example.COM");
assert!(client.is_domain_allowed("example.com"));
assert!(client.is_domain_allowed("EXAMPLE.COM"));
}
}

View File

@@ -0,0 +1,161 @@
use tools_web::{WebFetchClient, WebSearchClient, StubSearchProvider, SearchResult};
use wiremock::{MockServer, Mock, ResponseTemplate};
use wiremock::matchers::{method, path};
#[tokio::test]
async fn webfetch_domain_whitelist_only() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/test"))
.respond_with(ResponseTemplate::new(200).set_body_string("Hello from allowed domain"))
.mount(&mock_server)
.await;
let mut client = WebFetchClient::new();
client.allow_domain("localhost");
client.allow_domain("127.0.0.1"); // Domain without port
// Fetch from allowed domain should work
let url = format!("{}/test", mock_server.uri());
let response = client.fetch(&url).await.unwrap();
assert_eq!(response.status, 200);
assert!(response.content.contains("Hello from allowed domain"));
// Create a client with different allowlist
let mut strict_client = WebFetchClient::new();
strict_client.allow_domain("example.com");
// Fetch from non-allowed domain should fail
let result = strict_client.fetch(&url).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Domain not allowed"));
}
#[tokio::test]
async fn webfetch_redirect_to_unapproved_domain() {
let mock_server = MockServer::start().await;
// Mock a redirect to a different domain
Mock::given(method("GET"))
.and(path("/redirect"))
.respond_with(
ResponseTemplate::new(302)
.insert_header("location", "https://evil.com/malware")
)
.mount(&mock_server)
.await;
let mut client = WebFetchClient::new();
client.allow_domain("localhost");
client.allow_domain("127.0.0.1"); // Domain without port
// evil.com is NOT in the allowlist
let url = format!("{}/redirect", mock_server.uri());
let result = client.fetch(&url).await;
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Redirect to unapproved domain") || err_msg.contains("evil.com"));
}
#[tokio::test]
async fn webfetch_redirect_to_approved_domain() {
let mock_server = MockServer::start().await;
let redirect_url = format!("{}/target", mock_server.uri());
// Mock a redirect to an approved domain
Mock::given(method("GET"))
.and(path("/redirect"))
.respond_with(
ResponseTemplate::new(302)
.insert_header("location", &redirect_url)
)
.mount(&mock_server)
.await;
let mut client = WebFetchClient::new();
client.allow_domain("localhost");
client.allow_domain("127.0.0.1"); // Domain without port
let url = format!("{}/redirect", mock_server.uri());
let result = client.fetch(&url).await;
// Should fail but with a message about using the redirect URL
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Redirect detected") || err_msg.contains("Use the redirect URL"));
}
#[tokio::test]
async fn webfetch_blocklist_overrides_allowlist() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/test"))
.respond_with(ResponseTemplate::new(200).set_body_string("Hello"))
.mount(&mock_server)
.await;
let domain = "127.0.0.1";
let mut client = WebFetchClient::new();
client.allow_domain(domain);
client.block_domain(domain); // Block overrides allow
let url = format!("{}/test", mock_server.uri());
let result = client.fetch(&url).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Domain not allowed"));
}
#[tokio::test]
async fn websearch_pluggable_provider() {
let stub_results = vec![
SearchResult {
title: "Test Result 1".to_string(),
url: "https://example.com/1".to_string(),
snippet: "This is a test result".to_string(),
},
SearchResult {
title: "Test Result 2".to_string(),
url: "https://example.com/2".to_string(),
snippet: "Another test result".to_string(),
},
];
let provider = StubSearchProvider::new(stub_results.clone());
let client = WebSearchClient::new(Box::new(provider));
assert_eq!(client.provider_name(), "stub");
let results = client.search("test query").await.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].title, "Test Result 1");
assert_eq!(results[1].url, "https://example.com/2");
}
#[tokio::test]
async fn webfetch_successful_request() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/api/data"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string(r#"{"status":"ok"}"#)
.insert_header("content-type", "application/json")
)
.mount(&mock_server)
.await;
let client = WebFetchClient::new(); // Empty allowlist = allow all
let url = format!("{}/api/data", mock_server.uri());
let response = client.fetch(&url).await.unwrap();
assert_eq!(response.status, 200);
assert!(response.content.contains("status"));
assert!(response.content_type.is_some()); // Just verify content-type is present
}