feat(provider/ollama): enable tool calls and enrich metadata

Acceptance Criteria:\n- tool descriptors from MCP are forwarded to Ollama chat requests\n- models advertise tool support when metadata or heuristics imply function calling\n- chat responses include provider metadata with final token metrics

Test Notes:\n- cargo test -p owlen-core providers::ollama::tests::prepare_chat_request_serializes_tool_descriptors\n- cargo test -p owlen-core providers::ollama::tests::convert_model_marks_tool_capability\n- cargo test -p owlen-core providers::ollama::tests::convert_response_attaches_provider_metadata
This commit is contained in:
2025-10-23 20:12:12 +02:00
parent e0b14a42f2
commit 24671f5f2a

View File

@@ -20,7 +20,10 @@ use ollama_rs::{
ChatMessage as OllamaMessage, ChatMessageResponse as OllamaChatResponse,
MessageRole as OllamaRole, request::ChatMessageRequest as OllamaChatRequest,
},
generation::tools::{ToolCall as OllamaToolCall, ToolCallFunction as OllamaToolCallFunction},
generation::tools::{
ToolCall as OllamaToolCall, ToolCallFunction as OllamaToolCallFunction,
ToolInfo as OllamaToolInfo,
},
headers::{AUTHORIZATION, HeaderMap, HeaderValue},
models::{LocalModel, ModelInfo as OllamaModelInfo, ModelOptions},
};
@@ -843,15 +846,6 @@ impl OllamaProvider {
);
}
if let Some(descriptors) = &tools
&& !descriptors.is_empty()
{
debug!(
"Ignoring {} MCP tool descriptors for Ollama request (tool calling unsupported)",
descriptors.len()
);
}
let converted_messages = messages.into_iter().map(convert_message).collect();
let mut request = OllamaChatRequest::new(model.clone(), converted_messages);
@@ -859,6 +853,13 @@ impl OllamaProvider {
request.options = Some(options);
}
if let Some(tool_descriptors) = tools.as_ref() {
let tool_infos = convert_tool_descriptors(tool_descriptors)?;
if !tool_infos.is_empty() {
request.tools = tool_infos;
}
}
Ok((model, request))
}
@@ -1272,6 +1273,8 @@ impl OllamaProvider {
.and_then(|raw| u32::try_from(raw).ok())
});
let supports_tools = model_supports_tools(&name, &capabilities, detail.as_ref());
ModelInfo {
id: name.clone(),
name,
@@ -1279,12 +1282,20 @@ impl OllamaProvider {
provider: self.provider_name.clone(),
context_window,
capabilities,
supports_tools: false,
supports_tools,
}
}
fn convert_ollama_response(response: OllamaChatResponse, streaming: bool) -> ChatResponse {
let usage = response.final_data.as_ref().map(|data| {
let OllamaChatResponse {
model,
created_at,
message,
done,
final_data,
} = response;
let usage = final_data.as_ref().map(|data| {
let prompt = clamp_to_u32(data.prompt_eval_count);
let completion = clamp_to_u32(data.eval_count);
TokenUsage {
@@ -1294,11 +1305,27 @@ impl OllamaProvider {
}
});
let mut message = convert_ollama_message(message);
let mut provider_meta = JsonMap::new();
provider_meta.insert("model".into(), Value::String(model));
provider_meta.insert("created_at".into(), Value::String(created_at));
if let Some(ref final_block) = final_data {
if let Ok(value) = serde_json::to_value(final_block) {
provider_meta.insert("final_data".into(), value);
}
}
message
.metadata
.insert("ollama".into(), Value::Object(provider_meta));
ChatResponse {
message: convert_ollama_message(response.message),
message,
usage,
is_streaming: streaming,
is_final: if streaming { response.done } else { true },
is_final: if streaming { done } else { true },
}
}
@@ -1509,6 +1536,29 @@ fn build_model_options(parameters: &ChatParameters) -> Result<Option<ModelOption
.map_err(|err| Error::Config(format!("Invalid Ollama options: {err}")))
}
fn convert_tool_descriptors(descriptors: &[McpToolDescriptor]) -> Result<Vec<OllamaToolInfo>> {
descriptors
.iter()
.map(|descriptor| {
let payload = json!({
"type": "Function",
"function": {
"name": descriptor.name,
"description": descriptor.description,
"parameters": descriptor.input_schema
}
});
serde_json::from_value(payload).map_err(|err| {
Error::Config(format!(
"Invalid tool schema for '{}': {err}",
descriptor.name
))
})
})
.collect()
}
fn convert_message(message: Message) -> OllamaMessage {
let Message {
role,
@@ -1629,6 +1679,42 @@ fn heuristic_capabilities(name: &str) -> Vec<String> {
detected
}
fn capability_implies_tools(label: &str) -> bool {
let normalized = label.to_ascii_lowercase();
normalized.contains("tool")
|| normalized.contains("function_call")
|| normalized.contains("function-call")
|| normalized.contains("tool_call")
}
fn model_supports_tools(
name: &str,
capabilities: &[String],
detail: Option<&OllamaModelInfo>,
) -> bool {
if let Some(info) = detail {
if info
.capabilities
.iter()
.any(|capability| capability_implies_tools(capability))
{
return true;
}
}
if capabilities
.iter()
.any(|capability| capability_implies_tools(capability))
{
return true;
}
let lowered = name.to_ascii_lowercase();
["functioncall", "function-call", "function_call", "tool"]
.iter()
.any(|needle| lowered.contains(needle))
}
fn build_model_description(scope: &str, detail: Option<&OllamaModelInfo>) -> String {
if let Some(info) = detail {
let mut parts = Vec::new();
@@ -1942,7 +2028,9 @@ fn build_client_for_base(
#[cfg(test)]
mod tests {
use super::*;
use serde_json::{Map as JsonMap, Value};
use crate::mcp::McpToolDescriptor;
use ollama_rs::generation::chat::ChatMessageFinalResponseData;
use serde_json::{Map as JsonMap, Value, json};
use std::collections::HashMap;
#[test]
@@ -2050,6 +2138,10 @@ mod tests {
#[test]
fn cloud_provider_requires_api_key() {
let _primary = EnvVarGuard::clear(OLLAMA_API_KEY_ENV);
let _legacy_primary = EnvVarGuard::clear(LEGACY_OLLAMA_CLOUD_API_KEY_ENV);
let _legacy_secondary = EnvVarGuard::clear(LEGACY_OWLEN_OLLAMA_CLOUD_API_KEY_ENV);
let config = ProviderConfig {
enabled: true,
provider_type: "ollama_cloud".to_string(),
@@ -2275,6 +2367,113 @@ mod tests {
assert_eq!(serialized["num_ctx"], json!(4096));
}
#[test]
fn prepare_chat_request_serializes_tool_descriptors() {
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
let descriptor = McpToolDescriptor {
name: "web.search".to_string(),
description: "Perform a web search".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"query": {"type": "string"}
},
"required": ["query"]
}),
requires_network: true,
requires_filesystem: Vec::new(),
};
let (_model_id, request) = provider
.prepare_chat_request(
"llama3".to_string(),
vec![Message::user("Hello".to_string())],
ChatParameters::default(),
Some(vec![descriptor.clone()]),
)
.expect("request built");
assert_eq!(request.tools.len(), 1);
let tool = &request.tools[0];
assert_eq!(tool.function.name, descriptor.name);
assert_eq!(tool.function.description, descriptor.description);
let serialized = serde_json::to_value(&tool.function.parameters).expect("serialize schema");
assert_eq!(serialized, descriptor.input_schema);
}
#[test]
fn convert_model_marks_tool_capability() {
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
let local = LocalModel {
name: "llama3-tool".to_string(),
modified_at: "2025-10-23T00:00:00Z".to_string(),
size: 0,
};
let detail = OllamaModelInfo {
license: String::new(),
modelfile: String::new(),
parameters: String::new(),
template: String::new(),
model_info: JsonMap::new(),
capabilities: vec!["function_call".to_string()],
};
let info = provider.convert_model(OllamaMode::Local, local, Some(detail));
assert!(info.supports_tools);
}
#[test]
fn convert_response_attaches_provider_metadata() {
let final_data = ChatMessageFinalResponseData {
total_duration: 10,
load_duration: 2,
prompt_eval_count: 42,
prompt_eval_duration: 4,
eval_count: 21,
eval_duration: 6,
};
let response = OllamaChatResponse {
model: "llama3".to_string(),
created_at: "2025-10-23T18:00:00Z".to_string(),
message: OllamaMessage {
role: OllamaRole::Assistant,
content: "Tool output incoming".to_string(),
tool_calls: Vec::new(),
images: None,
thinking: None,
},
done: true,
final_data: Some(final_data),
};
let chunk = OllamaProvider::convert_ollama_response(response, false);
let metadata = chunk
.message
.metadata
.get("ollama")
.and_then(Value::as_object)
.expect("ollama metadata present");
assert_eq!(
metadata.get("model").and_then(Value::as_str),
Some("llama3")
);
assert!(metadata.contains_key("final_data"));
assert_eq!(
metadata.get("created_at").and_then(Value::as_str).unwrap(),
"2025-10-23T18:00:00Z"
);
let usage = chunk.usage.expect("usage populated");
assert_eq!(usage.prompt_tokens, 42);
assert_eq!(usage.completion_tokens, 21);
}
#[test]
fn heuristic_capabilities_detects_thinking_models() {
let caps = heuristic_capabilities("deepseek-r1");
@@ -2318,6 +2517,37 @@ impl Drop for ProbeOverrideGuard {
}
}
#[cfg(test)]
struct EnvVarGuard {
key: &'static str,
original: Option<String>,
}
#[cfg(test)]
impl EnvVarGuard {
fn clear(key: &'static str) -> Self {
let original = std::env::var(key).ok();
unsafe {
std::env::remove_var(key);
}
Self { key, original }
}
}
#[cfg(test)]
impl Drop for EnvVarGuard {
fn drop(&mut self) {
match &self.original {
Some(value) => unsafe {
std::env::set_var(self.key, value);
},
None => unsafe {
std::env::remove_var(self.key);
},
}
}
}
#[test]
fn auto_mode_with_api_key_and_successful_probe_prefers_local() {
let _guard = ProbeOverrideGuard::set(Some(true));