diff --git a/src/models.rs b/src/models.rs index 8139f19..caad243 100644 --- a/src/models.rs +++ b/src/models.rs @@ -2,6 +2,8 @@ // Copyright (c) 2025 . All rights reserved. //! Model discovery, selection, and downloading logic for PolyScribe. + +use std::any::Any; use std::collections::BTreeMap; use std::env; use std::fs::{File, create_dir_all}; @@ -16,7 +18,7 @@ use serde::Deserialize; use sha2::{Digest, Sha256}; use indicatif::{ProgressBar, ProgressStyle, MultiProgress}; use atty::Stream; - +use clap::builder::Str; // --- Model downloader: list & download ggml models from Hugging Face --- #[derive(Debug, Deserialize)] @@ -395,50 +397,116 @@ fn format_model_list(models: &[ModelEntry]) -> String { } fn prompt_select_models_two_stage(models: &[ModelEntry]) -> Result> { - // Replaced by cliclack-based multiselect; keep function to preserve signature but delegate. - prompt_select_models_cliclack(models) -} - -fn prompt_select_models_cliclack(models: &[ModelEntry]) -> Result> { + // Non-interactive: pick a sensible default or exit cleanly if crate::is_no_interaction() || !crate::stdin_is_tty() { - // Non-interactive: do not prompt, return empty selection to skip + // Prefer the default English base model (e.g., "base.en") + if let Some(default) = models.iter().find(|m| m.base == "base" && m.subtype == "en") { + ilog!("Non-Interactive: selecting default model {}", default.name); + return Ok(vec![default.clone()]); + } + // Fallback: any 'base' family model + if let Some(fallback) = models.iter().find(|m| m.base == "base") { + ilog!("Non-Interactive: selecting default model {}", fallback.name); + return Ok(vec![fallback.clone()]); + } + // Nothing sensible to pick + wlog!("No interactive selection possible and no default model found; skipping model selection."); return Ok(Vec::new()); } - // Build grouped, aligned labels for selection items (include base prefix for grouping). - let mut item_labels: Vec = Vec::new(); - let mut item_model_indices: Vec = Vec::new(); - - // Compute widths for alignment - let name_width = models.iter().map(|m| m.name.len()).max().unwrap_or(0); - let base_width = models.iter().map(|m| m.base.len()).max().unwrap_or(0); + // Know Whisper base families in preferred ordering + let mut known_order: Vec<&str> = vec!["tiny", "small", "base", "medium", "large"]; + // Collect available bases from the incoming list + use std::collections::{BTreeMap, BTreeSet}; + let mut bases_available: BTreeSet = BTreeSet::new(); + let mut by_base: BTreeMap> = BTreeMap::new(); for (i, m) in models.iter().enumerate() { + bases_available.insert(m.base.clone()); + by_base.entry(m.base.clone()).or_default().push((i, m)); + } + + // Filter known_order by what is available; append any unknown bases at the end (sorted) + let mut base_choices: Vec = Vec::new(); + for base in &known_order { + if bases_available.contains(*base) { + base_choices.push((*base).to_string()); + } + } + for b in &bases_available { + if !known_order.iter().any(|k| k == b) { + base_choices.push(b.clone()); + } + } + if base_choices.is_empty() { + wlog!("No models available to select from."); + return Ok(Vec::new()); + } + + // Build select items for bases + let base_prompt = "Choose a base model family"; + let base_items: Vec<(String, String, String)> = base_choices + .iter() + .map(|b| { + let count = by_base.get(b).map(|v| v.len()).unwrap_or(0); + let label = format!("{b} ({count} variants)"); + (b.clone(), label, String::new()) + }) + .collect(); + + let selected_base = match cliclack::select::(base_prompt).items(&base_items).interact() { + Ok(val) => val, + Err(e) => { + wlog!("Selection canceled or failed: {}", e); + return Ok(Vec::new()); + } + }; + + // Second stage: multiselect among the chosen base's variants + let Some(variants) = by_base.get(&selected_base) else { + wlog!("No variants found for base '{}'.", selected_base); + return Ok(Vec::new()); + }; + + // Sort variants by subtype then name for stable presentation + let mut variants_sorted = variants.clone(); + variants_sorted.sort_by(|a, b| { + let (_, ma) = a; + let (_, mb) = b; + ma.subtype.cmp(&mb.subtype) + .then(ma.name.cmp(&mb.name)) + .then(ma.repo.cmp(&mb.repo)) + }); + + // Build multiselect items where value is the original index into 'models' + let name_width = variants_sorted.iter() + .map(|(_, m)| m.name.len()) + .max() + .unwrap_or(0); + + let prompt = format!( + "select {base} variant(s) (↑/↓ move, space toggle, enter confirm)", base = selected_base + ); + + let mut items: Vec<(usize, String, String)> = Vec::with_capacity(variants_sorted.len()); + for (idx, m) in variants_sorted.iter() { let label = format!( - "{base: = Vec::with_capacity(item_labels.len()); - for (idx, label) in item_labels.iter().cloned().enumerate() { - items.push((item_model_indices[idx], label, String::new())); - } - match cliclack::multiselect::(prompt) - .items(&items) - .interact() - { + match cliclack::multiselect::(&prompt).items(&items).interact() { Ok(selected_indices) => { - let mut chosen: Vec = Vec::new(); + if selected_indices.is_empty() { + ilog!("No variants selected; nothing to download."); + return Ok(Vec::new()); + } + let mut chosen: Vec = Vec::with_capacity(selected_indices.len()); for mi in selected_indices { if let Some(m) = models.get(mi) { chosen.push(m.clone()); @@ -447,7 +515,6 @@ fn prompt_select_models_cliclack(models: &[ModelEntry]) -> Result { - // If interaction fails (e.g., not a TTY), return empty to gracefully skip wlog!("Selection canceled or failed: {}", e); Ok(Vec::new()) }