diff --git a/crates/owlen-core/src/session.rs b/crates/owlen-core/src/session.rs index 86a0de3..0c3cd90 100644 --- a/crates/owlen-core/src/session.rs +++ b/crates/owlen-core/src/session.rs @@ -35,7 +35,7 @@ use crate::{ }; use crate::{Error, Result}; use chrono::{DateTime, Utc}; -use log::{info, warn}; +use log::{debug, info, warn}; use reqwest::Url; use serde_json::{Value, json}; use std::cmp::{max, min}; @@ -1900,11 +1900,43 @@ impl SessionController { let streaming = { self.config.lock().await.general.enable_streaming || parameters.stream }; parameters.stream = streaming; - let tools = if !self.tool_registry.all().is_empty() { + let active_model = self.conversation.active().model.clone(); + let registry_tools = self.tool_registry.all(); + let mut include_tools = !registry_tools.is_empty(); + + if include_tools { + let cached_support = self.model_manager.select(&active_model).await; + let supports_tools = match cached_support { + Some(info) => info.supports_tools, + None => match self.models(false).await { + Ok(models) => models + .iter() + .find(|model| model.id == active_model || model.name == active_model) + .map(|model| model.supports_tools) + .unwrap_or(true), + Err(err) => { + warn!( + "Unable to resolve tool support for model '{}': {}. Assuming tools are supported.", + active_model, err + ); + true + } + }, + }; + + if !supports_tools { + include_tools = false; + debug!( + "Disabling tools for model '{}' because it does not advertise tool support.", + active_model + ); + } + } + + let tools = if include_tools { Some( - self.tool_registry - .all() - .into_iter() + registry_tools + .iter() .map(|tool| crate::mcp::McpToolDescriptor { name: tool.name().to_string(), description: tool.description().to_string(), @@ -1919,7 +1951,7 @@ impl SessionController { }; let mut request = ChatRequest { - model: self.conversation.active().model.clone(), + model: active_model, messages: self.conversation.active().messages.clone(), parameters: parameters.clone(), tools: tools.clone(), diff --git a/crates/owlen-core/tests/agent_tool_flow.rs b/crates/owlen-core/tests/agent_tool_flow.rs index f2f2b1c..37412e0 100644 --- a/crates/owlen-core/tests/agent_tool_flow.rs +++ b/crates/owlen-core/tests/agent_tool_flow.rs @@ -1,5 +1,10 @@ -use std::{any::Any, collections::HashMap, sync::Arc}; +use std::{ + any::Any, + collections::HashMap, + sync::{Arc, Mutex}, +}; +use anyhow::anyhow; use async_trait::async_trait; use futures::StreamExt; use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches}; @@ -87,6 +92,72 @@ impl Provider for StreamingToolProvider { } } +struct NoToolSupportProvider { + captured: Arc>>, +} + +impl NoToolSupportProvider { + fn new() -> Self { + Self { + captured: Arc::new(Mutex::new(None)), + } + } + + fn take_captured(&self) -> Option { + self.captured.lock().expect("capture mutex").take() + } +} + +#[async_trait] +impl Provider for NoToolSupportProvider { + fn name(&self) -> &str { + "mock-tool-less-provider" + } + + async fn list_models(&self) -> owlen_core::Result> { + Ok(vec![ModelInfo { + id: "tool-less-model".into(), + name: "Toolless Model".into(), + description: Some("A model without tool support.".into()), + provider: self.name().into(), + context_window: Some(4096), + capabilities: vec!["chat".into()], + supports_tools: false, + }]) + } + + async fn send_prompt(&self, request: ChatRequest) -> owlen_core::Result { + { + let mut guard = self.captured.lock().expect("capture mutex"); + *guard = Some(request); + } + + Ok(ChatResponse { + message: Message::assistant("ack".to_string()), + usage: None, + is_streaming: false, + is_final: true, + }) + } + + async fn stream_prompt( + &self, + _request: ChatRequest, + ) -> owlen_core::Result { + Err(Error::Provider(anyhow!( + "streaming disabled for mock provider" + ))) + } + + async fn health_check(&self) -> owlen_core::Result<()> { + Ok(()) + } + + fn as_any(&self) -> &(dyn Any + Send + Sync) { + self + } +} + fn tool_descriptor() -> McpToolDescriptor { McpToolDescriptor { name: WEB_SEARCH_TOOL_NAME.to_string(), @@ -277,6 +348,55 @@ async fn streaming_file_write_consent_denied_returns_resolution() { ); } +#[tokio::test(flavor = "multi_thread")] +async fn disables_tools_when_model_lacks_support() { + let raw_provider = Arc::new(NoToolSupportProvider::new()); + let provider: Arc = raw_provider.clone(); + + let temp_dir = tempdir().expect("temp dir"); + let storage = StorageManager::with_database_path(temp_dir.path().join("owlen-tests.db")) + .await + .expect("storage"); + + let mut config = Config::default(); + config.general.default_provider = "mock-tool-less-provider".into(); + config.general.default_model = Some("tool-less-model".into()); + config.general.enable_streaming = false; + config.privacy.encrypt_local_data = false; + config.privacy.require_consent_per_session = false; + config.mcp.mode = McpMode::LocalOnly; + config + .refresh_mcp_servers(None) + .expect("refresh MCP servers"); + + let ui = Arc::new(NoOpUiController); + let mut session = SessionController::new(provider, config, Arc::new(storage), ui, false, None) + .await + .expect("session controller"); + + let outcome = session + .send_message( + "Please respond without tools.".to_string(), + ChatParameters::default(), + ) + .await + .expect("send_message should succeed"); + + if let SessionOutcome::Complete(response) = outcome { + assert_eq!(response.message.content, "ack"); + } else { + panic!("expected complete outcome when sending prompt"); + } + + let captured = raw_provider + .take_captured() + .expect("provider should capture chat request"); + assert!( + captured.tools.is_none(), + "tools should be disabled when the selected model does not support tools" + ); +} + #[tokio::test] async fn web_tool_timeout_fails_over_to_cached_result() { let primary: Arc = Arc::new(TimeoutClient);