[feat] improve model selection prompt with two-stage filtering and formatted output; add corresponding tests
This commit is contained in:
2
TODO.md
2
TODO.md
@@ -8,7 +8,7 @@
|
||||
- [x] for merging (command line flag) -> if not present, treat each file as separate output (--merge | -m)
|
||||
- [x] for merge + separate output -> if present, treat each file as separate output and also output a merged version (--merge-and-separate)
|
||||
- [x] set speaker-names per input-file -> prompt user for each file if flag is set (--set-speaker-names)
|
||||
- fix cli output for model display
|
||||
- [x] fix cli output for model display
|
||||
- refactor into proper cli app
|
||||
- add support for video files -> use ffmpeg to extract audio
|
||||
- detect gpus and use them
|
||||
|
152
src/models.rs
152
src/models.rs
@@ -270,40 +270,107 @@ fn fetch_all_models(client: &Client) -> Result<Vec<ModelEntry>> {
|
||||
}
|
||||
|
||||
|
||||
fn prompt_select_models(models: &[ModelEntry]) -> Result<Vec<ModelEntry>> {
|
||||
// Build a flat list but show group headers; indices count only models
|
||||
println!("Available ggml Whisper models:");
|
||||
fn format_model_list(models: &[ModelEntry]) -> String {
|
||||
let mut out = String::new();
|
||||
out.push_str("Available ggml Whisper models:\n");
|
||||
|
||||
// Compute alignment widths
|
||||
let idx_width = std::cmp::max(2, models.len().to_string().len());
|
||||
let name_width = models.iter().map(|m| m.name.len()).max().unwrap_or(0);
|
||||
|
||||
let mut idx = 1usize;
|
||||
let mut current = "".to_string();
|
||||
// We'll record mapping from index -> position in models
|
||||
let mut index_map: Vec<usize> = Vec::with_capacity(models.len());
|
||||
for (pos, m) in models.iter().enumerate() {
|
||||
let mut current = String::new();
|
||||
for m in models.iter() {
|
||||
if m.base != current {
|
||||
current = m.base.clone();
|
||||
println!("\n{}:", current);
|
||||
out.push_str("\n");
|
||||
out.push_str(&format!("{}:\n", current));
|
||||
}
|
||||
let short_hash = m
|
||||
.sha256
|
||||
.as_ref()
|
||||
.map(|h| h.chars().take(8).collect::<String>())
|
||||
.unwrap_or_else(|| "-".to_string());
|
||||
println!(" {}) {} [{} | {} | {}]", idx, m.name, m.repo, human_size(m.size), short_hash);
|
||||
// Format without hash and with aligned columns
|
||||
out.push_str(&format!(
|
||||
" {i:>iw$}) {name:<nw$} [{repo} | {size}]\n",
|
||||
i = idx,
|
||||
iw = idx_width,
|
||||
name = m.name,
|
||||
nw = name_width,
|
||||
repo = m.repo,
|
||||
size = human_size(m.size),
|
||||
));
|
||||
idx += 1;
|
||||
}
|
||||
out.push_str("\nEnter selection by indices (e.g., 1 3 5-7), or 'all', '*' for all, 'q' to cancel.\n");
|
||||
out
|
||||
}
|
||||
|
||||
fn prompt_select_models_two_stage(models: &[ModelEntry]) -> Result<Vec<ModelEntry>> {
|
||||
// 1) Choose base (tiny, small, medium, etc.)
|
||||
let mut bases: Vec<String> = Vec::new();
|
||||
let mut last = String::new();
|
||||
for m in models.iter() {
|
||||
if m.base != last {
|
||||
// models are sorted by base; avoid duplicates while preserving order
|
||||
if bases.last().map(|b| b == &m.base).unwrap_or(false) == false {
|
||||
bases.push(m.base.clone());
|
||||
}
|
||||
last = m.base.clone();
|
||||
}
|
||||
}
|
||||
if bases.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Print base selection on stderr
|
||||
eprintln!("Available base model families:");
|
||||
for (i, b) in bases.iter().enumerate() {
|
||||
eprintln!(" {}) {}", i + 1, b);
|
||||
}
|
||||
loop {
|
||||
eprint!("Select base (number or name, 'q' to cancel): ");
|
||||
io::stderr().flush().ok();
|
||||
let mut line = String::new();
|
||||
io::stdin().read_line(&mut line).context("Failed to read base selection")?;
|
||||
let s = line.trim();
|
||||
if s.eq_ignore_ascii_case("q") || s.eq_ignore_ascii_case("quit") || s.eq_ignore_ascii_case("exit") {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let chosen_base = if let Ok(i) = s.parse::<usize>() {
|
||||
if i >= 1 && i <= bases.len() { Some(bases[i - 1].clone()) } else { None }
|
||||
} else if !s.is_empty() {
|
||||
// accept exact name match (case-insensitive)
|
||||
bases.iter().find(|b| b.eq_ignore_ascii_case(s)).cloned()
|
||||
} else { None };
|
||||
|
||||
if let Some(base) = chosen_base {
|
||||
// 2) Choose sub-type(s) within that base
|
||||
let filtered: Vec<ModelEntry> = models.iter().filter(|m| m.base == base).cloned().collect();
|
||||
if filtered.is_empty() {
|
||||
eprintln!("No models found for base '{}'.", base);
|
||||
continue;
|
||||
}
|
||||
// Reuse the formatter but only for the chosen base list
|
||||
let listing = format_model_list(&filtered);
|
||||
eprint!("{}", listing);
|
||||
|
||||
// Build index map for filtered list
|
||||
let mut index_map: Vec<usize> = Vec::with_capacity(filtered.len());
|
||||
let mut idx = 1usize;
|
||||
for (pos, _m) in filtered.iter().enumerate() {
|
||||
index_map.push(pos);
|
||||
idx += 1;
|
||||
}
|
||||
println!("\nEnter selection by indices (e.g., 1 3 5-7), or 'all', '*' for all, 'q' to cancel.");
|
||||
// Second prompt: sub-type selection
|
||||
loop {
|
||||
eprint!("Selection: ");
|
||||
io::stderr().flush().ok();
|
||||
let mut line = String::new();
|
||||
io::stdin().read_line(&mut line).context("Failed to read selection")?;
|
||||
let s = line.trim().to_lowercase();
|
||||
if s == "q" || s == "quit" || s == "exit" { return Ok(Vec::new()); }
|
||||
let mut line2 = String::new();
|
||||
io::stdin().read_line(&mut line2).context("Failed to read selection")?;
|
||||
let s2 = line2.trim().to_lowercase();
|
||||
if s2 == "q" || s2 == "quit" || s2 == "exit" { return Ok(Vec::new()); }
|
||||
let mut selected: Vec<usize> = Vec::new();
|
||||
if s == "all" || s == "*" {
|
||||
if s2 == "all" || s2 == "*" {
|
||||
selected = (1..idx).collect();
|
||||
} else if !s.is_empty() {
|
||||
for part in s.split(|c| c == ',' || c == ' ' || c == ';') {
|
||||
} else if !s2.is_empty() {
|
||||
for part in s2.split(|c| c == ',' || c == ' ' || c == ';') {
|
||||
let part = part.trim();
|
||||
if part.is_empty() { continue; }
|
||||
if let Some((a, b)) = part.split_once('-') {
|
||||
@@ -321,9 +388,14 @@ fn prompt_select_models(models: &[ModelEntry]) -> Result<Vec<ModelEntry>> {
|
||||
eprintln!("No valid selection. Please try again or 'q' to cancel.");
|
||||
continue;
|
||||
}
|
||||
let chosen: Vec<ModelEntry> = selected.into_iter().map(|i| models[index_map[i - 1]].clone()).collect();
|
||||
let chosen: Vec<ModelEntry> = selected.into_iter().map(|i| filtered[index_map[i - 1]].clone()).collect();
|
||||
return Ok(chosen);
|
||||
}
|
||||
} else {
|
||||
eprintln!("Invalid base selection. Please enter a number from 1-{} or a base name.", bases.len());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_file_sha256_hex(path: &Path) -> Result<String> {
|
||||
@@ -385,7 +457,7 @@ pub fn run_interactive_model_downloader() -> Result<()> {
|
||||
eprintln!("No models found on Hugging Face listing. Please try again later.");
|
||||
return Ok(());
|
||||
}
|
||||
let selected = prompt_select_models(&models)?;
|
||||
let selected = prompt_select_models_two_stage(&models)?;
|
||||
if selected.is_empty() {
|
||||
eprintln!("No selection. Aborting download.");
|
||||
return Ok(());
|
||||
@@ -631,6 +703,38 @@ mod tests {
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
|
||||
#[test]
|
||||
fn test_format_model_list_spacing_and_structure() {
|
||||
let models = vec![
|
||||
ModelEntry { name: "tiny.en-q5_1".to_string(), base: "tiny".to_string(), subtype: "en-q5_1".to_string(), size: 1024*1024, sha256: Some("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef".to_string()), repo: "ggerganov/whisper.cpp".to_string() },
|
||||
ModelEntry { name: "tiny-q5_1".to_string(), base: "tiny".to_string(), subtype: "q5_1".to_string(), size: 2048, sha256: None, repo: "ggerganov/whisper.cpp".to_string() },
|
||||
ModelEntry { name: "base.en-q5_1".to_string(), base: "base".to_string(), subtype: "en-q5_1".to_string(), size: 10, sha256: Some("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string()), repo: "akashmjn/tinydiarize-whisper.cpp".to_string() },
|
||||
];
|
||||
let s = format_model_list(&models);
|
||||
// Header present
|
||||
assert!(s.starts_with("Available ggml Whisper models:\n"));
|
||||
// Group headers and blank line before header
|
||||
assert!(s.contains("\ntiny:\n"));
|
||||
assert!(s.contains("\nbase:\n"));
|
||||
// No immediate double space before a bracket after parenthesis
|
||||
assert!(!s.contains(") ["), "should not have double space immediately before bracket");
|
||||
// Lines contain normalized spacing around pipes and no hash
|
||||
assert!(s.contains("[ggerganov/whisper.cpp | 1.00 MiB]"));
|
||||
assert!(s.contains("[ggerganov/whisper.cpp | 2.00 KiB]"));
|
||||
// Verify alignment: the '[' position should match across multiple lines
|
||||
let bracket_positions: Vec<usize> = s
|
||||
.lines()
|
||||
.filter(|l| l.contains("ggerganov/whisper.cpp"))
|
||||
.map(|l| l.find('[').unwrap())
|
||||
.collect();
|
||||
assert!(bracket_positions.len() >= 2);
|
||||
for w in bracket_positions.windows(2) {
|
||||
assert_eq!(w[0], w[1], "bracket columns should align");
|
||||
}
|
||||
// Footer instruction present
|
||||
assert!(s.contains("Enter selection by indices"));
|
||||
}
|
||||
|
||||
fn sha256_hex(data: &[u8]) -> String {
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hasher = Sha256::new();
|
||||
|
Reference in New Issue
Block a user