[feat] improve model selection prompt with two-stage filtering and formatted output; add corresponding tests

This commit is contained in:
2025-08-08 13:21:51 +02:00
parent 53a7471b99
commit 1cad6d593d
2 changed files with 147 additions and 43 deletions

View File

@@ -270,59 +270,131 @@ fn fetch_all_models(client: &Client) -> Result<Vec<ModelEntry>> {
}
fn prompt_select_models(models: &[ModelEntry]) -> Result<Vec<ModelEntry>> {
// Build a flat list but show group headers; indices count only models
println!("Available ggml Whisper models:");
fn format_model_list(models: &[ModelEntry]) -> String {
let mut out = String::new();
out.push_str("Available ggml Whisper models:\n");
// Compute alignment widths
let idx_width = std::cmp::max(2, models.len().to_string().len());
let name_width = models.iter().map(|m| m.name.len()).max().unwrap_or(0);
let mut idx = 1usize;
let mut current = "".to_string();
// We'll record mapping from index -> position in models
let mut index_map: Vec<usize> = Vec::with_capacity(models.len());
for (pos, m) in models.iter().enumerate() {
let mut current = String::new();
for m in models.iter() {
if m.base != current {
current = m.base.clone();
println!("\n{}:", current);
out.push_str("\n");
out.push_str(&format!("{}:\n", current));
}
let short_hash = m
.sha256
.as_ref()
.map(|h| h.chars().take(8).collect::<String>())
.unwrap_or_else(|| "-".to_string());
println!(" {}) {} [{} | {} | {}]", idx, m.name, m.repo, human_size(m.size), short_hash);
index_map.push(pos);
// Format without hash and with aligned columns
out.push_str(&format!(
" {i:>iw$}) {name:<nw$} [{repo} | {size}]\n",
i = idx,
iw = idx_width,
name = m.name,
nw = name_width,
repo = m.repo,
size = human_size(m.size),
));
idx += 1;
}
println!("\nEnter selection by indices (e.g., 1 3 5-7), or 'all', '*' for all, 'q' to cancel.");
out.push_str("\nEnter selection by indices (e.g., 1 3 5-7), or 'all', '*' for all, 'q' to cancel.\n");
out
}
fn prompt_select_models_two_stage(models: &[ModelEntry]) -> Result<Vec<ModelEntry>> {
// 1) Choose base (tiny, small, medium, etc.)
let mut bases: Vec<String> = Vec::new();
let mut last = String::new();
for m in models.iter() {
if m.base != last {
// models are sorted by base; avoid duplicates while preserving order
if bases.last().map(|b| b == &m.base).unwrap_or(false) == false {
bases.push(m.base.clone());
}
last = m.base.clone();
}
}
if bases.is_empty() {
return Ok(Vec::new());
}
// Print base selection on stderr
eprintln!("Available base model families:");
for (i, b) in bases.iter().enumerate() {
eprintln!(" {}) {}", i + 1, b);
}
loop {
eprint!("Selection: ");
eprint!("Select base (number or name, 'q' to cancel): ");
io::stderr().flush().ok();
let mut line = String::new();
io::stdin().read_line(&mut line).context("Failed to read selection")?;
let s = line.trim().to_lowercase();
if s == "q" || s == "quit" || s == "exit" { return Ok(Vec::new()); }
let mut selected: Vec<usize> = Vec::new();
if s == "all" || s == "*" {
selected = (1..idx).collect();
} else if !s.is_empty() {
for part in s.split(|c| c == ',' || c == ' ' || c == ';') {
let part = part.trim();
if part.is_empty() { continue; }
if let Some((a, b)) = part.split_once('-') {
if let (Ok(ia), Ok(ib)) = (a.parse::<usize>(), b.parse::<usize>()) {
if ia >= 1 && ib < idx && ia <= ib { selected.extend(ia..=ib); }
}
} else if let Ok(i) = part.parse::<usize>() {
if i >= 1 && i < idx { selected.push(i); }
}
}
io::stdin().read_line(&mut line).context("Failed to read base selection")?;
let s = line.trim();
if s.eq_ignore_ascii_case("q") || s.eq_ignore_ascii_case("quit") || s.eq_ignore_ascii_case("exit") {
return Ok(Vec::new());
}
selected.sort_unstable();
selected.dedup();
if selected.is_empty() {
eprintln!("No valid selection. Please try again or 'q' to cancel.");
let chosen_base = if let Ok(i) = s.parse::<usize>() {
if i >= 1 && i <= bases.len() { Some(bases[i - 1].clone()) } else { None }
} else if !s.is_empty() {
// accept exact name match (case-insensitive)
bases.iter().find(|b| b.eq_ignore_ascii_case(s)).cloned()
} else { None };
if let Some(base) = chosen_base {
// 2) Choose sub-type(s) within that base
let filtered: Vec<ModelEntry> = models.iter().filter(|m| m.base == base).cloned().collect();
if filtered.is_empty() {
eprintln!("No models found for base '{}'.", base);
continue;
}
// Reuse the formatter but only for the chosen base list
let listing = format_model_list(&filtered);
eprint!("{}", listing);
// Build index map for filtered list
let mut index_map: Vec<usize> = Vec::with_capacity(filtered.len());
let mut idx = 1usize;
for (pos, _m) in filtered.iter().enumerate() {
index_map.push(pos);
idx += 1;
}
// Second prompt: sub-type selection
loop {
eprint!("Selection: ");
io::stderr().flush().ok();
let mut line2 = String::new();
io::stdin().read_line(&mut line2).context("Failed to read selection")?;
let s2 = line2.trim().to_lowercase();
if s2 == "q" || s2 == "quit" || s2 == "exit" { return Ok(Vec::new()); }
let mut selected: Vec<usize> = Vec::new();
if s2 == "all" || s2 == "*" {
selected = (1..idx).collect();
} else if !s2.is_empty() {
for part in s2.split(|c| c == ',' || c == ' ' || c == ';') {
let part = part.trim();
if part.is_empty() { continue; }
if let Some((a, b)) = part.split_once('-') {
if let (Ok(ia), Ok(ib)) = (a.parse::<usize>(), b.parse::<usize>()) {
if ia >= 1 && ib < idx && ia <= ib { selected.extend(ia..=ib); }
}
} else if let Ok(i) = part.parse::<usize>() {
if i >= 1 && i < idx { selected.push(i); }
}
}
}
selected.sort_unstable();
selected.dedup();
if selected.is_empty() {
eprintln!("No valid selection. Please try again or 'q' to cancel.");
continue;
}
let chosen: Vec<ModelEntry> = selected.into_iter().map(|i| filtered[index_map[i - 1]].clone()).collect();
return Ok(chosen);
}
} else {
eprintln!("Invalid base selection. Please enter a number from 1-{} or a base name.", bases.len());
continue;
}
let chosen: Vec<ModelEntry> = selected.into_iter().map(|i| models[index_map[i - 1]].clone()).collect();
return Ok(chosen);
}
}
@@ -385,7 +457,7 @@ pub fn run_interactive_model_downloader() -> Result<()> {
eprintln!("No models found on Hugging Face listing. Please try again later.");
return Ok(());
}
let selected = prompt_select_models(&models)?;
let selected = prompt_select_models_two_stage(&models)?;
if selected.is_empty() {
eprintln!("No selection. Aborting download.");
return Ok(());
@@ -631,6 +703,38 @@ mod tests {
use std::fs;
use std::io::Write;
#[test]
fn test_format_model_list_spacing_and_structure() {
let models = vec![
ModelEntry { name: "tiny.en-q5_1".to_string(), base: "tiny".to_string(), subtype: "en-q5_1".to_string(), size: 1024*1024, sha256: Some("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef".to_string()), repo: "ggerganov/whisper.cpp".to_string() },
ModelEntry { name: "tiny-q5_1".to_string(), base: "tiny".to_string(), subtype: "q5_1".to_string(), size: 2048, sha256: None, repo: "ggerganov/whisper.cpp".to_string() },
ModelEntry { name: "base.en-q5_1".to_string(), base: "base".to_string(), subtype: "en-q5_1".to_string(), size: 10, sha256: Some("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string()), repo: "akashmjn/tinydiarize-whisper.cpp".to_string() },
];
let s = format_model_list(&models);
// Header present
assert!(s.starts_with("Available ggml Whisper models:\n"));
// Group headers and blank line before header
assert!(s.contains("\ntiny:\n"));
assert!(s.contains("\nbase:\n"));
// No immediate double space before a bracket after parenthesis
assert!(!s.contains(") ["), "should not have double space immediately before bracket");
// Lines contain normalized spacing around pipes and no hash
assert!(s.contains("[ggerganov/whisper.cpp | 1.00 MiB]"));
assert!(s.contains("[ggerganov/whisper.cpp | 2.00 KiB]"));
// Verify alignment: the '[' position should match across multiple lines
let bracket_positions: Vec<usize> = s
.lines()
.filter(|l| l.contains("ggerganov/whisper.cpp"))
.map(|l| l.find('[').unwrap())
.collect();
assert!(bracket_positions.len() >= 2);
for w in bracket_positions.windows(2) {
assert_eq!(w[0], w[1], "bracket columns should align");
}
// Footer instruction present
assert!(s.contains("Enter selection by indices"));
}
fn sha256_hex(data: &[u8]) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();