[feat] improve model selection prompt with two-stage filtering and formatted output; add corresponding tests

This commit is contained in:
2025-08-08 13:21:51 +02:00
parent 53a7471b99
commit 1cad6d593d
2 changed files with 147 additions and 43 deletions

View File

@@ -8,7 +8,7 @@
- [x] for merging (command line flag) -> if not present, treat each file as separate output (--merge | -m) - [x] for merging (command line flag) -> if not present, treat each file as separate output (--merge | -m)
- [x] for merge + separate output -> if present, treat each file as separate output and also output a merged version (--merge-and-separate) - [x] for merge + separate output -> if present, treat each file as separate output and also output a merged version (--merge-and-separate)
- [x] set speaker-names per input-file -> prompt user for each file if flag is set (--set-speaker-names) - [x] set speaker-names per input-file -> prompt user for each file if flag is set (--set-speaker-names)
- fix cli output for model display - [x] fix cli output for model display
- refactor into proper cli app - refactor into proper cli app
- add support for video files -> use ffmpeg to extract audio - add support for video files -> use ffmpeg to extract audio
- detect gpus and use them - detect gpus and use them

View File

@@ -270,59 +270,131 @@ fn fetch_all_models(client: &Client) -> Result<Vec<ModelEntry>> {
} }
fn prompt_select_models(models: &[ModelEntry]) -> Result<Vec<ModelEntry>> { fn format_model_list(models: &[ModelEntry]) -> String {
// Build a flat list but show group headers; indices count only models let mut out = String::new();
println!("Available ggml Whisper models:"); out.push_str("Available ggml Whisper models:\n");
// Compute alignment widths
let idx_width = std::cmp::max(2, models.len().to_string().len());
let name_width = models.iter().map(|m| m.name.len()).max().unwrap_or(0);
let mut idx = 1usize; let mut idx = 1usize;
let mut current = "".to_string(); let mut current = String::new();
// We'll record mapping from index -> position in models for m in models.iter() {
let mut index_map: Vec<usize> = Vec::with_capacity(models.len());
for (pos, m) in models.iter().enumerate() {
if m.base != current { if m.base != current {
current = m.base.clone(); current = m.base.clone();
println!("\n{}:", current); out.push_str("\n");
out.push_str(&format!("{}:\n", current));
} }
let short_hash = m // Format without hash and with aligned columns
.sha256 out.push_str(&format!(
.as_ref() " {i:>iw$}) {name:<nw$} [{repo} | {size}]\n",
.map(|h| h.chars().take(8).collect::<String>()) i = idx,
.unwrap_or_else(|| "-".to_string()); iw = idx_width,
println!(" {}) {} [{} | {} | {}]", idx, m.name, m.repo, human_size(m.size), short_hash); name = m.name,
index_map.push(pos); nw = name_width,
repo = m.repo,
size = human_size(m.size),
));
idx += 1; idx += 1;
} }
println!("\nEnter selection by indices (e.g., 1 3 5-7), or 'all', '*' for all, 'q' to cancel."); out.push_str("\nEnter selection by indices (e.g., 1 3 5-7), or 'all', '*' for all, 'q' to cancel.\n");
out
}
fn prompt_select_models_two_stage(models: &[ModelEntry]) -> Result<Vec<ModelEntry>> {
// 1) Choose base (tiny, small, medium, etc.)
let mut bases: Vec<String> = Vec::new();
let mut last = String::new();
for m in models.iter() {
if m.base != last {
// models are sorted by base; avoid duplicates while preserving order
if bases.last().map(|b| b == &m.base).unwrap_or(false) == false {
bases.push(m.base.clone());
}
last = m.base.clone();
}
}
if bases.is_empty() {
return Ok(Vec::new());
}
// Print base selection on stderr
eprintln!("Available base model families:");
for (i, b) in bases.iter().enumerate() {
eprintln!(" {}) {}", i + 1, b);
}
loop { loop {
eprint!("Selection: "); eprint!("Select base (number or name, 'q' to cancel): ");
io::stderr().flush().ok(); io::stderr().flush().ok();
let mut line = String::new(); let mut line = String::new();
io::stdin().read_line(&mut line).context("Failed to read selection")?; io::stdin().read_line(&mut line).context("Failed to read base selection")?;
let s = line.trim().to_lowercase(); let s = line.trim();
if s == "q" || s == "quit" || s == "exit" { return Ok(Vec::new()); } if s.eq_ignore_ascii_case("q") || s.eq_ignore_ascii_case("quit") || s.eq_ignore_ascii_case("exit") {
let mut selected: Vec<usize> = Vec::new(); return Ok(Vec::new());
if s == "all" || s == "*" {
selected = (1..idx).collect();
} else if !s.is_empty() {
for part in s.split(|c| c == ',' || c == ' ' || c == ';') {
let part = part.trim();
if part.is_empty() { continue; }
if let Some((a, b)) = part.split_once('-') {
if let (Ok(ia), Ok(ib)) = (a.parse::<usize>(), b.parse::<usize>()) {
if ia >= 1 && ib < idx && ia <= ib { selected.extend(ia..=ib); }
}
} else if let Ok(i) = part.parse::<usize>() {
if i >= 1 && i < idx { selected.push(i); }
}
}
} }
selected.sort_unstable(); let chosen_base = if let Ok(i) = s.parse::<usize>() {
selected.dedup(); if i >= 1 && i <= bases.len() { Some(bases[i - 1].clone()) } else { None }
if selected.is_empty() { } else if !s.is_empty() {
eprintln!("No valid selection. Please try again or 'q' to cancel."); // accept exact name match (case-insensitive)
bases.iter().find(|b| b.eq_ignore_ascii_case(s)).cloned()
} else { None };
if let Some(base) = chosen_base {
// 2) Choose sub-type(s) within that base
let filtered: Vec<ModelEntry> = models.iter().filter(|m| m.base == base).cloned().collect();
if filtered.is_empty() {
eprintln!("No models found for base '{}'.", base);
continue;
}
// Reuse the formatter but only for the chosen base list
let listing = format_model_list(&filtered);
eprint!("{}", listing);
// Build index map for filtered list
let mut index_map: Vec<usize> = Vec::with_capacity(filtered.len());
let mut idx = 1usize;
for (pos, _m) in filtered.iter().enumerate() {
index_map.push(pos);
idx += 1;
}
// Second prompt: sub-type selection
loop {
eprint!("Selection: ");
io::stderr().flush().ok();
let mut line2 = String::new();
io::stdin().read_line(&mut line2).context("Failed to read selection")?;
let s2 = line2.trim().to_lowercase();
if s2 == "q" || s2 == "quit" || s2 == "exit" { return Ok(Vec::new()); }
let mut selected: Vec<usize> = Vec::new();
if s2 == "all" || s2 == "*" {
selected = (1..idx).collect();
} else if !s2.is_empty() {
for part in s2.split(|c| c == ',' || c == ' ' || c == ';') {
let part = part.trim();
if part.is_empty() { continue; }
if let Some((a, b)) = part.split_once('-') {
if let (Ok(ia), Ok(ib)) = (a.parse::<usize>(), b.parse::<usize>()) {
if ia >= 1 && ib < idx && ia <= ib { selected.extend(ia..=ib); }
}
} else if let Ok(i) = part.parse::<usize>() {
if i >= 1 && i < idx { selected.push(i); }
}
}
}
selected.sort_unstable();
selected.dedup();
if selected.is_empty() {
eprintln!("No valid selection. Please try again or 'q' to cancel.");
continue;
}
let chosen: Vec<ModelEntry> = selected.into_iter().map(|i| filtered[index_map[i - 1]].clone()).collect();
return Ok(chosen);
}
} else {
eprintln!("Invalid base selection. Please enter a number from 1-{} or a base name.", bases.len());
continue; continue;
} }
let chosen: Vec<ModelEntry> = selected.into_iter().map(|i| models[index_map[i - 1]].clone()).collect();
return Ok(chosen);
} }
} }
@@ -385,7 +457,7 @@ pub fn run_interactive_model_downloader() -> Result<()> {
eprintln!("No models found on Hugging Face listing. Please try again later."); eprintln!("No models found on Hugging Face listing. Please try again later.");
return Ok(()); return Ok(());
} }
let selected = prompt_select_models(&models)?; let selected = prompt_select_models_two_stage(&models)?;
if selected.is_empty() { if selected.is_empty() {
eprintln!("No selection. Aborting download."); eprintln!("No selection. Aborting download.");
return Ok(()); return Ok(());
@@ -631,6 +703,38 @@ mod tests {
use std::fs; use std::fs;
use std::io::Write; use std::io::Write;
#[test]
fn test_format_model_list_spacing_and_structure() {
let models = vec![
ModelEntry { name: "tiny.en-q5_1".to_string(), base: "tiny".to_string(), subtype: "en-q5_1".to_string(), size: 1024*1024, sha256: Some("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef".to_string()), repo: "ggerganov/whisper.cpp".to_string() },
ModelEntry { name: "tiny-q5_1".to_string(), base: "tiny".to_string(), subtype: "q5_1".to_string(), size: 2048, sha256: None, repo: "ggerganov/whisper.cpp".to_string() },
ModelEntry { name: "base.en-q5_1".to_string(), base: "base".to_string(), subtype: "en-q5_1".to_string(), size: 10, sha256: Some("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string()), repo: "akashmjn/tinydiarize-whisper.cpp".to_string() },
];
let s = format_model_list(&models);
// Header present
assert!(s.starts_with("Available ggml Whisper models:\n"));
// Group headers and blank line before header
assert!(s.contains("\ntiny:\n"));
assert!(s.contains("\nbase:\n"));
// No immediate double space before a bracket after parenthesis
assert!(!s.contains(") ["), "should not have double space immediately before bracket");
// Lines contain normalized spacing around pipes and no hash
assert!(s.contains("[ggerganov/whisper.cpp | 1.00 MiB]"));
assert!(s.contains("[ggerganov/whisper.cpp | 2.00 KiB]"));
// Verify alignment: the '[' position should match across multiple lines
let bracket_positions: Vec<usize> = s
.lines()
.filter(|l| l.contains("ggerganov/whisper.cpp"))
.map(|l| l.find('[').unwrap())
.collect();
assert!(bracket_positions.len() >= 2);
for w in bracket_positions.windows(2) {
assert_eq!(w[0], w[1], "bracket columns should align");
}
// Footer instruction present
assert!(s.contains("Enter selection by indices"));
}
fn sha256_hex(data: &[u8]) -> String { fn sha256_hex(data: &[u8]) -> String {
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
let mut hasher = Sha256::new(); let mut hasher = Sha256::new();