Files
owlen/crates/owlen-core/src/provider.rs
vikingowl c2f5ccea3b feat(model): add rich model metadata, caching, and UI panel for inspection
Introduce `DetailedModelInfo` and `ModelInfoRetrievalError` structs for richer model data.
Add `ModelDetailsCache` with TTL‑based storage and async API for get/insert/invalidate.
Extend `OllamaProvider` to fetch, cache, refresh, and list detailed model info.
Expose model‑detail methods in `Session` for on‑demand and bulk retrieval.
Add `ModelInfoPanel` widget to display detailed info with scrolling support.
Update TUI rendering to show the panel, compute viewport height, and render model selector labels with parameters, size, and context length.
Adjust imports and module re‑exports accordingly.
2025-10-12 09:45:16 +02:00

381 lines
12 KiB
Rust

//! Provider traits and registries.
use crate::{types::*, Error, Result};
use anyhow::anyhow;
use futures::{Stream, StreamExt};
use std::any::Any;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
/// A stream of chat responses
pub type ChatStream = Pin<Box<dyn Stream<Item = Result<ChatResponse>> + Send>>;
/// Trait for LLM providers (Ollama, OpenAI, Anthropic, etc.) with zero-cost static dispatch.
pub trait LLMProvider: Send + Sync + 'static + Any + Sized {
type Stream: Stream<Item = Result<ChatResponse>> + Send + 'static;
type ListModelsFuture<'a>: Future<Output = Result<Vec<ModelInfo>>> + Send
where
Self: 'a;
type ChatFuture<'a>: Future<Output = Result<ChatResponse>> + Send
where
Self: 'a;
type ChatStreamFuture<'a>: Future<Output = Result<Self::Stream>> + Send
where
Self: 'a;
type HealthCheckFuture<'a>: Future<Output = Result<()>> + Send
where
Self: 'a;
fn name(&self) -> &str;
fn list_models(&self) -> Self::ListModelsFuture<'_>;
fn chat(&self, request: ChatRequest) -> Self::ChatFuture<'_>;
fn chat_stream(&self, request: ChatRequest) -> Self::ChatStreamFuture<'_>;
fn health_check(&self) -> Self::HealthCheckFuture<'_>;
fn config_schema(&self) -> serde_json::Value {
serde_json::json!({})
}
fn as_any(&self) -> &(dyn Any + Send + Sync) {
self
}
}
/// Helper that implements [`LLMProvider::chat`] in terms of [`LLMProvider::chat_stream`].
pub async fn chat_via_stream<'a, P>(provider: &'a P, request: ChatRequest) -> Result<ChatResponse>
where
P: LLMProvider + 'a,
{
let stream = provider.chat_stream(request).await?;
let mut boxed: ChatStream = Box::pin(stream);
match boxed.next().await {
Some(Ok(response)) => Ok(response),
Some(Err(err)) => Err(err),
None => Err(Error::Provider(anyhow!(
"Empty chat stream from provider {}",
provider.name()
))),
}
}
/// Object-safe wrapper trait for runtime-configurable provider usage.
#[async_trait::async_trait]
pub trait Provider: Send + Sync {
/// Get the name of this provider.
fn name(&self) -> &str;
/// List available models from this provider.
async fn list_models(&self) -> Result<Vec<ModelInfo>>;
/// Send a chat completion request.
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse>;
/// Send a streaming chat completion request.
async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream>;
/// Check if the provider is available/healthy.
async fn health_check(&self) -> Result<()>;
/// Get provider-specific configuration schema.
fn config_schema(&self) -> serde_json::Value {
serde_json::json!({})
}
fn as_any(&self) -> &(dyn Any + Send + Sync);
}
#[async_trait::async_trait]
impl<T> Provider for T
where
T: LLMProvider,
{
fn name(&self) -> &str {
LLMProvider::name(self)
}
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
LLMProvider::list_models(self).await
}
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
LLMProvider::chat(self, request).await
}
async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream> {
let stream = LLMProvider::chat_stream(self, request).await?;
Ok(Box::pin(stream))
}
async fn health_check(&self) -> Result<()> {
LLMProvider::health_check(self).await
}
fn config_schema(&self) -> serde_json::Value {
LLMProvider::config_schema(self)
}
fn as_any(&self) -> &(dyn Any + Send + Sync) {
LLMProvider::as_any(self)
}
}
/// Configuration for a provider
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ProviderConfig {
/// Provider type identifier
pub provider_type: String,
/// Base URL for API calls
pub base_url: Option<String>,
/// API key or token
pub api_key: Option<String>,
/// Additional provider-specific configuration
#[serde(flatten)]
pub extra: std::collections::HashMap<String, serde_json::Value>,
}
/// A registry of providers
pub struct ProviderRegistry {
providers: std::collections::HashMap<String, Arc<dyn Provider>>,
}
impl ProviderRegistry {
/// Create a new provider registry
pub fn new() -> Self {
Self {
providers: std::collections::HashMap::new(),
}
}
/// Register a provider using static dispatch.
pub fn register<P: LLMProvider + 'static>(&mut self, provider: P) {
self.register_arc(Arc::new(provider));
}
/// Register an already wrapped provider
pub fn register_arc(&mut self, provider: Arc<dyn Provider>) {
let name = provider.name().to_string();
self.providers.insert(name, provider);
}
/// Get a provider by name
pub fn get(&self, name: &str) -> Option<Arc<dyn Provider>> {
self.providers.get(name).cloned()
}
/// List all registered provider names
pub fn list_providers(&self) -> Vec<String> {
self.providers.keys().cloned().collect()
}
/// Get all models from all providers
pub async fn list_all_models(&self) -> Result<Vec<ModelInfo>> {
let mut all_models = Vec::new();
for provider in self.providers.values() {
match provider.list_models().await {
Ok(mut models) => all_models.append(&mut models),
Err(_) => {
// Continue with other providers
}
}
}
Ok(all_models)
}
}
impl Default for ProviderRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
pub mod test_utils {
use super::*;
use crate::types::{ChatRequest, ChatResponse, Message, ModelInfo, Role};
use futures::stream;
use std::future::{ready, Ready};
/// Mock provider for testing
#[derive(Default)]
pub struct MockProvider;
impl LLMProvider for MockProvider {
type Stream = stream::Iter<std::vec::IntoIter<Result<ChatResponse>>>;
type ListModelsFuture<'a> = Ready<Result<Vec<ModelInfo>>>;
type ChatFuture<'a> = Ready<Result<ChatResponse>>;
type ChatStreamFuture<'a> = Ready<Result<Self::Stream>>;
type HealthCheckFuture<'a> = Ready<Result<()>>;
fn name(&self) -> &str {
"mock"
}
fn list_models(&self) -> Self::ListModelsFuture<'_> {
ready(Ok(vec![ModelInfo {
id: "mock-model".to_string(),
provider: "mock".to_string(),
name: "mock-model".to_string(),
description: None,
context_window: None,
capabilities: vec![],
supports_tools: false,
}]))
}
fn chat(&self, request: ChatRequest) -> Self::ChatFuture<'_> {
ready(Ok(self.build_response(&request)))
}
fn chat_stream(&self, request: ChatRequest) -> Self::ChatStreamFuture<'_> {
let response = self.build_response(&request);
ready(Ok(stream::iter(vec![Ok(response)])))
}
fn health_check(&self) -> Self::HealthCheckFuture<'_> {
ready(Ok(()))
}
}
impl MockProvider {
fn build_response(&self, request: &ChatRequest) -> ChatResponse {
let content = format!(
"Mock response to: {}",
request
.messages
.last()
.map(|m| m.content.clone())
.unwrap_or_default()
);
ChatResponse {
message: Message::new(Role::Assistant, content),
usage: None,
is_streaming: false,
is_final: true,
}
}
}
}
#[cfg(test)]
mod tests {
use super::test_utils::MockProvider;
use super::*;
use crate::types::{ChatParameters, ChatRequest, ChatResponse, Message, ModelInfo, Role};
use futures::stream;
use std::future::{ready, Ready};
use std::sync::Arc;
struct StreamingProvider;
impl LLMProvider for StreamingProvider {
type Stream = stream::Iter<std::vec::IntoIter<Result<ChatResponse>>>;
type ListModelsFuture<'a> = Ready<Result<Vec<ModelInfo>>>;
type ChatFuture<'a> = Ready<Result<ChatResponse>>;
type ChatStreamFuture<'a> = Ready<Result<Self::Stream>>;
type HealthCheckFuture<'a> = Ready<Result<()>>;
fn name(&self) -> &str {
"streaming"
}
fn list_models(&self) -> Self::ListModelsFuture<'_> {
ready(Ok(vec![ModelInfo {
id: "stream-model".to_string(),
provider: "streaming".to_string(),
name: "stream-model".to_string(),
description: None,
context_window: None,
capabilities: vec!["chat".to_string()],
supports_tools: false,
}]))
}
fn chat(&self, request: ChatRequest) -> Self::ChatFuture<'_> {
ready(Ok(self.response(&request)))
}
fn chat_stream(&self, request: ChatRequest) -> Self::ChatStreamFuture<'_> {
let response = self.response(&request);
ready(Ok(stream::iter(vec![Ok(response)])))
}
fn health_check(&self) -> Self::HealthCheckFuture<'_> {
ready(Ok(()))
}
}
impl StreamingProvider {
fn response(&self, request: &ChatRequest) -> ChatResponse {
let reply = format!(
"echo:{}",
request
.messages
.last()
.map(|m| m.content.clone())
.unwrap_or_default()
);
ChatResponse {
message: Message::new(Role::Assistant, reply),
usage: None,
is_streaming: true,
is_final: true,
}
}
}
#[tokio::test]
async fn default_chat_reads_from_stream() {
let provider = StreamingProvider;
let request = ChatRequest {
model: "stream-model".to_string(),
messages: vec![Message::new(Role::User, "ping".to_string())],
parameters: ChatParameters::default(),
tools: None,
};
let response = LLMProvider::chat(&provider, request)
.await
.expect("chat succeeded");
assert_eq!(response.message.content, "echo:ping");
assert!(response.is_final);
}
#[tokio::test]
async fn registry_registers_static_provider() {
let mut registry = ProviderRegistry::new();
registry.register(StreamingProvider);
let provider = registry.get("streaming").expect("provider registered");
let models = provider.list_models().await.expect("models listed");
assert_eq!(models[0].id, "stream-model");
}
#[tokio::test]
async fn registry_accepts_dynamic_provider() {
let mut registry = ProviderRegistry::new();
let provider: Arc<dyn Provider> = Arc::new(MockProvider::default());
registry.register_arc(provider.clone());
let fetched = registry.get("mock").expect("mock provider present");
let request = ChatRequest {
model: "mock-model".to_string(),
messages: vec![Message::new(Role::User, "hi".to_string())],
parameters: ChatParameters::default(),
tools: None,
};
let response = Provider::chat(fetched.as_ref(), request)
.await
.expect("chat succeeded");
assert_eq!(response.message.content, "Mock response to: hi");
}
}