diff --git a/TODO.md b/TODO.md index 629196f..9ecadd4 100644 --- a/TODO.md +++ b/TODO.md @@ -8,7 +8,7 @@ - [x] for merging (command line flag) -> if not present, treat each file as separate output (--merge | -m) - [x] for merge + separate output -> if present, treat each file as separate output and also output a merged version (--merge-and-separate) - [x] set speaker-names per input-file -> prompt user for each file if flag is set (--set-speaker-names) -- fix cli output for model display +- [x] fix cli output for model display - refactor into proper cli app - add support for video files -> use ffmpeg to extract audio - detect gpus and use them diff --git a/src/models.rs b/src/models.rs index 47b8081..29fb87c 100644 --- a/src/models.rs +++ b/src/models.rs @@ -270,59 +270,131 @@ fn fetch_all_models(client: &Client) -> Result> { } -fn prompt_select_models(models: &[ModelEntry]) -> Result> { - // Build a flat list but show group headers; indices count only models - println!("Available ggml Whisper models:"); +fn format_model_list(models: &[ModelEntry]) -> String { + let mut out = String::new(); + out.push_str("Available ggml Whisper models:\n"); + + // Compute alignment widths + let idx_width = std::cmp::max(2, models.len().to_string().len()); + let name_width = models.iter().map(|m| m.name.len()).max().unwrap_or(0); + let mut idx = 1usize; - let mut current = "".to_string(); - // We'll record mapping from index -> position in models - let mut index_map: Vec = Vec::with_capacity(models.len()); - for (pos, m) in models.iter().enumerate() { + let mut current = String::new(); + for m in models.iter() { if m.base != current { current = m.base.clone(); - println!("\n{}:", current); + out.push_str("\n"); + out.push_str(&format!("{}:\n", current)); } - let short_hash = m - .sha256 - .as_ref() - .map(|h| h.chars().take(8).collect::()) - .unwrap_or_else(|| "-".to_string()); - println!(" {}) {} [{} | {} | {}]", idx, m.name, m.repo, human_size(m.size), short_hash); - index_map.push(pos); + // Format without hash and with aligned columns + out.push_str(&format!( + " {i:>iw$}) {name: Result> { + // 1) Choose base (tiny, small, medium, etc.) + let mut bases: Vec = Vec::new(); + let mut last = String::new(); + for m in models.iter() { + if m.base != last { + // models are sorted by base; avoid duplicates while preserving order + if bases.last().map(|b| b == &m.base).unwrap_or(false) == false { + bases.push(m.base.clone()); + } + last = m.base.clone(); + } + } + if bases.is_empty() { + return Ok(Vec::new()); + } + + // Print base selection on stderr + eprintln!("Available base model families:"); + for (i, b) in bases.iter().enumerate() { + eprintln!(" {}) {}", i + 1, b); + } loop { - eprint!("Selection: "); + eprint!("Select base (number or name, 'q' to cancel): "); io::stderr().flush().ok(); let mut line = String::new(); - io::stdin().read_line(&mut line).context("Failed to read selection")?; - let s = line.trim().to_lowercase(); - if s == "q" || s == "quit" || s == "exit" { return Ok(Vec::new()); } - let mut selected: Vec = Vec::new(); - if s == "all" || s == "*" { - selected = (1..idx).collect(); - } else if !s.is_empty() { - for part in s.split(|c| c == ',' || c == ' ' || c == ';') { - let part = part.trim(); - if part.is_empty() { continue; } - if let Some((a, b)) = part.split_once('-') { - if let (Ok(ia), Ok(ib)) = (a.parse::(), b.parse::()) { - if ia >= 1 && ib < idx && ia <= ib { selected.extend(ia..=ib); } - } - } else if let Ok(i) = part.parse::() { - if i >= 1 && i < idx { selected.push(i); } - } - } + io::stdin().read_line(&mut line).context("Failed to read base selection")?; + let s = line.trim(); + if s.eq_ignore_ascii_case("q") || s.eq_ignore_ascii_case("quit") || s.eq_ignore_ascii_case("exit") { + return Ok(Vec::new()); } - selected.sort_unstable(); - selected.dedup(); - if selected.is_empty() { - eprintln!("No valid selection. Please try again or 'q' to cancel."); + let chosen_base = if let Ok(i) = s.parse::() { + if i >= 1 && i <= bases.len() { Some(bases[i - 1].clone()) } else { None } + } else if !s.is_empty() { + // accept exact name match (case-insensitive) + bases.iter().find(|b| b.eq_ignore_ascii_case(s)).cloned() + } else { None }; + + if let Some(base) = chosen_base { + // 2) Choose sub-type(s) within that base + let filtered: Vec = models.iter().filter(|m| m.base == base).cloned().collect(); + if filtered.is_empty() { + eprintln!("No models found for base '{}'.", base); + continue; + } + // Reuse the formatter but only for the chosen base list + let listing = format_model_list(&filtered); + eprint!("{}", listing); + + // Build index map for filtered list + let mut index_map: Vec = Vec::with_capacity(filtered.len()); + let mut idx = 1usize; + for (pos, _m) in filtered.iter().enumerate() { + index_map.push(pos); + idx += 1; + } + // Second prompt: sub-type selection + loop { + eprint!("Selection: "); + io::stderr().flush().ok(); + let mut line2 = String::new(); + io::stdin().read_line(&mut line2).context("Failed to read selection")?; + let s2 = line2.trim().to_lowercase(); + if s2 == "q" || s2 == "quit" || s2 == "exit" { return Ok(Vec::new()); } + let mut selected: Vec = Vec::new(); + if s2 == "all" || s2 == "*" { + selected = (1..idx).collect(); + } else if !s2.is_empty() { + for part in s2.split(|c| c == ',' || c == ' ' || c == ';') { + let part = part.trim(); + if part.is_empty() { continue; } + if let Some((a, b)) = part.split_once('-') { + if let (Ok(ia), Ok(ib)) = (a.parse::(), b.parse::()) { + if ia >= 1 && ib < idx && ia <= ib { selected.extend(ia..=ib); } + } + } else if let Ok(i) = part.parse::() { + if i >= 1 && i < idx { selected.push(i); } + } + } + } + selected.sort_unstable(); + selected.dedup(); + if selected.is_empty() { + eprintln!("No valid selection. Please try again or 'q' to cancel."); + continue; + } + let chosen: Vec = selected.into_iter().map(|i| filtered[index_map[i - 1]].clone()).collect(); + return Ok(chosen); + } + } else { + eprintln!("Invalid base selection. Please enter a number from 1-{} or a base name.", bases.len()); continue; } - let chosen: Vec = selected.into_iter().map(|i| models[index_map[i - 1]].clone()).collect(); - return Ok(chosen); } } @@ -385,7 +457,7 @@ pub fn run_interactive_model_downloader() -> Result<()> { eprintln!("No models found on Hugging Face listing. Please try again later."); return Ok(()); } - let selected = prompt_select_models(&models)?; + let selected = prompt_select_models_two_stage(&models)?; if selected.is_empty() { eprintln!("No selection. Aborting download."); return Ok(()); @@ -631,6 +703,38 @@ mod tests { use std::fs; use std::io::Write; + #[test] + fn test_format_model_list_spacing_and_structure() { + let models = vec![ + ModelEntry { name: "tiny.en-q5_1".to_string(), base: "tiny".to_string(), subtype: "en-q5_1".to_string(), size: 1024*1024, sha256: Some("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef".to_string()), repo: "ggerganov/whisper.cpp".to_string() }, + ModelEntry { name: "tiny-q5_1".to_string(), base: "tiny".to_string(), subtype: "q5_1".to_string(), size: 2048, sha256: None, repo: "ggerganov/whisper.cpp".to_string() }, + ModelEntry { name: "base.en-q5_1".to_string(), base: "base".to_string(), subtype: "en-q5_1".to_string(), size: 10, sha256: Some("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string()), repo: "akashmjn/tinydiarize-whisper.cpp".to_string() }, + ]; + let s = format_model_list(&models); + // Header present + assert!(s.starts_with("Available ggml Whisper models:\n")); + // Group headers and blank line before header + assert!(s.contains("\ntiny:\n")); + assert!(s.contains("\nbase:\n")); + // No immediate double space before a bracket after parenthesis + assert!(!s.contains(") ["), "should not have double space immediately before bracket"); + // Lines contain normalized spacing around pipes and no hash + assert!(s.contains("[ggerganov/whisper.cpp | 1.00 MiB]")); + assert!(s.contains("[ggerganov/whisper.cpp | 2.00 KiB]")); + // Verify alignment: the '[' position should match across multiple lines + let bracket_positions: Vec = s + .lines() + .filter(|l| l.contains("ggerganov/whisper.cpp")) + .map(|l| l.find('[').unwrap()) + .collect(); + assert!(bracket_positions.len() >= 2); + for w in bracket_positions.windows(2) { + assert_eq!(w[0], w[1], "bracket columns should align"); + } + // Footer instruction present + assert!(s.contains("Enter selection by indices")); + } + fn sha256_hex(data: &[u8]) -> String { use sha2::{Digest, Sha256}; let mut hasher = Sha256::new();