From 53a7471b997fa340c7b00634a302d2b5de6b219c Mon Sep 17 00:00:00 2001 From: vikingowl Date: Fri, 8 Aug 2025 13:06:24 +0200 Subject: [PATCH] [feat] add `--set-speaker-names` CLI flag; implement prompt-based speaker name assignment with tests --- TODO.md | 2 +- src/main.rs | 34 ++++++++++++++-- src/models.rs | 4 ++ tests/integration_cli.rs | 86 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 122 insertions(+), 4 deletions(-) diff --git a/TODO.md b/TODO.md index eda615d..629196f 100644 --- a/TODO.md +++ b/TODO.md @@ -7,7 +7,7 @@ - [x] create missing folders for output files - [x] for merging (command line flag) -> if not present, treat each file as separate output (--merge | -m) - [x] for merge + separate output -> if present, treat each file as separate output and also output a merged version (--merge-and-separate) -- set speaker-names per input-file -> prompt user for each file if flag is set (--set-speaker-names) +- [x] set speaker-names per input-file -> prompt user for each file if flag is set (--set-speaker-names) - fix cli output for model display - refactor into proper cli app - add support for video files -> use ffmpeg to extract audio diff --git a/src/main.rs b/src/main.rs index 7f234ca..96f62f1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -76,6 +76,10 @@ struct Args { /// Update local Whisper models by comparing hashes/sizes with remote manifest #[arg(long)] update_models: bool, + + /// Prompt for speaker names per input file + #[arg(long = "set-speaker-names")] + set_speaker_names: bool, } #[derive(Debug, Deserialize)] @@ -145,6 +149,27 @@ fn sanitize_speaker_name(raw: &str) -> String { raw.to_string() } +fn prompt_speaker_name_for_path(path: &Path, default_name: &str, enabled: bool) -> String { + if !enabled { + return default_name.to_string(); + } + let display_owned: String = path + .file_name() + .and_then(|s| s.to_str()) + .map(|s| s.to_string()) + .unwrap_or_else(|| path.to_string_lossy().to_string()); + eprint!("Enter speaker name for {} [default: {}]: ", display_owned, default_name); + io::stderr().flush().ok(); + let mut buf = String::new(); + match io::stdin().read_line(&mut buf) { + Ok(_) => { + let s = buf.trim(); + if s.is_empty() { default_name.to_string() } else { s.to_string() } + } + Err(_) => default_name.to_string(), + } +} + // --- Helpers for audio transcription --- fn is_json_file(path: &Path) -> bool { matches!(path.extension().and_then(|s| s.to_str()).map(|s| s.to_lowercase()), Some(ext) if ext == "json") @@ -467,9 +492,10 @@ fn main() -> Result<()> { for input_path in &inputs { let path = Path::new(input_path); - let speaker = sanitize_speaker_name( + let default_speaker = sanitize_speaker_name( path.file_stem().and_then(|s| s.to_str()).unwrap_or("speaker") ); + let speaker = prompt_speaker_name_for_path(path, &default_speaker, args.set_speaker_names); // Collect entries per file and extend merged let mut entries: Vec = Vec::new(); @@ -557,11 +583,12 @@ fn main() -> Result<()> { let mut entries: Vec = Vec::new(); for input_path in &inputs { let path = Path::new(input_path); - let speaker = sanitize_speaker_name( + let default_speaker = sanitize_speaker_name( path.file_stem() .and_then(|s| s.to_str()) .unwrap_or("speaker") ); + let speaker = prompt_speaker_name_for_path(path, &default_speaker, args.set_speaker_names); let mut buf = String::new(); if is_audio_file(path) { @@ -657,9 +684,10 @@ fn main() -> Result<()> { for input_path in &inputs { let path = Path::new(input_path); - let speaker = sanitize_speaker_name( + let default_speaker = sanitize_speaker_name( path.file_stem().and_then(|s| s.to_str()).unwrap_or("speaker") ); + let speaker = prompt_speaker_name_for_path(path, &default_speaker, args.set_speaker_names); // Collect entries per file let mut entries: Vec = Vec::new(); diff --git a/src/models.rs b/src/models.rs index d38508c..47b8081 100644 --- a/src/models.rs +++ b/src/models.rs @@ -643,6 +643,10 @@ mod tests { #[test] fn test_update_local_models_offline_copy_and_manifest() { + use std::sync::{Mutex, OnceLock}; + static ENV_LOCK: OnceLock> = OnceLock::new(); + let _guard = ENV_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap(); + let tmp_models = tempdir().unwrap(); let tmp_base = tempdir().unwrap(); let tmp_manifest = tempdir().unwrap(); diff --git a/tests/integration_cli.rs b/tests/integration_cli.rs index 962ec25..930ff0b 100644 --- a/tests/integration_cli.rs +++ b/tests/integration_cli.rs @@ -204,3 +204,89 @@ fn cli_merge_and_separate_writes_both_kinds_of_outputs() { // Cleanup let _ = fs::remove_dir_all(&out_dir); } + + +#[test] +fn cli_set_speaker_names_merge_prompts_and_uses_names() { + use std::io::{Read as _, Write as _}; + use std::process::Stdio; + + let exe = env!("CARGO_BIN_EXE_polyscribe"); + + let input1 = manifest_path("input/1-s0wlz.json"); + let input2 = manifest_path("input/2-vikingowl.json"); + + let mut child = Command::new(exe) + .arg(input1.as_os_str()) + .arg(input2.as_os_str()) + .arg("-m") + .arg("--set-speaker-names") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn() + .expect("failed to spawn polyscribe"); + + { + let stdin = child.stdin.as_mut().expect("failed to open stdin"); + // Provide two names for two files + writeln!(stdin, "Alpha").unwrap(); + writeln!(stdin, "Beta").unwrap(); + } + + let output = child.wait_with_output().expect("failed to wait on child"); + assert!(output.status.success(), "CLI did not exit successfully"); + + let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8"); + let root: OutputRoot = serde_json::from_str(&stdout).unwrap(); + let speakers: std::collections::HashSet = root.items.into_iter().map(|e| e.speaker).collect(); + assert!(speakers.contains("Alpha"), "Alpha not found in speakers"); + assert!(speakers.contains("Beta"), "Beta not found in speakers"); +} + +#[test] +fn cli_set_speaker_names_separate_single_input() { + use std::io::Write as _; + use std::process::Stdio; + + let exe = env!("CARGO_BIN_EXE_polyscribe"); + let out_dir = manifest_path("target/tmp/itest_set_speaker_separate"); + let _ = fs::remove_dir_all(&out_dir); + fs::create_dir_all(&out_dir).unwrap(); + + let input1 = manifest_path("input/3-schmendrizzle.json"); + + let mut child = Command::new(exe) + .arg(input1.as_os_str()) + .arg("--set-speaker-names") + .arg("-o") + .arg(out_dir.as_os_str()) + .stdin(Stdio::piped()) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .spawn() + .expect("failed to spawn polyscribe"); + + { + let stdin = child.stdin.as_mut().expect("failed to open stdin"); + writeln!(stdin, "ChosenOne").unwrap(); + } + + let status = child.wait().expect("failed to wait on child"); + assert!(status.success(), "CLI did not exit successfully"); + + // Find created JSON + let mut json_paths: Vec = Vec::new(); + for e in fs::read_dir(&out_dir).unwrap() { + let p = e.unwrap().path(); + if let Some(name) = p.file_name().and_then(|s| s.to_str()) { + if name.ends_with(".json") { json_paths.push(p.clone()); } + } + } + assert!(!json_paths.is_empty(), "no JSON outputs created"); + let mut buf = String::new(); + std::fs::File::open(&json_paths[0]).unwrap().read_to_string(&mut buf).unwrap(); + let root: OutputRoot = serde_json::from_str(&buf).unwrap(); + assert!(root.items.iter().all(|e| e.speaker == "ChosenOne")); + + let _ = fs::remove_dir_all(&out_dir); +}