feat(session): disable tools for unsupported models

This commit is contained in:
2025-10-26 01:56:43 +02:00
parent 7daa4f4ebe
commit 9aa8722ec3
2 changed files with 159 additions and 7 deletions

View File

@@ -35,7 +35,7 @@ use crate::{
}; };
use crate::{Error, Result}; use crate::{Error, Result};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use log::{info, warn}; use log::{debug, info, warn};
use reqwest::Url; use reqwest::Url;
use serde_json::{Value, json}; use serde_json::{Value, json};
use std::cmp::{max, min}; use std::cmp::{max, min};
@@ -1900,11 +1900,43 @@ impl SessionController {
let streaming = { self.config.lock().await.general.enable_streaming || parameters.stream }; let streaming = { self.config.lock().await.general.enable_streaming || parameters.stream };
parameters.stream = streaming; parameters.stream = streaming;
let tools = if !self.tool_registry.all().is_empty() { let active_model = self.conversation.active().model.clone();
let registry_tools = self.tool_registry.all();
let mut include_tools = !registry_tools.is_empty();
if include_tools {
let cached_support = self.model_manager.select(&active_model).await;
let supports_tools = match cached_support {
Some(info) => info.supports_tools,
None => match self.models(false).await {
Ok(models) => models
.iter()
.find(|model| model.id == active_model || model.name == active_model)
.map(|model| model.supports_tools)
.unwrap_or(true),
Err(err) => {
warn!(
"Unable to resolve tool support for model '{}': {}. Assuming tools are supported.",
active_model, err
);
true
}
},
};
if !supports_tools {
include_tools = false;
debug!(
"Disabling tools for model '{}' because it does not advertise tool support.",
active_model
);
}
}
let tools = if include_tools {
Some( Some(
self.tool_registry registry_tools
.all() .iter()
.into_iter()
.map(|tool| crate::mcp::McpToolDescriptor { .map(|tool| crate::mcp::McpToolDescriptor {
name: tool.name().to_string(), name: tool.name().to_string(),
description: tool.description().to_string(), description: tool.description().to_string(),
@@ -1919,7 +1951,7 @@ impl SessionController {
}; };
let mut request = ChatRequest { let mut request = ChatRequest {
model: self.conversation.active().model.clone(), model: active_model,
messages: self.conversation.active().messages.clone(), messages: self.conversation.active().messages.clone(),
parameters: parameters.clone(), parameters: parameters.clone(),
tools: tools.clone(), tools: tools.clone(),

View File

@@ -1,5 +1,10 @@
use std::{any::Any, collections::HashMap, sync::Arc}; use std::{
any::Any,
collections::HashMap,
sync::{Arc, Mutex},
};
use anyhow::anyhow;
use async_trait::async_trait; use async_trait::async_trait;
use futures::StreamExt; use futures::StreamExt;
use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches}; use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches};
@@ -87,6 +92,72 @@ impl Provider for StreamingToolProvider {
} }
} }
struct NoToolSupportProvider {
captured: Arc<Mutex<Option<ChatRequest>>>,
}
impl NoToolSupportProvider {
fn new() -> Self {
Self {
captured: Arc::new(Mutex::new(None)),
}
}
fn take_captured(&self) -> Option<ChatRequest> {
self.captured.lock().expect("capture mutex").take()
}
}
#[async_trait]
impl Provider for NoToolSupportProvider {
fn name(&self) -> &str {
"mock-tool-less-provider"
}
async fn list_models(&self) -> owlen_core::Result<Vec<ModelInfo>> {
Ok(vec![ModelInfo {
id: "tool-less-model".into(),
name: "Toolless Model".into(),
description: Some("A model without tool support.".into()),
provider: self.name().into(),
context_window: Some(4096),
capabilities: vec!["chat".into()],
supports_tools: false,
}])
}
async fn send_prompt(&self, request: ChatRequest) -> owlen_core::Result<ChatResponse> {
{
let mut guard = self.captured.lock().expect("capture mutex");
*guard = Some(request);
}
Ok(ChatResponse {
message: Message::assistant("ack".to_string()),
usage: None,
is_streaming: false,
is_final: true,
})
}
async fn stream_prompt(
&self,
_request: ChatRequest,
) -> owlen_core::Result<owlen_core::ChatStream> {
Err(Error::Provider(anyhow!(
"streaming disabled for mock provider"
)))
}
async fn health_check(&self) -> owlen_core::Result<()> {
Ok(())
}
fn as_any(&self) -> &(dyn Any + Send + Sync) {
self
}
}
fn tool_descriptor() -> McpToolDescriptor { fn tool_descriptor() -> McpToolDescriptor {
McpToolDescriptor { McpToolDescriptor {
name: WEB_SEARCH_TOOL_NAME.to_string(), name: WEB_SEARCH_TOOL_NAME.to_string(),
@@ -277,6 +348,55 @@ async fn streaming_file_write_consent_denied_returns_resolution() {
); );
} }
#[tokio::test(flavor = "multi_thread")]
async fn disables_tools_when_model_lacks_support() {
let raw_provider = Arc::new(NoToolSupportProvider::new());
let provider: Arc<dyn Provider> = raw_provider.clone();
let temp_dir = tempdir().expect("temp dir");
let storage = StorageManager::with_database_path(temp_dir.path().join("owlen-tests.db"))
.await
.expect("storage");
let mut config = Config::default();
config.general.default_provider = "mock-tool-less-provider".into();
config.general.default_model = Some("tool-less-model".into());
config.general.enable_streaming = false;
config.privacy.encrypt_local_data = false;
config.privacy.require_consent_per_session = false;
config.mcp.mode = McpMode::LocalOnly;
config
.refresh_mcp_servers(None)
.expect("refresh MCP servers");
let ui = Arc::new(NoOpUiController);
let mut session = SessionController::new(provider, config, Arc::new(storage), ui, false, None)
.await
.expect("session controller");
let outcome = session
.send_message(
"Please respond without tools.".to_string(),
ChatParameters::default(),
)
.await
.expect("send_message should succeed");
if let SessionOutcome::Complete(response) = outcome {
assert_eq!(response.message.content, "ack");
} else {
panic!("expected complete outcome when sending prompt");
}
let captured = raw_provider
.take_captured()
.expect("provider should capture chat request");
assert!(
captured.tools.is_none(),
"tools should be disabled when the selected model does not support tools"
);
}
#[tokio::test] #[tokio::test]
async fn web_tool_timeout_fails_over_to_cached_result() { async fn web_tool_timeout_fails_over_to_cached_result() {
let primary: Arc<dyn McpClient> = Arc::new(TimeoutClient); let primary: Arc<dyn McpClient> = Arc::new(TimeoutClient);