Compare commits
10 Commits
688d1fe58a
...
4a07b97eab
| Author | SHA1 | Date | |
|---|---|---|---|
| 4a07b97eab | |||
| 10c8e2baae | |||
| 09c8c9d83e | |||
| 5caf502009 | |||
| 04a7085007 | |||
| 6022aeb2b0 | |||
| e77e33ce2f | |||
| f87e5d2796 | |||
| 3c436fda54 | |||
| 173403379f |
13
Cargo.toml
13
Cargo.toml
@@ -1,13 +1,26 @@
|
||||
[workspace]
|
||||
members = [
|
||||
"crates/app/cli",
|
||||
"crates/app/ui",
|
||||
"crates/core/agent",
|
||||
"crates/llm/core",
|
||||
"crates/llm/anthropic",
|
||||
"crates/llm/ollama",
|
||||
"crates/llm/openai",
|
||||
"crates/platform/config",
|
||||
"crates/platform/hooks",
|
||||
"crates/platform/permissions",
|
||||
"crates/platform/plugins",
|
||||
"crates/tools/ask",
|
||||
"crates/tools/bash",
|
||||
"crates/tools/fs",
|
||||
"crates/tools/notebook",
|
||||
"crates/tools/plan",
|
||||
"crates/tools/skill",
|
||||
"crates/tools/slash",
|
||||
"crates/tools/task",
|
||||
"crates/tools/todo",
|
||||
"crates/tools/web",
|
||||
"crates/integration/mcp-client",
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
@@ -11,6 +11,8 @@ tokio = { version = "1.39", features = ["macros", "rt-multi-thread"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
color-eyre = "0.6"
|
||||
agent-core = { path = "../../core/agent" }
|
||||
llm-core = { path = "../../llm/core" }
|
||||
llm-ollama = { path = "../../llm/ollama" }
|
||||
tools-fs = { path = "../../tools/fs" }
|
||||
tools-bash = { path = "../../tools/bash" }
|
||||
@@ -18,6 +20,9 @@ tools-slash = { path = "../../tools/slash" }
|
||||
config-agent = { package = "config-agent", path = "../../platform/config" }
|
||||
permissions = { path = "../../platform/permissions" }
|
||||
hooks = { path = "../../platform/hooks" }
|
||||
plugins = { path = "../../platform/plugins" }
|
||||
ui = { path = "../ui" }
|
||||
atty = "0.2"
|
||||
futures-util = "0.3.31"
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
382
crates/app/cli/src/commands.rs
Normal file
382
crates/app/cli/src/commands.rs
Normal file
@@ -0,0 +1,382 @@
|
||||
//! Built-in commands for CLI and TUI
|
||||
//!
|
||||
//! Provides handlers for /help, /mcp, /hooks, /clear, and other built-in commands.
|
||||
|
||||
use ui::{CommandInfo, CommandOutput, OutputFormat, TreeNode, ListItem};
|
||||
use permissions::PermissionManager;
|
||||
use hooks::HookManager;
|
||||
use plugins::PluginManager;
|
||||
use agent_core::SessionStats;
|
||||
|
||||
/// Result of executing a built-in command
|
||||
pub enum CommandResult {
|
||||
/// Command produced output to display
|
||||
Output(CommandOutput),
|
||||
/// Command was handled but produced no output (e.g., /clear)
|
||||
Handled,
|
||||
/// Command was not recognized
|
||||
NotFound,
|
||||
/// Command needs to exit the session
|
||||
Exit,
|
||||
}
|
||||
|
||||
/// Built-in command handler
|
||||
pub struct BuiltinCommands<'a> {
|
||||
plugin_manager: Option<&'a PluginManager>,
|
||||
hook_manager: Option<&'a HookManager>,
|
||||
permission_manager: Option<&'a PermissionManager>,
|
||||
stats: Option<&'a SessionStats>,
|
||||
}
|
||||
|
||||
impl<'a> BuiltinCommands<'a> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
plugin_manager: None,
|
||||
hook_manager: None,
|
||||
permission_manager: None,
|
||||
stats: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_plugins(mut self, pm: &'a PluginManager) -> Self {
|
||||
self.plugin_manager = Some(pm);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_hooks(mut self, hm: &'a HookManager) -> Self {
|
||||
self.hook_manager = Some(hm);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_permissions(mut self, perms: &'a PermissionManager) -> Self {
|
||||
self.permission_manager = Some(perms);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_stats(mut self, stats: &'a SessionStats) -> Self {
|
||||
self.stats = Some(stats);
|
||||
self
|
||||
}
|
||||
|
||||
/// Execute a built-in command
|
||||
pub fn execute(&self, command: &str) -> CommandResult {
|
||||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||||
let cmd = parts.first().map(|s| s.trim_start_matches('/'));
|
||||
|
||||
match cmd {
|
||||
Some("help") | Some("?") => CommandResult::Output(self.help()),
|
||||
Some("mcp") => CommandResult::Output(self.mcp()),
|
||||
Some("hooks") => CommandResult::Output(self.hooks()),
|
||||
Some("plugins") => CommandResult::Output(self.plugins()),
|
||||
Some("status") => CommandResult::Output(self.status()),
|
||||
Some("permissions") | Some("perms") => CommandResult::Output(self.permissions()),
|
||||
Some("clear") => CommandResult::Handled,
|
||||
Some("exit") | Some("quit") | Some("q") => CommandResult::Exit,
|
||||
_ => CommandResult::NotFound,
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate help output
|
||||
fn help(&self) -> CommandOutput {
|
||||
let mut commands = vec![
|
||||
// Built-in commands
|
||||
CommandInfo::new("help", "Show available commands", "builtin"),
|
||||
CommandInfo::new("clear", "Clear the screen", "builtin"),
|
||||
CommandInfo::new("status", "Show session status", "builtin"),
|
||||
CommandInfo::new("permissions", "Show permission settings", "builtin"),
|
||||
CommandInfo::new("mcp", "List MCP servers and tools", "builtin"),
|
||||
CommandInfo::new("hooks", "Show loaded hooks", "builtin"),
|
||||
CommandInfo::new("plugins", "Show loaded plugins", "builtin"),
|
||||
CommandInfo::new("checkpoint", "Save session state", "builtin"),
|
||||
CommandInfo::new("checkpoints", "List saved checkpoints", "builtin"),
|
||||
CommandInfo::new("rewind", "Restore from checkpoint", "builtin"),
|
||||
CommandInfo::new("compact", "Compact conversation context", "builtin"),
|
||||
CommandInfo::new("exit", "Exit the session", "builtin"),
|
||||
];
|
||||
|
||||
// Add plugin commands
|
||||
if let Some(pm) = self.plugin_manager {
|
||||
for plugin in pm.plugins() {
|
||||
for cmd_name in plugin.all_command_names() {
|
||||
commands.push(CommandInfo::new(
|
||||
&cmd_name,
|
||||
&format!("Plugin command from {}", plugin.manifest.name),
|
||||
&format!("plugin:{}", plugin.manifest.name),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CommandOutput::help_table(&commands)
|
||||
}
|
||||
|
||||
/// Generate MCP servers output
|
||||
fn mcp(&self) -> CommandOutput {
|
||||
let mut servers: Vec<(String, Vec<String>)> = vec![];
|
||||
|
||||
// Get MCP servers from plugins
|
||||
if let Some(pm) = self.plugin_manager {
|
||||
for plugin in pm.plugins() {
|
||||
// Check for .mcp.json in plugin directory
|
||||
let mcp_path = plugin.base_path.join(".mcp.json");
|
||||
if mcp_path.exists() {
|
||||
if let Ok(content) = std::fs::read_to_string(&mcp_path) {
|
||||
if let Ok(config) = serde_json::from_str::<serde_json::Value>(&content) {
|
||||
if let Some(mcpservers) = config.get("mcpServers").and_then(|v| v.as_object()) {
|
||||
for (name, _) in mcpservers {
|
||||
servers.push((
|
||||
format!("{} ({})", name, plugin.manifest.name),
|
||||
vec!["(connect to discover tools)".to_string()],
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if servers.is_empty() {
|
||||
CommandOutput::new(OutputFormat::Text {
|
||||
content: "No MCP servers configured.\n\nAdd MCP servers in plugin .mcp.json files.".to_string(),
|
||||
})
|
||||
} else {
|
||||
CommandOutput::mcp_tree(&servers)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate hooks output
|
||||
fn hooks(&self) -> CommandOutput {
|
||||
let mut hooks_list: Vec<(String, String, bool)> = vec![];
|
||||
|
||||
// Check for file-based hooks in .owlen/hooks/
|
||||
let hook_events = ["PreToolUse", "PostToolUse", "SessionStart", "SessionEnd",
|
||||
"UserPromptSubmit", "PreCompact", "Stop", "SubagentStop"];
|
||||
|
||||
for event in hook_events {
|
||||
let path = format!(".owlen/hooks/{}", event);
|
||||
let exists = std::path::Path::new(&path).exists();
|
||||
if exists {
|
||||
hooks_list.push((event.to_string(), path, true));
|
||||
}
|
||||
}
|
||||
|
||||
// Get hooks from plugins
|
||||
if let Some(pm) = self.plugin_manager {
|
||||
for plugin in pm.plugins() {
|
||||
if let Some(hooks_config) = plugin.load_hooks_config().ok().flatten() {
|
||||
// hooks_config.hooks is HashMap<String, Vec<HookMatcher>>
|
||||
for (event_name, matchers) in &hooks_config.hooks {
|
||||
for matcher in matchers {
|
||||
for hook_def in &matcher.hooks {
|
||||
let cmd = hook_def.command.as_deref()
|
||||
.or(hook_def.prompt.as_deref())
|
||||
.unwrap_or("(no command)");
|
||||
hooks_list.push((
|
||||
event_name.clone(),
|
||||
format!("{}: {}", plugin.manifest.name, cmd),
|
||||
true,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hooks_list.is_empty() {
|
||||
CommandOutput::new(OutputFormat::Text {
|
||||
content: "No hooks configured.\n\nAdd hooks in .owlen/hooks/ or plugin hooks.json files.".to_string(),
|
||||
})
|
||||
} else {
|
||||
CommandOutput::hooks_list(&hooks_list)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate plugins output
|
||||
fn plugins(&self) -> CommandOutput {
|
||||
if let Some(pm) = self.plugin_manager {
|
||||
let plugins = pm.plugins();
|
||||
if plugins.is_empty() {
|
||||
return CommandOutput::new(OutputFormat::Text {
|
||||
content: "No plugins loaded.\n\nPlace plugins in:\n - ~/.config/owlen/plugins (user)\n - .owlen/plugins (project)".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Build tree of plugins and their components
|
||||
let children: Vec<TreeNode> = plugins.iter().map(|p| {
|
||||
let mut plugin_children = vec![];
|
||||
|
||||
let commands = p.all_command_names();
|
||||
if !commands.is_empty() {
|
||||
plugin_children.push(TreeNode::new("Commands").with_children(
|
||||
commands.iter().map(|c| TreeNode::new(format!("/{}", c))).collect()
|
||||
));
|
||||
}
|
||||
|
||||
let agents = p.all_agent_names();
|
||||
if !agents.is_empty() {
|
||||
plugin_children.push(TreeNode::new("Agents").with_children(
|
||||
agents.iter().map(|a| TreeNode::new(a)).collect()
|
||||
));
|
||||
}
|
||||
|
||||
let skills = p.all_skill_names();
|
||||
if !skills.is_empty() {
|
||||
plugin_children.push(TreeNode::new("Skills").with_children(
|
||||
skills.iter().map(|s| TreeNode::new(s)).collect()
|
||||
));
|
||||
}
|
||||
|
||||
TreeNode::new(format!("{} v{}", p.manifest.name, p.manifest.version))
|
||||
.with_children(plugin_children)
|
||||
}).collect();
|
||||
|
||||
CommandOutput::new(OutputFormat::Tree {
|
||||
root: TreeNode::new("Loaded Plugins").with_children(children),
|
||||
})
|
||||
} else {
|
||||
CommandOutput::new(OutputFormat::Text {
|
||||
content: "Plugin manager not available.".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate status output
|
||||
fn status(&self) -> CommandOutput {
|
||||
let mut items = vec![];
|
||||
|
||||
if let Some(stats) = self.stats {
|
||||
items.push(ListItem {
|
||||
text: format!("Messages: {}", stats.total_messages),
|
||||
marker: Some("📊".to_string()),
|
||||
style: None,
|
||||
});
|
||||
items.push(ListItem {
|
||||
text: format!("Tool Calls: {}", stats.total_tool_calls),
|
||||
marker: Some("🔧".to_string()),
|
||||
style: None,
|
||||
});
|
||||
items.push(ListItem {
|
||||
text: format!("Est. Tokens: ~{}", stats.estimated_tokens),
|
||||
marker: Some("📝".to_string()),
|
||||
style: None,
|
||||
});
|
||||
let uptime = stats.start_time.elapsed().unwrap_or_default();
|
||||
items.push(ListItem {
|
||||
text: format!("Uptime: {}", SessionStats::format_duration(uptime)),
|
||||
marker: Some("⏱️".to_string()),
|
||||
style: None,
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(perms) = self.permission_manager {
|
||||
items.push(ListItem {
|
||||
text: format!("Mode: {:?}", perms.mode()),
|
||||
marker: Some("🔒".to_string()),
|
||||
style: None,
|
||||
});
|
||||
}
|
||||
|
||||
if items.is_empty() {
|
||||
CommandOutput::new(OutputFormat::Text {
|
||||
content: "Session status not available.".to_string(),
|
||||
})
|
||||
} else {
|
||||
CommandOutput::new(OutputFormat::List { items })
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate permissions output
|
||||
fn permissions(&self) -> CommandOutput {
|
||||
if let Some(perms) = self.permission_manager {
|
||||
let mode = perms.mode();
|
||||
let mode_str = format!("{:?}", mode);
|
||||
|
||||
let mut items = vec![
|
||||
ListItem {
|
||||
text: format!("Current Mode: {}", mode_str),
|
||||
marker: Some("🔒".to_string()),
|
||||
style: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Add tool permissions summary
|
||||
let (read_status, write_status, bash_status) = match mode {
|
||||
permissions::Mode::Plan => ("✅ Allowed", "❓ Ask", "❓ Ask"),
|
||||
permissions::Mode::AcceptEdits => ("✅ Allowed", "✅ Allowed", "❓ Ask"),
|
||||
permissions::Mode::Code => ("✅ Allowed", "✅ Allowed", "✅ Allowed"),
|
||||
};
|
||||
|
||||
items.push(ListItem {
|
||||
text: format!("Read/Grep/Glob: {}", read_status),
|
||||
marker: None,
|
||||
style: None,
|
||||
});
|
||||
items.push(ListItem {
|
||||
text: format!("Write/Edit: {}", write_status),
|
||||
marker: None,
|
||||
style: None,
|
||||
});
|
||||
items.push(ListItem {
|
||||
text: format!("Bash: {}", bash_status),
|
||||
marker: None,
|
||||
style: None,
|
||||
});
|
||||
|
||||
CommandOutput::new(OutputFormat::List { items })
|
||||
} else {
|
||||
CommandOutput::new(OutputFormat::Text {
|
||||
content: "Permission manager not available.".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for BuiltinCommands<'_> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_help_command() {
|
||||
let handler = BuiltinCommands::new();
|
||||
match handler.execute("/help") {
|
||||
CommandResult::Output(output) => {
|
||||
match output.format {
|
||||
OutputFormat::Table { headers, rows } => {
|
||||
assert!(!headers.is_empty());
|
||||
assert!(!rows.is_empty());
|
||||
}
|
||||
_ => panic!("Expected Table format"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected Output result"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exit_command() {
|
||||
let handler = BuiltinCommands::new();
|
||||
assert!(matches!(handler.execute("/exit"), CommandResult::Exit));
|
||||
assert!(matches!(handler.execute("/quit"), CommandResult::Exit));
|
||||
assert!(matches!(handler.execute("/q"), CommandResult::Exit));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear_command() {
|
||||
let handler = BuiltinCommands::new();
|
||||
assert!(matches!(handler.execute("/clear"), CommandResult::Handled));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unknown_command() {
|
||||
let handler = BuiltinCommands::new();
|
||||
assert!(matches!(handler.execute("/unknown"), CommandResult::NotFound));
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,19 @@
|
||||
mod commands;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
use color_eyre::eyre::{Result, eyre};
|
||||
use config_agent::load_settings;
|
||||
use futures_util::TryStreamExt;
|
||||
use hooks::{HookEvent, HookManager, HookResult};
|
||||
use llm_ollama::{OllamaClient, OllamaOptions, types::ChatMessage};
|
||||
use llm_core::ChatOptions;
|
||||
use llm_ollama::OllamaClient;
|
||||
use permissions::{PermissionDecision, Tool};
|
||||
use plugins::PluginManager;
|
||||
use serde::Serialize;
|
||||
use std::io::{self, Write};
|
||||
use std::io::Write;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
pub use commands::{BuiltinCommands, CommandResult};
|
||||
|
||||
#[derive(Debug, Clone, Copy, ValueEnum)]
|
||||
enum OutputFormat {
|
||||
Text,
|
||||
@@ -49,6 +54,51 @@ struct StreamEvent {
|
||||
stats: Option<Stats>,
|
||||
}
|
||||
|
||||
/// Application context shared across the session
|
||||
pub struct AppContext {
|
||||
pub plugin_manager: PluginManager,
|
||||
pub config: config_agent::Settings,
|
||||
}
|
||||
|
||||
impl AppContext {
|
||||
pub fn new() -> Result<Self> {
|
||||
let config = load_settings(None).unwrap_or_default();
|
||||
|
||||
let mut plugin_manager = PluginManager::new();
|
||||
// Non-fatal: just log warnings, don't fail startup
|
||||
if let Err(e) = plugin_manager.load_all() {
|
||||
eprintln!("Warning: Failed to load some plugins: {}", e);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
plugin_manager,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
/// Print loaded plugins and available commands
|
||||
pub fn print_plugin_info(&self) {
|
||||
let plugins = self.plugin_manager.plugins();
|
||||
if !plugins.is_empty() {
|
||||
println!("\nLoaded {} plugin(s):", plugins.len());
|
||||
for plugin in plugins {
|
||||
println!(" - {} v{}", plugin.manifest.name, plugin.manifest.version);
|
||||
if let Some(desc) = &plugin.manifest.description {
|
||||
println!(" {}", desc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let commands = self.plugin_manager.all_commands();
|
||||
if !commands.is_empty() {
|
||||
println!("\nAvailable plugin commands:");
|
||||
for (name, _path) in &commands {
|
||||
println!(" /{}", name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_session_id() -> String {
|
||||
let timestamp = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
@@ -150,6 +200,9 @@ struct Args {
|
||||
/// Output format (text, json, stream-json)
|
||||
#[arg(long, value_enum, default_value = "text")]
|
||||
output_format: OutputFormat,
|
||||
/// Disable TUI and use legacy text-based REPL
|
||||
#[arg(long)]
|
||||
no_tui: bool,
|
||||
#[arg()]
|
||||
prompt: Vec<String>,
|
||||
#[command(subcommand)]
|
||||
@@ -160,7 +213,10 @@ struct Args {
|
||||
async fn main() -> Result<()> {
|
||||
color_eyre::install()?;
|
||||
let args = Args::parse();
|
||||
let mut settings = load_settings(None).unwrap_or_default();
|
||||
|
||||
// Initialize application context with plugins
|
||||
let app_context = AppContext::new()?;
|
||||
let mut settings = app_context.config.clone();
|
||||
|
||||
// Override mode if specified via CLI
|
||||
if let Some(mode) = args.mode {
|
||||
@@ -171,7 +227,16 @@ async fn main() -> Result<()> {
|
||||
let perms = settings.create_permission_manager();
|
||||
|
||||
// Create hook manager
|
||||
let hook_mgr = HookManager::new(".");
|
||||
let mut hook_mgr = HookManager::new(".");
|
||||
|
||||
// Register plugin hooks
|
||||
for plugin in app_context.plugin_manager.plugins() {
|
||||
if let Ok(Some(hooks_config)) = plugin.load_hooks_config() {
|
||||
for (event, command, pattern, timeout) in plugin.register_hooks_with_manager(&hooks_config) {
|
||||
hook_mgr.register_hook(event, command, pattern, timeout);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate session ID
|
||||
let session_id = generate_session_id();
|
||||
@@ -395,19 +460,20 @@ async fn main() -> Result<()> {
|
||||
HookResult::Allow => {}
|
||||
}
|
||||
|
||||
// Look for command file in .owlen/commands/
|
||||
let command_path = format!(".owlen/commands/{}.md", command_name);
|
||||
// Look for command file in .owlen/commands/ first
|
||||
let local_command_path = format!(".owlen/commands/{}.md", command_name);
|
||||
|
||||
// Read the command file
|
||||
let content = match tools_fs::read_file(&command_path) {
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
return Err(eyre!(
|
||||
"Slash command '{}' not found at {}",
|
||||
command_name,
|
||||
command_path
|
||||
));
|
||||
}
|
||||
// Try local commands first, then plugin commands
|
||||
let content = if let Ok(c) = tools_fs::read_file(&local_command_path) {
|
||||
c
|
||||
} else if let Some(plugin_path) = app_context.plugin_manager.all_commands().get(&command_name) {
|
||||
// Found in plugins
|
||||
tools_fs::read_file(&plugin_path.to_string_lossy())?
|
||||
} else {
|
||||
return Err(eyre!(
|
||||
"Slash command '{}' not found in .owlen/commands/ or plugins",
|
||||
command_name
|
||||
));
|
||||
};
|
||||
|
||||
// Parse with arguments
|
||||
@@ -435,76 +501,316 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
let prompt = if args.prompt.is_empty() {
|
||||
"Say hello".to_string()
|
||||
} else {
|
||||
args.prompt.join(" ")
|
||||
};
|
||||
|
||||
let model = args.model.unwrap_or(settings.model);
|
||||
let api_key = args.api_key.or(settings.api_key);
|
||||
let model = args.model.unwrap_or(settings.model.clone());
|
||||
let api_key = args.api_key.or(settings.api_key.clone());
|
||||
|
||||
// Use Ollama Cloud when model has "-cloud" suffix AND API key is set
|
||||
let use_cloud = model.ends_with("-cloud") && api_key.is_some();
|
||||
let client = if use_cloud {
|
||||
OllamaClient::with_cloud().with_api_key(api_key.unwrap())
|
||||
} else {
|
||||
let base_url = args.ollama_url.unwrap_or(settings.ollama_url);
|
||||
let base_url = args.ollama_url.unwrap_or(settings.ollama_url.clone());
|
||||
let mut client = OllamaClient::new(base_url);
|
||||
if let Some(key) = api_key {
|
||||
client = client.with_api_key(key);
|
||||
}
|
||||
client
|
||||
};
|
||||
let opts = OllamaOptions {
|
||||
model,
|
||||
stream: true,
|
||||
};
|
||||
let opts = ChatOptions::new(model);
|
||||
|
||||
let msgs = vec![ChatMessage {
|
||||
role: "user".into(),
|
||||
content: prompt.clone(),
|
||||
}];
|
||||
// Check if interactive mode (no prompt provided)
|
||||
if args.prompt.is_empty() {
|
||||
// Use TUI mode unless --no-tui flag is set or not a TTY
|
||||
if !args.no_tui && atty::is(atty::Stream::Stdout) {
|
||||
// Launch TUI
|
||||
// Note: For now, TUI doesn't use plugin manager directly
|
||||
// In the future, we'll integrate plugin commands into TUI
|
||||
return ui::run(client, opts, perms, settings).await;
|
||||
}
|
||||
|
||||
// Legacy text-based REPL
|
||||
println!("🤖 Owlen Interactive Mode");
|
||||
println!("Model: {}", opts.model);
|
||||
println!("Mode: {:?}", settings.mode);
|
||||
|
||||
// Show loaded plugins
|
||||
let plugins = app_context.plugin_manager.plugins();
|
||||
if !plugins.is_empty() {
|
||||
println!("Plugins: {} loaded", plugins.len());
|
||||
}
|
||||
|
||||
println!("Type your message or /help for commands. Press Ctrl+C to exit.\n");
|
||||
|
||||
use std::io::{stdin, BufRead};
|
||||
let stdin = stdin();
|
||||
let mut lines = stdin.lock().lines();
|
||||
let mut stats = agent_core::SessionStats::new();
|
||||
let mut history = agent_core::SessionHistory::new();
|
||||
let mut checkpoint_mgr = agent_core::CheckpointManager::new(
|
||||
std::path::PathBuf::from(".owlen/checkpoints")
|
||||
);
|
||||
|
||||
loop {
|
||||
print!("> ");
|
||||
std::io::stdout().flush().ok();
|
||||
|
||||
if let Some(Ok(line)) = lines.next() {
|
||||
let input = line.trim();
|
||||
if input.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle slash commands
|
||||
if input.starts_with('/') {
|
||||
match input {
|
||||
"/help" => {
|
||||
println!("\n📖 Available Commands:");
|
||||
println!(" /help - Show this help message");
|
||||
println!(" /status - Show session status");
|
||||
println!(" /permissions - Show permission settings");
|
||||
println!(" /cost - Show token usage and timing");
|
||||
println!(" /history - Show conversation history");
|
||||
println!(" /checkpoint - Save current session state");
|
||||
println!(" /checkpoints - List all saved checkpoints");
|
||||
println!(" /rewind <id> - Restore session from checkpoint");
|
||||
println!(" /clear - Clear conversation history");
|
||||
println!(" /plugins - Show loaded plugins and commands");
|
||||
println!(" /exit - Exit interactive mode");
|
||||
|
||||
// Show plugin commands if any are loaded
|
||||
let plugin_commands = app_context.plugin_manager.all_commands();
|
||||
if !plugin_commands.is_empty() {
|
||||
println!("\n📦 Plugin Commands:");
|
||||
for (name, _path) in &plugin_commands {
|
||||
println!(" /{}", name);
|
||||
}
|
||||
}
|
||||
}
|
||||
"/status" => {
|
||||
println!("\n📊 Session Status:");
|
||||
println!(" Model: {}", opts.model);
|
||||
println!(" Mode: {:?}", settings.mode);
|
||||
println!(" Messages: {}", stats.total_messages);
|
||||
println!(" Tools: {} calls", stats.total_tool_calls);
|
||||
let elapsed = stats.start_time.elapsed().unwrap_or_default();
|
||||
println!(" Uptime: {}", agent_core::SessionStats::format_duration(elapsed));
|
||||
}
|
||||
"/permissions" => {
|
||||
println!("\n🔒 Permission Settings:");
|
||||
println!(" Mode: {:?}", perms.mode());
|
||||
println!("\n Read-only tools: Read, Grep, Glob, NotebookRead");
|
||||
match perms.mode() {
|
||||
permissions::Mode::Plan => {
|
||||
println!(" ✅ Allowed (plan mode)");
|
||||
println!("\n Write tools: Write, Edit, NotebookEdit");
|
||||
println!(" ❓ Ask permission");
|
||||
println!("\n System tools: Bash");
|
||||
println!(" ❓ Ask permission");
|
||||
}
|
||||
permissions::Mode::AcceptEdits => {
|
||||
println!(" ✅ Allowed");
|
||||
println!("\n Write tools: Write, Edit, NotebookEdit");
|
||||
println!(" ✅ Allowed (acceptEdits mode)");
|
||||
println!("\n System tools: Bash");
|
||||
println!(" ❓ Ask permission");
|
||||
}
|
||||
permissions::Mode::Code => {
|
||||
println!(" ✅ Allowed");
|
||||
println!("\n Write tools: Write, Edit, NotebookEdit");
|
||||
println!(" ✅ Allowed (code mode)");
|
||||
println!("\n System tools: Bash");
|
||||
println!(" ✅ Allowed (code mode)");
|
||||
}
|
||||
}
|
||||
}
|
||||
"/cost" => {
|
||||
println!("\n💰 Token Usage & Timing:");
|
||||
println!(" Est. Tokens: ~{}", stats.estimated_tokens);
|
||||
println!(" Total Time: {}", agent_core::SessionStats::format_duration(stats.total_duration));
|
||||
if stats.total_messages > 0 {
|
||||
let avg_time = stats.total_duration / stats.total_messages as u32;
|
||||
println!(" Avg/Message: {}", agent_core::SessionStats::format_duration(avg_time));
|
||||
}
|
||||
println!("\n Note: Ollama is free - no cost incurred!");
|
||||
}
|
||||
"/history" => {
|
||||
println!("\n📜 Conversation History:");
|
||||
if history.user_prompts.is_empty() {
|
||||
println!(" (No messages yet)");
|
||||
} else {
|
||||
for (i, (user, assistant)) in history.user_prompts.iter()
|
||||
.zip(history.assistant_responses.iter()).enumerate() {
|
||||
println!("\n [{}] User: {}", i + 1, user);
|
||||
println!(" Assistant: {}...",
|
||||
assistant.chars().take(100).collect::<String>());
|
||||
}
|
||||
}
|
||||
if !history.tool_calls.is_empty() {
|
||||
println!("\n Tool Calls: {}", history.tool_calls.len());
|
||||
}
|
||||
}
|
||||
"/checkpoint" => {
|
||||
let checkpoint_id = format!("checkpoint-{}",
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
);
|
||||
match checkpoint_mgr.save_checkpoint(
|
||||
checkpoint_id.clone(),
|
||||
stats.clone(),
|
||||
&history,
|
||||
) {
|
||||
Ok(checkpoint) => {
|
||||
println!("\n💾 Checkpoint saved: {}", checkpoint_id);
|
||||
if !checkpoint.file_diffs.is_empty() {
|
||||
println!(" Files tracked: {}", checkpoint.file_diffs.len());
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("\n❌ Failed to save checkpoint: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
"/checkpoints" => {
|
||||
match checkpoint_mgr.list_checkpoints() {
|
||||
Ok(checkpoints) => {
|
||||
if checkpoints.is_empty() {
|
||||
println!("\n📋 No checkpoints saved yet");
|
||||
} else {
|
||||
println!("\n📋 Saved Checkpoints:");
|
||||
for (i, cp_id) in checkpoints.iter().enumerate() {
|
||||
println!(" [{}] {}", i + 1, cp_id);
|
||||
}
|
||||
println!("\n Use /rewind <id> to restore");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("\n❌ Failed to list checkpoints: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
"/clear" => {
|
||||
history.clear();
|
||||
stats = agent_core::SessionStats::new();
|
||||
println!("\n🗑️ Session history cleared!");
|
||||
}
|
||||
"/plugins" => {
|
||||
let plugins = app_context.plugin_manager.plugins();
|
||||
if plugins.is_empty() {
|
||||
println!("\n📦 No plugins loaded");
|
||||
println!(" Place plugins in:");
|
||||
println!(" - ~/.config/owlen/plugins (user plugins)");
|
||||
println!(" - .owlen/plugins (project plugins)");
|
||||
} else {
|
||||
println!("\n📦 Loaded Plugins:");
|
||||
for plugin in plugins {
|
||||
println!("\n {} v{}", plugin.manifest.name, plugin.manifest.version);
|
||||
if let Some(desc) = &plugin.manifest.description {
|
||||
println!(" {}", desc);
|
||||
}
|
||||
if let Some(author) = &plugin.manifest.author {
|
||||
println!(" Author: {}", author);
|
||||
}
|
||||
|
||||
let commands = plugin.all_command_names();
|
||||
if !commands.is_empty() {
|
||||
println!(" Commands: {}", commands.join(", "));
|
||||
}
|
||||
|
||||
let agents = plugin.all_agent_names();
|
||||
if !agents.is_empty() {
|
||||
println!(" Agents: {}", agents.join(", "));
|
||||
}
|
||||
|
||||
let skills = plugin.all_skill_names();
|
||||
if !skills.is_empty() {
|
||||
println!(" Skills: {}", skills.join(", "));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"/exit" => {
|
||||
println!("\n👋 Goodbye!");
|
||||
break;
|
||||
}
|
||||
cmd if cmd.starts_with("/rewind ") => {
|
||||
let checkpoint_id = cmd.strip_prefix("/rewind ").unwrap().trim();
|
||||
match checkpoint_mgr.rewind_to(checkpoint_id) {
|
||||
Ok(restored_files) => {
|
||||
println!("\n⏪ Rewound to checkpoint: {}", checkpoint_id);
|
||||
if !restored_files.is_empty() {
|
||||
println!(" Restored files:");
|
||||
for file in restored_files {
|
||||
println!(" - {}", file.display());
|
||||
}
|
||||
}
|
||||
// Load the checkpoint to restore history and stats
|
||||
if let Ok(checkpoint) = checkpoint_mgr.load_checkpoint(checkpoint_id) {
|
||||
stats = checkpoint.stats;
|
||||
history.user_prompts = checkpoint.user_prompts;
|
||||
history.assistant_responses = checkpoint.assistant_responses;
|
||||
history.tool_calls = checkpoint.tool_calls;
|
||||
println!(" Session state restored");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("\n❌ Failed to rewind: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
println!("\n❌ Unknown command: {}", input);
|
||||
println!(" Type /help for available commands");
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Regular message - run through agent loop
|
||||
history.add_user_message(input.to_string());
|
||||
let start = SystemTime::now();
|
||||
|
||||
let ctx = agent_core::ToolContext::new();
|
||||
match agent_core::run_agent_loop(&client, input, &opts, &perms, &ctx).await {
|
||||
Ok(response) => {
|
||||
println!("\n{}", response);
|
||||
history.add_assistant_message(response.clone());
|
||||
|
||||
// Update stats
|
||||
let duration = start.elapsed().unwrap_or_default();
|
||||
let tokens = (input.len() + response.len()) / 4; // Rough estimate
|
||||
stats.record_message(tokens, duration);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("\n❌ Error: {}", e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Non-interactive mode - process single prompt
|
||||
let prompt = args.prompt.join(" ");
|
||||
let start_time = SystemTime::now();
|
||||
|
||||
// Handle different output formats
|
||||
let ctx = agent_core::ToolContext::new();
|
||||
match output_format {
|
||||
OutputFormat::Text => {
|
||||
// Text format: stream to stdout as before
|
||||
let mut stream = client.chat_stream(&msgs, &opts).await?;
|
||||
while let Some(chunk) = stream.try_next().await? {
|
||||
if let Some(m) = chunk.message {
|
||||
if let Some(c) = m.content {
|
||||
print!("{c}");
|
||||
io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
if matches!(chunk.done, Some(true)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
println!(); // Newline after response
|
||||
// Text format: Use agent orchestrator with tool calling
|
||||
let response = agent_core::run_agent_loop(&client, &prompt, &opts, &perms, &ctx).await?;
|
||||
println!("{}", response);
|
||||
}
|
||||
OutputFormat::Json => {
|
||||
// JSON format: collect all chunks, then output final JSON
|
||||
let mut stream = client.chat_stream(&msgs, &opts).await?;
|
||||
let mut response = String::new();
|
||||
|
||||
while let Some(chunk) = stream.try_next().await? {
|
||||
if let Some(m) = chunk.message {
|
||||
if let Some(c) = m.content {
|
||||
response.push_str(&c);
|
||||
}
|
||||
}
|
||||
if matches!(chunk.done, Some(true)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// JSON format: Use agent loop and output as JSON
|
||||
let response = agent_core::run_agent_loop(&client, &prompt, &opts, &perms, &ctx).await?;
|
||||
|
||||
let duration_ms = start_time.elapsed().unwrap().as_millis() as u64;
|
||||
|
||||
// Rough token estimate (tokens ~= chars / 4)
|
||||
let estimated_tokens = ((prompt.len() + response.len()) / 4) as u64;
|
||||
|
||||
let output = SessionOutput {
|
||||
@@ -526,7 +832,7 @@ async fn main() -> Result<()> {
|
||||
println!("{}", serde_json::to_string(&output)?);
|
||||
}
|
||||
OutputFormat::StreamJson => {
|
||||
// Stream-JSON format: emit session_start, chunks, and session_end
|
||||
// Stream-JSON format: emit session_start, response, and session_end
|
||||
let session_start = StreamEvent {
|
||||
event_type: "session_start".to_string(),
|
||||
session_id: Some(session_id.clone()),
|
||||
@@ -535,30 +841,17 @@ async fn main() -> Result<()> {
|
||||
};
|
||||
println!("{}", serde_json::to_string(&session_start)?);
|
||||
|
||||
let mut stream = client.chat_stream(&msgs, &opts).await?;
|
||||
let mut response = String::new();
|
||||
let response = agent_core::run_agent_loop(&client, &prompt, &opts, &perms, &ctx).await?;
|
||||
|
||||
while let Some(chunk) = stream.try_next().await? {
|
||||
if let Some(m) = chunk.message {
|
||||
if let Some(c) = m.content {
|
||||
response.push_str(&c);
|
||||
let chunk_event = StreamEvent {
|
||||
event_type: "chunk".to_string(),
|
||||
session_id: None,
|
||||
content: Some(c),
|
||||
stats: None,
|
||||
};
|
||||
println!("{}", serde_json::to_string(&chunk_event)?);
|
||||
}
|
||||
}
|
||||
if matches!(chunk.done, Some(true)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let chunk_event = StreamEvent {
|
||||
event_type: "chunk".to_string(),
|
||||
session_id: None,
|
||||
content: Some(response.clone()),
|
||||
stats: None,
|
||||
};
|
||||
println!("{}", serde_json::to_string(&chunk_event)?);
|
||||
|
||||
let duration_ms = start_time.elapsed().unwrap().as_millis() as u64;
|
||||
|
||||
// Rough token estimate
|
||||
let estimated_tokens = ((prompt.len() + response.len()) / 4) as u64;
|
||||
|
||||
let session_end = StreamEvent {
|
||||
|
||||
@@ -5,12 +5,6 @@ use predicates::prelude::PredicateBooleanExt;
|
||||
#[tokio::test]
|
||||
async fn headless_streams_ndjson() {
|
||||
let server = MockServer::start_async().await;
|
||||
// Mock /api/chat with NDJSON lines
|
||||
let body = serde_json::json!({
|
||||
"model": "qwen2.5",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"stream": true
|
||||
});
|
||||
|
||||
let response = concat!(
|
||||
r#"{"message":{"role":"assistant","content":"Hel"}}"#,"\n",
|
||||
@@ -18,10 +12,11 @@ async fn headless_streams_ndjson() {
|
||||
r#"{"done":true}"#,"\n",
|
||||
);
|
||||
|
||||
// The CLI includes tools in the request, so we need to match any request to /api/chat
|
||||
// instead of matching exact body (which includes tool definitions)
|
||||
let _m = server.mock(|when, then| {
|
||||
when.method(POST)
|
||||
.path("/api/chat")
|
||||
.json_body(body.clone());
|
||||
.path("/api/chat");
|
||||
then.status(200)
|
||||
.header("content-type", "application/x-ndjson")
|
||||
.body(response);
|
||||
|
||||
27
crates/app/ui/Cargo.toml
Normal file
27
crates/app/ui/Cargo.toml
Normal file
@@ -0,0 +1,27 @@
|
||||
[package]
|
||||
name = "ui"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
rust-version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
color-eyre = "0.6"
|
||||
crossterm = { version = "0.28", features = ["event-stream"] }
|
||||
ratatui = "0.28"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
futures = "0.3"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
unicode-width = "0.2"
|
||||
textwrap = "0.16"
|
||||
syntect = { version = "5.0", default-features = false, features = ["default-syntaxes", "default-themes", "regex-onig"] }
|
||||
pulldown-cmark = "0.11"
|
||||
|
||||
# Internal dependencies
|
||||
agent-core = { path = "../../core/agent" }
|
||||
permissions = { path = "../../platform/permissions" }
|
||||
llm-core = { path = "../../llm/core" }
|
||||
llm-ollama = { path = "../../llm/ollama" }
|
||||
config-agent = { path = "../../platform/config" }
|
||||
tools-todo = { path = "../../tools/todo" }
|
||||
1101
crates/app/ui/src/app.rs
Normal file
1101
crates/app/ui/src/app.rs
Normal file
File diff suppressed because it is too large
Load Diff
226
crates/app/ui/src/completions.rs
Normal file
226
crates/app/ui/src/completions.rs
Normal file
@@ -0,0 +1,226 @@
|
||||
//! Command completion engine for the TUI
|
||||
//!
|
||||
//! Provides Tab-completion for slash commands, file paths, and tool names.
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
/// A single completion suggestion
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Completion {
|
||||
/// The text to insert
|
||||
pub text: String,
|
||||
/// Description of what this completion does
|
||||
pub description: String,
|
||||
/// Source of the completion (e.g., "builtin", "plugin:name")
|
||||
pub source: String,
|
||||
}
|
||||
|
||||
/// Information about a command for completion purposes
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CommandInfo {
|
||||
/// Command name (without leading /)
|
||||
pub name: String,
|
||||
/// Command description
|
||||
pub description: String,
|
||||
/// Source of the command
|
||||
pub source: String,
|
||||
}
|
||||
|
||||
impl CommandInfo {
|
||||
pub fn new(name: &str, description: &str, source: &str) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
description: description.to_string(),
|
||||
source: source.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Completion engine for the TUI
|
||||
pub struct CompletionEngine {
|
||||
/// Available commands
|
||||
commands: Vec<CommandInfo>,
|
||||
}
|
||||
|
||||
impl Default for CompletionEngine {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionEngine {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
commands: Self::builtin_commands(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get built-in commands
|
||||
fn builtin_commands() -> Vec<CommandInfo> {
|
||||
vec![
|
||||
CommandInfo::new("help", "Show available commands and help", "builtin"),
|
||||
CommandInfo::new("clear", "Clear the screen", "builtin"),
|
||||
CommandInfo::new("mcp", "List MCP servers and their tools", "builtin"),
|
||||
CommandInfo::new("hooks", "Show loaded hooks", "builtin"),
|
||||
CommandInfo::new("compact", "Compact conversation context", "builtin"),
|
||||
CommandInfo::new("mode", "Switch permission mode (plan/edit/code)", "builtin"),
|
||||
CommandInfo::new("provider", "Switch LLM provider", "builtin"),
|
||||
CommandInfo::new("model", "Switch LLM model", "builtin"),
|
||||
CommandInfo::new("checkpoint", "Create a checkpoint", "builtin"),
|
||||
CommandInfo::new("rewind", "Rewind to a checkpoint", "builtin"),
|
||||
]
|
||||
}
|
||||
|
||||
/// Add commands from plugins
|
||||
pub fn add_plugin_commands(&mut self, plugin_name: &str, commands: Vec<CommandInfo>) {
|
||||
for mut cmd in commands {
|
||||
cmd.source = format!("plugin:{}", plugin_name);
|
||||
self.commands.push(cmd);
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a single command
|
||||
pub fn add_command(&mut self, command: CommandInfo) {
|
||||
self.commands.push(command);
|
||||
}
|
||||
|
||||
/// Get completions for the given input
|
||||
pub fn complete(&self, input: &str) -> Vec<Completion> {
|
||||
if input.starts_with('/') {
|
||||
self.complete_command(&input[1..])
|
||||
} else if input.starts_with('@') {
|
||||
self.complete_file_path(&input[1..])
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
||||
/// Complete a slash command
|
||||
fn complete_command(&self, partial: &str) -> Vec<Completion> {
|
||||
let partial_lower = partial.to_lowercase();
|
||||
|
||||
self.commands
|
||||
.iter()
|
||||
.filter(|cmd| {
|
||||
// Match if name starts with partial, or contains partial (fuzzy)
|
||||
cmd.name.to_lowercase().starts_with(&partial_lower)
|
||||
|| (partial.len() >= 2 && cmd.name.to_lowercase().contains(&partial_lower))
|
||||
})
|
||||
.map(|cmd| Completion {
|
||||
text: format!("/{}", cmd.name),
|
||||
description: cmd.description.clone(),
|
||||
source: cmd.source.clone(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Complete a file path
|
||||
fn complete_file_path(&self, partial: &str) -> Vec<Completion> {
|
||||
let path = Path::new(partial);
|
||||
|
||||
// Get the directory to search and the prefix to match
|
||||
let (dir, prefix) = if partial.ends_with('/') || partial.is_empty() {
|
||||
(partial, "")
|
||||
} else {
|
||||
let parent = path.parent().map(|p| p.to_str().unwrap_or("")).unwrap_or("");
|
||||
let file_name = path.file_name().and_then(|f| f.to_str()).unwrap_or("");
|
||||
(parent, file_name)
|
||||
};
|
||||
|
||||
// Search directory
|
||||
let search_dir = if dir.is_empty() { "." } else { dir };
|
||||
|
||||
match std::fs::read_dir(search_dir) {
|
||||
Ok(entries) => {
|
||||
entries
|
||||
.filter_map(|entry| entry.ok())
|
||||
.filter(|entry| {
|
||||
let name = entry.file_name();
|
||||
let name_str = name.to_string_lossy();
|
||||
// Skip hidden files unless user started typing with .
|
||||
if !prefix.starts_with('.') && name_str.starts_with('.') {
|
||||
return false;
|
||||
}
|
||||
name_str.to_lowercase().starts_with(&prefix.to_lowercase())
|
||||
})
|
||||
.map(|entry| {
|
||||
let name = entry.file_name();
|
||||
let name_str = name.to_string_lossy();
|
||||
let is_dir = entry.file_type().map(|t| t.is_dir()).unwrap_or(false);
|
||||
|
||||
let full_path = if dir.is_empty() {
|
||||
name_str.to_string()
|
||||
} else if dir.ends_with('/') {
|
||||
format!("{}{}", dir, name_str)
|
||||
} else {
|
||||
format!("{}/{}", dir, name_str)
|
||||
};
|
||||
|
||||
Completion {
|
||||
text: format!("@{}{}", full_path, if is_dir { "/" } else { "" }),
|
||||
description: if is_dir { "Directory".to_string() } else { "File".to_string() },
|
||||
source: "filesystem".to_string(),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
Err(_) => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all commands (for /help display)
|
||||
pub fn all_commands(&self) -> &[CommandInfo] {
|
||||
&self.commands
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_command_completion_exact() {
|
||||
let engine = CompletionEngine::new();
|
||||
let completions = engine.complete("/help");
|
||||
assert!(!completions.is_empty());
|
||||
assert!(completions.iter().any(|c| c.text == "/help"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_command_completion_partial() {
|
||||
let engine = CompletionEngine::new();
|
||||
let completions = engine.complete("/hel");
|
||||
assert!(!completions.is_empty());
|
||||
assert!(completions.iter().any(|c| c.text == "/help"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_command_completion_fuzzy() {
|
||||
let engine = CompletionEngine::new();
|
||||
// "cle" should match "clear"
|
||||
let completions = engine.complete("/cle");
|
||||
assert!(!completions.is_empty());
|
||||
assert!(completions.iter().any(|c| c.text == "/clear"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_command_info() {
|
||||
let info = CommandInfo::new("test", "A test command", "builtin");
|
||||
assert_eq!(info.name, "test");
|
||||
assert_eq!(info.description, "A test command");
|
||||
assert_eq!(info.source, "builtin");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_plugin_commands() {
|
||||
let mut engine = CompletionEngine::new();
|
||||
let plugin_cmds = vec![
|
||||
CommandInfo::new("custom", "A custom command", ""),
|
||||
];
|
||||
engine.add_plugin_commands("my-plugin", plugin_cmds);
|
||||
|
||||
let completions = engine.complete("/custom");
|
||||
assert!(!completions.is_empty());
|
||||
assert!(completions.iter().any(|c| c.source == "plugin:my-plugin"));
|
||||
}
|
||||
}
|
||||
377
crates/app/ui/src/components/autocomplete.rs
Normal file
377
crates/app/ui/src/components/autocomplete.rs
Normal file
@@ -0,0 +1,377 @@
|
||||
//! Command autocomplete dropdown component
|
||||
//!
|
||||
//! Displays inline autocomplete suggestions when user types `/`.
|
||||
//! Supports fuzzy filtering as user types.
|
||||
|
||||
use crate::theme::Theme;
|
||||
use crossterm::event::{KeyCode, KeyEvent};
|
||||
use ratatui::{
|
||||
layout::Rect,
|
||||
style::Style,
|
||||
text::{Line, Span},
|
||||
widgets::{Block, Borders, Clear, Paragraph},
|
||||
Frame,
|
||||
};
|
||||
|
||||
/// An autocomplete option
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AutocompleteOption {
|
||||
/// The trigger text (command name without /)
|
||||
pub trigger: String,
|
||||
/// Display text (e.g., "/model [name]")
|
||||
pub display: String,
|
||||
/// Short description
|
||||
pub description: String,
|
||||
/// Has submenu/subcommands
|
||||
pub has_submenu: bool,
|
||||
}
|
||||
|
||||
impl AutocompleteOption {
|
||||
pub fn new(trigger: &str, description: &str) -> Self {
|
||||
Self {
|
||||
trigger: trigger.to_string(),
|
||||
display: format!("/{}", trigger),
|
||||
description: description.to_string(),
|
||||
has_submenu: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_args(trigger: &str, args: &str, description: &str) -> Self {
|
||||
Self {
|
||||
trigger: trigger.to_string(),
|
||||
display: format!("/{} {}", trigger, args),
|
||||
description: description.to_string(),
|
||||
has_submenu: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_submenu(trigger: &str, description: &str) -> Self {
|
||||
Self {
|
||||
trigger: trigger.to_string(),
|
||||
display: format!("/{}", trigger),
|
||||
description: description.to_string(),
|
||||
has_submenu: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Default command options
|
||||
fn default_options() -> Vec<AutocompleteOption> {
|
||||
vec![
|
||||
AutocompleteOption::new("help", "Show help"),
|
||||
AutocompleteOption::new("status", "Session info"),
|
||||
AutocompleteOption::with_args("model", "[name]", "Switch model"),
|
||||
AutocompleteOption::with_args("provider", "[name]", "Switch provider"),
|
||||
AutocompleteOption::new("history", "View history"),
|
||||
AutocompleteOption::new("checkpoint", "Save state"),
|
||||
AutocompleteOption::new("checkpoints", "List checkpoints"),
|
||||
AutocompleteOption::with_args("rewind", "[id]", "Restore"),
|
||||
AutocompleteOption::new("cost", "Token usage"),
|
||||
AutocompleteOption::new("clear", "Clear chat"),
|
||||
AutocompleteOption::new("compact", "Compact context"),
|
||||
AutocompleteOption::new("permissions", "Permission mode"),
|
||||
AutocompleteOption::new("themes", "List themes"),
|
||||
AutocompleteOption::with_args("theme", "[name]", "Switch theme"),
|
||||
AutocompleteOption::new("exit", "Exit"),
|
||||
]
|
||||
}
|
||||
|
||||
/// Autocomplete dropdown component
|
||||
pub struct Autocomplete {
|
||||
options: Vec<AutocompleteOption>,
|
||||
filtered: Vec<usize>, // indices into options
|
||||
selected: usize,
|
||||
visible: bool,
|
||||
theme: Theme,
|
||||
}
|
||||
|
||||
impl Autocomplete {
|
||||
pub fn new(theme: Theme) -> Self {
|
||||
let options = default_options();
|
||||
let filtered: Vec<usize> = (0..options.len()).collect();
|
||||
|
||||
Self {
|
||||
options,
|
||||
filtered,
|
||||
selected: 0,
|
||||
visible: false,
|
||||
theme,
|
||||
}
|
||||
}
|
||||
|
||||
/// Show autocomplete and reset filter
|
||||
pub fn show(&mut self) {
|
||||
self.visible = true;
|
||||
self.filtered = (0..self.options.len()).collect();
|
||||
self.selected = 0;
|
||||
}
|
||||
|
||||
/// Hide autocomplete
|
||||
pub fn hide(&mut self) {
|
||||
self.visible = false;
|
||||
}
|
||||
|
||||
/// Check if visible
|
||||
pub fn is_visible(&self) -> bool {
|
||||
self.visible
|
||||
}
|
||||
|
||||
/// Update filter based on current input (text after /)
|
||||
pub fn update_filter(&mut self, query: &str) {
|
||||
if query.is_empty() {
|
||||
self.filtered = (0..self.options.len()).collect();
|
||||
} else {
|
||||
let query_lower = query.to_lowercase();
|
||||
self.filtered = self.options
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, opt)| {
|
||||
// Fuzzy match: check if query chars appear in order
|
||||
fuzzy_match(&opt.trigger.to_lowercase(), &query_lower)
|
||||
})
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
}
|
||||
|
||||
// Reset selection if it's out of bounds
|
||||
if self.selected >= self.filtered.len() {
|
||||
self.selected = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Select next option
|
||||
pub fn select_next(&mut self) {
|
||||
if !self.filtered.is_empty() {
|
||||
self.selected = (self.selected + 1) % self.filtered.len();
|
||||
}
|
||||
}
|
||||
|
||||
/// Select previous option
|
||||
pub fn select_prev(&mut self) {
|
||||
if !self.filtered.is_empty() {
|
||||
self.selected = if self.selected == 0 {
|
||||
self.filtered.len() - 1
|
||||
} else {
|
||||
self.selected - 1
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the currently selected option's trigger
|
||||
pub fn confirm(&self) -> Option<String> {
|
||||
if self.filtered.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let idx = self.filtered[self.selected];
|
||||
Some(format!("/{}", self.options[idx].trigger))
|
||||
}
|
||||
|
||||
/// Handle key input, returns Some(command) if confirmed
|
||||
///
|
||||
/// Key behavior:
|
||||
/// - Tab: Confirm selection and insert into input
|
||||
/// - Down/Up: Navigate options
|
||||
/// - Enter: Pass through to submit (NotHandled)
|
||||
/// - Esc: Cancel autocomplete
|
||||
pub fn handle_key(&mut self, key: KeyEvent) -> AutocompleteResult {
|
||||
if !self.visible {
|
||||
return AutocompleteResult::NotHandled;
|
||||
}
|
||||
|
||||
match key.code {
|
||||
KeyCode::Tab => {
|
||||
// Tab confirms and inserts the selected command
|
||||
if let Some(cmd) = self.confirm() {
|
||||
self.hide();
|
||||
AutocompleteResult::Confirmed(cmd)
|
||||
} else {
|
||||
AutocompleteResult::Handled
|
||||
}
|
||||
}
|
||||
KeyCode::Down => {
|
||||
self.select_next();
|
||||
AutocompleteResult::Handled
|
||||
}
|
||||
KeyCode::BackTab | KeyCode::Up => {
|
||||
self.select_prev();
|
||||
AutocompleteResult::Handled
|
||||
}
|
||||
KeyCode::Enter => {
|
||||
// Enter should submit the message, not confirm autocomplete
|
||||
// Hide autocomplete and let Enter pass through
|
||||
self.hide();
|
||||
AutocompleteResult::NotHandled
|
||||
}
|
||||
KeyCode::Esc => {
|
||||
self.hide();
|
||||
AutocompleteResult::Cancelled
|
||||
}
|
||||
_ => AutocompleteResult::NotHandled,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update theme
|
||||
pub fn set_theme(&mut self, theme: Theme) {
|
||||
self.theme = theme;
|
||||
}
|
||||
|
||||
/// Add custom options (from plugins)
|
||||
pub fn add_options(&mut self, options: Vec<AutocompleteOption>) {
|
||||
self.options.extend(options);
|
||||
// Re-filter with all options
|
||||
self.filtered = (0..self.options.len()).collect();
|
||||
}
|
||||
|
||||
/// Render the autocomplete dropdown above the input line
|
||||
pub fn render(&self, frame: &mut Frame, input_area: Rect) {
|
||||
if !self.visible || self.filtered.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Calculate dropdown dimensions
|
||||
let max_visible = 8.min(self.filtered.len());
|
||||
let width = 40.min(input_area.width.saturating_sub(4));
|
||||
let height = (max_visible + 2) as u16; // +2 for borders
|
||||
|
||||
// Position above input, left-aligned with some padding
|
||||
let x = input_area.x + 2;
|
||||
let y = input_area.y.saturating_sub(height);
|
||||
|
||||
let dropdown_area = Rect::new(x, y, width, height);
|
||||
|
||||
// Clear area behind dropdown
|
||||
frame.render_widget(Clear, dropdown_area);
|
||||
|
||||
// Build option lines
|
||||
let mut lines: Vec<Line> = Vec::new();
|
||||
|
||||
for (display_idx, &opt_idx) in self.filtered.iter().take(max_visible).enumerate() {
|
||||
let opt = &self.options[opt_idx];
|
||||
let is_selected = display_idx == self.selected;
|
||||
|
||||
let style = if is_selected {
|
||||
self.theme.selected
|
||||
} else {
|
||||
Style::default()
|
||||
};
|
||||
|
||||
let mut spans = vec![
|
||||
Span::styled(" ", style),
|
||||
Span::styled("/", if is_selected { style } else { self.theme.cmd_slash }),
|
||||
Span::styled(&opt.trigger, if is_selected { style } else { self.theme.cmd_name }),
|
||||
];
|
||||
|
||||
// Submenu indicator
|
||||
if opt.has_submenu {
|
||||
spans.push(Span::styled(" >", if is_selected { style } else { self.theme.cmd_desc }));
|
||||
}
|
||||
|
||||
// Pad to fixed width for consistent selection highlighting
|
||||
let current_len: usize = spans.iter().map(|s| s.content.len()).sum();
|
||||
let padding = (width as usize).saturating_sub(current_len + 1);
|
||||
spans.push(Span::styled(" ".repeat(padding), style));
|
||||
|
||||
lines.push(Line::from(spans));
|
||||
}
|
||||
|
||||
// Show overflow indicator if needed
|
||||
if self.filtered.len() > max_visible {
|
||||
lines.push(Line::from(Span::styled(
|
||||
format!(" ... +{} more", self.filtered.len() - max_visible),
|
||||
self.theme.cmd_desc,
|
||||
)));
|
||||
}
|
||||
|
||||
let block = Block::default()
|
||||
.borders(Borders::ALL)
|
||||
.border_style(Style::default().fg(self.theme.palette.border))
|
||||
.style(self.theme.overlay_bg);
|
||||
|
||||
let paragraph = Paragraph::new(lines).block(block);
|
||||
|
||||
frame.render_widget(paragraph, dropdown_area);
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of handling autocomplete key
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum AutocompleteResult {
|
||||
/// Key was not handled by autocomplete
|
||||
NotHandled,
|
||||
/// Key was handled, no action needed
|
||||
Handled,
|
||||
/// User confirmed selection, returns command string
|
||||
Confirmed(String),
|
||||
/// User cancelled autocomplete
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
/// Simple fuzzy match: check if query chars appear in order in text
|
||||
fn fuzzy_match(text: &str, query: &str) -> bool {
|
||||
let mut text_chars = text.chars().peekable();
|
||||
|
||||
for query_char in query.chars() {
|
||||
loop {
|
||||
match text_chars.next() {
|
||||
Some(c) if c == query_char => break,
|
||||
Some(_) => continue,
|
||||
None => return false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_fuzzy_match() {
|
||||
assert!(fuzzy_match("help", "h"));
|
||||
assert!(fuzzy_match("help", "he"));
|
||||
assert!(fuzzy_match("help", "hel"));
|
||||
assert!(fuzzy_match("help", "help"));
|
||||
assert!(fuzzy_match("help", "hp")); // fuzzy: h...p
|
||||
assert!(!fuzzy_match("help", "x"));
|
||||
assert!(!fuzzy_match("help", "helping")); // query longer than text
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autocomplete_filter() {
|
||||
let theme = Theme::default();
|
||||
let mut ac = Autocomplete::new(theme);
|
||||
|
||||
ac.update_filter("he");
|
||||
assert!(ac.filtered.len() < ac.options.len());
|
||||
|
||||
// Should match "help"
|
||||
assert!(ac.filtered.iter().any(|&i| ac.options[i].trigger == "help"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autocomplete_navigation() {
|
||||
let theme = Theme::default();
|
||||
let mut ac = Autocomplete::new(theme);
|
||||
ac.show();
|
||||
|
||||
assert_eq!(ac.selected, 0);
|
||||
ac.select_next();
|
||||
assert_eq!(ac.selected, 1);
|
||||
ac.select_prev();
|
||||
assert_eq!(ac.selected, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autocomplete_confirm() {
|
||||
let theme = Theme::default();
|
||||
let mut ac = Autocomplete::new(theme);
|
||||
ac.show();
|
||||
|
||||
let cmd = ac.confirm();
|
||||
assert!(cmd.is_some());
|
||||
assert!(cmd.unwrap().starts_with("/"));
|
||||
}
|
||||
}
|
||||
468
crates/app/ui/src/components/chat_panel.rs
Normal file
468
crates/app/ui/src/components/chat_panel.rs
Normal file
@@ -0,0 +1,468 @@
|
||||
//! Borderless chat panel component
|
||||
//!
|
||||
//! Displays chat messages with proper indentation, timestamps,
|
||||
//! and streaming indicators. Uses whitespace instead of borders.
|
||||
|
||||
use crate::theme::Theme;
|
||||
use ratatui::{
|
||||
layout::Rect,
|
||||
style::{Modifier, Style},
|
||||
text::{Line, Span, Text},
|
||||
widgets::{Paragraph, Scrollbar, ScrollbarOrientation, ScrollbarState},
|
||||
Frame,
|
||||
};
|
||||
use std::time::SystemTime;
|
||||
|
||||
/// Chat message types
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ChatMessage {
|
||||
User(String),
|
||||
Assistant(String),
|
||||
ToolCall { name: String, args: String },
|
||||
ToolResult { success: bool, output: String },
|
||||
System(String),
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
/// Get a timestamp for when the message was created (for display)
|
||||
pub fn timestamp_display() -> String {
|
||||
let now = SystemTime::now();
|
||||
let secs = now
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0);
|
||||
let hours = (secs / 3600) % 24;
|
||||
let mins = (secs / 60) % 60;
|
||||
format!("{:02}:{:02}", hours, mins)
|
||||
}
|
||||
}
|
||||
|
||||
/// Message with metadata for display
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DisplayMessage {
|
||||
pub message: ChatMessage,
|
||||
pub timestamp: String,
|
||||
pub focused: bool,
|
||||
}
|
||||
|
||||
impl DisplayMessage {
|
||||
pub fn new(message: ChatMessage) -> Self {
|
||||
Self {
|
||||
message,
|
||||
timestamp: ChatMessage::timestamp_display(),
|
||||
focused: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Borderless chat panel
|
||||
pub struct ChatPanel {
|
||||
messages: Vec<DisplayMessage>,
|
||||
scroll_offset: usize,
|
||||
auto_scroll: bool,
|
||||
total_lines: usize,
|
||||
focused_index: Option<usize>,
|
||||
is_streaming: bool,
|
||||
theme: Theme,
|
||||
}
|
||||
|
||||
impl ChatPanel {
|
||||
/// Create new borderless chat panel
|
||||
pub fn new(theme: Theme) -> Self {
|
||||
Self {
|
||||
messages: Vec::new(),
|
||||
scroll_offset: 0,
|
||||
auto_scroll: true,
|
||||
total_lines: 0,
|
||||
focused_index: None,
|
||||
is_streaming: false,
|
||||
theme,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a new message
|
||||
pub fn add_message(&mut self, message: ChatMessage) {
|
||||
self.messages.push(DisplayMessage::new(message));
|
||||
self.auto_scroll = true;
|
||||
self.is_streaming = false;
|
||||
}
|
||||
|
||||
/// Append content to the last assistant message, or create a new one
|
||||
pub fn append_to_assistant(&mut self, content: &str) {
|
||||
if let Some(DisplayMessage {
|
||||
message: ChatMessage::Assistant(last_content),
|
||||
..
|
||||
}) = self.messages.last_mut()
|
||||
{
|
||||
last_content.push_str(content);
|
||||
} else {
|
||||
self.messages.push(DisplayMessage::new(ChatMessage::Assistant(
|
||||
content.to_string(),
|
||||
)));
|
||||
}
|
||||
self.auto_scroll = true;
|
||||
self.is_streaming = true;
|
||||
}
|
||||
|
||||
/// Set streaming state
|
||||
pub fn set_streaming(&mut self, streaming: bool) {
|
||||
self.is_streaming = streaming;
|
||||
}
|
||||
|
||||
/// Scroll up
|
||||
pub fn scroll_up(&mut self, amount: usize) {
|
||||
self.scroll_offset = self.scroll_offset.saturating_sub(amount);
|
||||
self.auto_scroll = false;
|
||||
}
|
||||
|
||||
/// Scroll down
|
||||
pub fn scroll_down(&mut self, amount: usize) {
|
||||
self.scroll_offset = self.scroll_offset.saturating_add(amount);
|
||||
let near_bottom_threshold = 5;
|
||||
if self.total_lines > 0 {
|
||||
let max_scroll = self.total_lines.saturating_sub(1);
|
||||
if self.scroll_offset.saturating_add(near_bottom_threshold) >= max_scroll {
|
||||
self.auto_scroll = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Scroll to bottom
|
||||
pub fn scroll_to_bottom(&mut self) {
|
||||
self.scroll_offset = self.total_lines.saturating_sub(1);
|
||||
self.auto_scroll = true;
|
||||
}
|
||||
|
||||
/// Page up
|
||||
pub fn page_up(&mut self, page_size: usize) {
|
||||
self.scroll_up(page_size.saturating_sub(2));
|
||||
}
|
||||
|
||||
/// Page down
|
||||
pub fn page_down(&mut self, page_size: usize) {
|
||||
self.scroll_down(page_size.saturating_sub(2));
|
||||
}
|
||||
|
||||
/// Focus next message
|
||||
pub fn focus_next(&mut self) {
|
||||
if self.messages.is_empty() {
|
||||
return;
|
||||
}
|
||||
self.focused_index = Some(match self.focused_index {
|
||||
Some(i) if i + 1 < self.messages.len() => i + 1,
|
||||
Some(_) => 0,
|
||||
None => 0,
|
||||
});
|
||||
}
|
||||
|
||||
/// Focus previous message
|
||||
pub fn focus_previous(&mut self) {
|
||||
if self.messages.is_empty() {
|
||||
return;
|
||||
}
|
||||
self.focused_index = Some(match self.focused_index {
|
||||
Some(0) => self.messages.len() - 1,
|
||||
Some(i) => i - 1,
|
||||
None => self.messages.len() - 1,
|
||||
});
|
||||
}
|
||||
|
||||
/// Clear focus
|
||||
pub fn clear_focus(&mut self) {
|
||||
self.focused_index = None;
|
||||
}
|
||||
|
||||
/// Get focused message index
|
||||
pub fn focused_index(&self) -> Option<usize> {
|
||||
self.focused_index
|
||||
}
|
||||
|
||||
/// Get focused message
|
||||
pub fn focused_message(&self) -> Option<&ChatMessage> {
|
||||
self.focused_index
|
||||
.and_then(|i| self.messages.get(i))
|
||||
.map(|m| &m.message)
|
||||
}
|
||||
|
||||
/// Update scroll position before rendering
|
||||
pub fn update_scroll(&mut self, area: Rect) {
|
||||
self.total_lines = self.count_total_lines(area);
|
||||
|
||||
if self.auto_scroll {
|
||||
let visible_height = area.height as usize;
|
||||
let max_scroll = self.total_lines.saturating_sub(visible_height);
|
||||
self.scroll_offset = max_scroll;
|
||||
} else {
|
||||
let visible_height = area.height as usize;
|
||||
let max_scroll = self.total_lines.saturating_sub(visible_height);
|
||||
self.scroll_offset = self.scroll_offset.min(max_scroll);
|
||||
}
|
||||
}
|
||||
|
||||
/// Count total lines for scroll calculation
|
||||
fn count_total_lines(&self, area: Rect) -> usize {
|
||||
let mut line_count = 0;
|
||||
let wrap_width = area.width.saturating_sub(4) as usize;
|
||||
|
||||
for msg in &self.messages {
|
||||
line_count += match &msg.message {
|
||||
ChatMessage::User(content) => {
|
||||
let wrapped = textwrap::wrap(content, wrap_width);
|
||||
wrapped.len() + 1 // +1 for spacing
|
||||
}
|
||||
ChatMessage::Assistant(content) => {
|
||||
let wrapped = textwrap::wrap(content, wrap_width);
|
||||
wrapped.len() + 1
|
||||
}
|
||||
ChatMessage::ToolCall { .. } => 2,
|
||||
ChatMessage::ToolResult { .. } => 2,
|
||||
ChatMessage::System(_) => 1,
|
||||
};
|
||||
}
|
||||
|
||||
line_count
|
||||
}
|
||||
|
||||
/// Render the borderless chat panel
|
||||
///
|
||||
/// Message display format (no symbols, clean typography):
|
||||
/// - Role: bold, appropriate color
|
||||
/// - Timestamp: dim, same line as role
|
||||
/// - Content: 2-space indent, normal weight
|
||||
/// - Blank line between messages
|
||||
pub fn render(&self, frame: &mut Frame, area: Rect) {
|
||||
let mut text_lines = Vec::new();
|
||||
let wrap_width = area.width.saturating_sub(4) as usize;
|
||||
|
||||
for (idx, display_msg) in self.messages.iter().enumerate() {
|
||||
let is_focused = self.focused_index == Some(idx);
|
||||
let is_last = idx == self.messages.len() - 1;
|
||||
|
||||
match &display_msg.message {
|
||||
ChatMessage::User(content) => {
|
||||
// Role line: "You" bold + timestamp dim
|
||||
text_lines.push(Line::from(vec![
|
||||
Span::styled(" ", Style::default()),
|
||||
Span::styled("You", self.theme.user_message),
|
||||
Span::styled(
|
||||
format!(" {}", display_msg.timestamp),
|
||||
self.theme.timestamp,
|
||||
),
|
||||
]));
|
||||
|
||||
// Message content with 2-space indent
|
||||
let wrapped = textwrap::wrap(content, wrap_width);
|
||||
for line in wrapped {
|
||||
let style = if is_focused {
|
||||
self.theme.user_message.add_modifier(Modifier::REVERSED)
|
||||
} else {
|
||||
self.theme.user_message.remove_modifier(Modifier::BOLD)
|
||||
};
|
||||
text_lines.push(Line::from(Span::styled(
|
||||
format!(" {}", line),
|
||||
style,
|
||||
)));
|
||||
}
|
||||
|
||||
// Focus hints
|
||||
if is_focused {
|
||||
text_lines.push(Line::from(Span::styled(
|
||||
" [y]copy [e]edit [r]retry",
|
||||
self.theme.status_dim,
|
||||
)));
|
||||
}
|
||||
|
||||
text_lines.push(Line::from(""));
|
||||
}
|
||||
|
||||
ChatMessage::Assistant(content) => {
|
||||
// Role line: streaming indicator (if active) + "Assistant" bold + timestamp
|
||||
let mut role_spans = vec![Span::styled(" ", Style::default())];
|
||||
|
||||
// Streaming indicator (subtle, no symbol)
|
||||
if is_last && self.is_streaming {
|
||||
role_spans.push(Span::styled(
|
||||
"... ",
|
||||
Style::default().fg(self.theme.palette.success),
|
||||
));
|
||||
}
|
||||
|
||||
role_spans.push(Span::styled(
|
||||
"Assistant",
|
||||
self.theme.assistant_message.add_modifier(Modifier::BOLD),
|
||||
));
|
||||
|
||||
role_spans.push(Span::styled(
|
||||
format!(" {}", display_msg.timestamp),
|
||||
self.theme.timestamp,
|
||||
));
|
||||
|
||||
text_lines.push(Line::from(role_spans));
|
||||
|
||||
// Content
|
||||
let wrapped = textwrap::wrap(content, wrap_width);
|
||||
for line in wrapped {
|
||||
let style = if is_focused {
|
||||
self.theme.assistant_message.add_modifier(Modifier::REVERSED)
|
||||
} else {
|
||||
self.theme.assistant_message
|
||||
};
|
||||
text_lines.push(Line::from(Span::styled(
|
||||
format!(" {}", line),
|
||||
style,
|
||||
)));
|
||||
}
|
||||
|
||||
// Focus hints
|
||||
if is_focused {
|
||||
text_lines.push(Line::from(Span::styled(
|
||||
" [y]copy [r]retry",
|
||||
self.theme.status_dim,
|
||||
)));
|
||||
}
|
||||
|
||||
text_lines.push(Line::from(""));
|
||||
}
|
||||
|
||||
ChatMessage::ToolCall { name, args } => {
|
||||
// Tool calls: name in tool color, args dimmed
|
||||
text_lines.push(Line::from(vec![
|
||||
Span::styled(" ", Style::default()),
|
||||
Span::styled(format!("{} ", name), self.theme.tool_call),
|
||||
Span::styled(
|
||||
truncate_str(args, 60),
|
||||
self.theme.tool_call.add_modifier(Modifier::DIM),
|
||||
),
|
||||
]));
|
||||
text_lines.push(Line::from(""));
|
||||
}
|
||||
|
||||
ChatMessage::ToolResult { success, output } => {
|
||||
// Tool results: status prefix + output
|
||||
let (prefix, style) = if *success {
|
||||
("ok ", self.theme.tool_result_success)
|
||||
} else {
|
||||
("err ", self.theme.tool_result_error)
|
||||
};
|
||||
|
||||
text_lines.push(Line::from(vec![
|
||||
Span::styled(" ", Style::default()),
|
||||
Span::styled(prefix, style),
|
||||
Span::styled(
|
||||
truncate_str(output, 100),
|
||||
style.remove_modifier(Modifier::BOLD),
|
||||
),
|
||||
]));
|
||||
text_lines.push(Line::from(""));
|
||||
}
|
||||
|
||||
ChatMessage::System(content) => {
|
||||
// System messages: just dim text, no prefix
|
||||
text_lines.push(Line::from(vec![
|
||||
Span::styled(" ", Style::default()),
|
||||
Span::styled(content.to_string(), self.theme.system_message),
|
||||
]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let text = Text::from(text_lines);
|
||||
let paragraph = Paragraph::new(text).scroll((self.scroll_offset as u16, 0));
|
||||
|
||||
frame.render_widget(paragraph, area);
|
||||
|
||||
// Render scrollbar if needed
|
||||
if self.total_lines > area.height as usize {
|
||||
let scrollbar = Scrollbar::default()
|
||||
.orientation(ScrollbarOrientation::VerticalRight)
|
||||
.begin_symbol(None)
|
||||
.end_symbol(None)
|
||||
.track_symbol(Some(" "))
|
||||
.thumb_symbol("│")
|
||||
.style(self.theme.status_dim);
|
||||
|
||||
let mut scrollbar_state = ScrollbarState::default()
|
||||
.content_length(self.total_lines)
|
||||
.position(self.scroll_offset);
|
||||
|
||||
frame.render_stateful_widget(scrollbar, area, &mut scrollbar_state);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get messages
|
||||
pub fn messages(&self) -> &[DisplayMessage] {
|
||||
&self.messages
|
||||
}
|
||||
|
||||
/// Clear all messages
|
||||
pub fn clear(&mut self) {
|
||||
self.messages.clear();
|
||||
self.scroll_offset = 0;
|
||||
self.focused_index = None;
|
||||
}
|
||||
|
||||
/// Update theme
|
||||
pub fn set_theme(&mut self, theme: Theme) {
|
||||
self.theme = theme;
|
||||
}
|
||||
}
|
||||
|
||||
/// Truncate a string to max length with ellipsis
|
||||
fn truncate_str(s: &str, max_len: usize) -> String {
|
||||
if s.len() <= max_len {
|
||||
s.to_string()
|
||||
} else {
|
||||
format!("{}...", &s[..max_len.saturating_sub(3)])
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_chat_panel_add_message() {
|
||||
let theme = Theme::default();
|
||||
let mut panel = ChatPanel::new(theme);
|
||||
|
||||
panel.add_message(ChatMessage::User("Hello".to_string()));
|
||||
panel.add_message(ChatMessage::Assistant("Hi there!".to_string()));
|
||||
|
||||
assert_eq!(panel.messages().len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_append_to_assistant() {
|
||||
let theme = Theme::default();
|
||||
let mut panel = ChatPanel::new(theme);
|
||||
|
||||
panel.append_to_assistant("Hello");
|
||||
panel.append_to_assistant(" world");
|
||||
|
||||
assert_eq!(panel.messages().len(), 1);
|
||||
if let ChatMessage::Assistant(content) = &panel.messages()[0].message {
|
||||
assert_eq!(content, "Hello world");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_focus_navigation() {
|
||||
let theme = Theme::default();
|
||||
let mut panel = ChatPanel::new(theme);
|
||||
|
||||
panel.add_message(ChatMessage::User("1".to_string()));
|
||||
panel.add_message(ChatMessage::User("2".to_string()));
|
||||
panel.add_message(ChatMessage::User("3".to_string()));
|
||||
|
||||
assert_eq!(panel.focused_index(), None);
|
||||
|
||||
panel.focus_next();
|
||||
assert_eq!(panel.focused_index(), Some(0));
|
||||
|
||||
panel.focus_next();
|
||||
assert_eq!(panel.focused_index(), Some(1));
|
||||
|
||||
panel.focus_previous();
|
||||
assert_eq!(panel.focused_index(), Some(0));
|
||||
}
|
||||
}
|
||||
322
crates/app/ui/src/components/command_help.rs
Normal file
322
crates/app/ui/src/components/command_help.rs
Normal file
@@ -0,0 +1,322 @@
|
||||
//! Command help overlay component
|
||||
//!
|
||||
//! Modal overlay that displays available commands in a structured format.
|
||||
//! Shown when user types `/help` or `?`. Supports scrolling with j/k or arrows.
|
||||
|
||||
use crate::theme::Theme;
|
||||
use crossterm::event::{KeyCode, KeyEvent};
|
||||
use ratatui::{
|
||||
layout::Rect,
|
||||
style::Style,
|
||||
text::{Line, Span},
|
||||
widgets::{Block, Borders, Clear, Paragraph, Scrollbar, ScrollbarOrientation, ScrollbarState},
|
||||
Frame,
|
||||
};
|
||||
|
||||
/// A single command definition
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Command {
|
||||
pub name: &'static str,
|
||||
pub args: Option<&'static str>,
|
||||
pub description: &'static str,
|
||||
}
|
||||
|
||||
impl Command {
|
||||
pub const fn new(name: &'static str, description: &'static str) -> Self {
|
||||
Self {
|
||||
name,
|
||||
args: None,
|
||||
description,
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn with_args(name: &'static str, args: &'static str, description: &'static str) -> Self {
|
||||
Self {
|
||||
name,
|
||||
args: Some(args),
|
||||
description,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Built-in commands
|
||||
pub fn builtin_commands() -> Vec<Command> {
|
||||
vec![
|
||||
Command::new("help", "Show this help"),
|
||||
Command::new("status", "Current session info"),
|
||||
Command::with_args("model", "[name]", "Switch model"),
|
||||
Command::with_args("provider", "[name]", "Switch provider (ollama, anthropic, openai)"),
|
||||
Command::new("history", "Browse conversation history"),
|
||||
Command::new("checkpoint", "Save conversation state"),
|
||||
Command::new("checkpoints", "List saved checkpoints"),
|
||||
Command::with_args("rewind", "[id]", "Restore checkpoint"),
|
||||
Command::new("cost", "Show token usage"),
|
||||
Command::new("clear", "Clear conversation"),
|
||||
Command::new("compact", "Compact conversation context"),
|
||||
Command::new("permissions", "Show permission mode"),
|
||||
Command::new("themes", "List available themes"),
|
||||
Command::with_args("theme", "[name]", "Switch theme"),
|
||||
Command::new("exit", "Exit OWLEN"),
|
||||
]
|
||||
}
|
||||
|
||||
/// Command help overlay
|
||||
pub struct CommandHelp {
|
||||
commands: Vec<Command>,
|
||||
visible: bool,
|
||||
scroll_offset: usize,
|
||||
theme: Theme,
|
||||
}
|
||||
|
||||
impl CommandHelp {
|
||||
pub fn new(theme: Theme) -> Self {
|
||||
Self {
|
||||
commands: builtin_commands(),
|
||||
visible: false,
|
||||
scroll_offset: 0,
|
||||
theme,
|
||||
}
|
||||
}
|
||||
|
||||
/// Show the help overlay
|
||||
pub fn show(&mut self) {
|
||||
self.visible = true;
|
||||
self.scroll_offset = 0; // Reset scroll when showing
|
||||
}
|
||||
|
||||
/// Hide the help overlay
|
||||
pub fn hide(&mut self) {
|
||||
self.visible = false;
|
||||
}
|
||||
|
||||
/// Check if visible
|
||||
pub fn is_visible(&self) -> bool {
|
||||
self.visible
|
||||
}
|
||||
|
||||
/// Toggle visibility
|
||||
pub fn toggle(&mut self) {
|
||||
self.visible = !self.visible;
|
||||
if self.visible {
|
||||
self.scroll_offset = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Scroll up by amount
|
||||
fn scroll_up(&mut self, amount: usize) {
|
||||
self.scroll_offset = self.scroll_offset.saturating_sub(amount);
|
||||
}
|
||||
|
||||
/// Scroll down by amount, respecting max
|
||||
fn scroll_down(&mut self, amount: usize, max_scroll: usize) {
|
||||
self.scroll_offset = (self.scroll_offset + amount).min(max_scroll);
|
||||
}
|
||||
|
||||
/// Handle key input, returns true if overlay handled the key
|
||||
pub fn handle_key(&mut self, key: KeyEvent) -> bool {
|
||||
if !self.visible {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Calculate max scroll (commands + padding lines - visible area)
|
||||
let total_lines = self.commands.len() + 3; // +3 for padding and footer
|
||||
let max_scroll = total_lines.saturating_sub(10); // Assume ~10 visible lines
|
||||
|
||||
match key.code {
|
||||
KeyCode::Esc | KeyCode::Char('q') | KeyCode::Char('?') => {
|
||||
self.hide();
|
||||
true
|
||||
}
|
||||
// Scroll navigation
|
||||
KeyCode::Up | KeyCode::Char('k') => {
|
||||
self.scroll_up(1);
|
||||
true
|
||||
}
|
||||
KeyCode::Down | KeyCode::Char('j') => {
|
||||
self.scroll_down(1, max_scroll);
|
||||
true
|
||||
}
|
||||
KeyCode::PageUp | KeyCode::Char('u') => {
|
||||
self.scroll_up(5);
|
||||
true
|
||||
}
|
||||
KeyCode::PageDown | KeyCode::Char('d') => {
|
||||
self.scroll_down(5, max_scroll);
|
||||
true
|
||||
}
|
||||
KeyCode::Home | KeyCode::Char('g') => {
|
||||
self.scroll_offset = 0;
|
||||
true
|
||||
}
|
||||
KeyCode::End | KeyCode::Char('G') => {
|
||||
self.scroll_offset = max_scroll;
|
||||
true
|
||||
}
|
||||
_ => true, // Consume all other keys while visible
|
||||
}
|
||||
}
|
||||
|
||||
/// Update theme
|
||||
pub fn set_theme(&mut self, theme: Theme) {
|
||||
self.theme = theme;
|
||||
}
|
||||
|
||||
/// Add plugin commands
|
||||
pub fn add_commands(&mut self, commands: Vec<Command>) {
|
||||
self.commands.extend(commands);
|
||||
}
|
||||
|
||||
/// Render the help overlay
|
||||
pub fn render(&self, frame: &mut Frame, area: Rect) {
|
||||
if !self.visible {
|
||||
return;
|
||||
}
|
||||
|
||||
// Calculate overlay dimensions
|
||||
let width = (area.width as f32 * 0.7).min(65.0) as u16;
|
||||
let max_height = area.height.saturating_sub(4);
|
||||
let content_height = self.commands.len() as u16 + 4; // +4 for padding and footer
|
||||
let height = content_height.min(max_height).max(8);
|
||||
|
||||
// Center the overlay
|
||||
let x = (area.width.saturating_sub(width)) / 2;
|
||||
let y = (area.height.saturating_sub(height)) / 2;
|
||||
|
||||
let overlay_area = Rect::new(x, y, width, height);
|
||||
|
||||
// Clear the area behind the overlay
|
||||
frame.render_widget(Clear, overlay_area);
|
||||
|
||||
// Build content lines
|
||||
let mut lines: Vec<Line> = Vec::new();
|
||||
|
||||
// Empty line for padding
|
||||
lines.push(Line::from(""));
|
||||
|
||||
// Command list
|
||||
for cmd in &self.commands {
|
||||
let name_with_args = if let Some(args) = cmd.args {
|
||||
format!("/{} {}", cmd.name, args)
|
||||
} else {
|
||||
format!("/{}", cmd.name)
|
||||
};
|
||||
|
||||
// Calculate padding for alignment
|
||||
let name_width: usize = 22;
|
||||
let padding = name_width.saturating_sub(name_with_args.len());
|
||||
|
||||
lines.push(Line::from(vec![
|
||||
Span::styled(" ", Style::default()),
|
||||
Span::styled("/", self.theme.cmd_slash),
|
||||
Span::styled(
|
||||
if let Some(args) = cmd.args {
|
||||
format!("{} {}", cmd.name, args)
|
||||
} else {
|
||||
cmd.name.to_string()
|
||||
},
|
||||
self.theme.cmd_name,
|
||||
),
|
||||
Span::raw(" ".repeat(padding)),
|
||||
Span::styled(cmd.description, self.theme.cmd_desc),
|
||||
]));
|
||||
}
|
||||
|
||||
// Empty line for padding
|
||||
lines.push(Line::from(""));
|
||||
|
||||
// Footer hint with scroll info
|
||||
let scroll_hint = if self.commands.len() > (height as usize - 4) {
|
||||
format!(" (scroll: j/k or ↑/↓)")
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
lines.push(Line::from(vec![
|
||||
Span::styled(" Press ", self.theme.cmd_desc),
|
||||
Span::styled("Esc", self.theme.cmd_name),
|
||||
Span::styled(" to close", self.theme.cmd_desc),
|
||||
Span::styled(scroll_hint, self.theme.cmd_desc),
|
||||
]));
|
||||
|
||||
// Create the block with border
|
||||
let block = Block::default()
|
||||
.title(" Commands ")
|
||||
.title_style(self.theme.popup_title)
|
||||
.borders(Borders::ALL)
|
||||
.border_style(self.theme.popup_border)
|
||||
.style(self.theme.overlay_bg);
|
||||
|
||||
let paragraph = Paragraph::new(lines)
|
||||
.block(block)
|
||||
.scroll((self.scroll_offset as u16, 0));
|
||||
|
||||
frame.render_widget(paragraph, overlay_area);
|
||||
|
||||
// Render scrollbar if content exceeds visible area
|
||||
let visible_height = height.saturating_sub(2) as usize; // -2 for borders
|
||||
let total_lines = self.commands.len() + 3;
|
||||
if total_lines > visible_height {
|
||||
let scrollbar = Scrollbar::default()
|
||||
.orientation(ScrollbarOrientation::VerticalRight)
|
||||
.begin_symbol(None)
|
||||
.end_symbol(None)
|
||||
.track_symbol(Some(" "))
|
||||
.thumb_symbol("│")
|
||||
.style(self.theme.status_dim);
|
||||
|
||||
let mut scrollbar_state = ScrollbarState::default()
|
||||
.content_length(total_lines)
|
||||
.position(self.scroll_offset);
|
||||
|
||||
// Adjust scrollbar area to be inside the border
|
||||
let scrollbar_area = Rect::new(
|
||||
overlay_area.x + overlay_area.width - 2,
|
||||
overlay_area.y + 1,
|
||||
1,
|
||||
overlay_area.height.saturating_sub(2),
|
||||
);
|
||||
|
||||
frame.render_stateful_widget(scrollbar, scrollbar_area, &mut scrollbar_state);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_command_help_visibility() {
|
||||
let theme = Theme::default();
|
||||
let mut help = CommandHelp::new(theme);
|
||||
|
||||
assert!(!help.is_visible());
|
||||
help.show();
|
||||
assert!(help.is_visible());
|
||||
help.hide();
|
||||
assert!(!help.is_visible());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builtin_commands() {
|
||||
let commands = builtin_commands();
|
||||
assert!(!commands.is_empty());
|
||||
assert!(commands.iter().any(|c| c.name == "help"));
|
||||
assert!(commands.iter().any(|c| c.name == "provider"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scroll_navigation() {
|
||||
let theme = Theme::default();
|
||||
let mut help = CommandHelp::new(theme);
|
||||
help.show();
|
||||
|
||||
assert_eq!(help.scroll_offset, 0);
|
||||
help.scroll_down(3, 10);
|
||||
assert_eq!(help.scroll_offset, 3);
|
||||
help.scroll_up(1);
|
||||
assert_eq!(help.scroll_offset, 2);
|
||||
help.scroll_up(10); // Should clamp to 0
|
||||
assert_eq!(help.scroll_offset, 0);
|
||||
}
|
||||
}
|
||||
507
crates/app/ui/src/components/input_box.rs
Normal file
507
crates/app/ui/src/components/input_box.rs
Normal file
@@ -0,0 +1,507 @@
|
||||
//! Vim-modal input component
|
||||
//!
|
||||
//! Borderless input with vim-like modes (Normal, Insert, Command).
|
||||
//! Uses mode prefix instead of borders for visual indication.
|
||||
|
||||
use crate::theme::{Theme, VimMode};
|
||||
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||
use ratatui::{
|
||||
layout::Rect,
|
||||
style::Style,
|
||||
text::{Line, Span},
|
||||
widgets::Paragraph,
|
||||
Frame,
|
||||
};
|
||||
|
||||
/// Input event from the input box
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum InputEvent {
|
||||
/// User submitted a message
|
||||
Message(String),
|
||||
/// User submitted a command (without / prefix)
|
||||
Command(String),
|
||||
/// Mode changed
|
||||
ModeChange(VimMode),
|
||||
/// Request to cancel current operation
|
||||
Cancel,
|
||||
/// Request to expand input (multiline)
|
||||
Expand,
|
||||
}
|
||||
|
||||
/// Vim-modal input box
|
||||
pub struct InputBox {
|
||||
input: String,
|
||||
cursor_position: usize,
|
||||
history: Vec<String>,
|
||||
history_index: usize,
|
||||
mode: VimMode,
|
||||
theme: Theme,
|
||||
}
|
||||
|
||||
impl InputBox {
|
||||
pub fn new(theme: Theme) -> Self {
|
||||
Self {
|
||||
input: String::new(),
|
||||
cursor_position: 0,
|
||||
history: Vec::new(),
|
||||
history_index: 0,
|
||||
mode: VimMode::Insert, // Start in insert mode for familiarity
|
||||
theme,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current vim mode
|
||||
pub fn mode(&self) -> VimMode {
|
||||
self.mode
|
||||
}
|
||||
|
||||
/// Set vim mode
|
||||
pub fn set_mode(&mut self, mode: VimMode) {
|
||||
self.mode = mode;
|
||||
}
|
||||
|
||||
/// Handle key event, returns input event if action is needed
|
||||
pub fn handle_key(&mut self, key: KeyEvent) -> Option<InputEvent> {
|
||||
match self.mode {
|
||||
VimMode::Normal => self.handle_normal_mode(key),
|
||||
VimMode::Insert => self.handle_insert_mode(key),
|
||||
VimMode::Command => self.handle_command_mode(key),
|
||||
VimMode::Visual => self.handle_visual_mode(key),
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle keys in normal mode
|
||||
fn handle_normal_mode(&mut self, key: KeyEvent) -> Option<InputEvent> {
|
||||
match key.code {
|
||||
// Enter insert mode
|
||||
KeyCode::Char('i') => {
|
||||
self.mode = VimMode::Insert;
|
||||
Some(InputEvent::ModeChange(VimMode::Insert))
|
||||
}
|
||||
KeyCode::Char('a') => {
|
||||
self.mode = VimMode::Insert;
|
||||
if self.cursor_position < self.input.len() {
|
||||
self.cursor_position += 1;
|
||||
}
|
||||
Some(InputEvent::ModeChange(VimMode::Insert))
|
||||
}
|
||||
KeyCode::Char('I') => {
|
||||
self.mode = VimMode::Insert;
|
||||
self.cursor_position = 0;
|
||||
Some(InputEvent::ModeChange(VimMode::Insert))
|
||||
}
|
||||
KeyCode::Char('A') => {
|
||||
self.mode = VimMode::Insert;
|
||||
self.cursor_position = self.input.len();
|
||||
Some(InputEvent::ModeChange(VimMode::Insert))
|
||||
}
|
||||
// Enter command mode
|
||||
KeyCode::Char(':') => {
|
||||
self.mode = VimMode::Command;
|
||||
self.input.clear();
|
||||
self.cursor_position = 0;
|
||||
Some(InputEvent::ModeChange(VimMode::Command))
|
||||
}
|
||||
// Navigation
|
||||
KeyCode::Char('h') | KeyCode::Left => {
|
||||
self.cursor_position = self.cursor_position.saturating_sub(1);
|
||||
None
|
||||
}
|
||||
KeyCode::Char('l') | KeyCode::Right => {
|
||||
if self.cursor_position < self.input.len() {
|
||||
self.cursor_position += 1;
|
||||
}
|
||||
None
|
||||
}
|
||||
KeyCode::Char('0') | KeyCode::Home => {
|
||||
self.cursor_position = 0;
|
||||
None
|
||||
}
|
||||
KeyCode::Char('$') | KeyCode::End => {
|
||||
self.cursor_position = self.input.len();
|
||||
None
|
||||
}
|
||||
KeyCode::Char('w') => {
|
||||
// Jump to next word
|
||||
self.cursor_position = self.next_word_position();
|
||||
None
|
||||
}
|
||||
KeyCode::Char('b') => {
|
||||
// Jump to previous word
|
||||
self.cursor_position = self.prev_word_position();
|
||||
None
|
||||
}
|
||||
// Editing
|
||||
KeyCode::Char('x') => {
|
||||
if self.cursor_position < self.input.len() {
|
||||
self.input.remove(self.cursor_position);
|
||||
}
|
||||
None
|
||||
}
|
||||
KeyCode::Char('d') => {
|
||||
// Delete line (dd would require tracking, simplify to clear)
|
||||
self.input.clear();
|
||||
self.cursor_position = 0;
|
||||
None
|
||||
}
|
||||
// History
|
||||
KeyCode::Char('k') | KeyCode::Up => {
|
||||
self.history_prev();
|
||||
None
|
||||
}
|
||||
KeyCode::Char('j') | KeyCode::Down => {
|
||||
self.history_next();
|
||||
None
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle keys in insert mode
|
||||
fn handle_insert_mode(&mut self, key: KeyEvent) -> Option<InputEvent> {
|
||||
match key.code {
|
||||
KeyCode::Esc => {
|
||||
self.mode = VimMode::Normal;
|
||||
// Move cursor back when exiting insert mode (vim behavior)
|
||||
if self.cursor_position > 0 {
|
||||
self.cursor_position -= 1;
|
||||
}
|
||||
Some(InputEvent::ModeChange(VimMode::Normal))
|
||||
}
|
||||
KeyCode::Enter => {
|
||||
let message = self.input.clone();
|
||||
if !message.trim().is_empty() {
|
||||
self.history.push(message.clone());
|
||||
self.history_index = self.history.len();
|
||||
self.input.clear();
|
||||
self.cursor_position = 0;
|
||||
return Some(InputEvent::Message(message));
|
||||
}
|
||||
None
|
||||
}
|
||||
KeyCode::Char('e') if key.modifiers.contains(KeyModifiers::CONTROL) => {
|
||||
Some(InputEvent::Expand)
|
||||
}
|
||||
KeyCode::Char('c') if key.modifiers.contains(KeyModifiers::CONTROL) => {
|
||||
Some(InputEvent::Cancel)
|
||||
}
|
||||
KeyCode::Char(c) => {
|
||||
self.input.insert(self.cursor_position, c);
|
||||
self.cursor_position += 1;
|
||||
None
|
||||
}
|
||||
KeyCode::Backspace => {
|
||||
if self.cursor_position > 0 {
|
||||
self.input.remove(self.cursor_position - 1);
|
||||
self.cursor_position -= 1;
|
||||
}
|
||||
None
|
||||
}
|
||||
KeyCode::Delete => {
|
||||
if self.cursor_position < self.input.len() {
|
||||
self.input.remove(self.cursor_position);
|
||||
}
|
||||
None
|
||||
}
|
||||
KeyCode::Left => {
|
||||
self.cursor_position = self.cursor_position.saturating_sub(1);
|
||||
None
|
||||
}
|
||||
KeyCode::Right => {
|
||||
if self.cursor_position < self.input.len() {
|
||||
self.cursor_position += 1;
|
||||
}
|
||||
None
|
||||
}
|
||||
KeyCode::Home => {
|
||||
self.cursor_position = 0;
|
||||
None
|
||||
}
|
||||
KeyCode::End => {
|
||||
self.cursor_position = self.input.len();
|
||||
None
|
||||
}
|
||||
KeyCode::Up => {
|
||||
self.history_prev();
|
||||
None
|
||||
}
|
||||
KeyCode::Down => {
|
||||
self.history_next();
|
||||
None
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle keys in command mode
|
||||
fn handle_command_mode(&mut self, key: KeyEvent) -> Option<InputEvent> {
|
||||
match key.code {
|
||||
KeyCode::Esc => {
|
||||
self.mode = VimMode::Normal;
|
||||
self.input.clear();
|
||||
self.cursor_position = 0;
|
||||
Some(InputEvent::ModeChange(VimMode::Normal))
|
||||
}
|
||||
KeyCode::Enter => {
|
||||
let command = self.input.clone();
|
||||
self.mode = VimMode::Normal;
|
||||
self.input.clear();
|
||||
self.cursor_position = 0;
|
||||
if !command.trim().is_empty() {
|
||||
return Some(InputEvent::Command(command));
|
||||
}
|
||||
Some(InputEvent::ModeChange(VimMode::Normal))
|
||||
}
|
||||
KeyCode::Char(c) => {
|
||||
self.input.insert(self.cursor_position, c);
|
||||
self.cursor_position += 1;
|
||||
None
|
||||
}
|
||||
KeyCode::Backspace => {
|
||||
if self.cursor_position > 0 {
|
||||
self.input.remove(self.cursor_position - 1);
|
||||
self.cursor_position -= 1;
|
||||
} else {
|
||||
// Empty command, exit to normal mode
|
||||
self.mode = VimMode::Normal;
|
||||
return Some(InputEvent::ModeChange(VimMode::Normal));
|
||||
}
|
||||
None
|
||||
}
|
||||
KeyCode::Left => {
|
||||
self.cursor_position = self.cursor_position.saturating_sub(1);
|
||||
None
|
||||
}
|
||||
KeyCode::Right => {
|
||||
if self.cursor_position < self.input.len() {
|
||||
self.cursor_position += 1;
|
||||
}
|
||||
None
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle keys in visual mode (simplified)
|
||||
fn handle_visual_mode(&mut self, key: KeyEvent) -> Option<InputEvent> {
|
||||
match key.code {
|
||||
KeyCode::Esc => {
|
||||
self.mode = VimMode::Normal;
|
||||
Some(InputEvent::ModeChange(VimMode::Normal))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// History navigation - previous
|
||||
fn history_prev(&mut self) {
|
||||
if !self.history.is_empty() && self.history_index > 0 {
|
||||
self.history_index -= 1;
|
||||
self.input = self.history[self.history_index].clone();
|
||||
self.cursor_position = self.input.len();
|
||||
}
|
||||
}
|
||||
|
||||
/// History navigation - next
|
||||
fn history_next(&mut self) {
|
||||
if self.history_index < self.history.len().saturating_sub(1) {
|
||||
self.history_index += 1;
|
||||
self.input = self.history[self.history_index].clone();
|
||||
self.cursor_position = self.input.len();
|
||||
} else if self.history_index < self.history.len() {
|
||||
self.history_index = self.history.len();
|
||||
self.input.clear();
|
||||
self.cursor_position = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Find next word position
|
||||
fn next_word_position(&self) -> usize {
|
||||
let bytes = self.input.as_bytes();
|
||||
let mut pos = self.cursor_position;
|
||||
|
||||
// Skip current word
|
||||
while pos < bytes.len() && !bytes[pos].is_ascii_whitespace() {
|
||||
pos += 1;
|
||||
}
|
||||
// Skip whitespace
|
||||
while pos < bytes.len() && bytes[pos].is_ascii_whitespace() {
|
||||
pos += 1;
|
||||
}
|
||||
pos
|
||||
}
|
||||
|
||||
/// Find previous word position
|
||||
fn prev_word_position(&self) -> usize {
|
||||
let bytes = self.input.as_bytes();
|
||||
let mut pos = self.cursor_position.saturating_sub(1);
|
||||
|
||||
// Skip whitespace
|
||||
while pos > 0 && bytes[pos].is_ascii_whitespace() {
|
||||
pos -= 1;
|
||||
}
|
||||
// Skip to start of word
|
||||
while pos > 0 && !bytes[pos - 1].is_ascii_whitespace() {
|
||||
pos -= 1;
|
||||
}
|
||||
pos
|
||||
}
|
||||
|
||||
/// Render the borderless input (single line)
|
||||
pub fn render(&self, frame: &mut Frame, area: Rect) {
|
||||
let is_empty = self.input.is_empty();
|
||||
let symbols = &self.theme.symbols;
|
||||
|
||||
// Mode-specific prefix
|
||||
let prefix = match self.mode {
|
||||
VimMode::Normal => Span::styled(
|
||||
format!("{} ", symbols.mode_normal),
|
||||
self.theme.status_dim,
|
||||
),
|
||||
VimMode::Insert => Span::styled(
|
||||
format!("{} ", symbols.user_prefix),
|
||||
self.theme.input_prefix,
|
||||
),
|
||||
VimMode::Command => Span::styled(
|
||||
": ",
|
||||
self.theme.input_prefix,
|
||||
),
|
||||
VimMode::Visual => Span::styled(
|
||||
format!("{} ", symbols.mode_visual),
|
||||
self.theme.status_accent,
|
||||
),
|
||||
};
|
||||
|
||||
// Cursor position handling
|
||||
let (text_before, cursor_char, text_after) = if self.cursor_position < self.input.len() {
|
||||
let before = &self.input[..self.cursor_position];
|
||||
let cursor = &self.input[self.cursor_position..self.cursor_position + 1];
|
||||
let after = &self.input[self.cursor_position + 1..];
|
||||
(before, cursor, after)
|
||||
} else {
|
||||
(&self.input[..], " ", "")
|
||||
};
|
||||
|
||||
let line = if is_empty && self.mode == VimMode::Insert {
|
||||
Line::from(vec![
|
||||
Span::raw(" "),
|
||||
prefix,
|
||||
Span::styled("▊", self.theme.input_prefix),
|
||||
Span::styled(" Type message...", self.theme.input_placeholder),
|
||||
])
|
||||
} else if is_empty && self.mode == VimMode::Command {
|
||||
Line::from(vec![
|
||||
Span::raw(" "),
|
||||
prefix,
|
||||
Span::styled("▊", self.theme.input_prefix),
|
||||
])
|
||||
} else {
|
||||
// Build cursor span with appropriate styling
|
||||
let cursor_style = if self.mode == VimMode::Normal {
|
||||
Style::default()
|
||||
.bg(self.theme.palette.fg)
|
||||
.fg(self.theme.palette.bg)
|
||||
} else {
|
||||
self.theme.input_prefix
|
||||
};
|
||||
|
||||
let cursor_span = if self.mode == VimMode::Normal && !is_empty {
|
||||
Span::styled(cursor_char.to_string(), cursor_style)
|
||||
} else {
|
||||
Span::styled("▊", self.theme.input_prefix)
|
||||
};
|
||||
|
||||
Line::from(vec![
|
||||
Span::raw(" "),
|
||||
prefix,
|
||||
Span::styled(text_before.to_string(), self.theme.input_text),
|
||||
cursor_span,
|
||||
Span::styled(text_after.to_string(), self.theme.input_text),
|
||||
])
|
||||
};
|
||||
|
||||
let paragraph = Paragraph::new(line);
|
||||
frame.render_widget(paragraph, area);
|
||||
}
|
||||
|
||||
/// Clear input
|
||||
pub fn clear(&mut self) {
|
||||
self.input.clear();
|
||||
self.cursor_position = 0;
|
||||
}
|
||||
|
||||
/// Get current input text
|
||||
pub fn text(&self) -> &str {
|
||||
&self.input
|
||||
}
|
||||
|
||||
/// Set input text
|
||||
pub fn set_text(&mut self, text: String) {
|
||||
self.input = text;
|
||||
self.cursor_position = self.input.len();
|
||||
}
|
||||
|
||||
/// Update theme
|
||||
pub fn set_theme(&mut self, theme: Theme) {
|
||||
self.theme = theme;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_mode_transitions() {
|
||||
let theme = Theme::default();
|
||||
let mut input = InputBox::new(theme);
|
||||
|
||||
// Start in insert mode
|
||||
assert_eq!(input.mode(), VimMode::Insert);
|
||||
|
||||
// Escape to normal mode
|
||||
let event = input.handle_key(KeyEvent::from(KeyCode::Esc));
|
||||
assert!(matches!(event, Some(InputEvent::ModeChange(VimMode::Normal))));
|
||||
assert_eq!(input.mode(), VimMode::Normal);
|
||||
|
||||
// 'i' to insert mode
|
||||
let event = input.handle_key(KeyEvent::from(KeyCode::Char('i')));
|
||||
assert!(matches!(event, Some(InputEvent::ModeChange(VimMode::Insert))));
|
||||
assert_eq!(input.mode(), VimMode::Insert);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_text() {
|
||||
let theme = Theme::default();
|
||||
let mut input = InputBox::new(theme);
|
||||
|
||||
input.handle_key(KeyEvent::from(KeyCode::Char('h')));
|
||||
input.handle_key(KeyEvent::from(KeyCode::Char('i')));
|
||||
|
||||
assert_eq!(input.text(), "hi");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_command_mode() {
|
||||
let theme = Theme::default();
|
||||
let mut input = InputBox::new(theme);
|
||||
|
||||
// Escape to normal, then : to command
|
||||
input.handle_key(KeyEvent::from(KeyCode::Esc));
|
||||
input.handle_key(KeyEvent::from(KeyCode::Char(':')));
|
||||
|
||||
assert_eq!(input.mode(), VimMode::Command);
|
||||
|
||||
// Type command
|
||||
input.handle_key(KeyEvent::from(KeyCode::Char('q')));
|
||||
input.handle_key(KeyEvent::from(KeyCode::Char('u')));
|
||||
input.handle_key(KeyEvent::from(KeyCode::Char('i')));
|
||||
input.handle_key(KeyEvent::from(KeyCode::Char('t')));
|
||||
|
||||
assert_eq!(input.text(), "quit");
|
||||
|
||||
// Submit command
|
||||
let event = input.handle_key(KeyEvent::from(KeyCode::Enter));
|
||||
assert!(matches!(event, Some(InputEvent::Command(cmd)) if cmd == "quit"));
|
||||
}
|
||||
}
|
||||
19
crates/app/ui/src/components/mod.rs
Normal file
19
crates/app/ui/src/components/mod.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
//! TUI components for the borderless multi-provider design
|
||||
|
||||
mod autocomplete;
|
||||
mod chat_panel;
|
||||
mod command_help;
|
||||
mod input_box;
|
||||
mod permission_popup;
|
||||
mod provider_tabs;
|
||||
mod status_bar;
|
||||
mod todo_panel;
|
||||
|
||||
pub use autocomplete::{Autocomplete, AutocompleteOption, AutocompleteResult};
|
||||
pub use chat_panel::{ChatMessage, ChatPanel, DisplayMessage};
|
||||
pub use command_help::{Command, CommandHelp};
|
||||
pub use input_box::{InputBox, InputEvent};
|
||||
pub use permission_popup::{PermissionOption, PermissionPopup};
|
||||
pub use provider_tabs::ProviderTabs;
|
||||
pub use status_bar::{AppState, StatusBar};
|
||||
pub use todo_panel::TodoPanel;
|
||||
196
crates/app/ui/src/components/permission_popup.rs
Normal file
196
crates/app/ui/src/components/permission_popup.rs
Normal file
@@ -0,0 +1,196 @@
|
||||
use crate::theme::Theme;
|
||||
use crossterm::event::{KeyCode, KeyEvent};
|
||||
use permissions::PermissionDecision;
|
||||
use ratatui::{
|
||||
layout::{Constraint, Direction, Layout, Rect},
|
||||
style::{Modifier, Style},
|
||||
text::{Line, Span},
|
||||
widgets::{Block, Borders, Clear, Paragraph},
|
||||
Frame,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PermissionOption {
|
||||
AllowOnce,
|
||||
AlwaysAllow,
|
||||
Deny,
|
||||
Explain,
|
||||
}
|
||||
|
||||
pub struct PermissionPopup {
|
||||
tool: String,
|
||||
context: Option<String>,
|
||||
selected: usize,
|
||||
theme: Theme,
|
||||
}
|
||||
|
||||
impl PermissionPopup {
|
||||
pub fn new(tool: String, context: Option<String>, theme: Theme) -> Self {
|
||||
Self {
|
||||
tool,
|
||||
context,
|
||||
selected: 0,
|
||||
theme,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn handle_key(&mut self, key: KeyEvent) -> Option<PermissionOption> {
|
||||
match key.code {
|
||||
KeyCode::Char('a') => Some(PermissionOption::AllowOnce),
|
||||
KeyCode::Char('A') => Some(PermissionOption::AlwaysAllow),
|
||||
KeyCode::Char('d') => Some(PermissionOption::Deny),
|
||||
KeyCode::Char('?') => Some(PermissionOption::Explain),
|
||||
KeyCode::Up => {
|
||||
self.selected = self.selected.saturating_sub(1);
|
||||
None
|
||||
}
|
||||
KeyCode::Down => {
|
||||
if self.selected < 3 {
|
||||
self.selected += 1;
|
||||
}
|
||||
None
|
||||
}
|
||||
KeyCode::Enter => match self.selected {
|
||||
0 => Some(PermissionOption::AllowOnce),
|
||||
1 => Some(PermissionOption::AlwaysAllow),
|
||||
2 => Some(PermissionOption::Deny),
|
||||
3 => Some(PermissionOption::Explain),
|
||||
_ => None,
|
||||
},
|
||||
KeyCode::Esc => Some(PermissionOption::Deny),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn render(&self, frame: &mut Frame, area: Rect) {
|
||||
// Center the popup
|
||||
let popup_area = crate::layout::AppLayout::center_popup(area, 64, 14);
|
||||
|
||||
// Clear the area behind the popup
|
||||
frame.render_widget(Clear, popup_area);
|
||||
|
||||
// Render popup with styled border
|
||||
let block = Block::default()
|
||||
.borders(Borders::ALL)
|
||||
.border_style(self.theme.popup_border)
|
||||
.style(self.theme.popup_bg)
|
||||
.title(Line::from(vec![
|
||||
Span::raw(" "),
|
||||
Span::styled("🔒", self.theme.popup_title),
|
||||
Span::raw(" "),
|
||||
Span::styled("Permission Required", self.theme.popup_title),
|
||||
Span::raw(" "),
|
||||
]));
|
||||
|
||||
frame.render_widget(block, popup_area);
|
||||
|
||||
// Split popup into sections
|
||||
let inner = popup_area.inner(ratatui::layout::Margin {
|
||||
vertical: 1,
|
||||
horizontal: 2,
|
||||
});
|
||||
|
||||
let sections = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([
|
||||
Constraint::Length(2), // Tool name with box
|
||||
Constraint::Length(3), // Context (if any)
|
||||
Constraint::Length(1), // Separator
|
||||
Constraint::Length(1), // Option 1
|
||||
Constraint::Length(1), // Option 2
|
||||
Constraint::Length(1), // Option 3
|
||||
Constraint::Length(1), // Option 4
|
||||
Constraint::Length(1), // Help text
|
||||
])
|
||||
.split(inner);
|
||||
|
||||
// Tool name with highlight
|
||||
let tool_line = Line::from(vec![
|
||||
Span::styled("⚡ Tool: ", Style::default().fg(self.theme.palette.warning)),
|
||||
Span::styled(&self.tool, self.theme.popup_title),
|
||||
]);
|
||||
frame.render_widget(Paragraph::new(tool_line), sections[0]);
|
||||
|
||||
// Context with wrapping
|
||||
if let Some(ctx) = &self.context {
|
||||
let context_text = if ctx.len() > 100 {
|
||||
format!("{}...", &ctx[..100])
|
||||
} else {
|
||||
ctx.clone()
|
||||
};
|
||||
let context_lines = textwrap::wrap(&context_text, (sections[1].width - 2) as usize);
|
||||
let mut lines = vec![
|
||||
Line::from(vec![
|
||||
Span::styled("📝 Context: ", Style::default().fg(self.theme.palette.info)),
|
||||
])
|
||||
];
|
||||
for line in context_lines.iter().take(2) {
|
||||
lines.push(Line::from(vec![
|
||||
Span::raw(" "),
|
||||
Span::styled(line.to_string(), Style::default().fg(self.theme.palette.fg_dim)),
|
||||
]));
|
||||
}
|
||||
frame.render_widget(Paragraph::new(lines), sections[1]);
|
||||
}
|
||||
|
||||
// Separator
|
||||
let separator = Line::styled(
|
||||
"─".repeat(sections[2].width as usize),
|
||||
Style::default().fg(self.theme.palette.divider_fg),
|
||||
);
|
||||
frame.render_widget(Paragraph::new(separator), sections[2]);
|
||||
|
||||
// Options with icons and colors
|
||||
let options = [
|
||||
("✓", " [a] Allow once", self.theme.palette.success, 0),
|
||||
("✓✓", " [A] Always allow", self.theme.palette.primary, 1),
|
||||
("✗", " [d] Deny", self.theme.palette.error, 2),
|
||||
("?", " [?] Explain", self.theme.palette.info, 3),
|
||||
];
|
||||
|
||||
for (icon, text, color, idx) in options.iter() {
|
||||
let (style, prefix) = if self.selected == *idx {
|
||||
(
|
||||
self.theme.selected,
|
||||
"▶ "
|
||||
)
|
||||
} else {
|
||||
(
|
||||
Style::default().fg(*color),
|
||||
" "
|
||||
)
|
||||
};
|
||||
|
||||
let line = Line::from(vec![
|
||||
Span::styled(prefix, style),
|
||||
Span::styled(*icon, style),
|
||||
Span::styled(*text, style),
|
||||
]);
|
||||
frame.render_widget(Paragraph::new(line), sections[3 + idx]);
|
||||
}
|
||||
|
||||
// Help text at bottom
|
||||
let help_line = Line::from(vec![
|
||||
Span::styled(
|
||||
"↑↓ Navigate Enter to select Esc to deny",
|
||||
Style::default().fg(self.theme.palette.fg_dim).add_modifier(Modifier::ITALIC),
|
||||
),
|
||||
]);
|
||||
frame.render_widget(Paragraph::new(help_line), sections[7]);
|
||||
}
|
||||
}
|
||||
|
||||
impl PermissionOption {
|
||||
pub fn to_decision(&self) -> Option<PermissionDecision> {
|
||||
match self {
|
||||
PermissionOption::AllowOnce => Some(PermissionDecision::Allow),
|
||||
PermissionOption::AlwaysAllow => Some(PermissionDecision::Allow),
|
||||
PermissionOption::Deny => Some(PermissionDecision::Deny),
|
||||
PermissionOption::Explain => None, // Special handling needed
|
||||
}
|
||||
}
|
||||
|
||||
pub fn should_persist(&self) -> bool {
|
||||
matches!(self, PermissionOption::AlwaysAllow)
|
||||
}
|
||||
}
|
||||
189
crates/app/ui/src/components/provider_tabs.rs
Normal file
189
crates/app/ui/src/components/provider_tabs.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
//! Provider tabs component for multi-LLM support
|
||||
//!
|
||||
//! Displays horizontal tabs for switching between providers (Claude, Ollama, OpenAI)
|
||||
//! with icons and keybind hints.
|
||||
|
||||
use crate::theme::{Provider, Theme};
|
||||
use ratatui::{
|
||||
layout::Rect,
|
||||
style::Style,
|
||||
text::{Line, Span},
|
||||
widgets::Paragraph,
|
||||
Frame,
|
||||
};
|
||||
|
||||
/// Provider tab state and rendering
|
||||
pub struct ProviderTabs {
|
||||
active: Provider,
|
||||
theme: Theme,
|
||||
}
|
||||
|
||||
impl ProviderTabs {
|
||||
/// Create new provider tabs with default provider
|
||||
pub fn new(theme: Theme) -> Self {
|
||||
Self {
|
||||
active: Provider::Ollama, // Default to Ollama (local)
|
||||
theme,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with specific active provider
|
||||
pub fn with_provider(provider: Provider, theme: Theme) -> Self {
|
||||
Self {
|
||||
active: provider,
|
||||
theme,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the currently active provider
|
||||
pub fn active(&self) -> Provider {
|
||||
self.active
|
||||
}
|
||||
|
||||
/// Set the active provider
|
||||
pub fn set_active(&mut self, provider: Provider) {
|
||||
self.active = provider;
|
||||
}
|
||||
|
||||
/// Cycle to the next provider
|
||||
pub fn next(&mut self) {
|
||||
self.active = match self.active {
|
||||
Provider::Claude => Provider::Ollama,
|
||||
Provider::Ollama => Provider::OpenAI,
|
||||
Provider::OpenAI => Provider::Claude,
|
||||
};
|
||||
}
|
||||
|
||||
/// Cycle to the previous provider
|
||||
pub fn previous(&mut self) {
|
||||
self.active = match self.active {
|
||||
Provider::Claude => Provider::OpenAI,
|
||||
Provider::Ollama => Provider::Claude,
|
||||
Provider::OpenAI => Provider::Ollama,
|
||||
};
|
||||
}
|
||||
|
||||
/// Select provider by number (1, 2, 3)
|
||||
pub fn select_by_number(&mut self, num: u8) {
|
||||
self.active = match num {
|
||||
1 => Provider::Claude,
|
||||
2 => Provider::Ollama,
|
||||
3 => Provider::OpenAI,
|
||||
_ => self.active,
|
||||
};
|
||||
}
|
||||
|
||||
/// Update the theme
|
||||
pub fn set_theme(&mut self, theme: Theme) {
|
||||
self.theme = theme;
|
||||
}
|
||||
|
||||
/// Render the provider tabs (borderless)
|
||||
pub fn render(&self, frame: &mut Frame, area: Rect) {
|
||||
let mut spans = Vec::new();
|
||||
|
||||
// Add spacing at start
|
||||
spans.push(Span::raw(" "));
|
||||
|
||||
for (i, provider) in Provider::all().iter().enumerate() {
|
||||
let is_active = *provider == self.active;
|
||||
let icon = self.theme.provider_icon(*provider);
|
||||
let name = provider.name();
|
||||
let number = (i + 1).to_string();
|
||||
|
||||
// Keybind hint
|
||||
spans.push(Span::styled(
|
||||
format!("[{}] ", number),
|
||||
self.theme.status_dim,
|
||||
));
|
||||
|
||||
// Icon and name
|
||||
let style = if is_active {
|
||||
Style::default()
|
||||
.fg(self.theme.provider_color(*provider))
|
||||
.add_modifier(ratatui::style::Modifier::BOLD)
|
||||
} else {
|
||||
self.theme.tab_inactive
|
||||
};
|
||||
|
||||
spans.push(Span::styled(format!("{} ", icon), style));
|
||||
spans.push(Span::styled(name.to_string(), style));
|
||||
|
||||
// Separator between tabs (not after last)
|
||||
if i < Provider::all().len() - 1 {
|
||||
spans.push(Span::styled(
|
||||
format!(" {} ", self.theme.symbols.vertical_separator),
|
||||
self.theme.status_dim,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Tab cycling hint on the right
|
||||
spans.push(Span::raw(" "));
|
||||
spans.push(Span::styled("[Tab] cycle", self.theme.status_dim));
|
||||
|
||||
let line = Line::from(spans);
|
||||
let paragraph = Paragraph::new(line);
|
||||
frame.render_widget(paragraph, area);
|
||||
}
|
||||
|
||||
/// Render a compact version (just active provider)
|
||||
pub fn render_compact(&self, frame: &mut Frame, area: Rect) {
|
||||
let icon = self.theme.provider_icon(self.active);
|
||||
let name = self.active.name();
|
||||
|
||||
let line = Line::from(vec![
|
||||
Span::raw(" "),
|
||||
Span::styled(
|
||||
format!("{} {}", icon, name),
|
||||
Style::default()
|
||||
.fg(self.theme.provider_color(self.active))
|
||||
.add_modifier(ratatui::style::Modifier::BOLD),
|
||||
),
|
||||
]);
|
||||
|
||||
let paragraph = Paragraph::new(line);
|
||||
frame.render_widget(paragraph, area);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_provider_cycling() {
|
||||
let theme = Theme::default();
|
||||
let mut tabs = ProviderTabs::new(theme);
|
||||
|
||||
assert_eq!(tabs.active(), Provider::Ollama);
|
||||
|
||||
tabs.next();
|
||||
assert_eq!(tabs.active(), Provider::OpenAI);
|
||||
|
||||
tabs.next();
|
||||
assert_eq!(tabs.active(), Provider::Claude);
|
||||
|
||||
tabs.next();
|
||||
assert_eq!(tabs.active(), Provider::Ollama);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_by_number() {
|
||||
let theme = Theme::default();
|
||||
let mut tabs = ProviderTabs::new(theme);
|
||||
|
||||
tabs.select_by_number(1);
|
||||
assert_eq!(tabs.active(), Provider::Claude);
|
||||
|
||||
tabs.select_by_number(2);
|
||||
assert_eq!(tabs.active(), Provider::Ollama);
|
||||
|
||||
tabs.select_by_number(3);
|
||||
assert_eq!(tabs.active(), Provider::OpenAI);
|
||||
|
||||
// Invalid number should not change
|
||||
tabs.select_by_number(4);
|
||||
assert_eq!(tabs.active(), Provider::OpenAI);
|
||||
}
|
||||
}
|
||||
188
crates/app/ui/src/components/status_bar.rs
Normal file
188
crates/app/ui/src/components/status_bar.rs
Normal file
@@ -0,0 +1,188 @@
|
||||
//! Minimal status bar component
|
||||
//!
|
||||
//! Clean, readable status bar with essential info only.
|
||||
//! Format: ` Mode │ N msgs │ ~Nk tok │ state`
|
||||
|
||||
use crate::theme::{Provider, Theme, VimMode};
|
||||
use agent_core::SessionStats;
|
||||
use permissions::Mode;
|
||||
use ratatui::{
|
||||
layout::Rect,
|
||||
style::Style,
|
||||
text::{Line, Span},
|
||||
widgets::Paragraph,
|
||||
Frame,
|
||||
};
|
||||
|
||||
/// Application state for status display
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AppState {
|
||||
Idle,
|
||||
Streaming,
|
||||
WaitingPermission,
|
||||
Error,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub fn label(&self) -> &'static str {
|
||||
match self {
|
||||
AppState::Idle => "idle",
|
||||
AppState::Streaming => "streaming...",
|
||||
AppState::WaitingPermission => "waiting",
|
||||
AppState::Error => "error",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StatusBar {
|
||||
provider: Provider,
|
||||
model: String,
|
||||
mode: Mode,
|
||||
vim_mode: VimMode,
|
||||
stats: SessionStats,
|
||||
last_tool: Option<String>,
|
||||
state: AppState,
|
||||
estimated_cost: f64,
|
||||
planning_mode: bool,
|
||||
theme: Theme,
|
||||
}
|
||||
|
||||
impl StatusBar {
|
||||
pub fn new(model: String, mode: Mode, theme: Theme) -> Self {
|
||||
Self {
|
||||
provider: Provider::Ollama, // Default provider
|
||||
model,
|
||||
mode,
|
||||
vim_mode: VimMode::Insert,
|
||||
stats: SessionStats::new(),
|
||||
last_tool: None,
|
||||
state: AppState::Idle,
|
||||
estimated_cost: 0.0,
|
||||
planning_mode: false,
|
||||
theme,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the active provider
|
||||
pub fn set_provider(&mut self, provider: Provider) {
|
||||
self.provider = provider;
|
||||
}
|
||||
|
||||
/// Set the current model
|
||||
pub fn set_model(&mut self, model: String) {
|
||||
self.model = model;
|
||||
}
|
||||
|
||||
/// Update session stats
|
||||
pub fn update_stats(&mut self, stats: SessionStats) {
|
||||
self.stats = stats;
|
||||
}
|
||||
|
||||
/// Set the last used tool
|
||||
pub fn set_last_tool(&mut self, tool: String) {
|
||||
self.last_tool = Some(tool);
|
||||
}
|
||||
|
||||
/// Set application state
|
||||
pub fn set_state(&mut self, state: AppState) {
|
||||
self.state = state;
|
||||
}
|
||||
|
||||
/// Set vim mode for display
|
||||
pub fn set_vim_mode(&mut self, mode: VimMode) {
|
||||
self.vim_mode = mode;
|
||||
}
|
||||
|
||||
/// Add to estimated cost
|
||||
pub fn add_cost(&mut self, cost: f64) {
|
||||
self.estimated_cost += cost;
|
||||
}
|
||||
|
||||
/// Reset cost
|
||||
pub fn reset_cost(&mut self) {
|
||||
self.estimated_cost = 0.0;
|
||||
}
|
||||
|
||||
/// Update theme
|
||||
pub fn set_theme(&mut self, theme: Theme) {
|
||||
self.theme = theme;
|
||||
}
|
||||
|
||||
/// Set planning mode status
|
||||
pub fn set_planning_mode(&mut self, active: bool) {
|
||||
self.planning_mode = active;
|
||||
}
|
||||
|
||||
/// Render the minimal status bar
|
||||
///
|
||||
/// Format: ` Mode │ N msgs │ ~Nk tok │ state`
|
||||
pub fn render(&self, frame: &mut Frame, area: Rect) {
|
||||
let sep = self.theme.symbols.vertical_separator;
|
||||
let sep_style = Style::default().fg(self.theme.palette.border);
|
||||
|
||||
// Permission mode
|
||||
let mode_str = if self.planning_mode {
|
||||
"PLAN"
|
||||
} else {
|
||||
match self.mode {
|
||||
Mode::Plan => "Plan",
|
||||
Mode::AcceptEdits => "Edit",
|
||||
Mode::Code => "Code",
|
||||
}
|
||||
};
|
||||
|
||||
// Format token count
|
||||
let tokens_str = if self.stats.estimated_tokens >= 1000 {
|
||||
format!("~{}k tok", self.stats.estimated_tokens / 1000)
|
||||
} else {
|
||||
format!("~{} tok", self.stats.estimated_tokens)
|
||||
};
|
||||
|
||||
// State style - only highlight non-idle states
|
||||
let state_style = match self.state {
|
||||
AppState::Idle => self.theme.status_dim,
|
||||
AppState::Streaming => Style::default().fg(self.theme.palette.success),
|
||||
AppState::WaitingPermission => Style::default().fg(self.theme.palette.warning),
|
||||
AppState::Error => Style::default().fg(self.theme.palette.error),
|
||||
};
|
||||
|
||||
// Build minimal status line
|
||||
let spans = vec![
|
||||
Span::styled(" ", self.theme.status_dim),
|
||||
// Mode
|
||||
Span::styled(mode_str, self.theme.status_dim),
|
||||
Span::styled(format!(" {} ", sep), sep_style),
|
||||
// Message count
|
||||
Span::styled(format!("{} msgs", self.stats.total_messages), self.theme.status_dim),
|
||||
Span::styled(format!(" {} ", sep), sep_style),
|
||||
// Token count
|
||||
Span::styled(&tokens_str, self.theme.status_dim),
|
||||
Span::styled(format!(" {} ", sep), sep_style),
|
||||
// State
|
||||
Span::styled(self.state.label(), state_style),
|
||||
];
|
||||
|
||||
let line = Line::from(spans);
|
||||
let paragraph = Paragraph::new(line);
|
||||
frame.render_widget(paragraph, area);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_status_bar_creation() {
|
||||
let theme = Theme::default();
|
||||
let status_bar = StatusBar::new("gpt-4".to_string(), Mode::Plan, theme);
|
||||
assert_eq!(status_bar.model, "gpt-4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_app_state_display() {
|
||||
assert_eq!(AppState::Idle.label(), "idle");
|
||||
assert_eq!(AppState::Streaming.label(), "streaming...");
|
||||
assert_eq!(AppState::Error.label(), "error");
|
||||
}
|
||||
}
|
||||
200
crates/app/ui/src/components/todo_panel.rs
Normal file
200
crates/app/ui/src/components/todo_panel.rs
Normal file
@@ -0,0 +1,200 @@
|
||||
//! Todo panel component for displaying task list
|
||||
//!
|
||||
//! Shows the current todo list with status indicators and progress.
|
||||
|
||||
use ratatui::{
|
||||
layout::Rect,
|
||||
style::{Color, Modifier, Style},
|
||||
text::{Line, Span},
|
||||
widgets::{Block, Borders, Paragraph},
|
||||
Frame,
|
||||
};
|
||||
use tools_todo::{Todo, TodoList, TodoStatus};
|
||||
|
||||
use crate::theme::Theme;
|
||||
|
||||
/// Todo panel component
|
||||
pub struct TodoPanel {
|
||||
theme: Theme,
|
||||
collapsed: bool,
|
||||
}
|
||||
|
||||
impl TodoPanel {
|
||||
pub fn new(theme: Theme) -> Self {
|
||||
Self {
|
||||
theme,
|
||||
collapsed: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Toggle collapsed state
|
||||
pub fn toggle(&mut self) {
|
||||
self.collapsed = !self.collapsed;
|
||||
}
|
||||
|
||||
/// Check if collapsed
|
||||
pub fn is_collapsed(&self) -> bool {
|
||||
self.collapsed
|
||||
}
|
||||
|
||||
/// Update theme
|
||||
pub fn set_theme(&mut self, theme: Theme) {
|
||||
self.theme = theme;
|
||||
}
|
||||
|
||||
/// Get the minimum height needed for the panel
|
||||
pub fn min_height(&self) -> u16 {
|
||||
if self.collapsed {
|
||||
1
|
||||
} else {
|
||||
5
|
||||
}
|
||||
}
|
||||
|
||||
/// Render the todo panel
|
||||
pub fn render(&self, frame: &mut Frame, area: Rect, todos: &TodoList) {
|
||||
if self.collapsed {
|
||||
self.render_collapsed(frame, area, todos);
|
||||
} else {
|
||||
self.render_expanded(frame, area, todos);
|
||||
}
|
||||
}
|
||||
|
||||
/// Render collapsed view (single line summary)
|
||||
fn render_collapsed(&self, frame: &mut Frame, area: Rect, todos: &TodoList) {
|
||||
let items = todos.read();
|
||||
let completed = items.iter().filter(|t| t.status == TodoStatus::Completed).count();
|
||||
let in_progress = items.iter().filter(|t| t.status == TodoStatus::InProgress).count();
|
||||
let pending = items.iter().filter(|t| t.status == TodoStatus::Pending).count();
|
||||
|
||||
let summary = if items.is_empty() {
|
||||
"No tasks".to_string()
|
||||
} else {
|
||||
format!(
|
||||
"{} {} / {} {} / {} {}",
|
||||
self.theme.symbols.check, completed,
|
||||
self.theme.symbols.streaming, in_progress,
|
||||
self.theme.symbols.bullet, pending
|
||||
)
|
||||
};
|
||||
|
||||
let line = Line::from(vec![
|
||||
Span::styled("Tasks: ", self.theme.status_bar),
|
||||
Span::styled(summary, self.theme.status_dim),
|
||||
Span::styled(" [t to expand]", self.theme.status_dim),
|
||||
]);
|
||||
|
||||
let paragraph = Paragraph::new(line);
|
||||
frame.render_widget(paragraph, area);
|
||||
}
|
||||
|
||||
/// Render expanded view with task list
|
||||
fn render_expanded(&self, frame: &mut Frame, area: Rect, todos: &TodoList) {
|
||||
let items = todos.read();
|
||||
|
||||
let mut lines: Vec<Line> = Vec::new();
|
||||
|
||||
// Header
|
||||
lines.push(Line::from(vec![
|
||||
Span::styled("Tasks", Style::default().add_modifier(Modifier::BOLD)),
|
||||
Span::styled(" [t to collapse]", self.theme.status_dim),
|
||||
]));
|
||||
|
||||
if items.is_empty() {
|
||||
lines.push(Line::from(Span::styled(
|
||||
" No active tasks",
|
||||
self.theme.status_dim,
|
||||
)));
|
||||
} else {
|
||||
// Show tasks (limit to available space)
|
||||
let max_items = (area.height as usize).saturating_sub(2);
|
||||
let display_items: Vec<&Todo> = items.iter().take(max_items).collect();
|
||||
|
||||
for item in display_items {
|
||||
let (icon, style) = match item.status {
|
||||
TodoStatus::Completed => (
|
||||
self.theme.symbols.check,
|
||||
Style::default().fg(Color::Green),
|
||||
),
|
||||
TodoStatus::InProgress => (
|
||||
self.theme.symbols.streaming,
|
||||
Style::default().fg(Color::Yellow),
|
||||
),
|
||||
TodoStatus::Pending => (
|
||||
self.theme.symbols.bullet,
|
||||
self.theme.status_dim,
|
||||
),
|
||||
};
|
||||
|
||||
// Use active form for in-progress, content for others
|
||||
let text = if item.status == TodoStatus::InProgress {
|
||||
&item.active_form
|
||||
} else {
|
||||
&item.content
|
||||
};
|
||||
|
||||
// Truncate if too long
|
||||
let max_width = area.width.saturating_sub(6) as usize;
|
||||
let display_text = if text.len() > max_width {
|
||||
format!("{}...", &text[..max_width.saturating_sub(3)])
|
||||
} else {
|
||||
text.clone()
|
||||
};
|
||||
|
||||
lines.push(Line::from(vec![
|
||||
Span::styled(format!(" {} ", icon), style),
|
||||
Span::styled(display_text, style),
|
||||
]));
|
||||
}
|
||||
|
||||
// Show overflow indicator if needed
|
||||
if items.len() > max_items {
|
||||
lines.push(Line::from(Span::styled(
|
||||
format!(" ... and {} more", items.len() - max_items),
|
||||
self.theme.status_dim,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let block = Block::default()
|
||||
.borders(Borders::TOP)
|
||||
.border_style(self.theme.status_dim);
|
||||
|
||||
let paragraph = Paragraph::new(lines).block(block);
|
||||
frame.render_widget(paragraph, area);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_todo_panel_creation() {
|
||||
let theme = Theme::default();
|
||||
let panel = TodoPanel::new(theme);
|
||||
assert!(!panel.is_collapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_todo_panel_toggle() {
|
||||
let theme = Theme::default();
|
||||
let mut panel = TodoPanel::new(theme);
|
||||
|
||||
assert!(!panel.is_collapsed());
|
||||
panel.toggle();
|
||||
assert!(panel.is_collapsed());
|
||||
panel.toggle();
|
||||
assert!(!panel.is_collapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_min_height() {
|
||||
let theme = Theme::default();
|
||||
let mut panel = TodoPanel::new(theme);
|
||||
|
||||
assert_eq!(panel.min_height(), 5);
|
||||
panel.toggle();
|
||||
assert_eq!(panel.min_height(), 1);
|
||||
}
|
||||
}
|
||||
53
crates/app/ui/src/events.rs
Normal file
53
crates/app/ui/src/events.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||
use serde_json::Value;
|
||||
|
||||
/// Application events that drive the TUI
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AppEvent {
|
||||
/// User input from keyboard
|
||||
Input(KeyEvent),
|
||||
/// User submitted a message
|
||||
UserMessage(String),
|
||||
/// LLM streaming started
|
||||
StreamStart,
|
||||
/// LLM response chunk (streaming)
|
||||
LlmChunk(String),
|
||||
/// LLM streaming completed
|
||||
StreamEnd { response: String },
|
||||
/// LLM streaming error
|
||||
StreamError(String),
|
||||
/// Tool call started
|
||||
ToolCall { name: String, args: Value },
|
||||
/// Tool execution result
|
||||
ToolResult { success: bool, output: String },
|
||||
/// Permission request from agent
|
||||
PermissionRequest {
|
||||
tool: String,
|
||||
context: Option<String>,
|
||||
},
|
||||
/// Session statistics updated
|
||||
StatusUpdate(agent_core::SessionStats),
|
||||
/// Terminal was resized
|
||||
Resize { width: u16, height: u16 },
|
||||
/// Mouse scroll up
|
||||
ScrollUp,
|
||||
/// Mouse scroll down
|
||||
ScrollDown,
|
||||
/// Toggle the todo panel
|
||||
ToggleTodo,
|
||||
/// Application should quit
|
||||
Quit,
|
||||
}
|
||||
|
||||
/// Process keyboard input into app events
|
||||
pub fn handle_key_event(key: KeyEvent) -> Option<AppEvent> {
|
||||
match key.code {
|
||||
KeyCode::Char('c') if key.modifiers.contains(KeyModifiers::CONTROL) => {
|
||||
Some(AppEvent::Quit)
|
||||
}
|
||||
KeyCode::Char('t') if key.modifiers.contains(KeyModifiers::CONTROL) => {
|
||||
Some(AppEvent::ToggleTodo)
|
||||
}
|
||||
_ => Some(AppEvent::Input(key)),
|
||||
}
|
||||
}
|
||||
532
crates/app/ui/src/formatting.rs
Normal file
532
crates/app/ui/src/formatting.rs
Normal file
@@ -0,0 +1,532 @@
|
||||
//! Output formatting with markdown parsing and syntax highlighting
|
||||
//!
|
||||
//! This module provides rich text rendering for the TUI, converting markdown
|
||||
//! content into styled ratatui spans with proper syntax highlighting for code blocks.
|
||||
|
||||
use pulldown_cmark::{CodeBlockKind, Event, Parser, Tag, TagEnd};
|
||||
use ratatui::style::{Color, Modifier, Style};
|
||||
use ratatui::text::{Line, Span};
|
||||
use syntect::easy::HighlightLines;
|
||||
use syntect::highlighting::{Theme, ThemeSet};
|
||||
use syntect::parsing::SyntaxSet;
|
||||
use syntect::util::LinesWithEndings;
|
||||
|
||||
/// Highlighter for syntax highlighting code blocks
|
||||
pub struct SyntaxHighlighter {
|
||||
syntax_set: SyntaxSet,
|
||||
theme: Theme,
|
||||
}
|
||||
|
||||
impl SyntaxHighlighter {
|
||||
/// Create a new syntax highlighter with default theme
|
||||
pub fn new() -> Self {
|
||||
let syntax_set = SyntaxSet::load_defaults_newlines();
|
||||
let theme_set = ThemeSet::load_defaults();
|
||||
// Use a dark theme that works well in terminals
|
||||
let theme = theme_set.themes["base16-ocean.dark"].clone();
|
||||
|
||||
Self { syntax_set, theme }
|
||||
}
|
||||
|
||||
/// Create highlighter with a specific theme name
|
||||
pub fn with_theme(theme_name: &str) -> Self {
|
||||
let syntax_set = SyntaxSet::load_defaults_newlines();
|
||||
let theme_set = ThemeSet::load_defaults();
|
||||
let theme = theme_set
|
||||
.themes
|
||||
.get(theme_name)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| theme_set.themes["base16-ocean.dark"].clone());
|
||||
|
||||
Self { syntax_set, theme }
|
||||
}
|
||||
|
||||
/// Get available theme names
|
||||
pub fn available_themes() -> Vec<&'static str> {
|
||||
vec![
|
||||
"base16-ocean.dark",
|
||||
"base16-eighties.dark",
|
||||
"base16-mocha.dark",
|
||||
"base16-ocean.light",
|
||||
"InspiredGitHub",
|
||||
"Solarized (dark)",
|
||||
"Solarized (light)",
|
||||
]
|
||||
}
|
||||
|
||||
/// Highlight a code block and return styled lines
|
||||
pub fn highlight_code(&self, code: &str, language: &str) -> Vec<Line<'static>> {
|
||||
// Find syntax for the language
|
||||
let syntax = self
|
||||
.syntax_set
|
||||
.find_syntax_by_token(language)
|
||||
.or_else(|| self.syntax_set.find_syntax_by_extension(language))
|
||||
.unwrap_or_else(|| self.syntax_set.find_syntax_plain_text());
|
||||
|
||||
let mut highlighter = HighlightLines::new(syntax, &self.theme);
|
||||
let mut lines = Vec::new();
|
||||
|
||||
for line in LinesWithEndings::from(code) {
|
||||
let Ok(ranges) = highlighter.highlight_line(line, &self.syntax_set) else {
|
||||
// Fallback to plain text if highlighting fails
|
||||
lines.push(Line::from(Span::raw(line.trim_end().to_string())));
|
||||
continue;
|
||||
};
|
||||
|
||||
let spans: Vec<Span<'static>> = ranges
|
||||
.into_iter()
|
||||
.map(|(style, text)| {
|
||||
let fg = syntect_to_ratatui_color(style.foreground);
|
||||
let ratatui_style = Style::default().fg(fg);
|
||||
Span::styled(text.trim_end_matches('\n').to_string(), ratatui_style)
|
||||
})
|
||||
.collect();
|
||||
|
||||
lines.push(Line::from(spans));
|
||||
}
|
||||
|
||||
lines
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SyntaxHighlighter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert syntect color to ratatui color
|
||||
fn syntect_to_ratatui_color(color: syntect::highlighting::Color) -> Color {
|
||||
Color::Rgb(color.r, color.g, color.b)
|
||||
}
|
||||
|
||||
/// Parsed markdown content ready for rendering
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FormattedContent {
|
||||
pub lines: Vec<Line<'static>>,
|
||||
}
|
||||
|
||||
impl FormattedContent {
|
||||
/// Create empty formatted content
|
||||
pub fn empty() -> Self {
|
||||
Self { lines: Vec::new() }
|
||||
}
|
||||
|
||||
/// Get the number of lines
|
||||
pub fn len(&self) -> usize {
|
||||
self.lines.len()
|
||||
}
|
||||
|
||||
/// Check if content is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.lines.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Markdown parser that converts markdown to styled ratatui lines
|
||||
pub struct MarkdownRenderer {
|
||||
highlighter: SyntaxHighlighter,
|
||||
}
|
||||
|
||||
impl MarkdownRenderer {
|
||||
/// Create a new markdown renderer
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
highlighter: SyntaxHighlighter::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create renderer with custom highlighter
|
||||
pub fn with_highlighter(highlighter: SyntaxHighlighter) -> Self {
|
||||
Self { highlighter }
|
||||
}
|
||||
|
||||
/// Render markdown text to formatted content
|
||||
pub fn render(&self, markdown: &str) -> FormattedContent {
|
||||
let parser = Parser::new(markdown);
|
||||
let mut lines: Vec<Line<'static>> = Vec::new();
|
||||
let mut current_line_spans: Vec<Span<'static>> = Vec::new();
|
||||
|
||||
// State tracking
|
||||
let mut in_code_block = false;
|
||||
let mut code_block_lang = String::new();
|
||||
let mut code_block_content = String::new();
|
||||
let mut current_style = Style::default();
|
||||
let mut list_depth: usize = 0;
|
||||
let mut ordered_list_index: Option<u64> = None;
|
||||
|
||||
for event in parser {
|
||||
match event {
|
||||
Event::Start(tag) => match tag {
|
||||
Tag::Heading { level, .. } => {
|
||||
// Flush current line
|
||||
if !current_line_spans.is_empty() {
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
}
|
||||
// Style for headings
|
||||
current_style = match level {
|
||||
pulldown_cmark::HeadingLevel::H1 => Style::default()
|
||||
.fg(Color::Cyan)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
pulldown_cmark::HeadingLevel::H2 => Style::default()
|
||||
.fg(Color::Blue)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
pulldown_cmark::HeadingLevel::H3 => Style::default()
|
||||
.fg(Color::Green)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
_ => Style::default().add_modifier(Modifier::BOLD),
|
||||
};
|
||||
// Add heading prefix
|
||||
let prefix = "#".repeat(level as usize);
|
||||
current_line_spans.push(Span::styled(
|
||||
format!("{} ", prefix),
|
||||
Style::default().fg(Color::DarkGray),
|
||||
));
|
||||
}
|
||||
Tag::Paragraph => {
|
||||
// Start a new paragraph
|
||||
if !current_line_spans.is_empty() {
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
}
|
||||
}
|
||||
Tag::CodeBlock(kind) => {
|
||||
in_code_block = true;
|
||||
code_block_content.clear();
|
||||
code_block_lang = match kind {
|
||||
CodeBlockKind::Fenced(lang) => lang.to_string(),
|
||||
CodeBlockKind::Indented => String::new(),
|
||||
};
|
||||
// Flush current line and add code block header
|
||||
if !current_line_spans.is_empty() {
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
}
|
||||
// Add code fence line
|
||||
let fence_line = if code_block_lang.is_empty() {
|
||||
"```".to_string()
|
||||
} else {
|
||||
format!("```{}", code_block_lang)
|
||||
};
|
||||
lines.push(Line::from(Span::styled(
|
||||
fence_line,
|
||||
Style::default().fg(Color::DarkGray),
|
||||
)));
|
||||
}
|
||||
Tag::List(start) => {
|
||||
list_depth += 1;
|
||||
ordered_list_index = start;
|
||||
}
|
||||
Tag::Item => {
|
||||
// Flush current line
|
||||
if !current_line_spans.is_empty() {
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
}
|
||||
// Add list marker
|
||||
let indent = " ".repeat(list_depth.saturating_sub(1));
|
||||
let marker = if let Some(idx) = ordered_list_index {
|
||||
ordered_list_index = Some(idx + 1);
|
||||
format!("{}{}. ", indent, idx)
|
||||
} else {
|
||||
format!("{}- ", indent)
|
||||
};
|
||||
current_line_spans.push(Span::styled(
|
||||
marker,
|
||||
Style::default().fg(Color::Yellow),
|
||||
));
|
||||
}
|
||||
Tag::Emphasis => {
|
||||
current_style = current_style.add_modifier(Modifier::ITALIC);
|
||||
}
|
||||
Tag::Strong => {
|
||||
current_style = current_style.add_modifier(Modifier::BOLD);
|
||||
}
|
||||
Tag::Strikethrough => {
|
||||
current_style = current_style.add_modifier(Modifier::CROSSED_OUT);
|
||||
}
|
||||
Tag::Link { dest_url, .. } => {
|
||||
current_style = Style::default()
|
||||
.fg(Color::Blue)
|
||||
.add_modifier(Modifier::UNDERLINED);
|
||||
// Store URL for later
|
||||
current_line_spans.push(Span::styled(
|
||||
"[",
|
||||
Style::default().fg(Color::DarkGray),
|
||||
));
|
||||
// URL will be shown after link text
|
||||
code_block_content = dest_url.to_string();
|
||||
}
|
||||
Tag::BlockQuote(_) => {
|
||||
if !current_line_spans.is_empty() {
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
}
|
||||
current_line_spans.push(Span::styled(
|
||||
"│ ",
|
||||
Style::default().fg(Color::DarkGray),
|
||||
));
|
||||
current_style = Style::default().fg(Color::Gray).add_modifier(Modifier::ITALIC);
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
Event::End(tag_end) => match tag_end {
|
||||
TagEnd::Heading(_) => {
|
||||
current_style = Style::default();
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
}
|
||||
TagEnd::Paragraph => {
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
lines.push(Line::from("")); // Empty line after paragraph
|
||||
}
|
||||
TagEnd::CodeBlock => {
|
||||
in_code_block = false;
|
||||
// Highlight and add code content
|
||||
let highlighted =
|
||||
self.highlighter.highlight_code(&code_block_content, &code_block_lang);
|
||||
lines.extend(highlighted);
|
||||
// Add closing fence
|
||||
lines.push(Line::from(Span::styled(
|
||||
"```",
|
||||
Style::default().fg(Color::DarkGray),
|
||||
)));
|
||||
code_block_content.clear();
|
||||
code_block_lang.clear();
|
||||
}
|
||||
TagEnd::List(_) => {
|
||||
list_depth = list_depth.saturating_sub(1);
|
||||
if list_depth == 0 {
|
||||
ordered_list_index = None;
|
||||
}
|
||||
}
|
||||
TagEnd::Item => {
|
||||
if !current_line_spans.is_empty() {
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
}
|
||||
}
|
||||
TagEnd::Emphasis | TagEnd::Strong | TagEnd::Strikethrough => {
|
||||
current_style = Style::default();
|
||||
}
|
||||
TagEnd::Link => {
|
||||
current_line_spans.push(Span::styled(
|
||||
"]",
|
||||
Style::default().fg(Color::DarkGray),
|
||||
));
|
||||
current_line_spans.push(Span::styled(
|
||||
format!("({})", code_block_content),
|
||||
Style::default().fg(Color::DarkGray),
|
||||
));
|
||||
code_block_content.clear();
|
||||
current_style = Style::default();
|
||||
}
|
||||
TagEnd::BlockQuote => {
|
||||
current_style = Style::default();
|
||||
if !current_line_spans.is_empty() {
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
Event::Text(text) => {
|
||||
if in_code_block {
|
||||
code_block_content.push_str(&text);
|
||||
} else {
|
||||
current_line_spans.push(Span::styled(text.to_string(), current_style));
|
||||
}
|
||||
}
|
||||
Event::Code(code) => {
|
||||
// Inline code
|
||||
current_line_spans.push(Span::styled(
|
||||
format!("`{}`", code),
|
||||
Style::default().fg(Color::Magenta),
|
||||
));
|
||||
}
|
||||
Event::SoftBreak => {
|
||||
current_line_spans.push(Span::raw(" "));
|
||||
}
|
||||
Event::HardBreak => {
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
}
|
||||
Event::Rule => {
|
||||
if !current_line_spans.is_empty() {
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
}
|
||||
lines.push(Line::from(Span::styled(
|
||||
"─".repeat(40),
|
||||
Style::default().fg(Color::DarkGray),
|
||||
)));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Flush any remaining content
|
||||
if !current_line_spans.is_empty() {
|
||||
lines.push(Line::from(current_line_spans));
|
||||
}
|
||||
|
||||
FormattedContent { lines }
|
||||
}
|
||||
|
||||
/// Render plain text (no markdown parsing)
|
||||
pub fn render_plain(&self, text: &str) -> FormattedContent {
|
||||
let lines = text
|
||||
.lines()
|
||||
.map(|line| Line::from(Span::raw(line.to_string())))
|
||||
.collect();
|
||||
FormattedContent { lines }
|
||||
}
|
||||
|
||||
/// Render a diff with +/- highlighting
|
||||
pub fn render_diff(&self, diff: &str) -> FormattedContent {
|
||||
let lines = diff
|
||||
.lines()
|
||||
.map(|line| {
|
||||
let style = if line.starts_with('+') && !line.starts_with("+++") {
|
||||
Style::default().fg(Color::Green)
|
||||
} else if line.starts_with('-') && !line.starts_with("---") {
|
||||
Style::default().fg(Color::Red)
|
||||
} else if line.starts_with("@@") {
|
||||
Style::default().fg(Color::Cyan)
|
||||
} else if line.starts_with("diff ") || line.starts_with("index ") {
|
||||
Style::default().fg(Color::Yellow)
|
||||
} else {
|
||||
Style::default()
|
||||
};
|
||||
Line::from(Span::styled(line.to_string(), style))
|
||||
})
|
||||
.collect();
|
||||
FormattedContent { lines }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MarkdownRenderer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a file path with syntax highlighting based on extension
|
||||
pub fn format_file_path(path: &str) -> Span<'static> {
|
||||
let color = if path.ends_with(".rs") {
|
||||
Color::Rgb(222, 165, 132) // Rust orange
|
||||
} else if path.ends_with(".toml") {
|
||||
Color::Rgb(156, 220, 254) // Light blue
|
||||
} else if path.ends_with(".md") {
|
||||
Color::Rgb(86, 156, 214) // Blue
|
||||
} else if path.ends_with(".json") {
|
||||
Color::Rgb(206, 145, 120) // Brown
|
||||
} else if path.ends_with(".ts") || path.ends_with(".tsx") {
|
||||
Color::Rgb(49, 120, 198) // TypeScript blue
|
||||
} else if path.ends_with(".js") || path.ends_with(".jsx") {
|
||||
Color::Rgb(241, 224, 90) // JavaScript yellow
|
||||
} else if path.ends_with(".py") {
|
||||
Color::Rgb(55, 118, 171) // Python blue
|
||||
} else if path.ends_with(".go") {
|
||||
Color::Rgb(0, 173, 216) // Go cyan
|
||||
} else if path.ends_with(".sh") || path.ends_with(".bash") {
|
||||
Color::Rgb(137, 224, 81) // Shell green
|
||||
} else {
|
||||
Color::White
|
||||
};
|
||||
|
||||
Span::styled(path.to_string(), Style::default().fg(color))
|
||||
}
|
||||
|
||||
/// Format a tool name with appropriate styling
|
||||
pub fn format_tool_name(name: &str) -> Span<'static> {
|
||||
let style = Style::default()
|
||||
.fg(Color::Yellow)
|
||||
.add_modifier(Modifier::BOLD);
|
||||
Span::styled(name.to_string(), style)
|
||||
}
|
||||
|
||||
/// Format an error message
|
||||
pub fn format_error(message: &str) -> Line<'static> {
|
||||
Line::from(vec![
|
||||
Span::styled("Error: ", Style::default().fg(Color::Red).add_modifier(Modifier::BOLD)),
|
||||
Span::styled(message.to_string(), Style::default().fg(Color::Red)),
|
||||
])
|
||||
}
|
||||
|
||||
/// Format a success message
|
||||
pub fn format_success(message: &str) -> Line<'static> {
|
||||
Line::from(vec![
|
||||
Span::styled("✓ ", Style::default().fg(Color::Green)),
|
||||
Span::styled(message.to_string(), Style::default().fg(Color::Green)),
|
||||
])
|
||||
}
|
||||
|
||||
/// Format a warning message
|
||||
pub fn format_warning(message: &str) -> Line<'static> {
|
||||
Line::from(vec![
|
||||
Span::styled("⚠ ", Style::default().fg(Color::Yellow)),
|
||||
Span::styled(message.to_string(), Style::default().fg(Color::Yellow)),
|
||||
])
|
||||
}
|
||||
|
||||
/// Format an info message
|
||||
pub fn format_info(message: &str) -> Line<'static> {
|
||||
Line::from(vec![
|
||||
Span::styled("ℹ ", Style::default().fg(Color::Blue)),
|
||||
Span::styled(message.to_string(), Style::default().fg(Color::Blue)),
|
||||
])
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_syntax_highlighter_creation() {
|
||||
let highlighter = SyntaxHighlighter::new();
|
||||
let lines = highlighter.highlight_code("fn main() {}", "rust");
|
||||
assert!(!lines.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_markdown_render_heading() {
|
||||
let renderer = MarkdownRenderer::new();
|
||||
let content = renderer.render("# Hello World");
|
||||
assert!(!content.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_markdown_render_code_block() {
|
||||
let renderer = MarkdownRenderer::new();
|
||||
let content = renderer.render("```rust\nfn main() {}\n```");
|
||||
assert!(content.len() >= 3); // Opening fence, code, closing fence
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_markdown_render_list() {
|
||||
let renderer = MarkdownRenderer::new();
|
||||
let content = renderer.render("- Item 1\n- Item 2\n- Item 3");
|
||||
assert!(content.len() >= 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diff_rendering() {
|
||||
let renderer = MarkdownRenderer::new();
|
||||
let diff = "+added line\n-removed line\n unchanged";
|
||||
let content = renderer.render_diff(diff);
|
||||
assert_eq!(content.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_file_path() {
|
||||
let span = format_file_path("src/main.rs");
|
||||
assert!(span.content.contains("main.rs"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_messages() {
|
||||
let error = format_error("Something went wrong");
|
||||
assert!(!error.spans.is_empty());
|
||||
|
||||
let success = format_success("Operation completed");
|
||||
assert!(!success.spans.is_empty());
|
||||
|
||||
let warning = format_warning("Be careful");
|
||||
assert!(!warning.spans.is_empty());
|
||||
|
||||
let info = format_info("FYI");
|
||||
assert!(!info.spans.is_empty());
|
||||
}
|
||||
}
|
||||
218
crates/app/ui/src/layout.rs
Normal file
218
crates/app/ui/src/layout.rs
Normal file
@@ -0,0 +1,218 @@
|
||||
//! Layout calculation for the borderless TUI
|
||||
//!
|
||||
//! Uses vertical layout with whitespace for visual hierarchy instead of borders:
|
||||
//! - Header row (app name, mode, model, help)
|
||||
//! - Provider tabs
|
||||
//! - Horizontal divider
|
||||
//! - Chat area (scrollable)
|
||||
//! - Horizontal divider
|
||||
//! - Input area
|
||||
//! - Status bar
|
||||
|
||||
use ratatui::layout::{Constraint, Direction, Layout, Rect};
|
||||
|
||||
/// Calculated layout areas for the borderless TUI
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct AppLayout {
|
||||
/// Header row: app name, mode indicator, model, help hint
|
||||
pub header_area: Rect,
|
||||
/// Provider tabs row
|
||||
pub tabs_area: Rect,
|
||||
/// Top divider (horizontal rule)
|
||||
pub top_divider: Rect,
|
||||
/// Main chat/message area
|
||||
pub chat_area: Rect,
|
||||
/// Todo panel area (optional, between chat and input)
|
||||
pub todo_area: Rect,
|
||||
/// Bottom divider (horizontal rule)
|
||||
pub bottom_divider: Rect,
|
||||
/// Input area for user text
|
||||
pub input_area: Rect,
|
||||
/// Status bar at the bottom
|
||||
pub status_area: Rect,
|
||||
}
|
||||
|
||||
impl AppLayout {
|
||||
/// Calculate layout for the given terminal size
|
||||
pub fn calculate(area: Rect) -> Self {
|
||||
Self::calculate_with_todo(area, 0)
|
||||
}
|
||||
|
||||
/// Calculate layout with todo panel of specified height
|
||||
///
|
||||
/// Simplified layout without provider tabs:
|
||||
/// - Header (1 line)
|
||||
/// - Top divider (1 line)
|
||||
/// - Chat area (flexible)
|
||||
/// - Todo panel (optional)
|
||||
/// - Bottom divider (1 line)
|
||||
/// - Input (1 line)
|
||||
/// - Status bar (1 line)
|
||||
pub fn calculate_with_todo(area: Rect, todo_height: u16) -> Self {
|
||||
let chunks = if todo_height > 0 {
|
||||
Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([
|
||||
Constraint::Length(1), // Header
|
||||
Constraint::Length(1), // Top divider
|
||||
Constraint::Min(5), // Chat area (flexible)
|
||||
Constraint::Length(todo_height), // Todo panel
|
||||
Constraint::Length(1), // Bottom divider
|
||||
Constraint::Length(1), // Input
|
||||
Constraint::Length(1), // Status bar
|
||||
])
|
||||
.split(area)
|
||||
} else {
|
||||
Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([
|
||||
Constraint::Length(1), // Header
|
||||
Constraint::Length(1), // Top divider
|
||||
Constraint::Min(5), // Chat area (flexible)
|
||||
Constraint::Length(0), // No todo panel
|
||||
Constraint::Length(1), // Bottom divider
|
||||
Constraint::Length(1), // Input
|
||||
Constraint::Length(1), // Status bar
|
||||
])
|
||||
.split(area)
|
||||
};
|
||||
|
||||
Self {
|
||||
header_area: chunks[0],
|
||||
tabs_area: Rect::default(), // Not used in simplified layout
|
||||
top_divider: chunks[1],
|
||||
chat_area: chunks[2],
|
||||
todo_area: chunks[3],
|
||||
bottom_divider: chunks[4],
|
||||
input_area: chunks[5],
|
||||
status_area: chunks[6],
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate layout with expanded input (multiline)
|
||||
pub fn calculate_expanded_input(area: Rect, input_lines: u16) -> Self {
|
||||
let input_height = input_lines.min(10).max(1); // Cap at 10 lines
|
||||
|
||||
let chunks = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([
|
||||
Constraint::Length(1), // Header
|
||||
Constraint::Length(1), // Top divider
|
||||
Constraint::Min(5), // Chat area (flexible)
|
||||
Constraint::Length(0), // No todo panel
|
||||
Constraint::Length(1), // Bottom divider
|
||||
Constraint::Length(input_height), // Expanded input
|
||||
Constraint::Length(1), // Status bar
|
||||
])
|
||||
.split(area);
|
||||
|
||||
Self {
|
||||
header_area: chunks[0],
|
||||
tabs_area: Rect::default(),
|
||||
top_divider: chunks[1],
|
||||
chat_area: chunks[2],
|
||||
todo_area: chunks[3],
|
||||
bottom_divider: chunks[4],
|
||||
input_area: chunks[5],
|
||||
status_area: chunks[6],
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate layout without tabs (compact mode)
|
||||
pub fn calculate_compact(area: Rect) -> Self {
|
||||
let chunks = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([
|
||||
Constraint::Length(1), // Header (includes compact provider indicator)
|
||||
Constraint::Length(1), // Top divider
|
||||
Constraint::Min(5), // Chat area (flexible)
|
||||
Constraint::Length(0), // No todo panel
|
||||
Constraint::Length(1), // Bottom divider
|
||||
Constraint::Length(1), // Input
|
||||
Constraint::Length(1), // Status bar
|
||||
])
|
||||
.split(area);
|
||||
|
||||
Self {
|
||||
header_area: chunks[0],
|
||||
tabs_area: Rect::default(), // No tabs area in compact mode
|
||||
top_divider: chunks[1],
|
||||
chat_area: chunks[2],
|
||||
todo_area: chunks[3],
|
||||
bottom_divider: chunks[4],
|
||||
input_area: chunks[5],
|
||||
status_area: chunks[6],
|
||||
}
|
||||
}
|
||||
|
||||
/// Center a popup in the given area
|
||||
pub fn center_popup(area: Rect, width: u16, height: u16) -> Rect {
|
||||
let popup_layout = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([
|
||||
Constraint::Length((area.height.saturating_sub(height)) / 2),
|
||||
Constraint::Length(height),
|
||||
Constraint::Length((area.height.saturating_sub(height)) / 2),
|
||||
])
|
||||
.split(area);
|
||||
|
||||
Layout::default()
|
||||
.direction(Direction::Horizontal)
|
||||
.constraints([
|
||||
Constraint::Length((area.width.saturating_sub(width)) / 2),
|
||||
Constraint::Length(width),
|
||||
Constraint::Length((area.width.saturating_sub(width)) / 2),
|
||||
])
|
||||
.split(popup_layout[1])[1]
|
||||
}
|
||||
}
|
||||
|
||||
/// Layout mode based on terminal width
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum LayoutMode {
|
||||
/// Full layout with provider tabs (>= 80 cols)
|
||||
Full,
|
||||
/// Compact layout without tabs (< 80 cols)
|
||||
Compact,
|
||||
}
|
||||
|
||||
impl LayoutMode {
|
||||
/// Determine layout mode based on terminal width
|
||||
pub fn for_width(width: u16) -> Self {
|
||||
if width >= 80 {
|
||||
LayoutMode::Full
|
||||
} else {
|
||||
LayoutMode::Compact
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_layout_calculation() {
|
||||
let area = Rect::new(0, 0, 120, 40);
|
||||
let layout = AppLayout::calculate(area);
|
||||
|
||||
// Header should be at top
|
||||
assert_eq!(layout.header_area.y, 0);
|
||||
assert_eq!(layout.header_area.height, 1);
|
||||
|
||||
// Status should be at bottom
|
||||
assert_eq!(layout.status_area.y, 39);
|
||||
assert_eq!(layout.status_area.height, 1);
|
||||
|
||||
// Chat area should have most of the space
|
||||
assert!(layout.chat_area.height > 20);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_layout_mode() {
|
||||
assert_eq!(LayoutMode::for_width(80), LayoutMode::Full);
|
||||
assert_eq!(LayoutMode::for_width(120), LayoutMode::Full);
|
||||
assert_eq!(LayoutMode::for_width(79), LayoutMode::Compact);
|
||||
assert_eq!(LayoutMode::for_width(60), LayoutMode::Compact);
|
||||
}
|
||||
}
|
||||
30
crates/app/ui/src/lib.rs
Normal file
30
crates/app/ui/src/lib.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
pub mod app;
|
||||
pub mod completions;
|
||||
pub mod components;
|
||||
pub mod events;
|
||||
pub mod formatting;
|
||||
pub mod layout;
|
||||
pub mod output;
|
||||
pub mod theme;
|
||||
|
||||
pub use app::TuiApp;
|
||||
pub use completions::{CompletionEngine, Completion, CommandInfo};
|
||||
pub use events::AppEvent;
|
||||
pub use output::{CommandOutput, OutputFormat, TreeNode, ListItem};
|
||||
pub use formatting::{
|
||||
FormattedContent, MarkdownRenderer, SyntaxHighlighter,
|
||||
format_file_path, format_tool_name, format_error, format_success, format_warning, format_info,
|
||||
};
|
||||
|
||||
use color_eyre::eyre::Result;
|
||||
|
||||
/// Run the TUI application
|
||||
pub async fn run(
|
||||
client: llm_ollama::OllamaClient,
|
||||
opts: llm_core::ChatOptions,
|
||||
perms: permissions::PermissionManager,
|
||||
settings: config_agent::Settings,
|
||||
) -> Result<()> {
|
||||
let mut app = TuiApp::new(client, opts, perms, settings)?;
|
||||
app.run().await
|
||||
}
|
||||
388
crates/app/ui/src/output.rs
Normal file
388
crates/app/ui/src/output.rs
Normal file
@@ -0,0 +1,388 @@
|
||||
//! Rich command output formatting
|
||||
//!
|
||||
//! Provides formatted output for commands like /help, /mcp, /hooks
|
||||
//! with tables, trees, and syntax highlighting.
|
||||
|
||||
use ratatui::text::{Line, Span};
|
||||
use ratatui::style::{Color, Modifier, Style};
|
||||
|
||||
use crate::completions::CommandInfo;
|
||||
use crate::theme::Theme;
|
||||
|
||||
/// A tree node for hierarchical display
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TreeNode {
|
||||
pub label: String,
|
||||
pub children: Vec<TreeNode>,
|
||||
}
|
||||
|
||||
impl TreeNode {
|
||||
pub fn new(label: impl Into<String>) -> Self {
|
||||
Self {
|
||||
label: label.into(),
|
||||
children: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_children(mut self, children: Vec<TreeNode>) -> Self {
|
||||
self.children = children;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// A list item with optional icon/marker
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ListItem {
|
||||
pub text: String,
|
||||
pub marker: Option<String>,
|
||||
pub style: Option<Style>,
|
||||
}
|
||||
|
||||
/// Different output formats
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum OutputFormat {
|
||||
/// Formatted table with headers and rows
|
||||
Table {
|
||||
headers: Vec<String>,
|
||||
rows: Vec<Vec<String>>,
|
||||
},
|
||||
/// Hierarchical tree view
|
||||
Tree {
|
||||
root: TreeNode,
|
||||
},
|
||||
/// Syntax-highlighted code block
|
||||
Code {
|
||||
language: String,
|
||||
content: String,
|
||||
},
|
||||
/// Side-by-side diff view
|
||||
Diff {
|
||||
old: String,
|
||||
new: String,
|
||||
},
|
||||
/// Simple list with markers
|
||||
List {
|
||||
items: Vec<ListItem>,
|
||||
},
|
||||
/// Plain text
|
||||
Text {
|
||||
content: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// Rich command output renderer
|
||||
pub struct CommandOutput {
|
||||
pub format: OutputFormat,
|
||||
}
|
||||
|
||||
impl CommandOutput {
|
||||
pub fn new(format: OutputFormat) -> Self {
|
||||
Self { format }
|
||||
}
|
||||
|
||||
/// Create a help table output
|
||||
pub fn help_table(commands: &[CommandInfo]) -> Self {
|
||||
let headers = vec![
|
||||
"Command".to_string(),
|
||||
"Description".to_string(),
|
||||
"Source".to_string(),
|
||||
];
|
||||
|
||||
let rows: Vec<Vec<String>> = commands
|
||||
.iter()
|
||||
.map(|c| vec![
|
||||
format!("/{}", c.name),
|
||||
c.description.clone(),
|
||||
c.source.clone(),
|
||||
])
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
format: OutputFormat::Table { headers, rows },
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an MCP servers tree view
|
||||
pub fn mcp_tree(servers: &[(String, Vec<String>)]) -> Self {
|
||||
let children: Vec<TreeNode> = servers
|
||||
.iter()
|
||||
.map(|(name, tools)| {
|
||||
TreeNode {
|
||||
label: name.clone(),
|
||||
children: tools.iter().map(|t| TreeNode::new(t)).collect(),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
format: OutputFormat::Tree {
|
||||
root: TreeNode {
|
||||
label: "MCP Servers".to_string(),
|
||||
children,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a hooks list output
|
||||
pub fn hooks_list(hooks: &[(String, String, bool)]) -> Self {
|
||||
let items: Vec<ListItem> = hooks
|
||||
.iter()
|
||||
.map(|(event, path, enabled)| {
|
||||
let marker = if *enabled { "✓" } else { "✗" };
|
||||
let style = if *enabled {
|
||||
Some(Style::default().fg(Color::Green))
|
||||
} else {
|
||||
Some(Style::default().fg(Color::Red))
|
||||
};
|
||||
ListItem {
|
||||
text: format!("{}: {}", event, path),
|
||||
marker: Some(marker.to_string()),
|
||||
style,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
format: OutputFormat::List { items },
|
||||
}
|
||||
}
|
||||
|
||||
/// Render to TUI Lines
|
||||
pub fn render(&self, theme: &Theme) -> Vec<Line<'static>> {
|
||||
match &self.format {
|
||||
OutputFormat::Table { headers, rows } => {
|
||||
self.render_table(headers, rows, theme)
|
||||
}
|
||||
OutputFormat::Tree { root } => {
|
||||
self.render_tree(root, 0, theme)
|
||||
}
|
||||
OutputFormat::List { items } => {
|
||||
self.render_list(items, theme)
|
||||
}
|
||||
OutputFormat::Code { content, .. } => {
|
||||
content.lines()
|
||||
.map(|line| Line::from(Span::styled(line.to_string(), theme.tool_call)))
|
||||
.collect()
|
||||
}
|
||||
OutputFormat::Diff { old, new } => {
|
||||
self.render_diff(old, new, theme)
|
||||
}
|
||||
OutputFormat::Text { content } => {
|
||||
content.lines()
|
||||
.map(|line| Line::from(line.to_string()))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn render_table(&self, headers: &[String], rows: &[Vec<String>], theme: &Theme) -> Vec<Line<'static>> {
|
||||
let mut lines = Vec::new();
|
||||
|
||||
// Calculate column widths
|
||||
let mut widths: Vec<usize> = headers.iter().map(|h| h.len()).collect();
|
||||
for row in rows {
|
||||
for (i, cell) in row.iter().enumerate() {
|
||||
if i < widths.len() {
|
||||
widths[i] = widths[i].max(cell.len());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Header line
|
||||
let header_spans: Vec<Span> = headers
|
||||
.iter()
|
||||
.enumerate()
|
||||
.flat_map(|(i, h)| {
|
||||
let padded = format!("{:width$}", h, width = widths.get(i).copied().unwrap_or(h.len()));
|
||||
vec![
|
||||
Span::styled(padded, Style::default().add_modifier(Modifier::BOLD)),
|
||||
Span::raw(" "),
|
||||
]
|
||||
})
|
||||
.collect();
|
||||
lines.push(Line::from(header_spans));
|
||||
|
||||
// Separator
|
||||
let sep: String = widths.iter().map(|w| "─".repeat(*w)).collect::<Vec<_>>().join("──");
|
||||
lines.push(Line::from(Span::styled(sep, theme.status_dim)));
|
||||
|
||||
// Rows
|
||||
for row in rows {
|
||||
let row_spans: Vec<Span> = row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.flat_map(|(i, cell)| {
|
||||
let padded = format!("{:width$}", cell, width = widths.get(i).copied().unwrap_or(cell.len()));
|
||||
let style = if i == 0 {
|
||||
theme.status_accent // Command names in accent color
|
||||
} else {
|
||||
theme.status_bar
|
||||
};
|
||||
vec![
|
||||
Span::styled(padded, style),
|
||||
Span::raw(" "),
|
||||
]
|
||||
})
|
||||
.collect();
|
||||
lines.push(Line::from(row_spans));
|
||||
}
|
||||
|
||||
lines
|
||||
}
|
||||
|
||||
fn render_tree(&self, node: &TreeNode, depth: usize, theme: &Theme) -> Vec<Line<'static>> {
|
||||
let mut lines = Vec::new();
|
||||
|
||||
// Render current node
|
||||
let prefix = if depth == 0 {
|
||||
"".to_string()
|
||||
} else {
|
||||
format!("{}├─ ", "│ ".repeat(depth - 1))
|
||||
};
|
||||
|
||||
let style = if depth == 0 {
|
||||
Style::default().add_modifier(Modifier::BOLD)
|
||||
} else if node.children.is_empty() {
|
||||
theme.status_bar
|
||||
} else {
|
||||
theme.status_accent
|
||||
};
|
||||
|
||||
lines.push(Line::from(vec![
|
||||
Span::styled(prefix, theme.status_dim),
|
||||
Span::styled(node.label.clone(), style),
|
||||
]));
|
||||
|
||||
// Render children
|
||||
for child in &node.children {
|
||||
lines.extend(self.render_tree(child, depth + 1, theme));
|
||||
}
|
||||
|
||||
lines
|
||||
}
|
||||
|
||||
fn render_list(&self, items: &[ListItem], theme: &Theme) -> Vec<Line<'static>> {
|
||||
items
|
||||
.iter()
|
||||
.map(|item| {
|
||||
let marker_span = if let Some(marker) = &item.marker {
|
||||
Span::styled(
|
||||
format!("{} ", marker),
|
||||
item.style.unwrap_or(theme.status_bar),
|
||||
)
|
||||
} else {
|
||||
Span::raw("• ")
|
||||
};
|
||||
|
||||
Line::from(vec![
|
||||
marker_span,
|
||||
Span::styled(
|
||||
item.text.clone(),
|
||||
item.style.unwrap_or(theme.status_bar),
|
||||
),
|
||||
])
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn render_diff(&self, old: &str, new: &str, _theme: &Theme) -> Vec<Line<'static>> {
|
||||
let mut lines = Vec::new();
|
||||
|
||||
// Simple line-by-line diff
|
||||
let old_lines: Vec<&str> = old.lines().collect();
|
||||
let new_lines: Vec<&str> = new.lines().collect();
|
||||
|
||||
let max_len = old_lines.len().max(new_lines.len());
|
||||
|
||||
for i in 0..max_len {
|
||||
let old_line = old_lines.get(i).copied().unwrap_or("");
|
||||
let new_line = new_lines.get(i).copied().unwrap_or("");
|
||||
|
||||
if old_line != new_line {
|
||||
if !old_line.is_empty() {
|
||||
lines.push(Line::from(Span::styled(
|
||||
format!("- {}", old_line),
|
||||
Style::default().fg(Color::Red),
|
||||
)));
|
||||
}
|
||||
if !new_line.is_empty() {
|
||||
lines.push(Line::from(Span::styled(
|
||||
format!("+ {}", new_line),
|
||||
Style::default().fg(Color::Green),
|
||||
)));
|
||||
}
|
||||
} else {
|
||||
lines.push(Line::from(format!(" {}", old_line)));
|
||||
}
|
||||
}
|
||||
|
||||
lines
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_help_table() {
|
||||
let commands = vec![
|
||||
CommandInfo::new("help", "Show help", "builtin"),
|
||||
CommandInfo::new("clear", "Clear screen", "builtin"),
|
||||
];
|
||||
let output = CommandOutput::help_table(&commands);
|
||||
|
||||
match output.format {
|
||||
OutputFormat::Table { headers, rows } => {
|
||||
assert_eq!(headers.len(), 3);
|
||||
assert_eq!(rows.len(), 2);
|
||||
}
|
||||
_ => panic!("Expected Table format"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mcp_tree() {
|
||||
let servers = vec![
|
||||
("filesystem".to_string(), vec!["read".to_string(), "write".to_string()]),
|
||||
("database".to_string(), vec!["query".to_string()]),
|
||||
];
|
||||
let output = CommandOutput::mcp_tree(&servers);
|
||||
|
||||
match output.format {
|
||||
OutputFormat::Tree { root } => {
|
||||
assert_eq!(root.label, "MCP Servers");
|
||||
assert_eq!(root.children.len(), 2);
|
||||
}
|
||||
_ => panic!("Expected Tree format"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hooks_list() {
|
||||
let hooks = vec![
|
||||
("PreToolUse".to_string(), "./hooks/pre".to_string(), true),
|
||||
("PostToolUse".to_string(), "./hooks/post".to_string(), false),
|
||||
];
|
||||
let output = CommandOutput::hooks_list(&hooks);
|
||||
|
||||
match output.format {
|
||||
OutputFormat::List { items } => {
|
||||
assert_eq!(items.len(), 2);
|
||||
}
|
||||
_ => panic!("Expected List format"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tree_node() {
|
||||
let node = TreeNode::new("root")
|
||||
.with_children(vec![
|
||||
TreeNode::new("child1"),
|
||||
TreeNode::new("child2"),
|
||||
]);
|
||||
assert_eq!(node.label, "root");
|
||||
assert_eq!(node.children.len(), 2);
|
||||
}
|
||||
}
|
||||
707
crates/app/ui/src/theme.rs
Normal file
707
crates/app/ui/src/theme.rs
Normal file
@@ -0,0 +1,707 @@
|
||||
//! Theme system for the borderless TUI design
|
||||
//!
|
||||
//! Provides color palettes, semantic styling, and terminal capability detection
|
||||
//! for graceful degradation across different terminal emulators.
|
||||
|
||||
use ratatui::style::{Color, Modifier, Style};
|
||||
|
||||
/// Terminal capability detection for graceful degradation
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TerminalCapability {
|
||||
/// Full Unicode support with true color
|
||||
Full,
|
||||
/// Basic Unicode with 256 colors
|
||||
Unicode256,
|
||||
/// ASCII only with 16 colors
|
||||
Basic,
|
||||
}
|
||||
|
||||
impl TerminalCapability {
|
||||
/// Detect terminal capabilities from environment
|
||||
pub fn detect() -> Self {
|
||||
// Check for true color support
|
||||
let colorterm = std::env::var("COLORTERM").unwrap_or_default();
|
||||
let term = std::env::var("TERM").unwrap_or_default();
|
||||
|
||||
if colorterm == "truecolor" || colorterm == "24bit" {
|
||||
return Self::Full;
|
||||
}
|
||||
|
||||
if term.contains("256color") || term.contains("kitty") || term.contains("alacritty") {
|
||||
return Self::Unicode256;
|
||||
}
|
||||
|
||||
// Check if we're in a linux VT or basic terminal
|
||||
if term == "linux" || term == "vt100" || term == "dumb" {
|
||||
return Self::Basic;
|
||||
}
|
||||
|
||||
// Default to unicode with 256 colors
|
||||
Self::Unicode256
|
||||
}
|
||||
|
||||
/// Check if Unicode box drawing is supported
|
||||
pub fn supports_unicode(&self) -> bool {
|
||||
matches!(self, Self::Full | Self::Unicode256)
|
||||
}
|
||||
|
||||
/// Check if true color (RGB) is supported
|
||||
pub fn supports_truecolor(&self) -> bool {
|
||||
matches!(self, Self::Full)
|
||||
}
|
||||
}
|
||||
|
||||
/// Symbols with fallbacks for different terminal capabilities
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Symbols {
|
||||
pub horizontal_rule: &'static str,
|
||||
pub vertical_separator: &'static str,
|
||||
pub bullet: &'static str,
|
||||
pub arrow: &'static str,
|
||||
pub check: &'static str,
|
||||
pub cross: &'static str,
|
||||
pub warning: &'static str,
|
||||
pub info: &'static str,
|
||||
pub streaming: &'static str,
|
||||
pub user_prefix: &'static str,
|
||||
pub assistant_prefix: &'static str,
|
||||
pub tool_prefix: &'static str,
|
||||
pub system_prefix: &'static str,
|
||||
// Provider icons
|
||||
pub claude_icon: &'static str,
|
||||
pub ollama_icon: &'static str,
|
||||
pub openai_icon: &'static str,
|
||||
// Vim mode indicators
|
||||
pub mode_normal: &'static str,
|
||||
pub mode_insert: &'static str,
|
||||
pub mode_visual: &'static str,
|
||||
pub mode_command: &'static str,
|
||||
}
|
||||
|
||||
impl Symbols {
|
||||
/// Unicode symbols for capable terminals
|
||||
pub fn unicode() -> Self {
|
||||
Self {
|
||||
horizontal_rule: "─",
|
||||
vertical_separator: "│",
|
||||
bullet: "•",
|
||||
arrow: "→",
|
||||
check: "✓",
|
||||
cross: "✗",
|
||||
warning: "⚠",
|
||||
info: "ℹ",
|
||||
streaming: "●",
|
||||
user_prefix: "❯",
|
||||
assistant_prefix: "◆",
|
||||
tool_prefix: "⚡",
|
||||
system_prefix: "○",
|
||||
claude_icon: "",
|
||||
ollama_icon: "",
|
||||
openai_icon: "",
|
||||
mode_normal: "[N]",
|
||||
mode_insert: "[I]",
|
||||
mode_visual: "[V]",
|
||||
mode_command: "[:]",
|
||||
}
|
||||
}
|
||||
|
||||
/// ASCII fallback symbols
|
||||
pub fn ascii() -> Self {
|
||||
Self {
|
||||
horizontal_rule: "-",
|
||||
vertical_separator: "|",
|
||||
bullet: "*",
|
||||
arrow: "->",
|
||||
check: "+",
|
||||
cross: "x",
|
||||
warning: "!",
|
||||
info: "i",
|
||||
streaming: "*",
|
||||
user_prefix: ">",
|
||||
assistant_prefix: "-",
|
||||
tool_prefix: "#",
|
||||
system_prefix: "-",
|
||||
claude_icon: "C",
|
||||
ollama_icon: "O",
|
||||
openai_icon: "G",
|
||||
mode_normal: "[N]",
|
||||
mode_insert: "[I]",
|
||||
mode_visual: "[V]",
|
||||
mode_command: "[:]",
|
||||
}
|
||||
}
|
||||
|
||||
/// Select symbols based on terminal capability
|
||||
pub fn for_capability(cap: TerminalCapability) -> Self {
|
||||
match cap {
|
||||
TerminalCapability::Full | TerminalCapability::Unicode256 => Self::unicode(),
|
||||
TerminalCapability::Basic => Self::ascii(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Modern color palette inspired by contemporary design systems
|
||||
///
|
||||
/// Color assignment principles:
|
||||
/// - fg (#c0caf5): PRIMARY text - user messages, command names
|
||||
/// - assistant (#9aa5ce): Soft gray-blue for AI responses (distinct from user)
|
||||
/// - accent (#7aa2f7): Interactive elements ONLY (mode, prompt symbol)
|
||||
/// - cmd_slash (#bb9af7): Purple for / prefix (signals "command")
|
||||
/// - fg_dim (#565f89): Timestamps, hints, inactive elements
|
||||
/// - selection (#283457): Highlighted row background
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ColorPalette {
|
||||
pub primary: Color,
|
||||
pub secondary: Color,
|
||||
pub accent: Color,
|
||||
pub success: Color,
|
||||
pub warning: Color,
|
||||
pub error: Color,
|
||||
pub info: Color,
|
||||
pub bg: Color,
|
||||
pub fg: Color,
|
||||
pub fg_dim: Color,
|
||||
pub fg_muted: Color,
|
||||
pub highlight: Color,
|
||||
pub border: Color, // For horizontal rules (subtle)
|
||||
pub selection: Color, // Highlighted row background
|
||||
// Provider-specific colors
|
||||
pub claude: Color,
|
||||
pub ollama: Color,
|
||||
pub openai: Color,
|
||||
// Semantic colors for messages
|
||||
pub user_fg: Color, // User message text (bright, fg)
|
||||
pub assistant_fg: Color, // Assistant message text (soft gray-blue)
|
||||
pub tool_fg: Color,
|
||||
pub timestamp_fg: Color,
|
||||
pub divider_fg: Color,
|
||||
// Command colors
|
||||
pub cmd_slash: Color, // Purple for / prefix
|
||||
pub cmd_name: Color, // Command name (same as fg)
|
||||
pub cmd_desc: Color, // Command description (dim)
|
||||
// Overlay/modal colors
|
||||
pub overlay_bg: Color, // Slightly lighter than main bg
|
||||
}
|
||||
|
||||
impl ColorPalette {
|
||||
/// Tokyo Night inspired palette - high contrast, readable
|
||||
///
|
||||
/// Key principles:
|
||||
/// - fg (#c0caf5) for user messages and command names
|
||||
/// - assistant (#a9b1d6) brighter gray-blue for AI responses (readable)
|
||||
/// - accent (#7aa2f7) only for interactive elements (mode indicator, prompt symbol)
|
||||
/// - cmd_slash (#bb9af7) purple for / prefix (signals "command")
|
||||
/// - fg_dim (#737aa2) for timestamps, hints, descriptions (brighter than before)
|
||||
/// - border (#3b4261) for horizontal rules
|
||||
pub fn tokyo_night() -> Self {
|
||||
Self {
|
||||
primary: Color::Rgb(122, 162, 247), // #7aa2f7 - Blue accent
|
||||
secondary: Color::Rgb(187, 154, 247), // #bb9af7 - Purple
|
||||
accent: Color::Rgb(122, 162, 247), // #7aa2f7 - Interactive elements ONLY
|
||||
success: Color::Rgb(158, 206, 106), // #9ece6a - Green
|
||||
warning: Color::Rgb(224, 175, 104), // #e0af68 - Yellow
|
||||
error: Color::Rgb(247, 118, 142), // #f7768e - Pink/Red
|
||||
info: Color::Rgb(125, 207, 255), // Cyan (rarely used)
|
||||
bg: Color::Rgb(26, 27, 38), // #1a1b26 - Dark bg
|
||||
fg: Color::Rgb(192, 202, 245), // #c0caf5 - Primary text (HIGH CONTRAST)
|
||||
fg_dim: Color::Rgb(115, 122, 162), // #737aa2 - Secondary text (BRIGHTER)
|
||||
fg_muted: Color::Rgb(86, 95, 137), // #565f89 - Very dim
|
||||
highlight: Color::Rgb(56, 62, 90), // Selection bg (legacy)
|
||||
border: Color::Rgb(73, 82, 115), // #495273 - Horizontal rules (BRIGHTER)
|
||||
selection: Color::Rgb(40, 52, 87), // #283457 - Highlighted row bg
|
||||
// Provider colors
|
||||
claude: Color::Rgb(217, 119, 87), // Claude orange
|
||||
ollama: Color::Rgb(122, 162, 247), // Blue
|
||||
openai: Color::Rgb(16, 163, 127), // OpenAI green
|
||||
// Message colors - user bright, assistant readable
|
||||
user_fg: Color::Rgb(192, 202, 245), // #c0caf5 - Same as fg (bright)
|
||||
assistant_fg: Color::Rgb(169, 177, 214), // #a9b1d6 - Brighter gray-blue (READABLE)
|
||||
tool_fg: Color::Rgb(224, 175, 104), // #e0af68 - Yellow for tools
|
||||
timestamp_fg: Color::Rgb(115, 122, 162), // #737aa2 - Brighter dim
|
||||
divider_fg: Color::Rgb(73, 82, 115), // #495273 - Border color (BRIGHTER)
|
||||
// Command colors
|
||||
cmd_slash: Color::Rgb(187, 154, 247), // #bb9af7 - Purple for / prefix
|
||||
cmd_name: Color::Rgb(192, 202, 245), // #c0caf5 - White for command name
|
||||
cmd_desc: Color::Rgb(115, 122, 162), // #737aa2 - Brighter description
|
||||
// Overlay colors
|
||||
overlay_bg: Color::Rgb(36, 40, 59), // #24283b - Slightly lighter than bg
|
||||
}
|
||||
}
|
||||
|
||||
/// Dracula inspired palette - classic and elegant
|
||||
pub fn dracula() -> Self {
|
||||
Self {
|
||||
primary: Color::Rgb(139, 233, 253), // Cyan
|
||||
secondary: Color::Rgb(189, 147, 249), // Purple
|
||||
accent: Color::Rgb(255, 121, 198), // Pink
|
||||
success: Color::Rgb(80, 250, 123), // Green
|
||||
warning: Color::Rgb(241, 250, 140), // Yellow
|
||||
error: Color::Rgb(255, 85, 85), // Red
|
||||
info: Color::Rgb(139, 233, 253), // Cyan
|
||||
bg: Color::Rgb(40, 42, 54), // Dark bg
|
||||
fg: Color::Rgb(248, 248, 242), // Light text
|
||||
fg_dim: Color::Rgb(98, 114, 164), // Comment
|
||||
fg_muted: Color::Rgb(68, 71, 90), // Very dim
|
||||
highlight: Color::Rgb(68, 71, 90), // Selection
|
||||
border: Color::Rgb(68, 71, 90),
|
||||
selection: Color::Rgb(68, 71, 90),
|
||||
claude: Color::Rgb(255, 121, 198),
|
||||
ollama: Color::Rgb(139, 233, 253),
|
||||
openai: Color::Rgb(80, 250, 123),
|
||||
user_fg: Color::Rgb(248, 248, 242),
|
||||
assistant_fg: Color::Rgb(189, 186, 220), // Softer purple-gray
|
||||
tool_fg: Color::Rgb(241, 250, 140),
|
||||
timestamp_fg: Color::Rgb(68, 71, 90),
|
||||
divider_fg: Color::Rgb(68, 71, 90),
|
||||
cmd_slash: Color::Rgb(189, 147, 249), // Purple
|
||||
cmd_name: Color::Rgb(248, 248, 242),
|
||||
cmd_desc: Color::Rgb(98, 114, 164),
|
||||
overlay_bg: Color::Rgb(50, 52, 64),
|
||||
}
|
||||
}
|
||||
|
||||
/// Catppuccin Mocha - warm and cozy
|
||||
pub fn catppuccin() -> Self {
|
||||
Self {
|
||||
primary: Color::Rgb(137, 180, 250), // Blue
|
||||
secondary: Color::Rgb(203, 166, 247), // Mauve
|
||||
accent: Color::Rgb(245, 194, 231), // Pink
|
||||
success: Color::Rgb(166, 227, 161), // Green
|
||||
warning: Color::Rgb(249, 226, 175), // Yellow
|
||||
error: Color::Rgb(243, 139, 168), // Red
|
||||
info: Color::Rgb(148, 226, 213), // Teal
|
||||
bg: Color::Rgb(30, 30, 46), // Base
|
||||
fg: Color::Rgb(205, 214, 244), // Text
|
||||
fg_dim: Color::Rgb(108, 112, 134), // Overlay
|
||||
fg_muted: Color::Rgb(69, 71, 90), // Surface
|
||||
highlight: Color::Rgb(49, 50, 68), // Surface
|
||||
border: Color::Rgb(69, 71, 90),
|
||||
selection: Color::Rgb(49, 50, 68),
|
||||
claude: Color::Rgb(245, 194, 231),
|
||||
ollama: Color::Rgb(137, 180, 250),
|
||||
openai: Color::Rgb(166, 227, 161),
|
||||
user_fg: Color::Rgb(205, 214, 244),
|
||||
assistant_fg: Color::Rgb(166, 187, 213), // Softer blue-gray
|
||||
tool_fg: Color::Rgb(249, 226, 175),
|
||||
timestamp_fg: Color::Rgb(69, 71, 90),
|
||||
divider_fg: Color::Rgb(69, 71, 90),
|
||||
cmd_slash: Color::Rgb(203, 166, 247), // Mauve
|
||||
cmd_name: Color::Rgb(205, 214, 244),
|
||||
cmd_desc: Color::Rgb(108, 112, 134),
|
||||
overlay_bg: Color::Rgb(40, 40, 56),
|
||||
}
|
||||
}
|
||||
|
||||
/// Nord - minimal and clean
|
||||
pub fn nord() -> Self {
|
||||
Self {
|
||||
primary: Color::Rgb(136, 192, 208), // Frost cyan
|
||||
secondary: Color::Rgb(129, 161, 193), // Frost blue
|
||||
accent: Color::Rgb(180, 142, 173), // Aurora purple
|
||||
success: Color::Rgb(163, 190, 140), // Aurora green
|
||||
warning: Color::Rgb(235, 203, 139), // Aurora yellow
|
||||
error: Color::Rgb(191, 97, 106), // Aurora red
|
||||
info: Color::Rgb(136, 192, 208), // Frost cyan
|
||||
bg: Color::Rgb(46, 52, 64), // Polar night
|
||||
fg: Color::Rgb(236, 239, 244), // Snow storm
|
||||
fg_dim: Color::Rgb(76, 86, 106), // Polar night light
|
||||
fg_muted: Color::Rgb(59, 66, 82),
|
||||
highlight: Color::Rgb(59, 66, 82), // Selection
|
||||
border: Color::Rgb(59, 66, 82),
|
||||
selection: Color::Rgb(59, 66, 82),
|
||||
claude: Color::Rgb(180, 142, 173),
|
||||
ollama: Color::Rgb(136, 192, 208),
|
||||
openai: Color::Rgb(163, 190, 140),
|
||||
user_fg: Color::Rgb(236, 239, 244),
|
||||
assistant_fg: Color::Rgb(180, 195, 210), // Softer blue-gray
|
||||
tool_fg: Color::Rgb(235, 203, 139),
|
||||
timestamp_fg: Color::Rgb(59, 66, 82),
|
||||
divider_fg: Color::Rgb(59, 66, 82),
|
||||
cmd_slash: Color::Rgb(180, 142, 173), // Aurora purple
|
||||
cmd_name: Color::Rgb(236, 239, 244),
|
||||
cmd_desc: Color::Rgb(76, 86, 106),
|
||||
overlay_bg: Color::Rgb(56, 62, 74),
|
||||
}
|
||||
}
|
||||
|
||||
/// Synthwave - vibrant and retro
|
||||
pub fn synthwave() -> Self {
|
||||
Self {
|
||||
primary: Color::Rgb(255, 0, 128), // Hot pink
|
||||
secondary: Color::Rgb(0, 229, 255), // Cyan
|
||||
accent: Color::Rgb(255, 128, 0), // Orange
|
||||
success: Color::Rgb(0, 255, 157), // Neon green
|
||||
warning: Color::Rgb(255, 215, 0), // Gold
|
||||
error: Color::Rgb(255, 64, 64), // Neon red
|
||||
info: Color::Rgb(0, 229, 255), // Cyan
|
||||
bg: Color::Rgb(20, 16, 32), // Dark purple
|
||||
fg: Color::Rgb(242, 233, 255), // Light purple
|
||||
fg_dim: Color::Rgb(127, 90, 180), // Mid purple
|
||||
fg_muted: Color::Rgb(72, 12, 168),
|
||||
highlight: Color::Rgb(72, 12, 168), // Deep purple
|
||||
border: Color::Rgb(72, 12, 168),
|
||||
selection: Color::Rgb(72, 12, 168),
|
||||
claude: Color::Rgb(255, 128, 0),
|
||||
ollama: Color::Rgb(0, 229, 255),
|
||||
openai: Color::Rgb(0, 255, 157),
|
||||
user_fg: Color::Rgb(242, 233, 255),
|
||||
assistant_fg: Color::Rgb(180, 170, 220), // Softer purple
|
||||
tool_fg: Color::Rgb(255, 215, 0),
|
||||
timestamp_fg: Color::Rgb(72, 12, 168),
|
||||
divider_fg: Color::Rgb(72, 12, 168),
|
||||
cmd_slash: Color::Rgb(255, 0, 128), // Hot pink
|
||||
cmd_name: Color::Rgb(242, 233, 255),
|
||||
cmd_desc: Color::Rgb(127, 90, 180),
|
||||
overlay_bg: Color::Rgb(30, 26, 42),
|
||||
}
|
||||
}
|
||||
|
||||
/// Rose Pine - elegant and muted
|
||||
pub fn rose_pine() -> Self {
|
||||
Self {
|
||||
primary: Color::Rgb(156, 207, 216), // Foam
|
||||
secondary: Color::Rgb(235, 188, 186), // Rose
|
||||
accent: Color::Rgb(234, 154, 151), // Love
|
||||
success: Color::Rgb(49, 116, 143), // Pine
|
||||
warning: Color::Rgb(246, 193, 119), // Gold
|
||||
error: Color::Rgb(235, 111, 146), // Love (darker)
|
||||
info: Color::Rgb(156, 207, 216), // Foam
|
||||
bg: Color::Rgb(25, 23, 36), // Base
|
||||
fg: Color::Rgb(224, 222, 244), // Text
|
||||
fg_dim: Color::Rgb(110, 106, 134), // Muted
|
||||
fg_muted: Color::Rgb(42, 39, 63),
|
||||
highlight: Color::Rgb(42, 39, 63), // Highlight
|
||||
border: Color::Rgb(42, 39, 63),
|
||||
selection: Color::Rgb(42, 39, 63),
|
||||
claude: Color::Rgb(234, 154, 151),
|
||||
ollama: Color::Rgb(156, 207, 216),
|
||||
openai: Color::Rgb(49, 116, 143),
|
||||
user_fg: Color::Rgb(224, 222, 244),
|
||||
assistant_fg: Color::Rgb(180, 185, 210), // Softer lavender-gray
|
||||
tool_fg: Color::Rgb(246, 193, 119),
|
||||
timestamp_fg: Color::Rgb(42, 39, 63),
|
||||
divider_fg: Color::Rgb(42, 39, 63),
|
||||
cmd_slash: Color::Rgb(235, 188, 186), // Rose
|
||||
cmd_name: Color::Rgb(224, 222, 244),
|
||||
cmd_desc: Color::Rgb(110, 106, 134),
|
||||
overlay_bg: Color::Rgb(35, 33, 46),
|
||||
}
|
||||
}
|
||||
|
||||
/// Midnight Ocean - deep and serene
|
||||
pub fn midnight_ocean() -> Self {
|
||||
Self {
|
||||
primary: Color::Rgb(102, 217, 239), // Bright cyan
|
||||
secondary: Color::Rgb(130, 170, 255), // Periwinkle
|
||||
accent: Color::Rgb(199, 146, 234), // Purple
|
||||
success: Color::Rgb(163, 190, 140), // Sea green
|
||||
warning: Color::Rgb(229, 200, 144), // Sandy yellow
|
||||
error: Color::Rgb(236, 95, 103), // Coral red
|
||||
info: Color::Rgb(102, 217, 239), // Bright cyan
|
||||
bg: Color::Rgb(1, 22, 39), // Deep ocean
|
||||
fg: Color::Rgb(201, 211, 235), // Light blue-white
|
||||
fg_dim: Color::Rgb(71, 103, 145), // Muted blue
|
||||
fg_muted: Color::Rgb(13, 43, 69),
|
||||
highlight: Color::Rgb(13, 43, 69), // Deep blue
|
||||
border: Color::Rgb(13, 43, 69),
|
||||
selection: Color::Rgb(13, 43, 69),
|
||||
claude: Color::Rgb(199, 146, 234),
|
||||
ollama: Color::Rgb(102, 217, 239),
|
||||
openai: Color::Rgb(163, 190, 140),
|
||||
user_fg: Color::Rgb(201, 211, 235),
|
||||
assistant_fg: Color::Rgb(150, 175, 200), // Softer blue-gray
|
||||
tool_fg: Color::Rgb(229, 200, 144),
|
||||
timestamp_fg: Color::Rgb(13, 43, 69),
|
||||
divider_fg: Color::Rgb(13, 43, 69),
|
||||
cmd_slash: Color::Rgb(199, 146, 234), // Purple
|
||||
cmd_name: Color::Rgb(201, 211, 235),
|
||||
cmd_desc: Color::Rgb(71, 103, 145),
|
||||
overlay_bg: Color::Rgb(11, 32, 49),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// LLM Provider enum
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Provider {
|
||||
Claude,
|
||||
Ollama,
|
||||
OpenAI,
|
||||
}
|
||||
|
||||
impl Provider {
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
Provider::Claude => "Claude",
|
||||
Provider::Ollama => "Ollama",
|
||||
Provider::OpenAI => "OpenAI",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn all() -> &'static [Provider] {
|
||||
&[Provider::Claude, Provider::Ollama, Provider::OpenAI]
|
||||
}
|
||||
}
|
||||
|
||||
/// Vim-like editing mode
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum VimMode {
|
||||
#[default]
|
||||
Normal,
|
||||
Insert,
|
||||
Visual,
|
||||
Command,
|
||||
}
|
||||
|
||||
impl VimMode {
|
||||
pub fn indicator(&self, symbols: &Symbols) -> &'static str {
|
||||
match self {
|
||||
VimMode::Normal => symbols.mode_normal,
|
||||
VimMode::Insert => symbols.mode_insert,
|
||||
VimMode::Visual => symbols.mode_visual,
|
||||
VimMode::Command => symbols.mode_command,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Theme configuration for the borderless TUI
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Theme {
|
||||
pub palette: ColorPalette,
|
||||
pub symbols: Symbols,
|
||||
pub capability: TerminalCapability,
|
||||
// Message styles
|
||||
pub user_message: Style,
|
||||
pub assistant_message: Style,
|
||||
pub tool_call: Style,
|
||||
pub tool_result_success: Style,
|
||||
pub tool_result_error: Style,
|
||||
pub system_message: Style,
|
||||
pub timestamp: Style,
|
||||
// UI element styles
|
||||
pub divider: Style,
|
||||
pub header: Style,
|
||||
pub header_accent: Style,
|
||||
pub tab_active: Style,
|
||||
pub tab_inactive: Style,
|
||||
pub input_prefix: Style,
|
||||
pub input_text: Style,
|
||||
pub input_placeholder: Style,
|
||||
pub status_bar: Style,
|
||||
pub status_accent: Style,
|
||||
pub status_dim: Style,
|
||||
// Command styles
|
||||
pub cmd_slash: Style, // Purple for / prefix
|
||||
pub cmd_name: Style, // White for command name
|
||||
pub cmd_desc: Style, // Dim for description
|
||||
// Overlay/modal styles
|
||||
pub overlay_bg: Style, // Modal background
|
||||
pub selection_bg: Style, // Selected row background
|
||||
// Popup styles (for permission dialogs)
|
||||
pub popup_border: Style,
|
||||
pub popup_bg: Style,
|
||||
pub popup_title: Style,
|
||||
pub selected: Style,
|
||||
// Legacy compatibility
|
||||
pub border: Style,
|
||||
pub border_active: Style,
|
||||
pub status_bar_highlight: Style,
|
||||
pub input_box: Style,
|
||||
pub input_box_active: Style,
|
||||
}
|
||||
|
||||
impl Theme {
|
||||
/// Create theme from color palette with automatic capability detection
|
||||
pub fn from_palette(palette: ColorPalette) -> Self {
|
||||
let capability = TerminalCapability::detect();
|
||||
Self::from_palette_with_capability(palette, capability)
|
||||
}
|
||||
|
||||
/// Create theme with specific terminal capability
|
||||
pub fn from_palette_with_capability(palette: ColorPalette, capability: TerminalCapability) -> Self {
|
||||
let symbols = Symbols::for_capability(capability);
|
||||
|
||||
Self {
|
||||
// Message styles
|
||||
user_message: Style::default()
|
||||
.fg(palette.user_fg)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
assistant_message: Style::default().fg(palette.assistant_fg),
|
||||
tool_call: Style::default()
|
||||
.fg(palette.tool_fg)
|
||||
.add_modifier(Modifier::ITALIC),
|
||||
tool_result_success: Style::default()
|
||||
.fg(palette.success)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
tool_result_error: Style::default()
|
||||
.fg(palette.error)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
system_message: Style::default().fg(palette.fg_dim),
|
||||
timestamp: Style::default().fg(palette.timestamp_fg),
|
||||
// UI elements
|
||||
divider: Style::default().fg(palette.divider_fg),
|
||||
header: Style::default()
|
||||
.fg(palette.fg)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
header_accent: Style::default()
|
||||
.fg(palette.accent)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
tab_active: Style::default()
|
||||
.fg(palette.primary)
|
||||
.add_modifier(Modifier::BOLD | Modifier::UNDERLINED),
|
||||
tab_inactive: Style::default().fg(palette.fg_dim),
|
||||
input_prefix: Style::default()
|
||||
.fg(palette.accent)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
input_text: Style::default().fg(palette.fg),
|
||||
input_placeholder: Style::default().fg(palette.fg_muted),
|
||||
status_bar: Style::default().fg(palette.fg_dim),
|
||||
status_accent: Style::default().fg(palette.accent),
|
||||
status_dim: Style::default().fg(palette.fg_muted),
|
||||
// Command styles
|
||||
cmd_slash: Style::default().fg(palette.cmd_slash),
|
||||
cmd_name: Style::default().fg(palette.cmd_name),
|
||||
cmd_desc: Style::default().fg(palette.cmd_desc),
|
||||
// Overlay/modal styles
|
||||
overlay_bg: Style::default().bg(palette.overlay_bg),
|
||||
selection_bg: Style::default().bg(palette.selection),
|
||||
// Popup styles
|
||||
popup_border: Style::default()
|
||||
.fg(palette.border)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
popup_bg: Style::default().bg(palette.overlay_bg),
|
||||
popup_title: Style::default()
|
||||
.fg(palette.fg)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
selected: Style::default()
|
||||
.fg(palette.fg)
|
||||
.bg(palette.selection)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
// Legacy compatibility
|
||||
border: Style::default().fg(palette.fg_dim),
|
||||
border_active: Style::default()
|
||||
.fg(palette.primary)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
status_bar_highlight: Style::default()
|
||||
.fg(palette.bg)
|
||||
.bg(palette.accent)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
input_box: Style::default().fg(palette.fg),
|
||||
input_box_active: Style::default()
|
||||
.fg(palette.accent)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
symbols,
|
||||
capability,
|
||||
palette,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get provider-specific color
|
||||
pub fn provider_color(&self, provider: Provider) -> Color {
|
||||
match provider {
|
||||
Provider::Claude => self.palette.claude,
|
||||
Provider::Ollama => self.palette.ollama,
|
||||
Provider::OpenAI => self.palette.openai,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get provider icon
|
||||
pub fn provider_icon(&self, provider: Provider) -> &str {
|
||||
match provider {
|
||||
Provider::Claude => self.symbols.claude_icon,
|
||||
Provider::Ollama => self.symbols.ollama_icon,
|
||||
Provider::OpenAI => self.symbols.openai_icon,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a horizontal rule string of given width
|
||||
pub fn horizontal_rule(&self, width: usize) -> String {
|
||||
self.symbols.horizontal_rule.repeat(width)
|
||||
}
|
||||
|
||||
/// Tokyo Night theme (default) - modern and vibrant
|
||||
pub fn tokyo_night() -> Self {
|
||||
Self::from_palette(ColorPalette::tokyo_night())
|
||||
}
|
||||
|
||||
/// Dracula theme - classic dark theme
|
||||
pub fn dracula() -> Self {
|
||||
Self::from_palette(ColorPalette::dracula())
|
||||
}
|
||||
|
||||
/// Catppuccin Mocha - warm and cozy
|
||||
pub fn catppuccin() -> Self {
|
||||
Self::from_palette(ColorPalette::catppuccin())
|
||||
}
|
||||
|
||||
/// Nord theme - minimal and clean
|
||||
pub fn nord() -> Self {
|
||||
Self::from_palette(ColorPalette::nord())
|
||||
}
|
||||
|
||||
/// Synthwave theme - vibrant retro
|
||||
pub fn synthwave() -> Self {
|
||||
Self::from_palette(ColorPalette::synthwave())
|
||||
}
|
||||
|
||||
/// Rose Pine theme - elegant and muted
|
||||
pub fn rose_pine() -> Self {
|
||||
Self::from_palette(ColorPalette::rose_pine())
|
||||
}
|
||||
|
||||
/// Midnight Ocean theme - deep and serene
|
||||
pub fn midnight_ocean() -> Self {
|
||||
Self::from_palette(ColorPalette::midnight_ocean())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Theme {
|
||||
fn default() -> Self {
|
||||
Self::tokyo_night()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_terminal_capability_detection() {
|
||||
let cap = TerminalCapability::detect();
|
||||
// Should return some valid capability
|
||||
assert!(matches!(
|
||||
cap,
|
||||
TerminalCapability::Full | TerminalCapability::Unicode256 | TerminalCapability::Basic
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_symbols_for_capability() {
|
||||
let unicode = Symbols::for_capability(TerminalCapability::Full);
|
||||
assert_eq!(unicode.horizontal_rule, "─");
|
||||
|
||||
let ascii = Symbols::for_capability(TerminalCapability::Basic);
|
||||
assert_eq!(ascii.horizontal_rule, "-");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_theme_from_palette() {
|
||||
let theme = Theme::tokyo_night();
|
||||
assert!(theme.capability.supports_unicode() || !theme.capability.supports_unicode());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_colors() {
|
||||
let theme = Theme::tokyo_night();
|
||||
let claude_color = theme.provider_color(Provider::Claude);
|
||||
let ollama_color = theme.provider_color(Provider::Ollama);
|
||||
assert_ne!(claude_color, ollama_color);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vim_mode_indicator() {
|
||||
let symbols = Symbols::unicode();
|
||||
assert_eq!(VimMode::Normal.indicator(&symbols), "[N]");
|
||||
assert_eq!(VimMode::Insert.indicator(&symbols), "[I]");
|
||||
}
|
||||
}
|
||||
29
crates/core/agent/Cargo.toml
Normal file
29
crates/core/agent/Cargo.toml
Normal file
@@ -0,0 +1,29 @@
|
||||
[package]
|
||||
name = "agent-core"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
rust-version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
color-eyre = "0.6"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
futures-util = "0.3"
|
||||
tracing = "0.1"
|
||||
async-trait = "0.1"
|
||||
chrono = "0.4"
|
||||
|
||||
# Internal dependencies
|
||||
llm-core = { path = "../../llm/core" }
|
||||
permissions = { path = "../../platform/permissions" }
|
||||
tools-fs = { path = "../../tools/fs" }
|
||||
tools-bash = { path = "../../tools/bash" }
|
||||
tools-ask = { path = "../../tools/ask" }
|
||||
tools-todo = { path = "../../tools/todo" }
|
||||
tools-web = { path = "../../tools/web" }
|
||||
tools-plan = { path = "../../tools/plan" }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.13"
|
||||
74
crates/core/agent/examples/git_demo.rs
Normal file
74
crates/core/agent/examples/git_demo.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
//! Example demonstrating the git integration module
|
||||
//!
|
||||
//! Run with: cargo run -p agent-core --example git_demo
|
||||
|
||||
use agent_core::{detect_git_state, format_git_status, is_safe_git_command, is_destructive_git_command};
|
||||
use std::env;
|
||||
|
||||
fn main() -> color_eyre::Result<()> {
|
||||
color_eyre::install()?;
|
||||
|
||||
// Get current working directory
|
||||
let cwd = env::current_dir()?;
|
||||
println!("Detecting git state in: {}\n", cwd.display());
|
||||
|
||||
// Detect git state
|
||||
let state = detect_git_state(&cwd)?;
|
||||
|
||||
// Display formatted status
|
||||
println!("{}\n", format_git_status(&state));
|
||||
|
||||
// Show detailed file status if there are changes
|
||||
if !state.status.is_empty() {
|
||||
println!("Detailed file status:");
|
||||
for status in &state.status {
|
||||
match status {
|
||||
agent_core::GitFileStatus::Modified { path } => {
|
||||
println!(" M {}", path);
|
||||
}
|
||||
agent_core::GitFileStatus::Added { path } => {
|
||||
println!(" A {}", path);
|
||||
}
|
||||
agent_core::GitFileStatus::Deleted { path } => {
|
||||
println!(" D {}", path);
|
||||
}
|
||||
agent_core::GitFileStatus::Renamed { from, to } => {
|
||||
println!(" R {} -> {}", from, to);
|
||||
}
|
||||
agent_core::GitFileStatus::Untracked { path } => {
|
||||
println!(" ? {}", path);
|
||||
}
|
||||
}
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
// Test command safety checking
|
||||
println!("Command safety checks:");
|
||||
let test_commands = vec![
|
||||
"git status",
|
||||
"git log --oneline",
|
||||
"git diff HEAD",
|
||||
"git commit -m 'test'",
|
||||
"git push --force origin main",
|
||||
"git reset --hard HEAD~1",
|
||||
"git rebase main",
|
||||
"git branch -D feature",
|
||||
];
|
||||
|
||||
for cmd in test_commands {
|
||||
let is_safe = is_safe_git_command(cmd);
|
||||
let (is_destructive, warning) = is_destructive_git_command(cmd);
|
||||
|
||||
print!(" {} - ", cmd);
|
||||
if is_safe {
|
||||
println!("SAFE (read-only)");
|
||||
} else if is_destructive {
|
||||
println!("DESTRUCTIVE: {}", warning);
|
||||
} else {
|
||||
println!("UNSAFE (modifies state)");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
92
crates/core/agent/examples/streaming_agent.rs
Normal file
92
crates/core/agent/examples/streaming_agent.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
/// Example demonstrating the streaming agent loop API
|
||||
///
|
||||
/// This example shows how to use `run_agent_loop_streaming` to receive
|
||||
/// real-time events during agent execution, including:
|
||||
/// - Text deltas as the LLM generates text
|
||||
/// - Tool execution start/end events
|
||||
/// - Tool output events
|
||||
/// - Final completion events
|
||||
///
|
||||
/// Run with: cargo run --example streaming_agent -p agent-core
|
||||
|
||||
use agent_core::{create_event_channel, run_agent_loop_streaming, AgentEvent, ToolContext};
|
||||
use llm_core::ChatOptions;
|
||||
use permissions::{Mode, PermissionManager};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> color_eyre::Result<()> {
|
||||
color_eyre::install()?;
|
||||
|
||||
// Note: This is a minimal example. In a real application, you would:
|
||||
// 1. Initialize a real LLM provider (e.g., OllamaClient)
|
||||
// 2. Configure the ChatOptions with your preferred model
|
||||
// 3. Set up appropriate permissions and tool context
|
||||
|
||||
println!("=== Streaming Agent Example ===\n");
|
||||
println!("This example demonstrates how to use the streaming agent loop API.");
|
||||
println!("To run with a real LLM provider, modify this example to:");
|
||||
println!(" 1. Create an LLM provider instance");
|
||||
println!(" 2. Set up permissions and tool context");
|
||||
println!(" 3. Call run_agent_loop_streaming with your prompt\n");
|
||||
|
||||
// Example code structure:
|
||||
println!("Example code:");
|
||||
println!("```rust");
|
||||
println!("// Create LLM provider");
|
||||
println!("let provider = OllamaClient::new(\"http://localhost:11434\");");
|
||||
println!();
|
||||
println!("// Set up permissions and context");
|
||||
println!("let perms = PermissionManager::new(Mode::Plan);");
|
||||
println!("let ctx = ToolContext::default();");
|
||||
println!();
|
||||
println!("// Create event channel");
|
||||
println!("let (tx, mut rx) = create_event_channel();");
|
||||
println!();
|
||||
println!("// Spawn agent loop");
|
||||
println!("let handle = tokio::spawn(async move {{");
|
||||
println!(" run_agent_loop_streaming(");
|
||||
println!(" &provider,");
|
||||
println!(" \"Your prompt here\",");
|
||||
println!(" &ChatOptions::default(),");
|
||||
println!(" &perms,");
|
||||
println!(" &ctx,");
|
||||
println!(" tx,");
|
||||
println!(" ).await");
|
||||
println!("}});");
|
||||
println!();
|
||||
println!("// Process events");
|
||||
println!("while let Some(event) = rx.recv().await {{");
|
||||
println!(" match event {{");
|
||||
println!(" AgentEvent::TextDelta(text) => {{");
|
||||
println!(" print!(\"{{text}}\");");
|
||||
println!(" }}");
|
||||
println!(" AgentEvent::ToolStart {{ tool_name, .. }} => {{");
|
||||
println!(" println!(\"\\n[Executing tool: {{tool_name}}]\");");
|
||||
println!(" }}");
|
||||
println!(" AgentEvent::ToolOutput {{ content, is_error, .. }} => {{");
|
||||
println!(" if is_error {{");
|
||||
println!(" eprintln!(\"Error: {{content}}\");");
|
||||
println!(" }} else {{");
|
||||
println!(" println!(\"Output: {{content}}\");");
|
||||
println!(" }}");
|
||||
println!(" }}");
|
||||
println!(" AgentEvent::ToolEnd {{ success, .. }} => {{");
|
||||
println!(" println!(\"[Tool finished: {{}}]\", if success {{ \"success\" }} else {{ \"failed\" }});");
|
||||
println!(" }}");
|
||||
println!(" AgentEvent::Done {{ final_response }} => {{");
|
||||
println!(" println!(\"\\n\\nFinal response: {{final_response}}\");");
|
||||
println!(" break;");
|
||||
println!(" }}");
|
||||
println!(" AgentEvent::Error(e) => {{");
|
||||
println!(" eprintln!(\"Error: {{e}}\");");
|
||||
println!(" break;");
|
||||
println!(" }}");
|
||||
println!(" }}");
|
||||
println!("}}");
|
||||
println!();
|
||||
println!("// Wait for completion");
|
||||
println!("let result = handle.await??;");
|
||||
println!("```");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
218
crates/core/agent/src/compact.rs
Normal file
218
crates/core/agent/src/compact.rs
Normal file
@@ -0,0 +1,218 @@
|
||||
//! Context compaction for long conversations
|
||||
//!
|
||||
//! When the conversation context grows too large, this module compacts
|
||||
//! earlier messages into a summary while preserving recent context.
|
||||
|
||||
use color_eyre::eyre::Result;
|
||||
use llm_core::{ChatMessage, ChatOptions, LlmProvider};
|
||||
|
||||
/// Token limit threshold for triggering compaction
|
||||
const CONTEXT_LIMIT: usize = 180_000;
|
||||
|
||||
/// Threshold ratio at which to trigger compaction (90% of limit)
|
||||
const COMPACTION_THRESHOLD: f64 = 0.9;
|
||||
|
||||
/// Number of recent messages to preserve during compaction
|
||||
const PRESERVE_RECENT: usize = 10;
|
||||
|
||||
/// Token counter for estimating context size
|
||||
pub struct TokenCounter {
|
||||
chars_per_token: f64,
|
||||
}
|
||||
|
||||
impl Default for TokenCounter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenCounter {
|
||||
pub fn new() -> Self {
|
||||
// Rough estimate: ~4 chars per token for English text
|
||||
Self { chars_per_token: 4.0 }
|
||||
}
|
||||
|
||||
/// Estimate token count for a message
|
||||
pub fn count_message(&self, message: &ChatMessage) -> usize {
|
||||
let content_len = message.content.as_ref().map(|c| c.len()).unwrap_or(0);
|
||||
// Add overhead for role, metadata
|
||||
let overhead = 10;
|
||||
((content_len as f64 / self.chars_per_token) as usize) + overhead
|
||||
}
|
||||
|
||||
/// Estimate total token count for all messages
|
||||
pub fn count_messages(&self, messages: &[ChatMessage]) -> usize {
|
||||
messages.iter().map(|m| self.count_message(m)).sum()
|
||||
}
|
||||
|
||||
/// Check if context should be compacted
|
||||
pub fn should_compact(&self, messages: &[ChatMessage]) -> bool {
|
||||
let count = self.count_messages(messages);
|
||||
count > (CONTEXT_LIMIT as f64 * COMPACTION_THRESHOLD) as usize
|
||||
}
|
||||
}
|
||||
|
||||
/// Context compactor that summarizes conversation history
|
||||
pub struct Compactor {
|
||||
token_counter: TokenCounter,
|
||||
}
|
||||
|
||||
impl Default for Compactor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Compactor {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
token_counter: TokenCounter::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if messages need compaction
|
||||
pub fn needs_compaction(&self, messages: &[ChatMessage]) -> bool {
|
||||
self.token_counter.should_compact(messages)
|
||||
}
|
||||
|
||||
/// Compact messages by summarizing earlier conversation
|
||||
///
|
||||
/// Returns compacted messages with:
|
||||
/// - A system message containing the summary of earlier context
|
||||
/// - The most recent N messages preserved in full
|
||||
pub async fn compact<P: LlmProvider>(
|
||||
&self,
|
||||
provider: &P,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
) -> Result<Vec<ChatMessage>> {
|
||||
// If not enough messages to compact, return as-is
|
||||
if messages.len() <= PRESERVE_RECENT + 1 {
|
||||
return Ok(messages.to_vec());
|
||||
}
|
||||
|
||||
// Split into messages to summarize and messages to preserve
|
||||
let split_point = messages.len().saturating_sub(PRESERVE_RECENT);
|
||||
let to_summarize = &messages[..split_point];
|
||||
let to_preserve = &messages[split_point..];
|
||||
|
||||
// Generate summary of earlier messages
|
||||
let summary = self.summarize_messages(provider, to_summarize, options).await?;
|
||||
|
||||
// Build compacted message list
|
||||
let mut compacted = Vec::with_capacity(PRESERVE_RECENT + 1);
|
||||
|
||||
// Add system message with summary
|
||||
compacted.push(ChatMessage::system(format!(
|
||||
"## Earlier Conversation Summary\n\n{}\n\n---\n\n\
|
||||
The above summarizes the earlier part of this conversation. \
|
||||
Continue from the recent messages below.",
|
||||
summary
|
||||
)));
|
||||
|
||||
// Add preserved recent messages
|
||||
compacted.extend(to_preserve.iter().cloned());
|
||||
|
||||
Ok(compacted)
|
||||
}
|
||||
|
||||
/// Generate a summary of messages using the LLM
|
||||
async fn summarize_messages<P: LlmProvider>(
|
||||
&self,
|
||||
provider: &P,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
) -> Result<String> {
|
||||
// Format messages for summarization
|
||||
let mut context = String::new();
|
||||
for msg in messages {
|
||||
let role = &msg.role;
|
||||
let content = msg.content.as_deref().unwrap_or("");
|
||||
context.push_str(&format!("[{:?}]: {}\n\n", role, content));
|
||||
}
|
||||
|
||||
// Create summarization prompt
|
||||
let summary_prompt = format!(
|
||||
"Please provide a concise summary of the following conversation. \
|
||||
Focus on:\n\
|
||||
1. Key decisions made\n\
|
||||
2. Important files or code mentioned\n\
|
||||
3. Tasks completed and their outcomes\n\
|
||||
4. Any pending items or next steps discussed\n\n\
|
||||
Keep the summary informative but brief (under 500 words).\n\n\
|
||||
Conversation:\n{}\n\n\
|
||||
Summary:",
|
||||
context
|
||||
);
|
||||
|
||||
// Call LLM to generate summary
|
||||
let summary_options = ChatOptions {
|
||||
model: options.model.clone(),
|
||||
max_tokens: Some(1000),
|
||||
temperature: Some(0.3), // Lower temperature for more focused summary
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let summary_messages = vec![ChatMessage::user(&summary_prompt)];
|
||||
let mut stream = provider.chat_stream(&summary_messages, &summary_options, None).await?;
|
||||
|
||||
let mut summary = String::new();
|
||||
use futures_util::StreamExt;
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
if let Ok(chunk) = chunk_result {
|
||||
if let Some(content) = &chunk.content {
|
||||
summary.push_str(content);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(summary.trim().to_string())
|
||||
}
|
||||
|
||||
/// Get token counter for external use
|
||||
pub fn token_counter(&self) -> &TokenCounter {
|
||||
&self.token_counter
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_token_counter_estimate() {
|
||||
let counter = TokenCounter::new();
|
||||
let msg = ChatMessage::user("Hello, world!");
|
||||
let count = counter.count_message(&msg);
|
||||
// Should be approximately 13/4 + 10 overhead = 13
|
||||
assert!(count > 10);
|
||||
assert!(count < 20);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_compact() {
|
||||
let counter = TokenCounter::new();
|
||||
|
||||
// Small message list shouldn't compact
|
||||
let small_messages: Vec<ChatMessage> = (0..10)
|
||||
.map(|i| ChatMessage::user(&format!("Message {}", i)))
|
||||
.collect();
|
||||
assert!(!counter.should_compact(&small_messages));
|
||||
|
||||
// Large message list should compact
|
||||
// Need ~162,000 tokens = ~648,000 chars (at 4 chars per token)
|
||||
let large_content = "x".repeat(700_000);
|
||||
let large_messages = vec![ChatMessage::user(&large_content)];
|
||||
assert!(counter.should_compact(&large_messages));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compactor_needs_compaction() {
|
||||
let compactor = Compactor::new();
|
||||
|
||||
let small: Vec<ChatMessage> = (0..5)
|
||||
.map(|i| ChatMessage::user(&format!("Short message {}", i)))
|
||||
.collect();
|
||||
assert!(!compactor.needs_compaction(&small));
|
||||
}
|
||||
}
|
||||
557
crates/core/agent/src/git.rs
Normal file
557
crates/core/agent/src/git.rs
Normal file
@@ -0,0 +1,557 @@
|
||||
//! Git integration module for detecting repository state and validating git commands.
|
||||
//!
|
||||
//! This module provides functionality to:
|
||||
//! - Detect if the current directory is a git repository
|
||||
//! - Capture git repository state (branch, status, uncommitted changes)
|
||||
//! - Validate git commands for safety (read-only vs destructive operations)
|
||||
|
||||
use color_eyre::eyre::Result;
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
|
||||
/// Status of a file in the git working tree
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum GitFileStatus {
|
||||
/// File has been modified
|
||||
Modified { path: String },
|
||||
/// File has been added (staged)
|
||||
Added { path: String },
|
||||
/// File has been deleted
|
||||
Deleted { path: String },
|
||||
/// File has been renamed
|
||||
Renamed { from: String, to: String },
|
||||
/// File is untracked
|
||||
Untracked { path: String },
|
||||
}
|
||||
|
||||
impl GitFileStatus {
|
||||
/// Get the primary path associated with this status
|
||||
pub fn path(&self) -> &str {
|
||||
match self {
|
||||
Self::Modified { path } => path,
|
||||
Self::Added { path } => path,
|
||||
Self::Deleted { path } => path,
|
||||
Self::Renamed { to, .. } => to,
|
||||
Self::Untracked { path } => path,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Complete state of a git repository
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GitState {
|
||||
/// Whether the current directory is in a git repository
|
||||
pub is_git_repo: bool,
|
||||
/// Current branch name (None if not in a repo or detached HEAD)
|
||||
pub current_branch: Option<String>,
|
||||
/// Main branch name (main/master, None if not detected)
|
||||
pub main_branch: Option<String>,
|
||||
/// Status of files in the working tree
|
||||
pub status: Vec<GitFileStatus>,
|
||||
/// Whether there are any uncommitted changes
|
||||
pub has_uncommitted_changes: bool,
|
||||
/// Remote URL for the repository (None if no remote configured)
|
||||
pub remote_url: Option<String>,
|
||||
}
|
||||
|
||||
impl GitState {
|
||||
/// Create a default GitState for non-git directories
|
||||
pub fn not_a_repo() -> Self {
|
||||
Self {
|
||||
is_git_repo: false,
|
||||
current_branch: None,
|
||||
main_branch: None,
|
||||
status: Vec::new(),
|
||||
has_uncommitted_changes: false,
|
||||
remote_url: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect the current git repository state
|
||||
///
|
||||
/// This function runs various git commands to gather information about the repository.
|
||||
/// If git is not available or the directory is not a git repo, returns a default state.
|
||||
pub fn detect_git_state(working_dir: &Path) -> Result<GitState> {
|
||||
// Check if this is a git repository
|
||||
let is_repo = Command::new("git")
|
||||
.arg("rev-parse")
|
||||
.arg("--git-dir")
|
||||
.current_dir(working_dir)
|
||||
.output()
|
||||
.map(|output| output.status.success())
|
||||
.unwrap_or(false);
|
||||
|
||||
if !is_repo {
|
||||
return Ok(GitState::not_a_repo());
|
||||
}
|
||||
|
||||
// Get current branch
|
||||
let current_branch = get_current_branch(working_dir)?;
|
||||
|
||||
// Detect main branch (try main first, then master)
|
||||
let main_branch = detect_main_branch(working_dir)?;
|
||||
|
||||
// Get file status
|
||||
let status = get_git_status(working_dir)?;
|
||||
|
||||
// Check if there are uncommitted changes
|
||||
let has_uncommitted_changes = !status.is_empty();
|
||||
|
||||
// Get remote URL
|
||||
let remote_url = get_remote_url(working_dir)?;
|
||||
|
||||
Ok(GitState {
|
||||
is_git_repo: true,
|
||||
current_branch,
|
||||
main_branch,
|
||||
status,
|
||||
has_uncommitted_changes,
|
||||
remote_url,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the current branch name
|
||||
fn get_current_branch(working_dir: &Path) -> Result<Option<String>> {
|
||||
let output = Command::new("git")
|
||||
.arg("rev-parse")
|
||||
.arg("--abbrev-ref")
|
||||
.arg("HEAD")
|
||||
.current_dir(working_dir)
|
||||
.output()?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let branch = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
|
||||
// "HEAD" means detached HEAD state
|
||||
if branch == "HEAD" {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(branch))
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect the main branch (main or master)
|
||||
fn detect_main_branch(working_dir: &Path) -> Result<Option<String>> {
|
||||
// Try to get all branches
|
||||
let output = Command::new("git")
|
||||
.arg("branch")
|
||||
.arg("-a")
|
||||
.current_dir(working_dir)
|
||||
.output()?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let branches = String::from_utf8_lossy(&output.stdout);
|
||||
|
||||
// Check for main branch first (modern convention)
|
||||
if branches.lines().any(|line| {
|
||||
let trimmed = line.trim_start_matches('*').trim();
|
||||
trimmed == "main" || trimmed.ends_with("/main")
|
||||
}) {
|
||||
return Ok(Some("main".to_string()));
|
||||
}
|
||||
|
||||
// Fall back to master
|
||||
if branches.lines().any(|line| {
|
||||
let trimmed = line.trim_start_matches('*').trim();
|
||||
trimmed == "master" || trimmed.ends_with("/master")
|
||||
}) {
|
||||
return Ok(Some("master".to_string()));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Get the git status for all files
|
||||
fn get_git_status(working_dir: &Path) -> Result<Vec<GitFileStatus>> {
|
||||
let output = Command::new("git")
|
||||
.arg("status")
|
||||
.arg("--porcelain")
|
||||
.arg("-z") // Null-terminated for better parsing
|
||||
.current_dir(working_dir)
|
||||
.output()?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let status_text = String::from_utf8_lossy(&output.stdout);
|
||||
let mut statuses = Vec::new();
|
||||
|
||||
// Parse porcelain format with null termination
|
||||
// Format: XY filename\0 (where X is staged status, Y is unstaged status)
|
||||
for entry in status_text.split('\0').filter(|s| !s.is_empty()) {
|
||||
if entry.len() < 3 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let status_code = &entry[0..2];
|
||||
let path = entry[3..].to_string();
|
||||
|
||||
// Parse status codes
|
||||
match status_code {
|
||||
"M " | " M" | "MM" => {
|
||||
statuses.push(GitFileStatus::Modified { path });
|
||||
}
|
||||
"A " | " A" | "AM" => {
|
||||
statuses.push(GitFileStatus::Added { path });
|
||||
}
|
||||
"D " | " D" | "AD" => {
|
||||
statuses.push(GitFileStatus::Deleted { path });
|
||||
}
|
||||
"??" => {
|
||||
statuses.push(GitFileStatus::Untracked { path });
|
||||
}
|
||||
s if s.starts_with('R') => {
|
||||
// Renamed files have format "R old_name -> new_name"
|
||||
if let Some((from, to)) = path.split_once(" -> ") {
|
||||
statuses.push(GitFileStatus::Renamed {
|
||||
from: from.to_string(),
|
||||
to: to.to_string(),
|
||||
});
|
||||
} else {
|
||||
// Fallback if parsing fails
|
||||
statuses.push(GitFileStatus::Modified { path });
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Unknown status code, treat as modified
|
||||
statuses.push(GitFileStatus::Modified { path });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(statuses)
|
||||
}
|
||||
|
||||
/// Get the remote URL for the repository
|
||||
fn get_remote_url(working_dir: &Path) -> Result<Option<String>> {
|
||||
let output = Command::new("git")
|
||||
.arg("remote")
|
||||
.arg("get-url")
|
||||
.arg("origin")
|
||||
.current_dir(working_dir)
|
||||
.output()?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let url = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
|
||||
if url.is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(url))
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a git command is safe (read-only)
|
||||
///
|
||||
/// Safe commands include:
|
||||
/// - status, log, show, diff, branch (without -D)
|
||||
/// - remote (without add/remove)
|
||||
/// - config --get
|
||||
/// - rev-parse, ls-files, ls-tree
|
||||
pub fn is_safe_git_command(command: &str) -> bool {
|
||||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||||
|
||||
if parts.is_empty() || parts[0] != "git" {
|
||||
return false;
|
||||
}
|
||||
|
||||
if parts.len() < 2 {
|
||||
return false;
|
||||
}
|
||||
|
||||
let subcommand = parts[1];
|
||||
|
||||
// List of read-only git commands
|
||||
match subcommand {
|
||||
"status" | "log" | "show" | "diff" | "blame" | "reflog" => true,
|
||||
"ls-files" | "ls-tree" | "ls-remote" => true,
|
||||
"rev-parse" | "rev-list" => true,
|
||||
"describe" | "tag" if !command.contains("-d") && !command.contains("--delete") => true,
|
||||
"branch" if !command.contains("-D") && !command.contains("-d") && !command.contains("-m") => true,
|
||||
"remote" if command.contains("get-url") || command.contains("-v") || command.contains("show") => true,
|
||||
"config" if command.contains("--get") || command.contains("--list") => true,
|
||||
"grep" | "shortlog" | "whatchanged" => true,
|
||||
"fetch" if !command.contains("--prune") => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a git command is destructive
|
||||
///
|
||||
/// Returns (is_destructive, warning_message) tuple.
|
||||
/// Destructive commands include:
|
||||
/// - push --force, reset --hard, clean -fd
|
||||
/// - rebase, amend, filter-branch
|
||||
/// - branch -D, tag -d
|
||||
pub fn is_destructive_git_command(command: &str) -> (bool, &'static str) {
|
||||
let cmd_lower = command.to_lowercase();
|
||||
|
||||
// Check for force push
|
||||
if cmd_lower.contains("push") && (cmd_lower.contains("--force") || cmd_lower.contains("-f")) {
|
||||
return (true, "Force push can overwrite remote history and affect other collaborators");
|
||||
}
|
||||
|
||||
// Check for hard reset
|
||||
if cmd_lower.contains("reset") && cmd_lower.contains("--hard") {
|
||||
return (true, "Hard reset will discard uncommitted changes permanently");
|
||||
}
|
||||
|
||||
// Check for git clean
|
||||
if cmd_lower.contains("clean") && (cmd_lower.contains("-f") || cmd_lower.contains("-d")) {
|
||||
return (true, "Git clean will permanently delete untracked files");
|
||||
}
|
||||
|
||||
// Check for rebase
|
||||
if cmd_lower.contains("rebase") {
|
||||
return (true, "Rebase rewrites commit history and can cause conflicts");
|
||||
}
|
||||
|
||||
// Check for amend
|
||||
if cmd_lower.contains("commit") && cmd_lower.contains("--amend") {
|
||||
return (true, "Amending rewrites the last commit and changes its hash");
|
||||
}
|
||||
|
||||
// Check for filter-branch or filter-repo
|
||||
if cmd_lower.contains("filter-branch") || cmd_lower.contains("filter-repo") {
|
||||
return (true, "Filter operations rewrite repository history");
|
||||
}
|
||||
|
||||
// Check for branch/tag deletion
|
||||
if (cmd_lower.contains("branch") && (cmd_lower.contains("-D") || cmd_lower.contains("-d")))
|
||||
|| (cmd_lower.contains("tag") && (cmd_lower.contains("-d") || cmd_lower.contains("--delete")))
|
||||
{
|
||||
return (true, "This will delete a branch or tag");
|
||||
}
|
||||
|
||||
// Check for reflog expire
|
||||
if cmd_lower.contains("reflog") && cmd_lower.contains("expire") {
|
||||
return (true, "Expiring reflog removes recovery points for lost commits");
|
||||
}
|
||||
|
||||
// Check for gc with aggressive or prune
|
||||
if cmd_lower.contains("gc") && (cmd_lower.contains("--aggressive") || cmd_lower.contains("--prune")) {
|
||||
return (true, "Aggressive garbage collection can make recovery difficult");
|
||||
}
|
||||
|
||||
(false, "")
|
||||
}
|
||||
|
||||
/// Format git state for human-readable display
|
||||
///
|
||||
/// Example output:
|
||||
/// ```text
|
||||
/// Git Repository: yes
|
||||
/// Current branch: feature-branch
|
||||
/// Main branch: main
|
||||
/// Status: 3 modified, 1 untracked
|
||||
/// Remote: https://github.com/user/repo.git
|
||||
/// ```
|
||||
pub fn format_git_status(state: &GitState) -> String {
|
||||
if !state.is_git_repo {
|
||||
return "Not a git repository".to_string();
|
||||
}
|
||||
|
||||
let mut lines = Vec::new();
|
||||
|
||||
lines.push("Git Repository: yes".to_string());
|
||||
|
||||
if let Some(branch) = &state.current_branch {
|
||||
lines.push(format!("Current branch: {}", branch));
|
||||
} else {
|
||||
lines.push("Current branch: (detached HEAD)".to_string());
|
||||
}
|
||||
|
||||
if let Some(main) = &state.main_branch {
|
||||
lines.push(format!("Main branch: {}", main));
|
||||
}
|
||||
|
||||
// Summarize status
|
||||
if state.status.is_empty() {
|
||||
lines.push("Status: clean working tree".to_string());
|
||||
} else {
|
||||
let mut modified = 0;
|
||||
let mut added = 0;
|
||||
let mut deleted = 0;
|
||||
let mut renamed = 0;
|
||||
let mut untracked = 0;
|
||||
|
||||
for status in &state.status {
|
||||
match status {
|
||||
GitFileStatus::Modified { .. } => modified += 1,
|
||||
GitFileStatus::Added { .. } => added += 1,
|
||||
GitFileStatus::Deleted { .. } => deleted += 1,
|
||||
GitFileStatus::Renamed { .. } => renamed += 1,
|
||||
GitFileStatus::Untracked { .. } => untracked += 1,
|
||||
}
|
||||
}
|
||||
|
||||
let mut status_parts = Vec::new();
|
||||
if modified > 0 {
|
||||
status_parts.push(format!("{} modified", modified));
|
||||
}
|
||||
if added > 0 {
|
||||
status_parts.push(format!("{} added", added));
|
||||
}
|
||||
if deleted > 0 {
|
||||
status_parts.push(format!("{} deleted", deleted));
|
||||
}
|
||||
if renamed > 0 {
|
||||
status_parts.push(format!("{} renamed", renamed));
|
||||
}
|
||||
if untracked > 0 {
|
||||
status_parts.push(format!("{} untracked", untracked));
|
||||
}
|
||||
|
||||
lines.push(format!("Status: {}", status_parts.join(", ")));
|
||||
}
|
||||
|
||||
if let Some(url) = &state.remote_url {
|
||||
lines.push(format!("Remote: {}", url));
|
||||
} else {
|
||||
lines.push("Remote: (none)".to_string());
|
||||
}
|
||||
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_is_safe_git_command() {
|
||||
// Safe commands
|
||||
assert!(is_safe_git_command("git status"));
|
||||
assert!(is_safe_git_command("git log --oneline"));
|
||||
assert!(is_safe_git_command("git diff HEAD"));
|
||||
assert!(is_safe_git_command("git branch -v"));
|
||||
assert!(is_safe_git_command("git remote -v"));
|
||||
assert!(is_safe_git_command("git config --get user.name"));
|
||||
|
||||
// Unsafe commands
|
||||
assert!(!is_safe_git_command("git commit -m test"));
|
||||
assert!(!is_safe_git_command("git push origin main"));
|
||||
assert!(!is_safe_git_command("git branch -D feature"));
|
||||
assert!(!is_safe_git_command("git remote add origin url"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_destructive_git_command() {
|
||||
// Destructive commands
|
||||
let (is_dest, msg) = is_destructive_git_command("git push --force origin main");
|
||||
assert!(is_dest);
|
||||
assert!(msg.contains("Force push"));
|
||||
|
||||
let (is_dest, msg) = is_destructive_git_command("git reset --hard HEAD~1");
|
||||
assert!(is_dest);
|
||||
assert!(msg.contains("Hard reset"));
|
||||
|
||||
let (is_dest, msg) = is_destructive_git_command("git clean -fd");
|
||||
assert!(is_dest);
|
||||
assert!(msg.contains("clean"));
|
||||
|
||||
let (is_dest, msg) = is_destructive_git_command("git rebase main");
|
||||
assert!(is_dest);
|
||||
assert!(msg.contains("Rebase"));
|
||||
|
||||
let (is_dest, msg) = is_destructive_git_command("git commit --amend");
|
||||
assert!(is_dest);
|
||||
assert!(msg.contains("Amending"));
|
||||
|
||||
// Non-destructive commands
|
||||
let (is_dest, _) = is_destructive_git_command("git status");
|
||||
assert!(!is_dest);
|
||||
|
||||
let (is_dest, _) = is_destructive_git_command("git log");
|
||||
assert!(!is_dest);
|
||||
|
||||
let (is_dest, _) = is_destructive_git_command("git diff");
|
||||
assert!(!is_dest);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_git_state_not_a_repo() {
|
||||
let state = GitState::not_a_repo();
|
||||
assert!(!state.is_git_repo);
|
||||
assert!(state.current_branch.is_none());
|
||||
assert!(state.main_branch.is_none());
|
||||
assert!(state.status.is_empty());
|
||||
assert!(!state.has_uncommitted_changes);
|
||||
assert!(state.remote_url.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_git_file_status_path() {
|
||||
let status = GitFileStatus::Modified {
|
||||
path: "test.rs".to_string(),
|
||||
};
|
||||
assert_eq!(status.path(), "test.rs");
|
||||
|
||||
let status = GitFileStatus::Renamed {
|
||||
from: "old.rs".to_string(),
|
||||
to: "new.rs".to_string(),
|
||||
};
|
||||
assert_eq!(status.path(), "new.rs");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_git_status_not_repo() {
|
||||
let state = GitState::not_a_repo();
|
||||
let formatted = format_git_status(&state);
|
||||
assert_eq!(formatted, "Not a git repository");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_git_status_clean() {
|
||||
let state = GitState {
|
||||
is_git_repo: true,
|
||||
current_branch: Some("main".to_string()),
|
||||
main_branch: Some("main".to_string()),
|
||||
status: Vec::new(),
|
||||
has_uncommitted_changes: false,
|
||||
remote_url: Some("https://github.com/user/repo.git".to_string()),
|
||||
};
|
||||
|
||||
let formatted = format_git_status(&state);
|
||||
assert!(formatted.contains("Git Repository: yes"));
|
||||
assert!(formatted.contains("Current branch: main"));
|
||||
assert!(formatted.contains("clean working tree"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_git_status_with_changes() {
|
||||
let state = GitState {
|
||||
is_git_repo: true,
|
||||
current_branch: Some("feature".to_string()),
|
||||
main_branch: Some("main".to_string()),
|
||||
status: vec![
|
||||
GitFileStatus::Modified {
|
||||
path: "file1.rs".to_string(),
|
||||
},
|
||||
GitFileStatus::Modified {
|
||||
path: "file2.rs".to_string(),
|
||||
},
|
||||
GitFileStatus::Untracked {
|
||||
path: "new.rs".to_string(),
|
||||
},
|
||||
],
|
||||
has_uncommitted_changes: true,
|
||||
remote_url: None,
|
||||
};
|
||||
|
||||
let formatted = format_git_status(&state);
|
||||
assert!(formatted.contains("2 modified"));
|
||||
assert!(formatted.contains("1 untracked"));
|
||||
}
|
||||
}
|
||||
1130
crates/core/agent/src/lib.rs
Normal file
1130
crates/core/agent/src/lib.rs
Normal file
File diff suppressed because it is too large
Load Diff
295
crates/core/agent/src/session.rs
Normal file
295
crates/core/agent/src/session.rs
Normal file
@@ -0,0 +1,295 @@
|
||||
use color_eyre::eyre::{Result, eyre};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionStats {
|
||||
pub start_time: SystemTime,
|
||||
pub total_messages: usize,
|
||||
pub total_tool_calls: usize,
|
||||
pub total_duration: Duration,
|
||||
pub estimated_tokens: usize,
|
||||
}
|
||||
|
||||
impl SessionStats {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
start_time: SystemTime::now(),
|
||||
total_messages: 0,
|
||||
total_tool_calls: 0,
|
||||
total_duration: Duration::ZERO,
|
||||
estimated_tokens: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn record_message(&mut self, tokens: usize, duration: Duration) {
|
||||
self.total_messages += 1;
|
||||
self.estimated_tokens += tokens;
|
||||
self.total_duration += duration;
|
||||
}
|
||||
|
||||
pub fn record_tool_call(&mut self) {
|
||||
self.total_tool_calls += 1;
|
||||
}
|
||||
|
||||
pub fn format_duration(d: Duration) -> String {
|
||||
let secs = d.as_secs();
|
||||
if secs < 60 {
|
||||
format!("{}s", secs)
|
||||
} else if secs < 3600 {
|
||||
format!("{}m {}s", secs / 60, secs % 60)
|
||||
} else {
|
||||
format!("{}h {}m", secs / 3600, (secs % 3600) / 60)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SessionStats {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SessionHistory {
|
||||
pub user_prompts: Vec<String>,
|
||||
pub assistant_responses: Vec<String>,
|
||||
pub tool_calls: Vec<ToolCallRecord>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCallRecord {
|
||||
pub tool_name: String,
|
||||
pub arguments: String,
|
||||
pub result: String,
|
||||
pub success: bool,
|
||||
}
|
||||
|
||||
impl SessionHistory {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
user_prompts: Vec::new(),
|
||||
assistant_responses: Vec::new(),
|
||||
tool_calls: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_user_message(&mut self, message: String) {
|
||||
self.user_prompts.push(message);
|
||||
}
|
||||
|
||||
pub fn add_assistant_message(&mut self, message: String) {
|
||||
self.assistant_responses.push(message);
|
||||
}
|
||||
|
||||
pub fn add_tool_call(&mut self, record: ToolCallRecord) {
|
||||
self.tool_calls.push(record);
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.user_prompts.clear();
|
||||
self.assistant_responses.clear();
|
||||
self.tool_calls.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SessionHistory {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a file modification with before/after content
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FileDiff {
|
||||
pub path: PathBuf,
|
||||
pub before: String,
|
||||
pub after: String,
|
||||
pub timestamp: SystemTime,
|
||||
}
|
||||
|
||||
impl FileDiff {
|
||||
/// Create a new file diff
|
||||
pub fn new(path: PathBuf, before: String, after: String) -> Self {
|
||||
Self {
|
||||
path,
|
||||
before,
|
||||
after,
|
||||
timestamp: SystemTime::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A checkpoint captures the state of a session at a point in time
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Checkpoint {
|
||||
pub id: String,
|
||||
pub timestamp: SystemTime,
|
||||
pub stats: SessionStats,
|
||||
pub user_prompts: Vec<String>,
|
||||
pub assistant_responses: Vec<String>,
|
||||
pub tool_calls: Vec<ToolCallRecord>,
|
||||
pub file_diffs: Vec<FileDiff>,
|
||||
}
|
||||
|
||||
impl Checkpoint {
|
||||
/// Create a new checkpoint from current session state
|
||||
pub fn new(
|
||||
id: String,
|
||||
stats: SessionStats,
|
||||
history: &SessionHistory,
|
||||
file_diffs: Vec<FileDiff>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id,
|
||||
timestamp: SystemTime::now(),
|
||||
stats,
|
||||
user_prompts: history.user_prompts.clone(),
|
||||
assistant_responses: history.assistant_responses.clone(),
|
||||
tool_calls: history.tool_calls.clone(),
|
||||
file_diffs,
|
||||
}
|
||||
}
|
||||
|
||||
/// Save checkpoint to disk
|
||||
pub fn save(&self, checkpoint_dir: &Path) -> Result<()> {
|
||||
fs::create_dir_all(checkpoint_dir)?;
|
||||
let path = checkpoint_dir.join(format!("{}.json", self.id));
|
||||
let content = serde_json::to_string_pretty(self)?;
|
||||
fs::write(path, content)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load checkpoint from disk
|
||||
pub fn load(checkpoint_dir: &Path, id: &str) -> Result<Self> {
|
||||
let path = checkpoint_dir.join(format!("{}.json", id));
|
||||
let content = fs::read_to_string(&path)
|
||||
.map_err(|e| eyre!("Failed to read checkpoint: {}", e))?;
|
||||
let checkpoint: Checkpoint = serde_json::from_str(&content)
|
||||
.map_err(|e| eyre!("Failed to parse checkpoint: {}", e))?;
|
||||
Ok(checkpoint)
|
||||
}
|
||||
|
||||
/// List all available checkpoints in a directory
|
||||
pub fn list(checkpoint_dir: &Path) -> Result<Vec<String>> {
|
||||
if !checkpoint_dir.exists() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut checkpoints = Vec::new();
|
||||
for entry in fs::read_dir(checkpoint_dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.extension().and_then(|s| s.to_str()) == Some("json") {
|
||||
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
|
||||
checkpoints.push(stem.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by checkpoint ID (which includes timestamp)
|
||||
checkpoints.sort();
|
||||
Ok(checkpoints)
|
||||
}
|
||||
}
|
||||
|
||||
/// Session checkpoint manager
|
||||
pub struct CheckpointManager {
|
||||
checkpoint_dir: PathBuf,
|
||||
file_snapshots: HashMap<PathBuf, String>,
|
||||
}
|
||||
|
||||
impl CheckpointManager {
|
||||
/// Create a new checkpoint manager
|
||||
pub fn new(checkpoint_dir: PathBuf) -> Self {
|
||||
Self {
|
||||
checkpoint_dir,
|
||||
file_snapshots: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Snapshot a file's current content before modification
|
||||
pub fn snapshot_file(&mut self, path: &Path) -> Result<()> {
|
||||
if !self.file_snapshots.contains_key(path) {
|
||||
let content = fs::read_to_string(path).unwrap_or_default();
|
||||
self.file_snapshots.insert(path.to_path_buf(), content);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a file diff after modification
|
||||
pub fn create_diff(&self, path: &Path) -> Result<Option<FileDiff>> {
|
||||
if let Some(before) = self.file_snapshots.get(path) {
|
||||
let after = fs::read_to_string(path).unwrap_or_default();
|
||||
if before != &after {
|
||||
Ok(Some(FileDiff::new(
|
||||
path.to_path_buf(),
|
||||
before.clone(),
|
||||
after,
|
||||
)))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all file diffs since last checkpoint
|
||||
pub fn get_all_diffs(&self) -> Result<Vec<FileDiff>> {
|
||||
let mut diffs = Vec::new();
|
||||
for (path, before) in &self.file_snapshots {
|
||||
let after = fs::read_to_string(path).unwrap_or_default();
|
||||
if before != &after {
|
||||
diffs.push(FileDiff::new(path.clone(), before.clone(), after));
|
||||
}
|
||||
}
|
||||
Ok(diffs)
|
||||
}
|
||||
|
||||
/// Clear file snapshots
|
||||
pub fn clear_snapshots(&mut self) {
|
||||
self.file_snapshots.clear();
|
||||
}
|
||||
|
||||
/// Save a checkpoint
|
||||
pub fn save_checkpoint(
|
||||
&mut self,
|
||||
id: String,
|
||||
stats: SessionStats,
|
||||
history: &SessionHistory,
|
||||
) -> Result<Checkpoint> {
|
||||
let file_diffs = self.get_all_diffs()?;
|
||||
let checkpoint = Checkpoint::new(id, stats, history, file_diffs);
|
||||
checkpoint.save(&self.checkpoint_dir)?;
|
||||
self.clear_snapshots();
|
||||
Ok(checkpoint)
|
||||
}
|
||||
|
||||
/// Load a checkpoint
|
||||
pub fn load_checkpoint(&self, id: &str) -> Result<Checkpoint> {
|
||||
Checkpoint::load(&self.checkpoint_dir, id)
|
||||
}
|
||||
|
||||
/// List all checkpoints
|
||||
pub fn list_checkpoints(&self) -> Result<Vec<String>> {
|
||||
Checkpoint::list(&self.checkpoint_dir)
|
||||
}
|
||||
|
||||
/// Rewind to a checkpoint by restoring file contents
|
||||
pub fn rewind_to(&self, checkpoint_id: &str) -> Result<Vec<PathBuf>> {
|
||||
let checkpoint = self.load_checkpoint(checkpoint_id)?;
|
||||
let mut restored_files = Vec::new();
|
||||
|
||||
// Restore files from diffs (revert to 'before' state)
|
||||
for diff in &checkpoint.file_diffs {
|
||||
fs::write(&diff.path, &diff.before)?;
|
||||
restored_files.push(diff.path.clone());
|
||||
}
|
||||
|
||||
Ok(restored_files)
|
||||
}
|
||||
}
|
||||
266
crates/core/agent/src/system_prompt.rs
Normal file
266
crates/core/agent/src/system_prompt.rs
Normal file
@@ -0,0 +1,266 @@
|
||||
//! System Prompt Management
|
||||
//!
|
||||
//! Composes system prompts from multiple sources for agent sessions.
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
/// Builder for composing system prompts
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct SystemPromptBuilder {
|
||||
sections: Vec<PromptSection>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct PromptSection {
|
||||
name: String,
|
||||
content: String,
|
||||
priority: i32, // Lower = earlier in prompt
|
||||
}
|
||||
|
||||
impl SystemPromptBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Add the base agent prompt
|
||||
pub fn with_base_prompt(mut self, content: impl Into<String>) -> Self {
|
||||
self.sections.push(PromptSection {
|
||||
name: "base".to_string(),
|
||||
content: content.into(),
|
||||
priority: 0,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Add tool usage instructions
|
||||
pub fn with_tool_instructions(mut self, content: impl Into<String>) -> Self {
|
||||
self.sections.push(PromptSection {
|
||||
name: "tools".to_string(),
|
||||
content: content.into(),
|
||||
priority: 10,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Load and add project instructions from CLAUDE.md or .owlen.md
|
||||
pub fn with_project_instructions(mut self, project_root: &Path) -> Self {
|
||||
// Try CLAUDE.md first (Claude Code compatibility)
|
||||
let claude_md = project_root.join("CLAUDE.md");
|
||||
if claude_md.exists() {
|
||||
if let Ok(content) = std::fs::read_to_string(&claude_md) {
|
||||
self.sections.push(PromptSection {
|
||||
name: "project".to_string(),
|
||||
content: format!("# Project Instructions\n\n{}", content),
|
||||
priority: 20,
|
||||
});
|
||||
return self;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to .owlen.md
|
||||
let owlen_md = project_root.join(".owlen.md");
|
||||
if owlen_md.exists() {
|
||||
if let Ok(content) = std::fs::read_to_string(&owlen_md) {
|
||||
self.sections.push(PromptSection {
|
||||
name: "project".to_string(),
|
||||
content: format!("# Project Instructions\n\n{}", content),
|
||||
priority: 20,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Add skill content
|
||||
pub fn with_skill(mut self, skill_name: &str, content: impl Into<String>) -> Self {
|
||||
self.sections.push(PromptSection {
|
||||
name: format!("skill:{}", skill_name),
|
||||
content: content.into(),
|
||||
priority: 30,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Add hook-injected content (from SessionStart hooks)
|
||||
pub fn with_hook_injection(mut self, content: impl Into<String>) -> Self {
|
||||
self.sections.push(PromptSection {
|
||||
name: "hook".to_string(),
|
||||
content: content.into(),
|
||||
priority: 40,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Add custom section
|
||||
pub fn with_section(mut self, name: impl Into<String>, content: impl Into<String>, priority: i32) -> Self {
|
||||
self.sections.push(PromptSection {
|
||||
name: name.into(),
|
||||
content: content.into(),
|
||||
priority,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the final system prompt
|
||||
pub fn build(mut self) -> String {
|
||||
// Sort by priority
|
||||
self.sections.sort_by_key(|s| s.priority);
|
||||
|
||||
// Join sections with separators
|
||||
self.sections
|
||||
.iter()
|
||||
.map(|s| s.content.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n\n---\n\n")
|
||||
}
|
||||
|
||||
/// Check if any content has been added
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.sections.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Default base prompt for Owlen agent
|
||||
pub fn default_base_prompt() -> &'static str {
|
||||
r#"You are Owlen, an AI assistant that helps with software engineering tasks.
|
||||
|
||||
You have access to tools for reading files, writing code, running commands, and searching the web.
|
||||
|
||||
## Guidelines
|
||||
|
||||
1. Be direct and concise in your responses
|
||||
2. Use tools to gather information before making changes
|
||||
3. Explain your reasoning when making decisions
|
||||
4. Ask for clarification when requirements are unclear
|
||||
5. Prefer editing existing files over creating new ones
|
||||
|
||||
## Tool Usage
|
||||
|
||||
- Use `read` to examine file contents before editing
|
||||
- Use `glob` and `grep` to find relevant files
|
||||
- Use `edit` for precise changes, `write` for new files
|
||||
- Use `bash` for running tests and commands
|
||||
- Use `web_search` for current information"#
|
||||
}
|
||||
|
||||
/// Generate tool instructions based on available tools
|
||||
pub fn generate_tool_instructions(tool_names: &[&str]) -> String {
|
||||
let mut instructions = String::from("## Available Tools\n\n");
|
||||
|
||||
for name in tool_names {
|
||||
let desc = match *name {
|
||||
"read" => "Read file contents",
|
||||
"write" => "Create or overwrite a file",
|
||||
"edit" => "Edit a file by replacing text",
|
||||
"multi_edit" => "Apply multiple edits atomically",
|
||||
"glob" => "Find files by pattern",
|
||||
"grep" => "Search file contents",
|
||||
"ls" => "List directory contents",
|
||||
"bash" => "Execute shell commands",
|
||||
"web_search" => "Search the web",
|
||||
"web_fetch" => "Fetch a URL",
|
||||
"todo_write" => "Update task list",
|
||||
"ask_user" => "Ask user a question",
|
||||
_ => continue,
|
||||
};
|
||||
instructions.push_str(&format!("- `{}`: {}\n", name, desc));
|
||||
}
|
||||
|
||||
instructions
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_builder() {
|
||||
let prompt = SystemPromptBuilder::new()
|
||||
.with_base_prompt("You are helpful")
|
||||
.with_tool_instructions("Use tools wisely")
|
||||
.build();
|
||||
|
||||
assert!(prompt.contains("You are helpful"));
|
||||
assert!(prompt.contains("Use tools wisely"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_priority_ordering() {
|
||||
let prompt = SystemPromptBuilder::new()
|
||||
.with_section("last", "Third", 100)
|
||||
.with_section("first", "First", 0)
|
||||
.with_section("middle", "Second", 50)
|
||||
.build();
|
||||
|
||||
let first_pos = prompt.find("First").unwrap();
|
||||
let second_pos = prompt.find("Second").unwrap();
|
||||
let third_pos = prompt.find("Third").unwrap();
|
||||
|
||||
assert!(first_pos < second_pos);
|
||||
assert!(second_pos < third_pos);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_base_prompt() {
|
||||
let prompt = default_base_prompt();
|
||||
assert!(prompt.contains("Owlen"));
|
||||
assert!(prompt.contains("Guidelines"));
|
||||
assert!(prompt.contains("Tool Usage"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_tool_instructions() {
|
||||
let tools = vec!["read", "write", "edit", "bash"];
|
||||
let instructions = generate_tool_instructions(&tools);
|
||||
|
||||
assert!(instructions.contains("Available Tools"));
|
||||
assert!(instructions.contains("read"));
|
||||
assert!(instructions.contains("write"));
|
||||
assert!(instructions.contains("edit"));
|
||||
assert!(instructions.contains("bash"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder_empty() {
|
||||
let builder = SystemPromptBuilder::new();
|
||||
assert!(builder.is_empty());
|
||||
|
||||
let builder = builder.with_base_prompt("test");
|
||||
assert!(!builder.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skill_section() {
|
||||
let prompt = SystemPromptBuilder::new()
|
||||
.with_base_prompt("Base")
|
||||
.with_skill("rust", "Rust expertise")
|
||||
.build();
|
||||
|
||||
assert!(prompt.contains("Base"));
|
||||
assert!(prompt.contains("Rust expertise"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hook_injection() {
|
||||
let prompt = SystemPromptBuilder::new()
|
||||
.with_base_prompt("Base")
|
||||
.with_hook_injection("Additional context from hook")
|
||||
.build();
|
||||
|
||||
assert!(prompt.contains("Base"));
|
||||
assert!(prompt.contains("Additional context from hook"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_separator_between_sections() {
|
||||
let prompt = SystemPromptBuilder::new()
|
||||
.with_section("first", "First section", 0)
|
||||
.with_section("second", "Second section", 10)
|
||||
.build();
|
||||
|
||||
assert!(prompt.contains("---"));
|
||||
assert!(prompt.contains("First section"));
|
||||
assert!(prompt.contains("Second section"));
|
||||
}
|
||||
}
|
||||
210
crates/core/agent/tests/checkpoint.rs
Normal file
210
crates/core/agent/tests/checkpoint.rs
Normal file
@@ -0,0 +1,210 @@
|
||||
use agent_core::{Checkpoint, CheckpointManager, FileDiff, SessionHistory, SessionStats};
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn test_checkpoint_save_and_load() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let checkpoint_dir = temp_dir.path().to_path_buf();
|
||||
|
||||
let stats = SessionStats::new();
|
||||
let mut history = SessionHistory::new();
|
||||
history.add_user_message("Hello".to_string());
|
||||
history.add_assistant_message("Hi there!".to_string());
|
||||
|
||||
let file_diffs = vec![FileDiff::new(
|
||||
PathBuf::from("test.txt"),
|
||||
"before".to_string(),
|
||||
"after".to_string(),
|
||||
)];
|
||||
|
||||
let checkpoint = Checkpoint::new(
|
||||
"test-checkpoint".to_string(),
|
||||
stats.clone(),
|
||||
&history,
|
||||
file_diffs,
|
||||
);
|
||||
|
||||
// Save checkpoint
|
||||
checkpoint.save(&checkpoint_dir).unwrap();
|
||||
|
||||
// Load checkpoint
|
||||
let loaded = Checkpoint::load(&checkpoint_dir, "test-checkpoint").unwrap();
|
||||
|
||||
assert_eq!(loaded.id, "test-checkpoint");
|
||||
assert_eq!(loaded.user_prompts, vec!["Hello"]);
|
||||
assert_eq!(loaded.assistant_responses, vec!["Hi there!"]);
|
||||
assert_eq!(loaded.file_diffs.len(), 1);
|
||||
assert_eq!(loaded.file_diffs[0].path, PathBuf::from("test.txt"));
|
||||
assert_eq!(loaded.file_diffs[0].before, "before");
|
||||
assert_eq!(loaded.file_diffs[0].after, "after");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_checkpoint_list() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let checkpoint_dir = temp_dir.path().to_path_buf();
|
||||
|
||||
// Create a few checkpoints
|
||||
for i in 1..=3 {
|
||||
let checkpoint = Checkpoint::new(
|
||||
format!("checkpoint-{}", i),
|
||||
SessionStats::new(),
|
||||
&SessionHistory::new(),
|
||||
vec![],
|
||||
);
|
||||
checkpoint.save(&checkpoint_dir).unwrap();
|
||||
}
|
||||
|
||||
let checkpoints = Checkpoint::list(&checkpoint_dir).unwrap();
|
||||
assert_eq!(checkpoints.len(), 3);
|
||||
assert!(checkpoints.contains(&"checkpoint-1".to_string()));
|
||||
assert!(checkpoints.contains(&"checkpoint-2".to_string()));
|
||||
assert!(checkpoints.contains(&"checkpoint-3".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_checkpoint_manager_snapshot_and_diff() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let checkpoint_dir = temp_dir.path().join("checkpoints");
|
||||
let test_file = temp_dir.path().join("test.txt");
|
||||
|
||||
// Create initial file content
|
||||
fs::write(&test_file, "initial content").unwrap();
|
||||
|
||||
let mut manager = CheckpointManager::new(checkpoint_dir.clone());
|
||||
|
||||
// Snapshot the file
|
||||
manager.snapshot_file(&test_file).unwrap();
|
||||
|
||||
// Modify the file
|
||||
fs::write(&test_file, "modified content").unwrap();
|
||||
|
||||
// Create a diff
|
||||
let diff = manager.create_diff(&test_file).unwrap();
|
||||
assert!(diff.is_some());
|
||||
|
||||
let diff = diff.unwrap();
|
||||
assert_eq!(diff.path, test_file);
|
||||
assert_eq!(diff.before, "initial content");
|
||||
assert_eq!(diff.after, "modified content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_checkpoint_manager_save_and_restore() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let checkpoint_dir = temp_dir.path().join("checkpoints");
|
||||
let test_file = temp_dir.path().join("test.txt");
|
||||
|
||||
// Create initial file content
|
||||
fs::write(&test_file, "initial content").unwrap();
|
||||
|
||||
let mut manager = CheckpointManager::new(checkpoint_dir.clone());
|
||||
|
||||
// Snapshot the file
|
||||
manager.snapshot_file(&test_file).unwrap();
|
||||
|
||||
// Modify the file
|
||||
fs::write(&test_file, "modified content").unwrap();
|
||||
|
||||
// Save checkpoint
|
||||
let mut history = SessionHistory::new();
|
||||
history.add_user_message("test".to_string());
|
||||
let checkpoint = manager
|
||||
.save_checkpoint("test-checkpoint".to_string(), SessionStats::new(), &history)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(checkpoint.file_diffs.len(), 1);
|
||||
assert_eq!(checkpoint.file_diffs[0].before, "initial content");
|
||||
assert_eq!(checkpoint.file_diffs[0].after, "modified content");
|
||||
|
||||
// Modify file again
|
||||
fs::write(&test_file, "final content").unwrap();
|
||||
assert_eq!(fs::read_to_string(&test_file).unwrap(), "final content");
|
||||
|
||||
// Rewind to checkpoint
|
||||
let restored_files = manager.rewind_to("test-checkpoint").unwrap();
|
||||
assert_eq!(restored_files.len(), 1);
|
||||
assert_eq!(restored_files[0], test_file);
|
||||
|
||||
// File should be reverted to initial content (before the checkpoint)
|
||||
assert_eq!(fs::read_to_string(&test_file).unwrap(), "initial content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_checkpoint_manager_multiple_files() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let checkpoint_dir = temp_dir.path().join("checkpoints");
|
||||
let test_file1 = temp_dir.path().join("file1.txt");
|
||||
let test_file2 = temp_dir.path().join("file2.txt");
|
||||
|
||||
// Create initial files
|
||||
fs::write(&test_file1, "file1 initial").unwrap();
|
||||
fs::write(&test_file2, "file2 initial").unwrap();
|
||||
|
||||
let mut manager = CheckpointManager::new(checkpoint_dir.clone());
|
||||
|
||||
// Snapshot both files
|
||||
manager.snapshot_file(&test_file1).unwrap();
|
||||
manager.snapshot_file(&test_file2).unwrap();
|
||||
|
||||
// Modify both files
|
||||
fs::write(&test_file1, "file1 modified").unwrap();
|
||||
fs::write(&test_file2, "file2 modified").unwrap();
|
||||
|
||||
// Save checkpoint
|
||||
let checkpoint = manager
|
||||
.save_checkpoint(
|
||||
"multi-file-checkpoint".to_string(),
|
||||
SessionStats::new(),
|
||||
&SessionHistory::new(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(checkpoint.file_diffs.len(), 2);
|
||||
|
||||
// Modify files again
|
||||
fs::write(&test_file1, "file1 final").unwrap();
|
||||
fs::write(&test_file2, "file2 final").unwrap();
|
||||
|
||||
// Rewind
|
||||
let restored_files = manager.rewind_to("multi-file-checkpoint").unwrap();
|
||||
assert_eq!(restored_files.len(), 2);
|
||||
|
||||
// Both files should be reverted
|
||||
assert_eq!(fs::read_to_string(&test_file1).unwrap(), "file1 initial");
|
||||
assert_eq!(fs::read_to_string(&test_file2).unwrap(), "file2 initial");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_checkpoint_no_changes() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let checkpoint_dir = temp_dir.path().join("checkpoints");
|
||||
let test_file = temp_dir.path().join("test.txt");
|
||||
|
||||
// Create file
|
||||
fs::write(&test_file, "content").unwrap();
|
||||
|
||||
let mut manager = CheckpointManager::new(checkpoint_dir.clone());
|
||||
|
||||
// Snapshot the file
|
||||
manager.snapshot_file(&test_file).unwrap();
|
||||
|
||||
// Don't modify the file
|
||||
|
||||
// Create diff - should be None because nothing changed
|
||||
let diff = manager.create_diff(&test_file).unwrap();
|
||||
assert!(diff.is_none());
|
||||
|
||||
// Save checkpoint - should have no diffs
|
||||
let checkpoint = manager
|
||||
.save_checkpoint(
|
||||
"no-change-checkpoint".to_string(),
|
||||
SessionStats::new(),
|
||||
&SessionHistory::new(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(checkpoint.file_diffs.len(), 0);
|
||||
}
|
||||
276
crates/core/agent/tests/streaming.rs
Normal file
276
crates/core/agent/tests/streaming.rs
Normal file
@@ -0,0 +1,276 @@
|
||||
use agent_core::{create_event_channel, run_agent_loop_streaming, AgentEvent, ToolContext};
|
||||
use async_trait::async_trait;
|
||||
use futures_util::stream;
|
||||
use llm_core::{
|
||||
ChatMessage, ChatOptions, LlmError, StreamChunk, LlmProvider, Tool, ToolCallDelta,
|
||||
};
|
||||
use permissions::{Mode, PermissionManager};
|
||||
use std::pin::Pin;
|
||||
|
||||
/// Mock LLM provider for testing streaming
|
||||
struct MockStreamingProvider {
|
||||
responses: Vec<MockResponse>,
|
||||
}
|
||||
|
||||
enum MockResponse {
|
||||
/// Text-only response (no tool calls)
|
||||
Text(Vec<String>), // Chunks of text
|
||||
/// Tool call response
|
||||
ToolCall {
|
||||
text_chunks: Vec<String>,
|
||||
tool_id: String,
|
||||
tool_name: String,
|
||||
tool_args: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for MockStreamingProvider {
|
||||
fn name(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
"mock-model"
|
||||
}
|
||||
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
_options: &ChatOptions,
|
||||
_tools: Option<&[Tool]>,
|
||||
) -> Result<Pin<Box<dyn futures_util::Stream<Item = Result<StreamChunk, LlmError>> + Send>>, LlmError> {
|
||||
// Determine which response to use based on message count
|
||||
let response_idx = (messages.len() / 2).min(self.responses.len() - 1);
|
||||
let response = &self.responses[response_idx];
|
||||
|
||||
let chunks: Vec<Result<StreamChunk, LlmError>> = match response {
|
||||
MockResponse::Text(text_chunks) => text_chunks
|
||||
.iter()
|
||||
.map(|text| {
|
||||
Ok(StreamChunk {
|
||||
content: Some(text.clone()),
|
||||
tool_calls: None,
|
||||
done: false,
|
||||
usage: None,
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
MockResponse::ToolCall {
|
||||
text_chunks,
|
||||
tool_id,
|
||||
tool_name,
|
||||
tool_args,
|
||||
} => {
|
||||
let mut result = vec![];
|
||||
|
||||
// First emit text chunks
|
||||
for text in text_chunks {
|
||||
result.push(Ok(StreamChunk {
|
||||
content: Some(text.clone()),
|
||||
tool_calls: None,
|
||||
done: false,
|
||||
usage: None,
|
||||
}));
|
||||
}
|
||||
|
||||
// Then emit tool call in chunks
|
||||
result.push(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: 0,
|
||||
id: Some(tool_id.clone()),
|
||||
function_name: Some(tool_name.clone()),
|
||||
arguments_delta: None,
|
||||
}]),
|
||||
done: false,
|
||||
usage: None,
|
||||
}));
|
||||
|
||||
// Emit args in chunks
|
||||
for chunk in tool_args.chars().collect::<Vec<_>>().chunks(5) {
|
||||
result.push(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: 0,
|
||||
id: None,
|
||||
function_name: None,
|
||||
arguments_delta: Some(chunk.iter().collect()),
|
||||
}]),
|
||||
done: false,
|
||||
usage: None,
|
||||
}));
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Box::pin(stream::iter(chunks)))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_text_only() {
|
||||
let provider = MockStreamingProvider {
|
||||
responses: vec![MockResponse::Text(vec![
|
||||
"Hello".to_string(),
|
||||
" ".to_string(),
|
||||
"world".to_string(),
|
||||
"!".to_string(),
|
||||
])],
|
||||
};
|
||||
|
||||
let perms = PermissionManager::new(Mode::Plan);
|
||||
let ctx = ToolContext::default();
|
||||
let (tx, mut rx) = create_event_channel();
|
||||
|
||||
// Spawn the agent loop
|
||||
let handle = tokio::spawn(async move {
|
||||
run_agent_loop_streaming(
|
||||
&provider,
|
||||
"Say hello",
|
||||
&ChatOptions::default(),
|
||||
&perms,
|
||||
&ctx,
|
||||
tx,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
// Collect events
|
||||
let mut text_deltas = vec![];
|
||||
let mut done_response = None;
|
||||
|
||||
while let Some(event) = rx.recv().await {
|
||||
match event {
|
||||
AgentEvent::TextDelta(text) => {
|
||||
text_deltas.push(text);
|
||||
}
|
||||
AgentEvent::Done { final_response } => {
|
||||
done_response = Some(final_response);
|
||||
break;
|
||||
}
|
||||
AgentEvent::Error(e) => {
|
||||
panic!("Unexpected error: {}", e);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for agent loop to complete
|
||||
let result = handle.await.unwrap();
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify events
|
||||
assert_eq!(text_deltas, vec!["Hello", " ", "world", "!"]);
|
||||
assert_eq!(done_response, Some("Hello world!".to_string()));
|
||||
assert_eq!(result.unwrap(), "Hello world!");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_with_tool_call() {
|
||||
let provider = MockStreamingProvider {
|
||||
responses: vec![
|
||||
MockResponse::ToolCall {
|
||||
text_chunks: vec!["Let me ".to_string(), "check...".to_string()],
|
||||
tool_id: "call_123".to_string(),
|
||||
tool_name: "glob".to_string(),
|
||||
tool_args: r#"{"pattern":"*.rs"}"#.to_string(),
|
||||
},
|
||||
MockResponse::Text(vec!["Found ".to_string(), "the files!".to_string()]),
|
||||
],
|
||||
};
|
||||
|
||||
let perms = PermissionManager::new(Mode::Plan);
|
||||
let ctx = ToolContext::default();
|
||||
let (tx, mut rx) = create_event_channel();
|
||||
|
||||
// Spawn the agent loop
|
||||
let handle = tokio::spawn(async move {
|
||||
run_agent_loop_streaming(
|
||||
&provider,
|
||||
"Find Rust files",
|
||||
&ChatOptions::default(),
|
||||
&perms,
|
||||
&ctx,
|
||||
tx,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
// Collect events
|
||||
let mut text_deltas = vec![];
|
||||
let mut tool_starts = vec![];
|
||||
let mut tool_outputs = vec![];
|
||||
let mut tool_ends = vec![];
|
||||
|
||||
while let Some(event) = rx.recv().await {
|
||||
match event {
|
||||
AgentEvent::TextDelta(text) => {
|
||||
text_deltas.push(text);
|
||||
}
|
||||
AgentEvent::ToolStart {
|
||||
tool_name,
|
||||
tool_id,
|
||||
} => {
|
||||
tool_starts.push((tool_name, tool_id));
|
||||
}
|
||||
AgentEvent::ToolOutput {
|
||||
tool_id,
|
||||
content,
|
||||
is_error,
|
||||
} => {
|
||||
tool_outputs.push((tool_id, content, is_error));
|
||||
}
|
||||
AgentEvent::ToolEnd { tool_id, success } => {
|
||||
tool_ends.push((tool_id, success));
|
||||
}
|
||||
AgentEvent::Done { .. } => {
|
||||
break;
|
||||
}
|
||||
AgentEvent::Error(e) => {
|
||||
panic!("Unexpected error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for agent loop to complete
|
||||
let result = handle.await.unwrap();
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify we got text deltas from both responses
|
||||
assert!(text_deltas.contains(&"Let me ".to_string()));
|
||||
assert!(text_deltas.contains(&"check...".to_string()));
|
||||
assert!(text_deltas.contains(&"Found ".to_string()));
|
||||
assert!(text_deltas.contains(&"the files!".to_string()));
|
||||
|
||||
// Verify tool events
|
||||
assert_eq!(tool_starts.len(), 1);
|
||||
assert_eq!(tool_starts[0].0, "glob");
|
||||
assert_eq!(tool_starts[0].1, "call_123");
|
||||
|
||||
assert_eq!(tool_outputs.len(), 1);
|
||||
assert_eq!(tool_outputs[0].0, "call_123");
|
||||
assert!(!tool_outputs[0].2); // not an error
|
||||
|
||||
assert_eq!(tool_ends.len(), 1);
|
||||
assert_eq!(tool_ends[0].0, "call_123");
|
||||
assert!(tool_ends[0].1); // success
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_channel_creation() {
|
||||
let (tx, mut rx) = create_event_channel();
|
||||
|
||||
// Test that channel works
|
||||
tx.send(AgentEvent::TextDelta("test".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let event = rx.recv().await.unwrap();
|
||||
match event {
|
||||
AgentEvent::TextDelta(text) => assert_eq!(text, "test"),
|
||||
_ => panic!("Wrong event type"),
|
||||
}
|
||||
}
|
||||
114
crates/core/agent/tests/tool_context.rs
Normal file
114
crates/core/agent/tests/tool_context.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
// Test that ToolContext properly wires up the placeholder tools
|
||||
use agent_core::{ToolContext, execute_tool};
|
||||
use permissions::{Mode, PermissionManager};
|
||||
use tools_todo::{TodoList, TodoStatus};
|
||||
use tools_bash::ShellManager;
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_todo_write_with_context() {
|
||||
let todo_list = TodoList::new();
|
||||
let ctx = ToolContext::new().with_todo_list(todo_list.clone());
|
||||
let perms = PermissionManager::new(Mode::Code); // Allow all tools
|
||||
|
||||
let arguments = json!({
|
||||
"todos": [
|
||||
{
|
||||
"content": "First task",
|
||||
"status": "pending",
|
||||
"active_form": "Working on first task"
|
||||
},
|
||||
{
|
||||
"content": "Second task",
|
||||
"status": "in_progress",
|
||||
"active_form": "Working on second task"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let result = execute_tool("todo_write", &arguments, &perms, &ctx).await;
|
||||
assert!(result.is_ok(), "TodoWrite should succeed: {:?}", result);
|
||||
|
||||
// Verify the todos were written
|
||||
let todos = todo_list.read();
|
||||
assert_eq!(todos.len(), 2);
|
||||
assert_eq!(todos[0].content, "First task");
|
||||
assert_eq!(todos[1].status, TodoStatus::InProgress);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_todo_write_without_context() {
|
||||
let ctx = ToolContext::new(); // No todo_list
|
||||
let perms = PermissionManager::new(Mode::Code);
|
||||
|
||||
let arguments = json!({
|
||||
"todos": []
|
||||
});
|
||||
|
||||
let result = execute_tool("todo_write", &arguments, &perms, &ctx).await;
|
||||
assert!(result.is_err(), "TodoWrite should fail without TodoList");
|
||||
assert!(result.unwrap_err().to_string().contains("not available"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bash_output_with_context() {
|
||||
let manager = ShellManager::new();
|
||||
let ctx = ToolContext::new().with_shell_manager(manager.clone());
|
||||
let perms = PermissionManager::new(Mode::Code);
|
||||
|
||||
// Start a shell and run a command
|
||||
let shell_id = manager.start_shell().await.unwrap();
|
||||
let _ = manager.execute(&shell_id, "echo test", None).await.unwrap();
|
||||
|
||||
let arguments = json!({
|
||||
"shell_id": shell_id
|
||||
});
|
||||
|
||||
let result = execute_tool("bash_output", &arguments, &perms, &ctx).await;
|
||||
assert!(result.is_ok(), "BashOutput should succeed: {:?}", result);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bash_output_without_context() {
|
||||
let ctx = ToolContext::new(); // No shell_manager
|
||||
let perms = PermissionManager::new(Mode::Code);
|
||||
|
||||
let arguments = json!({
|
||||
"shell_id": "fake-id"
|
||||
});
|
||||
|
||||
let result = execute_tool("bash_output", &arguments, &perms, &ctx).await;
|
||||
assert!(result.is_err(), "BashOutput should fail without ShellManager");
|
||||
assert!(result.unwrap_err().to_string().contains("not available"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_kill_shell_with_context() {
|
||||
let manager = ShellManager::new();
|
||||
let ctx = ToolContext::new().with_shell_manager(manager.clone());
|
||||
let perms = PermissionManager::new(Mode::Code);
|
||||
|
||||
// Start a shell
|
||||
let shell_id = manager.start_shell().await.unwrap();
|
||||
|
||||
let arguments = json!({
|
||||
"shell_id": shell_id
|
||||
});
|
||||
|
||||
let result = execute_tool("kill_shell", &arguments, &perms, &ctx).await;
|
||||
assert!(result.is_ok(), "KillShell should succeed: {:?}", result);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ask_user_without_context() {
|
||||
let ctx = ToolContext::new(); // No ask_sender
|
||||
let perms = PermissionManager::new(Mode::Code);
|
||||
|
||||
let arguments = json!({
|
||||
"questions": []
|
||||
});
|
||||
|
||||
let result = execute_tool("ask_user", &arguments, &perms, &ctx).await;
|
||||
assert!(result.is_err(), "AskUser should fail without AskSender");
|
||||
assert!(result.unwrap_err().to_string().contains("not available"));
|
||||
}
|
||||
18
crates/llm/anthropic/Cargo.toml
Normal file
18
crates/llm/anthropic/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "llm-anthropic"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "Anthropic Claude API client for Owlen"
|
||||
|
||||
[dependencies]
|
||||
llm-core = { path = "../core" }
|
||||
async-trait = "0.1"
|
||||
futures = "0.3"
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
reqwest-eventsource = "0.6"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
tokio = { version = "1", features = ["sync", "time"] }
|
||||
tracing = "0.1"
|
||||
uuid = { version = "1.0", features = ["v4"] }
|
||||
285
crates/llm/anthropic/src/auth.rs
Normal file
285
crates/llm/anthropic/src/auth.rs
Normal file
@@ -0,0 +1,285 @@
|
||||
//! Anthropic OAuth Authentication
|
||||
//!
|
||||
//! Implements device code flow for authenticating with Anthropic without API keys.
|
||||
|
||||
use llm_core::{AuthMethod, DeviceAuthResult, DeviceCodeResponse, LlmError, OAuthProvider};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// OAuth client for Anthropic device flow
|
||||
pub struct AnthropicAuth {
|
||||
http: Client,
|
||||
client_id: String,
|
||||
}
|
||||
|
||||
// Anthropic OAuth endpoints (these would be the real endpoints)
|
||||
const AUTH_BASE_URL: &str = "https://console.anthropic.com";
|
||||
const DEVICE_CODE_ENDPOINT: &str = "/oauth/device/code";
|
||||
const TOKEN_ENDPOINT: &str = "/oauth/token";
|
||||
|
||||
// Default client ID for Owlen CLI
|
||||
const DEFAULT_CLIENT_ID: &str = "owlen-cli";
|
||||
|
||||
impl AnthropicAuth {
|
||||
/// Create a new OAuth client with the default CLI client ID
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
client_id: DEFAULT_CLIENT_ID.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with a custom client ID
|
||||
pub fn with_client_id(client_id: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
client_id: client_id.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AnthropicAuth {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct DeviceCodeRequest<'a> {
|
||||
client_id: &'a str,
|
||||
scope: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeviceCodeApiResponse {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
verification_uri: String,
|
||||
verification_uri_complete: Option<String>,
|
||||
expires_in: u64,
|
||||
interval: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct TokenRequest<'a> {
|
||||
client_id: &'a str,
|
||||
device_code: &'a str,
|
||||
grant_type: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenApiResponse {
|
||||
access_token: String,
|
||||
#[allow(dead_code)]
|
||||
token_type: String,
|
||||
expires_in: Option<u64>,
|
||||
refresh_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenErrorResponse {
|
||||
error: String,
|
||||
error_description: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OAuthProvider for AnthropicAuth {
|
||||
async fn start_device_auth(&self) -> Result<DeviceCodeResponse, LlmError> {
|
||||
let url = format!("{}{}", AUTH_BASE_URL, DEVICE_CODE_ENDPOINT);
|
||||
|
||||
let request = DeviceCodeRequest {
|
||||
client_id: &self.client_id,
|
||||
scope: "api:read api:write", // Request API access
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(LlmError::Auth(format!(
|
||||
"Device code request failed ({}): {}",
|
||||
status, text
|
||||
)));
|
||||
}
|
||||
|
||||
let api_response: DeviceCodeApiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
Ok(DeviceCodeResponse {
|
||||
device_code: api_response.device_code,
|
||||
user_code: api_response.user_code,
|
||||
verification_uri: api_response.verification_uri,
|
||||
verification_uri_complete: api_response.verification_uri_complete,
|
||||
expires_in: api_response.expires_in,
|
||||
interval: api_response.interval,
|
||||
})
|
||||
}
|
||||
|
||||
async fn poll_device_auth(&self, device_code: &str) -> Result<DeviceAuthResult, LlmError> {
|
||||
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
|
||||
|
||||
let request = TokenRequest {
|
||||
client_id: &self.client_id,
|
||||
device_code,
|
||||
grant_type: "urn:ietf:params:oauth:grant-type:device_code",
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let token_response: TokenApiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
return Ok(DeviceAuthResult::Success {
|
||||
access_token: token_response.access_token,
|
||||
refresh_token: token_response.refresh_token,
|
||||
expires_in: token_response.expires_in,
|
||||
});
|
||||
}
|
||||
|
||||
// Parse error response
|
||||
let error_response: TokenErrorResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
match error_response.error.as_str() {
|
||||
"authorization_pending" => Ok(DeviceAuthResult::Pending),
|
||||
"slow_down" => Ok(DeviceAuthResult::Pending), // Treat as pending, caller should slow down
|
||||
"access_denied" => Ok(DeviceAuthResult::Denied),
|
||||
"expired_token" => Ok(DeviceAuthResult::Expired),
|
||||
_ => Err(LlmError::Auth(format!(
|
||||
"Token request failed: {} - {}",
|
||||
error_response.error,
|
||||
error_response.error_description.unwrap_or_default()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
async fn refresh_token(&self, refresh_token: &str) -> Result<AuthMethod, LlmError> {
|
||||
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct RefreshRequest<'a> {
|
||||
client_id: &'a str,
|
||||
refresh_token: &'a str,
|
||||
grant_type: &'a str,
|
||||
}
|
||||
|
||||
let request = RefreshRequest {
|
||||
client_id: &self.client_id,
|
||||
refresh_token,
|
||||
grant_type: "refresh_token",
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(LlmError::Auth(format!("Token refresh failed: {}", text)));
|
||||
}
|
||||
|
||||
let token_response: TokenApiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
let expires_at = token_response.expires_in.map(|secs| {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs() + secs)
|
||||
.unwrap_or(0)
|
||||
});
|
||||
|
||||
Ok(AuthMethod::OAuth {
|
||||
access_token: token_response.access_token,
|
||||
refresh_token: token_response.refresh_token,
|
||||
expires_at,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to perform the full device auth flow with polling
|
||||
pub async fn perform_device_auth<F>(
|
||||
auth: &AnthropicAuth,
|
||||
on_code: F,
|
||||
) -> Result<AuthMethod, LlmError>
|
||||
where
|
||||
F: FnOnce(&DeviceCodeResponse),
|
||||
{
|
||||
// Start the device flow
|
||||
let device_code = auth.start_device_auth().await?;
|
||||
|
||||
// Let caller display the code to user
|
||||
on_code(&device_code);
|
||||
|
||||
// Poll for completion
|
||||
let poll_interval = std::time::Duration::from_secs(device_code.interval);
|
||||
let deadline =
|
||||
std::time::Instant::now() + std::time::Duration::from_secs(device_code.expires_in);
|
||||
|
||||
loop {
|
||||
if std::time::Instant::now() > deadline {
|
||||
return Err(LlmError::Auth("Device code expired".to_string()));
|
||||
}
|
||||
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
|
||||
match auth.poll_device_auth(&device_code.device_code).await? {
|
||||
DeviceAuthResult::Success {
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_in,
|
||||
} => {
|
||||
let expires_at = expires_in.map(|secs| {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs() + secs)
|
||||
.unwrap_or(0)
|
||||
});
|
||||
|
||||
return Ok(AuthMethod::OAuth {
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_at,
|
||||
});
|
||||
}
|
||||
DeviceAuthResult::Pending => continue,
|
||||
DeviceAuthResult::Denied => {
|
||||
return Err(LlmError::Auth("Authorization denied by user".to_string()));
|
||||
}
|
||||
DeviceAuthResult::Expired => {
|
||||
return Err(LlmError::Auth("Device code expired".to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
577
crates/llm/anthropic/src/client.rs
Normal file
577
crates/llm/anthropic/src/client.rs
Normal file
@@ -0,0 +1,577 @@
|
||||
//! Anthropic Claude API Client
|
||||
//!
|
||||
//! Implements the Messages API with streaming support.
|
||||
|
||||
use crate::types::*;
|
||||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use llm_core::{
|
||||
AccountInfo, AuthMethod, ChatMessage, ChatOptions, ChatResponse, ChunkStream, FunctionCall,
|
||||
LlmError, LlmProvider, ModelInfo, ProviderInfo, ProviderStatus, Role, StreamChunk, Tool,
|
||||
ToolCall, ToolCallDelta, Usage, UsageStats,
|
||||
};
|
||||
use reqwest::Client;
|
||||
use reqwest_eventsource::{Event, EventSource};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
const API_BASE_URL: &str = "https://api.anthropic.com";
|
||||
const MESSAGES_ENDPOINT: &str = "/v1/messages";
|
||||
const API_VERSION: &str = "2023-06-01";
|
||||
const DEFAULT_MAX_TOKENS: u32 = 8192;
|
||||
|
||||
/// Anthropic Claude API client
|
||||
pub struct AnthropicClient {
|
||||
http: Client,
|
||||
auth: AuthMethod,
|
||||
model: String,
|
||||
}
|
||||
|
||||
impl AnthropicClient {
|
||||
/// Create a new client with API key authentication
|
||||
pub fn new(api_key: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
auth: AuthMethod::api_key(api_key),
|
||||
model: "claude-sonnet-4-20250514".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new client with OAuth token
|
||||
pub fn with_oauth(access_token: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
auth: AuthMethod::oauth(access_token),
|
||||
model: "claude-sonnet-4-20250514".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new client with full AuthMethod
|
||||
pub fn with_auth(auth: AuthMethod) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
auth,
|
||||
model: "claude-sonnet-4-20250514".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the model to use
|
||||
pub fn with_model(mut self, model: impl Into<String>) -> Self {
|
||||
self.model = model.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Get current auth method (for token refresh)
|
||||
pub fn auth(&self) -> &AuthMethod {
|
||||
&self.auth
|
||||
}
|
||||
|
||||
/// Update the auth method (after refresh)
|
||||
pub fn set_auth(&mut self, auth: AuthMethod) {
|
||||
self.auth = auth;
|
||||
}
|
||||
|
||||
/// Convert messages to Anthropic format, extracting system message
|
||||
fn prepare_messages(messages: &[ChatMessage]) -> (Option<String>, Vec<AnthropicMessage>) {
|
||||
let mut system_content = None;
|
||||
let mut anthropic_messages = Vec::new();
|
||||
|
||||
for msg in messages {
|
||||
if msg.role == Role::System {
|
||||
// Collect system messages
|
||||
if let Some(content) = &msg.content {
|
||||
if let Some(existing) = &mut system_content {
|
||||
*existing = format!("{}\n\n{}", existing, content);
|
||||
} else {
|
||||
system_content = Some(content.clone());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
anthropic_messages.push(AnthropicMessage::from(msg));
|
||||
}
|
||||
}
|
||||
|
||||
(system_content, anthropic_messages)
|
||||
}
|
||||
|
||||
/// Convert tools to Anthropic format
|
||||
fn prepare_tools(tools: Option<&[Tool]>) -> Option<Vec<AnthropicTool>> {
|
||||
tools.map(|t| t.iter().map(AnthropicTool::from).collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for AnthropicClient {
|
||||
fn name(&self) -> &str {
|
||||
"anthropic"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChunkStream, LlmError> {
|
||||
let url = format!("{}{}", API_BASE_URL, MESSAGES_ENDPOINT);
|
||||
|
||||
let model = if options.model.is_empty() {
|
||||
&self.model
|
||||
} else {
|
||||
&options.model
|
||||
};
|
||||
|
||||
let (system, anthropic_messages) = Self::prepare_messages(messages);
|
||||
let anthropic_tools = Self::prepare_tools(tools);
|
||||
|
||||
let request = MessagesRequest {
|
||||
model,
|
||||
messages: anthropic_messages,
|
||||
max_tokens: options.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
|
||||
system: system.as_deref(),
|
||||
temperature: options.temperature,
|
||||
top_p: options.top_p,
|
||||
stop_sequences: options.stop.as_deref(),
|
||||
tools: anthropic_tools,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let bearer = self
|
||||
.auth
|
||||
.bearer_token()
|
||||
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
|
||||
|
||||
// Build the SSE request
|
||||
let req = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("x-api-key", bearer)
|
||||
.header("anthropic-version", API_VERSION)
|
||||
.header("content-type", "application/json")
|
||||
.json(&request);
|
||||
|
||||
let es = EventSource::new(req).map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
// State for accumulating tool calls across deltas
|
||||
let tool_state: Arc<Mutex<Vec<PartialToolCall>>> = Arc::new(Mutex::new(Vec::new()));
|
||||
|
||||
let stream = es.filter_map(move |event| {
|
||||
let tool_state = Arc::clone(&tool_state);
|
||||
async move {
|
||||
match event {
|
||||
Ok(Event::Open) => None,
|
||||
Ok(Event::Message(msg)) => {
|
||||
// Parse the SSE data as JSON
|
||||
let event: StreamEvent = match serde_json::from_str(&msg.data) {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse SSE event: {}", e);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
convert_stream_event(event, &tool_state).await
|
||||
}
|
||||
Err(reqwest_eventsource::Error::StreamEnded) => None,
|
||||
Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
let url = format!("{}{}", API_BASE_URL, MESSAGES_ENDPOINT);
|
||||
|
||||
let model = if options.model.is_empty() {
|
||||
&self.model
|
||||
} else {
|
||||
&options.model
|
||||
};
|
||||
|
||||
let (system, anthropic_messages) = Self::prepare_messages(messages);
|
||||
let anthropic_tools = Self::prepare_tools(tools);
|
||||
|
||||
let request = MessagesRequest {
|
||||
model,
|
||||
messages: anthropic_messages,
|
||||
max_tokens: options.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
|
||||
system: system.as_deref(),
|
||||
temperature: options.temperature,
|
||||
top_p: options.top_p,
|
||||
stop_sequences: options.stop.as_deref(),
|
||||
tools: anthropic_tools,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let bearer = self
|
||||
.auth
|
||||
.bearer_token()
|
||||
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("x-api-key", bearer)
|
||||
.header("anthropic-version", API_VERSION)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
// Check for rate limiting
|
||||
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
|
||||
return Err(LlmError::RateLimit {
|
||||
retry_after_secs: None,
|
||||
});
|
||||
}
|
||||
|
||||
return Err(LlmError::Api {
|
||||
message: text,
|
||||
code: Some(status.to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
let api_response: MessagesResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
// Convert response to common format
|
||||
let mut content = String::new();
|
||||
let mut tool_calls = Vec::new();
|
||||
|
||||
for block in api_response.content {
|
||||
match block {
|
||||
ResponseContentBlock::Text { text } => {
|
||||
content.push_str(&text);
|
||||
}
|
||||
ResponseContentBlock::ToolUse { id, name, input } => {
|
||||
tool_calls.push(ToolCall {
|
||||
id,
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name,
|
||||
arguments: input,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let usage = api_response.usage.map(|u| Usage {
|
||||
prompt_tokens: u.input_tokens,
|
||||
completion_tokens: u.output_tokens,
|
||||
total_tokens: u.input_tokens + u.output_tokens,
|
||||
});
|
||||
|
||||
Ok(ChatResponse {
|
||||
content: if content.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(content)
|
||||
},
|
||||
tool_calls: if tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(tool_calls)
|
||||
},
|
||||
usage,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper struct for accumulating streaming tool calls
|
||||
#[derive(Default)]
|
||||
struct PartialToolCall {
|
||||
#[allow(dead_code)]
|
||||
id: String,
|
||||
#[allow(dead_code)]
|
||||
name: String,
|
||||
input_json: String,
|
||||
}
|
||||
|
||||
/// Convert an Anthropic stream event to our common StreamChunk format
|
||||
async fn convert_stream_event(
|
||||
event: StreamEvent,
|
||||
tool_state: &Arc<Mutex<Vec<PartialToolCall>>>,
|
||||
) -> Option<Result<StreamChunk, LlmError>> {
|
||||
match event {
|
||||
StreamEvent::ContentBlockStart {
|
||||
index,
|
||||
content_block,
|
||||
} => {
|
||||
match content_block {
|
||||
ContentBlockStartInfo::Text { text } => {
|
||||
if text.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(Ok(StreamChunk {
|
||||
content: Some(text),
|
||||
tool_calls: None,
|
||||
done: false,
|
||||
usage: None,
|
||||
}))
|
||||
}
|
||||
}
|
||||
ContentBlockStartInfo::ToolUse { id, name } => {
|
||||
// Store the tool call start
|
||||
let mut state = tool_state.lock().await;
|
||||
while state.len() <= index {
|
||||
state.push(PartialToolCall::default());
|
||||
}
|
||||
state[index] = PartialToolCall {
|
||||
id: id.clone(),
|
||||
name: name.clone(),
|
||||
input_json: String::new(),
|
||||
};
|
||||
|
||||
Some(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index,
|
||||
id: Some(id),
|
||||
function_name: Some(name),
|
||||
arguments_delta: None,
|
||||
}]),
|
||||
done: false,
|
||||
usage: None,
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
StreamEvent::ContentBlockDelta { index, delta } => match delta {
|
||||
ContentDelta::TextDelta { text } => Some(Ok(StreamChunk {
|
||||
content: Some(text),
|
||||
tool_calls: None,
|
||||
done: false,
|
||||
usage: None,
|
||||
})),
|
||||
ContentDelta::InputJsonDelta { partial_json } => {
|
||||
// Accumulate the JSON
|
||||
let mut state = tool_state.lock().await;
|
||||
if index < state.len() {
|
||||
state[index].input_json.push_str(&partial_json);
|
||||
}
|
||||
|
||||
Some(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index,
|
||||
id: None,
|
||||
function_name: None,
|
||||
arguments_delta: Some(partial_json),
|
||||
}]),
|
||||
done: false,
|
||||
usage: None,
|
||||
}))
|
||||
}
|
||||
},
|
||||
|
||||
StreamEvent::MessageDelta { usage, .. } => {
|
||||
let u = usage.map(|u| Usage {
|
||||
prompt_tokens: u.input_tokens,
|
||||
completion_tokens: u.output_tokens,
|
||||
total_tokens: u.input_tokens + u.output_tokens,
|
||||
});
|
||||
|
||||
Some(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
done: false,
|
||||
usage: u,
|
||||
}))
|
||||
}
|
||||
|
||||
StreamEvent::MessageStop => Some(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
done: true,
|
||||
usage: None,
|
||||
})),
|
||||
|
||||
StreamEvent::Error { error } => Some(Err(LlmError::Api {
|
||||
message: error.message,
|
||||
code: Some(error.error_type),
|
||||
})),
|
||||
|
||||
// Ignore other events
|
||||
StreamEvent::MessageStart { .. }
|
||||
| StreamEvent::ContentBlockStop { .. }
|
||||
| StreamEvent::Ping => None,
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ProviderInfo Implementation
|
||||
// ============================================================================
|
||||
|
||||
/// Known Claude models with their specifications
|
||||
fn get_claude_models() -> Vec<ModelInfo> {
|
||||
vec![
|
||||
ModelInfo {
|
||||
id: "claude-opus-4-20250514".to_string(),
|
||||
display_name: Some("Claude Opus 4".to_string()),
|
||||
description: Some("Most capable model for complex tasks".to_string()),
|
||||
context_window: Some(200_000),
|
||||
max_output_tokens: Some(32_000),
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(15.0),
|
||||
output_price_per_mtok: Some(75.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "claude-sonnet-4-20250514".to_string(),
|
||||
display_name: Some("Claude Sonnet 4".to_string()),
|
||||
description: Some("Best balance of performance and speed".to_string()),
|
||||
context_window: Some(200_000),
|
||||
max_output_tokens: Some(64_000),
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(3.0),
|
||||
output_price_per_mtok: Some(15.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "claude-haiku-3-5-20241022".to_string(),
|
||||
display_name: Some("Claude 3.5 Haiku".to_string()),
|
||||
description: Some("Fast and affordable for simple tasks".to_string()),
|
||||
context_window: Some(200_000),
|
||||
max_output_tokens: Some(8_192),
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(0.80),
|
||||
output_price_per_mtok: Some(4.0),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProviderInfo for AnthropicClient {
|
||||
async fn status(&self) -> Result<ProviderStatus, LlmError> {
|
||||
let authenticated = self.auth.bearer_token().is_some();
|
||||
|
||||
// Try to reach the API with a simple request
|
||||
let reachable = if authenticated {
|
||||
// Test with a minimal message to verify auth works
|
||||
let test_messages = vec![ChatMessage::user("Hi")];
|
||||
let test_opts = ChatOptions::new(&self.model).with_max_tokens(1);
|
||||
|
||||
match self.chat(&test_messages, &test_opts, None).await {
|
||||
Ok(_) => true,
|
||||
Err(LlmError::Auth(_)) => false, // Auth failed
|
||||
Err(_) => true, // Other errors mean API is reachable
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
let account = if authenticated && reachable {
|
||||
self.account_info().await.ok().flatten()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let message = if !authenticated {
|
||||
Some("Not authenticated - run 'owlen login anthropic' to authenticate".to_string())
|
||||
} else if !reachable {
|
||||
Some("Cannot reach Anthropic API".to_string())
|
||||
} else {
|
||||
Some("Connected".to_string())
|
||||
};
|
||||
|
||||
Ok(ProviderStatus {
|
||||
provider: "anthropic".to_string(),
|
||||
authenticated,
|
||||
account,
|
||||
model: self.model.clone(),
|
||||
endpoint: API_BASE_URL.to_string(),
|
||||
reachable,
|
||||
message,
|
||||
})
|
||||
}
|
||||
|
||||
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError> {
|
||||
// Anthropic doesn't have a public account info endpoint
|
||||
// Return None - account info would come from OAuth token claims
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError> {
|
||||
// Anthropic doesn't expose usage stats via API
|
||||
// This would require the admin/billing API with different auth
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
|
||||
// Return known models - Anthropic doesn't have a models list endpoint
|
||||
Ok(get_claude_models())
|
||||
}
|
||||
|
||||
async fn model_info(&self, model_id: &str) -> Result<Option<ModelInfo>, LlmError> {
|
||||
let models = get_claude_models();
|
||||
Ok(models.into_iter().find(|m| m.id == model_id))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use llm_core::ToolParameters;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_message_conversion() {
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are helpful"),
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::assistant("Hi there!"),
|
||||
];
|
||||
|
||||
let (system, anthropic_msgs) = AnthropicClient::prepare_messages(&messages);
|
||||
|
||||
assert_eq!(system, Some("You are helpful".to_string()));
|
||||
assert_eq!(anthropic_msgs.len(), 2);
|
||||
assert_eq!(anthropic_msgs[0].role, "user");
|
||||
assert_eq!(anthropic_msgs[1].role, "assistant");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_conversion() {
|
||||
let tools = vec![Tool::function(
|
||||
"read_file",
|
||||
"Read a file's contents",
|
||||
ToolParameters::object(
|
||||
json!({
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path"
|
||||
}
|
||||
}),
|
||||
vec!["path".to_string()],
|
||||
),
|
||||
)];
|
||||
|
||||
let anthropic_tools = AnthropicClient::prepare_tools(Some(&tools)).unwrap();
|
||||
|
||||
assert_eq!(anthropic_tools.len(), 1);
|
||||
assert_eq!(anthropic_tools[0].name, "read_file");
|
||||
assert_eq!(anthropic_tools[0].description, "Read a file's contents");
|
||||
}
|
||||
}
|
||||
12
crates/llm/anthropic/src/lib.rs
Normal file
12
crates/llm/anthropic/src/lib.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
//! Anthropic Claude API Client
|
||||
//!
|
||||
//! Implements the LlmProvider trait for Anthropic's Claude models.
|
||||
//! Supports both API key authentication and OAuth device flow.
|
||||
|
||||
mod auth;
|
||||
mod client;
|
||||
mod types;
|
||||
|
||||
pub use auth::*;
|
||||
pub use client::*;
|
||||
pub use types::*;
|
||||
276
crates/llm/anthropic/src/types.rs
Normal file
276
crates/llm/anthropic/src/types.rs
Normal file
@@ -0,0 +1,276 @@
|
||||
//! Anthropic API request/response types
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
// ============================================================================
|
||||
// Request Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct MessagesRequest<'a> {
|
||||
pub model: &'a str,
|
||||
pub messages: Vec<AnthropicMessage>,
|
||||
pub max_tokens: u32,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system: Option<&'a str>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop_sequences: Option<&'a [String]>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<AnthropicTool>>,
|
||||
|
||||
pub stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AnthropicMessage {
|
||||
pub role: String, // "user" or "assistant"
|
||||
pub content: AnthropicContent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum AnthropicContent {
|
||||
Text(String),
|
||||
Blocks(Vec<ContentBlock>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ContentBlock {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
|
||||
#[serde(rename = "tool_use")]
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: Value,
|
||||
},
|
||||
|
||||
#[serde(rename = "tool_result")]
|
||||
ToolResult {
|
||||
tool_use_id: String,
|
||||
content: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
is_error: Option<bool>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AnthropicTool {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub input_schema: ToolInputSchema,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolInputSchema {
|
||||
#[serde(rename = "type")]
|
||||
pub schema_type: String,
|
||||
pub properties: Value,
|
||||
pub required: Vec<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Response Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct MessagesResponse {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub response_type: String,
|
||||
pub role: String,
|
||||
pub content: Vec<ResponseContentBlock>,
|
||||
pub model: String,
|
||||
pub stop_reason: Option<String>,
|
||||
pub usage: Option<UsageInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ResponseContentBlock {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
|
||||
#[serde(rename = "tool_use")]
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: Value,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct UsageInfo {
|
||||
pub input_tokens: u32,
|
||||
pub output_tokens: u32,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Streaming Event Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum StreamEvent {
|
||||
#[serde(rename = "message_start")]
|
||||
MessageStart { message: MessageStartInfo },
|
||||
|
||||
#[serde(rename = "content_block_start")]
|
||||
ContentBlockStart {
|
||||
index: usize,
|
||||
content_block: ContentBlockStartInfo,
|
||||
},
|
||||
|
||||
#[serde(rename = "content_block_delta")]
|
||||
ContentBlockDelta { index: usize, delta: ContentDelta },
|
||||
|
||||
#[serde(rename = "content_block_stop")]
|
||||
ContentBlockStop { index: usize },
|
||||
|
||||
#[serde(rename = "message_delta")]
|
||||
MessageDelta {
|
||||
delta: MessageDeltaInfo,
|
||||
usage: Option<UsageInfo>,
|
||||
},
|
||||
|
||||
#[serde(rename = "message_stop")]
|
||||
MessageStop,
|
||||
|
||||
#[serde(rename = "ping")]
|
||||
Ping,
|
||||
|
||||
#[serde(rename = "error")]
|
||||
Error { error: ApiError },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct MessageStartInfo {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub message_type: String,
|
||||
pub role: String,
|
||||
pub model: String,
|
||||
pub usage: Option<UsageInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ContentBlockStartInfo {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
|
||||
#[serde(rename = "tool_use")]
|
||||
ToolUse { id: String, name: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ContentDelta {
|
||||
#[serde(rename = "text_delta")]
|
||||
TextDelta { text: String },
|
||||
|
||||
#[serde(rename = "input_json_delta")]
|
||||
InputJsonDelta { partial_json: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct MessageDeltaInfo {
|
||||
pub stop_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ApiError {
|
||||
#[serde(rename = "type")]
|
||||
pub error_type: String,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Conversions
|
||||
// ============================================================================
|
||||
|
||||
impl From<&llm_core::Tool> for AnthropicTool {
|
||||
fn from(tool: &llm_core::Tool) -> Self {
|
||||
Self {
|
||||
name: tool.function.name.clone(),
|
||||
description: tool.function.description.clone(),
|
||||
input_schema: ToolInputSchema {
|
||||
schema_type: tool.function.parameters.param_type.clone(),
|
||||
properties: tool.function.parameters.properties.clone(),
|
||||
required: tool.function.parameters.required.clone(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&llm_core::ChatMessage> for AnthropicMessage {
|
||||
fn from(msg: &llm_core::ChatMessage) -> Self {
|
||||
use llm_core::Role;
|
||||
|
||||
let role = match msg.role {
|
||||
Role::User | Role::System => "user",
|
||||
Role::Assistant => "assistant",
|
||||
Role::Tool => "user", // Tool results come as user messages in Anthropic
|
||||
};
|
||||
|
||||
// Handle tool results
|
||||
if msg.role == Role::Tool {
|
||||
if let (Some(tool_call_id), Some(content)) = (&msg.tool_call_id, &msg.content) {
|
||||
return Self {
|
||||
role: "user".to_string(),
|
||||
content: AnthropicContent::Blocks(vec![ContentBlock::ToolResult {
|
||||
tool_use_id: tool_call_id.clone(),
|
||||
content: content.clone(),
|
||||
is_error: None,
|
||||
}]),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Handle assistant messages with tool calls
|
||||
if msg.role == Role::Assistant {
|
||||
if let Some(tool_calls) = &msg.tool_calls {
|
||||
let mut blocks: Vec<ContentBlock> = Vec::new();
|
||||
|
||||
// Add text content if present
|
||||
if let Some(text) = &msg.content {
|
||||
if !text.is_empty() {
|
||||
blocks.push(ContentBlock::Text { text: text.clone() });
|
||||
}
|
||||
}
|
||||
|
||||
// Add tool use blocks
|
||||
for call in tool_calls {
|
||||
blocks.push(ContentBlock::ToolUse {
|
||||
id: call.id.clone(),
|
||||
name: call.function.name.clone(),
|
||||
input: call.function.arguments.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
return Self {
|
||||
role: "assistant".to_string(),
|
||||
content: AnthropicContent::Blocks(blocks),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Simple text message
|
||||
Self {
|
||||
role: role.to_string(),
|
||||
content: AnthropicContent::Text(msg.content.clone().unwrap_or_default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
18
crates/llm/core/Cargo.toml
Normal file
18
crates/llm/core/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "llm-core"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "LLM provider abstraction layer for Owlen"
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1"
|
||||
futures = "0.3"
|
||||
rand = "0.8"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
thiserror = "2.0"
|
||||
tokio = { version = "1.0", features = ["time"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { version = "1.0", features = ["macros", "rt"] }
|
||||
195
crates/llm/core/examples/token_counting.rs
Normal file
195
crates/llm/core/examples/token_counting.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
//! Token counting example
|
||||
//!
|
||||
//! This example demonstrates how to use the token counting utilities
|
||||
//! to manage LLM context windows.
|
||||
//!
|
||||
//! Run with: cargo run --example token_counting -p llm-core
|
||||
|
||||
use llm_core::{
|
||||
ChatMessage, ClaudeTokenCounter, ContextWindow, SimpleTokenCounter, TokenCounter,
|
||||
};
|
||||
|
||||
fn main() {
|
||||
println!("=== Token Counting Example ===\n");
|
||||
|
||||
// Example 1: Basic token counting with SimpleTokenCounter
|
||||
println!("1. Basic Token Counting");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let simple_counter = SimpleTokenCounter::new(8192);
|
||||
let text = "The quick brown fox jumps over the lazy dog.";
|
||||
|
||||
let token_count = simple_counter.count(text);
|
||||
println!("Text: \"{}\"", text);
|
||||
println!("Estimated tokens: {}", token_count);
|
||||
println!("Max context: {}\n", simple_counter.max_context());
|
||||
|
||||
// Example 2: Counting tokens in chat messages
|
||||
println!("2. Counting Tokens in Chat Messages");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are a helpful assistant that provides concise answers."),
|
||||
ChatMessage::user("What is the capital of France?"),
|
||||
ChatMessage::assistant("The capital of France is Paris."),
|
||||
ChatMessage::user("What is its population?"),
|
||||
];
|
||||
|
||||
let total_tokens = simple_counter.count_messages(&messages);
|
||||
println!("Number of messages: {}", messages.len());
|
||||
println!("Total tokens (with overhead): {}\n", total_tokens);
|
||||
|
||||
// Example 3: Using ClaudeTokenCounter for Claude models
|
||||
println!("3. Claude-Specific Token Counting");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let claude_counter = ClaudeTokenCounter::new();
|
||||
let claude_total = claude_counter.count_messages(&messages);
|
||||
|
||||
println!("Claude counter max context: {}", claude_counter.max_context());
|
||||
println!("Claude estimated tokens: {}\n", claude_total);
|
||||
|
||||
// Example 4: Context window management
|
||||
println!("4. Context Window Management");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let mut context = ContextWindow::new(8192);
|
||||
println!("Created context window with max: {} tokens", context.max());
|
||||
|
||||
// Simulate adding messages
|
||||
let conversation = vec![
|
||||
ChatMessage::user("Tell me about Rust programming."),
|
||||
ChatMessage::assistant(
|
||||
"Rust is a systems programming language focused on safety, \
|
||||
speed, and concurrency. It prevents common bugs like null pointer \
|
||||
dereferences and data races through its ownership system.",
|
||||
),
|
||||
ChatMessage::user("What are its main features?"),
|
||||
ChatMessage::assistant(
|
||||
"Rust's main features include: 1) Memory safety without garbage collection, \
|
||||
2) Zero-cost abstractions, 3) Fearless concurrency, 4) Pattern matching, \
|
||||
5) Type inference, and 6) A powerful macro system.",
|
||||
),
|
||||
];
|
||||
|
||||
for (i, msg) in conversation.iter().enumerate() {
|
||||
let tokens = simple_counter.count_messages(&[msg.clone()]);
|
||||
context.add_tokens(tokens);
|
||||
|
||||
let role = msg.role.as_str();
|
||||
let preview = msg
|
||||
.content
|
||||
.as_ref()
|
||||
.map(|c| {
|
||||
if c.len() > 50 {
|
||||
format!("{}...", &c[..50])
|
||||
} else {
|
||||
c.clone()
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
println!(
|
||||
"Message {}: [{}] \"{}\"",
|
||||
i + 1,
|
||||
role,
|
||||
preview
|
||||
);
|
||||
println!(" Added {} tokens", tokens);
|
||||
println!(" Total used: {} / {}", context.used(), context.max());
|
||||
println!(" Usage: {:.1}%", context.usage_percent() * 100.0);
|
||||
println!(" Progress: {}\n", context.progress_bar(30));
|
||||
}
|
||||
|
||||
// Example 5: Checking context limits
|
||||
println!("5. Checking Context Limits");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
if context.is_near_limit(0.8) {
|
||||
println!("Warning: Context is over 80% full!");
|
||||
} else {
|
||||
println!("Context usage is below 80%");
|
||||
}
|
||||
|
||||
let remaining = context.remaining();
|
||||
println!("Remaining tokens: {}", remaining);
|
||||
|
||||
let new_message_tokens = 500;
|
||||
if context.has_room_for(new_message_tokens) {
|
||||
println!(
|
||||
"Can fit a message of {} tokens",
|
||||
new_message_tokens
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
"Cannot fit a message of {} tokens - would need to compact or start new context",
|
||||
new_message_tokens
|
||||
);
|
||||
}
|
||||
|
||||
// Example 6: Different counter variants
|
||||
println!("\n6. Using Different Counter Variants");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let counter_8k = SimpleTokenCounter::default_8k();
|
||||
let counter_32k = SimpleTokenCounter::with_32k();
|
||||
let counter_128k = SimpleTokenCounter::with_128k();
|
||||
|
||||
println!("8k context counter: {} tokens", counter_8k.max_context());
|
||||
println!("32k context counter: {} tokens", counter_32k.max_context());
|
||||
println!("128k context counter: {} tokens", counter_128k.max_context());
|
||||
|
||||
let haiku = ClaudeTokenCounter::haiku();
|
||||
let sonnet = ClaudeTokenCounter::sonnet();
|
||||
let opus = ClaudeTokenCounter::opus();
|
||||
|
||||
println!("\nClaude Haiku: {} tokens", haiku.max_context());
|
||||
println!("Claude Sonnet: {} tokens", sonnet.max_context());
|
||||
println!("Claude Opus: {} tokens", opus.max_context());
|
||||
|
||||
// Example 7: Managing context for a long conversation
|
||||
println!("\n7. Long Conversation Simulation");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let mut long_context = ContextWindow::new(4096); // Smaller context for demo
|
||||
let counter = SimpleTokenCounter::new(4096);
|
||||
|
||||
let mut message_count = 0;
|
||||
let mut compaction_count = 0;
|
||||
|
||||
// Simulate 20 exchanges
|
||||
for i in 0..20 {
|
||||
let user_msg = ChatMessage::user(format!(
|
||||
"This is user message number {} asking a question.",
|
||||
i + 1
|
||||
));
|
||||
let assistant_msg = ChatMessage::assistant(format!(
|
||||
"This is assistant response number {} providing a detailed answer with multiple sentences to make it longer.",
|
||||
i + 1
|
||||
));
|
||||
|
||||
let tokens_needed = counter.count_messages(&[user_msg, assistant_msg]);
|
||||
|
||||
if !long_context.has_room_for(tokens_needed) {
|
||||
println!(
|
||||
"After {} messages, context is full ({}%). Compacting...",
|
||||
message_count,
|
||||
(long_context.usage_percent() * 100.0) as u32
|
||||
);
|
||||
// In a real scenario, we would compact the conversation
|
||||
// For now, just reset
|
||||
long_context.reset();
|
||||
compaction_count += 1;
|
||||
}
|
||||
|
||||
long_context.add_tokens(tokens_needed);
|
||||
message_count += 2;
|
||||
}
|
||||
|
||||
println!("Total messages: {}", message_count);
|
||||
println!("Compactions needed: {}", compaction_count);
|
||||
println!("Final context usage: {:.1}%", long_context.usage_percent() * 100.0);
|
||||
println!("Final progress: {}", long_context.progress_bar(40));
|
||||
|
||||
println!("\n=== Example Complete ===");
|
||||
}
|
||||
796
crates/llm/core/src/lib.rs
Normal file
796
crates/llm/core/src/lib.rs
Normal file
@@ -0,0 +1,796 @@
|
||||
//! LLM Provider Abstraction Layer
|
||||
//!
|
||||
//! This crate defines the common types and traits for LLM provider integration.
|
||||
//! Providers (Ollama, Anthropic Claude, OpenAI) implement the `LlmProvider` trait
|
||||
//! to enable swapping providers at runtime.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::pin::Pin;
|
||||
use thiserror::Error;
|
||||
|
||||
// ============================================================================
|
||||
// Public Modules
|
||||
// ============================================================================
|
||||
|
||||
pub mod retry;
|
||||
pub mod tokens;
|
||||
|
||||
// Re-export token counting types for convenience
|
||||
pub use tokens::{ClaudeTokenCounter, ContextWindow, SimpleTokenCounter, TokenCounter};
|
||||
|
||||
// Re-export retry types for convenience
|
||||
pub use retry::{is_retryable_error, RetryConfig, RetryStrategy};
|
||||
|
||||
// ============================================================================
|
||||
// Error Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum LlmError {
|
||||
#[error("HTTP error: {0}")]
|
||||
Http(String),
|
||||
|
||||
#[error("JSON parsing error: {0}")]
|
||||
Json(String),
|
||||
|
||||
#[error("Authentication error: {0}")]
|
||||
Auth(String),
|
||||
|
||||
#[error("Rate limit exceeded: retry after {retry_after_secs:?} seconds")]
|
||||
RateLimit { retry_after_secs: Option<u64> },
|
||||
|
||||
#[error("API error: {message}")]
|
||||
Api { message: String, code: Option<String> },
|
||||
|
||||
#[error("Provider error: {0}")]
|
||||
Provider(String),
|
||||
|
||||
#[error("Stream error: {0}")]
|
||||
Stream(String),
|
||||
|
||||
#[error("Request timeout: {0}")]
|
||||
Timeout(String),
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Message Types
|
||||
// ============================================================================
|
||||
|
||||
/// Role of a message in the conversation
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
System,
|
||||
User,
|
||||
Assistant,
|
||||
Tool,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Role::System => "system",
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
Role::Tool => "tool",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for Role {
|
||||
fn from(s: &str) -> Self {
|
||||
match s.to_lowercase().as_str() {
|
||||
"system" => Role::System,
|
||||
"user" => Role::User,
|
||||
"assistant" => Role::Assistant,
|
||||
"tool" => Role::Tool,
|
||||
_ => Role::User, // Default fallback
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A message in the conversation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: Role,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
/// Tool calls made by the assistant
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
|
||||
/// For tool role messages: the ID of the tool call this responds to
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
|
||||
/// For tool role messages: the name of the tool
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
/// Create a system message
|
||||
pub fn system(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: Role::System,
|
||||
content: Some(content.into()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a user message
|
||||
pub fn user(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: Role::User,
|
||||
content: Some(content.into()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an assistant message
|
||||
pub fn assistant(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: Role::Assistant,
|
||||
content: Some(content.into()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an assistant message with tool calls (no text content)
|
||||
pub fn assistant_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
|
||||
Self {
|
||||
role: Role::Assistant,
|
||||
content: None,
|
||||
tool_calls: Some(tool_calls),
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a tool result message
|
||||
pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: Role::Tool,
|
||||
content: Some(content.into()),
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(tool_call_id.into()),
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tool Types
|
||||
// ============================================================================
|
||||
|
||||
/// A tool call requested by the LLM
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
/// Unique identifier for this tool call
|
||||
pub id: String,
|
||||
|
||||
/// The type of tool call (always "function" for now)
|
||||
#[serde(rename = "type", default = "default_function_type")]
|
||||
pub call_type: String,
|
||||
|
||||
/// The function being called
|
||||
pub function: FunctionCall,
|
||||
}
|
||||
|
||||
fn default_function_type() -> String {
|
||||
"function".to_string()
|
||||
}
|
||||
|
||||
/// Details of a function call
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct FunctionCall {
|
||||
/// Name of the function to call
|
||||
pub name: String,
|
||||
|
||||
/// Arguments as a JSON object
|
||||
pub arguments: Value,
|
||||
}
|
||||
|
||||
/// Definition of a tool available to the LLM
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Tool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String,
|
||||
|
||||
pub function: ToolFunction,
|
||||
}
|
||||
|
||||
impl Tool {
|
||||
/// Create a new function tool
|
||||
pub fn function(
|
||||
name: impl Into<String>,
|
||||
description: impl Into<String>,
|
||||
parameters: ToolParameters,
|
||||
) -> Self {
|
||||
Self {
|
||||
tool_type: "function".to_string(),
|
||||
function: ToolFunction {
|
||||
name: name.into(),
|
||||
description: description.into(),
|
||||
parameters,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Function definition within a tool
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolFunction {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: ToolParameters,
|
||||
}
|
||||
|
||||
/// Parameters schema for a function
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolParameters {
|
||||
#[serde(rename = "type")]
|
||||
pub param_type: String,
|
||||
|
||||
/// JSON Schema properties object
|
||||
pub properties: Value,
|
||||
|
||||
/// Required parameter names
|
||||
pub required: Vec<String>,
|
||||
}
|
||||
|
||||
impl ToolParameters {
|
||||
/// Create an object parameter schema
|
||||
pub fn object(properties: Value, required: Vec<String>) -> Self {
|
||||
Self {
|
||||
param_type: "object".to_string(),
|
||||
properties,
|
||||
required,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Streaming Response Types
|
||||
// ============================================================================
|
||||
|
||||
/// A chunk of a streaming response
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamChunk {
|
||||
/// Incremental text content
|
||||
pub content: Option<String>,
|
||||
|
||||
/// Tool calls (may be partial/streaming)
|
||||
pub tool_calls: Option<Vec<ToolCallDelta>>,
|
||||
|
||||
/// Whether this is the final chunk
|
||||
pub done: bool,
|
||||
|
||||
/// Usage statistics (typically only in final chunk)
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
/// Partial tool call for streaming
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolCallDelta {
|
||||
/// Index of this tool call in the array
|
||||
pub index: usize,
|
||||
|
||||
/// Tool call ID (may only be present in first delta)
|
||||
pub id: Option<String>,
|
||||
|
||||
/// Function name (may only be present in first delta)
|
||||
pub function_name: Option<String>,
|
||||
|
||||
/// Incremental arguments string
|
||||
pub arguments_delta: Option<String>,
|
||||
}
|
||||
|
||||
/// Token usage statistics
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Provider Configuration
|
||||
// ============================================================================
|
||||
|
||||
/// Options for a chat request
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ChatOptions {
|
||||
/// Model to use
|
||||
pub model: String,
|
||||
|
||||
/// Temperature (0.0 - 2.0)
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
/// Maximum tokens to generate
|
||||
pub max_tokens: Option<u32>,
|
||||
|
||||
/// Top-p sampling
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
/// Stop sequences
|
||||
pub stop: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl ChatOptions {
|
||||
pub fn new(model: impl Into<String>) -> Self {
|
||||
Self {
|
||||
model: model.into(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_temperature(mut self, temp: f32) -> Self {
|
||||
self.temperature = Some(temp);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_tokens(mut self, max: u32) -> Self {
|
||||
self.max_tokens = Some(max);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Provider Trait
|
||||
// ============================================================================
|
||||
|
||||
/// A boxed stream of chunks
|
||||
pub type ChunkStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, LlmError>> + Send>>;
|
||||
|
||||
/// The main trait that all LLM providers must implement
|
||||
#[async_trait]
|
||||
pub trait LlmProvider: Send + Sync {
|
||||
/// Get the provider name (e.g., "ollama", "anthropic", "openai")
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Get the current model name
|
||||
fn model(&self) -> &str;
|
||||
|
||||
/// Send a chat request and receive a streaming response
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `messages` - The conversation history
|
||||
/// * `options` - Request options (model, temperature, etc.)
|
||||
/// * `tools` - Optional list of tools the model can use
|
||||
///
|
||||
/// # Returns
|
||||
/// A stream of response chunks
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChunkStream, LlmError>;
|
||||
|
||||
/// Send a chat request and receive a complete response (non-streaming)
|
||||
///
|
||||
/// Default implementation collects the stream, but providers may override
|
||||
/// for efficiency.
|
||||
async fn chat(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
use futures::StreamExt;
|
||||
|
||||
let mut stream = self.chat_stream(messages, options, tools).await?;
|
||||
let mut content = String::new();
|
||||
let mut tool_calls: Vec<PartialToolCall> = Vec::new();
|
||||
let mut usage = None;
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk?;
|
||||
|
||||
if let Some(text) = chunk.content {
|
||||
content.push_str(&text);
|
||||
}
|
||||
|
||||
if let Some(deltas) = chunk.tool_calls {
|
||||
for delta in deltas {
|
||||
// Grow the tool_calls vec if needed
|
||||
while tool_calls.len() <= delta.index {
|
||||
tool_calls.push(PartialToolCall::default());
|
||||
}
|
||||
|
||||
let partial = &mut tool_calls[delta.index];
|
||||
if let Some(id) = delta.id {
|
||||
partial.id = Some(id);
|
||||
}
|
||||
if let Some(name) = delta.function_name {
|
||||
partial.function_name = Some(name);
|
||||
}
|
||||
if let Some(args) = delta.arguments_delta {
|
||||
partial.arguments.push_str(&args);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if chunk.usage.is_some() {
|
||||
usage = chunk.usage;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert partial tool calls to complete tool calls
|
||||
let final_tool_calls: Vec<ToolCall> = tool_calls
|
||||
.into_iter()
|
||||
.filter_map(|p| p.try_into_tool_call())
|
||||
.collect();
|
||||
|
||||
Ok(ChatResponse {
|
||||
content: if content.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(content)
|
||||
},
|
||||
tool_calls: if final_tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(final_tool_calls)
|
||||
},
|
||||
usage,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A complete chat response (non-streaming)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ChatResponse {
|
||||
pub content: Option<String>,
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
/// Helper for accumulating streaming tool calls
|
||||
#[derive(Default)]
|
||||
struct PartialToolCall {
|
||||
id: Option<String>,
|
||||
function_name: Option<String>,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
impl PartialToolCall {
|
||||
fn try_into_tool_call(self) -> Option<ToolCall> {
|
||||
let id = self.id?;
|
||||
let name = self.function_name?;
|
||||
let arguments: Value = serde_json::from_str(&self.arguments).ok()?;
|
||||
|
||||
Some(ToolCall {
|
||||
id,
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall { name, arguments },
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Authentication
|
||||
// ============================================================================
|
||||
|
||||
/// Authentication method for LLM providers
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AuthMethod {
|
||||
/// No authentication (for local providers like Ollama)
|
||||
None,
|
||||
|
||||
/// API key authentication
|
||||
ApiKey(String),
|
||||
|
||||
/// OAuth access token (from login flow)
|
||||
OAuth {
|
||||
access_token: String,
|
||||
refresh_token: Option<String>,
|
||||
expires_at: Option<u64>,
|
||||
},
|
||||
}
|
||||
|
||||
impl AuthMethod {
|
||||
/// Create API key auth
|
||||
pub fn api_key(key: impl Into<String>) -> Self {
|
||||
Self::ApiKey(key.into())
|
||||
}
|
||||
|
||||
/// Create OAuth auth from tokens
|
||||
pub fn oauth(access_token: impl Into<String>) -> Self {
|
||||
Self::OAuth {
|
||||
access_token: access_token.into(),
|
||||
refresh_token: None,
|
||||
expires_at: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create OAuth auth with refresh token
|
||||
pub fn oauth_with_refresh(
|
||||
access_token: impl Into<String>,
|
||||
refresh_token: impl Into<String>,
|
||||
expires_at: Option<u64>,
|
||||
) -> Self {
|
||||
Self::OAuth {
|
||||
access_token: access_token.into(),
|
||||
refresh_token: Some(refresh_token.into()),
|
||||
expires_at,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the bearer token for Authorization header
|
||||
pub fn bearer_token(&self) -> Option<&str> {
|
||||
match self {
|
||||
Self::None => None,
|
||||
Self::ApiKey(key) => Some(key),
|
||||
Self::OAuth { access_token, .. } => Some(access_token),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if token might need refresh
|
||||
pub fn needs_refresh(&self) -> bool {
|
||||
match self {
|
||||
Self::OAuth {
|
||||
expires_at: Some(exp),
|
||||
refresh_token: Some(_),
|
||||
..
|
||||
} => {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0);
|
||||
// Refresh if expiring within 5 minutes
|
||||
*exp < now + 300
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Device code response for OAuth device flow
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DeviceCodeResponse {
|
||||
/// Code the user enters on the verification page
|
||||
pub user_code: String,
|
||||
|
||||
/// URL the user visits to authorize
|
||||
pub verification_uri: String,
|
||||
|
||||
/// Full URL with code pre-filled (if supported)
|
||||
pub verification_uri_complete: Option<String>,
|
||||
|
||||
/// Device code for polling (internal use)
|
||||
pub device_code: String,
|
||||
|
||||
/// How often to poll (in seconds)
|
||||
pub interval: u64,
|
||||
|
||||
/// When the codes expire (in seconds)
|
||||
pub expires_in: u64,
|
||||
}
|
||||
|
||||
/// Result of polling for device authorization
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DeviceAuthResult {
|
||||
/// Still waiting for user to authorize
|
||||
Pending,
|
||||
|
||||
/// User authorized, here are the tokens
|
||||
Success {
|
||||
access_token: String,
|
||||
refresh_token: Option<String>,
|
||||
expires_in: Option<u64>,
|
||||
},
|
||||
|
||||
/// User denied authorization
|
||||
Denied,
|
||||
|
||||
/// Code expired
|
||||
Expired,
|
||||
}
|
||||
|
||||
/// Trait for providers that support OAuth device flow
|
||||
#[async_trait]
|
||||
pub trait OAuthProvider {
|
||||
/// Start the device authorization flow
|
||||
async fn start_device_auth(&self) -> Result<DeviceCodeResponse, LlmError>;
|
||||
|
||||
/// Poll for the authorization result
|
||||
async fn poll_device_auth(&self, device_code: &str) -> Result<DeviceAuthResult, LlmError>;
|
||||
|
||||
/// Refresh an access token using a refresh token
|
||||
async fn refresh_token(&self, refresh_token: &str) -> Result<AuthMethod, LlmError>;
|
||||
}
|
||||
|
||||
/// Stored credentials for a provider
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StoredCredentials {
|
||||
pub provider: String,
|
||||
pub access_token: String,
|
||||
pub refresh_token: Option<String>,
|
||||
pub expires_at: Option<u64>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Provider Status & Info
|
||||
// ============================================================================
|
||||
|
||||
/// Status information for a provider connection
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProviderStatus {
|
||||
/// Provider name
|
||||
pub provider: String,
|
||||
|
||||
/// Whether the connection is authenticated
|
||||
pub authenticated: bool,
|
||||
|
||||
/// Current user/account info if authenticated
|
||||
pub account: Option<AccountInfo>,
|
||||
|
||||
/// Current model being used
|
||||
pub model: String,
|
||||
|
||||
/// API endpoint URL
|
||||
pub endpoint: String,
|
||||
|
||||
/// Whether the provider is reachable
|
||||
pub reachable: bool,
|
||||
|
||||
/// Any status message or error
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
/// Account/user information from the provider
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AccountInfo {
|
||||
/// Account/user ID
|
||||
pub id: Option<String>,
|
||||
|
||||
/// Display name or email
|
||||
pub name: Option<String>,
|
||||
|
||||
/// Account email
|
||||
pub email: Option<String>,
|
||||
|
||||
/// Account type (free, pro, team, enterprise)
|
||||
pub account_type: Option<String>,
|
||||
|
||||
/// Organization name if applicable
|
||||
pub organization: Option<String>,
|
||||
}
|
||||
|
||||
/// Usage statistics from the provider
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UsageStats {
|
||||
/// Total tokens used in current period
|
||||
pub tokens_used: Option<u64>,
|
||||
|
||||
/// Token limit for current period (if applicable)
|
||||
pub token_limit: Option<u64>,
|
||||
|
||||
/// Number of requests made
|
||||
pub requests_made: Option<u64>,
|
||||
|
||||
/// Request limit (if applicable)
|
||||
pub request_limit: Option<u64>,
|
||||
|
||||
/// Cost incurred (if available)
|
||||
pub cost_usd: Option<f64>,
|
||||
|
||||
/// Period start timestamp
|
||||
pub period_start: Option<u64>,
|
||||
|
||||
/// Period end timestamp
|
||||
pub period_end: Option<u64>,
|
||||
}
|
||||
|
||||
/// Available model information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelInfo {
|
||||
/// Model ID/name
|
||||
pub id: String,
|
||||
|
||||
/// Human-readable display name
|
||||
pub display_name: Option<String>,
|
||||
|
||||
/// Model description
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Context window size (tokens)
|
||||
pub context_window: Option<u32>,
|
||||
|
||||
/// Max output tokens
|
||||
pub max_output_tokens: Option<u32>,
|
||||
|
||||
/// Whether the model supports tool use
|
||||
pub supports_tools: bool,
|
||||
|
||||
/// Whether the model supports vision/images
|
||||
pub supports_vision: bool,
|
||||
|
||||
/// Input token price per 1M tokens (USD)
|
||||
pub input_price_per_mtok: Option<f64>,
|
||||
|
||||
/// Output token price per 1M tokens (USD)
|
||||
pub output_price_per_mtok: Option<f64>,
|
||||
}
|
||||
|
||||
/// Trait for providers that support status/info queries
|
||||
#[async_trait]
|
||||
pub trait ProviderInfo {
|
||||
/// Get the current connection status
|
||||
async fn status(&self) -> Result<ProviderStatus, LlmError>;
|
||||
|
||||
/// Get account information (if authenticated)
|
||||
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError>;
|
||||
|
||||
/// Get usage statistics (if available)
|
||||
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError>;
|
||||
|
||||
/// List available models
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError>;
|
||||
|
||||
/// Check if a specific model is available
|
||||
async fn model_info(&self, model_id: &str) -> Result<Option<ModelInfo>, LlmError> {
|
||||
let models = self.list_models().await?;
|
||||
Ok(models.into_iter().find(|m| m.id == model_id))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Provider Factory
|
||||
// ============================================================================
|
||||
|
||||
/// Supported LLM providers
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ProviderType {
|
||||
Ollama,
|
||||
Anthropic,
|
||||
OpenAI,
|
||||
}
|
||||
|
||||
impl ProviderType {
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"ollama" => Some(Self::Ollama),
|
||||
"anthropic" | "claude" => Some(Self::Anthropic),
|
||||
"openai" | "gpt" => Some(Self::OpenAI),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Ollama => "ollama",
|
||||
Self::Anthropic => "anthropic",
|
||||
Self::OpenAI => "openai",
|
||||
}
|
||||
}
|
||||
|
||||
/// Default model for this provider
|
||||
pub fn default_model(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Ollama => "qwen3:8b",
|
||||
Self::Anthropic => "claude-sonnet-4-20250514",
|
||||
Self::OpenAI => "gpt-4o",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ProviderType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.as_str())
|
||||
}
|
||||
}
|
||||
386
crates/llm/core/src/retry.rs
Normal file
386
crates/llm/core/src/retry.rs
Normal file
@@ -0,0 +1,386 @@
|
||||
//! Error recovery and retry logic for LLM operations
|
||||
//!
|
||||
//! This module provides configurable retry strategies with exponential backoff
|
||||
//! for handling transient failures when communicating with LLM providers.
|
||||
|
||||
use crate::LlmError;
|
||||
use rand::Rng;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Configuration for retry behavior
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetryConfig {
|
||||
/// Maximum number of retry attempts
|
||||
pub max_retries: u32,
|
||||
/// Initial delay before first retry (in milliseconds)
|
||||
pub initial_delay_ms: u64,
|
||||
/// Maximum delay between retries (in milliseconds)
|
||||
pub max_delay_ms: u64,
|
||||
/// Multiplier for exponential backoff
|
||||
pub backoff_multiplier: f32,
|
||||
}
|
||||
|
||||
impl Default for RetryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_retries: 3,
|
||||
initial_delay_ms: 1000,
|
||||
max_delay_ms: 30000,
|
||||
backoff_multiplier: 2.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RetryConfig {
|
||||
/// Create a new retry configuration with custom values
|
||||
pub fn new(
|
||||
max_retries: u32,
|
||||
initial_delay_ms: u64,
|
||||
max_delay_ms: u64,
|
||||
backoff_multiplier: f32,
|
||||
) -> Self {
|
||||
Self {
|
||||
max_retries,
|
||||
initial_delay_ms,
|
||||
max_delay_ms,
|
||||
backoff_multiplier,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a configuration with no retries
|
||||
pub fn no_retry() -> Self {
|
||||
Self {
|
||||
max_retries: 0,
|
||||
initial_delay_ms: 0,
|
||||
max_delay_ms: 0,
|
||||
backoff_multiplier: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a configuration with aggressive retries for rate-limited scenarios
|
||||
pub fn aggressive() -> Self {
|
||||
Self {
|
||||
max_retries: 5,
|
||||
initial_delay_ms: 2000,
|
||||
max_delay_ms: 60000,
|
||||
backoff_multiplier: 2.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Determines whether an error is retryable
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `error` - The error to check
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if the error is transient and the operation should be retried,
|
||||
/// `false` if the error is permanent and retrying won't help
|
||||
pub fn is_retryable_error(error: &LlmError) -> bool {
|
||||
match error {
|
||||
// Always retry rate limits
|
||||
LlmError::RateLimit { .. } => true,
|
||||
|
||||
// Always retry timeouts
|
||||
LlmError::Timeout(_) => true,
|
||||
|
||||
// Retry HTTP errors that are server-side (5xx)
|
||||
LlmError::Http(msg) => {
|
||||
// Check if the error message contains a 5xx status code
|
||||
msg.contains("500")
|
||||
|| msg.contains("502")
|
||||
|| msg.contains("503")
|
||||
|| msg.contains("504")
|
||||
|| msg.contains("Internal Server Error")
|
||||
|| msg.contains("Bad Gateway")
|
||||
|| msg.contains("Service Unavailable")
|
||||
|| msg.contains("Gateway Timeout")
|
||||
}
|
||||
|
||||
// Don't retry authentication errors - they need user intervention
|
||||
LlmError::Auth(_) => false,
|
||||
|
||||
// Don't retry JSON parsing errors - the data is malformed
|
||||
LlmError::Json(_) => false,
|
||||
|
||||
// Don't retry API errors - these are typically client-side issues
|
||||
LlmError::Api { .. } => false,
|
||||
|
||||
// Provider errors might be transient, but we conservatively don't retry
|
||||
LlmError::Provider(_) => false,
|
||||
|
||||
// Stream errors are typically not retryable
|
||||
LlmError::Stream(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Strategy for retrying failed operations with exponential backoff
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetryStrategy {
|
||||
config: RetryConfig,
|
||||
}
|
||||
|
||||
impl RetryStrategy {
|
||||
/// Create a new retry strategy with the given configuration
|
||||
pub fn new(config: RetryConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
/// Create a retry strategy with default configuration
|
||||
pub fn default_config() -> Self {
|
||||
Self::new(RetryConfig::default())
|
||||
}
|
||||
|
||||
/// Execute an async operation with retries
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `operation` - A function that returns a Future producing a Result
|
||||
///
|
||||
/// # Returns
|
||||
/// The result of the operation, or the last error if all retries fail
|
||||
///
|
||||
/// # Example
|
||||
/// ```ignore
|
||||
/// let strategy = RetryStrategy::default_config();
|
||||
/// let result = strategy.execute(|| async {
|
||||
/// // Your LLM API call here
|
||||
/// llm_client.chat(&messages, &options, None).await
|
||||
/// }).await?;
|
||||
/// ```
|
||||
pub async fn execute<F, T, Fut>(&self, operation: F) -> Result<T, LlmError>
|
||||
where
|
||||
F: Fn() -> Fut,
|
||||
Fut: std::future::Future<Output = Result<T, LlmError>>,
|
||||
{
|
||||
let mut attempt = 0;
|
||||
|
||||
loop {
|
||||
// Try the operation
|
||||
match operation().await {
|
||||
Ok(result) => return Ok(result),
|
||||
Err(err) => {
|
||||
// Check if we should retry
|
||||
if !is_retryable_error(&err) {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
attempt += 1;
|
||||
|
||||
// Check if we've exhausted retries
|
||||
if attempt > self.config.max_retries {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
// Calculate delay with exponential backoff and jitter
|
||||
let delay = self.delay_for_attempt(attempt);
|
||||
|
||||
// Log retry attempt (in a real implementation, you might use tracing)
|
||||
eprintln!(
|
||||
"Retry attempt {}/{} after {:?}",
|
||||
attempt, self.config.max_retries, delay
|
||||
);
|
||||
|
||||
// Sleep before next attempt
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the delay for a given attempt number with jitter
|
||||
///
|
||||
/// Uses exponential backoff: delay = initial_delay * (backoff_multiplier ^ (attempt - 1))
|
||||
/// Adds random jitter of ±10% to prevent thundering herd problems
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `attempt` - The attempt number (1-indexed)
|
||||
///
|
||||
/// # Returns
|
||||
/// The delay duration to wait before the next retry
|
||||
fn delay_for_attempt(&self, attempt: u32) -> Duration {
|
||||
// Calculate base delay with exponential backoff
|
||||
let base_delay_ms = self.config.initial_delay_ms as f64
|
||||
* self.config.backoff_multiplier.powi((attempt - 1) as i32) as f64;
|
||||
|
||||
// Cap at max_delay_ms
|
||||
let capped_delay_ms = base_delay_ms.min(self.config.max_delay_ms as f64);
|
||||
|
||||
// Add jitter: ±10%
|
||||
let mut rng = rand::thread_rng();
|
||||
let jitter_factor = rng.gen_range(0.9..=1.1);
|
||||
let final_delay_ms = capped_delay_ms * jitter_factor;
|
||||
|
||||
Duration::from_millis(final_delay_ms as u64)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[test]
|
||||
fn test_default_retry_config() {
|
||||
let config = RetryConfig::default();
|
||||
assert_eq!(config.max_retries, 3);
|
||||
assert_eq!(config.initial_delay_ms, 1000);
|
||||
assert_eq!(config.max_delay_ms, 30000);
|
||||
assert_eq!(config.backoff_multiplier, 2.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_retry_config() {
|
||||
let config = RetryConfig::no_retry();
|
||||
assert_eq!(config.max_retries, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_retryable_error() {
|
||||
// Retryable errors
|
||||
assert!(is_retryable_error(&LlmError::RateLimit {
|
||||
retry_after_secs: Some(60)
|
||||
}));
|
||||
assert!(is_retryable_error(&LlmError::Timeout(
|
||||
"Request timed out".to_string()
|
||||
)));
|
||||
assert!(is_retryable_error(&LlmError::Http(
|
||||
"500 Internal Server Error".to_string()
|
||||
)));
|
||||
assert!(is_retryable_error(&LlmError::Http(
|
||||
"503 Service Unavailable".to_string()
|
||||
)));
|
||||
|
||||
// Non-retryable errors
|
||||
assert!(!is_retryable_error(&LlmError::Auth(
|
||||
"Invalid API key".to_string()
|
||||
)));
|
||||
assert!(!is_retryable_error(&LlmError::Json(
|
||||
"Invalid JSON".to_string()
|
||||
)));
|
||||
assert!(!is_retryable_error(&LlmError::Api {
|
||||
message: "Invalid request".to_string(),
|
||||
code: Some("400".to_string())
|
||||
}));
|
||||
assert!(!is_retryable_error(&LlmError::Http(
|
||||
"400 Bad Request".to_string()
|
||||
)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delay_calculation() {
|
||||
let config = RetryConfig::default();
|
||||
let strategy = RetryStrategy::new(config);
|
||||
|
||||
// Test that delays increase exponentially
|
||||
let delay1 = strategy.delay_for_attempt(1);
|
||||
let delay2 = strategy.delay_for_attempt(2);
|
||||
let delay3 = strategy.delay_for_attempt(3);
|
||||
|
||||
// Base delays should be around 1000ms, 2000ms, 4000ms (with jitter)
|
||||
assert!(delay1.as_millis() >= 900 && delay1.as_millis() <= 1100);
|
||||
assert!(delay2.as_millis() >= 1800 && delay2.as_millis() <= 2200);
|
||||
assert!(delay3.as_millis() >= 3600 && delay3.as_millis() <= 4400);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delay_max_cap() {
|
||||
let config = RetryConfig {
|
||||
max_retries: 10,
|
||||
initial_delay_ms: 1000,
|
||||
max_delay_ms: 5000,
|
||||
backoff_multiplier: 2.0,
|
||||
};
|
||||
let strategy = RetryStrategy::new(config);
|
||||
|
||||
// Even with high attempt numbers, delay should be capped
|
||||
let delay = strategy.delay_for_attempt(10);
|
||||
assert!(delay.as_millis() <= 5500); // max + jitter
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retry_success_on_first_attempt() {
|
||||
let strategy = RetryStrategy::default_config();
|
||||
let call_count = Arc::new(AtomicU32::new(0));
|
||||
let count_clone = call_count.clone();
|
||||
|
||||
let result = strategy
|
||||
.execute(|| {
|
||||
let count = count_clone.clone();
|
||||
async move {
|
||||
count.fetch_add(1, Ordering::SeqCst);
|
||||
Ok::<_, LlmError>(42)
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(result.unwrap(), 42);
|
||||
assert_eq!(call_count.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retry_success_after_retries() {
|
||||
let config = RetryConfig::new(3, 10, 100, 2.0); // Fast retries for testing
|
||||
let strategy = RetryStrategy::new(config);
|
||||
let call_count = Arc::new(AtomicU32::new(0));
|
||||
let count_clone = call_count.clone();
|
||||
|
||||
let result = strategy
|
||||
.execute(|| {
|
||||
let count = count_clone.clone();
|
||||
async move {
|
||||
let current = count.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
if current < 3 {
|
||||
Err(LlmError::Timeout("Timeout".to_string()))
|
||||
} else {
|
||||
Ok(42)
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(result.unwrap(), 42);
|
||||
assert_eq!(call_count.load(Ordering::SeqCst), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retry_exhausted() {
|
||||
let config = RetryConfig::new(2, 10, 100, 2.0); // Fast retries for testing
|
||||
let strategy = RetryStrategy::new(config);
|
||||
let call_count = Arc::new(AtomicU32::new(0));
|
||||
let count_clone = call_count.clone();
|
||||
|
||||
let result = strategy
|
||||
.execute(|| {
|
||||
let count = count_clone.clone();
|
||||
async move {
|
||||
count.fetch_add(1, Ordering::SeqCst);
|
||||
Err::<(), _>(LlmError::Timeout("Always fails".to_string()))
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert_eq!(call_count.load(Ordering::SeqCst), 3); // Initial attempt + 2 retries
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_non_retryable_error() {
|
||||
let strategy = RetryStrategy::default_config();
|
||||
let call_count = Arc::new(AtomicU32::new(0));
|
||||
let count_clone = call_count.clone();
|
||||
|
||||
let result = strategy
|
||||
.execute(|| {
|
||||
let count = count_clone.clone();
|
||||
async move {
|
||||
count.fetch_add(1, Ordering::SeqCst);
|
||||
Err::<(), _>(LlmError::Auth("Invalid API key".to_string()))
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert_eq!(call_count.load(Ordering::SeqCst), 1); // Should not retry
|
||||
}
|
||||
}
|
||||
607
crates/llm/core/src/tokens.rs
Normal file
607
crates/llm/core/src/tokens.rs
Normal file
@@ -0,0 +1,607 @@
|
||||
//! Token counting utilities for LLM context management
|
||||
//!
|
||||
//! This module provides token counting abstractions and implementations for
|
||||
//! managing LLM context windows. Token counters estimate token usage without
|
||||
//! requiring external tokenization libraries, using heuristic-based approaches.
|
||||
|
||||
use crate::ChatMessage;
|
||||
|
||||
// ============================================================================
|
||||
// TokenCounter Trait
|
||||
// ============================================================================
|
||||
|
||||
/// Trait for counting tokens in text and chat messages
|
||||
///
|
||||
/// Implementations provide model-specific token counting logic to help
|
||||
/// manage context windows and estimate API costs.
|
||||
pub trait TokenCounter: Send + Sync {
|
||||
/// Count tokens in a string
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `text` - The text to count tokens for
|
||||
///
|
||||
/// # Returns
|
||||
/// Estimated number of tokens
|
||||
fn count(&self, text: &str) -> usize;
|
||||
|
||||
/// Count tokens in chat messages
|
||||
///
|
||||
/// This accounts for both the message content and the overhead
|
||||
/// from the chat message structure (roles, delimiters, etc.).
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `messages` - The messages to count tokens for
|
||||
///
|
||||
/// # Returns
|
||||
/// Estimated total tokens including message structure overhead
|
||||
fn count_messages(&self, messages: &[ChatMessage]) -> usize;
|
||||
|
||||
/// Get the model's max context window size
|
||||
///
|
||||
/// # Returns
|
||||
/// Maximum number of tokens the model can handle
|
||||
fn max_context(&self) -> usize;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SimpleTokenCounter
|
||||
// ============================================================================
|
||||
|
||||
/// A basic token counter using simple heuristics
|
||||
///
|
||||
/// This counter uses the rule of thumb that English text averages about
|
||||
/// 4 characters per token. It adds overhead for message structure.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use llm_core::tokens::{TokenCounter, SimpleTokenCounter};
|
||||
/// use llm_core::ChatMessage;
|
||||
///
|
||||
/// let counter = SimpleTokenCounter::new(8192);
|
||||
/// let text = "Hello, world!";
|
||||
/// let tokens = counter.count(text);
|
||||
/// assert!(tokens > 0);
|
||||
///
|
||||
/// let messages = vec![
|
||||
/// ChatMessage::user("What is the weather?"),
|
||||
/// ChatMessage::assistant("I don't have access to weather data."),
|
||||
/// ];
|
||||
/// let total = counter.count_messages(&messages);
|
||||
/// assert!(total > 0);
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SimpleTokenCounter {
|
||||
max_context: usize,
|
||||
}
|
||||
|
||||
impl SimpleTokenCounter {
|
||||
/// Create a new simple token counter
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `max_context` - Maximum context window size for the model
|
||||
pub fn new(max_context: usize) -> Self {
|
||||
Self { max_context }
|
||||
}
|
||||
|
||||
/// Create a token counter with a default 8192 token context
|
||||
pub fn default_8k() -> Self {
|
||||
Self::new(8192)
|
||||
}
|
||||
|
||||
/// Create a token counter with a 32k token context
|
||||
pub fn with_32k() -> Self {
|
||||
Self::new(32768)
|
||||
}
|
||||
|
||||
/// Create a token counter with a 128k token context
|
||||
pub fn with_128k() -> Self {
|
||||
Self::new(131072)
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenCounter for SimpleTokenCounter {
|
||||
fn count(&self, text: &str) -> usize {
|
||||
// Estimate: approximately 4 characters per token for English
|
||||
// Add 3 before dividing to round up
|
||||
(text.len() + 3) / 4
|
||||
}
|
||||
|
||||
fn count_messages(&self, messages: &[ChatMessage]) -> usize {
|
||||
let mut total = 0;
|
||||
|
||||
// Base overhead for message formatting (estimated)
|
||||
// Each message has role, delimiters, etc.
|
||||
const MESSAGE_OVERHEAD: usize = 4;
|
||||
|
||||
for msg in messages {
|
||||
// Count role
|
||||
total += MESSAGE_OVERHEAD;
|
||||
|
||||
// Count content
|
||||
if let Some(content) = &msg.content {
|
||||
total += self.count(content);
|
||||
}
|
||||
|
||||
// Count tool calls (more expensive due to JSON structure)
|
||||
if let Some(tool_calls) = &msg.tool_calls {
|
||||
for tc in tool_calls {
|
||||
// ID overhead
|
||||
total += self.count(&tc.id);
|
||||
// Function name
|
||||
total += self.count(&tc.function.name);
|
||||
// Arguments (JSON serialized, add 20% overhead for JSON structure)
|
||||
let args_str = tc.function.arguments.to_string();
|
||||
total += (self.count(&args_str) * 12) / 10;
|
||||
}
|
||||
}
|
||||
|
||||
// Count tool call id for tool result messages
|
||||
if let Some(tool_call_id) = &msg.tool_call_id {
|
||||
total += self.count(tool_call_id);
|
||||
}
|
||||
|
||||
// Count tool name for tool result messages
|
||||
if let Some(name) = &msg.name {
|
||||
total += self.count(name);
|
||||
}
|
||||
}
|
||||
|
||||
total
|
||||
}
|
||||
|
||||
fn max_context(&self) -> usize {
|
||||
self.max_context
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ClaudeTokenCounter
|
||||
// ============================================================================
|
||||
|
||||
/// Token counter optimized for Anthropic Claude models
|
||||
///
|
||||
/// Claude models have specific tokenization characteristics and overhead.
|
||||
/// This counter adjusts the estimates accordingly.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use llm_core::tokens::{TokenCounter, ClaudeTokenCounter};
|
||||
/// use llm_core::ChatMessage;
|
||||
///
|
||||
/// let counter = ClaudeTokenCounter::new();
|
||||
/// let messages = vec![
|
||||
/// ChatMessage::system("You are a helpful assistant."),
|
||||
/// ChatMessage::user("Hello!"),
|
||||
/// ];
|
||||
/// let total = counter.count_messages(&messages);
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClaudeTokenCounter {
|
||||
max_context: usize,
|
||||
}
|
||||
|
||||
impl ClaudeTokenCounter {
|
||||
/// Create a new Claude token counter with default 200k context
|
||||
///
|
||||
/// This is suitable for Claude 3.5 Sonnet, Claude 4 Sonnet, and Claude 4 Opus.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
max_context: 200_000,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a Claude counter with a custom context window
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `max_context` - Maximum context window size
|
||||
pub fn with_context(max_context: usize) -> Self {
|
||||
Self { max_context }
|
||||
}
|
||||
|
||||
/// Create a counter for Claude 3 Haiku (200k context)
|
||||
pub fn haiku() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
|
||||
/// Create a counter for Claude 3.5 Sonnet (200k context)
|
||||
pub fn sonnet() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
|
||||
/// Create a counter for Claude 4 Opus (200k context)
|
||||
pub fn opus() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClaudeTokenCounter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenCounter for ClaudeTokenCounter {
|
||||
fn count(&self, text: &str) -> usize {
|
||||
// Claude's tokenization is similar to the 4 chars/token heuristic
|
||||
// but tends to be slightly more efficient with structured content
|
||||
(text.len() + 3) / 4
|
||||
}
|
||||
|
||||
fn count_messages(&self, messages: &[ChatMessage]) -> usize {
|
||||
let mut total = 0;
|
||||
|
||||
// Claude has specific message formatting overhead
|
||||
const MESSAGE_OVERHEAD: usize = 5;
|
||||
const SYSTEM_MESSAGE_OVERHEAD: usize = 3;
|
||||
|
||||
for msg in messages {
|
||||
// Different overhead for system vs other messages
|
||||
let overhead = if matches!(msg.role, crate::Role::System) {
|
||||
SYSTEM_MESSAGE_OVERHEAD
|
||||
} else {
|
||||
MESSAGE_OVERHEAD
|
||||
};
|
||||
|
||||
total += overhead;
|
||||
|
||||
// Count content
|
||||
if let Some(content) = &msg.content {
|
||||
total += self.count(content);
|
||||
}
|
||||
|
||||
// Count tool calls
|
||||
if let Some(tool_calls) = &msg.tool_calls {
|
||||
// Claude's tool call format has additional overhead
|
||||
const TOOL_CALL_OVERHEAD: usize = 10;
|
||||
|
||||
for tc in tool_calls {
|
||||
total += TOOL_CALL_OVERHEAD;
|
||||
total += self.count(&tc.id);
|
||||
total += self.count(&tc.function.name);
|
||||
|
||||
// Arguments with JSON structure overhead
|
||||
let args_str = tc.function.arguments.to_string();
|
||||
total += (self.count(&args_str) * 12) / 10;
|
||||
}
|
||||
}
|
||||
|
||||
// Tool result overhead
|
||||
if msg.tool_call_id.is_some() {
|
||||
const TOOL_RESULT_OVERHEAD: usize = 8;
|
||||
total += TOOL_RESULT_OVERHEAD;
|
||||
|
||||
if let Some(tool_call_id) = &msg.tool_call_id {
|
||||
total += self.count(tool_call_id);
|
||||
}
|
||||
|
||||
if let Some(name) = &msg.name {
|
||||
total += self.count(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
total
|
||||
}
|
||||
|
||||
fn max_context(&self) -> usize {
|
||||
self.max_context
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ContextWindow
|
||||
// ============================================================================
|
||||
|
||||
/// Manages context window tracking for a conversation
|
||||
///
|
||||
/// Helps monitor token usage and determine when context limits are approaching.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use llm_core::tokens::{ContextWindow, TokenCounter, SimpleTokenCounter};
|
||||
/// use llm_core::ChatMessage;
|
||||
///
|
||||
/// let counter = SimpleTokenCounter::new(8192);
|
||||
/// let mut window = ContextWindow::new(counter.max_context());
|
||||
///
|
||||
/// let messages = vec![
|
||||
/// ChatMessage::user("Hello!"),
|
||||
/// ChatMessage::assistant("Hi there!"),
|
||||
/// ];
|
||||
///
|
||||
/// let tokens = counter.count_messages(&messages);
|
||||
/// window.add_tokens(tokens);
|
||||
///
|
||||
/// println!("Used: {} tokens", window.used());
|
||||
/// println!("Remaining: {} tokens", window.remaining());
|
||||
/// println!("Usage: {:.1}%", window.usage_percent() * 100.0);
|
||||
///
|
||||
/// if window.is_near_limit(0.8) {
|
||||
/// println!("Warning: Context is 80% full!");
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ContextWindow {
|
||||
/// Number of tokens currently used
|
||||
used: usize,
|
||||
/// Maximum number of tokens allowed
|
||||
max: usize,
|
||||
}
|
||||
|
||||
impl ContextWindow {
|
||||
/// Create a new context window tracker
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `max` - Maximum context window size in tokens
|
||||
pub fn new(max: usize) -> Self {
|
||||
Self { used: 0, max }
|
||||
}
|
||||
|
||||
/// Create a context window with initial usage
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `max` - Maximum context window size
|
||||
/// * `used` - Initial number of tokens used
|
||||
pub fn with_usage(max: usize, used: usize) -> Self {
|
||||
Self { used, max }
|
||||
}
|
||||
|
||||
/// Get the number of tokens currently used
|
||||
pub fn used(&self) -> usize {
|
||||
self.used
|
||||
}
|
||||
|
||||
/// Get the maximum number of tokens
|
||||
pub fn max(&self) -> usize {
|
||||
self.max
|
||||
}
|
||||
|
||||
/// Get the number of remaining tokens
|
||||
pub fn remaining(&self) -> usize {
|
||||
self.max.saturating_sub(self.used)
|
||||
}
|
||||
|
||||
/// Get the usage as a percentage (0.0 to 1.0)
|
||||
///
|
||||
/// Returns the fraction of the context window that is currently used.
|
||||
pub fn usage_percent(&self) -> f32 {
|
||||
if self.max == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
self.used as f32 / self.max as f32
|
||||
}
|
||||
|
||||
/// Check if usage is near the limit
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `threshold` - Threshold as a fraction (0.0 to 1.0). For example,
|
||||
/// 0.8 means "is usage > 80%?"
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if the current usage exceeds the threshold percentage
|
||||
pub fn is_near_limit(&self, threshold: f32) -> bool {
|
||||
self.usage_percent() > threshold
|
||||
}
|
||||
|
||||
/// Add tokens to the usage count
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tokens` - Number of tokens to add
|
||||
pub fn add_tokens(&mut self, tokens: usize) {
|
||||
self.used = self.used.saturating_add(tokens);
|
||||
}
|
||||
|
||||
/// Set the current usage
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `used` - Number of tokens currently used
|
||||
pub fn set_used(&mut self, used: usize) {
|
||||
self.used = used;
|
||||
}
|
||||
|
||||
/// Reset the usage counter to zero
|
||||
pub fn reset(&mut self) {
|
||||
self.used = 0;
|
||||
}
|
||||
|
||||
/// Check if there's enough room for additional tokens
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tokens` - Number of tokens needed
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if adding these tokens would stay within the limit
|
||||
pub fn has_room_for(&self, tokens: usize) -> bool {
|
||||
self.used.saturating_add(tokens) <= self.max
|
||||
}
|
||||
|
||||
/// Get a visual progress bar representation
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `width` - Width of the progress bar in characters
|
||||
///
|
||||
/// # Returns
|
||||
/// A string with a simple text-based progress bar
|
||||
pub fn progress_bar(&self, width: usize) -> String {
|
||||
if width == 0 {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let percent = self.usage_percent();
|
||||
let filled = ((percent * width as f32) as usize).min(width);
|
||||
let empty = width - filled;
|
||||
|
||||
format!(
|
||||
"[{}{}] {:.1}%",
|
||||
"=".repeat(filled),
|
||||
" ".repeat(empty),
|
||||
percent * 100.0
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{ChatMessage, FunctionCall, ToolCall};
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_simple_counter_basic() {
|
||||
let counter = SimpleTokenCounter::new(8192);
|
||||
|
||||
// Empty string
|
||||
assert_eq!(counter.count(""), 0);
|
||||
|
||||
// Short string (~4 chars/token)
|
||||
let text = "Hello, world!"; // 13 chars -> ~4 tokens
|
||||
let count = counter.count(text);
|
||||
assert!(count >= 3 && count <= 5);
|
||||
|
||||
// Longer text
|
||||
let text = "The quick brown fox jumps over the lazy dog"; // 44 chars -> ~11 tokens
|
||||
let count = counter.count(text);
|
||||
assert!(count >= 10 && count <= 13);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_counter_messages() {
|
||||
let counter = SimpleTokenCounter::new(8192);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::user("Hello!"),
|
||||
ChatMessage::assistant("Hi there! How can I help you today?"),
|
||||
];
|
||||
|
||||
let total = counter.count_messages(&messages);
|
||||
|
||||
// Should be more than just the text due to overhead
|
||||
let text_only = counter.count("Hello!") + counter.count("Hi there! How can I help you today?");
|
||||
assert!(total > text_only);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_counter_with_tool_calls() {
|
||||
let counter = SimpleTokenCounter::new(8192);
|
||||
|
||||
let tool_call = ToolCall {
|
||||
id: "call_123".to_string(),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: "read_file".to_string(),
|
||||
arguments: json!({"path": "/etc/hosts"}),
|
||||
},
|
||||
};
|
||||
|
||||
let messages = vec![ChatMessage::assistant_tool_calls(vec![tool_call])];
|
||||
|
||||
let total = counter.count_messages(&messages);
|
||||
assert!(total > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_claude_counter() {
|
||||
let counter = ClaudeTokenCounter::new();
|
||||
|
||||
assert_eq!(counter.max_context(), 200_000);
|
||||
|
||||
let text = "Hello, Claude!";
|
||||
let count = counter.count(text);
|
||||
assert!(count > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_claude_counter_system_message() {
|
||||
let counter = ClaudeTokenCounter::new();
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are a helpful assistant."),
|
||||
ChatMessage::user("Hello!"),
|
||||
];
|
||||
|
||||
let total = counter.count_messages(&messages);
|
||||
assert!(total > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_window() {
|
||||
let mut window = ContextWindow::new(1000);
|
||||
|
||||
assert_eq!(window.used(), 0);
|
||||
assert_eq!(window.max(), 1000);
|
||||
assert_eq!(window.remaining(), 1000);
|
||||
assert_eq!(window.usage_percent(), 0.0);
|
||||
|
||||
window.add_tokens(200);
|
||||
assert_eq!(window.used(), 200);
|
||||
assert_eq!(window.remaining(), 800);
|
||||
assert_eq!(window.usage_percent(), 0.2);
|
||||
|
||||
window.add_tokens(600);
|
||||
assert_eq!(window.used(), 800);
|
||||
assert!(window.is_near_limit(0.7));
|
||||
assert!(!window.is_near_limit(0.9));
|
||||
|
||||
assert!(window.has_room_for(200));
|
||||
assert!(!window.has_room_for(300));
|
||||
|
||||
window.reset();
|
||||
assert_eq!(window.used(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_window_progress_bar() {
|
||||
let mut window = ContextWindow::new(100);
|
||||
|
||||
window.add_tokens(50);
|
||||
let bar = window.progress_bar(10);
|
||||
assert!(bar.contains("====="));
|
||||
assert!(bar.contains("50.0%"));
|
||||
|
||||
window.add_tokens(40);
|
||||
let bar = window.progress_bar(10);
|
||||
assert!(bar.contains("========="));
|
||||
assert!(bar.contains("90.0%"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_window_saturation() {
|
||||
let mut window = ContextWindow::new(100);
|
||||
|
||||
// Adding more tokens than max should saturate, not overflow
|
||||
window.add_tokens(150);
|
||||
assert_eq!(window.used(), 150);
|
||||
assert_eq!(window.remaining(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_counter_constructors() {
|
||||
let counter1 = SimpleTokenCounter::default_8k();
|
||||
assert_eq!(counter1.max_context(), 8192);
|
||||
|
||||
let counter2 = SimpleTokenCounter::with_32k();
|
||||
assert_eq!(counter2.max_context(), 32768);
|
||||
|
||||
let counter3 = SimpleTokenCounter::with_128k();
|
||||
assert_eq!(counter3.max_context(), 131072);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_claude_counter_variants() {
|
||||
let haiku = ClaudeTokenCounter::haiku();
|
||||
assert_eq!(haiku.max_context(), 200_000);
|
||||
|
||||
let sonnet = ClaudeTokenCounter::sonnet();
|
||||
assert_eq!(sonnet.max_context(), 200_000);
|
||||
|
||||
let opus = ClaudeTokenCounter::opus();
|
||||
assert_eq!(opus.max_context(), 200_000);
|
||||
|
||||
let custom = ClaudeTokenCounter::with_context(100_000);
|
||||
assert_eq!(custom.max_context(), 100_000);
|
||||
}
|
||||
}
|
||||
@@ -6,11 +6,13 @@ license.workspace = true
|
||||
rust-version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
llm-core = { path = "../core" }
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
tokio = { version = "1.39", features = ["rt-multi-thread"] }
|
||||
tokio = { version = "1.39", features = ["rt-multi-thread", "macros"] }
|
||||
futures = "0.3"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
thiserror = "1"
|
||||
bytes = "1"
|
||||
tokio-stream = "0.1.17"
|
||||
async-trait = "0.1"
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
use crate::types::{ChatMessage, ChatResponseChunk};
|
||||
use futures::{Stream, TryStreamExt};
|
||||
use crate::types::{ChatMessage, ChatResponseChunk, Tool};
|
||||
use futures::{Stream, StreamExt, TryStreamExt};
|
||||
use reqwest::Client;
|
||||
use serde::Serialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use async_trait::async_trait;
|
||||
use llm_core::{
|
||||
LlmProvider, ProviderInfo, LlmError, ChatOptions, ChunkStream,
|
||||
ProviderStatus, AccountInfo, UsageStats, ModelInfo,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OllamaClient {
|
||||
http: Client,
|
||||
base_url: String, // e.g. "http://localhost:11434"
|
||||
api_key: Option<String>, // For Ollama Cloud authentication
|
||||
current_model: String, // Default model for this client
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
@@ -27,12 +33,24 @@ pub enum OllamaError {
|
||||
Protocol(String),
|
||||
}
|
||||
|
||||
// Convert OllamaError to LlmError
|
||||
impl From<OllamaError> for LlmError {
|
||||
fn from(err: OllamaError) -> Self {
|
||||
match err {
|
||||
OllamaError::Http(e) => LlmError::Http(e.to_string()),
|
||||
OllamaError::Json(e) => LlmError::Json(e.to_string()),
|
||||
OllamaError::Protocol(msg) => LlmError::Provider(msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OllamaClient {
|
||||
pub fn new(base_url: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
base_url: base_url.into().trim_end_matches('/').to_string(),
|
||||
api_key: None,
|
||||
current_model: "qwen3:8b".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,24 +59,32 @@ impl OllamaClient {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_model(mut self, model: impl Into<String>) -> Self {
|
||||
self.current_model = model.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_cloud() -> Self {
|
||||
// Same API, different base
|
||||
Self::new("https://ollama.com")
|
||||
}
|
||||
|
||||
pub async fn chat_stream(
|
||||
pub async fn chat_stream_raw(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
opts: &OllamaOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<impl Stream<Item = Result<ChatResponseChunk, OllamaError>>, OllamaError> {
|
||||
#[derive(Serialize)]
|
||||
struct Body<'a> {
|
||||
model: &'a str,
|
||||
messages: &'a [ChatMessage],
|
||||
stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<&'a [Tool]>,
|
||||
}
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
let body = Body {model: &opts.model, messages, stream: true};
|
||||
let body = Body {model: &opts.model, messages, stream: true, tools};
|
||||
let mut req = self.http.post(url).json(&body);
|
||||
|
||||
// Add Authorization header if API key is present
|
||||
@@ -96,3 +122,208 @@ impl OllamaClient {
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// LlmProvider Trait Implementation
|
||||
// ============================================================================
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for OllamaClient {
|
||||
fn name(&self) -> &str {
|
||||
"ollama"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.current_model
|
||||
}
|
||||
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
messages: &[llm_core::ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[llm_core::Tool]>,
|
||||
) -> Result<ChunkStream, LlmError> {
|
||||
// Convert llm_core messages to Ollama messages
|
||||
let ollama_messages: Vec<ChatMessage> = messages.iter().map(|m| m.into()).collect();
|
||||
|
||||
// Convert llm_core tools to Ollama tools if present
|
||||
let ollama_tools: Option<Vec<Tool>> = tools.map(|tools| {
|
||||
tools.iter().map(|t| Tool {
|
||||
tool_type: t.tool_type.clone(),
|
||||
function: crate::types::ToolFunction {
|
||||
name: t.function.name.clone(),
|
||||
description: t.function.description.clone(),
|
||||
parameters: crate::types::ToolParameters {
|
||||
param_type: t.function.parameters.param_type.clone(),
|
||||
properties: t.function.parameters.properties.clone(),
|
||||
required: t.function.parameters.required.clone(),
|
||||
},
|
||||
},
|
||||
}).collect()
|
||||
});
|
||||
|
||||
let opts = OllamaOptions {
|
||||
model: options.model.clone(),
|
||||
stream: true,
|
||||
};
|
||||
|
||||
// Make the request and build the body inline to avoid lifetime issues
|
||||
#[derive(Serialize)]
|
||||
struct Body<'a> {
|
||||
model: &'a str,
|
||||
messages: &'a [ChatMessage],
|
||||
stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<&'a [Tool]>,
|
||||
}
|
||||
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
let body = Body {
|
||||
model: &opts.model,
|
||||
messages: &ollama_messages,
|
||||
stream: true,
|
||||
tools: ollama_tools.as_deref(),
|
||||
};
|
||||
|
||||
let mut req = self.http.post(url).json(&body);
|
||||
|
||||
// Add Authorization header if API key is present
|
||||
if let Some(ref key) = self.api_key {
|
||||
req = req.header("Authorization", format!("Bearer {}", key));
|
||||
}
|
||||
|
||||
let resp = req.send().await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
let bytes_stream = resp.bytes_stream();
|
||||
|
||||
// NDJSON parser: split by '\n', parse each as JSON and stream the results
|
||||
let converted_stream = bytes_stream
|
||||
.map(|result| {
|
||||
result.map_err(|e| LlmError::Http(e.to_string()))
|
||||
})
|
||||
.map_ok(|bytes| {
|
||||
// Convert the chunk to a UTF-8 string and own it
|
||||
let txt = String::from_utf8_lossy(&bytes).into_owned();
|
||||
// Parse each non-empty line into a ChatResponseChunk
|
||||
let results: Vec<Result<llm_core::StreamChunk, LlmError>> = txt
|
||||
.lines()
|
||||
.filter_map(|line| {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
serde_json::from_str::<ChatResponseChunk>(trimmed)
|
||||
.map(|chunk| llm_core::StreamChunk::from(chunk))
|
||||
.map_err(|e| LlmError::Json(e.to_string())),
|
||||
)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
futures::stream::iter(results)
|
||||
})
|
||||
.try_flatten();
|
||||
|
||||
Ok(Box::pin(converted_stream))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ProviderInfo Trait Implementation
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct OllamaModelList {
|
||||
models: Vec<OllamaModel>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct OllamaModel {
|
||||
name: String,
|
||||
#[serde(default)]
|
||||
modified_at: Option<String>,
|
||||
#[serde(default)]
|
||||
size: Option<u64>,
|
||||
#[serde(default)]
|
||||
digest: Option<String>,
|
||||
#[serde(default)]
|
||||
details: Option<OllamaModelDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct OllamaModelDetails {
|
||||
#[serde(default)]
|
||||
format: Option<String>,
|
||||
#[serde(default)]
|
||||
family: Option<String>,
|
||||
#[serde(default)]
|
||||
parameter_size: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProviderInfo for OllamaClient {
|
||||
async fn status(&self) -> Result<ProviderStatus, LlmError> {
|
||||
// Try to ping the Ollama server
|
||||
let url = format!("{}/api/tags", self.base_url);
|
||||
let reachable = self.http.get(&url).send().await.is_ok();
|
||||
|
||||
Ok(ProviderStatus {
|
||||
provider: "ollama".to_string(),
|
||||
authenticated: self.api_key.is_some(),
|
||||
account: None, // Ollama is local, no account info
|
||||
model: self.current_model.clone(),
|
||||
endpoint: self.base_url.clone(),
|
||||
reachable,
|
||||
message: if reachable {
|
||||
Some("Connected to Ollama".to_string())
|
||||
} else {
|
||||
Some("Cannot reach Ollama server".to_string())
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError> {
|
||||
// Ollama is a local service, no account info
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError> {
|
||||
// Ollama doesn't track usage statistics
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
|
||||
let url = format!("{}/api/tags", self.base_url);
|
||||
let mut req = self.http.get(&url);
|
||||
|
||||
// Add Authorization header if API key is present
|
||||
if let Some(ref key) = self.api_key {
|
||||
req = req.header("Authorization", format!("Bearer {}", key));
|
||||
}
|
||||
|
||||
let resp = req.send().await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
let model_list: OllamaModelList = resp.json().await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
// Convert Ollama models to ModelInfo
|
||||
let models = model_list.models.into_iter().map(|m| {
|
||||
ModelInfo {
|
||||
id: m.name.clone(),
|
||||
display_name: Some(m.name.clone()),
|
||||
description: m.details.as_ref()
|
||||
.and_then(|d| d.family.as_ref())
|
||||
.map(|f| format!("{} model", f)),
|
||||
context_window: None, // Ollama doesn't provide this in list
|
||||
max_output_tokens: None,
|
||||
supports_tools: true, // Most Ollama models support tools
|
||||
supports_vision: false, // Would need to check model capabilities
|
||||
input_price_per_mtok: None, // Local models are free
|
||||
output_price_per_mtok: None,
|
||||
}
|
||||
}).collect();
|
||||
|
||||
Ok(models)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
pub mod client;
|
||||
pub mod types;
|
||||
|
||||
pub use client::{OllamaClient, OllamaOptions};
|
||||
pub use types::{ChatMessage, ChatResponseChunk};
|
||||
pub use client::{OllamaClient, OllamaOptions, OllamaError};
|
||||
pub use types::{ChatMessage, ChatResponseChunk, Tool, ToolCall, ToolFunction, ToolParameters, FunctionCall};
|
||||
|
||||
// Re-export llm-core traits and types for convenience
|
||||
pub use llm_core::{
|
||||
LlmProvider, ProviderInfo, LlmError,
|
||||
ChatOptions, StreamChunk, ToolCallDelta, Usage,
|
||||
ProviderStatus, AccountInfo, UsageStats, ModelInfo,
|
||||
Role,
|
||||
};
|
||||
|
||||
@@ -1,9 +1,51 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use llm_core::{StreamChunk, ToolCallDelta};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String, // "user", | "assistant" | "system"
|
||||
pub content: String,
|
||||
pub role: String, // "user" | "assistant" | "system" | "tool"
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
|
||||
pub call_type: Option<String>, // "function"
|
||||
pub function: FunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct FunctionCall {
|
||||
pub name: String,
|
||||
pub arguments: Value, // JSON object with arguments
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Tool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String, // "function"
|
||||
pub function: ToolFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolFunction {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: ToolParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolParameters {
|
||||
#[serde(rename = "type")]
|
||||
pub param_type: String, // "object"
|
||||
pub properties: Value,
|
||||
pub required: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
@@ -19,4 +61,70 @@ pub struct ChatResponseChunk {
|
||||
pub struct ChunkMessage {
|
||||
pub role: Option<String>,
|
||||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Conversions to/from llm-core types
|
||||
// ============================================================================
|
||||
|
||||
/// Convert from llm_core::ChatMessage to Ollama's ChatMessage
|
||||
impl From<&llm_core::ChatMessage> for ChatMessage {
|
||||
fn from(msg: &llm_core::ChatMessage) -> Self {
|
||||
let role = msg.role.as_str().to_string();
|
||||
|
||||
// Convert tool_calls if present
|
||||
let tool_calls = msg.tool_calls.as_ref().map(|calls| {
|
||||
calls.iter().map(|tc| ToolCall {
|
||||
id: Some(tc.id.clone()),
|
||||
call_type: Some(tc.call_type.clone()),
|
||||
function: FunctionCall {
|
||||
name: tc.function.name.clone(),
|
||||
arguments: tc.function.arguments.clone(),
|
||||
},
|
||||
}).collect()
|
||||
});
|
||||
|
||||
ChatMessage {
|
||||
role,
|
||||
content: msg.content.clone(),
|
||||
tool_calls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert from Ollama's ChatResponseChunk to llm_core::StreamChunk
|
||||
impl From<ChatResponseChunk> for StreamChunk {
|
||||
fn from(chunk: ChatResponseChunk) -> Self {
|
||||
let done = chunk.done.unwrap_or(false);
|
||||
let content = chunk.message.as_ref().and_then(|m| m.content.clone());
|
||||
|
||||
// Convert tool calls to deltas
|
||||
let tool_calls = chunk.message.as_ref().and_then(|m| {
|
||||
m.tool_calls.as_ref().map(|calls| {
|
||||
calls.iter().enumerate().map(|(index, tc)| {
|
||||
// Serialize arguments back to JSON string for delta
|
||||
let arguments_delta = serde_json::to_string(&tc.function.arguments).ok();
|
||||
|
||||
ToolCallDelta {
|
||||
index,
|
||||
id: tc.id.clone(),
|
||||
function_name: Some(tc.function.name.clone()),
|
||||
arguments_delta,
|
||||
}
|
||||
}).collect()
|
||||
})
|
||||
});
|
||||
|
||||
// Ollama doesn't provide per-chunk usage stats, only in final chunk
|
||||
let usage = None;
|
||||
|
||||
StreamChunk {
|
||||
content,
|
||||
tool_calls,
|
||||
done,
|
||||
usage,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
18
crates/llm/openai/Cargo.toml
Normal file
18
crates/llm/openai/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "llm-openai"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "OpenAI GPT API client for Owlen"
|
||||
|
||||
[dependencies]
|
||||
llm-core = { path = "../core" }
|
||||
async-trait = "0.1"
|
||||
futures = "0.3"
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
tokio = { version = "1", features = ["sync", "time", "io-util"] }
|
||||
tokio-stream = { version = "0.1", default-features = false, features = ["io-util"] }
|
||||
tokio-util = { version = "0.7", features = ["codec", "io"] }
|
||||
tracing = "0.1"
|
||||
285
crates/llm/openai/src/auth.rs
Normal file
285
crates/llm/openai/src/auth.rs
Normal file
@@ -0,0 +1,285 @@
|
||||
//! OpenAI OAuth Authentication
|
||||
//!
|
||||
//! Implements device code flow for authenticating with OpenAI without API keys.
|
||||
|
||||
use llm_core::{AuthMethod, DeviceAuthResult, DeviceCodeResponse, LlmError, OAuthProvider};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// OAuth client for OpenAI device flow
|
||||
pub struct OpenAIAuth {
|
||||
http: Client,
|
||||
client_id: String,
|
||||
}
|
||||
|
||||
// OpenAI OAuth endpoints
|
||||
const AUTH_BASE_URL: &str = "https://auth.openai.com";
|
||||
const DEVICE_CODE_ENDPOINT: &str = "/oauth/device/code";
|
||||
const TOKEN_ENDPOINT: &str = "/oauth/token";
|
||||
|
||||
// Default client ID for Owlen CLI
|
||||
const DEFAULT_CLIENT_ID: &str = "owlen-cli";
|
||||
|
||||
impl OpenAIAuth {
|
||||
/// Create a new OAuth client with the default CLI client ID
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
client_id: DEFAULT_CLIENT_ID.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with a custom client ID
|
||||
pub fn with_client_id(client_id: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
client_id: client_id.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OpenAIAuth {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct DeviceCodeRequest<'a> {
|
||||
client_id: &'a str,
|
||||
scope: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeviceCodeApiResponse {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
verification_uri: String,
|
||||
verification_uri_complete: Option<String>,
|
||||
expires_in: u64,
|
||||
interval: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct TokenRequest<'a> {
|
||||
client_id: &'a str,
|
||||
device_code: &'a str,
|
||||
grant_type: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenApiResponse {
|
||||
access_token: String,
|
||||
#[allow(dead_code)]
|
||||
token_type: String,
|
||||
expires_in: Option<u64>,
|
||||
refresh_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenErrorResponse {
|
||||
error: String,
|
||||
error_description: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OAuthProvider for OpenAIAuth {
|
||||
async fn start_device_auth(&self) -> Result<DeviceCodeResponse, LlmError> {
|
||||
let url = format!("{}{}", AUTH_BASE_URL, DEVICE_CODE_ENDPOINT);
|
||||
|
||||
let request = DeviceCodeRequest {
|
||||
client_id: &self.client_id,
|
||||
scope: "api.read api.write",
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(LlmError::Auth(format!(
|
||||
"Device code request failed ({}): {}",
|
||||
status, text
|
||||
)));
|
||||
}
|
||||
|
||||
let api_response: DeviceCodeApiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
Ok(DeviceCodeResponse {
|
||||
device_code: api_response.device_code,
|
||||
user_code: api_response.user_code,
|
||||
verification_uri: api_response.verification_uri,
|
||||
verification_uri_complete: api_response.verification_uri_complete,
|
||||
expires_in: api_response.expires_in,
|
||||
interval: api_response.interval,
|
||||
})
|
||||
}
|
||||
|
||||
async fn poll_device_auth(&self, device_code: &str) -> Result<DeviceAuthResult, LlmError> {
|
||||
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
|
||||
|
||||
let request = TokenRequest {
|
||||
client_id: &self.client_id,
|
||||
device_code,
|
||||
grant_type: "urn:ietf:params:oauth:grant-type:device_code",
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let token_response: TokenApiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
return Ok(DeviceAuthResult::Success {
|
||||
access_token: token_response.access_token,
|
||||
refresh_token: token_response.refresh_token,
|
||||
expires_in: token_response.expires_in,
|
||||
});
|
||||
}
|
||||
|
||||
// Parse error response
|
||||
let error_response: TokenErrorResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
match error_response.error.as_str() {
|
||||
"authorization_pending" => Ok(DeviceAuthResult::Pending),
|
||||
"slow_down" => Ok(DeviceAuthResult::Pending),
|
||||
"access_denied" => Ok(DeviceAuthResult::Denied),
|
||||
"expired_token" => Ok(DeviceAuthResult::Expired),
|
||||
_ => Err(LlmError::Auth(format!(
|
||||
"Token request failed: {} - {}",
|
||||
error_response.error,
|
||||
error_response.error_description.unwrap_or_default()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
async fn refresh_token(&self, refresh_token: &str) -> Result<AuthMethod, LlmError> {
|
||||
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct RefreshRequest<'a> {
|
||||
client_id: &'a str,
|
||||
refresh_token: &'a str,
|
||||
grant_type: &'a str,
|
||||
}
|
||||
|
||||
let request = RefreshRequest {
|
||||
client_id: &self.client_id,
|
||||
refresh_token,
|
||||
grant_type: "refresh_token",
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(LlmError::Auth(format!("Token refresh failed: {}", text)));
|
||||
}
|
||||
|
||||
let token_response: TokenApiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
let expires_at = token_response.expires_in.map(|secs| {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs() + secs)
|
||||
.unwrap_or(0)
|
||||
});
|
||||
|
||||
Ok(AuthMethod::OAuth {
|
||||
access_token: token_response.access_token,
|
||||
refresh_token: token_response.refresh_token,
|
||||
expires_at,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to perform the full device auth flow with polling
|
||||
pub async fn perform_device_auth<F>(
|
||||
auth: &OpenAIAuth,
|
||||
on_code: F,
|
||||
) -> Result<AuthMethod, LlmError>
|
||||
where
|
||||
F: FnOnce(&DeviceCodeResponse),
|
||||
{
|
||||
// Start the device flow
|
||||
let device_code = auth.start_device_auth().await?;
|
||||
|
||||
// Let caller display the code to user
|
||||
on_code(&device_code);
|
||||
|
||||
// Poll for completion
|
||||
let poll_interval = std::time::Duration::from_secs(device_code.interval);
|
||||
let deadline =
|
||||
std::time::Instant::now() + std::time::Duration::from_secs(device_code.expires_in);
|
||||
|
||||
loop {
|
||||
if std::time::Instant::now() > deadline {
|
||||
return Err(LlmError::Auth("Device code expired".to_string()));
|
||||
}
|
||||
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
|
||||
match auth.poll_device_auth(&device_code.device_code).await? {
|
||||
DeviceAuthResult::Success {
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_in,
|
||||
} => {
|
||||
let expires_at = expires_in.map(|secs| {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs() + secs)
|
||||
.unwrap_or(0)
|
||||
});
|
||||
|
||||
return Ok(AuthMethod::OAuth {
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_at,
|
||||
});
|
||||
}
|
||||
DeviceAuthResult::Pending => continue,
|
||||
DeviceAuthResult::Denied => {
|
||||
return Err(LlmError::Auth("Authorization denied by user".to_string()));
|
||||
}
|
||||
DeviceAuthResult::Expired => {
|
||||
return Err(LlmError::Auth("Device code expired".to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
561
crates/llm/openai/src/client.rs
Normal file
561
crates/llm/openai/src/client.rs
Normal file
@@ -0,0 +1,561 @@
|
||||
//! OpenAI GPT API Client
|
||||
//!
|
||||
//! Implements the Chat Completions API with streaming support.
|
||||
|
||||
use crate::types::*;
|
||||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use llm_core::{
|
||||
AccountInfo, AuthMethod, ChatMessage, ChatOptions, ChatResponse, ChunkStream, FunctionCall,
|
||||
LlmError, LlmProvider, ModelInfo, ProviderInfo, ProviderStatus, StreamChunk, Tool, ToolCall,
|
||||
ToolCallDelta, Usage, UsageStats,
|
||||
};
|
||||
use reqwest::Client;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio_stream::wrappers::LinesStream;
|
||||
use tokio_util::io::StreamReader;
|
||||
|
||||
const API_BASE_URL: &str = "https://api.openai.com/v1";
|
||||
const CHAT_ENDPOINT: &str = "/chat/completions";
|
||||
const MODELS_ENDPOINT: &str = "/models";
|
||||
|
||||
/// OpenAI GPT API client
|
||||
pub struct OpenAIClient {
|
||||
http: Client,
|
||||
auth: AuthMethod,
|
||||
model: String,
|
||||
}
|
||||
|
||||
impl OpenAIClient {
|
||||
/// Create a new client with API key authentication
|
||||
pub fn new(api_key: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
auth: AuthMethod::api_key(api_key),
|
||||
model: "gpt-4o".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new client with OAuth token
|
||||
pub fn with_oauth(access_token: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
auth: AuthMethod::oauth(access_token),
|
||||
model: "gpt-4o".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new client with full AuthMethod
|
||||
pub fn with_auth(auth: AuthMethod) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
auth,
|
||||
model: "gpt-4o".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the model to use
|
||||
pub fn with_model(mut self, model: impl Into<String>) -> Self {
|
||||
self.model = model.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Get current auth method (for token refresh)
|
||||
pub fn auth(&self) -> &AuthMethod {
|
||||
&self.auth
|
||||
}
|
||||
|
||||
/// Update the auth method (after refresh)
|
||||
pub fn set_auth(&mut self, auth: AuthMethod) {
|
||||
self.auth = auth;
|
||||
}
|
||||
|
||||
/// Convert messages to OpenAI format
|
||||
fn prepare_messages(messages: &[ChatMessage]) -> Vec<OpenAIMessage> {
|
||||
messages.iter().map(OpenAIMessage::from).collect()
|
||||
}
|
||||
|
||||
/// Convert tools to OpenAI format
|
||||
fn prepare_tools(tools: Option<&[Tool]>) -> Option<Vec<OpenAITool>> {
|
||||
tools.map(|t| t.iter().map(OpenAITool::from).collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for OpenAIClient {
|
||||
fn name(&self) -> &str {
|
||||
"openai"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChunkStream, LlmError> {
|
||||
let url = format!("{}{}", API_BASE_URL, CHAT_ENDPOINT);
|
||||
|
||||
let model = if options.model.is_empty() {
|
||||
&self.model
|
||||
} else {
|
||||
&options.model
|
||||
};
|
||||
|
||||
let openai_messages = Self::prepare_messages(messages);
|
||||
let openai_tools = Self::prepare_tools(tools);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model,
|
||||
messages: openai_messages,
|
||||
temperature: options.temperature,
|
||||
max_tokens: options.max_tokens,
|
||||
top_p: options.top_p,
|
||||
stop: options.stop.as_deref(),
|
||||
tools: openai_tools,
|
||||
tool_choice: None,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let bearer = self
|
||||
.auth
|
||||
.bearer_token()
|
||||
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", bearer))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
|
||||
return Err(LlmError::RateLimit {
|
||||
retry_after_secs: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Try to parse as error response
|
||||
if let Ok(err_resp) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
return Err(LlmError::Api {
|
||||
message: err_resp.error.message,
|
||||
code: err_resp.error.code,
|
||||
});
|
||||
}
|
||||
|
||||
return Err(LlmError::Api {
|
||||
message: text,
|
||||
code: Some(status.to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
// Parse SSE stream
|
||||
let byte_stream = response
|
||||
.bytes_stream()
|
||||
.map(|result| result.map_err(std::io::Error::other));
|
||||
|
||||
let reader = StreamReader::new(byte_stream);
|
||||
let buf_reader = tokio::io::BufReader::new(reader);
|
||||
let lines_stream = LinesStream::new(buf_reader.lines());
|
||||
|
||||
let chunk_stream = lines_stream.filter_map(|line_result| async move {
|
||||
match line_result {
|
||||
Ok(line) => parse_sse_line(&line),
|
||||
Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Box::pin(chunk_stream))
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
let url = format!("{}{}", API_BASE_URL, CHAT_ENDPOINT);
|
||||
|
||||
let model = if options.model.is_empty() {
|
||||
&self.model
|
||||
} else {
|
||||
&options.model
|
||||
};
|
||||
|
||||
let openai_messages = Self::prepare_messages(messages);
|
||||
let openai_tools = Self::prepare_tools(tools);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model,
|
||||
messages: openai_messages,
|
||||
temperature: options.temperature,
|
||||
max_tokens: options.max_tokens,
|
||||
top_p: options.top_p,
|
||||
stop: options.stop.as_deref(),
|
||||
tools: openai_tools,
|
||||
tool_choice: None,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let bearer = self
|
||||
.auth
|
||||
.bearer_token()
|
||||
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", bearer))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
|
||||
return Err(LlmError::RateLimit {
|
||||
retry_after_secs: None,
|
||||
});
|
||||
}
|
||||
|
||||
if let Ok(err_resp) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
return Err(LlmError::Api {
|
||||
message: err_resp.error.message,
|
||||
code: err_resp.error.code,
|
||||
});
|
||||
}
|
||||
|
||||
return Err(LlmError::Api {
|
||||
message: text,
|
||||
code: Some(status.to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
let api_response: ChatCompletionResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
// Extract the first choice
|
||||
let choice = api_response
|
||||
.choices
|
||||
.first()
|
||||
.ok_or_else(|| LlmError::Api {
|
||||
message: "No choices in response".to_string(),
|
||||
code: None,
|
||||
})?;
|
||||
|
||||
let content = choice.message.content.clone();
|
||||
|
||||
let tool_calls = choice.message.tool_calls.as_ref().map(|calls| {
|
||||
calls
|
||||
.iter()
|
||||
.map(|call| {
|
||||
let arguments: serde_json::Value =
|
||||
serde_json::from_str(&call.function.arguments).unwrap_or_default();
|
||||
|
||||
ToolCall {
|
||||
id: call.id.clone(),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: call.function.name.clone(),
|
||||
arguments,
|
||||
},
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
|
||||
let usage = api_response.usage.map(|u| Usage {
|
||||
prompt_tokens: u.prompt_tokens,
|
||||
completion_tokens: u.completion_tokens,
|
||||
total_tokens: u.total_tokens,
|
||||
});
|
||||
|
||||
Ok(ChatResponse {
|
||||
content,
|
||||
tool_calls,
|
||||
usage,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a single SSE line into a StreamChunk
|
||||
fn parse_sse_line(line: &str) -> Option<Result<StreamChunk, LlmError>> {
|
||||
let line = line.trim();
|
||||
|
||||
// Skip empty lines and comments
|
||||
if line.is_empty() || line.starts_with(':') {
|
||||
return None;
|
||||
}
|
||||
|
||||
// SSE format: "data: <json>"
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
// OpenAI sends [DONE] to signal end
|
||||
if data == "[DONE]" {
|
||||
return Some(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
done: true,
|
||||
usage: None,
|
||||
}));
|
||||
}
|
||||
|
||||
// Parse the JSON chunk
|
||||
match serde_json::from_str::<ChatCompletionChunk>(data) {
|
||||
Ok(chunk) => Some(convert_chunk_to_stream_chunk(chunk)),
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse SSE chunk: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert OpenAI chunk to our common format
|
||||
fn convert_chunk_to_stream_chunk(chunk: ChatCompletionChunk) -> Result<StreamChunk, LlmError> {
|
||||
let choice = chunk.choices.first();
|
||||
|
||||
if let Some(choice) = choice {
|
||||
let content = choice.delta.content.clone();
|
||||
|
||||
let tool_calls = choice.delta.tool_calls.as_ref().map(|deltas| {
|
||||
deltas
|
||||
.iter()
|
||||
.map(|delta| ToolCallDelta {
|
||||
index: delta.index,
|
||||
id: delta.id.clone(),
|
||||
function_name: delta.function.as_ref().and_then(|f| f.name.clone()),
|
||||
arguments_delta: delta.function.as_ref().and_then(|f| f.arguments.clone()),
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
|
||||
let done = choice.finish_reason.is_some();
|
||||
|
||||
Ok(StreamChunk {
|
||||
content,
|
||||
tool_calls,
|
||||
done,
|
||||
usage: None,
|
||||
})
|
||||
} else {
|
||||
// No choices, treat as done
|
||||
Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
done: true,
|
||||
usage: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ProviderInfo Implementation
|
||||
// ============================================================================
|
||||
|
||||
/// Known GPT models with their specifications
|
||||
fn get_gpt_models() -> Vec<ModelInfo> {
|
||||
vec![
|
||||
ModelInfo {
|
||||
id: "gpt-4o".to_string(),
|
||||
display_name: Some("GPT-4o".to_string()),
|
||||
description: Some("Most advanced multimodal model with vision".to_string()),
|
||||
context_window: Some(128_000),
|
||||
max_output_tokens: Some(16_384),
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(2.50),
|
||||
output_price_per_mtok: Some(10.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "gpt-4o-mini".to_string(),
|
||||
display_name: Some("GPT-4o mini".to_string()),
|
||||
description: Some("Affordable and fast model for simple tasks".to_string()),
|
||||
context_window: Some(128_000),
|
||||
max_output_tokens: Some(16_384),
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(0.15),
|
||||
output_price_per_mtok: Some(0.60),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "gpt-4-turbo".to_string(),
|
||||
display_name: Some("GPT-4 Turbo".to_string()),
|
||||
description: Some("Previous generation high-performance model".to_string()),
|
||||
context_window: Some(128_000),
|
||||
max_output_tokens: Some(4_096),
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(10.0),
|
||||
output_price_per_mtok: Some(30.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "gpt-3.5-turbo".to_string(),
|
||||
display_name: Some("GPT-3.5 Turbo".to_string()),
|
||||
description: Some("Fast and affordable for simple tasks".to_string()),
|
||||
context_window: Some(16_385),
|
||||
max_output_tokens: Some(4_096),
|
||||
supports_tools: true,
|
||||
supports_vision: false,
|
||||
input_price_per_mtok: Some(0.50),
|
||||
output_price_per_mtok: Some(1.50),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "o1".to_string(),
|
||||
display_name: Some("OpenAI o1".to_string()),
|
||||
description: Some("Reasoning model optimized for complex problems".to_string()),
|
||||
context_window: Some(200_000),
|
||||
max_output_tokens: Some(100_000),
|
||||
supports_tools: false,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(15.0),
|
||||
output_price_per_mtok: Some(60.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "o1-mini".to_string(),
|
||||
display_name: Some("OpenAI o1-mini".to_string()),
|
||||
description: Some("Faster reasoning model for STEM".to_string()),
|
||||
context_window: Some(128_000),
|
||||
max_output_tokens: Some(65_536),
|
||||
supports_tools: false,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(3.0),
|
||||
output_price_per_mtok: Some(12.0),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProviderInfo for OpenAIClient {
|
||||
async fn status(&self) -> Result<ProviderStatus, LlmError> {
|
||||
let authenticated = self.auth.bearer_token().is_some();
|
||||
|
||||
// Try to reach the API by listing models
|
||||
let reachable = if authenticated {
|
||||
let url = format!("{}{}", API_BASE_URL, MODELS_ENDPOINT);
|
||||
let bearer = self.auth.bearer_token().unwrap();
|
||||
|
||||
match self
|
||||
.http
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", bearer))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp.status().is_success(),
|
||||
Err(_) => false,
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
let message = if !authenticated {
|
||||
Some("Not authenticated - set OPENAI_API_KEY or run 'owlen login openai'".to_string())
|
||||
} else if !reachable {
|
||||
Some("Cannot reach OpenAI API".to_string())
|
||||
} else {
|
||||
Some("Connected".to_string())
|
||||
};
|
||||
|
||||
Ok(ProviderStatus {
|
||||
provider: "openai".to_string(),
|
||||
authenticated,
|
||||
account: None, // OpenAI doesn't expose account info via API
|
||||
model: self.model.clone(),
|
||||
endpoint: API_BASE_URL.to_string(),
|
||||
reachable,
|
||||
message,
|
||||
})
|
||||
}
|
||||
|
||||
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError> {
|
||||
// OpenAI doesn't have a public account info endpoint
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError> {
|
||||
// OpenAI doesn't expose usage stats via the standard API
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
|
||||
// We can optionally fetch from API, but return known models for now
|
||||
Ok(get_gpt_models())
|
||||
}
|
||||
|
||||
async fn model_info(&self, model_id: &str) -> Result<Option<ModelInfo>, LlmError> {
|
||||
let models = get_gpt_models();
|
||||
Ok(models.into_iter().find(|m| m.id == model_id))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use llm_core::ToolParameters;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_message_conversion() {
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are helpful"),
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::assistant("Hi there!"),
|
||||
];
|
||||
|
||||
let openai_msgs = OpenAIClient::prepare_messages(&messages);
|
||||
|
||||
assert_eq!(openai_msgs.len(), 3);
|
||||
assert_eq!(openai_msgs[0].role, "system");
|
||||
assert_eq!(openai_msgs[1].role, "user");
|
||||
assert_eq!(openai_msgs[2].role, "assistant");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_conversion() {
|
||||
let tools = vec![Tool::function(
|
||||
"read_file",
|
||||
"Read a file's contents",
|
||||
ToolParameters::object(
|
||||
json!({
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path"
|
||||
}
|
||||
}),
|
||||
vec!["path".to_string()],
|
||||
),
|
||||
)];
|
||||
|
||||
let openai_tools = OpenAIClient::prepare_tools(Some(&tools)).unwrap();
|
||||
|
||||
assert_eq!(openai_tools.len(), 1);
|
||||
assert_eq!(openai_tools[0].function.name, "read_file");
|
||||
assert_eq!(
|
||||
openai_tools[0].function.description,
|
||||
"Read a file's contents"
|
||||
);
|
||||
}
|
||||
}
|
||||
12
crates/llm/openai/src/lib.rs
Normal file
12
crates/llm/openai/src/lib.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
//! OpenAI GPT API Client
|
||||
//!
|
||||
//! Implements the LlmProvider trait for OpenAI's GPT models.
|
||||
//! Supports both API key authentication and OAuth device flow.
|
||||
|
||||
mod auth;
|
||||
mod client;
|
||||
mod types;
|
||||
|
||||
pub use auth::*;
|
||||
pub use client::*;
|
||||
pub use types::*;
|
||||
285
crates/llm/openai/src/types.rs
Normal file
285
crates/llm/openai/src/types.rs
Normal file
@@ -0,0 +1,285 @@
|
||||
//! OpenAI API request/response types
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
// ============================================================================
|
||||
// Request Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ChatCompletionRequest<'a> {
|
||||
pub model: &'a str,
|
||||
pub messages: Vec<OpenAIMessage>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<u32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<&'a [String]>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<OpenAITool>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<&'a str>,
|
||||
|
||||
pub stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIMessage {
|
||||
pub role: String, // "system", "user", "assistant", "tool"
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<OpenAIToolCall>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIToolCall {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub call_type: String,
|
||||
pub function: OpenAIFunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIFunctionCall {
|
||||
pub name: String,
|
||||
pub arguments: String, // JSON string
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAITool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String,
|
||||
pub function: OpenAIFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIFunction {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: FunctionParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionParameters {
|
||||
#[serde(rename = "type")]
|
||||
pub param_type: String,
|
||||
pub properties: Value,
|
||||
pub required: Vec<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Response Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<Choice>,
|
||||
pub usage: Option<UsageInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Choice {
|
||||
pub index: u32,
|
||||
pub message: OpenAIMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct UsageInfo {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Streaming Response Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ChatCompletionChunk {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChunkChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ChunkChoice {
|
||||
pub index: u32,
|
||||
pub delta: Delta,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Delta {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub role: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<DeltaToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct DeltaToolCall {
|
||||
pub index: usize,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename = "type")]
|
||||
pub call_type: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function: Option<DeltaFunction>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct DeltaFunction {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Error Response Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ErrorResponse {
|
||||
pub error: ApiError,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ApiError {
|
||||
pub message: String,
|
||||
#[serde(rename = "type")]
|
||||
pub error_type: String,
|
||||
pub code: Option<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Models List Response
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ModelsResponse {
|
||||
pub object: String,
|
||||
pub data: Vec<ModelData>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ModelData {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub owned_by: String,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Conversions
|
||||
// ============================================================================
|
||||
|
||||
impl From<&llm_core::Tool> for OpenAITool {
|
||||
fn from(tool: &llm_core::Tool) -> Self {
|
||||
Self {
|
||||
tool_type: "function".to_string(),
|
||||
function: OpenAIFunction {
|
||||
name: tool.function.name.clone(),
|
||||
description: tool.function.description.clone(),
|
||||
parameters: FunctionParameters {
|
||||
param_type: tool.function.parameters.param_type.clone(),
|
||||
properties: tool.function.parameters.properties.clone(),
|
||||
required: tool.function.parameters.required.clone(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&llm_core::ChatMessage> for OpenAIMessage {
|
||||
fn from(msg: &llm_core::ChatMessage) -> Self {
|
||||
use llm_core::Role;
|
||||
|
||||
let role = match msg.role {
|
||||
Role::System => "system",
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
Role::Tool => "tool",
|
||||
};
|
||||
|
||||
// Handle tool result messages
|
||||
if msg.role == Role::Tool {
|
||||
return Self {
|
||||
role: "tool".to_string(),
|
||||
content: msg.content.clone(),
|
||||
tool_calls: None,
|
||||
tool_call_id: msg.tool_call_id.clone(),
|
||||
name: msg.name.clone(),
|
||||
};
|
||||
}
|
||||
|
||||
// Handle assistant messages with tool calls
|
||||
if msg.role == Role::Assistant && msg.tool_calls.is_some() {
|
||||
let tool_calls = msg.tool_calls.as_ref().map(|calls| {
|
||||
calls
|
||||
.iter()
|
||||
.map(|call| OpenAIToolCall {
|
||||
id: call.id.clone(),
|
||||
call_type: "function".to_string(),
|
||||
function: OpenAIFunctionCall {
|
||||
name: call.function.name.clone(),
|
||||
arguments: serde_json::to_string(&call.function.arguments)
|
||||
.unwrap_or_else(|_| "{}".to_string()),
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
|
||||
return Self {
|
||||
role: "assistant".to_string(),
|
||||
content: msg.content.clone(),
|
||||
tool_calls,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
};
|
||||
}
|
||||
|
||||
// Simple text message
|
||||
Self {
|
||||
role: role.to_string(),
|
||||
content: msg.content.clone(),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ serde = { version = "1", features = ["derive"] }
|
||||
directories = "5"
|
||||
figment = { version = "0.10", features = ["toml", "env"] }
|
||||
permissions = { path = "../permissions" }
|
||||
llm-core = { path = "../../llm/core" }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.23.0"
|
||||
|
||||
@@ -5,26 +5,65 @@ use figment::{
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use std::env;
|
||||
use permissions::{Mode, PermissionManager};
|
||||
use llm_core::ProviderType;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Settings {
|
||||
#[serde(default = "default_ollama_url")]
|
||||
pub ollama_url: String,
|
||||
// Provider configuration
|
||||
#[serde(default = "default_provider")]
|
||||
pub provider: String, // "ollama" | "anthropic" | "openai"
|
||||
|
||||
#[serde(default = "default_model")]
|
||||
pub model: String,
|
||||
#[serde(default = "default_mode")]
|
||||
pub mode: String, // "plan" (read-only) for now
|
||||
|
||||
// Ollama-specific
|
||||
#[serde(default = "default_ollama_url")]
|
||||
pub ollama_url: String,
|
||||
|
||||
// API keys for different providers
|
||||
#[serde(default)]
|
||||
pub api_key: Option<String>, // For Ollama Cloud or other API authentication
|
||||
pub api_key: Option<String>, // For Ollama Cloud or backwards compatibility
|
||||
|
||||
#[serde(default)]
|
||||
pub anthropic_api_key: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub openai_api_key: Option<String>,
|
||||
|
||||
// Permission mode
|
||||
#[serde(default = "default_mode")]
|
||||
pub mode: String, // "plan" | "acceptEdits" | "code"
|
||||
|
||||
// Tool permission lists
|
||||
/// Tools that are always allowed without prompting
|
||||
/// Format: "tool_name" or "tool_name:pattern"
|
||||
/// Example: ["bash:npm test:*", "bash:cargo test:*", "mcp:filesystem__*"]
|
||||
#[serde(default)]
|
||||
pub allowed_tools: Vec<String>,
|
||||
|
||||
/// Tools that are always denied (blocked)
|
||||
/// Format: "tool_name" or "tool_name:pattern"
|
||||
/// Example: ["bash:rm -rf*", "bash:sudo*"]
|
||||
#[serde(default)]
|
||||
pub disallowed_tools: Vec<String>,
|
||||
}
|
||||
|
||||
fn default_provider() -> String {
|
||||
"ollama".into()
|
||||
}
|
||||
|
||||
fn default_ollama_url() -> String {
|
||||
"http://localhost:11434".into()
|
||||
}
|
||||
|
||||
fn default_model() -> String {
|
||||
// Default model depends on provider, but we use ollama's default here
|
||||
// Users can override this per-provider or use get_effective_model()
|
||||
"qwen3:8b".into()
|
||||
}
|
||||
|
||||
fn default_mode() -> String {
|
||||
"plan".into()
|
||||
}
|
||||
@@ -32,25 +71,71 @@ fn default_mode() -> String {
|
||||
impl Default for Settings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
ollama_url: default_ollama_url(),
|
||||
provider: default_provider(),
|
||||
model: default_model(),
|
||||
mode: default_mode(),
|
||||
ollama_url: default_ollama_url(),
|
||||
api_key: None,
|
||||
anthropic_api_key: None,
|
||||
openai_api_key: None,
|
||||
mode: default_mode(),
|
||||
allowed_tools: Vec::new(),
|
||||
disallowed_tools: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Settings {
|
||||
/// Create a PermissionManager based on the configured mode
|
||||
/// Create a PermissionManager based on the configured mode and tool lists
|
||||
///
|
||||
/// Tool lists are applied in order:
|
||||
/// 1. Disallowed tools (highest priority - blocked first)
|
||||
/// 2. Allowed tools
|
||||
/// 3. Mode-based defaults
|
||||
pub fn create_permission_manager(&self) -> PermissionManager {
|
||||
let mode = Mode::from_str(&self.mode).unwrap_or(Mode::Plan);
|
||||
PermissionManager::new(mode)
|
||||
let mut pm = PermissionManager::new(mode);
|
||||
|
||||
// Add disallowed tools first (deny rules take precedence)
|
||||
pm.add_disallowed_tools(&self.disallowed_tools);
|
||||
|
||||
// Then add allowed tools
|
||||
pm.add_allowed_tools(&self.allowed_tools);
|
||||
|
||||
pm
|
||||
}
|
||||
|
||||
/// Get the Mode enum from the mode string
|
||||
pub fn get_mode(&self) -> Mode {
|
||||
Mode::from_str(&self.mode).unwrap_or(Mode::Plan)
|
||||
}
|
||||
|
||||
/// Get the ProviderType enum from the provider string
|
||||
pub fn get_provider(&self) -> Option<ProviderType> {
|
||||
ProviderType::from_str(&self.provider)
|
||||
}
|
||||
|
||||
/// Get the effective model for the current provider
|
||||
/// If no model is explicitly set, returns the provider's default
|
||||
pub fn get_effective_model(&self) -> String {
|
||||
// If model is explicitly set and not the default, use it
|
||||
if self.model != default_model() {
|
||||
return self.model.clone();
|
||||
}
|
||||
|
||||
// Otherwise, use provider-specific default
|
||||
self.get_provider()
|
||||
.map(|p| p.default_model().to_string())
|
||||
.unwrap_or_else(|| self.model.clone())
|
||||
}
|
||||
|
||||
/// Get the API key for the current provider
|
||||
pub fn get_provider_api_key(&self) -> Option<String> {
|
||||
match self.get_provider()? {
|
||||
ProviderType::Ollama => self.api_key.clone(),
|
||||
ProviderType::Anthropic => self.anthropic_api_key.clone(),
|
||||
ProviderType::OpenAI => self.openai_api_key.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_settings(project_root: Option<&str>) -> Result<Settings, figment::Error> {
|
||||
@@ -68,9 +153,31 @@ pub fn load_settings(project_root: Option<&str>) -> Result<Settings, figment::Er
|
||||
}
|
||||
|
||||
// Environment variables have highest precedence
|
||||
// OWLEN_* prefix (e.g., OWLEN_PROVIDER, OWLEN_MODEL, OWLEN_API_KEY, OWLEN_ANTHROPIC_API_KEY)
|
||||
fig = fig.merge(Env::prefixed("OWLEN_").split("__"));
|
||||
// Support OLLAMA_API_KEY, OLLAMA_MODEL, etc. (without nesting)
|
||||
|
||||
// Support OLLAMA_* prefix for backwards compatibility
|
||||
fig = fig.merge(Env::prefixed("OLLAMA_"));
|
||||
|
||||
fig.extract()
|
||||
// Support PROVIDER env var (without OWLEN_ prefix)
|
||||
fig = fig.merge(Env::raw().only(&["PROVIDER"]));
|
||||
|
||||
// Extract the settings
|
||||
let mut settings: Settings = fig.extract()?;
|
||||
|
||||
// Manually handle standard provider API key env vars (ANTHROPIC_API_KEY, OPENAI_API_KEY)
|
||||
// These override config files but are overridden by OWLEN_* vars
|
||||
if settings.anthropic_api_key.is_none() {
|
||||
if let Ok(key) = env::var("ANTHROPIC_API_KEY") {
|
||||
settings.anthropic_api_key = Some(key);
|
||||
}
|
||||
}
|
||||
|
||||
if settings.openai_api_key.is_none() {
|
||||
if let Ok(key) = env::var("OPENAI_API_KEY") {
|
||||
settings.openai_api_key = Some(key);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(settings)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use config_agent::{load_settings, Settings};
|
||||
use permissions::{Mode, PermissionDecision, Tool};
|
||||
use llm_core::ProviderType;
|
||||
use std::{env, fs};
|
||||
|
||||
#[test]
|
||||
@@ -45,4 +46,189 @@ fn settings_parse_mode_from_config() {
|
||||
// Code mode should allow everything
|
||||
assert_eq!(mgr.check(Tool::Write, None), PermissionDecision::Allow);
|
||||
assert_eq!(mgr.check(Tool::Bash, None), PermissionDecision::Allow);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_provider_is_ollama() {
|
||||
let s = Settings::default();
|
||||
assert_eq!(s.provider, "ollama");
|
||||
assert_eq!(s.get_provider(), Some(ProviderType::Ollama));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_from_config_file() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let project_file = tmp.path().join(".owlen.toml");
|
||||
fs::write(&project_file, r#"provider="anthropic""#).unwrap();
|
||||
|
||||
let s = load_settings(Some(tmp.path().to_str().unwrap())).unwrap();
|
||||
assert_eq!(s.provider, "anthropic");
|
||||
assert_eq!(s.get_provider(), Some(ProviderType::Anthropic));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // Ignore due to env var interaction in parallel tests
|
||||
fn provider_from_env_var() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
|
||||
unsafe {
|
||||
env::set_var("OWLEN_PROVIDER", "openai");
|
||||
env::remove_var("PROVIDER");
|
||||
env::remove_var("ANTHROPIC_API_KEY");
|
||||
env::remove_var("OPENAI_API_KEY");
|
||||
}
|
||||
let s = load_settings(Some(tmp.path().to_str().unwrap())).unwrap();
|
||||
assert_eq!(s.provider, "openai");
|
||||
assert_eq!(s.get_provider(), Some(ProviderType::OpenAI));
|
||||
unsafe { env::remove_var("OWLEN_PROVIDER"); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // Ignore due to env var interaction in parallel tests
|
||||
fn provider_from_provider_env_var() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
|
||||
unsafe {
|
||||
env::set_var("PROVIDER", "anthropic");
|
||||
env::remove_var("OWLEN_PROVIDER");
|
||||
env::remove_var("ANTHROPIC_API_KEY");
|
||||
env::remove_var("OPENAI_API_KEY");
|
||||
}
|
||||
let s = load_settings(Some(tmp.path().to_str().unwrap())).unwrap();
|
||||
assert_eq!(s.provider, "anthropic");
|
||||
assert_eq!(s.get_provider(), Some(ProviderType::Anthropic));
|
||||
unsafe { env::remove_var("PROVIDER"); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn anthropic_api_key_from_owlen_env() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let project_file = tmp.path().join(".owlen.toml");
|
||||
fs::write(&project_file, r#"provider="anthropic""#).unwrap();
|
||||
|
||||
unsafe { env::set_var("OWLEN_ANTHROPIC_API_KEY", "sk-ant-test123"); }
|
||||
let s = load_settings(Some(tmp.path().to_str().unwrap())).unwrap();
|
||||
assert_eq!(s.anthropic_api_key, Some("sk-ant-test123".to_string()));
|
||||
assert_eq!(s.get_provider_api_key(), Some("sk-ant-test123".to_string()));
|
||||
unsafe { env::remove_var("OWLEN_ANTHROPIC_API_KEY"); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn openai_api_key_from_owlen_env() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let project_file = tmp.path().join(".owlen.toml");
|
||||
fs::write(&project_file, r#"provider="openai""#).unwrap();
|
||||
|
||||
unsafe { env::set_var("OWLEN_OPENAI_API_KEY", "sk-test-456"); }
|
||||
let s = load_settings(Some(tmp.path().to_str().unwrap())).unwrap();
|
||||
assert_eq!(s.openai_api_key, Some("sk-test-456".to_string()));
|
||||
assert_eq!(s.get_provider_api_key(), Some("sk-test-456".to_string()));
|
||||
unsafe { env::remove_var("OWLEN_OPENAI_API_KEY"); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // Ignore due to env var interaction in parallel tests
|
||||
fn api_keys_from_config_file() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let project_file = tmp.path().join(".owlen.toml");
|
||||
fs::write(&project_file, r#"
|
||||
provider = "anthropic"
|
||||
anthropic_api_key = "sk-ant-from-file"
|
||||
openai_api_key = "sk-openai-from-file"
|
||||
"#).unwrap();
|
||||
|
||||
// Clear any env vars that might interfere
|
||||
unsafe {
|
||||
env::remove_var("ANTHROPIC_API_KEY");
|
||||
env::remove_var("OPENAI_API_KEY");
|
||||
env::remove_var("OWLEN_ANTHROPIC_API_KEY");
|
||||
env::remove_var("OWLEN_OPENAI_API_KEY");
|
||||
}
|
||||
|
||||
let s = load_settings(Some(tmp.path().to_str().unwrap())).unwrap();
|
||||
assert_eq!(s.anthropic_api_key, Some("sk-ant-from-file".to_string()));
|
||||
assert_eq!(s.openai_api_key, Some("sk-openai-from-file".to_string()));
|
||||
assert_eq!(s.get_provider_api_key(), Some("sk-ant-from-file".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // Ignore due to env var interaction in parallel tests
|
||||
fn anthropic_api_key_from_standard_env() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let project_file = tmp.path().join(".owlen.toml");
|
||||
fs::write(&project_file, r#"provider="anthropic""#).unwrap();
|
||||
|
||||
unsafe {
|
||||
env::set_var("ANTHROPIC_API_KEY", "sk-ant-std");
|
||||
env::remove_var("OWLEN_ANTHROPIC_API_KEY");
|
||||
env::remove_var("PROVIDER");
|
||||
env::remove_var("OWLEN_PROVIDER");
|
||||
}
|
||||
let s = load_settings(Some(tmp.path().to_str().unwrap())).unwrap();
|
||||
assert_eq!(s.anthropic_api_key, Some("sk-ant-std".to_string()));
|
||||
assert_eq!(s.get_provider_api_key(), Some("sk-ant-std".to_string()));
|
||||
unsafe { env::remove_var("ANTHROPIC_API_KEY"); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // Ignore due to env var interaction in parallel tests
|
||||
fn openai_api_key_from_standard_env() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let project_file = tmp.path().join(".owlen.toml");
|
||||
fs::write(&project_file, r#"provider="openai""#).unwrap();
|
||||
|
||||
unsafe {
|
||||
env::set_var("OPENAI_API_KEY", "sk-openai-std");
|
||||
env::remove_var("OWLEN_OPENAI_API_KEY");
|
||||
env::remove_var("PROVIDER");
|
||||
env::remove_var("OWLEN_PROVIDER");
|
||||
}
|
||||
let s = load_settings(Some(tmp.path().to_str().unwrap())).unwrap();
|
||||
assert_eq!(s.openai_api_key, Some("sk-openai-std".to_string()));
|
||||
assert_eq!(s.get_provider_api_key(), Some("sk-openai-std".to_string()));
|
||||
unsafe { env::remove_var("OPENAI_API_KEY"); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // Ignore due to env var interaction in parallel tests
|
||||
fn owlen_prefix_overrides_standard_env() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
|
||||
unsafe {
|
||||
env::set_var("ANTHROPIC_API_KEY", "sk-ant-std");
|
||||
env::set_var("OWLEN_ANTHROPIC_API_KEY", "sk-ant-owlen");
|
||||
}
|
||||
let s = load_settings(Some(tmp.path().to_str().unwrap())).unwrap();
|
||||
// OWLEN_ prefix should take precedence
|
||||
assert_eq!(s.anthropic_api_key, Some("sk-ant-owlen".to_string()));
|
||||
unsafe {
|
||||
env::remove_var("ANTHROPIC_API_KEY");
|
||||
env::remove_var("OWLEN_ANTHROPIC_API_KEY");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effective_model_uses_provider_default() {
|
||||
// Test Anthropic provider default
|
||||
let mut s = Settings::default();
|
||||
s.provider = "anthropic".to_string();
|
||||
assert_eq!(s.get_effective_model(), "claude-sonnet-4-20250514");
|
||||
|
||||
// Test OpenAI provider default
|
||||
s.provider = "openai".to_string();
|
||||
assert_eq!(s.get_effective_model(), "gpt-4o");
|
||||
|
||||
// Test Ollama provider default
|
||||
s.provider = "ollama".to_string();
|
||||
assert_eq!(s.get_effective_model(), "qwen3:8b");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn effective_model_respects_explicit_model() {
|
||||
let mut s = Settings::default();
|
||||
s.provider = "anthropic".to_string();
|
||||
s.model = "claude-opus-4-20250514".to_string();
|
||||
|
||||
// Should use explicit model, not provider default
|
||||
assert_eq!(s.get_effective_model(), "claude-opus-4-20250514");
|
||||
}
|
||||
@@ -10,6 +10,7 @@ serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
tokio = { version = "1.39", features = ["process", "time", "io-util"] }
|
||||
color-eyre = "0.6"
|
||||
regex = "1.10"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.23.0"
|
||||
|
||||
@@ -34,6 +34,34 @@ pub enum HookEvent {
|
||||
prompt: String,
|
||||
},
|
||||
PreCompact,
|
||||
/// Called before the agent stops - allows validation of completion
|
||||
#[serde(rename_all = "camelCase")]
|
||||
Stop {
|
||||
/// Reason for stopping (e.g., "task_complete", "max_iterations", "user_interrupt")
|
||||
reason: String,
|
||||
/// Number of messages in conversation
|
||||
num_messages: usize,
|
||||
/// Number of tool calls made
|
||||
num_tool_calls: usize,
|
||||
},
|
||||
/// Called before a subagent stops
|
||||
#[serde(rename_all = "camelCase")]
|
||||
SubagentStop {
|
||||
/// Unique identifier for the subagent
|
||||
agent_id: String,
|
||||
/// Type of subagent (e.g., "explore", "code-reviewer")
|
||||
agent_type: String,
|
||||
/// Reason for stopping
|
||||
reason: String,
|
||||
},
|
||||
/// Called when a notification is sent to the user
|
||||
#[serde(rename_all = "camelCase")]
|
||||
Notification {
|
||||
/// Notification message
|
||||
message: String,
|
||||
/// Notification type (e.g., "info", "warning", "error")
|
||||
notification_type: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl HookEvent {
|
||||
@@ -46,27 +74,137 @@ impl HookEvent {
|
||||
HookEvent::SessionEnd { .. } => "SessionEnd",
|
||||
HookEvent::UserPromptSubmit { .. } => "UserPromptSubmit",
|
||||
HookEvent::PreCompact => "PreCompact",
|
||||
HookEvent::Stop { .. } => "Stop",
|
||||
HookEvent::SubagentStop { .. } => "SubagentStop",
|
||||
HookEvent::Notification { .. } => "Notification",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple hook result for backwards compatibility
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum HookResult {
|
||||
Allow,
|
||||
Deny,
|
||||
}
|
||||
|
||||
/// Extended hook output with additional control options
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct HookOutput {
|
||||
/// Whether to continue execution (default: true if exit code 0)
|
||||
#[serde(default = "default_continue")]
|
||||
pub continue_execution: bool,
|
||||
/// Whether to suppress showing the result to the user
|
||||
#[serde(default)]
|
||||
pub suppress_output: bool,
|
||||
/// System message to inject into the conversation
|
||||
#[serde(default)]
|
||||
pub system_message: Option<String>,
|
||||
/// Permission decision override
|
||||
#[serde(default)]
|
||||
pub permission_decision: Option<HookPermission>,
|
||||
/// Modified input/args for the tool (PreToolUse only)
|
||||
#[serde(default)]
|
||||
pub updated_input: Option<Value>,
|
||||
}
|
||||
|
||||
impl Default for HookOutput {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
continue_execution: true,
|
||||
suppress_output: false,
|
||||
system_message: None,
|
||||
permission_decision: None,
|
||||
updated_input: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_continue() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Permission decision from a hook
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum HookPermission {
|
||||
Allow,
|
||||
Deny,
|
||||
Ask,
|
||||
}
|
||||
|
||||
impl HookOutput {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn allow() -> Self {
|
||||
Self {
|
||||
continue_execution: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deny() -> Self {
|
||||
Self {
|
||||
continue_execution: false,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_system_message(mut self, message: impl Into<String>) -> Self {
|
||||
self.system_message = Some(message.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_permission(mut self, permission: HookPermission) -> Self {
|
||||
self.permission_decision = Some(permission);
|
||||
self
|
||||
}
|
||||
|
||||
/// Convert to simple HookResult for backwards compatibility
|
||||
pub fn to_result(&self) -> HookResult {
|
||||
if self.continue_execution {
|
||||
HookResult::Allow
|
||||
} else {
|
||||
HookResult::Deny
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A registered hook that can be executed
|
||||
#[derive(Debug, Clone)]
|
||||
struct Hook {
|
||||
event: String, // Event name like "PreToolUse", "PostToolUse", etc.
|
||||
command: String,
|
||||
pattern: Option<String>, // Optional regex pattern for matching tool names
|
||||
timeout: Option<u64>,
|
||||
}
|
||||
|
||||
pub struct HookManager {
|
||||
project_root: PathBuf,
|
||||
hooks: Vec<Hook>,
|
||||
}
|
||||
|
||||
impl HookManager {
|
||||
pub fn new(project_root: &str) -> Self {
|
||||
Self {
|
||||
project_root: PathBuf::from(project_root),
|
||||
hooks: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a single hook
|
||||
pub fn register_hook(&mut self, event: String, command: String, pattern: Option<String>, timeout: Option<u64>) {
|
||||
self.hooks.push(Hook {
|
||||
event,
|
||||
command,
|
||||
pattern,
|
||||
timeout,
|
||||
});
|
||||
}
|
||||
|
||||
/// Execute a hook for the given event
|
||||
///
|
||||
/// Returns:
|
||||
@@ -74,18 +212,66 @@ impl HookManager {
|
||||
/// - Ok(HookResult::Deny) if hook denies (exit code 2)
|
||||
/// - Err if hook fails (other exit codes) or times out
|
||||
pub async fn execute(&self, event: &HookEvent, timeout_ms: Option<u64>) -> Result<HookResult> {
|
||||
// First check for legacy file-based hooks
|
||||
let hook_path = self.get_hook_path(event);
|
||||
let has_file_hook = hook_path.exists();
|
||||
|
||||
// If hook doesn't exist, allow by default
|
||||
if !hook_path.exists() {
|
||||
// Get registered hooks for this event
|
||||
let event_name = event.hook_name();
|
||||
let mut matching_hooks: Vec<&Hook> = self.hooks.iter()
|
||||
.filter(|h| h.event == event_name)
|
||||
.collect();
|
||||
|
||||
// If we need to filter by pattern (for PreToolUse events)
|
||||
if let HookEvent::PreToolUse { tool, .. } = event {
|
||||
matching_hooks.retain(|h| {
|
||||
if let Some(pattern) = &h.pattern {
|
||||
// Use regex to match tool name
|
||||
if let Ok(re) = regex::Regex::new(pattern) {
|
||||
re.is_match(tool)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
true // No pattern means match all
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// If no hooks at all, allow by default
|
||||
if !has_file_hook && matching_hooks.is_empty() {
|
||||
return Ok(HookResult::Allow);
|
||||
}
|
||||
|
||||
// Execute file-based hook first (if exists)
|
||||
if has_file_hook {
|
||||
let result = self.execute_hook_command(&hook_path.to_string_lossy(), event, timeout_ms).await?;
|
||||
if result == HookResult::Deny {
|
||||
return Ok(HookResult::Deny);
|
||||
}
|
||||
}
|
||||
|
||||
// Execute registered hooks
|
||||
for hook in matching_hooks {
|
||||
let hook_timeout = hook.timeout.or(timeout_ms);
|
||||
let result = self.execute_hook_command(&hook.command, event, hook_timeout).await?;
|
||||
if result == HookResult::Deny {
|
||||
return Ok(HookResult::Deny);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(HookResult::Allow)
|
||||
}
|
||||
|
||||
/// Execute a single hook command
|
||||
async fn execute_hook_command(&self, command: &str, event: &HookEvent, timeout_ms: Option<u64>) -> Result<HookResult> {
|
||||
// Serialize event to JSON
|
||||
let input_json = serde_json::to_string(event)?;
|
||||
|
||||
// Spawn the hook process
|
||||
let mut child = Command::new(&hook_path)
|
||||
let mut child = Command::new("sh")
|
||||
.arg("-c")
|
||||
.arg(command)
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
@@ -126,6 +312,131 @@ impl HookManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a hook and return extended output
|
||||
///
|
||||
/// This method parses JSON output from stdout if the hook provides it,
|
||||
/// otherwise falls back to exit code interpretation.
|
||||
pub async fn execute_extended(&self, event: &HookEvent, timeout_ms: Option<u64>) -> Result<HookOutput> {
|
||||
// First check for legacy file-based hooks
|
||||
let hook_path = self.get_hook_path(event);
|
||||
let has_file_hook = hook_path.exists();
|
||||
|
||||
// Get registered hooks for this event
|
||||
let event_name = event.hook_name();
|
||||
let mut matching_hooks: Vec<&Hook> = self.hooks.iter()
|
||||
.filter(|h| h.event == event_name)
|
||||
.collect();
|
||||
|
||||
// If we need to filter by pattern (for PreToolUse events)
|
||||
if let HookEvent::PreToolUse { tool, .. } = event {
|
||||
matching_hooks.retain(|h| {
|
||||
if let Some(pattern) = &h.pattern {
|
||||
if let Ok(re) = regex::Regex::new(pattern) {
|
||||
re.is_match(tool)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// If no hooks at all, allow by default
|
||||
if !has_file_hook && matching_hooks.is_empty() {
|
||||
return Ok(HookOutput::allow());
|
||||
}
|
||||
|
||||
let mut combined_output = HookOutput::allow();
|
||||
|
||||
// Execute file-based hook first (if exists)
|
||||
if has_file_hook {
|
||||
let output = self.execute_hook_extended(&hook_path.to_string_lossy(), event, timeout_ms).await?;
|
||||
combined_output = Self::merge_outputs(combined_output, output);
|
||||
if !combined_output.continue_execution {
|
||||
return Ok(combined_output);
|
||||
}
|
||||
}
|
||||
|
||||
// Execute registered hooks
|
||||
for hook in matching_hooks {
|
||||
let hook_timeout = hook.timeout.or(timeout_ms);
|
||||
let output = self.execute_hook_extended(&hook.command, event, hook_timeout).await?;
|
||||
combined_output = Self::merge_outputs(combined_output, output);
|
||||
if !combined_output.continue_execution {
|
||||
return Ok(combined_output);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(combined_output)
|
||||
}
|
||||
|
||||
/// Execute a single hook command and return extended output
|
||||
async fn execute_hook_extended(&self, command: &str, event: &HookEvent, timeout_ms: Option<u64>) -> Result<HookOutput> {
|
||||
let input_json = serde_json::to_string(event)?;
|
||||
|
||||
let mut child = Command::new("sh")
|
||||
.arg("-c")
|
||||
.arg(command)
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.current_dir(&self.project_root)
|
||||
.spawn()?;
|
||||
|
||||
if let Some(mut stdin) = child.stdin.take() {
|
||||
stdin.write_all(input_json.as_bytes()).await?;
|
||||
stdin.flush().await?;
|
||||
drop(stdin);
|
||||
}
|
||||
|
||||
let result = if let Some(ms) = timeout_ms {
|
||||
timeout(Duration::from_millis(ms), child.wait_with_output()).await
|
||||
} else {
|
||||
Ok(child.wait_with_output().await)
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(Ok(output)) => {
|
||||
let exit_code = output.status.code();
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
|
||||
// Try to parse JSON output from stdout
|
||||
if !stdout.trim().is_empty() {
|
||||
if let Ok(hook_output) = serde_json::from_str::<HookOutput>(stdout.trim()) {
|
||||
return Ok(hook_output);
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to exit code interpretation
|
||||
match exit_code {
|
||||
Some(0) => Ok(HookOutput::allow()),
|
||||
Some(2) => Ok(HookOutput::deny()),
|
||||
Some(code) => Err(eyre!(
|
||||
"Hook {} failed with exit code {}: {}",
|
||||
event.hook_name(),
|
||||
code,
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
)),
|
||||
None => Err(eyre!("Hook {} terminated by signal", event.hook_name())),
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => Err(eyre!("Failed to execute hook {}: {}", event.hook_name(), e)),
|
||||
Err(_) => Err(eyre!("Hook {} timed out", event.hook_name())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Merge two hook outputs, with the second taking precedence
|
||||
fn merge_outputs(base: HookOutput, new: HookOutput) -> HookOutput {
|
||||
HookOutput {
|
||||
continue_execution: base.continue_execution && new.continue_execution,
|
||||
suppress_output: base.suppress_output || new.suppress_output,
|
||||
system_message: new.system_message.or(base.system_message),
|
||||
permission_decision: new.permission_decision.or(base.permission_decision),
|
||||
updated_input: new.updated_input.or(base.updated_input),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_hook_path(&self, event: &HookEvent) -> PathBuf {
|
||||
self.project_root
|
||||
.join(".owlen")
|
||||
@@ -167,5 +478,76 @@ mod tests {
|
||||
.hook_name(),
|
||||
"SessionStart"
|
||||
);
|
||||
assert_eq!(
|
||||
HookEvent::Stop {
|
||||
reason: "task_complete".to_string(),
|
||||
num_messages: 10,
|
||||
num_tool_calls: 5,
|
||||
}
|
||||
.hook_name(),
|
||||
"Stop"
|
||||
);
|
||||
assert_eq!(
|
||||
HookEvent::SubagentStop {
|
||||
agent_id: "abc123".to_string(),
|
||||
agent_type: "explore".to_string(),
|
||||
reason: "completed".to_string(),
|
||||
}
|
||||
.hook_name(),
|
||||
"SubagentStop"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stop_event_serializes_correctly() {
|
||||
let event = HookEvent::Stop {
|
||||
reason: "task_complete".to_string(),
|
||||
num_messages: 10,
|
||||
num_tool_calls: 5,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&event).unwrap();
|
||||
assert!(json.contains("\"event\":\"stop\""));
|
||||
assert!(json.contains("\"reason\":\"task_complete\""));
|
||||
assert!(json.contains("\"numMessages\":10"));
|
||||
assert!(json.contains("\"numToolCalls\":5"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hook_output_defaults() {
|
||||
let output = HookOutput::default();
|
||||
assert!(output.continue_execution);
|
||||
assert!(!output.suppress_output);
|
||||
assert!(output.system_message.is_none());
|
||||
assert!(output.permission_decision.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hook_output_builders() {
|
||||
let output = HookOutput::allow()
|
||||
.with_system_message("Test message")
|
||||
.with_permission(HookPermission::Allow);
|
||||
|
||||
assert!(output.continue_execution);
|
||||
assert_eq!(output.system_message, Some("Test message".to_string()));
|
||||
assert_eq!(output.permission_decision, Some(HookPermission::Allow));
|
||||
|
||||
let deny = HookOutput::deny();
|
||||
assert!(!deny.continue_execution);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hook_output_deserializes() {
|
||||
let json = r#"{"continueExecution": true, "suppressOutput": false, "systemMessage": "Hello"}"#;
|
||||
let output: HookOutput = serde_json::from_str(json).unwrap();
|
||||
assert!(output.continue_execution);
|
||||
assert!(!output.suppress_output);
|
||||
assert_eq!(output.system_message, Some("Hello".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hook_output_to_result() {
|
||||
assert_eq!(HookOutput::allow().to_result(), HookResult::Allow);
|
||||
assert_eq!(HookOutput::deny().to_result(), HookResult::Deny);
|
||||
}
|
||||
}
|
||||
|
||||
154
crates/platform/hooks/tests/plugin_hooks.rs
Normal file
154
crates/platform/hooks/tests/plugin_hooks.rs
Normal file
@@ -0,0 +1,154 @@
|
||||
// Integration test for plugin hooks with HookManager
|
||||
use color_eyre::eyre::Result;
|
||||
use hooks::{HookEvent, HookManager, HookResult};
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_register_and_execute_plugin_hooks() -> Result<()> {
|
||||
// Create temporary directory to act as project root
|
||||
let temp_dir = TempDir::new()?;
|
||||
|
||||
// Create hook manager
|
||||
let mut hook_mgr = HookManager::new(temp_dir.path().to_str().unwrap());
|
||||
|
||||
// Register a hook that matches Edit|Write tools
|
||||
hook_mgr.register_hook(
|
||||
"PreToolUse".to_string(),
|
||||
"echo 'Hook executed' && exit 0".to_string(),
|
||||
Some("Edit|Write".to_string()),
|
||||
Some(5000),
|
||||
);
|
||||
|
||||
// Test that the hook executes for Edit tool
|
||||
let event = HookEvent::PreToolUse {
|
||||
tool: "Edit".to_string(),
|
||||
args: serde_json::json!({"path": "/tmp/test.txt"}),
|
||||
};
|
||||
|
||||
let result = hook_mgr.execute(&event, Some(5000)).await?;
|
||||
assert_eq!(result, HookResult::Allow);
|
||||
|
||||
// Test that the hook executes for Write tool
|
||||
let event = HookEvent::PreToolUse {
|
||||
tool: "Write".to_string(),
|
||||
args: serde_json::json!({"path": "/tmp/test.txt"}),
|
||||
};
|
||||
|
||||
let result = hook_mgr.execute(&event, Some(5000)).await?;
|
||||
assert_eq!(result, HookResult::Allow);
|
||||
|
||||
// Test that the hook does NOT execute for Read tool (doesn't match pattern)
|
||||
let event = HookEvent::PreToolUse {
|
||||
tool: "Read".to_string(),
|
||||
args: serde_json::json!({"path": "/tmp/test.txt"}),
|
||||
};
|
||||
|
||||
let result = hook_mgr.execute(&event, Some(5000)).await?;
|
||||
assert_eq!(result, HookResult::Allow);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_deny_hook() -> Result<()> {
|
||||
// Create temporary directory to act as project root
|
||||
let temp_dir = TempDir::new()?;
|
||||
|
||||
// Create hook manager
|
||||
let mut hook_mgr = HookManager::new(temp_dir.path().to_str().unwrap());
|
||||
|
||||
// Register a hook that denies Write operations
|
||||
hook_mgr.register_hook(
|
||||
"PreToolUse".to_string(),
|
||||
"exit 2".to_string(), // Exit code 2 means deny
|
||||
Some("Write".to_string()),
|
||||
Some(5000),
|
||||
);
|
||||
|
||||
// Test that the hook denies Write tool
|
||||
let event = HookEvent::PreToolUse {
|
||||
tool: "Write".to_string(),
|
||||
args: serde_json::json!({"path": "/tmp/test.txt"}),
|
||||
};
|
||||
|
||||
let result = hook_mgr.execute(&event, Some(5000)).await?;
|
||||
assert_eq!(result, HookResult::Deny);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_hooks_same_event() -> Result<()> {
|
||||
// Create temporary directory to act as project root
|
||||
let temp_dir = TempDir::new()?;
|
||||
|
||||
// Create hook manager
|
||||
let mut hook_mgr = HookManager::new(temp_dir.path().to_str().unwrap());
|
||||
|
||||
// Register multiple hooks for the same event
|
||||
hook_mgr.register_hook(
|
||||
"PreToolUse".to_string(),
|
||||
"echo 'Hook 1' && exit 0".to_string(),
|
||||
Some("Edit".to_string()),
|
||||
Some(5000),
|
||||
);
|
||||
|
||||
hook_mgr.register_hook(
|
||||
"PreToolUse".to_string(),
|
||||
"echo 'Hook 2' && exit 0".to_string(),
|
||||
Some("Edit".to_string()),
|
||||
Some(5000),
|
||||
);
|
||||
|
||||
// Test that both hooks execute
|
||||
let event = HookEvent::PreToolUse {
|
||||
tool: "Edit".to_string(),
|
||||
args: serde_json::json!({"path": "/tmp/test.txt"}),
|
||||
};
|
||||
|
||||
let result = hook_mgr.execute(&event, Some(5000)).await?;
|
||||
assert_eq!(result, HookResult::Allow);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hook_with_no_pattern_matches_all() -> Result<()> {
|
||||
// Create temporary directory to act as project root
|
||||
let temp_dir = TempDir::new()?;
|
||||
|
||||
// Create hook manager
|
||||
let mut hook_mgr = HookManager::new(temp_dir.path().to_str().unwrap());
|
||||
|
||||
// Register a hook with no pattern (matches all tools)
|
||||
hook_mgr.register_hook(
|
||||
"PreToolUse".to_string(),
|
||||
"echo 'Hook for all tools' && exit 0".to_string(),
|
||||
None, // No pattern = match all
|
||||
Some(5000),
|
||||
);
|
||||
|
||||
// Test that the hook executes for any tool
|
||||
let event = HookEvent::PreToolUse {
|
||||
tool: "Read".to_string(),
|
||||
args: serde_json::json!({"path": "/tmp/test.txt"}),
|
||||
};
|
||||
let result = hook_mgr.execute(&event, Some(5000)).await?;
|
||||
assert_eq!(result, HookResult::Allow);
|
||||
|
||||
let event = HookEvent::PreToolUse {
|
||||
tool: "Write".to_string(),
|
||||
args: serde_json::json!({"path": "/tmp/test.txt"}),
|
||||
};
|
||||
let result = hook_mgr.execute(&event, Some(5000)).await?;
|
||||
assert_eq!(result, HookResult::Allow);
|
||||
|
||||
let event = HookEvent::PreToolUse {
|
||||
tool: "Bash".to_string(),
|
||||
args: serde_json::json!({"command": "ls"}),
|
||||
};
|
||||
let result = hook_mgr.execute(&event, Some(5000)).await?;
|
||||
assert_eq!(result, HookResult::Allow);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -16,6 +16,75 @@ pub enum Tool {
|
||||
Task,
|
||||
TodoWrite,
|
||||
Mcp,
|
||||
// New tools
|
||||
MultiEdit,
|
||||
LS,
|
||||
AskUserQuestion,
|
||||
BashOutput,
|
||||
KillShell,
|
||||
// Planning mode tools
|
||||
EnterPlanMode,
|
||||
ExitPlanMode,
|
||||
Skill,
|
||||
}
|
||||
|
||||
impl Tool {
|
||||
/// Parse a tool name from string (case-insensitive)
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"read" => Some(Tool::Read),
|
||||
"write" => Some(Tool::Write),
|
||||
"edit" => Some(Tool::Edit),
|
||||
"bash" => Some(Tool::Bash),
|
||||
"grep" => Some(Tool::Grep),
|
||||
"glob" => Some(Tool::Glob),
|
||||
"webfetch" | "web_fetch" => Some(Tool::WebFetch),
|
||||
"websearch" | "web_search" => Some(Tool::WebSearch),
|
||||
"notebookread" | "notebook_read" => Some(Tool::NotebookRead),
|
||||
"notebookedit" | "notebook_edit" => Some(Tool::NotebookEdit),
|
||||
"slashcommand" | "slash_command" => Some(Tool::SlashCommand),
|
||||
"task" => Some(Tool::Task),
|
||||
"todowrite" | "todo_write" | "todo" => Some(Tool::TodoWrite),
|
||||
"mcp" => Some(Tool::Mcp),
|
||||
"multiedit" | "multi_edit" => Some(Tool::MultiEdit),
|
||||
"ls" => Some(Tool::LS),
|
||||
"askuserquestion" | "ask_user_question" | "ask" => Some(Tool::AskUserQuestion),
|
||||
"bashoutput" | "bash_output" => Some(Tool::BashOutput),
|
||||
"killshell" | "kill_shell" => Some(Tool::KillShell),
|
||||
"enterplanmode" | "enter_plan_mode" => Some(Tool::EnterPlanMode),
|
||||
"exitplanmode" | "exit_plan_mode" => Some(Tool::ExitPlanMode),
|
||||
"skill" => Some(Tool::Skill),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the string name of this tool
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
Tool::Read => "read",
|
||||
Tool::Write => "write",
|
||||
Tool::Edit => "edit",
|
||||
Tool::Bash => "bash",
|
||||
Tool::Grep => "grep",
|
||||
Tool::Glob => "glob",
|
||||
Tool::WebFetch => "web_fetch",
|
||||
Tool::WebSearch => "web_search",
|
||||
Tool::NotebookRead => "notebook_read",
|
||||
Tool::NotebookEdit => "notebook_edit",
|
||||
Tool::SlashCommand => "slash_command",
|
||||
Tool::Task => "task",
|
||||
Tool::TodoWrite => "todo_write",
|
||||
Tool::Mcp => "mcp",
|
||||
Tool::MultiEdit => "multi_edit",
|
||||
Tool::LS => "ls",
|
||||
Tool::AskUserQuestion => "ask_user_question",
|
||||
Tool::BashOutput => "bash_output",
|
||||
Tool::KillShell => "kill_shell",
|
||||
Tool::EnterPlanMode => "enter_plan_mode",
|
||||
Tool::ExitPlanMode => "exit_plan_mode",
|
||||
Tool::Skill => "skill",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
@@ -123,23 +192,34 @@ impl PermissionManager {
|
||||
match self.mode {
|
||||
Mode::Plan => match tool {
|
||||
// Read-only tools are allowed in plan mode
|
||||
Tool::Read | Tool::Grep | Tool::Glob | Tool::NotebookRead => {
|
||||
Tool::Read | Tool::Grep | Tool::Glob | Tool::NotebookRead | Tool::LS => {
|
||||
PermissionDecision::Allow
|
||||
}
|
||||
// User interaction and session state tools allowed
|
||||
Tool::AskUserQuestion | Tool::TodoWrite => PermissionDecision::Allow,
|
||||
// Planning mode tools - EnterPlanMode asks, ExitPlanMode allowed
|
||||
Tool::EnterPlanMode => PermissionDecision::Ask,
|
||||
Tool::ExitPlanMode => PermissionDecision::Allow,
|
||||
// Skill tool allowed (read-only skill injection)
|
||||
Tool::Skill => PermissionDecision::Allow,
|
||||
// Everything else requires asking
|
||||
_ => PermissionDecision::Ask,
|
||||
},
|
||||
Mode::AcceptEdits => match tool {
|
||||
// Read operations allowed
|
||||
Tool::Read | Tool::Grep | Tool::Glob | Tool::NotebookRead => {
|
||||
Tool::Read | Tool::Grep | Tool::Glob | Tool::NotebookRead | Tool::LS => {
|
||||
PermissionDecision::Allow
|
||||
}
|
||||
// Edit/Write operations allowed
|
||||
Tool::Edit | Tool::Write | Tool::NotebookEdit => PermissionDecision::Allow,
|
||||
Tool::Edit | Tool::Write | Tool::NotebookEdit | Tool::MultiEdit => PermissionDecision::Allow,
|
||||
// Bash and other dangerous operations still require asking
|
||||
Tool::Bash | Tool::WebFetch | Tool::WebSearch | Tool::Mcp => PermissionDecision::Ask,
|
||||
// Background shell operations same as Bash
|
||||
Tool::BashOutput | Tool::KillShell => PermissionDecision::Ask,
|
||||
// Utility tools allowed
|
||||
Tool::TodoWrite | Tool::SlashCommand | Tool::Task => PermissionDecision::Allow,
|
||||
Tool::TodoWrite | Tool::SlashCommand | Tool::Task | Tool::AskUserQuestion => PermissionDecision::Allow,
|
||||
// Planning mode tools allowed
|
||||
Tool::EnterPlanMode | Tool::ExitPlanMode | Tool::Skill => PermissionDecision::Allow,
|
||||
},
|
||||
Mode::Code => {
|
||||
// Everything allowed in code mode
|
||||
@@ -155,6 +235,41 @@ impl PermissionManager {
|
||||
pub fn mode(&self) -> Mode {
|
||||
self.mode
|
||||
}
|
||||
|
||||
/// Add allowed tools from a list of tool names (with optional patterns)
|
||||
///
|
||||
/// Format: "tool_name" or "tool_name:pattern"
|
||||
/// Example: "bash", "bash:npm test:*", "mcp:filesystem__*"
|
||||
pub fn add_allowed_tools(&mut self, tools: &[String]) {
|
||||
for spec in tools {
|
||||
if let Some((tool, pattern)) = Self::parse_tool_spec(spec) {
|
||||
self.add_rule(tool, pattern, Action::Allow);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add disallowed tools from a list of tool names (with optional patterns)
|
||||
///
|
||||
/// Format: "tool_name" or "tool_name:pattern"
|
||||
/// Example: "bash", "bash:rm -rf*"
|
||||
pub fn add_disallowed_tools(&mut self, tools: &[String]) {
|
||||
for spec in tools {
|
||||
if let Some((tool, pattern)) = Self::parse_tool_spec(spec) {
|
||||
self.add_rule(tool, pattern, Action::Deny);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a tool specification into (Tool, Option<pattern>)
|
||||
///
|
||||
/// Format: "tool_name" or "tool_name:pattern"
|
||||
fn parse_tool_spec(spec: &str) -> Option<(Tool, Option<String>)> {
|
||||
let parts: Vec<&str> = spec.splitn(2, ':').collect();
|
||||
let tool_name = parts[0].trim();
|
||||
let pattern = parts.get(1).map(|s| s.trim().to_string());
|
||||
|
||||
Tool::from_str(tool_name).map(|tool| (tool, pattern))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -237,4 +352,78 @@ mod tests {
|
||||
assert!(rule.matches(Tool::Mcp, Some("filesystem__read_file")));
|
||||
assert!(!rule.matches(Tool::Mcp, Some("filesystem__write_file")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_from_str() {
|
||||
assert_eq!(Tool::from_str("bash"), Some(Tool::Bash));
|
||||
assert_eq!(Tool::from_str("BASH"), Some(Tool::Bash));
|
||||
assert_eq!(Tool::from_str("Bash"), Some(Tool::Bash));
|
||||
assert_eq!(Tool::from_str("web_fetch"), Some(Tool::WebFetch));
|
||||
assert_eq!(Tool::from_str("webfetch"), Some(Tool::WebFetch));
|
||||
assert_eq!(Tool::from_str("unknown"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tool_spec() {
|
||||
let (tool, pattern) = PermissionManager::parse_tool_spec("bash").unwrap();
|
||||
assert_eq!(tool, Tool::Bash);
|
||||
assert_eq!(pattern, None);
|
||||
|
||||
let (tool, pattern) = PermissionManager::parse_tool_spec("bash:npm test*").unwrap();
|
||||
assert_eq!(tool, Tool::Bash);
|
||||
assert_eq!(pattern, Some("npm test*".to_string()));
|
||||
|
||||
let (tool, pattern) = PermissionManager::parse_tool_spec("mcp:filesystem__*").unwrap();
|
||||
assert_eq!(tool, Tool::Mcp);
|
||||
assert_eq!(pattern, Some("filesystem__*".to_string()));
|
||||
|
||||
assert!(PermissionManager::parse_tool_spec("invalid_tool").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowed_tools_list() {
|
||||
let mut pm = PermissionManager::new(Mode::Plan);
|
||||
|
||||
pm.add_allowed_tools(&[
|
||||
"bash:npm test:*".to_string(),
|
||||
"bash:cargo test".to_string(),
|
||||
]);
|
||||
|
||||
// Allowed by rule
|
||||
assert_eq!(pm.check(Tool::Bash, Some("npm test:unit")), PermissionDecision::Allow);
|
||||
assert_eq!(pm.check(Tool::Bash, Some("cargo test")), PermissionDecision::Allow);
|
||||
|
||||
// Not matched by any rule, falls back to mode default (Ask for bash in plan mode)
|
||||
assert_eq!(pm.check(Tool::Bash, Some("rm -rf")), PermissionDecision::Ask);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn disallowed_tools_list() {
|
||||
let mut pm = PermissionManager::new(Mode::Code);
|
||||
|
||||
pm.add_disallowed_tools(&[
|
||||
"bash:rm -rf*".to_string(),
|
||||
"bash:sudo*".to_string(),
|
||||
]);
|
||||
|
||||
// Denied by rule
|
||||
assert_eq!(pm.check(Tool::Bash, Some("rm -rf /")), PermissionDecision::Deny);
|
||||
assert_eq!(pm.check(Tool::Bash, Some("sudo apt install")), PermissionDecision::Deny);
|
||||
|
||||
// Not matched by deny rule, allowed by Code mode
|
||||
assert_eq!(pm.check(Tool::Bash, Some("npm test")), PermissionDecision::Allow);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deny_takes_precedence() {
|
||||
let mut pm = PermissionManager::new(Mode::Code);
|
||||
|
||||
// Add both allow and deny for similar patterns
|
||||
pm.add_disallowed_tools(&["bash:rm*".to_string()]);
|
||||
pm.add_allowed_tools(&["bash".to_string()]);
|
||||
|
||||
// Deny rule was added first, so it takes precedence when matched
|
||||
assert_eq!(pm.check(Tool::Bash, Some("rm -rf")), PermissionDecision::Deny);
|
||||
assert_eq!(pm.check(Tool::Bash, Some("ls -la")), PermissionDecision::Allow);
|
||||
}
|
||||
}
|
||||
|
||||
15
crates/platform/plugins/Cargo.toml
Normal file
15
crates/platform/plugins/Cargo.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
[package]
|
||||
name = "plugins"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
color-eyre = "0.6"
|
||||
dirs = "5.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
serde_yaml = "0.9"
|
||||
walkdir = "2.5"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.13"
|
||||
773
crates/platform/plugins/src/lib.rs
Normal file
773
crates/platform/plugins/src/lib.rs
Normal file
@@ -0,0 +1,773 @@
|
||||
use color_eyre::eyre::{Result, eyre};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use walkdir::WalkDir;
|
||||
|
||||
/// Plugin manifest schema (plugin.json)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PluginManifest {
|
||||
/// Plugin name
|
||||
pub name: String,
|
||||
/// Plugin version
|
||||
pub version: String,
|
||||
/// Plugin description
|
||||
pub description: Option<String>,
|
||||
/// Plugin author
|
||||
pub author: Option<String>,
|
||||
/// Commands provided by this plugin
|
||||
#[serde(default)]
|
||||
pub commands: Vec<String>,
|
||||
/// Agents provided by this plugin
|
||||
#[serde(default)]
|
||||
pub agents: Vec<String>,
|
||||
/// Skills provided by this plugin
|
||||
#[serde(default)]
|
||||
pub skills: Vec<String>,
|
||||
/// Hooks provided by this plugin
|
||||
#[serde(default)]
|
||||
pub hooks: HashMap<String, String>,
|
||||
/// MCP servers provided by this plugin
|
||||
#[serde(default)]
|
||||
pub mcp_servers: Vec<McpServerConfig>,
|
||||
}
|
||||
|
||||
/// MCP server configuration in plugin manifest
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpServerConfig {
|
||||
pub name: String,
|
||||
pub command: String,
|
||||
#[serde(default)]
|
||||
pub args: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub env: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Plugin hook configuration from hooks/hooks.json
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct PluginHooksConfig {
|
||||
pub description: Option<String>,
|
||||
pub hooks: HashMap<String, Vec<HookMatcher>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct HookMatcher {
|
||||
pub matcher: Option<String>, // Regex pattern for tool names
|
||||
pub hooks: Vec<HookDefinition>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct HookDefinition {
|
||||
#[serde(rename = "type")]
|
||||
pub hook_type: String, // "command" or "prompt"
|
||||
pub command: Option<String>,
|
||||
pub prompt: Option<String>,
|
||||
pub timeout: Option<u64>,
|
||||
}
|
||||
|
||||
/// Parsed slash command from markdown file
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SlashCommand {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub argument_hint: Option<String>,
|
||||
pub allowed_tools: Option<Vec<String>>,
|
||||
pub body: String, // Markdown content after frontmatter
|
||||
pub source_path: PathBuf,
|
||||
}
|
||||
|
||||
/// Parsed agent definition from markdown file
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AgentDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub tools: Vec<String>, // Tool whitelist
|
||||
pub model: Option<String>, // haiku, sonnet, opus
|
||||
pub color: Option<String>,
|
||||
pub system_prompt: String, // Markdown body
|
||||
pub source_path: PathBuf,
|
||||
}
|
||||
|
||||
/// Parsed skill definition
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Skill {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub version: Option<String>,
|
||||
pub content: String, // Core SKILL.md content
|
||||
pub references: Vec<PathBuf>, // Reference files
|
||||
pub examples: Vec<PathBuf>, // Example files
|
||||
pub source_path: PathBuf,
|
||||
}
|
||||
|
||||
/// YAML frontmatter for command files
|
||||
#[derive(Deserialize)]
|
||||
struct CommandFrontmatter {
|
||||
description: Option<String>,
|
||||
#[serde(rename = "argument-hint")]
|
||||
argument_hint: Option<String>,
|
||||
#[serde(rename = "allowed-tools")]
|
||||
allowed_tools: Option<String>,
|
||||
}
|
||||
|
||||
/// YAML frontmatter for agent files
|
||||
#[derive(Deserialize)]
|
||||
struct AgentFrontmatter {
|
||||
name: String,
|
||||
description: String,
|
||||
#[serde(default)]
|
||||
tools: Vec<String>,
|
||||
model: Option<String>,
|
||||
color: Option<String>,
|
||||
}
|
||||
|
||||
/// YAML frontmatter for skill files
|
||||
#[derive(Deserialize)]
|
||||
struct SkillFrontmatter {
|
||||
name: String,
|
||||
description: String,
|
||||
version: Option<String>,
|
||||
}
|
||||
|
||||
/// Parse YAML frontmatter from markdown content
|
||||
fn parse_frontmatter<T: serde::de::DeserializeOwned>(content: &str) -> Result<(T, String)> {
|
||||
if !content.starts_with("---") {
|
||||
return Err(eyre!("No frontmatter found"));
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = content.splitn(3, "---").collect();
|
||||
if parts.len() < 3 {
|
||||
return Err(eyre!("Invalid frontmatter format"));
|
||||
}
|
||||
|
||||
let frontmatter: T = serde_yaml::from_str(parts[1].trim())?;
|
||||
let body = parts[2].trim().to_string();
|
||||
|
||||
Ok((frontmatter, body))
|
||||
}
|
||||
|
||||
/// A loaded plugin with its manifest and base path
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Plugin {
|
||||
pub manifest: PluginManifest,
|
||||
pub base_path: PathBuf,
|
||||
}
|
||||
|
||||
impl Plugin {
|
||||
/// Get the path to a command file
|
||||
pub fn command_path(&self, command_name: &str) -> PathBuf {
|
||||
self.base_path.join("commands").join(format!("{}.md", command_name))
|
||||
}
|
||||
|
||||
/// Get the path to an agent file
|
||||
pub fn agent_path(&self, agent_name: &str) -> PathBuf {
|
||||
self.base_path.join("agents").join(format!("{}.md", agent_name))
|
||||
}
|
||||
|
||||
/// Get the path to a skill file
|
||||
pub fn skill_path(&self, skill_name: &str) -> PathBuf {
|
||||
self.base_path.join("skills").join(skill_name).join("SKILL.md")
|
||||
}
|
||||
|
||||
/// Get the path to a hook script
|
||||
pub fn hook_path(&self, hook_name: &str) -> Option<PathBuf> {
|
||||
self.manifest.hooks.get(hook_name).map(|path| {
|
||||
self.base_path.join("hooks").join(path)
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse a command file
|
||||
pub fn parse_command(&self, name: &str) -> Result<SlashCommand> {
|
||||
let path = self.command_path(name);
|
||||
let content = fs::read_to_string(&path)?;
|
||||
let (fm, body): (CommandFrontmatter, String) = parse_frontmatter(&content)?;
|
||||
|
||||
let allowed_tools = fm.allowed_tools.map(|s| {
|
||||
s.split(',').map(|t| t.trim().to_string()).collect()
|
||||
});
|
||||
|
||||
Ok(SlashCommand {
|
||||
name: name.to_string(),
|
||||
description: fm.description,
|
||||
argument_hint: fm.argument_hint,
|
||||
allowed_tools,
|
||||
body,
|
||||
source_path: path,
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse an agent file
|
||||
pub fn parse_agent(&self, name: &str) -> Result<AgentDefinition> {
|
||||
let path = self.agent_path(name);
|
||||
let content = fs::read_to_string(&path)?;
|
||||
let (fm, body): (AgentFrontmatter, String) = parse_frontmatter(&content)?;
|
||||
|
||||
Ok(AgentDefinition {
|
||||
name: fm.name,
|
||||
description: fm.description,
|
||||
tools: fm.tools,
|
||||
model: fm.model,
|
||||
color: fm.color,
|
||||
system_prompt: body,
|
||||
source_path: path,
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse a skill file
|
||||
pub fn parse_skill(&self, name: &str) -> Result<Skill> {
|
||||
let path = self.skill_path(name);
|
||||
let content = fs::read_to_string(&path)?;
|
||||
let (fm, body): (SkillFrontmatter, String) = parse_frontmatter(&content)?;
|
||||
|
||||
// Discover reference and example files in the skill directory
|
||||
let skill_dir = self.base_path.join("skills").join(name);
|
||||
let references_dir = skill_dir.join("references");
|
||||
let examples_dir = skill_dir.join("examples");
|
||||
|
||||
let references = if references_dir.exists() {
|
||||
fs::read_dir(&references_dir)
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.filter_map(|e| e.ok())
|
||||
.map(|e| e.path())
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
let examples = if examples_dir.exists() {
|
||||
fs::read_dir(&examples_dir)
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.filter_map(|e| e.ok())
|
||||
.map(|e| e.path())
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
Ok(Skill {
|
||||
name: fm.name,
|
||||
description: fm.description,
|
||||
version: fm.version,
|
||||
content: body,
|
||||
references,
|
||||
examples,
|
||||
source_path: path,
|
||||
})
|
||||
}
|
||||
|
||||
/// Auto-discover commands in the commands/ directory
|
||||
pub fn discover_commands(&self) -> Vec<String> {
|
||||
let commands_dir = self.base_path.join("commands");
|
||||
if !commands_dir.exists() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
std::fs::read_dir(&commands_dir)
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.filter_map(|e| e.ok())
|
||||
.filter(|e| e.path().extension().map(|ext| ext == "md").unwrap_or(false))
|
||||
.filter_map(|e| {
|
||||
e.path().file_stem()
|
||||
.map(|s| s.to_string_lossy().to_string())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Auto-discover agents in the agents/ directory
|
||||
pub fn discover_agents(&self) -> Vec<String> {
|
||||
let agents_dir = self.base_path.join("agents");
|
||||
if !agents_dir.exists() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
std::fs::read_dir(&agents_dir)
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.filter_map(|e| e.ok())
|
||||
.filter(|e| e.path().extension().map(|ext| ext == "md").unwrap_or(false))
|
||||
.filter_map(|e| {
|
||||
e.path().file_stem()
|
||||
.map(|s| s.to_string_lossy().to_string())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Auto-discover skills in skills/*/SKILL.md
|
||||
pub fn discover_skills(&self) -> Vec<String> {
|
||||
let skills_dir = self.base_path.join("skills");
|
||||
if !skills_dir.exists() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
std::fs::read_dir(&skills_dir)
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.filter_map(|e| e.ok())
|
||||
.filter(|e| e.path().is_dir())
|
||||
.filter(|e| e.path().join("SKILL.md").exists())
|
||||
.filter_map(|e| {
|
||||
e.path().file_name()
|
||||
.map(|s| s.to_string_lossy().to_string())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get all commands (manifest + discovered)
|
||||
pub fn all_command_names(&self) -> Vec<String> {
|
||||
let mut names: std::collections::HashSet<String> =
|
||||
self.manifest.commands.iter().cloned().collect();
|
||||
names.extend(self.discover_commands());
|
||||
names.into_iter().collect()
|
||||
}
|
||||
|
||||
/// Get all agent names (manifest + discovered)
|
||||
pub fn all_agent_names(&self) -> Vec<String> {
|
||||
let mut names: std::collections::HashSet<String> =
|
||||
self.manifest.agents.iter().cloned().collect();
|
||||
names.extend(self.discover_agents());
|
||||
names.into_iter().collect()
|
||||
}
|
||||
|
||||
/// Get all skill names (manifest + discovered)
|
||||
pub fn all_skill_names(&self) -> Vec<String> {
|
||||
let mut names: std::collections::HashSet<String> =
|
||||
self.manifest.skills.iter().cloned().collect();
|
||||
names.extend(self.discover_skills());
|
||||
names.into_iter().collect()
|
||||
}
|
||||
|
||||
/// Load hooks configuration from hooks/hooks.json
|
||||
pub fn load_hooks_config(&self) -> Result<Option<PluginHooksConfig>> {
|
||||
let hooks_path = self.base_path.join("hooks").join("hooks.json");
|
||||
if !hooks_path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let content = fs::read_to_string(&hooks_path)?;
|
||||
let config: PluginHooksConfig = serde_json::from_str(&content)?;
|
||||
Ok(Some(config))
|
||||
}
|
||||
|
||||
/// Register hooks from this plugin's config into a hook manager
|
||||
/// This requires the hooks crate to be available where this is called
|
||||
pub fn register_hooks_with_manager(&self, config: &PluginHooksConfig) -> Vec<(String, String, Option<String>, Option<u64>)> {
|
||||
let mut hooks_to_register = Vec::new();
|
||||
|
||||
for (event_name, matchers) in &config.hooks {
|
||||
for matcher in matchers {
|
||||
for hook_def in &matcher.hooks {
|
||||
if let Some(command) = &hook_def.command {
|
||||
// Substitute ${CLAUDE_PLUGIN_ROOT}
|
||||
let resolved = command.replace(
|
||||
"${CLAUDE_PLUGIN_ROOT}",
|
||||
&self.base_path.to_string_lossy()
|
||||
);
|
||||
|
||||
hooks_to_register.push((
|
||||
event_name.clone(),
|
||||
resolved,
|
||||
matcher.matcher.clone(),
|
||||
hook_def.timeout,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hooks_to_register
|
||||
}
|
||||
}
|
||||
|
||||
/// Plugin loader and registry
|
||||
pub struct PluginManager {
|
||||
plugins: Vec<Plugin>,
|
||||
plugin_dirs: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
impl PluginManager {
|
||||
/// Create a new plugin manager with default plugin directories
|
||||
pub fn new() -> Self {
|
||||
let mut plugin_dirs = Vec::new();
|
||||
|
||||
// User plugins: ~/.config/owlen/plugins
|
||||
if let Some(config_dir) = dirs::config_dir() {
|
||||
plugin_dirs.push(config_dir.join("owlen").join("plugins"));
|
||||
}
|
||||
|
||||
// Project plugins: .owlen/plugins
|
||||
plugin_dirs.push(PathBuf::from(".owlen/plugins"));
|
||||
|
||||
Self {
|
||||
plugins: Vec::new(),
|
||||
plugin_dirs,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a plugin manager with custom plugin directories
|
||||
pub fn with_dirs(plugin_dirs: Vec<PathBuf>) -> Self {
|
||||
Self {
|
||||
plugins: Vec::new(),
|
||||
plugin_dirs,
|
||||
}
|
||||
}
|
||||
|
||||
/// Load all plugins from configured directories
|
||||
pub fn load_all(&mut self) -> Result<()> {
|
||||
let plugin_dirs = self.plugin_dirs.clone();
|
||||
for dir in &plugin_dirs {
|
||||
if !dir.exists() {
|
||||
continue;
|
||||
}
|
||||
|
||||
self.load_from_dir(dir)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load plugins from a specific directory
|
||||
fn load_from_dir(&mut self, dir: &Path) -> Result<()> {
|
||||
// Walk directory looking for plugin.json files
|
||||
for entry in WalkDir::new(dir)
|
||||
.max_depth(2) // Don't recurse too deep
|
||||
.into_iter()
|
||||
.filter_map(|e| e.ok())
|
||||
{
|
||||
if entry.file_name() == "plugin.json" {
|
||||
if let Some(plugin_dir) = entry.path().parent() {
|
||||
match self.load_plugin(plugin_dir) {
|
||||
Ok(plugin) => {
|
||||
self.plugins.push(plugin);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Warning: Failed to load plugin from {:?}: {}", plugin_dir, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load a single plugin from a directory
|
||||
fn load_plugin(&self, plugin_dir: &Path) -> Result<Plugin> {
|
||||
let manifest_path = plugin_dir.join("plugin.json");
|
||||
let content = fs::read_to_string(&manifest_path)
|
||||
.map_err(|e| eyre!("Failed to read plugin manifest: {}", e))?;
|
||||
|
||||
let manifest: PluginManifest = serde_json::from_str(&content)
|
||||
.map_err(|e| eyre!("Failed to parse plugin manifest: {}", e))?;
|
||||
|
||||
Ok(Plugin {
|
||||
manifest,
|
||||
base_path: plugin_dir.to_path_buf(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get all loaded plugins
|
||||
pub fn plugins(&self) -> &[Plugin] {
|
||||
&self.plugins
|
||||
}
|
||||
|
||||
/// Find a plugin by name
|
||||
pub fn find_plugin(&self, name: &str) -> Option<&Plugin> {
|
||||
self.plugins.iter().find(|p| p.manifest.name == name)
|
||||
}
|
||||
|
||||
/// Get all available commands from all plugins
|
||||
pub fn all_commands(&self) -> HashMap<String, PathBuf> {
|
||||
let mut commands = HashMap::new();
|
||||
|
||||
for plugin in &self.plugins {
|
||||
for cmd_name in &plugin.manifest.commands {
|
||||
let path = plugin.command_path(cmd_name);
|
||||
if path.exists() {
|
||||
commands.insert(cmd_name.clone(), path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
commands
|
||||
}
|
||||
|
||||
/// Get all available agents from all plugins
|
||||
pub fn all_agents(&self) -> HashMap<String, PathBuf> {
|
||||
let mut agents = HashMap::new();
|
||||
|
||||
for plugin in &self.plugins {
|
||||
for agent_name in &plugin.manifest.agents {
|
||||
let path = plugin.agent_path(agent_name);
|
||||
if path.exists() {
|
||||
agents.insert(agent_name.clone(), path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
agents
|
||||
}
|
||||
|
||||
/// Get all available skills from all plugins
|
||||
pub fn all_skills(&self) -> HashMap<String, PathBuf> {
|
||||
let mut skills = HashMap::new();
|
||||
|
||||
for plugin in &self.plugins {
|
||||
for skill_name in &plugin.manifest.skills {
|
||||
let path = plugin.skill_path(skill_name);
|
||||
if path.exists() {
|
||||
skills.insert(skill_name.clone(), path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
skills
|
||||
}
|
||||
|
||||
/// Get all MCP servers from all plugins
|
||||
pub fn all_mcp_servers(&self) -> Vec<(String, &McpServerConfig)> {
|
||||
let mut servers = Vec::new();
|
||||
|
||||
for plugin in &self.plugins {
|
||||
for server in &plugin.manifest.mcp_servers {
|
||||
servers.push((plugin.manifest.name.clone(), server));
|
||||
}
|
||||
}
|
||||
|
||||
servers
|
||||
}
|
||||
|
||||
/// Get all parsed commands
|
||||
pub fn load_all_commands(&self) -> Vec<SlashCommand> {
|
||||
let mut commands = Vec::new();
|
||||
for plugin in &self.plugins {
|
||||
for cmd_name in &plugin.manifest.commands {
|
||||
if let Ok(cmd) = plugin.parse_command(cmd_name) {
|
||||
commands.push(cmd);
|
||||
}
|
||||
}
|
||||
}
|
||||
commands
|
||||
}
|
||||
|
||||
/// Get all parsed agents
|
||||
pub fn load_all_agents(&self) -> Vec<AgentDefinition> {
|
||||
let mut agents = Vec::new();
|
||||
for plugin in &self.plugins {
|
||||
for agent_name in &plugin.manifest.agents {
|
||||
if let Ok(agent) = plugin.parse_agent(agent_name) {
|
||||
agents.push(agent);
|
||||
}
|
||||
}
|
||||
}
|
||||
agents
|
||||
}
|
||||
|
||||
/// Get all parsed skills
|
||||
pub fn load_all_skills(&self) -> Vec<Skill> {
|
||||
let mut skills = Vec::new();
|
||||
for plugin in &self.plugins {
|
||||
for skill_name in &plugin.manifest.skills {
|
||||
if let Ok(skill) = plugin.parse_skill(skill_name) {
|
||||
skills.push(skill);
|
||||
}
|
||||
}
|
||||
}
|
||||
skills
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PluginManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::fs;
|
||||
|
||||
fn create_test_plugin(dir: &Path) -> Result<()> {
|
||||
fs::create_dir_all(dir)?;
|
||||
fs::create_dir_all(dir.join("commands"))?;
|
||||
fs::create_dir_all(dir.join("agents"))?;
|
||||
fs::create_dir_all(dir.join("hooks"))?;
|
||||
|
||||
let manifest = PluginManifest {
|
||||
name: "test-plugin".to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
description: Some("A test plugin".to_string()),
|
||||
author: Some("Test Author".to_string()),
|
||||
commands: vec!["test-cmd".to_string()],
|
||||
agents: vec!["test-agent".to_string()],
|
||||
skills: vec![],
|
||||
hooks: {
|
||||
let mut h = HashMap::new();
|
||||
h.insert("PreToolUse".to_string(), "pre_tool_use.sh".to_string());
|
||||
h
|
||||
},
|
||||
mcp_servers: vec![],
|
||||
};
|
||||
|
||||
fs::write(
|
||||
dir.join("plugin.json"),
|
||||
serde_json::to_string_pretty(&manifest)?,
|
||||
)?;
|
||||
|
||||
fs::write(
|
||||
dir.join("commands/test-cmd.md"),
|
||||
"---\ndescription: A test command\nargument-hint: <file>\nallowed-tools: read,write\n---\n\nThis is a test command body.",
|
||||
)?;
|
||||
|
||||
fs::write(
|
||||
dir.join("agents/test-agent.md"),
|
||||
"---\nname: test-agent\ndescription: A test agent\ntools:\n - read\n - write\nmodel: sonnet\ncolor: blue\n---\n\nYou are a helpful test agent.",
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_plugin() -> Result<()> {
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
let plugin_dir = temp_dir.path().join("test-plugin");
|
||||
create_test_plugin(&plugin_dir)?;
|
||||
|
||||
let manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
|
||||
let plugin = manager.load_plugin(&plugin_dir)?;
|
||||
|
||||
assert_eq!(plugin.manifest.name, "test-plugin");
|
||||
assert_eq!(plugin.manifest.version, "1.0.0");
|
||||
assert_eq!(plugin.manifest.commands, vec!["test-cmd"]);
|
||||
assert_eq!(plugin.manifest.agents, vec!["test-agent"]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_all_plugins() -> Result<()> {
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
let plugin_dir = temp_dir.path().join("test-plugin");
|
||||
create_test_plugin(&plugin_dir)?;
|
||||
|
||||
let mut manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
|
||||
manager.load_all()?;
|
||||
|
||||
assert_eq!(manager.plugins().len(), 1);
|
||||
assert_eq!(manager.plugins()[0].manifest.name, "test-plugin");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_plugin() -> Result<()> {
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
let plugin_dir = temp_dir.path().join("test-plugin");
|
||||
create_test_plugin(&plugin_dir)?;
|
||||
|
||||
let mut manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
|
||||
manager.load_all()?;
|
||||
|
||||
let plugin = manager.find_plugin("test-plugin");
|
||||
assert!(plugin.is_some());
|
||||
assert_eq!(plugin.unwrap().manifest.name, "test-plugin");
|
||||
|
||||
let not_found = manager.find_plugin("nonexistent");
|
||||
assert!(not_found.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_commands() -> Result<()> {
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
let plugin_dir = temp_dir.path().join("test-plugin");
|
||||
create_test_plugin(&plugin_dir)?;
|
||||
|
||||
let mut manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
|
||||
manager.load_all()?;
|
||||
|
||||
let commands = manager.all_commands();
|
||||
assert_eq!(commands.len(), 1);
|
||||
assert!(commands.contains_key("test-cmd"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_command() -> Result<()> {
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
let plugin_dir = temp_dir.path().join("test-plugin");
|
||||
create_test_plugin(&plugin_dir)?;
|
||||
|
||||
let manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
|
||||
let plugin = manager.load_plugin(&plugin_dir)?;
|
||||
|
||||
let cmd = plugin.parse_command("test-cmd")?;
|
||||
assert_eq!(cmd.name, "test-cmd");
|
||||
assert_eq!(cmd.description, Some("A test command".to_string()));
|
||||
assert_eq!(cmd.argument_hint, Some("<file>".to_string()));
|
||||
assert_eq!(cmd.allowed_tools, Some(vec!["read".to_string(), "write".to_string()]));
|
||||
assert_eq!(cmd.body, "This is a test command body.");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_agent() -> Result<()> {
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
let plugin_dir = temp_dir.path().join("test-plugin");
|
||||
create_test_plugin(&plugin_dir)?;
|
||||
|
||||
let manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
|
||||
let plugin = manager.load_plugin(&plugin_dir)?;
|
||||
|
||||
let agent = plugin.parse_agent("test-agent")?;
|
||||
assert_eq!(agent.name, "test-agent");
|
||||
assert_eq!(agent.description, "A test agent");
|
||||
assert_eq!(agent.tools, vec!["read", "write"]);
|
||||
assert_eq!(agent.model, Some("sonnet".to_string()));
|
||||
assert_eq!(agent.color, Some("blue".to_string()));
|
||||
assert_eq!(agent.system_prompt, "You are a helpful test agent.");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_all_commands() -> Result<()> {
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
let plugin_dir = temp_dir.path().join("test-plugin");
|
||||
create_test_plugin(&plugin_dir)?;
|
||||
|
||||
let mut manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
|
||||
manager.load_all()?;
|
||||
|
||||
let commands = manager.load_all_commands();
|
||||
assert_eq!(commands.len(), 1);
|
||||
assert_eq!(commands[0].name, "test-cmd");
|
||||
assert_eq!(commands[0].description, Some("A test command".to_string()));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_all_agents() -> Result<()> {
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
let plugin_dir = temp_dir.path().join("test-plugin");
|
||||
create_test_plugin(&plugin_dir)?;
|
||||
|
||||
let mut manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
|
||||
manager.load_all()?;
|
||||
|
||||
let agents = manager.load_all_agents();
|
||||
assert_eq!(agents.len(), 1);
|
||||
assert_eq!(agents[0].name, "test-agent");
|
||||
assert_eq!(agents[0].description, "A test agent");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
175
crates/platform/plugins/tests/plugin_hooks_integration.rs
Normal file
175
crates/platform/plugins/tests/plugin_hooks_integration.rs
Normal file
@@ -0,0 +1,175 @@
|
||||
// End-to-end integration test for plugin hooks
|
||||
use color_eyre::eyre::Result;
|
||||
use plugins::PluginManager;
|
||||
use std::fs;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn create_test_plugin_with_hooks(plugin_dir: &std::path::Path) -> Result<()> {
|
||||
fs::create_dir_all(plugin_dir)?;
|
||||
|
||||
// Create plugin manifest
|
||||
let manifest = serde_json::json!({
|
||||
"name": "test-hook-plugin",
|
||||
"version": "1.0.0",
|
||||
"description": "Test plugin with hooks",
|
||||
"commands": [],
|
||||
"agents": [],
|
||||
"skills": [],
|
||||
"hooks": {},
|
||||
"mcp_servers": []
|
||||
});
|
||||
fs::write(
|
||||
plugin_dir.join("plugin.json"),
|
||||
serde_json::to_string_pretty(&manifest)?,
|
||||
)?;
|
||||
|
||||
// Create hooks directory and hooks.json
|
||||
let hooks_dir = plugin_dir.join("hooks");
|
||||
fs::create_dir_all(&hooks_dir)?;
|
||||
|
||||
let hooks_config = serde_json::json!({
|
||||
"description": "Validate edit and write operations",
|
||||
"hooks": {
|
||||
"PreToolUse": [
|
||||
{
|
||||
"matcher": "Edit|Write",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "python3 ${CLAUDE_PLUGIN_ROOT}/hooks/validate.py",
|
||||
"timeout": 5000
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"matcher": "Bash",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "echo 'Bash hook' && exit 0"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"PostToolUse": [
|
||||
{
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "echo 'Post-tool hook' && exit 0"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
});
|
||||
fs::write(
|
||||
hooks_dir.join("hooks.json"),
|
||||
serde_json::to_string_pretty(&hooks_config)?,
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_plugin_hooks_config() -> Result<()> {
|
||||
let temp_dir = TempDir::new()?;
|
||||
let plugin_dir = temp_dir.path().join("test-plugin");
|
||||
create_test_plugin_with_hooks(&plugin_dir)?;
|
||||
|
||||
// Load all plugins
|
||||
let mut plugin_manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
|
||||
plugin_manager.load_all()?;
|
||||
|
||||
assert_eq!(plugin_manager.plugins().len(), 1);
|
||||
let plugin = &plugin_manager.plugins()[0];
|
||||
|
||||
// Load hooks config
|
||||
let hooks_config = plugin.load_hooks_config()?;
|
||||
assert!(hooks_config.is_some());
|
||||
|
||||
let config = hooks_config.unwrap();
|
||||
assert_eq!(config.description, Some("Validate edit and write operations".to_string()));
|
||||
assert!(config.hooks.contains_key("PreToolUse"));
|
||||
assert!(config.hooks.contains_key("PostToolUse"));
|
||||
|
||||
// Check PreToolUse hooks
|
||||
let pre_tool_hooks = &config.hooks["PreToolUse"];
|
||||
assert_eq!(pre_tool_hooks.len(), 2);
|
||||
|
||||
// First matcher: Edit|Write
|
||||
assert_eq!(pre_tool_hooks[0].matcher, Some("Edit|Write".to_string()));
|
||||
assert_eq!(pre_tool_hooks[0].hooks.len(), 1);
|
||||
assert_eq!(pre_tool_hooks[0].hooks[0].hook_type, "command");
|
||||
assert!(pre_tool_hooks[0].hooks[0].command.as_ref().unwrap().contains("validate.py"));
|
||||
|
||||
// Second matcher: Bash
|
||||
assert_eq!(pre_tool_hooks[1].matcher, Some("Bash".to_string()));
|
||||
assert_eq!(pre_tool_hooks[1].hooks.len(), 1);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plugin_hooks_substitution() -> Result<()> {
|
||||
let temp_dir = TempDir::new()?;
|
||||
let plugin_dir = temp_dir.path().join("test-plugin");
|
||||
create_test_plugin_with_hooks(&plugin_dir)?;
|
||||
|
||||
// Load all plugins
|
||||
let mut plugin_manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
|
||||
plugin_manager.load_all()?;
|
||||
|
||||
assert_eq!(plugin_manager.plugins().len(), 1);
|
||||
let plugin = &plugin_manager.plugins()[0];
|
||||
|
||||
// Load hooks config and register
|
||||
let hooks_config = plugin.load_hooks_config()?.unwrap();
|
||||
let hooks_to_register = plugin.register_hooks_with_manager(&hooks_config);
|
||||
|
||||
// Check that ${CLAUDE_PLUGIN_ROOT} was substituted
|
||||
let edit_write_hook = hooks_to_register.iter()
|
||||
.find(|(event, _, pattern, _)| {
|
||||
event == "PreToolUse" && pattern.as_ref().map(|p| p.contains("Edit")).unwrap_or(false)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// The command should have the plugin path substituted
|
||||
assert!(edit_write_hook.1.contains(&plugin_dir.to_string_lossy().to_string()));
|
||||
assert!(edit_write_hook.1.contains("validate.py"));
|
||||
assert!(!edit_write_hook.1.contains("${CLAUDE_PLUGIN_ROOT}"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_plugins_with_hooks() -> Result<()> {
|
||||
let temp_dir = TempDir::new()?;
|
||||
|
||||
// Create two plugins with hooks
|
||||
let plugin1_dir = temp_dir.path().join("plugin1");
|
||||
create_test_plugin_with_hooks(&plugin1_dir)?;
|
||||
|
||||
let plugin2_dir = temp_dir.path().join("plugin2");
|
||||
create_test_plugin_with_hooks(&plugin2_dir)?;
|
||||
|
||||
// Load all plugins
|
||||
let mut plugin_manager = PluginManager::with_dirs(vec![temp_dir.path().to_path_buf()]);
|
||||
plugin_manager.load_all()?;
|
||||
|
||||
assert_eq!(plugin_manager.plugins().len(), 2);
|
||||
|
||||
// Collect all hooks from all plugins
|
||||
let mut total_hooks = 0;
|
||||
for plugin in plugin_manager.plugins() {
|
||||
if let Ok(Some(hooks_config)) = plugin.load_hooks_config() {
|
||||
let hooks = plugin.register_hooks_with_manager(&hooks_config);
|
||||
total_hooks += hooks.len();
|
||||
}
|
||||
}
|
||||
|
||||
// Each plugin has 3 hooks (2 PreToolUse + 1 PostToolUse)
|
||||
assert_eq!(total_hooks, 6);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
11
crates/tools/ask/Cargo.toml
Normal file
11
crates/tools/ask/Cargo.toml
Normal file
@@ -0,0 +1,11 @@
|
||||
[package]
|
||||
name = "tools-ask"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
tokio = { version = "1", features = ["sync"] }
|
||||
color-eyre = "0.6"
|
||||
60
crates/tools/ask/src/lib.rs
Normal file
60
crates/tools/ask/src/lib.rs
Normal file
@@ -0,0 +1,60 @@
|
||||
//! AskUserQuestion tool for interactive user input
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
/// A question option
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct QuestionOption {
|
||||
pub label: String,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
/// A question to ask the user
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Question {
|
||||
pub question: String,
|
||||
pub header: String,
|
||||
pub options: Vec<QuestionOption>,
|
||||
pub multi_select: bool,
|
||||
}
|
||||
|
||||
/// Request sent to the UI to ask questions
|
||||
#[derive(Debug)]
|
||||
pub struct AskRequest {
|
||||
pub questions: Vec<Question>,
|
||||
pub response_tx: oneshot::Sender<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
/// Channel for sending ask requests to the UI
|
||||
pub type AskSender = mpsc::Sender<AskRequest>;
|
||||
pub type AskReceiver = mpsc::Receiver<AskRequest>;
|
||||
|
||||
/// Create a channel pair for ask requests
|
||||
pub fn create_ask_channel() -> (AskSender, AskReceiver) {
|
||||
mpsc::channel(1)
|
||||
}
|
||||
|
||||
/// Ask the user questions (called by agent)
|
||||
pub async fn ask_user(
|
||||
sender: &AskSender,
|
||||
questions: Vec<Question>,
|
||||
) -> color_eyre::Result<HashMap<String, String>> {
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
|
||||
sender.send(AskRequest { questions, response_tx }).await
|
||||
.map_err(|_| color_eyre::eyre::eyre!("Failed to send ask request"))?;
|
||||
|
||||
response_rx.await
|
||||
.map_err(|_| color_eyre::eyre::eyre!("Failed to receive ask response"))
|
||||
}
|
||||
|
||||
/// Parse questions from JSON tool input
|
||||
pub fn parse_questions(input: &serde_json::Value) -> color_eyre::Result<Vec<Question>> {
|
||||
let questions = input.get("questions")
|
||||
.ok_or_else(|| color_eyre::eyre::eyre!("Missing 'questions' field"))?;
|
||||
|
||||
serde_json::from_value(questions.clone())
|
||||
.map_err(|e| color_eyre::eyre::eyre!("Invalid questions format: {}", e))
|
||||
}
|
||||
@@ -6,9 +6,11 @@ license.workspace = true
|
||||
rust-version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.39", features = ["process", "io-util", "time", "sync"] }
|
||||
tokio = { version = "1.39", features = ["process", "io-util", "time", "sync", "rt"] }
|
||||
color-eyre = "0.6"
|
||||
tempfile = "3.23.0"
|
||||
parking_lot = "0.12"
|
||||
uuid = { version = "1.0", features = ["v4"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { version = "1.39", features = ["macros", "rt-multi-thread"] }
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use color_eyre::eyre::{Result, eyre};
|
||||
use std::collections::HashMap;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
use parking_lot::RwLock;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::process::{Child, Command};
|
||||
use tokio::sync::Mutex;
|
||||
@@ -19,6 +22,7 @@ pub struct CommandOutput {
|
||||
|
||||
pub struct BashSession {
|
||||
child: Mutex<Child>,
|
||||
last_output: Option<String>,
|
||||
}
|
||||
|
||||
impl BashSession {
|
||||
@@ -40,6 +44,7 @@ impl BashSession {
|
||||
|
||||
Ok(Self {
|
||||
child: Mutex::new(child),
|
||||
last_output: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -54,7 +59,13 @@ impl BashSession {
|
||||
let result = timeout(timeout_duration, self.execute_internal(command)).await;
|
||||
|
||||
match result {
|
||||
Ok(output) => output,
|
||||
Ok(output) => {
|
||||
// Store the output for potential retrieval via BashOutput tool
|
||||
let combined = format!("{}{}", output.as_ref().map(|o| o.stdout.as_str()).unwrap_or(""),
|
||||
output.as_ref().map(|o| o.stderr.as_str()).unwrap_or(""));
|
||||
self.last_output = Some(combined);
|
||||
output
|
||||
},
|
||||
Err(_) => Err(eyre!("Command timed out after {}ms", timeout_duration.as_millis())),
|
||||
}
|
||||
}
|
||||
@@ -158,6 +169,106 @@ impl BashSession {
|
||||
}
|
||||
}
|
||||
|
||||
/// Manages background bash shells by ID
|
||||
#[derive(Clone, Default)]
|
||||
pub struct ShellManager {
|
||||
shells: Arc<RwLock<HashMap<String, BashSession>>>,
|
||||
}
|
||||
|
||||
impl ShellManager {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Start a new background shell, returns shell ID
|
||||
pub async fn start_shell(&self) -> Result<String> {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let session = BashSession::new().await?;
|
||||
self.shells.write().insert(id.clone(), session);
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Execute command in background shell
|
||||
pub async fn execute(&self, shell_id: &str, command: &str, timeout: Option<Duration>) -> Result<CommandOutput> {
|
||||
// We need to handle this carefully to avoid holding the lock across await
|
||||
// First check if the shell exists and clone what we need
|
||||
let exists = self.shells.read().contains_key(shell_id);
|
||||
if !exists {
|
||||
return Err(eyre!("Shell not found: {}", shell_id));
|
||||
}
|
||||
|
||||
// For now, we need to use a more complex approach since BashSession contains async operations
|
||||
// We'll execute and then update in a separate critical section
|
||||
let timeout_ms = timeout.map(|d| d.as_millis() as u64);
|
||||
|
||||
// Take temporary ownership for execution
|
||||
let mut session = {
|
||||
let mut shells = self.shells.write();
|
||||
shells.remove(shell_id)
|
||||
.ok_or_else(|| eyre!("Shell not found: {}", shell_id))?
|
||||
};
|
||||
|
||||
// Execute without holding the lock
|
||||
let result = session.execute(command, timeout_ms).await;
|
||||
|
||||
// Put the session back
|
||||
self.shells.write().insert(shell_id.to_string(), session);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Get output from a shell (BashOutput tool)
|
||||
pub fn get_output(&self, shell_id: &str) -> Result<Option<String>> {
|
||||
let shells = self.shells.read();
|
||||
let session = shells.get(shell_id)
|
||||
.ok_or_else(|| eyre!("Shell not found: {}", shell_id))?;
|
||||
// Return any buffered output
|
||||
Ok(session.last_output.clone())
|
||||
}
|
||||
|
||||
/// Kill a shell (KillShell tool)
|
||||
pub fn kill_shell(&self, shell_id: &str) -> Result<()> {
|
||||
let mut shells = self.shells.write();
|
||||
if shells.remove(shell_id).is_some() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(eyre!("Shell not found: {}", shell_id))
|
||||
}
|
||||
}
|
||||
|
||||
/// List active shells
|
||||
pub fn list_shells(&self) -> Vec<String> {
|
||||
self.shells.read().keys().cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Start a background bash command, returns shell ID
|
||||
pub async fn run_background(manager: &ShellManager, command: &str) -> Result<String> {
|
||||
let shell_id = manager.start_shell().await?;
|
||||
// Execute in background (non-blocking)
|
||||
tokio::spawn({
|
||||
let manager = manager.clone();
|
||||
let command = command.to_string();
|
||||
let shell_id = shell_id.clone();
|
||||
async move {
|
||||
let _ = manager.execute(&shell_id, &command, None).await;
|
||||
}
|
||||
});
|
||||
Ok(shell_id)
|
||||
}
|
||||
|
||||
/// Get output from background shell (BashOutput tool)
|
||||
pub fn bash_output(manager: &ShellManager, shell_id: &str) -> Result<String> {
|
||||
manager.get_output(shell_id)?
|
||||
.ok_or_else(|| eyre!("No output available"))
|
||||
}
|
||||
|
||||
/// Kill a background shell (KillShell tool)
|
||||
pub fn kill_shell(manager: &ShellManager, shell_id: &str) -> Result<String> {
|
||||
manager.kill_shell(shell_id)?;
|
||||
Ok(format!("Shell {} terminated", shell_id))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -13,6 +13,8 @@ grep-regex = "0.1"
|
||||
grep-searcher = "0.1"
|
||||
color-eyre = "0.6"
|
||||
similar = "2.7"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
humantime = "2.1"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.23.0"
|
||||
@@ -3,6 +3,7 @@ use ignore::WalkBuilder;
|
||||
use grep_regex::RegexMatcher;
|
||||
use grep_searcher::{sinks::UTF8, SearcherBuilder};
|
||||
use globset::Glob;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
|
||||
pub fn read_file(path: &str) -> Result<String> {
|
||||
@@ -127,4 +128,81 @@ pub fn grep(root: &str, pattern: &str) -> Result<Vec<(String, usize, String)>> {
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Edit operation for MultiEdit
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EditOperation {
|
||||
pub old_string: String,
|
||||
pub new_string: String,
|
||||
}
|
||||
|
||||
/// Perform multiple edits on a file atomically
|
||||
pub fn multi_edit_file(path: &str, edits: Vec<EditOperation>) -> Result<String> {
|
||||
let content = std::fs::read_to_string(path)?;
|
||||
let mut new_content = content.clone();
|
||||
|
||||
// Apply edits in order
|
||||
for edit in &edits {
|
||||
if !new_content.contains(&edit.old_string) {
|
||||
return Err(eyre!("String not found: '{}'", edit.old_string));
|
||||
}
|
||||
new_content = new_content.replacen(&edit.old_string, &edit.new_string, 1);
|
||||
}
|
||||
|
||||
// Create backup and write
|
||||
let backup_path = format!("{}.bak", path);
|
||||
std::fs::copy(path, &backup_path)?;
|
||||
std::fs::write(path, &new_content)?;
|
||||
|
||||
Ok(format!("Applied {} edits to {}", edits.len(), path))
|
||||
}
|
||||
|
||||
/// Entry in directory listing
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DirEntry {
|
||||
pub name: String,
|
||||
pub is_dir: bool,
|
||||
pub size: Option<u64>,
|
||||
pub modified: Option<String>,
|
||||
}
|
||||
|
||||
/// List contents of a directory
|
||||
pub fn list_directory(path: &str, show_hidden: bool) -> Result<Vec<DirEntry>> {
|
||||
let entries = std::fs::read_dir(path)?;
|
||||
let mut result = Vec::new();
|
||||
|
||||
for entry in entries {
|
||||
let entry = entry?;
|
||||
let name = entry.file_name().to_string_lossy().to_string();
|
||||
|
||||
// Skip hidden files unless requested
|
||||
if !show_hidden && name.starts_with('.') {
|
||||
continue;
|
||||
}
|
||||
|
||||
let metadata = entry.metadata()?;
|
||||
let modified = metadata.modified().ok().map(|t| {
|
||||
// Format as ISO 8601
|
||||
humantime::format_rfc3339(t).to_string()
|
||||
});
|
||||
|
||||
result.push(DirEntry {
|
||||
name,
|
||||
is_dir: metadata.is_dir(),
|
||||
size: if metadata.is_file() { Some(metadata.len()) } else { None },
|
||||
modified,
|
||||
});
|
||||
}
|
||||
|
||||
// Sort directories first, then alphabetically
|
||||
result.sort_by(|a, b| {
|
||||
match (a.is_dir, b.is_dir) {
|
||||
(true, false) => std::cmp::Ordering::Less,
|
||||
(false, true) => std::cmp::Ordering::Greater,
|
||||
_ => a.name.cmp(&b.name),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use tools_fs::{read_file, glob_list, grep, write_file, edit_file};
|
||||
use tools_fs::{read_file, glob_list, grep, write_file, edit_file, multi_edit_file, list_directory, EditOperation};
|
||||
use std::fs;
|
||||
use tempfile::tempdir;
|
||||
|
||||
@@ -101,4 +101,124 @@ fn edit_file_fails_on_no_match() {
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(err_msg.contains("not found") || err_msg.contains("String to replace"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multi_edit_file_applies_multiple_edits() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("test.txt");
|
||||
let original = "line 1\nline 2\nline 3\n";
|
||||
fs::write(&file_path, original).unwrap();
|
||||
|
||||
let edits = vec![
|
||||
EditOperation {
|
||||
old_string: "line 1".to_string(),
|
||||
new_string: "modified 1".to_string(),
|
||||
},
|
||||
EditOperation {
|
||||
old_string: "line 2".to_string(),
|
||||
new_string: "modified 2".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
let result = multi_edit_file(file_path.to_str().unwrap(), edits).unwrap();
|
||||
assert!(result.contains("Applied 2 edits"));
|
||||
|
||||
let content = read_file(file_path.to_str().unwrap()).unwrap();
|
||||
assert_eq!(content, "modified 1\nmodified 2\nline 3\n");
|
||||
|
||||
// Backup file should exist
|
||||
let backup_path = format!("{}.bak", file_path.display());
|
||||
assert!(std::path::Path::new(&backup_path).exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multi_edit_file_fails_on_missing_string() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("test.txt");
|
||||
let original = "line 1\nline 2\n";
|
||||
fs::write(&file_path, original).unwrap();
|
||||
|
||||
let edits = vec![
|
||||
EditOperation {
|
||||
old_string: "line 1".to_string(),
|
||||
new_string: "modified 1".to_string(),
|
||||
},
|
||||
EditOperation {
|
||||
old_string: "nonexistent".to_string(),
|
||||
new_string: "modified".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
let result = multi_edit_file(file_path.to_str().unwrap(), edits);
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(err_msg.contains("String not found"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_directory_shows_files_and_dirs() {
|
||||
let dir = tempdir().unwrap();
|
||||
let root = dir.path();
|
||||
|
||||
// Create test structure
|
||||
fs::write(root.join("file1.txt"), "content").unwrap();
|
||||
fs::write(root.join("file2.txt"), "content").unwrap();
|
||||
fs::create_dir(root.join("subdir")).unwrap();
|
||||
fs::write(root.join(".hidden"), "hidden content").unwrap();
|
||||
|
||||
let entries = list_directory(root.to_str().unwrap(), false).unwrap();
|
||||
|
||||
// Should find 2 files and 1 directory (hidden file excluded)
|
||||
assert_eq!(entries.len(), 3);
|
||||
|
||||
// Verify directory appears first (sorted)
|
||||
assert_eq!(entries[0].name, "subdir");
|
||||
assert!(entries[0].is_dir);
|
||||
|
||||
// Verify files
|
||||
let file_names: Vec<_> = entries.iter().skip(1).map(|e| e.name.as_str()).collect();
|
||||
assert!(file_names.contains(&"file1.txt"));
|
||||
assert!(file_names.contains(&"file2.txt"));
|
||||
|
||||
// Hidden file should not be present
|
||||
assert!(!entries.iter().any(|e| e.name == ".hidden"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_directory_shows_hidden_when_requested() {
|
||||
let dir = tempdir().unwrap();
|
||||
let root = dir.path();
|
||||
|
||||
fs::write(root.join("visible.txt"), "content").unwrap();
|
||||
fs::write(root.join(".hidden"), "hidden content").unwrap();
|
||||
|
||||
let entries = list_directory(root.to_str().unwrap(), true).unwrap();
|
||||
|
||||
// Should find both files
|
||||
assert!(entries.iter().any(|e| e.name == "visible.txt"));
|
||||
assert!(entries.iter().any(|e| e.name == ".hidden"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_directory_includes_metadata() {
|
||||
let dir = tempdir().unwrap();
|
||||
let root = dir.path();
|
||||
|
||||
fs::write(root.join("test.txt"), "hello world").unwrap();
|
||||
fs::create_dir(root.join("testdir")).unwrap();
|
||||
|
||||
let entries = list_directory(root.to_str().unwrap(), false).unwrap();
|
||||
|
||||
// Directory entry should have no size
|
||||
let dir_entry = entries.iter().find(|e| e.name == "testdir").unwrap();
|
||||
assert!(dir_entry.is_dir);
|
||||
assert!(dir_entry.size.is_none());
|
||||
assert!(dir_entry.modified.is_some());
|
||||
|
||||
// File entry should have size
|
||||
let file_entry = entries.iter().find(|e| e.name == "test.txt").unwrap();
|
||||
assert!(!file_entry.is_dir);
|
||||
assert_eq!(file_entry.size, Some(11)); // "hello world" is 11 bytes
|
||||
assert!(file_entry.modified.is_some());
|
||||
}
|
||||
14
crates/tools/notebook/Cargo.toml
Normal file
14
crates/tools/notebook/Cargo.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[package]
|
||||
name = "tools-notebook"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
rust-version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
color-eyre = "0.6"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.23.0"
|
||||
175
crates/tools/notebook/src/lib.rs
Normal file
175
crates/tools/notebook/src/lib.rs
Normal file
@@ -0,0 +1,175 @@
|
||||
use color_eyre::eyre::{Result, eyre};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
/// Jupyter notebook structure
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Notebook {
|
||||
pub cells: Vec<Cell>,
|
||||
pub metadata: NotebookMetadata,
|
||||
pub nbformat: i32,
|
||||
pub nbformat_minor: i32,
|
||||
}
|
||||
|
||||
/// Notebook cell
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Cell {
|
||||
pub cell_type: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub execution_count: Option<i32>,
|
||||
pub metadata: HashMap<String, Value>,
|
||||
pub source: Vec<String>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub outputs: Vec<Output>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
}
|
||||
|
||||
/// Cell output
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Output {
|
||||
pub output_type: String,
|
||||
#[serde(flatten)]
|
||||
pub data: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
/// Notebook metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NotebookMetadata {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub kernelspec: Option<KernelSpec>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub language_info: Option<LanguageInfo>,
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct KernelSpec {
|
||||
pub display_name: String,
|
||||
pub language: String,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LanguageInfo {
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub version: Option<String>,
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
/// Read a Jupyter notebook from a file
|
||||
pub fn read_notebook<P: AsRef<Path>>(path: P) -> Result<Notebook> {
|
||||
let content = fs::read_to_string(path)?;
|
||||
let notebook: Notebook = serde_json::from_str(&content)?;
|
||||
Ok(notebook)
|
||||
}
|
||||
|
||||
/// Write a Jupyter notebook to a file
|
||||
pub fn write_notebook<P: AsRef<Path>>(path: P, notebook: &Notebook) -> Result<()> {
|
||||
let content = serde_json::to_string_pretty(notebook)?;
|
||||
fs::write(path, content)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Edit operations for notebooks
|
||||
pub enum NotebookEdit {
|
||||
/// Replace cell at index with new source
|
||||
EditCell { index: usize, source: Vec<String> },
|
||||
/// Add a new cell at index
|
||||
AddCell { index: usize, cell: Cell },
|
||||
/// Delete cell at index
|
||||
DeleteCell { index: usize },
|
||||
}
|
||||
|
||||
/// Apply an edit to a notebook
|
||||
pub fn edit_notebook(notebook: &mut Notebook, edit: NotebookEdit) -> Result<()> {
|
||||
match edit {
|
||||
NotebookEdit::EditCell { index, source } => {
|
||||
if index >= notebook.cells.len() {
|
||||
return Err(eyre!("Cell index {} out of bounds (notebook has {} cells)", index, notebook.cells.len()));
|
||||
}
|
||||
notebook.cells[index].source = source;
|
||||
}
|
||||
NotebookEdit::AddCell { index, cell } => {
|
||||
if index > notebook.cells.len() {
|
||||
return Err(eyre!("Cell index {} out of bounds (notebook has {} cells)", index, notebook.cells.len()));
|
||||
}
|
||||
notebook.cells.insert(index, cell);
|
||||
}
|
||||
NotebookEdit::DeleteCell { index } => {
|
||||
if index >= notebook.cells.len() {
|
||||
return Err(eyre!("Cell index {} out of bounds (notebook has {} cells)", index, notebook.cells.len()));
|
||||
}
|
||||
notebook.cells.remove(index);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a new code cell
|
||||
pub fn new_code_cell(source: Vec<String>) -> Cell {
|
||||
Cell {
|
||||
cell_type: "code".to_string(),
|
||||
execution_count: None,
|
||||
metadata: HashMap::new(),
|
||||
source,
|
||||
outputs: Vec::new(),
|
||||
id: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new markdown cell
|
||||
pub fn new_markdown_cell(source: Vec<String>) -> Cell {
|
||||
Cell {
|
||||
cell_type: "markdown".to_string(),
|
||||
execution_count: None,
|
||||
metadata: HashMap::new(),
|
||||
source,
|
||||
outputs: Vec::new(),
|
||||
id: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get cell source as a single string
|
||||
pub fn cell_source_as_string(cell: &Cell) -> String {
|
||||
cell.source.join("")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn cell_source_concatenation() {
|
||||
let cell = Cell {
|
||||
cell_type: "code".to_string(),
|
||||
execution_count: None,
|
||||
metadata: HashMap::new(),
|
||||
source: vec!["import pandas as pd\n".to_string(), "df = pd.DataFrame()\n".to_string()],
|
||||
outputs: Vec::new(),
|
||||
id: None,
|
||||
};
|
||||
|
||||
let source = cell_source_as_string(&cell);
|
||||
assert_eq!(source, "import pandas as pd\ndf = pd.DataFrame()\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_code_cell_creation() {
|
||||
let cell = new_code_cell(vec!["print('hello')\n".to_string()]);
|
||||
assert_eq!(cell.cell_type, "code");
|
||||
assert!(cell.outputs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_markdown_cell_creation() {
|
||||
let cell = new_markdown_cell(vec!["# Title\n".to_string()]);
|
||||
assert_eq!(cell.cell_type, "markdown");
|
||||
}
|
||||
}
|
||||
280
crates/tools/notebook/tests/notebook_tests.rs
Normal file
280
crates/tools/notebook/tests/notebook_tests.rs
Normal file
@@ -0,0 +1,280 @@
|
||||
use tools_notebook::*;
|
||||
use std::fs;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn notebook_round_trip_preserves_metadata() {
|
||||
let dir = tempdir().unwrap();
|
||||
let notebook_path = dir.path().join("test.ipynb");
|
||||
|
||||
// Create a sample notebook with metadata
|
||||
let notebook_json = r##"{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"source": ["print('hello world')"],
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": ["# Test Notebook", "This is a test."]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.9.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}"##;
|
||||
|
||||
fs::write(¬ebook_path, notebook_json).unwrap();
|
||||
|
||||
// Read the notebook
|
||||
let notebook = read_notebook(¬ebook_path).unwrap();
|
||||
|
||||
// Verify structure
|
||||
assert_eq!(notebook.cells.len(), 2);
|
||||
assert_eq!(notebook.nbformat, 4);
|
||||
assert_eq!(notebook.nbformat_minor, 5);
|
||||
|
||||
// Verify metadata
|
||||
assert!(notebook.metadata.kernelspec.is_some());
|
||||
let kernelspec = notebook.metadata.kernelspec.as_ref().unwrap();
|
||||
assert_eq!(kernelspec.language, "python");
|
||||
assert_eq!(kernelspec.name, "python3");
|
||||
|
||||
assert!(notebook.metadata.language_info.is_some());
|
||||
let lang_info = notebook.metadata.language_info.as_ref().unwrap();
|
||||
assert_eq!(lang_info.name, "python");
|
||||
assert_eq!(lang_info.version, Some("3.9.0".to_string()));
|
||||
|
||||
// Write it back
|
||||
let output_path = dir.path().join("output.ipynb");
|
||||
write_notebook(&output_path, ¬ebook).unwrap();
|
||||
|
||||
// Read it again
|
||||
let notebook2 = read_notebook(&output_path).unwrap();
|
||||
|
||||
// Verify metadata is preserved
|
||||
assert_eq!(notebook2.nbformat, 4);
|
||||
assert_eq!(notebook2.nbformat_minor, 5);
|
||||
assert!(notebook2.metadata.kernelspec.is_some());
|
||||
assert_eq!(
|
||||
notebook2.metadata.kernelspec.as_ref().unwrap().language,
|
||||
"python"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn notebook_edit_cell_content() {
|
||||
let dir = tempdir().unwrap();
|
||||
let notebook_path = dir.path().join("test.ipynb");
|
||||
|
||||
let notebook_json = r##"{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"source": ["x = 1"],
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"source": ["y = 2"],
|
||||
"outputs": []
|
||||
}
|
||||
],
|
||||
"metadata": {},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}"##;
|
||||
|
||||
fs::write(¬ebook_path, notebook_json).unwrap();
|
||||
|
||||
let mut notebook = read_notebook(¬ebook_path).unwrap();
|
||||
|
||||
// Edit the first cell
|
||||
edit_notebook(
|
||||
&mut notebook,
|
||||
NotebookEdit::EditCell {
|
||||
index: 0,
|
||||
source: vec!["x = 10\n".to_string(), "print(x)\n".to_string()],
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Verify the edit
|
||||
assert_eq!(notebook.cells[0].source.len(), 2);
|
||||
assert_eq!(notebook.cells[0].source[0], "x = 10\n");
|
||||
assert_eq!(notebook.cells[0].source[1], "print(x)\n");
|
||||
|
||||
// Second cell should be unchanged
|
||||
assert_eq!(notebook.cells[1].source[0], "y = 2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn notebook_add_delete_cells() {
|
||||
let dir = tempdir().unwrap();
|
||||
let notebook_path = dir.path().join("test.ipynb");
|
||||
|
||||
let notebook_json = r##"{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"source": ["x = 1"],
|
||||
"outputs": []
|
||||
}
|
||||
],
|
||||
"metadata": {},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}"##;
|
||||
|
||||
fs::write(¬ebook_path, notebook_json).unwrap();
|
||||
|
||||
let mut notebook = read_notebook(¬ebook_path).unwrap();
|
||||
assert_eq!(notebook.cells.len(), 1);
|
||||
|
||||
// Add a cell at the end
|
||||
let new_cell = new_code_cell(vec!["y = 2\n".to_string()]);
|
||||
edit_notebook(
|
||||
&mut notebook,
|
||||
NotebookEdit::AddCell {
|
||||
index: 1,
|
||||
cell: new_cell,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(notebook.cells.len(), 2);
|
||||
assert_eq!(notebook.cells[1].source[0], "y = 2\n");
|
||||
|
||||
// Add a cell at the beginning
|
||||
let first_cell = new_markdown_cell(vec!["# Header\n".to_string()]);
|
||||
edit_notebook(
|
||||
&mut notebook,
|
||||
NotebookEdit::AddCell {
|
||||
index: 0,
|
||||
cell: first_cell,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(notebook.cells.len(), 3);
|
||||
assert_eq!(notebook.cells[0].cell_type, "markdown");
|
||||
assert_eq!(notebook.cells[0].source[0], "# Header\n");
|
||||
assert_eq!(notebook.cells[1].source[0], "x = 1"); // Original first cell is now second
|
||||
|
||||
// Delete the middle cell
|
||||
edit_notebook(&mut notebook, NotebookEdit::DeleteCell { index: 1 }).unwrap();
|
||||
|
||||
assert_eq!(notebook.cells.len(), 2);
|
||||
assert_eq!(notebook.cells[0].cell_type, "markdown");
|
||||
assert_eq!(notebook.cells[1].source[0], "y = 2\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn notebook_edit_out_of_bounds() {
|
||||
let dir = tempdir().unwrap();
|
||||
let notebook_path = dir.path().join("test.ipynb");
|
||||
|
||||
let notebook_json = r##"{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"source": ["x = 1\n"],
|
||||
"outputs": []
|
||||
}
|
||||
],
|
||||
"metadata": {},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}"##;
|
||||
|
||||
fs::write(¬ebook_path, notebook_json).unwrap();
|
||||
|
||||
let mut notebook = read_notebook(¬ebook_path).unwrap();
|
||||
|
||||
// Try to edit non-existent cell
|
||||
let result = edit_notebook(
|
||||
&mut notebook,
|
||||
NotebookEdit::EditCell {
|
||||
index: 5,
|
||||
source: vec!["bad\n".to_string()],
|
||||
},
|
||||
);
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("out of bounds"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn notebook_with_outputs_preserved() {
|
||||
let dir = tempdir().unwrap();
|
||||
let notebook_path = dir.path().join("test.ipynb");
|
||||
|
||||
let notebook_json = r##"{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"source": ["print('hello')\n"],
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": ["hello\n"]
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}"##;
|
||||
|
||||
fs::write(¬ebook_path, notebook_json).unwrap();
|
||||
|
||||
let notebook = read_notebook(¬ebook_path).unwrap();
|
||||
|
||||
assert_eq!(notebook.cells[0].outputs.len(), 1);
|
||||
assert_eq!(notebook.cells[0].outputs[0].output_type, "stream");
|
||||
|
||||
// Write and read back
|
||||
let output_path = dir.path().join("output.ipynb");
|
||||
write_notebook(&output_path, ¬ebook).unwrap();
|
||||
|
||||
let notebook2 = read_notebook(&output_path).unwrap();
|
||||
assert_eq!(notebook2.cells[0].outputs.len(), 1);
|
||||
assert_eq!(notebook2.cells[0].outputs[0].output_type, "stream");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cell_source_as_string_concatenates() {
|
||||
let cell = new_code_cell(vec![
|
||||
"import numpy as np\n".to_string(),
|
||||
"arr = np.array([1, 2, 3])\n".to_string(),
|
||||
]);
|
||||
|
||||
let source = cell_source_as_string(&cell);
|
||||
assert_eq!(source, "import numpy as np\narr = np.array([1, 2, 3])\n");
|
||||
}
|
||||
18
crates/tools/plan/Cargo.toml
Normal file
18
crates/tools/plan/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "tools-plan"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
license = "AGPL-3.0"
|
||||
description = "Planning mode tools for the Owlen agent"
|
||||
|
||||
[dependencies]
|
||||
color-eyre = "0.6"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
uuid = { version = "1.0", features = ["v4"] }
|
||||
tokio = { version = "1", features = ["fs"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.13"
|
||||
tokio = { version = "1", features = ["rt", "macros"] }
|
||||
296
crates/tools/plan/src/lib.rs
Normal file
296
crates/tools/plan/src/lib.rs
Normal file
@@ -0,0 +1,296 @@
|
||||
//! Planning mode tools for the Owlen agent
|
||||
//!
|
||||
//! Provides EnterPlanMode and ExitPlanMode tools that allow the agent
|
||||
//! to enter a planning phase where only read-only operations are allowed,
|
||||
//! and then present a plan for user approval.
|
||||
|
||||
use color_eyre::eyre::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use chrono::{DateTime, Utc};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Agent mode - normal execution or planning
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum AgentMode {
|
||||
/// Normal mode - all tools available per permission settings
|
||||
Normal,
|
||||
/// Planning mode - only read-only tools allowed
|
||||
Planning {
|
||||
/// Path to the plan file being written
|
||||
plan_file: PathBuf,
|
||||
/// When planning mode was entered
|
||||
started_at: DateTime<Utc>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Default for AgentMode {
|
||||
fn default() -> Self {
|
||||
Self::Normal
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentMode {
|
||||
/// Check if we're in planning mode
|
||||
pub fn is_planning(&self) -> bool {
|
||||
matches!(self, AgentMode::Planning { .. })
|
||||
}
|
||||
|
||||
/// Get the plan file path if in planning mode
|
||||
pub fn plan_file(&self) -> Option<&PathBuf> {
|
||||
match self {
|
||||
AgentMode::Planning { plan_file, .. } => Some(plan_file),
|
||||
AgentMode::Normal => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Plan file metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PlanMetadata {
|
||||
pub id: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub status: PlanStatus,
|
||||
pub title: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum PlanStatus {
|
||||
/// Plan is being written
|
||||
Draft,
|
||||
/// Plan is awaiting user approval
|
||||
PendingApproval,
|
||||
/// Plan was approved by user
|
||||
Approved,
|
||||
/// Plan was rejected by user
|
||||
Rejected,
|
||||
}
|
||||
|
||||
/// Manager for plan files
|
||||
pub struct PlanManager {
|
||||
plans_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl PlanManager {
|
||||
/// Create a new plan manager
|
||||
pub fn new(project_root: PathBuf) -> Self {
|
||||
let plans_dir = project_root.join(".owlen").join("plans");
|
||||
Self { plans_dir }
|
||||
}
|
||||
|
||||
/// Create a new plan manager with custom directory
|
||||
pub fn with_dir(plans_dir: PathBuf) -> Self {
|
||||
Self { plans_dir }
|
||||
}
|
||||
|
||||
/// Get the plans directory
|
||||
pub fn plans_dir(&self) -> &PathBuf {
|
||||
&self.plans_dir
|
||||
}
|
||||
|
||||
/// Ensure the plans directory exists
|
||||
pub async fn ensure_dir(&self) -> Result<()> {
|
||||
tokio::fs::create_dir_all(&self.plans_dir).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generate a unique plan file name
|
||||
/// Uses a format like: <adjective>-<verb>-<noun>.md
|
||||
pub fn generate_plan_name(&self) -> String {
|
||||
// Simple word lists for readable names
|
||||
let adjectives = ["cozy", "swift", "clever", "bright", "calm", "eager", "gentle", "happy"];
|
||||
let verbs = ["dancing", "jumping", "running", "flying", "singing", "coding", "building", "thinking"];
|
||||
let nouns = ["owl", "fox", "bear", "wolf", "hawk", "deer", "lion", "tiger"];
|
||||
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
let uuid = Uuid::new_v4();
|
||||
let mut hasher = DefaultHasher::new();
|
||||
uuid.hash(&mut hasher);
|
||||
let hash = hasher.finish();
|
||||
|
||||
let adj = adjectives[(hash % adjectives.len() as u64) as usize];
|
||||
let verb = verbs[((hash >> 8) % verbs.len() as u64) as usize];
|
||||
let noun = nouns[((hash >> 16) % nouns.len() as u64) as usize];
|
||||
|
||||
format!("{}-{}-{}.md", adj, verb, noun)
|
||||
}
|
||||
|
||||
/// Create a new plan file and return the path
|
||||
pub async fn create_plan(&self) -> Result<PathBuf> {
|
||||
self.ensure_dir().await?;
|
||||
|
||||
let filename = self.generate_plan_name();
|
||||
let plan_path = self.plans_dir.join(&filename);
|
||||
|
||||
// Create initial plan file with metadata
|
||||
let metadata = PlanMetadata {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
created_at: Utc::now(),
|
||||
status: PlanStatus::Draft,
|
||||
title: None,
|
||||
};
|
||||
|
||||
let initial_content = format!(
|
||||
"<!-- plan-id: {} -->\n<!-- status: draft -->\n\n# Implementation Plan\n\n",
|
||||
metadata.id
|
||||
);
|
||||
|
||||
tokio::fs::write(&plan_path, initial_content).await?;
|
||||
|
||||
Ok(plan_path)
|
||||
}
|
||||
|
||||
/// Write content to a plan file
|
||||
pub async fn write_plan(&self, path: &PathBuf, content: &str) -> Result<()> {
|
||||
// Preserve the metadata header if it exists
|
||||
let existing = tokio::fs::read_to_string(path).await.unwrap_or_default();
|
||||
|
||||
// Extract metadata lines (lines starting with <!--)
|
||||
let metadata_lines: Vec<&str> = existing
|
||||
.lines()
|
||||
.take_while(|line| line.starts_with("<!--"))
|
||||
.collect();
|
||||
|
||||
// Update status to pending approval
|
||||
let mut new_content = String::new();
|
||||
for line in &metadata_lines {
|
||||
if line.contains("status:") {
|
||||
new_content.push_str("<!-- status: pending_approval -->\n");
|
||||
} else {
|
||||
new_content.push_str(line);
|
||||
new_content.push('\n');
|
||||
}
|
||||
}
|
||||
new_content.push('\n');
|
||||
new_content.push_str(content);
|
||||
|
||||
tokio::fs::write(path, new_content).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read a plan file
|
||||
pub async fn read_plan(&self, path: &PathBuf) -> Result<String> {
|
||||
let content = tokio::fs::read_to_string(path).await?;
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
/// Update plan status
|
||||
pub async fn set_status(&self, path: &PathBuf, status: PlanStatus) -> Result<()> {
|
||||
let content = tokio::fs::read_to_string(path).await?;
|
||||
|
||||
let status_str = match status {
|
||||
PlanStatus::Draft => "draft",
|
||||
PlanStatus::PendingApproval => "pending_approval",
|
||||
PlanStatus::Approved => "approved",
|
||||
PlanStatus::Rejected => "rejected",
|
||||
};
|
||||
|
||||
// Replace status line
|
||||
let updated: String = content
|
||||
.lines()
|
||||
.map(|line| {
|
||||
if line.contains("<!-- status:") {
|
||||
format!("<!-- status: {} -->", status_str)
|
||||
} else {
|
||||
line.to_string()
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
tokio::fs::write(path, updated).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all plan files
|
||||
pub async fn list_plans(&self) -> Result<Vec<PathBuf>> {
|
||||
let mut plans = Vec::new();
|
||||
|
||||
if !self.plans_dir.exists() {
|
||||
return Ok(plans);
|
||||
}
|
||||
|
||||
let mut entries = tokio::fs::read_dir(&self.plans_dir).await?;
|
||||
while let Some(entry) = entries.next_entry().await? {
|
||||
let path = entry.path();
|
||||
if path.extension().map_or(false, |ext| ext == "md") {
|
||||
plans.push(path);
|
||||
}
|
||||
}
|
||||
|
||||
plans.sort();
|
||||
Ok(plans)
|
||||
}
|
||||
}
|
||||
|
||||
/// Enter planning mode
|
||||
pub fn enter_plan_mode(plan_file: PathBuf) -> AgentMode {
|
||||
AgentMode::Planning {
|
||||
plan_file,
|
||||
started_at: Utc::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Exit planning mode and return to normal
|
||||
pub fn exit_plan_mode() -> AgentMode {
|
||||
AgentMode::Normal
|
||||
}
|
||||
|
||||
/// Check if a tool is allowed in planning mode
|
||||
/// Only read-only tools are allowed
|
||||
pub fn is_tool_allowed_in_plan_mode(tool_name: &str) -> bool {
|
||||
matches!(
|
||||
tool_name,
|
||||
"read" | "glob" | "grep" | "ls" | "web_fetch" | "web_search" |
|
||||
"todo_write" | "ask_user" | "exit_plan_mode"
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_plan() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = PlanManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let plan_path = manager.create_plan().await.unwrap();
|
||||
assert!(plan_path.exists());
|
||||
assert!(plan_path.extension().map_or(false, |ext| ext == "md"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_write_and_read_plan() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = PlanManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let plan_path = manager.create_plan().await.unwrap();
|
||||
|
||||
manager.write_plan(&plan_path, "# My Plan\n\nStep 1: Do something").await.unwrap();
|
||||
|
||||
let content = manager.read_plan(&plan_path).await.unwrap();
|
||||
assert!(content.contains("My Plan"));
|
||||
assert!(content.contains("pending_approval"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plan_mode_check() {
|
||||
assert!(is_tool_allowed_in_plan_mode("read"));
|
||||
assert!(is_tool_allowed_in_plan_mode("glob"));
|
||||
assert!(is_tool_allowed_in_plan_mode("grep"));
|
||||
assert!(!is_tool_allowed_in_plan_mode("write"));
|
||||
assert!(!is_tool_allowed_in_plan_mode("bash"));
|
||||
assert!(!is_tool_allowed_in_plan_mode("edit"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_mode_default() {
|
||||
let mode = AgentMode::default();
|
||||
assert!(!mode.is_planning());
|
||||
assert!(mode.plan_file().is_none());
|
||||
}
|
||||
}
|
||||
16
crates/tools/skill/Cargo.toml
Normal file
16
crates/tools/skill/Cargo.toml
Normal file
@@ -0,0 +1,16 @@
|
||||
[package]
|
||||
name = "tools-skill"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
rust-version.workspace = true
|
||||
description = "Skill invocation tool for the Owlen agent"
|
||||
|
||||
[dependencies]
|
||||
color-eyre = "0.6"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
plugins = { path = "../../platform/plugins" }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.13"
|
||||
275
crates/tools/skill/src/lib.rs
Normal file
275
crates/tools/skill/src/lib.rs
Normal file
@@ -0,0 +1,275 @@
|
||||
//! Skill invocation tool for the Owlen agent
|
||||
//!
|
||||
//! Provides the Skill tool that allows the agent to invoke skills
|
||||
//! from plugins programmatically during a conversation.
|
||||
|
||||
use color_eyre::eyre::{Result, eyre};
|
||||
use plugins::PluginManager;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Parameters for the Skill tool
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SkillParams {
|
||||
/// Name of the skill to invoke (e.g., "pdf", "xlsx", or "plugin:skill")
|
||||
pub skill: String,
|
||||
}
|
||||
|
||||
/// Result of skill invocation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SkillResult {
|
||||
/// The skill name that was invoked
|
||||
pub skill_name: String,
|
||||
/// The skill content (instructions)
|
||||
pub content: String,
|
||||
/// Source of the skill (plugin name)
|
||||
pub source: String,
|
||||
}
|
||||
|
||||
/// Skill registry for looking up and invoking skills
|
||||
pub struct SkillRegistry {
|
||||
/// Local skills directory (e.g., .owlen/skills/)
|
||||
local_skills_dir: PathBuf,
|
||||
/// Plugin manager for finding plugin skills
|
||||
plugin_manager: Option<PluginManager>,
|
||||
}
|
||||
|
||||
impl SkillRegistry {
|
||||
/// Create a new skill registry
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
local_skills_dir: PathBuf::from(".owlen/skills"),
|
||||
plugin_manager: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with custom local skills directory
|
||||
pub fn with_local_dir(mut self, dir: PathBuf) -> Self {
|
||||
self.local_skills_dir = dir;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the plugin manager for discovering plugin skills
|
||||
pub fn with_plugin_manager(mut self, pm: PluginManager) -> Self {
|
||||
self.plugin_manager = Some(pm);
|
||||
self
|
||||
}
|
||||
|
||||
/// Find and load a skill by name
|
||||
///
|
||||
/// Skill names can be:
|
||||
/// - Simple name: "pdf" (searches local, then plugins)
|
||||
/// - Fully qualified: "plugin-name:skill-name"
|
||||
pub fn invoke(&self, skill_name: &str) -> Result<SkillResult> {
|
||||
// Check for fully qualified name (plugin:skill)
|
||||
if let Some((plugin_name, skill_id)) = skill_name.split_once(':') {
|
||||
return self.load_plugin_skill(plugin_name, skill_id);
|
||||
}
|
||||
|
||||
// Try local skills first
|
||||
if let Ok(result) = self.load_local_skill(skill_name) {
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
// Try plugins
|
||||
if let Some(pm) = &self.plugin_manager {
|
||||
for plugin in pm.plugins() {
|
||||
if let Ok(result) = self.load_skill_from_plugin(plugin, skill_name) {
|
||||
return Ok(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(eyre!(
|
||||
"Skill '{}' not found.\n\nAvailable skills:\n{}",
|
||||
skill_name,
|
||||
self.list_available_skills().join("\n")
|
||||
))
|
||||
}
|
||||
|
||||
/// Load a local skill from .owlen/skills/
|
||||
fn load_local_skill(&self, skill_name: &str) -> Result<SkillResult> {
|
||||
// Try with and without .md extension
|
||||
let skill_file = self.local_skills_dir.join(format!("{}.md", skill_name));
|
||||
let skill_dir = self.local_skills_dir.join(skill_name).join("SKILL.md");
|
||||
|
||||
let content = if skill_file.exists() {
|
||||
fs::read_to_string(&skill_file)?
|
||||
} else if skill_dir.exists() {
|
||||
fs::read_to_string(&skill_dir)?
|
||||
} else {
|
||||
return Err(eyre!("Local skill '{}' not found", skill_name));
|
||||
};
|
||||
|
||||
Ok(SkillResult {
|
||||
skill_name: skill_name.to_string(),
|
||||
content: parse_skill_content(&content),
|
||||
source: "local".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Load a skill from a specific plugin
|
||||
fn load_plugin_skill(&self, plugin_name: &str, skill_name: &str) -> Result<SkillResult> {
|
||||
let pm = self.plugin_manager.as_ref()
|
||||
.ok_or_else(|| eyre!("Plugin manager not available"))?;
|
||||
|
||||
for plugin in pm.plugins() {
|
||||
if plugin.manifest.name == plugin_name {
|
||||
return self.load_skill_from_plugin(plugin, skill_name);
|
||||
}
|
||||
}
|
||||
|
||||
Err(eyre!("Plugin '{}' not found", plugin_name))
|
||||
}
|
||||
|
||||
/// Load a skill from a plugin
|
||||
fn load_skill_from_plugin(&self, plugin: &plugins::Plugin, skill_name: &str) -> Result<SkillResult> {
|
||||
let skill_names = plugin.all_skill_names();
|
||||
|
||||
if !skill_names.contains(&skill_name.to_string()) {
|
||||
return Err(eyre!("Skill '{}' not found in plugin '{}'", skill_name, plugin.manifest.name));
|
||||
}
|
||||
|
||||
// Skills are in skills/<name>/SKILL.md
|
||||
let skill_path = plugin.base_path.join("skills").join(skill_name).join("SKILL.md");
|
||||
|
||||
if !skill_path.exists() {
|
||||
return Err(eyre!("Skill file not found: {:?}", skill_path));
|
||||
}
|
||||
|
||||
let content = fs::read_to_string(&skill_path)?;
|
||||
|
||||
Ok(SkillResult {
|
||||
skill_name: skill_name.to_string(),
|
||||
content: parse_skill_content(&content),
|
||||
source: format!("plugin:{}", plugin.manifest.name),
|
||||
})
|
||||
}
|
||||
|
||||
/// List all available skills
|
||||
pub fn list_available_skills(&self) -> Vec<String> {
|
||||
let mut skills = Vec::new();
|
||||
|
||||
// Local skills
|
||||
if self.local_skills_dir.exists() {
|
||||
if let Ok(entries) = fs::read_dir(&self.local_skills_dir) {
|
||||
for entry in entries.filter_map(|e| e.ok()) {
|
||||
let path = entry.path();
|
||||
if path.is_file() && path.extension().map_or(false, |e| e == "md") {
|
||||
if let Some(name) = path.file_stem().and_then(|s| s.to_str()) {
|
||||
skills.push(format!(" - {} (local)", name));
|
||||
}
|
||||
} else if path.is_dir() && path.join("SKILL.md").exists() {
|
||||
if let Some(name) = path.file_name().and_then(|s| s.to_str()) {
|
||||
skills.push(format!(" - {} (local)", name));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Plugin skills
|
||||
if let Some(pm) = &self.plugin_manager {
|
||||
for plugin in pm.plugins() {
|
||||
for skill_name in plugin.all_skill_names() {
|
||||
skills.push(format!(" - {} (plugin:{})", skill_name, plugin.manifest.name));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if skills.is_empty() {
|
||||
skills.push(" (no skills available)".to_string());
|
||||
}
|
||||
|
||||
skills
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SkillRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse skill content, extracting the body (stripping YAML frontmatter)
|
||||
fn parse_skill_content(content: &str) -> String {
|
||||
// Check for YAML frontmatter
|
||||
if content.starts_with("---") {
|
||||
// Find the end of frontmatter
|
||||
if let Some(end_idx) = content[3..].find("---") {
|
||||
let body_start = end_idx + 6; // Skip past the closing ---
|
||||
if body_start < content.len() {
|
||||
return content[body_start..].trim().to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
content.trim().to_string()
|
||||
}
|
||||
|
||||
/// Execute the Skill tool
|
||||
pub fn execute_skill(params: &SkillParams, registry: &SkillRegistry) -> Result<String> {
|
||||
let result = registry.invoke(¶ms.skill)?;
|
||||
|
||||
// Format output for injection into conversation
|
||||
Ok(format!(
|
||||
"## Skill: {} ({})\n\n{}",
|
||||
result.skill_name,
|
||||
result.source,
|
||||
result.content
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::fs;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn test_parse_skill_content_with_frontmatter() {
|
||||
let content = r#"---
|
||||
name: test-skill
|
||||
description: A test skill
|
||||
---
|
||||
|
||||
# Test Skill
|
||||
|
||||
This is the skill content."#;
|
||||
|
||||
let parsed = parse_skill_content(content);
|
||||
assert!(parsed.starts_with("# Test Skill"));
|
||||
assert!(!parsed.contains("name: test-skill"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_skill_content_without_frontmatter() {
|
||||
let content = "# Just Content\n\nNo frontmatter here.";
|
||||
let parsed = parse_skill_content(content);
|
||||
assert_eq!(parsed, content.trim());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skill_registry_local() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let skills_dir = temp_dir.path().join(".owlen/skills");
|
||||
fs::create_dir_all(&skills_dir).unwrap();
|
||||
|
||||
// Create a test skill
|
||||
fs::write(skills_dir.join("test.md"), "# Test Skill\n\nTest content.").unwrap();
|
||||
|
||||
let registry = SkillRegistry::new().with_local_dir(skills_dir);
|
||||
let result = registry.invoke("test").unwrap();
|
||||
|
||||
assert_eq!(result.skill_name, "test");
|
||||
assert_eq!(result.source, "local");
|
||||
assert!(result.content.contains("Test Skill"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skill_not_found() {
|
||||
let registry = SkillRegistry::new();
|
||||
assert!(registry.invoke("nonexistent").is_err());
|
||||
}
|
||||
}
|
||||
16
crates/tools/task/Cargo.toml
Normal file
16
crates/tools/task/Cargo.toml
Normal file
@@ -0,0 +1,16 @@
|
||||
[package]
|
||||
name = "tools-task"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
rust-version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
color-eyre = "0.6"
|
||||
permissions = { path = "../../platform/permissions" }
|
||||
plugins = { path = "../../platform/plugins" }
|
||||
parking_lot = "0.12"
|
||||
|
||||
[dev-dependencies]
|
||||
335
crates/tools/task/src/lib.rs
Normal file
335
crates/tools/task/src/lib.rs
Normal file
@@ -0,0 +1,335 @@
|
||||
// Note: Result and eyre will be used by spawn_subagent when implemented
|
||||
#[allow(unused_imports)]
|
||||
use color_eyre::eyre::{Result, eyre};
|
||||
use parking_lot::RwLock;
|
||||
use permissions::Tool;
|
||||
use plugins::AgentDefinition;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Configuration for spawning a subagent
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SubagentConfig {
|
||||
/// Agent type/name (e.g., "code-reviewer", "explore")
|
||||
pub agent_type: String,
|
||||
|
||||
/// Task prompt for the agent
|
||||
pub prompt: String,
|
||||
|
||||
/// Optional model override
|
||||
pub model: Option<String>,
|
||||
|
||||
/// Tool whitelist (if None, uses agent's default)
|
||||
pub tools: Option<Vec<String>>,
|
||||
|
||||
/// Parsed agent definition (if from plugin)
|
||||
pub definition: Option<AgentDefinition>,
|
||||
}
|
||||
|
||||
impl SubagentConfig {
|
||||
/// Create a new subagent config with just type and prompt
|
||||
pub fn new(agent_type: String, prompt: String) -> Self {
|
||||
Self {
|
||||
agent_type,
|
||||
prompt,
|
||||
model: None,
|
||||
tools: None,
|
||||
definition: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder method to set model override
|
||||
pub fn with_model(mut self, model: String) -> Self {
|
||||
self.model = Some(model);
|
||||
self
|
||||
}
|
||||
|
||||
/// Builder method to set tool whitelist
|
||||
pub fn with_tools(mut self, tools: Vec<String>) -> Self {
|
||||
self.tools = Some(tools);
|
||||
self
|
||||
}
|
||||
|
||||
/// Builder method to set agent definition
|
||||
pub fn with_definition(mut self, definition: AgentDefinition) -> Self {
|
||||
self.definition = Some(definition);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Registry of available subagents
|
||||
#[derive(Clone, Default)]
|
||||
pub struct SubagentRegistry {
|
||||
agents: Arc<RwLock<HashMap<String, AgentDefinition>>>,
|
||||
}
|
||||
|
||||
impl SubagentRegistry {
|
||||
/// Create a new empty registry
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Register agents from plugin manager
|
||||
pub fn register_from_plugins(&self, agents: Vec<AgentDefinition>) {
|
||||
let mut map = self.agents.write();
|
||||
for agent in agents {
|
||||
map.insert(agent.name.clone(), agent);
|
||||
}
|
||||
}
|
||||
|
||||
/// Register built-in agents
|
||||
pub fn register_builtin(&self) {
|
||||
let mut map = self.agents.write();
|
||||
|
||||
// Explore agent - for codebase exploration
|
||||
map.insert("explore".to_string(), AgentDefinition {
|
||||
name: "explore".to_string(),
|
||||
description: "Explores codebases to find files and understand structure".to_string(),
|
||||
tools: vec!["read".to_string(), "glob".to_string(), "grep".to_string(), "ls".to_string()],
|
||||
model: None,
|
||||
color: Some("blue".to_string()),
|
||||
system_prompt: "You are an exploration agent. Your purpose is to find relevant files and understand code structure. Use glob to find files by pattern, grep to search for content, ls to list directories, and read to examine files. Be thorough and systematic in your exploration.".to_string(),
|
||||
source_path: PathBuf::new(),
|
||||
});
|
||||
|
||||
// Plan agent - for designing implementations
|
||||
map.insert("plan".to_string(), AgentDefinition {
|
||||
name: "plan".to_string(),
|
||||
description: "Designs implementation plans and architectures".to_string(),
|
||||
tools: vec!["read".to_string(), "glob".to_string(), "grep".to_string()],
|
||||
model: None,
|
||||
color: Some("green".to_string()),
|
||||
system_prompt: "You are a planning agent. Your purpose is to design clear implementation strategies and architectures. Read existing code, understand patterns, and create detailed plans. Focus on the 'why' and 'how' rather than the 'what'.".to_string(),
|
||||
source_path: PathBuf::new(),
|
||||
});
|
||||
|
||||
// Code reviewer - read-only analysis
|
||||
map.insert("code-reviewer".to_string(), AgentDefinition {
|
||||
name: "code-reviewer".to_string(),
|
||||
description: "Reviews code for quality, bugs, and best practices".to_string(),
|
||||
tools: vec!["read".to_string(), "grep".to_string(), "glob".to_string()],
|
||||
model: None,
|
||||
color: Some("yellow".to_string()),
|
||||
system_prompt: "You are a code review agent. Analyze code for quality, potential bugs, performance issues, and adherence to best practices. Provide constructive feedback with specific examples.".to_string(),
|
||||
source_path: PathBuf::new(),
|
||||
});
|
||||
|
||||
// Test writer - can read and write test files
|
||||
map.insert("test-writer".to_string(), AgentDefinition {
|
||||
name: "test-writer".to_string(),
|
||||
description: "Writes and updates test files".to_string(),
|
||||
tools: vec!["read".to_string(), "write".to_string(), "edit".to_string(), "grep".to_string(), "glob".to_string()],
|
||||
model: None,
|
||||
color: Some("cyan".to_string()),
|
||||
system_prompt: "You are a test writing agent. Write comprehensive, well-structured tests that cover edge cases and ensure code correctness. Follow testing best practices and patterns used in the codebase.".to_string(),
|
||||
source_path: PathBuf::new(),
|
||||
});
|
||||
|
||||
// Documentation agent - can read code and write docs
|
||||
map.insert("doc-writer".to_string(), AgentDefinition {
|
||||
name: "doc-writer".to_string(),
|
||||
description: "Writes and maintains documentation".to_string(),
|
||||
tools: vec!["read".to_string(), "write".to_string(), "edit".to_string(), "grep".to_string(), "glob".to_string()],
|
||||
model: None,
|
||||
color: Some("magenta".to_string()),
|
||||
system_prompt: "You are a documentation agent. Write clear, comprehensive documentation that helps users understand the code. Include examples, explain concepts, and maintain consistency with existing documentation style.".to_string(),
|
||||
source_path: PathBuf::new(),
|
||||
});
|
||||
|
||||
// Refactoring agent - full file access but no bash
|
||||
map.insert("refactorer".to_string(), AgentDefinition {
|
||||
name: "refactorer".to_string(),
|
||||
description: "Refactors code while preserving functionality".to_string(),
|
||||
tools: vec!["read".to_string(), "write".to_string(), "edit".to_string(), "grep".to_string(), "glob".to_string()],
|
||||
model: None,
|
||||
color: Some("red".to_string()),
|
||||
system_prompt: "You are a refactoring agent. Improve code structure, readability, and maintainability while preserving functionality. Follow SOLID principles and language idioms. Make small, incremental changes.".to_string(),
|
||||
source_path: PathBuf::new(),
|
||||
});
|
||||
}
|
||||
|
||||
/// Get an agent by name
|
||||
pub fn get(&self, name: &str) -> Option<AgentDefinition> {
|
||||
self.agents.read().get(name).cloned()
|
||||
}
|
||||
|
||||
/// List all available agents with their descriptions
|
||||
pub fn list(&self) -> Vec<(String, String)> {
|
||||
self.agents.read()
|
||||
.iter()
|
||||
.map(|(name, def)| (name.clone(), def.description.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Check if an agent exists
|
||||
pub fn contains(&self, name: &str) -> bool {
|
||||
self.agents.read().contains_key(name)
|
||||
}
|
||||
|
||||
/// Get all agent names
|
||||
pub fn agent_names(&self) -> Vec<String> {
|
||||
self.agents.read().keys().cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// A specialized subagent with limited tool access (legacy API for backward compatibility)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Subagent {
|
||||
/// Unique identifier for the subagent
|
||||
pub name: String,
|
||||
/// Description of subagent's capabilities and purpose
|
||||
pub description: String,
|
||||
/// Keywords that trigger this subagent's selection
|
||||
pub keywords: Vec<String>,
|
||||
/// Tools this subagent is allowed to use
|
||||
pub allowed_tools: Vec<Tool>,
|
||||
}
|
||||
|
||||
impl Subagent {
|
||||
pub fn new(name: String, description: String, keywords: Vec<String>, allowed_tools: Vec<Tool>) -> Self {
|
||||
Self {
|
||||
name,
|
||||
description,
|
||||
keywords,
|
||||
allowed_tools,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this subagent can use the specified tool
|
||||
pub fn can_use_tool(&self, tool: Tool) -> bool {
|
||||
self.allowed_tools.contains(&tool)
|
||||
}
|
||||
|
||||
/// Check if this subagent matches the task description
|
||||
pub fn matches_task(&self, task_description: &str) -> bool {
|
||||
let task_lower = task_description.to_lowercase();
|
||||
self.keywords.iter().any(|keyword| {
|
||||
task_lower.contains(&keyword.to_lowercase())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Task execution request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TaskRequest {
|
||||
/// Description of the task to execute
|
||||
pub description: String,
|
||||
/// Optional: specific subagent to use
|
||||
pub agent_name: Option<String>,
|
||||
}
|
||||
|
||||
/// Task execution result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TaskResult {
|
||||
/// The subagent that handled the task
|
||||
pub agent_name: String,
|
||||
/// Success or failure
|
||||
pub success: bool,
|
||||
/// Result message
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_subagent_registry_builtin() {
|
||||
let registry = SubagentRegistry::new();
|
||||
registry.register_builtin();
|
||||
|
||||
// Check that built-in agents are registered
|
||||
assert!(registry.contains("explore"));
|
||||
assert!(registry.contains("plan"));
|
||||
assert!(registry.contains("code-reviewer"));
|
||||
assert!(registry.contains("test-writer"));
|
||||
assert!(registry.contains("doc-writer"));
|
||||
assert!(registry.contains("refactorer"));
|
||||
|
||||
// Get an agent and verify its properties
|
||||
let explore = registry.get("explore").unwrap();
|
||||
assert_eq!(explore.name, "explore");
|
||||
assert!(explore.tools.contains(&"read".to_string()));
|
||||
assert!(explore.tools.contains(&"glob".to_string()));
|
||||
assert!(explore.tools.contains(&"grep".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subagent_registry_list() {
|
||||
let registry = SubagentRegistry::new();
|
||||
registry.register_builtin();
|
||||
|
||||
let agents = registry.list();
|
||||
assert!(agents.len() >= 6);
|
||||
|
||||
// Check that we have expected agents
|
||||
let names: Vec<String> = agents.iter().map(|(name, _)| name.clone()).collect();
|
||||
assert!(names.contains(&"explore".to_string()));
|
||||
assert!(names.contains(&"plan".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subagent_config_builder() {
|
||||
let config = SubagentConfig::new("explore".to_string(), "Find all Rust files".to_string())
|
||||
.with_model("claude-3-opus".to_string())
|
||||
.with_tools(vec!["read".to_string(), "glob".to_string()]);
|
||||
|
||||
assert_eq!(config.agent_type, "explore");
|
||||
assert_eq!(config.prompt, "Find all Rust files");
|
||||
assert_eq!(config.model, Some("claude-3-opus".to_string()));
|
||||
assert_eq!(config.tools, Some(vec!["read".to_string(), "glob".to_string()]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_register_from_plugins() {
|
||||
let registry = SubagentRegistry::new();
|
||||
|
||||
let plugin_agent = AgentDefinition {
|
||||
name: "custom-agent".to_string(),
|
||||
description: "A custom agent from plugin".to_string(),
|
||||
tools: vec!["read".to_string()],
|
||||
model: Some("haiku".to_string()),
|
||||
color: Some("purple".to_string()),
|
||||
system_prompt: "Custom prompt".to_string(),
|
||||
source_path: PathBuf::from("/path/to/plugin"),
|
||||
};
|
||||
|
||||
registry.register_from_plugins(vec![plugin_agent]);
|
||||
|
||||
assert!(registry.contains("custom-agent"));
|
||||
let agent = registry.get("custom-agent").unwrap();
|
||||
assert_eq!(agent.model, Some("haiku".to_string()));
|
||||
}
|
||||
|
||||
// Legacy API tests for backward compatibility
|
||||
#[test]
|
||||
fn subagent_tool_whitelist() {
|
||||
let agent = Subagent::new(
|
||||
"reader".to_string(),
|
||||
"Read-only agent".to_string(),
|
||||
vec!["read".to_string()],
|
||||
vec![Tool::Read, Tool::Grep],
|
||||
);
|
||||
|
||||
assert!(agent.can_use_tool(Tool::Read));
|
||||
assert!(agent.can_use_tool(Tool::Grep));
|
||||
assert!(!agent.can_use_tool(Tool::Write));
|
||||
assert!(!agent.can_use_tool(Tool::Bash));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subagent_keyword_matching() {
|
||||
let agent = Subagent::new(
|
||||
"tester".to_string(),
|
||||
"Test agent".to_string(),
|
||||
vec!["test".to_string(), "unit test".to_string()],
|
||||
vec![Tool::Read, Tool::Write],
|
||||
);
|
||||
|
||||
assert!(agent.matches_task("Write unit tests for the auth module"));
|
||||
assert!(agent.matches_task("Add test coverage"));
|
||||
assert!(!agent.matches_task("Refactor the database layer"));
|
||||
}
|
||||
}
|
||||
12
crates/tools/todo/Cargo.toml
Normal file
12
crates/tools/todo/Cargo.toml
Normal file
@@ -0,0 +1,12 @@
|
||||
[package]
|
||||
name = "tools-todo"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
rust-version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
parking_lot = "0.12"
|
||||
color-eyre = "0.6"
|
||||
113
crates/tools/todo/src/lib.rs
Normal file
113
crates/tools/todo/src/lib.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
//! TodoWrite tool for task list management
|
||||
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Status of a todo item
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TodoStatus {
|
||||
Pending,
|
||||
InProgress,
|
||||
Completed,
|
||||
}
|
||||
|
||||
/// A todo item
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Todo {
|
||||
pub content: String,
|
||||
pub status: TodoStatus,
|
||||
pub active_form: String, // Present continuous form for display
|
||||
}
|
||||
|
||||
/// Shared todo list state
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct TodoList {
|
||||
inner: Arc<RwLock<Vec<Todo>>>,
|
||||
}
|
||||
|
||||
impl TodoList {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Replace all todos with new list
|
||||
pub fn write(&self, todos: Vec<Todo>) {
|
||||
*self.inner.write() = todos;
|
||||
}
|
||||
|
||||
/// Get current todos
|
||||
pub fn read(&self) -> Vec<Todo> {
|
||||
self.inner.read().clone()
|
||||
}
|
||||
|
||||
/// Get the current in-progress task (for status display)
|
||||
pub fn current_task(&self) -> Option<String> {
|
||||
self.inner.read()
|
||||
.iter()
|
||||
.find(|t| t.status == TodoStatus::InProgress)
|
||||
.map(|t| t.active_form.clone())
|
||||
}
|
||||
|
||||
/// Get summary stats
|
||||
pub fn stats(&self) -> (usize, usize, usize) {
|
||||
let todos = self.inner.read();
|
||||
let pending = todos.iter().filter(|t| t.status == TodoStatus::Pending).count();
|
||||
let in_progress = todos.iter().filter(|t| t.status == TodoStatus::InProgress).count();
|
||||
let completed = todos.iter().filter(|t| t.status == TodoStatus::Completed).count();
|
||||
(pending, in_progress, completed)
|
||||
}
|
||||
|
||||
/// Format todos for display
|
||||
pub fn format_display(&self) -> String {
|
||||
let todos = self.inner.read();
|
||||
if todos.is_empty() {
|
||||
return "No tasks".to_string();
|
||||
}
|
||||
|
||||
todos.iter().enumerate().map(|(i, t)| {
|
||||
let status_icon = match t.status {
|
||||
TodoStatus::Pending => "○",
|
||||
TodoStatus::InProgress => "◐",
|
||||
TodoStatus::Completed => "●",
|
||||
};
|
||||
format!("{}. {} {}", i + 1, status_icon, t.content)
|
||||
}).collect::<Vec<_>>().join("\n")
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse todos from JSON tool input
|
||||
pub fn parse_todos(input: &serde_json::Value) -> color_eyre::Result<Vec<Todo>> {
|
||||
let todos = input.get("todos")
|
||||
.ok_or_else(|| color_eyre::eyre::eyre!("Missing 'todos' field"))?;
|
||||
|
||||
serde_json::from_value(todos.clone())
|
||||
.map_err(|e| color_eyre::eyre::eyre!("Invalid todos format: {}", e))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_todo_list() {
|
||||
let list = TodoList::new();
|
||||
|
||||
list.write(vec![
|
||||
Todo {
|
||||
content: "First task".to_string(),
|
||||
status: TodoStatus::Completed,
|
||||
active_form: "Completing first task".to_string(),
|
||||
},
|
||||
Todo {
|
||||
content: "Second task".to_string(),
|
||||
status: TodoStatus::InProgress,
|
||||
active_form: "Working on second task".to_string(),
|
||||
},
|
||||
]);
|
||||
|
||||
assert_eq!(list.current_task(), Some("Working on second task".to_string()));
|
||||
assert_eq!(list.stats(), (0, 1, 1));
|
||||
}
|
||||
}
|
||||
21
crates/tools/web/Cargo.toml
Normal file
21
crates/tools/web/Cargo.toml
Normal file
@@ -0,0 +1,21 @@
|
||||
[package]
|
||||
name = "tools-web"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
rust-version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
reqwest = { version = "0.12", features = ["json"] }
|
||||
tokio = { version = "1.39", features = ["macros"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
color-eyre = "0.6"
|
||||
url = "2.5"
|
||||
async-trait = "0.1"
|
||||
scraper = "0.18"
|
||||
urlencoding = "2.1"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { version = "1.39", features = ["macros", "rt-multi-thread"] }
|
||||
wiremock = "0.6"
|
||||
325
crates/tools/web/src/lib.rs
Normal file
325
crates/tools/web/src/lib.rs
Normal file
@@ -0,0 +1,325 @@
|
||||
use color_eyre::eyre::{Result, eyre};
|
||||
use reqwest::redirect::Policy;
|
||||
use scraper::{Html, Selector};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
use url::Url;
|
||||
|
||||
/// WebFetch response
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FetchResponse {
|
||||
pub url: String,
|
||||
pub status: u16,
|
||||
pub content: String,
|
||||
pub content_type: Option<String>,
|
||||
}
|
||||
|
||||
/// WebFetch client with domain filtering
|
||||
pub struct WebFetchClient {
|
||||
allowed_domains: HashSet<String>,
|
||||
blocked_domains: HashSet<String>,
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl WebFetchClient {
|
||||
/// Create a new WebFetch client
|
||||
pub fn new() -> Self {
|
||||
let client = reqwest::Client::builder()
|
||||
.redirect(Policy::none()) // Don't follow redirects automatically
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
Self {
|
||||
allowed_domains: HashSet::new(),
|
||||
blocked_domains: HashSet::new(),
|
||||
client,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an allowed domain
|
||||
pub fn allow_domain(&mut self, domain: &str) {
|
||||
self.allowed_domains.insert(domain.to_lowercase());
|
||||
}
|
||||
|
||||
/// Add a blocked domain
|
||||
pub fn block_domain(&mut self, domain: &str) {
|
||||
self.blocked_domains.insert(domain.to_lowercase());
|
||||
}
|
||||
|
||||
/// Check if a domain is allowed
|
||||
fn is_domain_allowed(&self, domain: &str) -> bool {
|
||||
let domain_lower = domain.to_lowercase();
|
||||
|
||||
// If explicitly blocked, deny
|
||||
if self.blocked_domains.contains(&domain_lower) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If allowlist is empty, allow all (except blocked)
|
||||
if self.allowed_domains.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Otherwise, must be in allowlist
|
||||
self.allowed_domains.contains(&domain_lower)
|
||||
}
|
||||
|
||||
/// Fetch a URL
|
||||
pub async fn fetch(&self, url: &str) -> Result<FetchResponse> {
|
||||
let parsed_url = Url::parse(url)?;
|
||||
let domain = parsed_url
|
||||
.host_str()
|
||||
.ok_or_else(|| eyre!("No host in URL"))?;
|
||||
|
||||
// Check domain permission
|
||||
if !self.is_domain_allowed(domain) {
|
||||
return Err(eyre!("Domain not allowed: {}", domain));
|
||||
}
|
||||
|
||||
// Make the request
|
||||
let response = self.client.get(url).send().await?;
|
||||
|
||||
let status = response.status().as_u16();
|
||||
|
||||
// Handle redirects manually
|
||||
if status >= 300 && status < 400 {
|
||||
if let Some(location) = response.headers().get("location") {
|
||||
let location_str = location.to_str()?;
|
||||
|
||||
// Parse the redirect URL (may be relative)
|
||||
let redirect_url = if location_str.starts_with("http") {
|
||||
Url::parse(location_str)?
|
||||
} else {
|
||||
parsed_url.join(location_str)?
|
||||
};
|
||||
|
||||
let redirect_domain = redirect_url
|
||||
.host_str()
|
||||
.ok_or_else(|| eyre!("No host in redirect URL"))?;
|
||||
|
||||
// Check if redirect domain is allowed
|
||||
if !self.is_domain_allowed(redirect_domain) {
|
||||
return Err(eyre!(
|
||||
"Redirect to unapproved domain: {} -> {}",
|
||||
domain,
|
||||
redirect_domain
|
||||
));
|
||||
}
|
||||
|
||||
return Err(eyre!(
|
||||
"Redirect detected: {} -> {}. Use the redirect URL directly.",
|
||||
url,
|
||||
redirect_url
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let content_type = response
|
||||
.headers()
|
||||
.get("content-type")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let content = response.text().await?;
|
||||
|
||||
Ok(FetchResponse {
|
||||
url: url.to_string(),
|
||||
status,
|
||||
content,
|
||||
content_type,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WebFetchClient {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Search provider trait
|
||||
#[async_trait::async_trait]
|
||||
pub trait SearchProvider: Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
async fn search(&self, query: &str) -> Result<Vec<SearchResult>>;
|
||||
}
|
||||
|
||||
/// Search result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchResult {
|
||||
pub title: String,
|
||||
pub url: String,
|
||||
pub snippet: String,
|
||||
}
|
||||
|
||||
/// Stub search provider for testing
|
||||
pub struct StubSearchProvider {
|
||||
results: Vec<SearchResult>,
|
||||
}
|
||||
|
||||
impl StubSearchProvider {
|
||||
pub fn new(results: Vec<SearchResult>) -> Self {
|
||||
Self { results }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl SearchProvider for StubSearchProvider {
|
||||
fn name(&self) -> &str {
|
||||
"stub"
|
||||
}
|
||||
|
||||
async fn search(&self, _query: &str) -> Result<Vec<SearchResult>> {
|
||||
Ok(self.results.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// DuckDuckGo HTML search provider
|
||||
pub struct DuckDuckGoSearchProvider {
|
||||
client: reqwest::Client,
|
||||
max_results: usize,
|
||||
}
|
||||
|
||||
impl DuckDuckGoSearchProvider {
|
||||
/// Create a new DuckDuckGo search provider with default max results (10)
|
||||
pub fn new() -> Self {
|
||||
Self::with_max_results(10)
|
||||
}
|
||||
|
||||
/// Create a new DuckDuckGo search provider with custom max results
|
||||
pub fn with_max_results(max_results: usize) -> Self {
|
||||
let client = reqwest::Client::builder()
|
||||
.user_agent("Mozilla/5.0 (compatible; Owlen/1.0)")
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
Self { client, max_results }
|
||||
}
|
||||
|
||||
/// Parse DuckDuckGo HTML results
|
||||
fn parse_results(html: &str, max_results: usize) -> Result<Vec<SearchResult>> {
|
||||
let document = Html::parse_document(html);
|
||||
|
||||
// DuckDuckGo HTML selectors
|
||||
let result_selector = Selector::parse(".result").map_err(|e| eyre!("Invalid selector: {:?}", e))?;
|
||||
let title_selector = Selector::parse(".result__title a").map_err(|e| eyre!("Invalid selector: {:?}", e))?;
|
||||
let snippet_selector = Selector::parse(".result__snippet").map_err(|e| eyre!("Invalid selector: {:?}", e))?;
|
||||
|
||||
let mut results = Vec::new();
|
||||
|
||||
for result in document.select(&result_selector).take(max_results) {
|
||||
let title = result
|
||||
.select(&title_selector)
|
||||
.next()
|
||||
.map(|e| e.text().collect::<String>().trim().to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
let url = result
|
||||
.select(&title_selector)
|
||||
.next()
|
||||
.and_then(|e| e.value().attr("href"))
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
|
||||
let snippet = result
|
||||
.select(&snippet_selector)
|
||||
.next()
|
||||
.map(|e| e.text().collect::<String>().trim().to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
if !title.is_empty() && !url.is_empty() {
|
||||
results.push(SearchResult { title, url, snippet });
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DuckDuckGoSearchProvider {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl SearchProvider for DuckDuckGoSearchProvider {
|
||||
fn name(&self) -> &str {
|
||||
"duckduckgo"
|
||||
}
|
||||
|
||||
async fn search(&self, query: &str) -> Result<Vec<SearchResult>> {
|
||||
let encoded_query = urlencoding::encode(query);
|
||||
let url = format!("https://html.duckduckgo.com/html/?q={}", encoded_query);
|
||||
|
||||
let response = self.client.get(&url).send().await?;
|
||||
let html = response.text().await?;
|
||||
|
||||
Self::parse_results(&html, self.max_results)
|
||||
}
|
||||
}
|
||||
|
||||
/// WebSearch client with pluggable providers
|
||||
pub struct WebSearchClient {
|
||||
provider: Box<dyn SearchProvider>,
|
||||
}
|
||||
|
||||
impl WebSearchClient {
|
||||
pub fn new(provider: Box<dyn SearchProvider>) -> Self {
|
||||
Self { provider }
|
||||
}
|
||||
|
||||
pub fn provider_name(&self) -> &str {
|
||||
self.provider.name()
|
||||
}
|
||||
|
||||
pub async fn search(&self, query: &str) -> Result<Vec<SearchResult>> {
|
||||
self.provider.search(query).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Format search results for LLM consumption (markdown format)
|
||||
pub fn format_search_results(results: &[SearchResult]) -> String {
|
||||
if results.is_empty() {
|
||||
return "No results found.".to_string();
|
||||
}
|
||||
|
||||
results
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, r)| format!("{}. [{}]({})\n {}", i + 1, r.title, r.url, r.snippet))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n\n")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn domain_filtering_allowlist() {
|
||||
let mut client = WebFetchClient::new();
|
||||
client.allow_domain("example.com");
|
||||
|
||||
assert!(client.is_domain_allowed("example.com"));
|
||||
assert!(!client.is_domain_allowed("evil.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn domain_filtering_blocklist() {
|
||||
let mut client = WebFetchClient::new();
|
||||
client.block_domain("evil.com");
|
||||
|
||||
assert!(client.is_domain_allowed("example.com")); // Empty allowlist = allow all
|
||||
assert!(!client.is_domain_allowed("evil.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn domain_filtering_case_insensitive() {
|
||||
let mut client = WebFetchClient::new();
|
||||
client.allow_domain("Example.COM");
|
||||
|
||||
assert!(client.is_domain_allowed("example.com"));
|
||||
assert!(client.is_domain_allowed("EXAMPLE.COM"));
|
||||
}
|
||||
}
|
||||
161
crates/tools/web/tests/web_tools.rs
Normal file
161
crates/tools/web/tests/web_tools.rs
Normal file
@@ -0,0 +1,161 @@
|
||||
use tools_web::{WebFetchClient, WebSearchClient, StubSearchProvider, SearchResult};
|
||||
use wiremock::{MockServer, Mock, ResponseTemplate};
|
||||
use wiremock::matchers::{method, path};
|
||||
|
||||
#[tokio::test]
|
||||
async fn webfetch_domain_whitelist_only() {
|
||||
let mock_server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/test"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_string("Hello from allowed domain"))
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let mut client = WebFetchClient::new();
|
||||
client.allow_domain("localhost");
|
||||
client.allow_domain("127.0.0.1"); // Domain without port
|
||||
|
||||
// Fetch from allowed domain should work
|
||||
let url = format!("{}/test", mock_server.uri());
|
||||
let response = client.fetch(&url).await.unwrap();
|
||||
assert_eq!(response.status, 200);
|
||||
assert!(response.content.contains("Hello from allowed domain"));
|
||||
|
||||
// Create a client with different allowlist
|
||||
let mut strict_client = WebFetchClient::new();
|
||||
strict_client.allow_domain("example.com");
|
||||
|
||||
// Fetch from non-allowed domain should fail
|
||||
let result = strict_client.fetch(&url).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("Domain not allowed"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn webfetch_redirect_to_unapproved_domain() {
|
||||
let mock_server = MockServer::start().await;
|
||||
|
||||
// Mock a redirect to a different domain
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/redirect"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(302)
|
||||
.insert_header("location", "https://evil.com/malware")
|
||||
)
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let mut client = WebFetchClient::new();
|
||||
client.allow_domain("localhost");
|
||||
client.allow_domain("127.0.0.1"); // Domain without port
|
||||
// evil.com is NOT in the allowlist
|
||||
|
||||
let url = format!("{}/redirect", mock_server.uri());
|
||||
let result = client.fetch(&url).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(err_msg.contains("Redirect to unapproved domain") || err_msg.contains("evil.com"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn webfetch_redirect_to_approved_domain() {
|
||||
let mock_server = MockServer::start().await;
|
||||
|
||||
let redirect_url = format!("{}/target", mock_server.uri());
|
||||
|
||||
// Mock a redirect to an approved domain
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/redirect"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(302)
|
||||
.insert_header("location", &redirect_url)
|
||||
)
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let mut client = WebFetchClient::new();
|
||||
client.allow_domain("localhost");
|
||||
client.allow_domain("127.0.0.1"); // Domain without port
|
||||
|
||||
let url = format!("{}/redirect", mock_server.uri());
|
||||
let result = client.fetch(&url).await;
|
||||
|
||||
// Should fail but with a message about using the redirect URL
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(err_msg.contains("Redirect detected") || err_msg.contains("Use the redirect URL"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn webfetch_blocklist_overrides_allowlist() {
|
||||
let mock_server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/test"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_string("Hello"))
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let domain = "127.0.0.1";
|
||||
let mut client = WebFetchClient::new();
|
||||
client.allow_domain(domain);
|
||||
client.block_domain(domain); // Block overrides allow
|
||||
|
||||
let url = format!("{}/test", mock_server.uri());
|
||||
let result = client.fetch(&url).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("Domain not allowed"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn websearch_pluggable_provider() {
|
||||
let stub_results = vec![
|
||||
SearchResult {
|
||||
title: "Test Result 1".to_string(),
|
||||
url: "https://example.com/1".to_string(),
|
||||
snippet: "This is a test result".to_string(),
|
||||
},
|
||||
SearchResult {
|
||||
title: "Test Result 2".to_string(),
|
||||
url: "https://example.com/2".to_string(),
|
||||
snippet: "Another test result".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
let provider = StubSearchProvider::new(stub_results.clone());
|
||||
let client = WebSearchClient::new(Box::new(provider));
|
||||
|
||||
assert_eq!(client.provider_name(), "stub");
|
||||
|
||||
let results = client.search("test query").await.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_eq!(results[0].title, "Test Result 1");
|
||||
assert_eq!(results[1].url, "https://example.com/2");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn webfetch_successful_request() {
|
||||
let mock_server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/api/data"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200)
|
||||
.set_body_string(r#"{"status":"ok"}"#)
|
||||
.insert_header("content-type", "application/json")
|
||||
)
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let client = WebFetchClient::new(); // Empty allowlist = allow all
|
||||
|
||||
let url = format!("{}/api/data", mock_server.uri());
|
||||
let response = client.fetch(&url).await.unwrap();
|
||||
|
||||
assert_eq!(response.status, 200);
|
||||
assert!(response.content.contains("status"));
|
||||
assert!(response.content_type.is_some()); // Just verify content-type is present
|
||||
}
|
||||
Reference in New Issue
Block a user