feat(session): disable tools for unsupported models
This commit is contained in:
@@ -35,7 +35,7 @@ use crate::{
|
|||||||
};
|
};
|
||||||
use crate::{Error, Result};
|
use crate::{Error, Result};
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use log::{info, warn};
|
use log::{debug, info, warn};
|
||||||
use reqwest::Url;
|
use reqwest::Url;
|
||||||
use serde_json::{Value, json};
|
use serde_json::{Value, json};
|
||||||
use std::cmp::{max, min};
|
use std::cmp::{max, min};
|
||||||
@@ -1900,11 +1900,43 @@ impl SessionController {
|
|||||||
let streaming = { self.config.lock().await.general.enable_streaming || parameters.stream };
|
let streaming = { self.config.lock().await.general.enable_streaming || parameters.stream };
|
||||||
parameters.stream = streaming;
|
parameters.stream = streaming;
|
||||||
|
|
||||||
let tools = if !self.tool_registry.all().is_empty() {
|
let active_model = self.conversation.active().model.clone();
|
||||||
|
let registry_tools = self.tool_registry.all();
|
||||||
|
let mut include_tools = !registry_tools.is_empty();
|
||||||
|
|
||||||
|
if include_tools {
|
||||||
|
let cached_support = self.model_manager.select(&active_model).await;
|
||||||
|
let supports_tools = match cached_support {
|
||||||
|
Some(info) => info.supports_tools,
|
||||||
|
None => match self.models(false).await {
|
||||||
|
Ok(models) => models
|
||||||
|
.iter()
|
||||||
|
.find(|model| model.id == active_model || model.name == active_model)
|
||||||
|
.map(|model| model.supports_tools)
|
||||||
|
.unwrap_or(true),
|
||||||
|
Err(err) => {
|
||||||
|
warn!(
|
||||||
|
"Unable to resolve tool support for model '{}': {}. Assuming tools are supported.",
|
||||||
|
active_model, err
|
||||||
|
);
|
||||||
|
true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
if !supports_tools {
|
||||||
|
include_tools = false;
|
||||||
|
debug!(
|
||||||
|
"Disabling tools for model '{}' because it does not advertise tool support.",
|
||||||
|
active_model
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let tools = if include_tools {
|
||||||
Some(
|
Some(
|
||||||
self.tool_registry
|
registry_tools
|
||||||
.all()
|
.iter()
|
||||||
.into_iter()
|
|
||||||
.map(|tool| crate::mcp::McpToolDescriptor {
|
.map(|tool| crate::mcp::McpToolDescriptor {
|
||||||
name: tool.name().to_string(),
|
name: tool.name().to_string(),
|
||||||
description: tool.description().to_string(),
|
description: tool.description().to_string(),
|
||||||
@@ -1919,7 +1951,7 @@ impl SessionController {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let mut request = ChatRequest {
|
let mut request = ChatRequest {
|
||||||
model: self.conversation.active().model.clone(),
|
model: active_model,
|
||||||
messages: self.conversation.active().messages.clone(),
|
messages: self.conversation.active().messages.clone(),
|
||||||
parameters: parameters.clone(),
|
parameters: parameters.clone(),
|
||||||
tools: tools.clone(),
|
tools: tools.clone(),
|
||||||
|
|||||||
@@ -1,5 +1,10 @@
|
|||||||
use std::{any::Any, collections::HashMap, sync::Arc};
|
use std::{
|
||||||
|
any::Any,
|
||||||
|
collections::HashMap,
|
||||||
|
sync::{Arc, Mutex},
|
||||||
|
};
|
||||||
|
|
||||||
|
use anyhow::anyhow;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches};
|
use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches};
|
||||||
@@ -87,6 +92,72 @@ impl Provider for StreamingToolProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct NoToolSupportProvider {
|
||||||
|
captured: Arc<Mutex<Option<ChatRequest>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NoToolSupportProvider {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
captured: Arc::new(Mutex::new(None)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn take_captured(&self) -> Option<ChatRequest> {
|
||||||
|
self.captured.lock().expect("capture mutex").take()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for NoToolSupportProvider {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"mock-tool-less-provider"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_models(&self) -> owlen_core::Result<Vec<ModelInfo>> {
|
||||||
|
Ok(vec![ModelInfo {
|
||||||
|
id: "tool-less-model".into(),
|
||||||
|
name: "Toolless Model".into(),
|
||||||
|
description: Some("A model without tool support.".into()),
|
||||||
|
provider: self.name().into(),
|
||||||
|
context_window: Some(4096),
|
||||||
|
capabilities: vec!["chat".into()],
|
||||||
|
supports_tools: false,
|
||||||
|
}])
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_prompt(&self, request: ChatRequest) -> owlen_core::Result<ChatResponse> {
|
||||||
|
{
|
||||||
|
let mut guard = self.captured.lock().expect("capture mutex");
|
||||||
|
*guard = Some(request);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ChatResponse {
|
||||||
|
message: Message::assistant("ack".to_string()),
|
||||||
|
usage: None,
|
||||||
|
is_streaming: false,
|
||||||
|
is_final: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stream_prompt(
|
||||||
|
&self,
|
||||||
|
_request: ChatRequest,
|
||||||
|
) -> owlen_core::Result<owlen_core::ChatStream> {
|
||||||
|
Err(Error::Provider(anyhow!(
|
||||||
|
"streaming disabled for mock provider"
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health_check(&self) -> owlen_core::Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_any(&self) -> &(dyn Any + Send + Sync) {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn tool_descriptor() -> McpToolDescriptor {
|
fn tool_descriptor() -> McpToolDescriptor {
|
||||||
McpToolDescriptor {
|
McpToolDescriptor {
|
||||||
name: WEB_SEARCH_TOOL_NAME.to_string(),
|
name: WEB_SEARCH_TOOL_NAME.to_string(),
|
||||||
@@ -277,6 +348,55 @@ async fn streaming_file_write_consent_denied_returns_resolution() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread")]
|
||||||
|
async fn disables_tools_when_model_lacks_support() {
|
||||||
|
let raw_provider = Arc::new(NoToolSupportProvider::new());
|
||||||
|
let provider: Arc<dyn Provider> = raw_provider.clone();
|
||||||
|
|
||||||
|
let temp_dir = tempdir().expect("temp dir");
|
||||||
|
let storage = StorageManager::with_database_path(temp_dir.path().join("owlen-tests.db"))
|
||||||
|
.await
|
||||||
|
.expect("storage");
|
||||||
|
|
||||||
|
let mut config = Config::default();
|
||||||
|
config.general.default_provider = "mock-tool-less-provider".into();
|
||||||
|
config.general.default_model = Some("tool-less-model".into());
|
||||||
|
config.general.enable_streaming = false;
|
||||||
|
config.privacy.encrypt_local_data = false;
|
||||||
|
config.privacy.require_consent_per_session = false;
|
||||||
|
config.mcp.mode = McpMode::LocalOnly;
|
||||||
|
config
|
||||||
|
.refresh_mcp_servers(None)
|
||||||
|
.expect("refresh MCP servers");
|
||||||
|
|
||||||
|
let ui = Arc::new(NoOpUiController);
|
||||||
|
let mut session = SessionController::new(provider, config, Arc::new(storage), ui, false, None)
|
||||||
|
.await
|
||||||
|
.expect("session controller");
|
||||||
|
|
||||||
|
let outcome = session
|
||||||
|
.send_message(
|
||||||
|
"Please respond without tools.".to_string(),
|
||||||
|
ChatParameters::default(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("send_message should succeed");
|
||||||
|
|
||||||
|
if let SessionOutcome::Complete(response) = outcome {
|
||||||
|
assert_eq!(response.message.content, "ack");
|
||||||
|
} else {
|
||||||
|
panic!("expected complete outcome when sending prompt");
|
||||||
|
}
|
||||||
|
|
||||||
|
let captured = raw_provider
|
||||||
|
.take_captured()
|
||||||
|
.expect("provider should capture chat request");
|
||||||
|
assert!(
|
||||||
|
captured.tools.is_none(),
|
||||||
|
"tools should be disabled when the selected model does not support tools"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn web_tool_timeout_fails_over_to_cached_result() {
|
async fn web_tool_timeout_fails_over_to_cached_result() {
|
||||||
let primary: Arc<dyn McpClient> = Arc::new(TimeoutClient);
|
let primary: Arc<dyn McpClient> = Arc::new(TimeoutClient);
|
||||||
|
|||||||
Reference in New Issue
Block a user