feat(provider): enrich model metadata with provider tags and display names, add canonical provider ID handling, and update UI to use new display names and handle provider errors

This commit is contained in:
2025-10-18 06:57:58 +02:00
parent 3308b483f7
commit 4820a6706f
3 changed files with 440 additions and 66 deletions

View File

@@ -3,12 +3,15 @@ use std::sync::Arc;
use futures::stream::{FuturesUnordered, StreamExt};
use log::{debug, warn};
use serde_json::Value;
use tokio::sync::RwLock;
use crate::config::Config;
use crate::{Error, Result};
use super::{GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderStatus};
use super::{
GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderStatus, ProviderType,
};
/// Model information annotated with the originating provider metadata.
#[derive(Debug, Clone)]
@@ -149,6 +152,7 @@ impl ProviderManager {
}
}
enrich_model_metadata(&mut annotated);
Ok(annotated)
}
@@ -217,6 +221,255 @@ impl ProviderManager {
}
}
fn enrich_model_metadata(models: &mut [AnnotatedModelInfo]) {
let mut name_counts: HashMap<String, usize> = HashMap::new();
for info in models.iter() {
*name_counts.entry(info.model.name.clone()).or_default() += 1;
}
for info in models.iter_mut() {
let provider_tag = provider_tag_for(&info.provider_id);
info.model
.metadata
.insert("provider_tag".into(), Value::String(provider_tag.clone()));
let scope_label = provider_scope_label(info.model.provider.provider_type);
info.model.metadata.insert(
"provider_scope".into(),
Value::String(scope_label.to_string()),
);
info.model.metadata.insert(
"provider_display_name".into(),
Value::String(info.model.provider.name.clone()),
);
let display_name = if name_counts
.get(&info.model.name)
.is_some_and(|count| *count > 1)
{
let suffix = scope_label;
let base = info.model.name.trim();
if base.ends_with(&format!("· {}", suffix)) {
base.to_string()
} else {
format!("{base} · {suffix}")
}
} else {
info.model.name.clone()
};
info.model
.metadata
.insert("display_name".into(), Value::String(display_name));
}
}
fn provider_tag_for(provider_id: &str) -> String {
let normalized = provider_id.trim().to_ascii_lowercase().replace('-', "_");
match normalized.as_str() {
"ollama" | "ollama_local" => "ollama".to_string(),
"ollama_cloud" => "ollama-cloud".to_string(),
other => other.replace('_', "-"),
}
}
fn provider_scope_label(provider_type: ProviderType) -> &'static str {
match provider_type {
ProviderType::Local => "local",
ProviderType::Cloud => "cloud",
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use std::sync::Arc;
use crate::{Error, provider::ProviderMetadata};
#[derive(Clone)]
struct StaticProvider {
metadata: ProviderMetadata,
models: Vec<ModelInfo>,
status: ProviderStatus,
}
impl StaticProvider {
fn new(
id: &str,
name: &str,
provider_type: ProviderType,
status: ProviderStatus,
models: Vec<ModelInfo>,
) -> Self {
let metadata = ProviderMetadata::new(id, name, provider_type, false);
let mut models = models;
for model in &mut models {
model.provider = metadata.clone();
}
let mut metadata = metadata;
metadata
.metadata
.insert("test".into(), Value::String("true".into()));
Self {
metadata,
models,
status,
}
}
}
#[async_trait]
impl ModelProvider for StaticProvider {
fn metadata(&self) -> &ProviderMetadata {
&self.metadata
}
async fn health_check(&self) -> Result<ProviderStatus> {
Ok(self.status)
}
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
Ok(self.models.clone())
}
async fn generate_stream(&self, _request: GenerateRequest) -> Result<GenerateStream> {
Err(Error::NotImplemented(
"streaming not implemented in StaticProvider".to_string(),
))
}
}
fn model(name: &str) -> ModelInfo {
ModelInfo {
name: name.to_string(),
size_bytes: None,
capabilities: Vec::new(),
description: None,
provider: ProviderMetadata::new("unused", "Unused", ProviderType::Local, false),
metadata: HashMap::new(),
}
}
#[tokio::test]
async fn aggregates_local_provider_models() {
let manager = ProviderManager::default();
let provider = StaticProvider::new(
"ollama_local",
"Ollama Local",
ProviderType::Local,
ProviderStatus::Available,
vec![model("qwen3:8b")],
);
manager.register_provider(Arc::new(provider)).await;
let models = manager.list_all_models().await.unwrap();
assert_eq!(models.len(), 1);
let entry = &models[0];
assert_eq!(entry.provider_id, "ollama_local");
assert_eq!(entry.provider_status, ProviderStatus::Available);
assert_eq!(
entry
.model
.metadata
.get("provider_tag")
.and_then(Value::as_str),
Some("ollama")
);
assert_eq!(
entry
.model
.metadata
.get("display_name")
.and_then(Value::as_str),
Some("qwen3:8b")
);
}
#[tokio::test]
async fn aggregates_cloud_provider_models() {
let manager = ProviderManager::default();
let provider = StaticProvider::new(
"ollama_cloud",
"Ollama Cloud",
ProviderType::Cloud,
ProviderStatus::Available,
vec![model("qwen3:0.5b-cloud")],
);
manager.register_provider(Arc::new(provider)).await;
let models = manager.list_all_models().await.unwrap();
assert_eq!(models.len(), 1);
let entry = &models[0];
assert_eq!(
entry
.model
.metadata
.get("provider_tag")
.and_then(Value::as_str),
Some("ollama-cloud")
);
assert_eq!(
entry
.model
.metadata
.get("display_name")
.and_then(Value::as_str),
Some("qwen3:0.5b-cloud")
);
}
#[tokio::test]
async fn deduplicates_model_names_with_provider_suffix() {
let manager = ProviderManager::default();
let local = StaticProvider::new(
"ollama_local",
"Ollama Local",
ProviderType::Local,
ProviderStatus::Available,
vec![model("qwen3:8b")],
);
let cloud = StaticProvider::new(
"ollama_cloud",
"Ollama Cloud",
ProviderType::Cloud,
ProviderStatus::Available,
vec![model("qwen3:8b")],
);
manager.register_provider(Arc::new(local)).await;
manager.register_provider(Arc::new(cloud)).await;
let models = manager.list_all_models().await.unwrap();
let local_entry = models
.iter()
.find(|entry| entry.provider_id == "ollama_local")
.expect("local provider entry");
let cloud_entry = models
.iter()
.find(|entry| entry.provider_id == "ollama_cloud")
.expect("cloud provider entry");
assert_eq!(
local_entry
.model
.metadata
.get("display_name")
.and_then(Value::as_str),
Some("qwen3:8b · local")
);
assert_eq!(
cloud_entry
.model
.metadata
.get("display_name")
.and_then(Value::as_str),
Some("qwen3:8b · cloud")
);
}
}
impl Default for ProviderManager {
fn default() -> Self {
Self {

View File

@@ -5,6 +5,7 @@ use crossterm::{
event::{KeyEvent, MouseButton, MouseEvent, MouseEventKind},
terminal::{disable_raw_mode, enable_raw_mode},
};
use owlen_core::Error as CoreError;
use owlen_core::consent::ConsentScope;
use owlen_core::facade::llm_client::LlmClient;
use owlen_core::mcp::remote_client::RemoteMcpClient;
@@ -757,7 +758,7 @@ impl ChatApp {
// Load theme and provider based on config before moving `controller`.
let config_guard = controller.config_async().await;
let theme_name = config_guard.ui.theme.clone();
let current_provider = config_guard.general.default_provider.clone();
let current_provider = Self::canonical_provider_id(&config_guard.general.default_provider);
let show_onboarding = config_guard.ui.show_onboarding;
let show_cursor_outside_insert = config_guard.ui.show_cursor_outside_insert;
let syntax_highlighting = config_guard.ui.syntax_highlighting;
@@ -797,7 +798,7 @@ impl ChatApp {
annotated_models: Vec::new(),
provider_scope_status: HashMap::new(),
available_providers: Vec::new(),
selected_provider: "ollama_local".to_string(), // Default, will be updated in initialize_models
selected_provider: current_provider.clone(),
selected_provider_index: 0,
selected_model_item: None,
model_selector_items: Vec::new(),
@@ -2493,11 +2494,32 @@ impl ChatApp {
references
}
fn display_name_for_model(model: &ModelInfo) -> String {
if model.name.trim().is_empty() {
model.id.clone()
pub(crate) fn display_name_for_model(model: &ModelInfo) -> String {
let base = {
let trimmed = model.name.trim();
if trimmed.is_empty() {
model.id.as_str()
} else {
model.name.clone()
trimmed
}
};
let scope = Self::model_scope_from_capabilities(model);
let scope_suffix = match &scope {
ModelScope::Local => "local".to_string(),
ModelScope::Cloud => "cloud".to_string(),
ModelScope::Other(other) => other.trim().to_ascii_lowercase(),
};
if scope_suffix.is_empty() {
base.to_string()
} else {
let lower = base.to_ascii_lowercase();
if lower.contains(&format!("· {}", scope_suffix)) {
base.to_string()
} else {
format!("{base} · {scope_suffix}")
}
}
}
@@ -5086,7 +5108,7 @@ impl ChatApp {
}
if !config_model_provider.is_empty() {
self.selected_provider = config_model_provider.clone();
self.selected_provider = Self::canonical_provider_id(&config_model_provider);
} else {
self.selected_provider = self.available_providers[0].clone();
}
@@ -8358,11 +8380,11 @@ impl ChatApp {
continue;
}
let canonical_name = match name.trim().to_ascii_lowercase().as_str() {
"ollama" | "ollama_local" => "ollama_local".to_string(),
"ollama-cloud" | "ollama_cloud" => "ollama_cloud".to_string(),
other => other.to_string(),
};
if !provider_cfg.enabled {
continue;
}
let canonical_name = Self::canonical_provider_id(&name);
// All providers communicate via MCP LLM server (Phase 10).
// Select provider by name via OWLEN_PROVIDER so per-provider settings apply.
@@ -8692,8 +8714,14 @@ impl ChatApp {
}
fn recompute_available_providers(&mut self) {
let mut providers: BTreeSet<String> =
self.controller.config().providers.keys().cloned().collect();
let mut providers: BTreeSet<String> = self
.controller
.config()
.providers
.iter()
.filter(|(_, cfg)| cfg.enabled)
.map(|(name, _)| Self::canonical_provider_id(name))
.collect();
providers.extend(self.models.iter().map(|m| m.provider.clone()));
@@ -8708,6 +8736,19 @@ impl ChatApp {
self.available_providers = providers.into_iter().collect();
}
fn canonical_provider_id(provider: &str) -> String {
let normalized = provider.trim().to_ascii_lowercase();
if normalized.is_empty() {
return "ollama_local".to_string();
}
match normalized.replace('-', "_").as_str() {
"ollama" => "ollama_local".to_string(),
"ollama_local" => "ollama_local".to_string(),
"ollama_cloud" => "ollama_cloud".to_string(),
other => other.to_string(),
}
}
fn with_temp_env_vars<T, F>(env_vars: &HashMap<String, String>, action: F) -> T
where
F: FnOnce() -> T,
@@ -8759,17 +8800,23 @@ impl ChatApp {
"scope".to_string(),
Value::String(Self::scope_display_name(&scope)),
);
provider_metadata.metadata.insert(
"provider_tag".to_string(),
Value::String(Self::provider_tag(&provider_id)),
);
let mut model_metadata = HashMap::new();
if !model.name.trim().is_empty() && model.name != model.id {
model_metadata.insert(
"display_name".to_string(),
Value::String(model.name.clone()),
Value::String(Self::display_name_for_model(model)),
);
}
if let Some(ctx) = model.context_window {
model_metadata.insert("context_window".to_string(), Value::from(ctx));
}
model_metadata.insert(
"provider_tag".to_string(),
Value::String(Self::provider_tag(&provider_id)),
);
let provider_model = ProviderModelInfo {
name: model.id.clone(),
@@ -8814,10 +8861,11 @@ impl ChatApp {
for provider in &self.available_providers {
let provider_lower = provider.to_ascii_lowercase();
let provider_display = Self::provider_display_name(provider);
let provider_status = self.provider_overall_status(provider);
let provider_type = self.provider_type_for(provider);
let provider_highlight = if search_active {
search_candidate(provider, &search_query).map(|(_, mask)| mask)
search_candidate(provider_display.as_str(), &search_query).map(|(_, mask)| mask)
} else {
None
};
@@ -9059,16 +9107,11 @@ impl ChatApp {
}
};
consider(
if model.name.trim().is_empty() {
None
} else {
Some(model.name.as_str())
},
&mut info.name,
);
let display_name = Self::display_name_for_model(model);
consider(Some(display_name.as_str()), &mut info.name);
consider(Some(model.id.as_str()), &mut info.id);
consider(Some(provider), &mut info.provider);
let provider_display = Self::provider_display_name(provider);
consider(Some(provider_display.as_str()), &mut info.provider);
if let Some(desc) = model.description.as_deref() {
consider(Some(desc), &mut info.description);
}
@@ -9188,7 +9231,7 @@ impl ChatApp {
}
}
fn provider_display_name(provider: &str) -> String {
pub(crate) fn provider_display_name(provider: &str) -> String {
if provider.trim().is_empty() {
return "Provider".to_string();
}
@@ -9196,6 +9239,14 @@ impl ChatApp {
capitalize_first(normalized.as_str())
}
fn provider_tag(provider: &str) -> String {
match provider.trim().to_ascii_lowercase().as_str() {
"ollama" | "ollama_local" => "ollama".to_string(),
"ollama-cloud" | "ollama_cloud" => "ollama-cloud".to_string(),
other => other.to_string(),
}
}
fn infer_provider_type(provider: &str, scope: &ModelScope) -> ProviderType {
match scope {
ModelScope::Local => ProviderType::Local,
@@ -9425,22 +9476,17 @@ impl ChatApp {
}
async fn switch_to_provider(&mut self, provider_name: &str) -> Result<()> {
if self.current_provider == provider_name {
let canonical_name = Self::canonical_provider_id(provider_name);
if Self::canonical_provider_id(&self.current_provider) == canonical_name {
return Ok(());
}
use owlen_core::config::McpServerConfig;
use std::collections::HashMap;
let canonical_name = if provider_name.eq_ignore_ascii_case("ollama-cloud") {
"ollama"
} else {
provider_name
};
if self.controller.config().provider(canonical_name).is_none() {
if self.controller.config().provider(&canonical_name).is_none() {
let mut guard = self.controller.config_mut();
config::ensure_provider_config(&mut guard, canonical_name);
config::ensure_provider_config(&mut guard, &canonical_name);
}
let workspace_root = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
@@ -9458,11 +9504,11 @@ impl ChatApp {
});
let mut env_vars = HashMap::new();
env_vars.insert("OWLEN_PROVIDER".to_string(), canonical_name.to_string());
env_vars.insert("OWLEN_PROVIDER".to_string(), canonical_name.clone());
let provider: Arc<dyn owlen_core::Provider> = if let Some(path) = server_binary {
let config = McpServerConfig {
name: canonical_name.to_string(),
name: canonical_name.clone(),
command: path.to_string_lossy().into_owned(),
args: Vec::new(),
transport: "stdio".to_string(),
@@ -9475,7 +9521,7 @@ impl ChatApp {
};
self.controller.switch_provider(provider).await?;
self.current_provider = provider_name.to_string();
self.current_provider = canonical_name;
self.model_details_cache.clear();
self.model_info_panel.clear();
self.set_model_info_visible(false);
@@ -10146,6 +10192,10 @@ impl ChatApp {
Ok(())
}
Ok(Err(err)) => {
self.stop_loading_animation();
if self.handle_provider_error(&err).await? {
return Ok(());
}
let message = err.to_string();
if message.to_lowercase().contains("not found") {
self.error = Some(
@@ -10159,7 +10209,6 @@ impl ChatApp {
self.error = Some(message);
self.status = "Request failed".to_string();
}
self.stop_loading_animation();
Ok(())
}
Err(_) => {
@@ -10171,6 +10220,87 @@ impl ChatApp {
}
}
async fn handle_provider_error(&mut self, err: &CoreError) -> Result<bool> {
let current_provider = Self::canonical_provider_id(&self.current_provider);
if current_provider != "ollama_cloud" {
return Ok(false);
}
match err {
CoreError::Auth(message) => {
self.push_toast(
ToastLevel::Error,
"Cloud key invalid; using local provider.",
);
let switch_result = self.switch_to_provider("ollama_local").await;
if let Err(switch_err) = switch_result {
let detail = format!(
"Cloud key invalid and local fallback failed: {}",
switch_err
);
self.error = Some(detail.clone());
self.status = "Cloud authentication failed".to_string();
self.push_toast(ToastLevel::Error, detail);
} else {
self.selected_provider = "ollama_local".to_string();
self.expanded_provider = Some("ollama_local".to_string());
self.update_selected_provider_index();
{
let mut cfg = self.controller.config_mut();
cfg.general.default_provider = "ollama_local".to_string();
}
let save_result = {
let cfg = self.controller.config();
config::save_config(&cfg)
};
if let Err(save_err) = save_result {
self.push_toast(
ToastLevel::Warning,
format!(
"Fell back to local provider, but failed to save config: {}",
save_err
),
);
}
if let Err(refresh_err) = self.refresh_models().await {
self.push_toast(
ToastLevel::Warning,
format!("Failed to refresh local models: {}", refresh_err),
);
}
self.status =
"Cloud authentication failed; using local provider instead.".to_string();
self.error = Some(format!(
"Cloud key invalid: {}. Update your credentials and reselect the cloud provider.",
message
));
self.push_toast(ToastLevel::Info, "Switched back to local provider.");
}
Ok(true)
}
CoreError::Network(message) => {
let lower = message.to_ascii_lowercase();
if message.contains("429")
|| lower.contains("too many requests")
|| lower.contains("rate limit")
{
self.error = Some("Cloud rate limit hit; retry later.".to_string());
self.status = "Cloud rate limit hit".to_string();
self.push_toast(ToastLevel::Warning, "Cloud rate limit hit; retry later.");
return Ok(true);
}
Ok(false)
}
_ => Ok(false),
}
}
async fn process_agent_request(&mut self) -> Result<()> {
use owlen_core::agent::{AgentConfig, AgentExecutor};
use owlen_core::mcp::remote_client::RemoteMcpClient;

View File

@@ -211,8 +211,9 @@ pub fn render_model_picker(frame: &mut Frame<'_>, app: &ChatApp) {
let mut spans = Vec::new();
spans.push(status_icon(*status, theme));
spans.push(Span::raw(" "));
let display_name = ChatApp::provider_display_name(provider);
let header_spans = render_highlighted_text(
provider,
display_name.as_str(),
if search_active {
app.provider_search_highlight(provider)
} else {
@@ -509,32 +510,15 @@ fn build_model_selector_lines<'a>(
spans.push(Span::raw(" "));
let name_style = Style::default().fg(theme.text).add_modifier(Modifier::BOLD);
let id_style = Style::default()
.fg(theme.placeholder)
.add_modifier(Modifier::DIM);
let name_trimmed = model.name.trim();
if !name_trimmed.is_empty() {
let display_name = ChatApp::display_name_for_model(model);
if !display_name.trim().is_empty() {
let name_spans = render_highlighted_text(
name_trimmed,
display_name.as_str(),
search.info.and_then(|info| info.name.as_ref()),
name_style,
search.highlight_style,
);
spans.extend(name_spans);
if !model.id.eq_ignore_ascii_case(name_trimmed) {
spans.push(Span::raw(" "));
spans.push(Span::styled("·", Style::default().fg(theme.placeholder)));
spans.push(Span::raw(" "));
let id_spans = render_highlighted_text(
model.id.as_str(),
search.info.and_then(|info| info.id.as_ref()),
id_style,
search.highlight_style,
);
spans.extend(id_spans);
}
} else {
let id_spans = render_highlighted_text(
model.id.as_str(),
@@ -580,6 +564,13 @@ fn build_model_selector_lines<'a>(
push_meta(scope_label.clone());
}
let provider_label = ChatApp::provider_display_name(&model.provider);
push_meta(format!("provider {}", provider_label));
if !display_name.trim().eq_ignore_ascii_case(model.id.trim()) {
push_meta(format!("id {}", model.id));
}
if let Some(detail) = detail {
if let Some(ctx) = detail.context_length {
push_meta(format!("max tokens {}", ctx));