test(agent): Add unit tests for agent-core and fix clippy warnings

This commit is contained in:
2025-12-26 18:19:58 +01:00
parent fbb6681cd2
commit f5a5724823
14 changed files with 322 additions and 120 deletions

View File

@@ -589,7 +589,7 @@ async fn main() -> Result<()> {
}
}
Cmd::Login { provider } => {
let provider_type = llm_core::ProviderType::from_str(&provider)
let provider_type = provider.parse::<llm_core::ProviderType>().ok()
.ok_or_else(|| eyre!(
"Unknown provider: {}. Supported: anthropic, openai, ollama",
provider
@@ -699,7 +699,7 @@ async fn main() -> Result<()> {
return Ok(());
}
Cmd::Logout { provider } => {
let provider_type = llm_core::ProviderType::from_str(&provider)
let provider_type = provider.parse::<llm_core::ProviderType>().ok()
.ok_or_else(|| eyre!(
"Unknown provider: {}. Supported: anthropic, openai, ollama",
provider
@@ -767,10 +767,10 @@ async fn main() -> Result<()> {
);
let _token_refresher = auth_manager.clone().start_background_refresh();
// Launch TUI
// Launch TUI with multi-provider support
// Note: For now, TUI doesn't use plugin manager directly
// In the future, we'll integrate plugin commands into TUI
return ui::run(client, opts, perms, settings).await;
return ui::run_with_providers(auth_manager, perms, settings).await;
}
// Legacy text-based REPL

View File

@@ -1,17 +1,13 @@
/// Example demonstrating the streaming agent loop API
///
/// This example shows how to use `run_agent_loop_streaming` to receive
/// real-time events during agent execution, including:
/// - Text deltas as the LLM generates text
/// - Tool execution start/end events
/// - Tool output events
/// - Final completion events
///
/// Run with: cargo run --example streaming_agent -p agent-core
use agent_core::{create_event_channel, run_agent_loop_streaming, AgentEvent, ToolContext};
use llm_core::ChatOptions;
use permissions::{Mode, PermissionManager};
//! Example demonstrating the streaming agent loop API
//!
//! This example shows how to use `run_agent_loop_streaming` to receive
//! real-time events during agent execution, including:
//! - Text deltas as the LLM generates text
//! - Tool execution start/end events
//! - Tool output events
//! - Final completion events
//!
//! Run with: cargo run --example streaming_agent -p agent-core
#[tokio::main]
async fn main() -> color_eyre::Result<()> {

View File

@@ -159,10 +159,10 @@ impl Compactor {
let mut summary = String::new();
use futures_util::StreamExt;
while let Some(chunk_result) = stream.next().await {
if let Ok(chunk) = chunk_result {
if let Some(content) = &chunk.content {
summary.push_str(content);
}
if let Ok(chunk) = chunk_result
&& let Some(content) = &chunk.content
{
summary.push_str(content);
}
}
@@ -195,14 +195,14 @@ mod tests {
// Small message list shouldn't compact
let small_messages: Vec<ChatMessage> = (0..10)
.map(|i| ChatMessage::user(&format!("Message {}", i)))
.map(|i| ChatMessage::user(format!("Message {}", i)))
.collect();
assert!(!counter.should_compact(&small_messages));
// Large message list should compact
// Need ~162,000 tokens = ~648,000 chars (at 4 chars per token)
let large_content = "x".repeat(700_000);
let large_messages = vec![ChatMessage::user(&large_content)];
let large_messages = vec![ChatMessage::user(large_content)];
assert!(counter.should_compact(&large_messages));
}
@@ -211,7 +211,7 @@ mod tests {
let compactor = Compactor::new();
let small: Vec<ChatMessage> = (0..5)
.map(|i| ChatMessage::user(&format!("Short message {}", i)))
.map(|i| ChatMessage::user(format!("Short message {}", i)))
.collect();
assert!(!compactor.needs_compaction(&small));
}

View File

@@ -928,7 +928,7 @@ pub async fn run_agent_loop<P: LlmProvider>(
let new_args = current_args + &args_delta;
// Try to parse as JSON, but keep as string if incomplete
tool_call.function.arguments = serde_json::from_str(&new_args)
.unwrap_or_else(|_| Value::String(new_args));
.unwrap_or(Value::String(new_args));
}
}
}
@@ -1128,3 +1128,65 @@ pub async fn run_agent_loop_streaming<P: LlmProvider>(
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use llm_core::ToolCallDelta;
#[test]
fn test_tool_calls_builder() {
let mut builder = ToolCallsBuilder::new();
// Add first tool call deltas
builder.add_deltas(&[
ToolCallDelta {
index: 0,
id: Some("call_1".to_string()),
function_name: Some("read".to_string()),
arguments_delta: Some("{\"path\":".to_string()),
}
]);
// Add second tool call deltas
builder.add_deltas(&[
ToolCallDelta {
index: 1,
id: Some("call_2".to_string()),
function_name: Some("write".to_string()),
arguments_delta: Some("{\"path\":\"test.txt\"".to_string()),
}
]);
// Add more deltas for first tool call
builder.add_deltas(&[
ToolCallDelta {
index: 0,
id: None,
function_name: None,
arguments_delta: Some("\"lib.rs\"}".to_string()),
}
]);
// Add more deltas for second tool call
builder.add_deltas(&[
ToolCallDelta {
index: 1,
id: None,
function_name: None,
arguments_delta: Some(",\"content\":\"hello\"}".to_string()),
}
]);
let calls = builder.build();
assert_eq!(calls.len(), 2);
assert_eq!(calls[0].id, "call_1");
assert_eq!(calls[0].function.name, "read");
assert_eq!(calls[0].function.arguments, json!({"path": "lib.rs"}));
assert_eq!(calls[1].id, "call_2");
assert_eq!(calls[1].function.name, "write");
assert_eq!(calls[1].function.arguments, json!({"path": "test.txt", "content": "hello"}));
}
}

View File

@@ -183,10 +183,10 @@ impl Checkpoint {
for entry in fs::read_dir(checkpoint_dir)? {
let entry = entry?;
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("json") {
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
checkpoints.push(stem.to_string());
}
if path.extension().and_then(|s| s.to_str()) == Some("json")
&& let Some(stem) = path.file_stem().and_then(|s| s.to_str())
{
checkpoints.push(stem.to_string());
}
}

View File

@@ -47,27 +47,27 @@ impl SystemPromptBuilder {
pub fn with_project_instructions(mut self, project_root: &Path) -> Self {
// Try CLAUDE.md first (Claude Code compatibility)
let claude_md = project_root.join("CLAUDE.md");
if claude_md.exists() {
if let Ok(content) = std::fs::read_to_string(&claude_md) {
self.sections.push(PromptSection {
name: "project".to_string(),
content: format!("# Project Instructions\n\n{}", content),
priority: 20,
});
return self;
}
if claude_md.exists()
&& let Ok(content) = std::fs::read_to_string(&claude_md)
{
self.sections.push(PromptSection {
name: "project".to_string(),
content: format!("# Project Instructions\n\n{}", content),
priority: 20,
});
return self;
}
// Fallback to .owlen.md
let owlen_md = project_root.join(".owlen.md");
if owlen_md.exists() {
if let Ok(content) = std::fs::read_to_string(&owlen_md) {
self.sections.push(PromptSection {
name: "project".to_string(),
content: format!("# Project Instructions\n\n{}", content),
priority: 20,
});
}
if owlen_md.exists()
&& let Ok(content) = std::fs::read_to_string(&owlen_md)
{
self.sections.push(PromptSection {
name: "project".to_string(),
content: format!("# Project Instructions\n\n{}", content),
priority: 20,
});
}
self

View File

@@ -0,0 +1,75 @@
use agent_core::{get_tool_definitions, ToolContext, execute_tool, AgentMode};
use permissions::{Mode, PermissionManager};
use serde_json::json;
#[test]
fn test_get_tool_definitions() {
let tools = get_tool_definitions();
assert!(!tools.is_empty());
// Check for some specific tools
let has_read = tools.iter().any(|t| t.function.name == "read");
let has_write = tools.iter().any(|t| t.function.name == "write");
let has_bash = tools.iter().any(|t| t.function.name == "bash");
assert!(has_read);
assert!(has_write);
assert!(has_bash);
}
#[tokio::test]
async fn test_execute_tool_permission_ask() {
let ctx = ToolContext::new();
let perms = PermissionManager::new(Mode::Plan); // Plan mode asks for write
let arguments = json!({
"path": "test.txt",
"content": "hello"
});
let result = execute_tool("write", &arguments, &perms, &ctx).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Permission required"));
}
#[tokio::test]
async fn test_unknown_tool() {
let ctx = ToolContext::new();
let perms = PermissionManager::new(Mode::Code);
let arguments = json!({});
let result = execute_tool("non_existent_tool", &arguments, &perms, &ctx).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Unknown tool"));
}
#[tokio::test]
async fn test_tool_context_mode_management() {
let ctx = ToolContext::new();
assert_eq!(ctx.get_mode().await, AgentMode::Normal);
assert!(!ctx.is_planning().await);
ctx.set_mode(AgentMode::Planning {
plan_file: "test_plan.md".into(),
started_at: chrono::Utc::now()
}).await;
match ctx.get_mode().await {
AgentMode::Planning { plan_file, .. } => assert_eq!(plan_file.to_str().unwrap(), "test_plan.md"),
_ => panic!("Expected Planning mode"),
}
assert!(ctx.is_planning().await);
}
#[tokio::test]
async fn test_agent_event_channel() {
let (tx, mut rx) = agent_core::create_event_channel();
tx.send(agent_core::AgentEvent::TextDelta("hello".into())).await.unwrap();
if let Some(agent_core::AgentEvent::TextDelta(text)) = rx.recv().await {
assert_eq!(text, "hello");
} else {
panic!("Expected TextDelta event");
}
}

View File

@@ -825,16 +825,20 @@ pub enum ProviderType {
OpenAI,
}
impl ProviderType {
pub fn from_str(s: &str) -> Option<Self> {
impl std::str::FromStr for ProviderType {
type Err = ();
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"ollama" => Some(Self::Ollama),
"anthropic" | "claude" => Some(Self::Anthropic),
"openai" | "gpt" => Some(Self::OpenAI),
_ => None,
"ollama" => Ok(Self::Ollama),
"anthropic" | "claude" => Ok(Self::Anthropic),
"openai" | "gpt" => Ok(Self::OpenAI),
_ => Err(()),
}
}
}
impl ProviderType {
pub fn as_str(&self) -> &'static str {
match self {
Self::Ollama => "ollama",

View File

@@ -103,7 +103,7 @@ impl TokenCounter for SimpleTokenCounter {
fn count(&self, text: &str) -> usize {
// Estimate: approximately 4 characters per token for English
// Add 3 before dividing to round up
(text.len() + 3) / 4
text.len().div_ceil(4)
}
fn count_messages(&self, messages: &[ChatMessage]) -> usize {
@@ -224,7 +224,7 @@ impl TokenCounter for ClaudeTokenCounter {
fn count(&self, text: &str) -> usize {
// Claude's tokenization is similar to the 4 chars/token heuristic
// but tends to be slightly more efficient with structured content
(text.len() + 3) / 4
text.len().div_ceil(4)
}
fn count_messages(&self, messages: &[ChatMessage]) -> usize {

View File

@@ -4,6 +4,7 @@ use figment::{
providers::{Env, Format, Serialized, Toml},
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::env;
use permissions::{Mode, PermissionManager};
@@ -18,6 +19,11 @@ pub struct Settings {
#[serde(default = "default_model")]
pub model: String,
/// Per-provider model preferences
/// Maps provider name to model ID: {"ollama": "qwen3:8b", "anthropic": "claude-sonnet-4-20250514"}
#[serde(default)]
pub models: HashMap<String, String>,
// Ollama-specific
#[serde(default = "default_ollama_url")]
pub ollama_url: String,
@@ -73,6 +79,7 @@ impl Default for Settings {
Self {
provider: default_provider(),
model: default_model(),
models: HashMap::new(),
ollama_url: default_ollama_url(),
api_key: None,
anthropic_api_key: None,
@@ -111,7 +118,7 @@ impl Settings {
/// Get the ProviderType enum from the provider string
pub fn get_provider(&self) -> Option<ProviderType> {
ProviderType::from_str(&self.provider)
self.provider.parse::<ProviderType>().ok()
}
/// Get the effective model for the current provider
@@ -136,6 +143,66 @@ impl Settings {
ProviderType::OpenAI => self.openai_api_key.clone(),
}
}
/// Get the model for a specific provider
///
/// Checks per-provider models map first, falls back to provider default
pub fn get_model_for_provider(&self, provider: ProviderType) -> String {
let provider_key = provider.to_string().to_lowercase();
// Check per-provider models map
if let Some(model) = self.models.get(&provider_key) {
if !model.is_empty() {
return model.clone();
}
}
// Fall back to provider's default model
provider.default_model().to_string()
}
/// Set the model for a specific provider
pub fn set_model_for_provider(&mut self, provider: ProviderType, model: &str) {
let provider_key = provider.to_string().to_lowercase();
self.models.insert(provider_key, model.to_string());
}
/// Save settings to user config file
///
/// Writes to ~/.config/owlen/config.toml
pub fn save(&self) -> Result<(), std::io::Error> {
if let Some(pd) = ProjectDirs::from("dev", "owlibou", "owlen") {
let config_dir = pd.config_dir();
std::fs::create_dir_all(config_dir)?;
let config_path = config_dir.join("config.toml");
let toml_str = toml::to_string_pretty(self)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
std::fs::write(config_path, toml_str)?;
}
Ok(())
}
}
/// Check if this is the first run (no user config exists)
pub fn is_first_run() -> bool {
if let Some(pd) = ProjectDirs::from("dev", "owlibou", "owlen") {
let config_path = pd.config_dir().join("config.toml");
!config_path.exists()
} else {
true
}
}
/// Get the default model for first-time setup
pub fn first_run_model() -> &'static str {
"qwen3:8b"
}
/// Get the default provider for first-time setup
pub fn first_run_provider() -> ProviderType {
ProviderType::Ollama
}
pub fn load_settings(project_root: Option<&str>) -> Result<Settings, figment::Error> {

View File

@@ -29,35 +29,6 @@ pub enum Tool {
}
impl Tool {
/// Parse a tool name from string (case-insensitive)
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"read" => Some(Tool::Read),
"write" => Some(Tool::Write),
"edit" => Some(Tool::Edit),
"bash" => Some(Tool::Bash),
"grep" => Some(Tool::Grep),
"glob" => Some(Tool::Glob),
"webfetch" | "web_fetch" => Some(Tool::WebFetch),
"websearch" | "web_search" => Some(Tool::WebSearch),
"notebookread" | "notebook_read" => Some(Tool::NotebookRead),
"notebookedit" | "notebook_edit" => Some(Tool::NotebookEdit),
"slashcommand" | "slash_command" => Some(Tool::SlashCommand),
"task" => Some(Tool::Task),
"todowrite" | "todo_write" | "todo" => Some(Tool::TodoWrite),
"mcp" => Some(Tool::Mcp),
"multiedit" | "multi_edit" => Some(Tool::MultiEdit),
"ls" => Some(Tool::LS),
"askuserquestion" | "ask_user_question" | "ask" => Some(Tool::AskUserQuestion),
"bashoutput" | "bash_output" => Some(Tool::BashOutput),
"killshell" | "kill_shell" => Some(Tool::KillShell),
"enterplanmode" | "enter_plan_mode" => Some(Tool::EnterPlanMode),
"exitplanmode" | "exit_plan_mode" => Some(Tool::ExitPlanMode),
"skill" => Some(Tool::Skill),
_ => None,
}
}
/// Get the string name of this tool
pub fn name(&self) -> &'static str {
match self {
@@ -87,6 +58,38 @@ impl Tool {
}
}
impl std::str::FromStr for Tool {
type Err = ();
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"read" => Ok(Tool::Read),
"write" => Ok(Tool::Write),
"edit" => Ok(Tool::Edit),
"bash" => Ok(Tool::Bash),
"grep" => Ok(Tool::Grep),
"glob" => Ok(Tool::Glob),
"webfetch" | "web_fetch" => Ok(Tool::WebFetch),
"websearch" | "web_search" => Ok(Tool::WebSearch),
"notebookread" | "notebook_read" => Ok(Tool::NotebookRead),
"notebookedit" | "notebook_edit" => Ok(Tool::NotebookEdit),
"slashcommand" | "slash_command" => Ok(Tool::SlashCommand),
"task" => Ok(Tool::Task),
"todowrite" | "todo_write" | "todo" => Ok(Tool::TodoWrite),
"mcp" => Ok(Tool::Mcp),
"multiedit" | "multi_edit" => Ok(Tool::MultiEdit),
"ls" => Ok(Tool::LS),
"askuserquestion" | "ask_user_question" | "ask" => Ok(Tool::AskUserQuestion),
"bashoutput" | "bash_output" => Ok(Tool::BashOutput),
"killshell" | "kill_shell" => Ok(Tool::KillShell),
"enterplanmode" | "enter_plan_mode" => Ok(Tool::EnterPlanMode),
"exitplanmode" | "exit_plan_mode" => Ok(Tool::ExitPlanMode),
"skill" => Ok(Tool::Skill),
_ => Err(()),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Action {
Allow,
@@ -101,13 +104,15 @@ pub enum Mode {
Code, // Full access (all allowed)
}
impl Mode {
pub fn from_str(s: &str) -> Option<Self> {
impl std::str::FromStr for Mode {
type Err = ();
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"plan" => Some(Mode::Plan),
"acceptedits" | "accept_edits" => Some(Mode::AcceptEdits),
"code" => Some(Mode::Code),
_ => None,
"plan" => Ok(Mode::Plan),
"acceptedits" | "accept_edits" => Ok(Mode::AcceptEdits),
"code" => Ok(Mode::Code),
_ => Err(()),
}
}
}
@@ -268,7 +273,7 @@ impl PermissionManager {
let tool_name = parts[0].trim();
let pattern = parts.get(1).map(|s| s.trim().to_string());
Tool::from_str(tool_name).map(|tool| (tool, pattern))
tool_name.parse::<Tool>().ok().map(|tool| (tool, pattern))
}
}
@@ -355,12 +360,12 @@ mod tests {
#[test]
fn tool_from_str() {
assert_eq!(Tool::from_str("bash"), Some(Tool::Bash));
assert_eq!(Tool::from_str("BASH"), Some(Tool::Bash));
assert_eq!(Tool::from_str("Bash"), Some(Tool::Bash));
assert_eq!(Tool::from_str("web_fetch"), Some(Tool::WebFetch));
assert_eq!(Tool::from_str("webfetch"), Some(Tool::WebFetch));
assert_eq!(Tool::from_str("unknown"), None);
assert_eq!("bash".parse::<Tool>(), Ok(Tool::Bash));
assert_eq!("BASH".parse::<Tool>(), Ok(Tool::Bash));
assert_eq!("Bash".parse::<Tool>(), Ok(Tool::Bash));
assert_eq!("web_fetch".parse::<Tool>(), Ok(Tool::WebFetch));
assert_eq!("webfetch".parse::<Tool>(), Ok(Tool::WebFetch));
assert!("unknown".parse::<Tool>().is_err());
}
#[test]

View File

@@ -64,7 +64,7 @@ pub fn glob_list(pattern: &str) -> Result<Vec<String>> {
// Extract the literal prefix to determine the root directory
// Find the position of the first glob metacharacter
let first_glob = pattern
.find(|c| matches!(c, '*' | '?' | '[' | '{'))
.find(['*', '?', '[', '{'])
.unwrap_or(pattern.len());
// Find the last directory separator before the first glob metacharacter
@@ -85,13 +85,11 @@ pub fn glob_list(pattern: &str) -> Result<Vec<String>> {
.build()
{
let entity = result?;
if entity.file_type().map(|filetype| filetype.is_file()).unwrap_or(false) {
if let Some(path) = entity.path().to_str() {
// Match against the glob pattern
if glob.is_match(path) {
out.push(path.to_string());
}
}
if entity.file_type().is_some_and(|ft| ft.is_file())
&& let Some(path) = entity.path().to_str()
&& glob.is_match(path)
{
out.push(path.to_string());
}
}
Ok(out)
@@ -111,7 +109,7 @@ pub fn grep(root: &str, pattern: &str) -> Result<Vec<(String, usize, String)>> {
.build()
{
let entity = result?;
if !entity.file_type().map(|filetype| filetype.is_file()).unwrap_or(false) { continue; }
if !entity.file_type().is_some_and(|ft| ft.is_file()) { continue; }
let path = entity.path().to_path_buf();
let mut line_hits: Vec<(usize, String)> = Vec::new();
let sink = UTF8(|line_number, line| {

View File

@@ -11,9 +11,10 @@ use chrono::{DateTime, Utc};
use uuid::Uuid;
/// Agent mode - normal execution or planning
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum AgentMode {
/// Normal mode - all tools available per permission settings
#[default]
Normal,
/// Planning mode - only read-only tools allowed
Planning {
@@ -24,12 +25,6 @@ pub enum AgentMode {
},
}
impl Default for AgentMode {
fn default() -> Self {
Self::Normal
}
}
impl AgentMode {
/// Check if we're in planning mode
pub fn is_planning(&self) -> bool {
@@ -215,7 +210,7 @@ impl PlanManager {
let mut entries = tokio::fs::read_dir(&self.plans_dir).await?;
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if path.extension().map_or(false, |ext| ext == "md") {
if path.extension().is_some_and(|ext| ext == "md") {
plans.push(path);
}
}
@@ -260,7 +255,7 @@ mod tests {
let plan_path = manager.create_plan().await.unwrap();
assert!(plan_path.exists());
assert!(plan_path.extension().map_or(false, |ext| ext == "md"));
assert!(plan_path.extension().is_some_and(|ext| ext == "md"));
}
#[tokio::test]

View File

@@ -82,9 +82,10 @@ impl WebFetchClient {
let status = response.status().as_u16();
// Handle redirects manually
if status >= 300 && status < 400 {
if let Some(location) = response.headers().get("location") {
let location_str = location.to_str()?;
if (300..400).contains(&status)
&& let Some(location) = response.headers().get("location")
{
let location_str = location.to_str()?;
// Parse the redirect URL (may be relative)
let redirect_url = if location_str.starts_with("http") {
@@ -111,7 +112,6 @@ impl WebFetchClient {
url,
redirect_url
));
}
}
let content_type = response