[feat] improve non-interactive model selection and enhance multi-stage TTY-based selection
This commit is contained in:
129
src/models.rs
129
src/models.rs
@@ -2,6 +2,8 @@
|
||||
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
|
||||
|
||||
//! Model discovery, selection, and downloading logic for PolyScribe.
|
||||
|
||||
use std::any::Any;
|
||||
use std::collections::BTreeMap;
|
||||
use std::env;
|
||||
use std::fs::{File, create_dir_all};
|
||||
@@ -16,7 +18,7 @@ use serde::Deserialize;
|
||||
use sha2::{Digest, Sha256};
|
||||
use indicatif::{ProgressBar, ProgressStyle, MultiProgress};
|
||||
use atty::Stream;
|
||||
|
||||
use clap::builder::Str;
|
||||
// --- Model downloader: list & download ggml models from Hugging Face ---
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -395,50 +397,116 @@ fn format_model_list(models: &[ModelEntry]) -> String {
|
||||
}
|
||||
|
||||
fn prompt_select_models_two_stage(models: &[ModelEntry]) -> Result<Vec<ModelEntry>> {
|
||||
// Replaced by cliclack-based multiselect; keep function to preserve signature but delegate.
|
||||
prompt_select_models_cliclack(models)
|
||||
}
|
||||
|
||||
fn prompt_select_models_cliclack(models: &[ModelEntry]) -> Result<Vec<ModelEntry>> {
|
||||
// Non-interactive: pick a sensible default or exit cleanly
|
||||
if crate::is_no_interaction() || !crate::stdin_is_tty() {
|
||||
// Non-interactive: do not prompt, return empty selection to skip
|
||||
// Prefer the default English base model (e.g., "base.en")
|
||||
if let Some(default) = models.iter().find(|m| m.base == "base" && m.subtype == "en") {
|
||||
ilog!("Non-Interactive: selecting default model {}", default.name);
|
||||
return Ok(vec![default.clone()]);
|
||||
}
|
||||
// Fallback: any 'base' family model
|
||||
if let Some(fallback) = models.iter().find(|m| m.base == "base") {
|
||||
ilog!("Non-Interactive: selecting default model {}", fallback.name);
|
||||
return Ok(vec![fallback.clone()]);
|
||||
}
|
||||
// Nothing sensible to pick
|
||||
wlog!("No interactive selection possible and no default model found; skipping model selection.");
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Build grouped, aligned labels for selection items (include base prefix for grouping).
|
||||
let mut item_labels: Vec<String> = Vec::new();
|
||||
let mut item_model_indices: Vec<usize> = Vec::new();
|
||||
|
||||
// Compute widths for alignment
|
||||
let name_width = models.iter().map(|m| m.name.len()).max().unwrap_or(0);
|
||||
let base_width = models.iter().map(|m| m.base.len()).max().unwrap_or(0);
|
||||
// Know Whisper base families in preferred ordering
|
||||
let mut known_order: Vec<&str> = vec!["tiny", "small", "base", "medium", "large"];
|
||||
// Collect available bases from the incoming list
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
|
||||
let mut bases_available: BTreeSet<String> = BTreeSet::new();
|
||||
let mut by_base: BTreeMap<String, Vec<(usize, &ModelEntry)>> = BTreeMap::new();
|
||||
for (i, m) in models.iter().enumerate() {
|
||||
bases_available.insert(m.base.clone());
|
||||
by_base.entry(m.base.clone()).or_default().push((i, m));
|
||||
}
|
||||
|
||||
// Filter known_order by what is available; append any unknown bases at the end (sorted)
|
||||
let mut base_choices: Vec<String> = Vec::new();
|
||||
for base in &known_order {
|
||||
if bases_available.contains(*base) {
|
||||
base_choices.push((*base).to_string());
|
||||
}
|
||||
}
|
||||
for b in &bases_available {
|
||||
if !known_order.iter().any(|k| k == b) {
|
||||
base_choices.push(b.clone());
|
||||
}
|
||||
}
|
||||
if base_choices.is_empty() {
|
||||
wlog!("No models available to select from.");
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Build select items for bases
|
||||
let base_prompt = "Choose a base model family";
|
||||
let base_items: Vec<(String, String, String)> = base_choices
|
||||
.iter()
|
||||
.map(|b| {
|
||||
let count = by_base.get(b).map(|v| v.len()).unwrap_or(0);
|
||||
let label = format!("{b} ({count} variants)");
|
||||
(b.clone(), label, String::new())
|
||||
})
|
||||
.collect();
|
||||
|
||||
let selected_base = match cliclack::select::<String>(base_prompt).items(&base_items).interact() {
|
||||
Ok(val) => val,
|
||||
Err(e) => {
|
||||
wlog!("Selection canceled or failed: {}", e);
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
};
|
||||
|
||||
// Second stage: multiselect among the chosen base's variants
|
||||
let Some(variants) = by_base.get(&selected_base) else {
|
||||
wlog!("No variants found for base '{}'.", selected_base);
|
||||
return Ok(Vec::new());
|
||||
};
|
||||
|
||||
// Sort variants by subtype then name for stable presentation
|
||||
let mut variants_sorted = variants.clone();
|
||||
variants_sorted.sort_by(|a, b| {
|
||||
let (_, ma) = a;
|
||||
let (_, mb) = b;
|
||||
ma.subtype.cmp(&mb.subtype)
|
||||
.then(ma.name.cmp(&mb.name))
|
||||
.then(ma.repo.cmp(&mb.repo))
|
||||
});
|
||||
|
||||
// Build multiselect items where value is the original index into 'models'
|
||||
let name_width = variants_sorted.iter()
|
||||
.map(|(_, m)| m.name.len())
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
let prompt = format!(
|
||||
"select {base} variant(s) (↑/↓ move, space toggle, enter confirm)", base = selected_base
|
||||
);
|
||||
|
||||
let mut items: Vec<(usize, String, String)> = Vec::with_capacity(variants_sorted.len());
|
||||
for (idx, m) in variants_sorted.iter() {
|
||||
let label = format!(
|
||||
"{base:<bw$}: {name:<nw$} [{repo} | {size}]",
|
||||
base = m.base,
|
||||
bw = base_width,
|
||||
"{name:<nw$} [{repo} | {size}]",
|
||||
name = m.name,
|
||||
nw = name_width,
|
||||
repo = m.repo,
|
||||
size = human_size(m.size)
|
||||
);
|
||||
item_labels.push(label);
|
||||
item_model_indices.push(i);
|
||||
items.push((*idx, label, String::new()));
|
||||
}
|
||||
|
||||
// Use cliclack multiselect builder with (value, label, help) tuples.
|
||||
let prompt = "Select Whisper model(s) to download (↑/↓ move, space toggle, enter confirm)";
|
||||
let mut items: Vec<(usize, String, String)> = Vec::with_capacity(item_labels.len());
|
||||
for (idx, label) in item_labels.iter().cloned().enumerate() {
|
||||
items.push((item_model_indices[idx], label, String::new()));
|
||||
}
|
||||
match cliclack::multiselect::<usize>(prompt)
|
||||
.items(&items)
|
||||
.interact()
|
||||
{
|
||||
match cliclack::multiselect::<usize>(&prompt).items(&items).interact() {
|
||||
Ok(selected_indices) => {
|
||||
let mut chosen: Vec<ModelEntry> = Vec::new();
|
||||
if selected_indices.is_empty() {
|
||||
ilog!("No variants selected; nothing to download.");
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut chosen: Vec<ModelEntry> = Vec::with_capacity(selected_indices.len());
|
||||
for mi in selected_indices {
|
||||
if let Some(m) = models.get(mi) {
|
||||
chosen.push(m.clone());
|
||||
@@ -447,7 +515,6 @@ fn prompt_select_models_cliclack(models: &[ModelEntry]) -> Result<Vec<ModelEntry
|
||||
Ok(chosen)
|
||||
}
|
||||
Err(e) => {
|
||||
// If interaction fails (e.g., not a TTY), return empty to gracefully skip
|
||||
wlog!("Selection canceled or failed: {}", e);
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
Reference in New Issue
Block a user