[feat] improve non-interactive model selection and enhance multi-stage TTY-based selection

This commit is contained in:
2025-08-12 11:02:49 +02:00
parent f41f1a4117
commit a987a3fcfb

View File

@@ -2,6 +2,8 @@
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
//! Model discovery, selection, and downloading logic for PolyScribe.
use std::any::Any;
use std::collections::BTreeMap;
use std::env;
use std::fs::{File, create_dir_all};
@@ -16,7 +18,7 @@ use serde::Deserialize;
use sha2::{Digest, Sha256};
use indicatif::{ProgressBar, ProgressStyle, MultiProgress};
use atty::Stream;
use clap::builder::Str;
// --- Model downloader: list & download ggml models from Hugging Face ---
#[derive(Debug, Deserialize)]
@@ -395,50 +397,116 @@ fn format_model_list(models: &[ModelEntry]) -> String {
}
fn prompt_select_models_two_stage(models: &[ModelEntry]) -> Result<Vec<ModelEntry>> {
// Replaced by cliclack-based multiselect; keep function to preserve signature but delegate.
prompt_select_models_cliclack(models)
}
fn prompt_select_models_cliclack(models: &[ModelEntry]) -> Result<Vec<ModelEntry>> {
// Non-interactive: pick a sensible default or exit cleanly
if crate::is_no_interaction() || !crate::stdin_is_tty() {
// Non-interactive: do not prompt, return empty selection to skip
// Prefer the default English base model (e.g., "base.en")
if let Some(default) = models.iter().find(|m| m.base == "base" && m.subtype == "en") {
ilog!("Non-Interactive: selecting default model {}", default.name);
return Ok(vec![default.clone()]);
}
// Fallback: any 'base' family model
if let Some(fallback) = models.iter().find(|m| m.base == "base") {
ilog!("Non-Interactive: selecting default model {}", fallback.name);
return Ok(vec![fallback.clone()]);
}
// Nothing sensible to pick
wlog!("No interactive selection possible and no default model found; skipping model selection.");
return Ok(Vec::new());
}
// Build grouped, aligned labels for selection items (include base prefix for grouping).
let mut item_labels: Vec<String> = Vec::new();
let mut item_model_indices: Vec<usize> = Vec::new();
// Compute widths for alignment
let name_width = models.iter().map(|m| m.name.len()).max().unwrap_or(0);
let base_width = models.iter().map(|m| m.base.len()).max().unwrap_or(0);
// Know Whisper base families in preferred ordering
let mut known_order: Vec<&str> = vec!["tiny", "small", "base", "medium", "large"];
// Collect available bases from the incoming list
use std::collections::{BTreeMap, BTreeSet};
let mut bases_available: BTreeSet<String> = BTreeSet::new();
let mut by_base: BTreeMap<String, Vec<(usize, &ModelEntry)>> = BTreeMap::new();
for (i, m) in models.iter().enumerate() {
bases_available.insert(m.base.clone());
by_base.entry(m.base.clone()).or_default().push((i, m));
}
// Filter known_order by what is available; append any unknown bases at the end (sorted)
let mut base_choices: Vec<String> = Vec::new();
for base in &known_order {
if bases_available.contains(*base) {
base_choices.push((*base).to_string());
}
}
for b in &bases_available {
if !known_order.iter().any(|k| k == b) {
base_choices.push(b.clone());
}
}
if base_choices.is_empty() {
wlog!("No models available to select from.");
return Ok(Vec::new());
}
// Build select items for bases
let base_prompt = "Choose a base model family";
let base_items: Vec<(String, String, String)> = base_choices
.iter()
.map(|b| {
let count = by_base.get(b).map(|v| v.len()).unwrap_or(0);
let label = format!("{b} ({count} variants)");
(b.clone(), label, String::new())
})
.collect();
let selected_base = match cliclack::select::<String>(base_prompt).items(&base_items).interact() {
Ok(val) => val,
Err(e) => {
wlog!("Selection canceled or failed: {}", e);
return Ok(Vec::new());
}
};
// Second stage: multiselect among the chosen base's variants
let Some(variants) = by_base.get(&selected_base) else {
wlog!("No variants found for base '{}'.", selected_base);
return Ok(Vec::new());
};
// Sort variants by subtype then name for stable presentation
let mut variants_sorted = variants.clone();
variants_sorted.sort_by(|a, b| {
let (_, ma) = a;
let (_, mb) = b;
ma.subtype.cmp(&mb.subtype)
.then(ma.name.cmp(&mb.name))
.then(ma.repo.cmp(&mb.repo))
});
// Build multiselect items where value is the original index into 'models'
let name_width = variants_sorted.iter()
.map(|(_, m)| m.name.len())
.max()
.unwrap_or(0);
let prompt = format!(
"select {base} variant(s) (↑/↓ move, space toggle, enter confirm)", base = selected_base
);
let mut items: Vec<(usize, String, String)> = Vec::with_capacity(variants_sorted.len());
for (idx, m) in variants_sorted.iter() {
let label = format!(
"{base:<bw$}: {name:<nw$} [{repo} | {size}]",
base = m.base,
bw = base_width,
"{name:<nw$} [{repo} | {size}]",
name = m.name,
nw = name_width,
repo = m.repo,
size = human_size(m.size)
);
item_labels.push(label);
item_model_indices.push(i);
items.push((*idx, label, String::new()));
}
// Use cliclack multiselect builder with (value, label, help) tuples.
let prompt = "Select Whisper model(s) to download (↑/↓ move, space toggle, enter confirm)";
let mut items: Vec<(usize, String, String)> = Vec::with_capacity(item_labels.len());
for (idx, label) in item_labels.iter().cloned().enumerate() {
items.push((item_model_indices[idx], label, String::new()));
}
match cliclack::multiselect::<usize>(prompt)
.items(&items)
.interact()
{
match cliclack::multiselect::<usize>(&prompt).items(&items).interact() {
Ok(selected_indices) => {
let mut chosen: Vec<ModelEntry> = Vec::new();
if selected_indices.is_empty() {
ilog!("No variants selected; nothing to download.");
return Ok(Vec::new());
}
let mut chosen: Vec<ModelEntry> = Vec::with_capacity(selected_indices.len());
for mi in selected_indices {
if let Some(m) = models.get(mi) {
chosen.push(m.clone());
@@ -447,7 +515,6 @@ fn prompt_select_models_cliclack(models: &[ModelEntry]) -> Result<Vec<ModelEntry
Ok(chosen)
}
Err(e) => {
// If interaction fails (e.g., not a TTY), return empty to gracefully skip
wlog!("Selection canceled or failed: {}", e);
Ok(Vec::new())
}