feat(provider/ollama): enable tool calls and enrich metadata
Acceptance Criteria:\n- tool descriptors from MCP are forwarded to Ollama chat requests\n- models advertise tool support when metadata or heuristics imply function calling\n- chat responses include provider metadata with final token metrics Test Notes:\n- cargo test -p owlen-core providers::ollama::tests::prepare_chat_request_serializes_tool_descriptors\n- cargo test -p owlen-core providers::ollama::tests::convert_model_marks_tool_capability\n- cargo test -p owlen-core providers::ollama::tests::convert_response_attaches_provider_metadata
This commit is contained in:
@@ -20,7 +20,10 @@ use ollama_rs::{
|
||||
ChatMessage as OllamaMessage, ChatMessageResponse as OllamaChatResponse,
|
||||
MessageRole as OllamaRole, request::ChatMessageRequest as OllamaChatRequest,
|
||||
},
|
||||
generation::tools::{ToolCall as OllamaToolCall, ToolCallFunction as OllamaToolCallFunction},
|
||||
generation::tools::{
|
||||
ToolCall as OllamaToolCall, ToolCallFunction as OllamaToolCallFunction,
|
||||
ToolInfo as OllamaToolInfo,
|
||||
},
|
||||
headers::{AUTHORIZATION, HeaderMap, HeaderValue},
|
||||
models::{LocalModel, ModelInfo as OllamaModelInfo, ModelOptions},
|
||||
};
|
||||
@@ -843,15 +846,6 @@ impl OllamaProvider {
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(descriptors) = &tools
|
||||
&& !descriptors.is_empty()
|
||||
{
|
||||
debug!(
|
||||
"Ignoring {} MCP tool descriptors for Ollama request (tool calling unsupported)",
|
||||
descriptors.len()
|
||||
);
|
||||
}
|
||||
|
||||
let converted_messages = messages.into_iter().map(convert_message).collect();
|
||||
let mut request = OllamaChatRequest::new(model.clone(), converted_messages);
|
||||
|
||||
@@ -859,6 +853,13 @@ impl OllamaProvider {
|
||||
request.options = Some(options);
|
||||
}
|
||||
|
||||
if let Some(tool_descriptors) = tools.as_ref() {
|
||||
let tool_infos = convert_tool_descriptors(tool_descriptors)?;
|
||||
if !tool_infos.is_empty() {
|
||||
request.tools = tool_infos;
|
||||
}
|
||||
}
|
||||
|
||||
Ok((model, request))
|
||||
}
|
||||
|
||||
@@ -1272,6 +1273,8 @@ impl OllamaProvider {
|
||||
.and_then(|raw| u32::try_from(raw).ok())
|
||||
});
|
||||
|
||||
let supports_tools = model_supports_tools(&name, &capabilities, detail.as_ref());
|
||||
|
||||
ModelInfo {
|
||||
id: name.clone(),
|
||||
name,
|
||||
@@ -1279,12 +1282,20 @@ impl OllamaProvider {
|
||||
provider: self.provider_name.clone(),
|
||||
context_window,
|
||||
capabilities,
|
||||
supports_tools: false,
|
||||
supports_tools,
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_ollama_response(response: OllamaChatResponse, streaming: bool) -> ChatResponse {
|
||||
let usage = response.final_data.as_ref().map(|data| {
|
||||
let OllamaChatResponse {
|
||||
model,
|
||||
created_at,
|
||||
message,
|
||||
done,
|
||||
final_data,
|
||||
} = response;
|
||||
|
||||
let usage = final_data.as_ref().map(|data| {
|
||||
let prompt = clamp_to_u32(data.prompt_eval_count);
|
||||
let completion = clamp_to_u32(data.eval_count);
|
||||
TokenUsage {
|
||||
@@ -1294,11 +1305,27 @@ impl OllamaProvider {
|
||||
}
|
||||
});
|
||||
|
||||
let mut message = convert_ollama_message(message);
|
||||
|
||||
let mut provider_meta = JsonMap::new();
|
||||
provider_meta.insert("model".into(), Value::String(model));
|
||||
provider_meta.insert("created_at".into(), Value::String(created_at));
|
||||
|
||||
if let Some(ref final_block) = final_data {
|
||||
if let Ok(value) = serde_json::to_value(final_block) {
|
||||
provider_meta.insert("final_data".into(), value);
|
||||
}
|
||||
}
|
||||
|
||||
message
|
||||
.metadata
|
||||
.insert("ollama".into(), Value::Object(provider_meta));
|
||||
|
||||
ChatResponse {
|
||||
message: convert_ollama_message(response.message),
|
||||
message,
|
||||
usage,
|
||||
is_streaming: streaming,
|
||||
is_final: if streaming { response.done } else { true },
|
||||
is_final: if streaming { done } else { true },
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1509,6 +1536,29 @@ fn build_model_options(parameters: &ChatParameters) -> Result<Option<ModelOption
|
||||
.map_err(|err| Error::Config(format!("Invalid Ollama options: {err}")))
|
||||
}
|
||||
|
||||
fn convert_tool_descriptors(descriptors: &[McpToolDescriptor]) -> Result<Vec<OllamaToolInfo>> {
|
||||
descriptors
|
||||
.iter()
|
||||
.map(|descriptor| {
|
||||
let payload = json!({
|
||||
"type": "Function",
|
||||
"function": {
|
||||
"name": descriptor.name,
|
||||
"description": descriptor.description,
|
||||
"parameters": descriptor.input_schema
|
||||
}
|
||||
});
|
||||
|
||||
serde_json::from_value(payload).map_err(|err| {
|
||||
Error::Config(format!(
|
||||
"Invalid tool schema for '{}': {err}",
|
||||
descriptor.name
|
||||
))
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn convert_message(message: Message) -> OllamaMessage {
|
||||
let Message {
|
||||
role,
|
||||
@@ -1629,6 +1679,42 @@ fn heuristic_capabilities(name: &str) -> Vec<String> {
|
||||
detected
|
||||
}
|
||||
|
||||
fn capability_implies_tools(label: &str) -> bool {
|
||||
let normalized = label.to_ascii_lowercase();
|
||||
normalized.contains("tool")
|
||||
|| normalized.contains("function_call")
|
||||
|| normalized.contains("function-call")
|
||||
|| normalized.contains("tool_call")
|
||||
}
|
||||
|
||||
fn model_supports_tools(
|
||||
name: &str,
|
||||
capabilities: &[String],
|
||||
detail: Option<&OllamaModelInfo>,
|
||||
) -> bool {
|
||||
if let Some(info) = detail {
|
||||
if info
|
||||
.capabilities
|
||||
.iter()
|
||||
.any(|capability| capability_implies_tools(capability))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if capabilities
|
||||
.iter()
|
||||
.any(|capability| capability_implies_tools(capability))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
let lowered = name.to_ascii_lowercase();
|
||||
["functioncall", "function-call", "function_call", "tool"]
|
||||
.iter()
|
||||
.any(|needle| lowered.contains(needle))
|
||||
}
|
||||
|
||||
fn build_model_description(scope: &str, detail: Option<&OllamaModelInfo>) -> String {
|
||||
if let Some(info) = detail {
|
||||
let mut parts = Vec::new();
|
||||
@@ -1942,7 +2028,9 @@ fn build_client_for_base(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::{Map as JsonMap, Value};
|
||||
use crate::mcp::McpToolDescriptor;
|
||||
use ollama_rs::generation::chat::ChatMessageFinalResponseData;
|
||||
use serde_json::{Map as JsonMap, Value, json};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
@@ -2050,6 +2138,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn cloud_provider_requires_api_key() {
|
||||
let _primary = EnvVarGuard::clear(OLLAMA_API_KEY_ENV);
|
||||
let _legacy_primary = EnvVarGuard::clear(LEGACY_OLLAMA_CLOUD_API_KEY_ENV);
|
||||
let _legacy_secondary = EnvVarGuard::clear(LEGACY_OWLEN_OLLAMA_CLOUD_API_KEY_ENV);
|
||||
|
||||
let config = ProviderConfig {
|
||||
enabled: true,
|
||||
provider_type: "ollama_cloud".to_string(),
|
||||
@@ -2275,6 +2367,113 @@ mod tests {
|
||||
assert_eq!(serialized["num_ctx"], json!(4096));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prepare_chat_request_serializes_tool_descriptors() {
|
||||
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
|
||||
|
||||
let descriptor = McpToolDescriptor {
|
||||
name: "web.search".to_string(),
|
||||
description: "Perform a web search".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}),
|
||||
requires_network: true,
|
||||
requires_filesystem: Vec::new(),
|
||||
};
|
||||
|
||||
let (_model_id, request) = provider
|
||||
.prepare_chat_request(
|
||||
"llama3".to_string(),
|
||||
vec![Message::user("Hello".to_string())],
|
||||
ChatParameters::default(),
|
||||
Some(vec![descriptor.clone()]),
|
||||
)
|
||||
.expect("request built");
|
||||
|
||||
assert_eq!(request.tools.len(), 1);
|
||||
let tool = &request.tools[0];
|
||||
assert_eq!(tool.function.name, descriptor.name);
|
||||
assert_eq!(tool.function.description, descriptor.description);
|
||||
|
||||
let serialized = serde_json::to_value(&tool.function.parameters).expect("serialize schema");
|
||||
assert_eq!(serialized, descriptor.input_schema);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_model_marks_tool_capability() {
|
||||
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
|
||||
|
||||
let local = LocalModel {
|
||||
name: "llama3-tool".to_string(),
|
||||
modified_at: "2025-10-23T00:00:00Z".to_string(),
|
||||
size: 0,
|
||||
};
|
||||
|
||||
let detail = OllamaModelInfo {
|
||||
license: String::new(),
|
||||
modelfile: String::new(),
|
||||
parameters: String::new(),
|
||||
template: String::new(),
|
||||
model_info: JsonMap::new(),
|
||||
capabilities: vec!["function_call".to_string()],
|
||||
};
|
||||
|
||||
let info = provider.convert_model(OllamaMode::Local, local, Some(detail));
|
||||
assert!(info.supports_tools);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_response_attaches_provider_metadata() {
|
||||
let final_data = ChatMessageFinalResponseData {
|
||||
total_duration: 10,
|
||||
load_duration: 2,
|
||||
prompt_eval_count: 42,
|
||||
prompt_eval_duration: 4,
|
||||
eval_count: 21,
|
||||
eval_duration: 6,
|
||||
};
|
||||
|
||||
let response = OllamaChatResponse {
|
||||
model: "llama3".to_string(),
|
||||
created_at: "2025-10-23T18:00:00Z".to_string(),
|
||||
message: OllamaMessage {
|
||||
role: OllamaRole::Assistant,
|
||||
content: "Tool output incoming".to_string(),
|
||||
tool_calls: Vec::new(),
|
||||
images: None,
|
||||
thinking: None,
|
||||
},
|
||||
done: true,
|
||||
final_data: Some(final_data),
|
||||
};
|
||||
|
||||
let chunk = OllamaProvider::convert_ollama_response(response, false);
|
||||
|
||||
let metadata = chunk
|
||||
.message
|
||||
.metadata
|
||||
.get("ollama")
|
||||
.and_then(Value::as_object)
|
||||
.expect("ollama metadata present");
|
||||
assert_eq!(
|
||||
metadata.get("model").and_then(Value::as_str),
|
||||
Some("llama3")
|
||||
);
|
||||
assert!(metadata.contains_key("final_data"));
|
||||
assert_eq!(
|
||||
metadata.get("created_at").and_then(Value::as_str).unwrap(),
|
||||
"2025-10-23T18:00:00Z"
|
||||
);
|
||||
|
||||
let usage = chunk.usage.expect("usage populated");
|
||||
assert_eq!(usage.prompt_tokens, 42);
|
||||
assert_eq!(usage.completion_tokens, 21);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heuristic_capabilities_detects_thinking_models() {
|
||||
let caps = heuristic_capabilities("deepseek-r1");
|
||||
@@ -2318,6 +2517,37 @@ impl Drop for ProbeOverrideGuard {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
struct EnvVarGuard {
|
||||
key: &'static str,
|
||||
original: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl EnvVarGuard {
|
||||
fn clear(key: &'static str) -> Self {
|
||||
let original = std::env::var(key).ok();
|
||||
unsafe {
|
||||
std::env::remove_var(key);
|
||||
}
|
||||
Self { key, original }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Drop for EnvVarGuard {
|
||||
fn drop(&mut self) {
|
||||
match &self.original {
|
||||
Some(value) => unsafe {
|
||||
std::env::set_var(self.key, value);
|
||||
},
|
||||
None => unsafe {
|
||||
std::env::remove_var(self.key);
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_mode_with_api_key_and_successful_probe_prefers_local() {
|
||||
let _guard = ProbeOverrideGuard::set(Some(true));
|
||||
|
||||
Reference in New Issue
Block a user