From 4820a6706f18f4d47543422ccf407f56cb64b660 Mon Sep 17 00:00:00 2001 From: vikingowl Date: Sat, 18 Oct 2025 06:57:58 +0200 Subject: [PATCH] feat(provider): enrich model metadata with provider tags and display names, add canonical provider ID handling, and update UI to use new display names and handle provider errors --- crates/owlen-core/src/provider/manager.rs | 255 ++++++++++++++++++- crates/owlen-tui/src/chat_app.rs | 218 ++++++++++++---- crates/owlen-tui/src/widgets/model_picker.rs | 33 +-- 3 files changed, 440 insertions(+), 66 deletions(-) diff --git a/crates/owlen-core/src/provider/manager.rs b/crates/owlen-core/src/provider/manager.rs index 7b197e0..184b895 100644 --- a/crates/owlen-core/src/provider/manager.rs +++ b/crates/owlen-core/src/provider/manager.rs @@ -3,12 +3,15 @@ use std::sync::Arc; use futures::stream::{FuturesUnordered, StreamExt}; use log::{debug, warn}; +use serde_json::Value; use tokio::sync::RwLock; use crate::config::Config; use crate::{Error, Result}; -use super::{GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderStatus}; +use super::{ + GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderStatus, ProviderType, +}; /// Model information annotated with the originating provider metadata. #[derive(Debug, Clone)] @@ -149,6 +152,7 @@ impl ProviderManager { } } + enrich_model_metadata(&mut annotated); Ok(annotated) } @@ -217,6 +221,255 @@ impl ProviderManager { } } +fn enrich_model_metadata(models: &mut [AnnotatedModelInfo]) { + let mut name_counts: HashMap = HashMap::new(); + for info in models.iter() { + *name_counts.entry(info.model.name.clone()).or_default() += 1; + } + + for info in models.iter_mut() { + let provider_tag = provider_tag_for(&info.provider_id); + info.model + .metadata + .insert("provider_tag".into(), Value::String(provider_tag.clone())); + + let scope_label = provider_scope_label(info.model.provider.provider_type); + info.model.metadata.insert( + "provider_scope".into(), + Value::String(scope_label.to_string()), + ); + info.model.metadata.insert( + "provider_display_name".into(), + Value::String(info.model.provider.name.clone()), + ); + + let display_name = if name_counts + .get(&info.model.name) + .is_some_and(|count| *count > 1) + { + let suffix = scope_label; + let base = info.model.name.trim(); + if base.ends_with(&format!("· {}", suffix)) { + base.to_string() + } else { + format!("{base} · {suffix}") + } + } else { + info.model.name.clone() + }; + + info.model + .metadata + .insert("display_name".into(), Value::String(display_name)); + } +} + +fn provider_tag_for(provider_id: &str) -> String { + let normalized = provider_id.trim().to_ascii_lowercase().replace('-', "_"); + match normalized.as_str() { + "ollama" | "ollama_local" => "ollama".to_string(), + "ollama_cloud" => "ollama-cloud".to_string(), + other => other.replace('_', "-"), + } +} + +fn provider_scope_label(provider_type: ProviderType) -> &'static str { + match provider_type { + ProviderType::Local => "local", + ProviderType::Cloud => "cloud", + } +} + +#[cfg(test)] +mod tests { + use super::*; + use async_trait::async_trait; + use std::sync::Arc; + + use crate::{Error, provider::ProviderMetadata}; + + #[derive(Clone)] + struct StaticProvider { + metadata: ProviderMetadata, + models: Vec, + status: ProviderStatus, + } + + impl StaticProvider { + fn new( + id: &str, + name: &str, + provider_type: ProviderType, + status: ProviderStatus, + models: Vec, + ) -> Self { + let metadata = ProviderMetadata::new(id, name, provider_type, false); + let mut models = models; + for model in &mut models { + model.provider = metadata.clone(); + } + let mut metadata = metadata; + metadata + .metadata + .insert("test".into(), Value::String("true".into())); + Self { + metadata, + models, + status, + } + } + } + + #[async_trait] + impl ModelProvider for StaticProvider { + fn metadata(&self) -> &ProviderMetadata { + &self.metadata + } + + async fn health_check(&self) -> Result { + Ok(self.status) + } + + async fn list_models(&self) -> Result> { + Ok(self.models.clone()) + } + + async fn generate_stream(&self, _request: GenerateRequest) -> Result { + Err(Error::NotImplemented( + "streaming not implemented in StaticProvider".to_string(), + )) + } + } + + fn model(name: &str) -> ModelInfo { + ModelInfo { + name: name.to_string(), + size_bytes: None, + capabilities: Vec::new(), + description: None, + provider: ProviderMetadata::new("unused", "Unused", ProviderType::Local, false), + metadata: HashMap::new(), + } + } + + #[tokio::test] + async fn aggregates_local_provider_models() { + let manager = ProviderManager::default(); + let provider = StaticProvider::new( + "ollama_local", + "Ollama Local", + ProviderType::Local, + ProviderStatus::Available, + vec![model("qwen3:8b")], + ); + manager.register_provider(Arc::new(provider)).await; + + let models = manager.list_all_models().await.unwrap(); + assert_eq!(models.len(), 1); + let entry = &models[0]; + assert_eq!(entry.provider_id, "ollama_local"); + assert_eq!(entry.provider_status, ProviderStatus::Available); + assert_eq!( + entry + .model + .metadata + .get("provider_tag") + .and_then(Value::as_str), + Some("ollama") + ); + assert_eq!( + entry + .model + .metadata + .get("display_name") + .and_then(Value::as_str), + Some("qwen3:8b") + ); + } + + #[tokio::test] + async fn aggregates_cloud_provider_models() { + let manager = ProviderManager::default(); + let provider = StaticProvider::new( + "ollama_cloud", + "Ollama Cloud", + ProviderType::Cloud, + ProviderStatus::Available, + vec![model("qwen3:0.5b-cloud")], + ); + manager.register_provider(Arc::new(provider)).await; + + let models = manager.list_all_models().await.unwrap(); + assert_eq!(models.len(), 1); + let entry = &models[0]; + assert_eq!( + entry + .model + .metadata + .get("provider_tag") + .and_then(Value::as_str), + Some("ollama-cloud") + ); + assert_eq!( + entry + .model + .metadata + .get("display_name") + .and_then(Value::as_str), + Some("qwen3:0.5b-cloud") + ); + } + + #[tokio::test] + async fn deduplicates_model_names_with_provider_suffix() { + let manager = ProviderManager::default(); + let local = StaticProvider::new( + "ollama_local", + "Ollama Local", + ProviderType::Local, + ProviderStatus::Available, + vec![model("qwen3:8b")], + ); + let cloud = StaticProvider::new( + "ollama_cloud", + "Ollama Cloud", + ProviderType::Cloud, + ProviderStatus::Available, + vec![model("qwen3:8b")], + ); + manager.register_provider(Arc::new(local)).await; + manager.register_provider(Arc::new(cloud)).await; + + let models = manager.list_all_models().await.unwrap(); + + let local_entry = models + .iter() + .find(|entry| entry.provider_id == "ollama_local") + .expect("local provider entry"); + let cloud_entry = models + .iter() + .find(|entry| entry.provider_id == "ollama_cloud") + .expect("cloud provider entry"); + + assert_eq!( + local_entry + .model + .metadata + .get("display_name") + .and_then(Value::as_str), + Some("qwen3:8b · local") + ); + assert_eq!( + cloud_entry + .model + .metadata + .get("display_name") + .and_then(Value::as_str), + Some("qwen3:8b · cloud") + ); + } +} + impl Default for ProviderManager { fn default() -> Self { Self { diff --git a/crates/owlen-tui/src/chat_app.rs b/crates/owlen-tui/src/chat_app.rs index f679e69..95e6ae0 100644 --- a/crates/owlen-tui/src/chat_app.rs +++ b/crates/owlen-tui/src/chat_app.rs @@ -5,6 +5,7 @@ use crossterm::{ event::{KeyEvent, MouseButton, MouseEvent, MouseEventKind}, terminal::{disable_raw_mode, enable_raw_mode}, }; +use owlen_core::Error as CoreError; use owlen_core::consent::ConsentScope; use owlen_core::facade::llm_client::LlmClient; use owlen_core::mcp::remote_client::RemoteMcpClient; @@ -757,7 +758,7 @@ impl ChatApp { // Load theme and provider based on config before moving `controller`. let config_guard = controller.config_async().await; let theme_name = config_guard.ui.theme.clone(); - let current_provider = config_guard.general.default_provider.clone(); + let current_provider = Self::canonical_provider_id(&config_guard.general.default_provider); let show_onboarding = config_guard.ui.show_onboarding; let show_cursor_outside_insert = config_guard.ui.show_cursor_outside_insert; let syntax_highlighting = config_guard.ui.syntax_highlighting; @@ -797,7 +798,7 @@ impl ChatApp { annotated_models: Vec::new(), provider_scope_status: HashMap::new(), available_providers: Vec::new(), - selected_provider: "ollama_local".to_string(), // Default, will be updated in initialize_models + selected_provider: current_provider.clone(), selected_provider_index: 0, selected_model_item: None, model_selector_items: Vec::new(), @@ -2493,11 +2494,32 @@ impl ChatApp { references } - fn display_name_for_model(model: &ModelInfo) -> String { - if model.name.trim().is_empty() { - model.id.clone() + pub(crate) fn display_name_for_model(model: &ModelInfo) -> String { + let base = { + let trimmed = model.name.trim(); + if trimmed.is_empty() { + model.id.as_str() + } else { + trimmed + } + }; + + let scope = Self::model_scope_from_capabilities(model); + let scope_suffix = match &scope { + ModelScope::Local => "local".to_string(), + ModelScope::Cloud => "cloud".to_string(), + ModelScope::Other(other) => other.trim().to_ascii_lowercase(), + }; + + if scope_suffix.is_empty() { + base.to_string() } else { - model.name.clone() + let lower = base.to_ascii_lowercase(); + if lower.contains(&format!("· {}", scope_suffix)) { + base.to_string() + } else { + format!("{base} · {scope_suffix}") + } } } @@ -5086,7 +5108,7 @@ impl ChatApp { } if !config_model_provider.is_empty() { - self.selected_provider = config_model_provider.clone(); + self.selected_provider = Self::canonical_provider_id(&config_model_provider); } else { self.selected_provider = self.available_providers[0].clone(); } @@ -8358,11 +8380,11 @@ impl ChatApp { continue; } - let canonical_name = match name.trim().to_ascii_lowercase().as_str() { - "ollama" | "ollama_local" => "ollama_local".to_string(), - "ollama-cloud" | "ollama_cloud" => "ollama_cloud".to_string(), - other => other.to_string(), - }; + if !provider_cfg.enabled { + continue; + } + + let canonical_name = Self::canonical_provider_id(&name); // All providers communicate via MCP LLM server (Phase 10). // Select provider by name via OWLEN_PROVIDER so per-provider settings apply. @@ -8692,8 +8714,14 @@ impl ChatApp { } fn recompute_available_providers(&mut self) { - let mut providers: BTreeSet = - self.controller.config().providers.keys().cloned().collect(); + let mut providers: BTreeSet = self + .controller + .config() + .providers + .iter() + .filter(|(_, cfg)| cfg.enabled) + .map(|(name, _)| Self::canonical_provider_id(name)) + .collect(); providers.extend(self.models.iter().map(|m| m.provider.clone())); @@ -8708,6 +8736,19 @@ impl ChatApp { self.available_providers = providers.into_iter().collect(); } + fn canonical_provider_id(provider: &str) -> String { + let normalized = provider.trim().to_ascii_lowercase(); + if normalized.is_empty() { + return "ollama_local".to_string(); + } + match normalized.replace('-', "_").as_str() { + "ollama" => "ollama_local".to_string(), + "ollama_local" => "ollama_local".to_string(), + "ollama_cloud" => "ollama_cloud".to_string(), + other => other.to_string(), + } + } + fn with_temp_env_vars(env_vars: &HashMap, action: F) -> T where F: FnOnce() -> T, @@ -8759,17 +8800,23 @@ impl ChatApp { "scope".to_string(), Value::String(Self::scope_display_name(&scope)), ); + provider_metadata.metadata.insert( + "provider_tag".to_string(), + Value::String(Self::provider_tag(&provider_id)), + ); let mut model_metadata = HashMap::new(); - if !model.name.trim().is_empty() && model.name != model.id { - model_metadata.insert( - "display_name".to_string(), - Value::String(model.name.clone()), - ); - } + model_metadata.insert( + "display_name".to_string(), + Value::String(Self::display_name_for_model(model)), + ); if let Some(ctx) = model.context_window { model_metadata.insert("context_window".to_string(), Value::from(ctx)); } + model_metadata.insert( + "provider_tag".to_string(), + Value::String(Self::provider_tag(&provider_id)), + ); let provider_model = ProviderModelInfo { name: model.id.clone(), @@ -8814,10 +8861,11 @@ impl ChatApp { for provider in &self.available_providers { let provider_lower = provider.to_ascii_lowercase(); + let provider_display = Self::provider_display_name(provider); let provider_status = self.provider_overall_status(provider); let provider_type = self.provider_type_for(provider); let provider_highlight = if search_active { - search_candidate(provider, &search_query).map(|(_, mask)| mask) + search_candidate(provider_display.as_str(), &search_query).map(|(_, mask)| mask) } else { None }; @@ -9059,16 +9107,11 @@ impl ChatApp { } }; - consider( - if model.name.trim().is_empty() { - None - } else { - Some(model.name.as_str()) - }, - &mut info.name, - ); + let display_name = Self::display_name_for_model(model); + consider(Some(display_name.as_str()), &mut info.name); consider(Some(model.id.as_str()), &mut info.id); - consider(Some(provider), &mut info.provider); + let provider_display = Self::provider_display_name(provider); + consider(Some(provider_display.as_str()), &mut info.provider); if let Some(desc) = model.description.as_deref() { consider(Some(desc), &mut info.description); } @@ -9188,7 +9231,7 @@ impl ChatApp { } } - fn provider_display_name(provider: &str) -> String { + pub(crate) fn provider_display_name(provider: &str) -> String { if provider.trim().is_empty() { return "Provider".to_string(); } @@ -9196,6 +9239,14 @@ impl ChatApp { capitalize_first(normalized.as_str()) } + fn provider_tag(provider: &str) -> String { + match provider.trim().to_ascii_lowercase().as_str() { + "ollama" | "ollama_local" => "ollama".to_string(), + "ollama-cloud" | "ollama_cloud" => "ollama-cloud".to_string(), + other => other.to_string(), + } + } + fn infer_provider_type(provider: &str, scope: &ModelScope) -> ProviderType { match scope { ModelScope::Local => ProviderType::Local, @@ -9425,22 +9476,17 @@ impl ChatApp { } async fn switch_to_provider(&mut self, provider_name: &str) -> Result<()> { - if self.current_provider == provider_name { + let canonical_name = Self::canonical_provider_id(provider_name); + if Self::canonical_provider_id(&self.current_provider) == canonical_name { return Ok(()); } use owlen_core::config::McpServerConfig; use std::collections::HashMap; - let canonical_name = if provider_name.eq_ignore_ascii_case("ollama-cloud") { - "ollama" - } else { - provider_name - }; - - if self.controller.config().provider(canonical_name).is_none() { + if self.controller.config().provider(&canonical_name).is_none() { let mut guard = self.controller.config_mut(); - config::ensure_provider_config(&mut guard, canonical_name); + config::ensure_provider_config(&mut guard, &canonical_name); } let workspace_root = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) @@ -9458,11 +9504,11 @@ impl ChatApp { }); let mut env_vars = HashMap::new(); - env_vars.insert("OWLEN_PROVIDER".to_string(), canonical_name.to_string()); + env_vars.insert("OWLEN_PROVIDER".to_string(), canonical_name.clone()); let provider: Arc = if let Some(path) = server_binary { let config = McpServerConfig { - name: canonical_name.to_string(), + name: canonical_name.clone(), command: path.to_string_lossy().into_owned(), args: Vec::new(), transport: "stdio".to_string(), @@ -9475,7 +9521,7 @@ impl ChatApp { }; self.controller.switch_provider(provider).await?; - self.current_provider = provider_name.to_string(); + self.current_provider = canonical_name; self.model_details_cache.clear(); self.model_info_panel.clear(); self.set_model_info_visible(false); @@ -10146,6 +10192,10 @@ impl ChatApp { Ok(()) } Ok(Err(err)) => { + self.stop_loading_animation(); + if self.handle_provider_error(&err).await? { + return Ok(()); + } let message = err.to_string(); if message.to_lowercase().contains("not found") { self.error = Some( @@ -10159,7 +10209,6 @@ impl ChatApp { self.error = Some(message); self.status = "Request failed".to_string(); } - self.stop_loading_animation(); Ok(()) } Err(_) => { @@ -10171,6 +10220,87 @@ impl ChatApp { } } + async fn handle_provider_error(&mut self, err: &CoreError) -> Result { + let current_provider = Self::canonical_provider_id(&self.current_provider); + if current_provider != "ollama_cloud" { + return Ok(false); + } + + match err { + CoreError::Auth(message) => { + self.push_toast( + ToastLevel::Error, + "Cloud key invalid; using local provider.", + ); + + let switch_result = self.switch_to_provider("ollama_local").await; + if let Err(switch_err) = switch_result { + let detail = format!( + "Cloud key invalid and local fallback failed: {}", + switch_err + ); + self.error = Some(detail.clone()); + self.status = "Cloud authentication failed".to_string(); + self.push_toast(ToastLevel::Error, detail); + } else { + self.selected_provider = "ollama_local".to_string(); + self.expanded_provider = Some("ollama_local".to_string()); + self.update_selected_provider_index(); + + { + let mut cfg = self.controller.config_mut(); + cfg.general.default_provider = "ollama_local".to_string(); + } + + let save_result = { + let cfg = self.controller.config(); + config::save_config(&cfg) + }; + if let Err(save_err) = save_result { + self.push_toast( + ToastLevel::Warning, + format!( + "Fell back to local provider, but failed to save config: {}", + save_err + ), + ); + } + + if let Err(refresh_err) = self.refresh_models().await { + self.push_toast( + ToastLevel::Warning, + format!("Failed to refresh local models: {}", refresh_err), + ); + } + + self.status = + "Cloud authentication failed; using local provider instead.".to_string(); + self.error = Some(format!( + "Cloud key invalid: {}. Update your credentials and reselect the cloud provider.", + message + )); + self.push_toast(ToastLevel::Info, "Switched back to local provider."); + } + + Ok(true) + } + CoreError::Network(message) => { + let lower = message.to_ascii_lowercase(); + if message.contains("429") + || lower.contains("too many requests") + || lower.contains("rate limit") + { + self.error = Some("Cloud rate limit hit; retry later.".to_string()); + self.status = "Cloud rate limit hit".to_string(); + self.push_toast(ToastLevel::Warning, "Cloud rate limit hit; retry later."); + return Ok(true); + } + Ok(false) + } + _ => Ok(false), + } + } + async fn process_agent_request(&mut self) -> Result<()> { use owlen_core::agent::{AgentConfig, AgentExecutor}; use owlen_core::mcp::remote_client::RemoteMcpClient; diff --git a/crates/owlen-tui/src/widgets/model_picker.rs b/crates/owlen-tui/src/widgets/model_picker.rs index d3b8e07..1da2e9f 100644 --- a/crates/owlen-tui/src/widgets/model_picker.rs +++ b/crates/owlen-tui/src/widgets/model_picker.rs @@ -211,8 +211,9 @@ pub fn render_model_picker(frame: &mut Frame<'_>, app: &ChatApp) { let mut spans = Vec::new(); spans.push(status_icon(*status, theme)); spans.push(Span::raw(" ")); + let display_name = ChatApp::provider_display_name(provider); let header_spans = render_highlighted_text( - provider, + display_name.as_str(), if search_active { app.provider_search_highlight(provider) } else { @@ -509,32 +510,15 @@ fn build_model_selector_lines<'a>( spans.push(Span::raw(" ")); let name_style = Style::default().fg(theme.text).add_modifier(Modifier::BOLD); - let id_style = Style::default() - .fg(theme.placeholder) - .add_modifier(Modifier::DIM); - - let name_trimmed = model.name.trim(); - if !name_trimmed.is_empty() { + let display_name = ChatApp::display_name_for_model(model); + if !display_name.trim().is_empty() { let name_spans = render_highlighted_text( - name_trimmed, + display_name.as_str(), search.info.and_then(|info| info.name.as_ref()), name_style, search.highlight_style, ); spans.extend(name_spans); - - if !model.id.eq_ignore_ascii_case(name_trimmed) { - spans.push(Span::raw(" ")); - spans.push(Span::styled("·", Style::default().fg(theme.placeholder))); - spans.push(Span::raw(" ")); - let id_spans = render_highlighted_text( - model.id.as_str(), - search.info.and_then(|info| info.id.as_ref()), - id_style, - search.highlight_style, - ); - spans.extend(id_spans); - } } else { let id_spans = render_highlighted_text( model.id.as_str(), @@ -580,6 +564,13 @@ fn build_model_selector_lines<'a>( push_meta(scope_label.clone()); } + let provider_label = ChatApp::provider_display_name(&model.provider); + push_meta(format!("provider {}", provider_label)); + + if !display_name.trim().eq_ignore_ascii_case(model.id.trim()) { + push_meta(format!("id {}", model.id)); + } + if let Some(detail) = detail { if let Some(ctx) = detail.context_length { push_meta(format!("max tokens {}", ctx));