feat(session): disable tools for unsupported models
This commit is contained in:
@@ -35,7 +35,7 @@ use crate::{
|
||||
};
|
||||
use crate::{Error, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use log::{info, warn};
|
||||
use log::{debug, info, warn};
|
||||
use reqwest::Url;
|
||||
use serde_json::{Value, json};
|
||||
use std::cmp::{max, min};
|
||||
@@ -1900,11 +1900,43 @@ impl SessionController {
|
||||
let streaming = { self.config.lock().await.general.enable_streaming || parameters.stream };
|
||||
parameters.stream = streaming;
|
||||
|
||||
let tools = if !self.tool_registry.all().is_empty() {
|
||||
let active_model = self.conversation.active().model.clone();
|
||||
let registry_tools = self.tool_registry.all();
|
||||
let mut include_tools = !registry_tools.is_empty();
|
||||
|
||||
if include_tools {
|
||||
let cached_support = self.model_manager.select(&active_model).await;
|
||||
let supports_tools = match cached_support {
|
||||
Some(info) => info.supports_tools,
|
||||
None => match self.models(false).await {
|
||||
Ok(models) => models
|
||||
.iter()
|
||||
.find(|model| model.id == active_model || model.name == active_model)
|
||||
.map(|model| model.supports_tools)
|
||||
.unwrap_or(true),
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"Unable to resolve tool support for model '{}': {}. Assuming tools are supported.",
|
||||
active_model, err
|
||||
);
|
||||
true
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
if !supports_tools {
|
||||
include_tools = false;
|
||||
debug!(
|
||||
"Disabling tools for model '{}' because it does not advertise tool support.",
|
||||
active_model
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let tools = if include_tools {
|
||||
Some(
|
||||
self.tool_registry
|
||||
.all()
|
||||
.into_iter()
|
||||
registry_tools
|
||||
.iter()
|
||||
.map(|tool| crate::mcp::McpToolDescriptor {
|
||||
name: tool.name().to_string(),
|
||||
description: tool.description().to_string(),
|
||||
@@ -1919,7 +1951,7 @@ impl SessionController {
|
||||
};
|
||||
|
||||
let mut request = ChatRequest {
|
||||
model: self.conversation.active().model.clone(),
|
||||
model: active_model,
|
||||
messages: self.conversation.active().messages.clone(),
|
||||
parameters: parameters.clone(),
|
||||
tools: tools.clone(),
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
use std::{any::Any, collections::HashMap, sync::Arc};
|
||||
use std::{
|
||||
any::Any,
|
||||
collections::HashMap,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use anyhow::anyhow;
|
||||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches};
|
||||
@@ -87,6 +92,72 @@ impl Provider for StreamingToolProvider {
|
||||
}
|
||||
}
|
||||
|
||||
struct NoToolSupportProvider {
|
||||
captured: Arc<Mutex<Option<ChatRequest>>>,
|
||||
}
|
||||
|
||||
impl NoToolSupportProvider {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
captured: Arc::new(Mutex::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
fn take_captured(&self) -> Option<ChatRequest> {
|
||||
self.captured.lock().expect("capture mutex").take()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for NoToolSupportProvider {
|
||||
fn name(&self) -> &str {
|
||||
"mock-tool-less-provider"
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> owlen_core::Result<Vec<ModelInfo>> {
|
||||
Ok(vec![ModelInfo {
|
||||
id: "tool-less-model".into(),
|
||||
name: "Toolless Model".into(),
|
||||
description: Some("A model without tool support.".into()),
|
||||
provider: self.name().into(),
|
||||
context_window: Some(4096),
|
||||
capabilities: vec!["chat".into()],
|
||||
supports_tools: false,
|
||||
}])
|
||||
}
|
||||
|
||||
async fn send_prompt(&self, request: ChatRequest) -> owlen_core::Result<ChatResponse> {
|
||||
{
|
||||
let mut guard = self.captured.lock().expect("capture mutex");
|
||||
*guard = Some(request);
|
||||
}
|
||||
|
||||
Ok(ChatResponse {
|
||||
message: Message::assistant("ack".to_string()),
|
||||
usage: None,
|
||||
is_streaming: false,
|
||||
is_final: true,
|
||||
})
|
||||
}
|
||||
|
||||
async fn stream_prompt(
|
||||
&self,
|
||||
_request: ChatRequest,
|
||||
) -> owlen_core::Result<owlen_core::ChatStream> {
|
||||
Err(Error::Provider(anyhow!(
|
||||
"streaming disabled for mock provider"
|
||||
)))
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> owlen_core::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &(dyn Any + Send + Sync) {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
fn tool_descriptor() -> McpToolDescriptor {
|
||||
McpToolDescriptor {
|
||||
name: WEB_SEARCH_TOOL_NAME.to_string(),
|
||||
@@ -277,6 +348,55 @@ async fn streaming_file_write_consent_denied_returns_resolution() {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn disables_tools_when_model_lacks_support() {
|
||||
let raw_provider = Arc::new(NoToolSupportProvider::new());
|
||||
let provider: Arc<dyn Provider> = raw_provider.clone();
|
||||
|
||||
let temp_dir = tempdir().expect("temp dir");
|
||||
let storage = StorageManager::with_database_path(temp_dir.path().join("owlen-tests.db"))
|
||||
.await
|
||||
.expect("storage");
|
||||
|
||||
let mut config = Config::default();
|
||||
config.general.default_provider = "mock-tool-less-provider".into();
|
||||
config.general.default_model = Some("tool-less-model".into());
|
||||
config.general.enable_streaming = false;
|
||||
config.privacy.encrypt_local_data = false;
|
||||
config.privacy.require_consent_per_session = false;
|
||||
config.mcp.mode = McpMode::LocalOnly;
|
||||
config
|
||||
.refresh_mcp_servers(None)
|
||||
.expect("refresh MCP servers");
|
||||
|
||||
let ui = Arc::new(NoOpUiController);
|
||||
let mut session = SessionController::new(provider, config, Arc::new(storage), ui, false, None)
|
||||
.await
|
||||
.expect("session controller");
|
||||
|
||||
let outcome = session
|
||||
.send_message(
|
||||
"Please respond without tools.".to_string(),
|
||||
ChatParameters::default(),
|
||||
)
|
||||
.await
|
||||
.expect("send_message should succeed");
|
||||
|
||||
if let SessionOutcome::Complete(response) = outcome {
|
||||
assert_eq!(response.message.content, "ack");
|
||||
} else {
|
||||
panic!("expected complete outcome when sending prompt");
|
||||
}
|
||||
|
||||
let captured = raw_provider
|
||||
.take_captured()
|
||||
.expect("provider should capture chat request");
|
||||
assert!(
|
||||
captured.tools.is_none(),
|
||||
"tools should be disabled when the selected model does not support tools"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn web_tool_timeout_fails_over_to_cached_result() {
|
||||
let primary: Arc<dyn McpClient> = Arc::new(TimeoutClient);
|
||||
|
||||
Reference in New Issue
Block a user