test(agent): Add unit tests for agent-core and fix clippy warnings
This commit is contained in:
@@ -589,7 +589,7 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
}
|
||||
Cmd::Login { provider } => {
|
||||
let provider_type = llm_core::ProviderType::from_str(&provider)
|
||||
let provider_type = provider.parse::<llm_core::ProviderType>().ok()
|
||||
.ok_or_else(|| eyre!(
|
||||
"Unknown provider: {}. Supported: anthropic, openai, ollama",
|
||||
provider
|
||||
@@ -699,7 +699,7 @@ async fn main() -> Result<()> {
|
||||
return Ok(());
|
||||
}
|
||||
Cmd::Logout { provider } => {
|
||||
let provider_type = llm_core::ProviderType::from_str(&provider)
|
||||
let provider_type = provider.parse::<llm_core::ProviderType>().ok()
|
||||
.ok_or_else(|| eyre!(
|
||||
"Unknown provider: {}. Supported: anthropic, openai, ollama",
|
||||
provider
|
||||
@@ -767,10 +767,10 @@ async fn main() -> Result<()> {
|
||||
);
|
||||
let _token_refresher = auth_manager.clone().start_background_refresh();
|
||||
|
||||
// Launch TUI
|
||||
// Launch TUI with multi-provider support
|
||||
// Note: For now, TUI doesn't use plugin manager directly
|
||||
// In the future, we'll integrate plugin commands into TUI
|
||||
return ui::run(client, opts, perms, settings).await;
|
||||
return ui::run_with_providers(auth_manager, perms, settings).await;
|
||||
}
|
||||
|
||||
// Legacy text-based REPL
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
/// Example demonstrating the streaming agent loop API
|
||||
///
|
||||
/// This example shows how to use `run_agent_loop_streaming` to receive
|
||||
/// real-time events during agent execution, including:
|
||||
/// - Text deltas as the LLM generates text
|
||||
/// - Tool execution start/end events
|
||||
/// - Tool output events
|
||||
/// - Final completion events
|
||||
///
|
||||
/// Run with: cargo run --example streaming_agent -p agent-core
|
||||
|
||||
use agent_core::{create_event_channel, run_agent_loop_streaming, AgentEvent, ToolContext};
|
||||
use llm_core::ChatOptions;
|
||||
use permissions::{Mode, PermissionManager};
|
||||
//! Example demonstrating the streaming agent loop API
|
||||
//!
|
||||
//! This example shows how to use `run_agent_loop_streaming` to receive
|
||||
//! real-time events during agent execution, including:
|
||||
//! - Text deltas as the LLM generates text
|
||||
//! - Tool execution start/end events
|
||||
//! - Tool output events
|
||||
//! - Final completion events
|
||||
//!
|
||||
//! Run with: cargo run --example streaming_agent -p agent-core
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> color_eyre::Result<()> {
|
||||
|
||||
@@ -159,10 +159,10 @@ impl Compactor {
|
||||
let mut summary = String::new();
|
||||
use futures_util::StreamExt;
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
if let Ok(chunk) = chunk_result {
|
||||
if let Some(content) = &chunk.content {
|
||||
summary.push_str(content);
|
||||
}
|
||||
if let Ok(chunk) = chunk_result
|
||||
&& let Some(content) = &chunk.content
|
||||
{
|
||||
summary.push_str(content);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,14 +195,14 @@ mod tests {
|
||||
|
||||
// Small message list shouldn't compact
|
||||
let small_messages: Vec<ChatMessage> = (0..10)
|
||||
.map(|i| ChatMessage::user(&format!("Message {}", i)))
|
||||
.map(|i| ChatMessage::user(format!("Message {}", i)))
|
||||
.collect();
|
||||
assert!(!counter.should_compact(&small_messages));
|
||||
|
||||
// Large message list should compact
|
||||
// Need ~162,000 tokens = ~648,000 chars (at 4 chars per token)
|
||||
let large_content = "x".repeat(700_000);
|
||||
let large_messages = vec![ChatMessage::user(&large_content)];
|
||||
let large_messages = vec![ChatMessage::user(large_content)];
|
||||
assert!(counter.should_compact(&large_messages));
|
||||
}
|
||||
|
||||
@@ -211,7 +211,7 @@ mod tests {
|
||||
let compactor = Compactor::new();
|
||||
|
||||
let small: Vec<ChatMessage> = (0..5)
|
||||
.map(|i| ChatMessage::user(&format!("Short message {}", i)))
|
||||
.map(|i| ChatMessage::user(format!("Short message {}", i)))
|
||||
.collect();
|
||||
assert!(!compactor.needs_compaction(&small));
|
||||
}
|
||||
|
||||
@@ -928,7 +928,7 @@ pub async fn run_agent_loop<P: LlmProvider>(
|
||||
let new_args = current_args + &args_delta;
|
||||
// Try to parse as JSON, but keep as string if incomplete
|
||||
tool_call.function.arguments = serde_json::from_str(&new_args)
|
||||
.unwrap_or_else(|_| Value::String(new_args));
|
||||
.unwrap_or(Value::String(new_args));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1128,3 +1128,65 @@ pub async fn run_agent_loop_streaming<P: LlmProvider>(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use llm_core::ToolCallDelta;
|
||||
|
||||
#[test]
|
||||
fn test_tool_calls_builder() {
|
||||
let mut builder = ToolCallsBuilder::new();
|
||||
|
||||
// Add first tool call deltas
|
||||
builder.add_deltas(&[
|
||||
ToolCallDelta {
|
||||
index: 0,
|
||||
id: Some("call_1".to_string()),
|
||||
function_name: Some("read".to_string()),
|
||||
arguments_delta: Some("{\"path\":".to_string()),
|
||||
}
|
||||
]);
|
||||
|
||||
// Add second tool call deltas
|
||||
builder.add_deltas(&[
|
||||
ToolCallDelta {
|
||||
index: 1,
|
||||
id: Some("call_2".to_string()),
|
||||
function_name: Some("write".to_string()),
|
||||
arguments_delta: Some("{\"path\":\"test.txt\"".to_string()),
|
||||
}
|
||||
]);
|
||||
|
||||
// Add more deltas for first tool call
|
||||
builder.add_deltas(&[
|
||||
ToolCallDelta {
|
||||
index: 0,
|
||||
id: None,
|
||||
function_name: None,
|
||||
arguments_delta: Some("\"lib.rs\"}".to_string()),
|
||||
}
|
||||
]);
|
||||
|
||||
// Add more deltas for second tool call
|
||||
builder.add_deltas(&[
|
||||
ToolCallDelta {
|
||||
index: 1,
|
||||
id: None,
|
||||
function_name: None,
|
||||
arguments_delta: Some(",\"content\":\"hello\"}".to_string()),
|
||||
}
|
||||
]);
|
||||
|
||||
let calls = builder.build();
|
||||
assert_eq!(calls.len(), 2);
|
||||
|
||||
assert_eq!(calls[0].id, "call_1");
|
||||
assert_eq!(calls[0].function.name, "read");
|
||||
assert_eq!(calls[0].function.arguments, json!({"path": "lib.rs"}));
|
||||
|
||||
assert_eq!(calls[1].id, "call_2");
|
||||
assert_eq!(calls[1].function.name, "write");
|
||||
assert_eq!(calls[1].function.arguments, json!({"path": "test.txt", "content": "hello"}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,10 +183,10 @@ impl Checkpoint {
|
||||
for entry in fs::read_dir(checkpoint_dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.extension().and_then(|s| s.to_str()) == Some("json") {
|
||||
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
|
||||
checkpoints.push(stem.to_string());
|
||||
}
|
||||
if path.extension().and_then(|s| s.to_str()) == Some("json")
|
||||
&& let Some(stem) = path.file_stem().and_then(|s| s.to_str())
|
||||
{
|
||||
checkpoints.push(stem.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -47,27 +47,27 @@ impl SystemPromptBuilder {
|
||||
pub fn with_project_instructions(mut self, project_root: &Path) -> Self {
|
||||
// Try CLAUDE.md first (Claude Code compatibility)
|
||||
let claude_md = project_root.join("CLAUDE.md");
|
||||
if claude_md.exists() {
|
||||
if let Ok(content) = std::fs::read_to_string(&claude_md) {
|
||||
self.sections.push(PromptSection {
|
||||
name: "project".to_string(),
|
||||
content: format!("# Project Instructions\n\n{}", content),
|
||||
priority: 20,
|
||||
});
|
||||
return self;
|
||||
}
|
||||
if claude_md.exists()
|
||||
&& let Ok(content) = std::fs::read_to_string(&claude_md)
|
||||
{
|
||||
self.sections.push(PromptSection {
|
||||
name: "project".to_string(),
|
||||
content: format!("# Project Instructions\n\n{}", content),
|
||||
priority: 20,
|
||||
});
|
||||
return self;
|
||||
}
|
||||
|
||||
// Fallback to .owlen.md
|
||||
let owlen_md = project_root.join(".owlen.md");
|
||||
if owlen_md.exists() {
|
||||
if let Ok(content) = std::fs::read_to_string(&owlen_md) {
|
||||
self.sections.push(PromptSection {
|
||||
name: "project".to_string(),
|
||||
content: format!("# Project Instructions\n\n{}", content),
|
||||
priority: 20,
|
||||
});
|
||||
}
|
||||
if owlen_md.exists()
|
||||
&& let Ok(content) = std::fs::read_to_string(&owlen_md)
|
||||
{
|
||||
self.sections.push(PromptSection {
|
||||
name: "project".to_string(),
|
||||
content: format!("# Project Instructions\n\n{}", content),
|
||||
priority: 20,
|
||||
});
|
||||
}
|
||||
|
||||
self
|
||||
|
||||
75
crates/core/agent/tests/core_logic.rs
Normal file
75
crates/core/agent/tests/core_logic.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
use agent_core::{get_tool_definitions, ToolContext, execute_tool, AgentMode};
|
||||
use permissions::{Mode, PermissionManager};
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_get_tool_definitions() {
|
||||
let tools = get_tool_definitions();
|
||||
assert!(!tools.is_empty());
|
||||
|
||||
// Check for some specific tools
|
||||
let has_read = tools.iter().any(|t| t.function.name == "read");
|
||||
let has_write = tools.iter().any(|t| t.function.name == "write");
|
||||
let has_bash = tools.iter().any(|t| t.function.name == "bash");
|
||||
|
||||
assert!(has_read);
|
||||
assert!(has_write);
|
||||
assert!(has_bash);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_tool_permission_ask() {
|
||||
let ctx = ToolContext::new();
|
||||
let perms = PermissionManager::new(Mode::Plan); // Plan mode asks for write
|
||||
|
||||
let arguments = json!({
|
||||
"path": "test.txt",
|
||||
"content": "hello"
|
||||
});
|
||||
|
||||
let result = execute_tool("write", &arguments, &perms, &ctx).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("Permission required"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_unknown_tool() {
|
||||
let ctx = ToolContext::new();
|
||||
let perms = PermissionManager::new(Mode::Code);
|
||||
let arguments = json!({});
|
||||
|
||||
let result = execute_tool("non_existent_tool", &arguments, &perms, &ctx).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("Unknown tool"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tool_context_mode_management() {
|
||||
let ctx = ToolContext::new();
|
||||
assert_eq!(ctx.get_mode().await, AgentMode::Normal);
|
||||
assert!(!ctx.is_planning().await);
|
||||
|
||||
ctx.set_mode(AgentMode::Planning {
|
||||
plan_file: "test_plan.md".into(),
|
||||
started_at: chrono::Utc::now()
|
||||
}).await;
|
||||
|
||||
match ctx.get_mode().await {
|
||||
AgentMode::Planning { plan_file, .. } => assert_eq!(plan_file.to_str().unwrap(), "test_plan.md"),
|
||||
_ => panic!("Expected Planning mode"),
|
||||
}
|
||||
assert!(ctx.is_planning().await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_agent_event_channel() {
|
||||
let (tx, mut rx) = agent_core::create_event_channel();
|
||||
|
||||
tx.send(agent_core::AgentEvent::TextDelta("hello".into())).await.unwrap();
|
||||
|
||||
if let Some(agent_core::AgentEvent::TextDelta(text)) = rx.recv().await {
|
||||
assert_eq!(text, "hello");
|
||||
} else {
|
||||
panic!("Expected TextDelta event");
|
||||
}
|
||||
}
|
||||
@@ -825,16 +825,20 @@ pub enum ProviderType {
|
||||
OpenAI,
|
||||
}
|
||||
|
||||
impl ProviderType {
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
impl std::str::FromStr for ProviderType {
|
||||
type Err = ();
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"ollama" => Some(Self::Ollama),
|
||||
"anthropic" | "claude" => Some(Self::Anthropic),
|
||||
"openai" | "gpt" => Some(Self::OpenAI),
|
||||
_ => None,
|
||||
"ollama" => Ok(Self::Ollama),
|
||||
"anthropic" | "claude" => Ok(Self::Anthropic),
|
||||
"openai" | "gpt" => Ok(Self::OpenAI),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderType {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Ollama => "ollama",
|
||||
|
||||
@@ -103,7 +103,7 @@ impl TokenCounter for SimpleTokenCounter {
|
||||
fn count(&self, text: &str) -> usize {
|
||||
// Estimate: approximately 4 characters per token for English
|
||||
// Add 3 before dividing to round up
|
||||
(text.len() + 3) / 4
|
||||
text.len().div_ceil(4)
|
||||
}
|
||||
|
||||
fn count_messages(&self, messages: &[ChatMessage]) -> usize {
|
||||
@@ -224,7 +224,7 @@ impl TokenCounter for ClaudeTokenCounter {
|
||||
fn count(&self, text: &str) -> usize {
|
||||
// Claude's tokenization is similar to the 4 chars/token heuristic
|
||||
// but tends to be slightly more efficient with structured content
|
||||
(text.len() + 3) / 4
|
||||
text.len().div_ceil(4)
|
||||
}
|
||||
|
||||
fn count_messages(&self, messages: &[ChatMessage]) -> usize {
|
||||
|
||||
@@ -4,6 +4,7 @@ use figment::{
|
||||
providers::{Env, Format, Serialized, Toml},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::env;
|
||||
use permissions::{Mode, PermissionManager};
|
||||
@@ -18,6 +19,11 @@ pub struct Settings {
|
||||
#[serde(default = "default_model")]
|
||||
pub model: String,
|
||||
|
||||
/// Per-provider model preferences
|
||||
/// Maps provider name to model ID: {"ollama": "qwen3:8b", "anthropic": "claude-sonnet-4-20250514"}
|
||||
#[serde(default)]
|
||||
pub models: HashMap<String, String>,
|
||||
|
||||
// Ollama-specific
|
||||
#[serde(default = "default_ollama_url")]
|
||||
pub ollama_url: String,
|
||||
@@ -73,6 +79,7 @@ impl Default for Settings {
|
||||
Self {
|
||||
provider: default_provider(),
|
||||
model: default_model(),
|
||||
models: HashMap::new(),
|
||||
ollama_url: default_ollama_url(),
|
||||
api_key: None,
|
||||
anthropic_api_key: None,
|
||||
@@ -111,7 +118,7 @@ impl Settings {
|
||||
|
||||
/// Get the ProviderType enum from the provider string
|
||||
pub fn get_provider(&self) -> Option<ProviderType> {
|
||||
ProviderType::from_str(&self.provider)
|
||||
self.provider.parse::<ProviderType>().ok()
|
||||
}
|
||||
|
||||
/// Get the effective model for the current provider
|
||||
@@ -136,6 +143,66 @@ impl Settings {
|
||||
ProviderType::OpenAI => self.openai_api_key.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the model for a specific provider
|
||||
///
|
||||
/// Checks per-provider models map first, falls back to provider default
|
||||
pub fn get_model_for_provider(&self, provider: ProviderType) -> String {
|
||||
let provider_key = provider.to_string().to_lowercase();
|
||||
|
||||
// Check per-provider models map
|
||||
if let Some(model) = self.models.get(&provider_key) {
|
||||
if !model.is_empty() {
|
||||
return model.clone();
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to provider's default model
|
||||
provider.default_model().to_string()
|
||||
}
|
||||
|
||||
/// Set the model for a specific provider
|
||||
pub fn set_model_for_provider(&mut self, provider: ProviderType, model: &str) {
|
||||
let provider_key = provider.to_string().to_lowercase();
|
||||
self.models.insert(provider_key, model.to_string());
|
||||
}
|
||||
|
||||
/// Save settings to user config file
|
||||
///
|
||||
/// Writes to ~/.config/owlen/config.toml
|
||||
pub fn save(&self) -> Result<(), std::io::Error> {
|
||||
if let Some(pd) = ProjectDirs::from("dev", "owlibou", "owlen") {
|
||||
let config_dir = pd.config_dir();
|
||||
std::fs::create_dir_all(config_dir)?;
|
||||
|
||||
let config_path = config_dir.join("config.toml");
|
||||
let toml_str = toml::to_string_pretty(self)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
std::fs::write(config_path, toml_str)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this is the first run (no user config exists)
|
||||
pub fn is_first_run() -> bool {
|
||||
if let Some(pd) = ProjectDirs::from("dev", "owlibou", "owlen") {
|
||||
let config_path = pd.config_dir().join("config.toml");
|
||||
!config_path.exists()
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the default model for first-time setup
|
||||
pub fn first_run_model() -> &'static str {
|
||||
"qwen3:8b"
|
||||
}
|
||||
|
||||
/// Get the default provider for first-time setup
|
||||
pub fn first_run_provider() -> ProviderType {
|
||||
ProviderType::Ollama
|
||||
}
|
||||
|
||||
pub fn load_settings(project_root: Option<&str>) -> Result<Settings, figment::Error> {
|
||||
|
||||
@@ -29,35 +29,6 @@ pub enum Tool {
|
||||
}
|
||||
|
||||
impl Tool {
|
||||
/// Parse a tool name from string (case-insensitive)
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"read" => Some(Tool::Read),
|
||||
"write" => Some(Tool::Write),
|
||||
"edit" => Some(Tool::Edit),
|
||||
"bash" => Some(Tool::Bash),
|
||||
"grep" => Some(Tool::Grep),
|
||||
"glob" => Some(Tool::Glob),
|
||||
"webfetch" | "web_fetch" => Some(Tool::WebFetch),
|
||||
"websearch" | "web_search" => Some(Tool::WebSearch),
|
||||
"notebookread" | "notebook_read" => Some(Tool::NotebookRead),
|
||||
"notebookedit" | "notebook_edit" => Some(Tool::NotebookEdit),
|
||||
"slashcommand" | "slash_command" => Some(Tool::SlashCommand),
|
||||
"task" => Some(Tool::Task),
|
||||
"todowrite" | "todo_write" | "todo" => Some(Tool::TodoWrite),
|
||||
"mcp" => Some(Tool::Mcp),
|
||||
"multiedit" | "multi_edit" => Some(Tool::MultiEdit),
|
||||
"ls" => Some(Tool::LS),
|
||||
"askuserquestion" | "ask_user_question" | "ask" => Some(Tool::AskUserQuestion),
|
||||
"bashoutput" | "bash_output" => Some(Tool::BashOutput),
|
||||
"killshell" | "kill_shell" => Some(Tool::KillShell),
|
||||
"enterplanmode" | "enter_plan_mode" => Some(Tool::EnterPlanMode),
|
||||
"exitplanmode" | "exit_plan_mode" => Some(Tool::ExitPlanMode),
|
||||
"skill" => Some(Tool::Skill),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the string name of this tool
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
@@ -87,6 +58,38 @@ impl Tool {
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for Tool {
|
||||
type Err = ();
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"read" => Ok(Tool::Read),
|
||||
"write" => Ok(Tool::Write),
|
||||
"edit" => Ok(Tool::Edit),
|
||||
"bash" => Ok(Tool::Bash),
|
||||
"grep" => Ok(Tool::Grep),
|
||||
"glob" => Ok(Tool::Glob),
|
||||
"webfetch" | "web_fetch" => Ok(Tool::WebFetch),
|
||||
"websearch" | "web_search" => Ok(Tool::WebSearch),
|
||||
"notebookread" | "notebook_read" => Ok(Tool::NotebookRead),
|
||||
"notebookedit" | "notebook_edit" => Ok(Tool::NotebookEdit),
|
||||
"slashcommand" | "slash_command" => Ok(Tool::SlashCommand),
|
||||
"task" => Ok(Tool::Task),
|
||||
"todowrite" | "todo_write" | "todo" => Ok(Tool::TodoWrite),
|
||||
"mcp" => Ok(Tool::Mcp),
|
||||
"multiedit" | "multi_edit" => Ok(Tool::MultiEdit),
|
||||
"ls" => Ok(Tool::LS),
|
||||
"askuserquestion" | "ask_user_question" | "ask" => Ok(Tool::AskUserQuestion),
|
||||
"bashoutput" | "bash_output" => Ok(Tool::BashOutput),
|
||||
"killshell" | "kill_shell" => Ok(Tool::KillShell),
|
||||
"enterplanmode" | "enter_plan_mode" => Ok(Tool::EnterPlanMode),
|
||||
"exitplanmode" | "exit_plan_mode" => Ok(Tool::ExitPlanMode),
|
||||
"skill" => Ok(Tool::Skill),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum Action {
|
||||
Allow,
|
||||
@@ -101,13 +104,15 @@ pub enum Mode {
|
||||
Code, // Full access (all allowed)
|
||||
}
|
||||
|
||||
impl Mode {
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
impl std::str::FromStr for Mode {
|
||||
type Err = ();
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"plan" => Some(Mode::Plan),
|
||||
"acceptedits" | "accept_edits" => Some(Mode::AcceptEdits),
|
||||
"code" => Some(Mode::Code),
|
||||
_ => None,
|
||||
"plan" => Ok(Mode::Plan),
|
||||
"acceptedits" | "accept_edits" => Ok(Mode::AcceptEdits),
|
||||
"code" => Ok(Mode::Code),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -268,7 +273,7 @@ impl PermissionManager {
|
||||
let tool_name = parts[0].trim();
|
||||
let pattern = parts.get(1).map(|s| s.trim().to_string());
|
||||
|
||||
Tool::from_str(tool_name).map(|tool| (tool, pattern))
|
||||
tool_name.parse::<Tool>().ok().map(|tool| (tool, pattern))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -355,12 +360,12 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn tool_from_str() {
|
||||
assert_eq!(Tool::from_str("bash"), Some(Tool::Bash));
|
||||
assert_eq!(Tool::from_str("BASH"), Some(Tool::Bash));
|
||||
assert_eq!(Tool::from_str("Bash"), Some(Tool::Bash));
|
||||
assert_eq!(Tool::from_str("web_fetch"), Some(Tool::WebFetch));
|
||||
assert_eq!(Tool::from_str("webfetch"), Some(Tool::WebFetch));
|
||||
assert_eq!(Tool::from_str("unknown"), None);
|
||||
assert_eq!("bash".parse::<Tool>(), Ok(Tool::Bash));
|
||||
assert_eq!("BASH".parse::<Tool>(), Ok(Tool::Bash));
|
||||
assert_eq!("Bash".parse::<Tool>(), Ok(Tool::Bash));
|
||||
assert_eq!("web_fetch".parse::<Tool>(), Ok(Tool::WebFetch));
|
||||
assert_eq!("webfetch".parse::<Tool>(), Ok(Tool::WebFetch));
|
||||
assert!("unknown".parse::<Tool>().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -64,7 +64,7 @@ pub fn glob_list(pattern: &str) -> Result<Vec<String>> {
|
||||
// Extract the literal prefix to determine the root directory
|
||||
// Find the position of the first glob metacharacter
|
||||
let first_glob = pattern
|
||||
.find(|c| matches!(c, '*' | '?' | '[' | '{'))
|
||||
.find(['*', '?', '[', '{'])
|
||||
.unwrap_or(pattern.len());
|
||||
|
||||
// Find the last directory separator before the first glob metacharacter
|
||||
@@ -85,13 +85,11 @@ pub fn glob_list(pattern: &str) -> Result<Vec<String>> {
|
||||
.build()
|
||||
{
|
||||
let entity = result?;
|
||||
if entity.file_type().map(|filetype| filetype.is_file()).unwrap_or(false) {
|
||||
if let Some(path) = entity.path().to_str() {
|
||||
// Match against the glob pattern
|
||||
if glob.is_match(path) {
|
||||
out.push(path.to_string());
|
||||
}
|
||||
}
|
||||
if entity.file_type().is_some_and(|ft| ft.is_file())
|
||||
&& let Some(path) = entity.path().to_str()
|
||||
&& glob.is_match(path)
|
||||
{
|
||||
out.push(path.to_string());
|
||||
}
|
||||
}
|
||||
Ok(out)
|
||||
@@ -111,7 +109,7 @@ pub fn grep(root: &str, pattern: &str) -> Result<Vec<(String, usize, String)>> {
|
||||
.build()
|
||||
{
|
||||
let entity = result?;
|
||||
if !entity.file_type().map(|filetype| filetype.is_file()).unwrap_or(false) { continue; }
|
||||
if !entity.file_type().is_some_and(|ft| ft.is_file()) { continue; }
|
||||
let path = entity.path().to_path_buf();
|
||||
let mut line_hits: Vec<(usize, String)> = Vec::new();
|
||||
let sink = UTF8(|line_number, line| {
|
||||
|
||||
@@ -11,9 +11,10 @@ use chrono::{DateTime, Utc};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Agent mode - normal execution or planning
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
pub enum AgentMode {
|
||||
/// Normal mode - all tools available per permission settings
|
||||
#[default]
|
||||
Normal,
|
||||
/// Planning mode - only read-only tools allowed
|
||||
Planning {
|
||||
@@ -24,12 +25,6 @@ pub enum AgentMode {
|
||||
},
|
||||
}
|
||||
|
||||
impl Default for AgentMode {
|
||||
fn default() -> Self {
|
||||
Self::Normal
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentMode {
|
||||
/// Check if we're in planning mode
|
||||
pub fn is_planning(&self) -> bool {
|
||||
@@ -215,7 +210,7 @@ impl PlanManager {
|
||||
let mut entries = tokio::fs::read_dir(&self.plans_dir).await?;
|
||||
while let Some(entry) = entries.next_entry().await? {
|
||||
let path = entry.path();
|
||||
if path.extension().map_or(false, |ext| ext == "md") {
|
||||
if path.extension().is_some_and(|ext| ext == "md") {
|
||||
plans.push(path);
|
||||
}
|
||||
}
|
||||
@@ -260,7 +255,7 @@ mod tests {
|
||||
|
||||
let plan_path = manager.create_plan().await.unwrap();
|
||||
assert!(plan_path.exists());
|
||||
assert!(plan_path.extension().map_or(false, |ext| ext == "md"));
|
||||
assert!(plan_path.extension().is_some_and(|ext| ext == "md"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -82,9 +82,10 @@ impl WebFetchClient {
|
||||
let status = response.status().as_u16();
|
||||
|
||||
// Handle redirects manually
|
||||
if status >= 300 && status < 400 {
|
||||
if let Some(location) = response.headers().get("location") {
|
||||
let location_str = location.to_str()?;
|
||||
if (300..400).contains(&status)
|
||||
&& let Some(location) = response.headers().get("location")
|
||||
{
|
||||
let location_str = location.to_str()?;
|
||||
|
||||
// Parse the redirect URL (may be relative)
|
||||
let redirect_url = if location_str.starts_with("http") {
|
||||
@@ -111,7 +112,6 @@ impl WebFetchClient {
|
||||
url,
|
||||
redirect_url
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let content_type = response
|
||||
|
||||
Reference in New Issue
Block a user