diff --git a/Cargo.lock b/Cargo.lock index d3fdbdf..befb354 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -248,6 +248,15 @@ dependencies = [ "strsim", ] +[[package]] +name = "clap_complete" +version = "4.5.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67e4efcbb5da11a92e8a609233aa1e8a7d91e38de0be865f016d14700d45a7fd" +dependencies = [ + "clap", +] + [[package]] name = "clap_derive" version = "4.5.41" @@ -266,6 +275,16 @@ version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" +[[package]] +name = "clap_mangen" +version = "0.2.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27b4c3c54b30f0d9adcb47f25f61fcce35c4dd8916638c6b82fbd5f4fb4179e2" +dependencies = [ + "clap", + "roff", +] + [[package]] name = "cmake" version = "0.1.54" @@ -1057,6 +1076,8 @@ dependencies = [ "anyhow", "chrono", "clap", + "clap_complete", + "clap_mangen", "reqwest", "serde", "serde_json", @@ -1194,6 +1215,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "roff" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88f8660c1ff60292143c98d08fc6e2f654d722db50410e3f3797d40baaf9d8f3" + [[package]] name = "rustc-demangle" version = "0.1.26" diff --git a/Cargo.toml b/Cargo.toml index 4f1c8ab..5fab8c0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,8 @@ edition = "2024" [dependencies] anyhow = "1.0.98" clap = { version = "4.5.43", features = ["derive"] } +clap_complete = "4.5.28" +clap_mangen = "0.2" serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.142" toml = "0.8" diff --git a/TODO.md b/TODO.md index 9ecadd4..a9beaa6 100644 --- a/TODO.md +++ b/TODO.md @@ -9,7 +9,7 @@ - [x] for merge + separate output -> if present, treat each file as separate output and also output a merged version (--merge-and-separate) - [x] set speaker-names per input-file -> prompt user for each file if flag is set (--set-speaker-names) - [x] fix cli output for model display -- refactor into proper cli app +- [x] refactor into proper cli app - add support for video files -> use ffmpeg to extract audio - detect gpus and use them - add error handling diff --git a/src/main.rs b/src/main.rs index 96f62f1..8fbb483 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,16 +5,38 @@ use std::process::Command; use std::env; use anyhow::{anyhow, Context, Result}; -use clap::Parser; +use clap::{Parser, Subcommand}; use serde::{Deserialize, Serialize}; use chrono::Local; -use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU8, Ordering}; +use clap_complete::Shell; use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters}; mod models; static LAST_MODEL_WRITTEN: AtomicBool = AtomicBool::new(false); +static VERBOSE: AtomicU8 = AtomicU8::new(0); + +macro_rules! vlog { + ($lvl:expr, $($arg:tt)*) => { + let v = VERBOSE.load(Ordering::Relaxed); + let needed = match $lvl { 0u8 => true, 1u8 => v >= 1, 2u8 => v >= 2, _ => true }; + if needed { eprintln!("INFO: {}", format!($($arg)*)); } + } +} + +macro_rules! warnlog { + ($($arg:tt)*) => { + eprintln!("WARN: {}", format!($($arg)*)); + } +} + +macro_rules! errorlog { + ($($arg:tt)*) => { + eprintln!("ERROR: {}", format!($($arg)*)); + } +} fn models_dir_path() -> PathBuf { // Highest priority: explicit override @@ -47,9 +69,30 @@ fn models_dir_path() -> PathBuf { PathBuf::from("models") } + +#[derive(Subcommand, Debug, Clone)] +enum AuxCommands { + /// Generate shell completion script to stdout + Completions { + /// Shell to generate completions for + #[arg(value_enum)] + shell: Shell, + }, + /// Generate a man page to stdout + Man, +} + #[derive(Parser, Debug)] -#[command(name = "PolyScribe", version, about = "Merge multiple JSON transcripts into one or transcribe audio using native whisper")] +#[command(name = "PolyScribe", bin_name = "polyscribe", version, about = "Merge JSON transcripts or transcribe audio using native whisper")] struct Args { + /// Increase verbosity (-v, -vv). Logs go to stderr. + #[arg(short = 'v', long = "verbose", action = clap::ArgAction::Count, global = true)] + verbose: u8, + + /// Optional auxiliary subcommands (completions, man) + #[command(subcommand)] + aux: Option, + /// Input .json transcript files or audio files to merge/transcribe inputs: Vec, @@ -243,7 +286,8 @@ fn find_model_file() -> Result { } if candidates.is_empty() { - eprintln!("No Whisper model files (*.bin) found in {}.", models_dir.display()); + // In quiet mode we still prompt for models; suppress only non-essential logs + warnlog!("No Whisper model files (*.bin) found in {}.", models_dir.display()); eprint!("Would you like to download models now? [Y/n]: "); io::stderr().flush().ok(); let mut input = String::new(); @@ -251,7 +295,7 @@ fn find_model_file() -> Result { let ans = input.trim().to_lowercase(); if ans.is_empty() || ans == "y" || ans == "yes" { if let Err(e) = models::run_interactive_model_downloader() { - eprintln!("Downloader failed: {:#}", e); + errorlog!("Downloader failed: {:#}", e); } // Re-scan candidates.clear(); @@ -292,7 +336,7 @@ fn find_model_file() -> Result { if p.is_file() { // Also ensure it's one of the candidates (same dir) if candidates.iter().any(|c| c == &p) { - eprintln!("Using previously selected model: {}", p.display()); + vlog!(0, "Using previously selected model: {}", p.display()); return Ok(p); } } @@ -419,8 +463,34 @@ impl Drop for LastModelCleanup { } } + fn main() -> Result<()> { - let args = Args::parse(); + // Parse CLI + let mut args = Args::parse(); + + // Initialize verbosity + VERBOSE.store(args.verbose, Ordering::Relaxed); + + // Handle auxiliary subcommands that write to stdout and exit early + if let Some(aux) = &args.aux { + use clap::CommandFactory; + match aux { + AuxCommands::Completions { shell } => { + let mut cmd = Args::command(); + let bin_name = cmd.get_name().to_string(); + clap_complete::generate(*shell, &mut cmd, bin_name, &mut io::stdout()); + return Ok(()) + } + AuxCommands::Man => { + let cmd = Args::command(); + let man = clap_mangen::Man::new(cmd); + let mut out = Vec::new(); + man.render(&mut out)?; + io::stdout().write_all(&out)?; + return Ok(()) + } + } + } // Defer cleanup of .last_model until program exit let models_dir_buf = models_dir_path(); @@ -431,7 +501,7 @@ fn main() -> Result<()> { // If requested, run the interactive model downloader first. If no inputs were provided, exit after downloading. if args.download_models { if let Err(e) = models::run_interactive_model_downloader() { - eprintln!("Model downloader failed: {:#}", e); + errorlog!("Model downloader failed: {:#}", e); } if args.inputs.is_empty() { return Ok(()); @@ -441,7 +511,7 @@ fn main() -> Result<()> { // If requested, update local models and exit unless inputs provided to continue if args.update_models { if let Err(e) = models::update_local_models() { - eprintln!("Model update failed: {:#}", e); + errorlog!("Model update failed: {:#}", e); return Err(e); } // if only updating models and no inputs, exit @@ -451,6 +521,7 @@ fn main() -> Result<()> { } // Determine inputs and optional output path + vlog!(1, "Parsed {} input(s)", args.inputs.len()); let mut inputs = args.inputs; let mut output_path = args.output; if output_path.is_none() && inputs.len() >= 2 { @@ -477,6 +548,7 @@ fn main() -> Result<()> { } if args.merge_and_separate { + vlog!(1, "Mode: merge-and-separate; output_dir={:?}", output_path); // Combined mode: write separate outputs per input and also a merged output set // Require an output directory let out_dir = match output_path.as_ref() { @@ -579,6 +651,7 @@ fn main() -> Result<()> { .with_context(|| format!("Failed to create output file: {}", m_srt.display()))?; ms.write_all(m_srt_str.as_bytes())?; } else if args.merge { + vlog!(1, "Mode: merge; output_base={:?}", output_path); // MERGED MODE (previous default) let mut entries: Vec = Vec::new(); for input_path in &inputs { @@ -668,6 +741,7 @@ fn main() -> Result<()> { serde_json::to_writer_pretty(&mut handle, &out)?; writeln!(&mut handle)?; } } else { + vlog!(1, "Mode: separate; output_dir={:?}", output_path); // SEPARATE MODE (default now) // If writing to stdout with multiple inputs, not supported if output_path.is_none() && inputs.len() > 1 { diff --git a/src/models.rs b/src/models.rs index 29fb87c..a303c68 100644 --- a/src/models.rs +++ b/src/models.rs @@ -11,6 +11,13 @@ use reqwest::blocking::Client; use reqwest::redirect::Policy; use sha2::{Digest, Sha256}; +// Print to stderr only when not in quiet mode +macro_rules! qlog { + ($($arg:tt)*) => { + eprintln!($($arg)*); + }; +} + // --- Model downloader: list & download ggml models from Hugging Face --- #[derive(Debug, Deserialize)] @@ -170,7 +177,7 @@ fn fill_meta_via_head(repo: &str, name: &str) -> (Option, Option) { } fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result> { - eprintln!("Fetching online data: listing models from {}...", repo); + qlog!("Fetching online data: listing models from {}...", repo); // Prefer the tree endpoint for reliable size/hash metadata, then fall back to model metadata let tree_url = format!("https://huggingface.co/api/models/{}/tree/main?recursive=1", repo); let mut out: Vec = Vec::new(); @@ -220,7 +227,7 @@ fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result Result Result> { - eprintln!("Fetching online data: aggregating available models from Hugging Face..."); + qlog!("Fetching online data: aggregating available models from Hugging Face..."); let mut v1 = hf_fetch_repo_models(client, "ggerganov/whisper.cpp")?; // main repo must succeed // Optional tinydiarize repo; ignore errors but log to stderr let mut v2: Vec = match hf_fetch_repo_models(client, "akashmjn/tinydiarize-whisper.cpp") { Ok(v) => v, Err(e) => { - eprintln!("Warning: failed to fetch optional repo akashmjn/tinydiarize-whisper.cpp: {:#}", e); + qlog!("Warning: failed to fetch optional repo akashmjn/tinydiarize-whisper.cpp: {:#}", e); Vec::new() } }; @@ -451,19 +458,19 @@ pub fn run_interactive_model_downloader() -> Result<()> { .build() .context("Failed to build HTTP client")?; - eprintln!("Fetching online data: contacting Hugging Face to retrieve available models (this may take a moment)..."); + qlog!("Fetching online data: contacting Hugging Face to retrieve available models (this may take a moment)..."); let models = fetch_all_models(&client)?; if models.is_empty() { - eprintln!("No models found on Hugging Face listing. Please try again later."); + qlog!("No models found on Hugging Face listing. Please try again later."); return Ok(()); } let selected = prompt_select_models_two_stage(&models)?; if selected.is_empty() { - eprintln!("No selection. Aborting download."); + qlog!("No selection. Aborting download."); return Ok(()); } for m in selected { - if let Err(e) = download_one_model(&client, models_dir, &m) { eprintln!("Error: {:#}", e); } + if let Err(e) = download_one_model(&client, models_dir, &m) { qlog!("Error: {:#}", e); } } Ok(()) } @@ -477,10 +484,10 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry match compute_file_sha256_hex(&final_path) { Ok(local_hash) => { if local_hash.eq_ignore_ascii_case(expected) { - eprintln!("Model {} is up-to-date (hash match).", final_path.display()); + qlog!("Model {} is up-to-date (hash match).", final_path.display()); return Ok(()); } else { - eprintln!( + qlog!( "Local model {} hash differs from online (local {}.., online {}..). Updating...", final_path.display(), &local_hash[..std::cmp::min(8, local_hash.len())], @@ -489,37 +496,37 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry } } Err(e) => { - eprintln!( - "Warning: failed to hash existing {}: {}. Will re-download to ensure correctness.", - final_path.display(), e - ); + qlog!( + "Warning: failed to hash existing {}: {}. Will re-download to ensure correctness.", + final_path.display(), e + ); } } } else if entry.size > 0 { match std::fs::metadata(&final_path) { Ok(md) => { if md.len() == entry.size { - eprintln!( + qlog!( "Model {} appears up-to-date by size ({}).", final_path.display(), entry.size ); return Ok(()); } else { - eprintln!( + qlog!( "Local model {} size ({}) differs from online ({}). Updating...", final_path.display(), md.len(), entry.size ); } } Err(e) => { - eprintln!( + qlog!( "Warning: failed to stat existing {}: {}. Will re-download to ensure correctness.", final_path.display(), e ); } } } else { - eprintln!( + qlog!( "Model {} exists but remote hash/size not available; will download to verify contents.", final_path.display() ); @@ -531,7 +538,7 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry if let Ok(base_dir) = env::var("POLYSCRIBE_MODELS_BASE_COPY_DIR") { let src_path = std::path::Path::new(&base_dir).join(format!("ggml-{}.bin", entry.name)); if src_path.exists() { - eprintln!("Copying {} from {}...", entry.name, src_path.display()); + qlog!("Copying {} from {}...", entry.name, src_path.display()); let tmp_path = models_dir.join(format!("ggml-{}.bin.part", entry.name)); if tmp_path.exists() { let _ = std::fs::remove_file(&tmp_path); } std::fs::copy(&src_path, &tmp_path) @@ -551,13 +558,13 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry if final_path.exists() { let _ = std::fs::remove_file(&final_path); } std::fs::rename(&tmp_path, &final_path) .with_context(|| format!("Failed to move into place: {}", final_path.display()))?; - eprintln!("Saved: {}", final_path.display()); + qlog!("Saved: {}", final_path.display()); return Ok(()); } } let url = format!("https://huggingface.co/{}/resolve/main/ggml-{}.bin", entry.repo, entry.name); - eprintln!("Downloading {} ({} | {})...", entry.name, human_size(entry.size), url); + qlog!("Downloading {} ({} | {})...", entry.name, human_size(entry.size), url); let mut resp = client .get(url) .send() @@ -593,7 +600,7 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry )); } } else { - eprintln!("Warning: no SHA-256 available for {}. Skipping verification.", entry.name); + qlog!("Warning: no SHA-256 available for {}. Skipping verification.", entry.name); } // Replace existing file safely if final_path.exists() { @@ -601,7 +608,7 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry } std::fs::rename(&tmp_path, &final_path) .with_context(|| format!("Failed to move into place: {}", final_path.display()))?; - eprintln!("Saved: {}", final_path.display()); + qlog!("Saved: {}", final_path.display()); Ok(()) } @@ -653,42 +660,42 @@ pub fn update_local_models() -> Result<()> { match compute_file_sha256_hex(&path) { Ok(local_hash) => { if local_hash.eq_ignore_ascii_case(expected) { - eprintln!("{} is up-to-date.", fname); + qlog!("{} is up-to-date.", fname); continue; } else { - eprintln!( - "{} hash differs (local {}.. != remote {}..). Updating...", - fname, - &local_hash[..std::cmp::min(8, local_hash.len())], - &expected[..std::cmp::min(8, expected.len())] - ); + qlog!( + "{} hash differs (local {}.. != remote {}..). Updating...", + fname, + &local_hash[..std::cmp::min(8, local_hash.len())], + &expected[..std::cmp::min(8, expected.len())] + ); } } Err(e) => { - eprintln!("Warning: failed hashing {}: {}. Re-downloading.", fname, e); + qlog!("Warning: failed hashing {}: {}. Re-downloading.", fname, e); } } download_one_model(&client, models_dir, remote)?; } else if remote.size > 0 { match std::fs::metadata(&path) { Ok(md) if md.len() == remote.size => { - eprintln!("{} appears up-to-date by size ({}).", fname, remote.size); + qlog!("{} appears up-to-date by size ({}).", fname, remote.size); continue; } Ok(md) => { - eprintln!("{} size {} differs from remote {}. Updating...", fname, md.len(), remote.size); + qlog!("{} size {} differs from remote {}. Updating...", fname, md.len(), remote.size); download_one_model(&client, models_dir, remote)?; } Err(e) => { - eprintln!("Warning: stat failed for {}: {}. Updating...", fname, e); + qlog!("Warning: stat failed for {}: {}. Updating...", fname, e); download_one_model(&client, models_dir, remote)?; } } } else { - eprintln!("No remote hash/size for {}. Skipping.", fname); + qlog!("No remote hash/size for {}. Skipping.", fname); } } else { - eprintln!("No remote metadata for {}. Skipping.", fname); + qlog!("No remote metadata for {}. Skipping.", fname); } } diff --git a/tests/integration_aux.rs b/tests/integration_aux.rs new file mode 100644 index 0000000..1703a62 --- /dev/null +++ b/tests/integration_aux.rs @@ -0,0 +1,45 @@ +use std::process::Command; + +fn bin() -> &'static str { env!("CARGO_BIN_EXE_polyscribe") } + +#[test] +fn aux_completions_bash_outputs_script() { + let out = Command::new(bin()) + .arg("completions") + .arg("bash") + .output() + .expect("failed to run polyscribe completions bash"); + assert!(out.status.success(), "completions bash exited with failure: {:?}", out.status); + let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8"); + assert!(!stdout.trim().is_empty(), "completions bash stdout is empty"); + // Heuristic: bash completion scripts often contain 'complete -F' lines + assert!(stdout.contains("complete") || stdout.contains("_polyscribe"), "bash completion script did not contain expected markers"); +} + +#[test] +fn aux_completions_zsh_outputs_script() { + let out = Command::new(bin()) + .arg("completions") + .arg("zsh") + .output() + .expect("failed to run polyscribe completions zsh"); + assert!(out.status.success(), "completions zsh exited with failure: {:?}", out.status); + let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8"); + assert!(!stdout.trim().is_empty(), "completions zsh stdout is empty"); + // Heuristic: zsh completion scripts often start with '#compdef' + assert!(stdout.contains("#compdef") || stdout.contains("#compdef polyscribe"), "zsh completion script did not contain expected markers"); +} + +#[test] +fn aux_man_outputs_roff() { + let out = Command::new(bin()) + .arg("man") + .output() + .expect("failed to run polyscribe man"); + assert!(out.status.success(), "man exited with failure: {:?}", out.status); + let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8"); + assert!(!stdout.trim().is_empty(), "man stdout is empty"); + // clap_mangen typically emits roff with .TH and/or section headers + let looks_like_roff = stdout.contains(".TH ") || stdout.starts_with(".TH") || stdout.contains(".SH NAME") || stdout.contains(".SH SYNOPSIS"); + assert!(looks_like_roff, "man output does not look like a roff manpage; got: {}", &stdout.lines().take(3).collect::>().join(" | ")); +}