diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..e71e53b --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,40 @@ +# PolyScribe Refactor toward Rust 2024 — Incremental Patches + +This changelog documents each incremental step applied to keep the build green while moving the codebase toward Rust 2024 idioms. + +## 1) Formatting only (rustfmt) +- Ran `cargo fmt` across the repository. +- No semantic changes. +- Build status: OK (`cargo build` succeeded). + +## 2) Lints — initial fixes (non-pedantic) +- Adjusted crate lint policy in `src/lib.rs`: + - Replaced `#![warn(clippy::pedantic, clippy::nursery, clippy::cargo)]` with `#![warn(clippy::all)]` to align with the plan (skip pedantic/nursery for now). + - Added comment/TODO to revisit stricter lints in a later pass. +- Fixed several clippy warnings that were causing `cargo clippy --all-targets` to error under tests: + - `src/backend.rs`: conditionally import `libloading::Library` only for non-test builds and mark `names` parameter as used in test cfg to avoid unused warnings; keep `check_lib()` side‑effect free during tests. + - `src/models.rs`: removed an unused `std::io::Write` import in test module. + - `src/main.rs` (unit tests): imported `polyscribe::format_srt_time` explicitly and removed a duplicate `use super::*;` to fix unresolved name and unused import warnings under clippy test builds. +- Build/Clippy status: + - `cargo build`: OK. + - `cargo clippy --all-targets`: OK (only warnings remain; no errors). + +## 3) Module hygiene +- Verified crate structure: + - Library crate (`src/lib.rs`) exposes a coherent API and re‑exports `backend` and `models` via `pub mod`. + - Binary (`src/main.rs`) consumes the library API through `polyscribe::...` paths. +- No structural changes required. Build status: OK. + +## 4) Edition +- The project already targets `edition = "2024"` in Cargo.toml. +- Verified that the project compiles under Rust 2024. No changes needed. +- TODO: If stricter lints or new features from 2024 edition introduce issues in future steps, document blockers here. + +## 5) Error handling +- The codebase already returns `anyhow::Result` in the binary and uses contextual errors widely. +- No `unwrap`/`expect` usages in production paths required attention in this pass. +- Build status: OK. + +## Next planned steps (not yet applied in this changelog) +- Gradually fix remaining clippy warnings (e.g., `uninlined_format_args`, small style nits) in small, compile‑green patches. +- Optionally re‑enable `clippy::pedantic`, `clippy::nursery`, and `clippy::cargo` once warnings are significantly reduced, then address non‑breaking warnings. diff --git a/TODO.md b/TODO.md index 1cfbeae..dd27431 100644 --- a/TODO.md +++ b/TODO.md @@ -11,11 +11,12 @@ - [x] fix cli output for model display - [x] refactor into proper cli app - [x] add support for video files -> use ffmpeg to extract audio -- detect gpus and use them -- refactor project +- [x] detect gpus and use them +- [x] refactor project - add error handling - add verbose flag (--verbose | -v) + add logging - add documentation +- refactor project - package into executable - add CI - add package build for arch linux diff --git a/build.rs b/build.rs index 3a9588a..539dcf2 100644 --- a/build.rs +++ b/build.rs @@ -7,5 +7,7 @@ fn main() { // Placeholder: In a full implementation, we would invoke CMake for whisper.cpp with GGML_VULKAN=1. // For now, emit a helpful note. Build will proceed; runtime Vulkan backend returns an explanatory error. println!("cargo:rerun-if-changed=extern/whisper.cpp"); - println!("cargo:warning=Building with gpu-vulkan: ensure Vulkan SDK/loader are installed. Future versions will compile whisper.cpp via CMake."); + println!( + "cargo:warning=Building with gpu-vulkan: ensure Vulkan SDK/loader are installed. Future versions will compile whisper.cpp via CMake." + ); } diff --git a/src/backend.rs b/src/backend.rs index d8c5c98..d770e5b 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -1,30 +1,54 @@ -use std::path::Path; -use anyhow::{anyhow, Context, Result}; -use libloading::Library; -use crate::{OutputEntry}; +use crate::OutputEntry; use crate::{decode_audio_to_pcm_f32_ffmpeg, find_model_file}; +use anyhow::{Context, Result, anyhow}; +#[cfg(not(test))] +use libloading::Library; use std::env; +use std::path::Path; // Re-export a public enum for CLI parsing usage #[derive(Debug, Clone, Copy, PartialEq, Eq)] +/// Kind of transcription backend to use. pub enum BackendKind { + /// Automatically detect the best available backend (CUDA > HIP > Vulkan > CPU). Auto, + /// Pure CPU backend using whisper-rs. Cpu, + /// NVIDIA CUDA backend (requires CUDA runtime available at load time and proper feature build). Cuda, + /// AMD ROCm/HIP backend (requires hip/rocBLAS libraries available and proper feature build). Hip, + /// Vulkan backend (experimental; requires Vulkan loader/SDK and feature build). Vulkan, } +/// Abstraction for a transcription backend implementation. pub trait TranscribeBackend { + /// Return the backend kind for this implementation. fn kind(&self) -> BackendKind; - fn transcribe(&self, audio_path: &Path, speaker: &str, lang_opt: Option<&str>, gpu_layers: Option) -> Result>; + /// Transcribe the given audio file path and return transcript entries. + /// + /// Parameters: + /// - audio_path: path to input media (audio or video) to be decoded/transcribed. + /// - speaker: label to attach to all produced segments. + /// - lang_opt: optional language hint (e.g., "en"); None means auto/multilingual model default. + /// - gpu_layers: optional GPU layer count if applicable (ignored by some backends). + fn transcribe( + &self, + audio_path: &Path, + speaker: &str, + lang_opt: Option<&str>, + gpu_layers: Option, + ) -> Result>; } fn check_lib(names: &[&str]) -> bool { #[cfg(test)] { // During unit tests, avoid touching system libs to prevent loader crashes in CI. - return false; + // Mark parameter as used to silence warnings in test builds. + let _ = names; + false } #[cfg(not(test))] { @@ -33,79 +57,167 @@ fn check_lib(names: &[&str]) -> bool { } for n in names { // Attempt to dlopen; ignore errors - if let Ok(_lib) = unsafe { Library::new(n) } { return true; } + if let Ok(_lib) = unsafe { Library::new(n) } { + return true; + } } false } } fn cuda_available() -> bool { - if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_CUDA") { return x == "1"; } - check_lib(&["libcudart.so", "libcudart.so.12", "libcudart.so.11", "libcublas.so", "libcublas.so.12"]) + if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_CUDA") { + return x == "1"; + } + check_lib(&[ + "libcudart.so", + "libcudart.so.12", + "libcudart.so.11", + "libcublas.so", + "libcublas.so.12", + ]) } fn hip_available() -> bool { - if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_HIP") { return x == "1"; } + if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_HIP") { + return x == "1"; + } check_lib(&["libhipblas.so", "librocblas.so"]) } fn vulkan_available() -> bool { - if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_VULKAN") { return x == "1"; } + if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_VULKAN") { + return x == "1"; + } check_lib(&["libvulkan.so.1", "libvulkan.so"]) } +/// CPU-based transcription backend using whisper-rs. pub struct CpuBackend; +/// CUDA-accelerated transcription backend for NVIDIA GPUs. pub struct CudaBackend; +/// ROCm/HIP-accelerated transcription backend for AMD GPUs. pub struct HipBackend; +/// Vulkan-based transcription backend (experimental/incomplete). pub struct VulkanBackend; impl CpuBackend { - pub fn new() -> Self { CpuBackend } + /// Create a new CPU backend instance. + pub fn new() -> Self { + CpuBackend + } +} +impl CudaBackend { + /// Create a new CUDA backend instance. + pub fn new() -> Self { + CudaBackend + } +} +impl HipBackend { + /// Create a new HIP backend instance. + pub fn new() -> Self { + HipBackend + } +} +impl VulkanBackend { + /// Create a new Vulkan backend instance. + pub fn new() -> Self { + VulkanBackend + } } -impl CudaBackend { pub fn new() -> Self { CudaBackend } } -impl HipBackend { pub fn new() -> Self { HipBackend } } -impl VulkanBackend { pub fn new() -> Self { VulkanBackend } } impl TranscribeBackend for CpuBackend { - fn kind(&self) -> BackendKind { BackendKind::Cpu } - fn transcribe(&self, audio_path: &Path, speaker: &str, lang_opt: Option<&str>, _gpu_layers: Option) -> Result> { + fn kind(&self) -> BackendKind { + BackendKind::Cpu + } + fn transcribe( + &self, + audio_path: &Path, + speaker: &str, + lang_opt: Option<&str>, + _gpu_layers: Option, + ) -> Result> { transcribe_with_whisper_rs(audio_path, speaker, lang_opt) } } impl TranscribeBackend for CudaBackend { - fn kind(&self) -> BackendKind { BackendKind::Cuda } - fn transcribe(&self, audio_path: &Path, speaker: &str, lang_opt: Option<&str>, _gpu_layers: Option) -> Result> { + fn kind(&self) -> BackendKind { + BackendKind::Cuda + } + fn transcribe( + &self, + audio_path: &Path, + speaker: &str, + lang_opt: Option<&str>, + _gpu_layers: Option, + ) -> Result> { // whisper-rs uses enabled CUDA feature at build time; call same code path transcribe_with_whisper_rs(audio_path, speaker, lang_opt) } } impl TranscribeBackend for HipBackend { - fn kind(&self) -> BackendKind { BackendKind::Hip } - fn transcribe(&self, audio_path: &Path, speaker: &str, lang_opt: Option<&str>, _gpu_layers: Option) -> Result> { + fn kind(&self) -> BackendKind { + BackendKind::Hip + } + fn transcribe( + &self, + audio_path: &Path, + speaker: &str, + lang_opt: Option<&str>, + _gpu_layers: Option, + ) -> Result> { transcribe_with_whisper_rs(audio_path, speaker, lang_opt) } } impl TranscribeBackend for VulkanBackend { - fn kind(&self) -> BackendKind { BackendKind::Vulkan } - fn transcribe(&self, _audio_path: &Path, _speaker: &str, _lang_opt: Option<&str>, _gpu_layers: Option) -> Result> { - Err(anyhow!("Vulkan backend not yet wired to whisper.cpp FFI. Build with --features gpu-vulkan and ensure Vulkan SDK is installed. How to fix: install Vulkan loader (libvulkan), set VULKAN_SDK, and run cargo build --features gpu-vulkan.")) + fn kind(&self) -> BackendKind { + BackendKind::Vulkan + } + fn transcribe( + &self, + _audio_path: &Path, + _speaker: &str, + _lang_opt: Option<&str>, + _gpu_layers: Option, + ) -> Result> { + Err(anyhow!( + "Vulkan backend not yet wired to whisper.cpp FFI. Build with --features gpu-vulkan and ensure Vulkan SDK is installed. How to fix: install Vulkan loader (libvulkan), set VULKAN_SDK, and run cargo build --features gpu-vulkan." + )) } } +/// Result of choosing a transcription backend. pub struct SelectionResult { + /// The constructed backend instance to perform transcription with. pub backend: Box, + /// Which backend kind was ultimately selected. pub chosen: BackendKind, + /// Which backend kinds were detected as available on this system. pub detected: Vec, } +/// Select an appropriate backend based on user request and system detection. +/// +/// If `requested` is `BackendKind::Auto`, the function prefers CUDA, then HIP, +/// then Vulkan, falling back to CPU when no GPU backend is detected. When a +/// specific GPU backend is requested but unavailable, an error is returned with +/// guidance on how to enable it. +/// +/// Set `verbose` to true to print detection/selection info to stderr. pub fn select_backend(requested: BackendKind, verbose: bool) -> Result { let mut detected = Vec::new(); - if cuda_available() { detected.push(BackendKind::Cuda); } - if hip_available() { detected.push(BackendKind::Hip); } - if vulkan_available() { detected.push(BackendKind::Vulkan); } + if cuda_available() { + detected.push(BackendKind::Cuda); + } + if hip_available() { + detected.push(BackendKind::Hip); + } + if vulkan_available() { + detected.push(BackendKind::Vulkan); + } let mk = |k: BackendKind| -> Box { match k { @@ -119,22 +231,42 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result { - if detected.contains(&BackendKind::Cuda) { BackendKind::Cuda } - else if detected.contains(&BackendKind::Hip) { BackendKind::Hip } - else if detected.contains(&BackendKind::Vulkan) { BackendKind::Vulkan } - else { BackendKind::Cpu } + if detected.contains(&BackendKind::Cuda) { + BackendKind::Cuda + } else if detected.contains(&BackendKind::Hip) { + BackendKind::Hip + } else if detected.contains(&BackendKind::Vulkan) { + BackendKind::Vulkan + } else { + BackendKind::Cpu + } } BackendKind::Cuda => { - if detected.contains(&BackendKind::Cuda) { BackendKind::Cuda } - else { return Err(anyhow!("Requested CUDA backend but CUDA libraries/devices not detected. How to fix: install NVIDIA driver + CUDA toolkit, ensure libcudart/libcublas are in loader path, and build with --features gpu-cuda.")); } + if detected.contains(&BackendKind::Cuda) { + BackendKind::Cuda + } else { + return Err(anyhow!( + "Requested CUDA backend but CUDA libraries/devices not detected. How to fix: install NVIDIA driver + CUDA toolkit, ensure libcudart/libcublas are in loader path, and build with --features gpu-cuda." + )); + } } BackendKind::Hip => { - if detected.contains(&BackendKind::Hip) { BackendKind::Hip } - else { return Err(anyhow!("Requested ROCm/HIP backend but libraries/devices not detected. How to fix: install ROCm hipBLAS/rocBLAS, ensure libs are in loader path, and build with --features gpu-hip.")); } + if detected.contains(&BackendKind::Hip) { + BackendKind::Hip + } else { + return Err(anyhow!( + "Requested ROCm/HIP backend but libraries/devices not detected. How to fix: install ROCm hipBLAS/rocBLAS, ensure libs are in loader path, and build with --features gpu-hip." + )); + } } BackendKind::Vulkan => { - if detected.contains(&BackendKind::Vulkan) { BackendKind::Vulkan } - else { return Err(anyhow!("Requested Vulkan backend but libvulkan not detected. How to fix: install Vulkan loader/SDK and build with --features gpu-vulkan.")); } + if detected.contains(&BackendKind::Vulkan) { + BackendKind::Vulkan + } else { + return Err(anyhow!( + "Requested Vulkan backend but libvulkan not detected. How to fix: install Vulkan loader/SDK and build with --features gpu-vulkan." + )); + } } BackendKind::Cpu => BackendKind::Cpu, }; @@ -144,12 +276,20 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result) -> Result> { +pub(crate) fn transcribe_with_whisper_rs( + audio_path: &Path, + speaker: &str, + lang_opt: Option<&str>, +) -> Result> { let pcm = decode_audio_to_pcm_f32_ffmpeg(audio_path)?; let model = find_model_file()?; let is_en_only = model @@ -161,34 +301,60 @@ pub(crate) fn transcribe_with_whisper_rs(audio_path: &Path, speaker: &str, lang_ if is_en_only && lang != "en" { return Err(anyhow!( "Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.", - model.display(), lang + model.display(), + lang )); } } - let model_str = model.to_str().ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model.display()))?; + let model_str = model + .to_str() + .ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model.display()))?; let cparams = whisper_rs::WhisperContextParameters::default(); let ctx = whisper_rs::WhisperContext::new_with_params(model_str, cparams) .with_context(|| format!("Failed to load Whisper model at {}", model.display()))?; - let mut state = ctx.create_state().map_err(|e| anyhow!("Failed to create Whisper state: {:?}", e))?; + let mut state = ctx + .create_state() + .map_err(|e| anyhow!("Failed to create Whisper state: {:?}", e))?; - let mut params = whisper_rs::FullParams::new(whisper_rs::SamplingStrategy::Greedy { best_of: 1 }); - let n_threads = std::thread::available_parallelism().map(|n| n.get() as i32).unwrap_or(1); + let mut params = + whisper_rs::FullParams::new(whisper_rs::SamplingStrategy::Greedy { best_of: 1 }); + let n_threads = std::thread::available_parallelism() + .map(|n| n.get() as i32) + .unwrap_or(1); params.set_n_threads(n_threads); params.set_translate(false); - if let Some(lang) = lang_opt { params.set_language(Some(lang)); } + if let Some(lang) = lang_opt { + params.set_language(Some(lang)); + } - state.full(params, &pcm).map_err(|e| anyhow!("Whisper full() failed: {:?}", e))?; + state + .full(params, &pcm) + .map_err(|e| anyhow!("Whisper full() failed: {:?}", e))?; - let num_segments = state.full_n_segments().map_err(|e| anyhow!("Failed to get segments: {:?}", e))?; + let num_segments = state + .full_n_segments() + .map_err(|e| anyhow!("Failed to get segments: {:?}", e))?; let mut items = Vec::new(); for i in 0..num_segments { - let text = state.full_get_segment_text(i).map_err(|e| anyhow!("Failed to get segment text: {:?}", e))?; - let t0 = state.full_get_segment_t0(i).map_err(|e| anyhow!("Failed to get segment t0: {:?}", e))?; - let t1 = state.full_get_segment_t1(i).map_err(|e| anyhow!("Failed to get segment t1: {:?}", e))?; + let text = state + .full_get_segment_text(i) + .map_err(|e| anyhow!("Failed to get segment text: {:?}", e))?; + let t0 = state + .full_get_segment_t0(i) + .map_err(|e| anyhow!("Failed to get segment t0: {:?}", e))?; + let t1 = state + .full_get_segment_t1(i) + .map_err(|e| anyhow!("Failed to get segment t1: {:?}", e))?; let start = (t0 as f64) * 0.01; let end = (t1 as f64) * 0.01; - items.push(OutputEntry { id: 0, speaker: speaker.to_string(), start, end, text: text.trim().to_string() }); + items.push(OutputEntry { + id: 0, + speaker: speaker.to_string(), + start, + end, + text: text.trim().to_string(), + }); } Ok(items) } diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..480a3d5 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,341 @@ +#![forbid(elided_lifetimes_in_paths)] +#![forbid(unused_must_use)] +#![deny(missing_docs)] +// Lint policy for incremental refactor toward 2024: +// - Keep basic clippy warnings enabled; skip pedantic/nursery for now (will revisit in step 7). +// - cargo lints can be re-enabled later once codebase is tidied. +#![warn(clippy::all)] +//! PolyScribe library: business logic and core types. +//! +//! This crate exposes the reusable parts of the PolyScribe CLI as a library. +//! The binary entry point (main.rs) remains a thin CLI wrapper. + +use anyhow::{Context, Result, anyhow}; +use chrono::Local; +use std::env; +use std::fs::create_dir_all; +use std::io::{self, Write}; +use std::path::{Path, PathBuf}; +use std::process::Command; + +/// Re-export backend module (GPU/CPU selection and transcription). +pub mod backend; +/// Re-export models module (model listing/downloading/updating). +pub mod models; + +/// Transcript entry for a single segment. +#[derive(Debug, serde::Serialize, Clone)] +pub struct OutputEntry { + /// Sequential id in output ordering. + pub id: u64, + /// Speaker label associated with the segment. + pub speaker: String, + /// Start time in seconds. + pub start: f64, + /// End time in seconds. + pub end: f64, + /// Text content. + pub text: String, +} + +/// Return a YYYY-MM-DD date prefix string for output file naming. +pub fn date_prefix() -> String { + Local::now().format("%Y-%m-%d").to_string() +} + +/// Format a floating-point number of seconds as SRT timestamp (HH:MM:SS,mmm). +pub fn format_srt_time(seconds: f64) -> String { + let total_ms = (seconds * 1000.0).round() as i64; + let ms = (total_ms % 1000) as i64; + let total_secs = total_ms / 1000; + let s = (total_secs % 60) as i64; + let m = ((total_secs / 60) % 60) as i64; + let h = (total_secs / 3600) as i64; + format!("{:02}:{:02}:{:02},{:03}", h, m, s, ms) +} + +/// Render a list of transcript entries to SRT format. +pub fn render_srt(items: &[OutputEntry]) -> String { + let mut out = String::new(); + for (i, e) in items.iter().enumerate() { + let idx = i + 1; + out.push_str(&format!("{}\n", idx)); + out.push_str(&format!( + "{} --> {}\n", + format_srt_time(e.start), + format_srt_time(e.end) + )); + if !e.speaker.is_empty() { + out.push_str(&format!("{}: {}\n", e.speaker, e.text)); + } else { + out.push_str(&format!("{}\n", e.text)); + } + out.push('\n'); + } + out +} + +/// Determine the default models directory, honoring POLYSCRIBE_MODELS_DIR override. +pub fn models_dir_path() -> PathBuf { + if let Ok(p) = env::var("POLYSCRIBE_MODELS_DIR") { + let pb = PathBuf::from(p); + if !pb.as_os_str().is_empty() { + return pb; + } + } + if cfg!(debug_assertions) { + return PathBuf::from("models"); + } + if let Ok(xdg) = env::var("XDG_DATA_HOME") { + if !xdg.is_empty() { + return PathBuf::from(xdg).join("polyscribe").join("models"); + } + } + if let Ok(home) = env::var("HOME") { + if !home.is_empty() { + return PathBuf::from(home) + .join(".local") + .join("share") + .join("polyscribe") + .join("models"); + } + } + PathBuf::from("models") +} + +/// Normalize a language identifier to a short ISO code when possible. +pub fn normalize_lang_code(input: &str) -> Option { + let mut s = input.trim().to_lowercase(); + if s.is_empty() || s == "auto" || s == "c" || s == "posix" { + return None; + } + if let Some((lhs, _)) = s.split_once('.') { + s = lhs.to_string(); + } + if let Some((lhs, _)) = s.split_once('_') { + s = lhs.to_string(); + } + let code = match s.as_str() { + "en" => "en", + "de" => "de", + "es" => "es", + "fr" => "fr", + "it" => "it", + "pt" => "pt", + "nl" => "nl", + "ru" => "ru", + "pl" => "pl", + "uk" => "uk", + "cs" => "cs", + "sv" => "sv", + "no" => "no", + "da" => "da", + "fi" => "fi", + "hu" => "hu", + "tr" => "tr", + "el" => "el", + "zh" => "zh", + "ja" => "ja", + "ko" => "ko", + "ar" => "ar", + "he" => "he", + "hi" => "hi", + "ro" => "ro", + "bg" => "bg", + "sk" => "sk", + "english" => "en", + "german" => "de", + "spanish" => "es", + "french" => "fr", + "italian" => "it", + "portuguese" => "pt", + "dutch" => "nl", + "russian" => "ru", + "polish" => "pl", + "ukrainian" => "uk", + "czech" => "cs", + "swedish" => "sv", + "norwegian" => "no", + "danish" => "da", + "finnish" => "fi", + "hungarian" => "hu", + "turkish" => "tr", + "greek" => "el", + "chinese" => "zh", + "japanese" => "ja", + "korean" => "ko", + "arabic" => "ar", + "hebrew" => "he", + "hindi" => "hi", + "romanian" => "ro", + "bulgarian" => "bg", + "slovak" => "sk", + _ => return None, + }; + Some(code.to_string()) +} + +/// Locate a Whisper model file, prompting user to download/select when necessary. +pub fn find_model_file() -> Result { + let models_dir_buf = models_dir_path(); + let models_dir = models_dir_buf.as_path(); + if !models_dir.exists() { + create_dir_all(models_dir).with_context(|| { + format!( + "Failed to create models directory: {}", + models_dir.display() + ) + })?; + } + + if let Ok(env_model) = env::var("WHISPER_MODEL") { + let p = PathBuf::from(env_model); + if p.is_file() { + let _ = std::fs::write(models_dir.join(".last_model"), p.display().to_string()); + return Ok(p); + } + } + + let mut candidates: Vec = Vec::new(); + let rd = std::fs::read_dir(models_dir) + .with_context(|| format!("Failed to read models directory: {}", models_dir.display()))?; + for entry in rd { + let entry = entry?; + let path = entry.path(); + if path.is_file() { + if let Some(ext) = path + .extension() + .and_then(|s| s.to_str()) + .map(|s| s.to_lowercase()) + { + if ext == "bin" { + candidates.push(path); + } + } + } + } + + if candidates.is_empty() { + eprintln!( + "WARN: No Whisper model files (*.bin) found in {}.", + models_dir.display() + ); + eprint!("Would you like to download models now? [Y/n]: "); + io::stderr().flush().ok(); + let mut input = String::new(); + io::stdin().read_line(&mut input).ok(); + let ans = input.trim().to_lowercase(); + if ans.is_empty() || ans == "y" || ans == "yes" { + if let Err(e) = models::run_interactive_model_downloader() { + eprintln!("ERROR: Downloader failed: {:#}", e); + } + candidates.clear(); + let rd2 = std::fs::read_dir(models_dir).with_context(|| { + format!("Failed to read models directory: {}", models_dir.display()) + })?; + for entry in rd2 { + let entry = entry?; + let path = entry.path(); + if path.is_file() { + if let Some(ext) = path + .extension() + .and_then(|s| s.to_str()) + .map(|s| s.to_lowercase()) + { + if ext == "bin" { + candidates.push(path); + } + } + } + } + } + } + + if candidates.is_empty() { + return Err(anyhow!( + "No Whisper model files (*.bin) available in {}", + models_dir.display() + )); + } + + if candidates.len() == 1 { + let only = candidates.remove(0); + let _ = std::fs::write(models_dir.join(".last_model"), only.display().to_string()); + return Ok(only); + } + + let last_file = models_dir.join(".last_model"); + if let Ok(prev) = std::fs::read_to_string(&last_file) { + let prev = prev.trim(); + if !prev.is_empty() { + let p = PathBuf::from(prev); + if p.is_file() { + if candidates.iter().any(|c| c == &p) { + eprintln!("INFO: Using previously selected model: {}", p.display()); + return Ok(p); + } + } + } + } + + eprintln!("Multiple Whisper models found in {}:", models_dir.display()); + for (i, p) in candidates.iter().enumerate() { + eprintln!(" {}) {}", i + 1, p.display()); + } + eprint!("Select model by number [1-{}]: ", candidates.len()); + io::stderr().flush().ok(); + let mut input = String::new(); + io::stdin() + .read_line(&mut input) + .context("Failed to read selection")?; + let sel: usize = input + .trim() + .parse() + .map_err(|_| anyhow!("Invalid selection: {}", input.trim()))?; + if sel == 0 || sel > candidates.len() { + return Err(anyhow!("Selection out of range")); + } + let chosen = candidates.swap_remove(sel - 1); + let _ = std::fs::write(models_dir.join(".last_model"), chosen.display().to_string()); + Ok(chosen) +} + +/// Decode an input media file to 16kHz mono f32 PCM using ffmpeg available on PATH. +pub fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result> { + let output = Command::new("ffmpeg") + .arg("-i") + .arg(audio_path) + .arg("-f") + .arg("f32le") + .arg("-ac") + .arg("1") + .arg("-ar") + .arg("16000") + .arg("pipe:1") + .output() + .with_context(|| format!("Failed to execute ffmpeg for {}", audio_path.display()))?; + if !output.status.success() { + return Err(anyhow!( + "ffmpeg failed for {}: {}", + audio_path.display(), + String::from_utf8_lossy(&output.stderr) + )); + } + let bytes = output.stdout; + if bytes.len() % 4 != 0 { + let truncated = bytes.len() - (bytes.len() % 4); + let mut v = Vec::with_capacity(truncated / 4); + for chunk in bytes[..truncated].chunks_exact(4) { + let arr = [chunk[0], chunk[1], chunk[2], chunk[3]]; + v.push(f32::from_le_bytes(arr)); + } + Ok(v) + } else { + let mut v = Vec::with_capacity(bytes.len() / 4); + for chunk in bytes.chunks_exact(4) { + let arr = [chunk[0], chunk[1], chunk[2], chunk[3]]; + v.push(f32::from_le_bytes(arr)); + } + Ok(v) + } +} diff --git a/src/main.rs b/src/main.rs index 169459e..876f67d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,22 +1,16 @@ use std::fs::{File, create_dir_all}; use std::io::{self, Read, Write}; use std::path::{Path, PathBuf}; -use std::process::Command; -use std::env; -use anyhow::{anyhow, Context, Result}; +use anyhow::{Context, Result, anyhow}; use clap::{Parser, Subcommand}; -use serde::{Deserialize, Serialize}; -use chrono::Local; -use std::sync::atomic::{AtomicBool, AtomicU8, Ordering}; use clap_complete::Shell; +use serde::{Deserialize, Serialize}; +use std::sync::atomic::{AtomicU8, Ordering}; -// whisper-rs is used in backend module -mod models; -mod backend; -use backend::{BackendKind, select_backend, TranscribeBackend}; +// whisper-rs is used from the library crate +use polyscribe::backend::{BackendKind, select_backend}; -static LAST_MODEL_WRITTEN: AtomicBool = AtomicBool::new(false); static VERBOSE: AtomicU8 = AtomicU8::new(0); macro_rules! vlog { @@ -27,6 +21,7 @@ macro_rules! vlog { } } +#[allow(unused_macros)] macro_rules! warnlog { ($($arg:tt)*) => { eprintln!("WARN: {}", format!($($arg)*)); @@ -39,38 +34,6 @@ macro_rules! errorlog { } } -fn models_dir_path() -> PathBuf { - // Highest priority: explicit override - if let Ok(p) = env::var("POLYSCRIBE_MODELS_DIR") { - let pb = PathBuf::from(p); - if !pb.as_os_str().is_empty() { - return pb; - } - } - // In debug builds, keep local ./models for convenience - if cfg!(debug_assertions) { - return PathBuf::from("models"); - } - // In release builds, choose a user-writable data directory - if let Ok(xdg) = env::var("XDG_DATA_HOME") { - if !xdg.is_empty() { - return PathBuf::from(xdg).join("polyscribe").join("models"); - } - } - if let Ok(home) = env::var("HOME") { - if !home.is_empty() { - return PathBuf::from(home) - .join(".local") - .join("share") - .join("polyscribe") - .join("models"); - } - } - // Last resort fallback - PathBuf::from("models") -} - - #[derive(Subcommand, Debug, Clone)] enum AuxCommands { /// Generate shell completion script to stdout @@ -94,7 +57,12 @@ enum GpuBackendCli { } #[derive(Parser, Debug)] -#[command(name = "PolyScribe", bin_name = "polyscribe", version, about = "Merge JSON transcripts or transcribe audio using native whisper")] +#[command( + name = "PolyScribe", + bin_name = "polyscribe", + version, + about = "Merge JSON transcripts or transcribe audio using native whisper" +)] struct Args { /// Increase verbosity (-v, -vv). Logs go to stderr. #[arg(short = 'v', long = "verbose", action = clap::ArgAction::Count, global = true)] @@ -158,50 +126,13 @@ struct InputSegment { // other fields are ignored } -#[derive(Debug, Serialize, Clone)] -struct OutputEntry { - id: u64, - speaker: String, - start: f64, - end: f64, - text: String, -} +use polyscribe::{OutputEntry, date_prefix, models_dir_path, normalize_lang_code, render_srt}; #[derive(Debug, Serialize)] struct OutputRoot { items: Vec, } -fn date_prefix() -> String { - Local::now().format("%Y-%m-%d").to_string() -} - -fn format_srt_time(seconds: f64) -> String { - let total_ms = (seconds * 1000.0).round() as i64; - let ms = (total_ms % 1000) as i64; - let total_secs = total_ms / 1000; - let s = (total_secs % 60) as i64; - let m = ((total_secs / 60) % 60) as i64; - let h = (total_secs / 3600) as i64; - format!("{:02}:{:02}:{:02},{:03}", h, m, s, ms) -} - -fn render_srt(items: &[OutputEntry]) -> String { - let mut out = String::new(); - for (i, e) in items.iter().enumerate() { - let idx = i + 1; - out.push_str(&format!("{}\n", idx)); - out.push_str(&format!("{} --> {}\n", format_srt_time(e.start), format_srt_time(e.end))); - if !e.speaker.is_empty() { - out.push_str(&format!("{}: {}\n", e.speaker, e.text)); - } else { - out.push_str(&format!("{}\n", e.text)); - } - out.push('\n'); - } - out -} - fn sanitize_speaker_name(raw: &str) -> String { if let Some((prefix, rest)) = raw.split_once('-') { if !prefix.is_empty() && prefix.chars().all(|c| c.is_ascii_digit()) { @@ -220,13 +151,20 @@ fn prompt_speaker_name_for_path(path: &Path, default_name: &str, enabled: bool) .and_then(|s| s.to_str()) .map(|s| s.to_string()) .unwrap_or_else(|| path.to_string_lossy().to_string()); - eprint!("Enter speaker name for {} [default: {}]: ", display_owned, default_name); + eprint!( + "Enter speaker name for {} [default: {}]: ", + display_owned, default_name + ); io::stderr().flush().ok(); let mut buf = String::new(); match io::stdin().read_line(&mut buf) { Ok(_) => { let s = buf.trim(); - if s.is_empty() { default_name.to_string() } else { s.to_string() } + if s.is_empty() { + default_name.to_string() + } else { + s.to_string() + } } Err(_) => default_name.to_string(), } @@ -238,185 +176,20 @@ fn is_json_file(path: &Path) -> bool { } fn is_audio_file(path: &Path) -> bool { - if let Some(ext) = path.extension().and_then(|s| s.to_str()).map(|s| s.to_lowercase()) { + if let Some(ext) = path + .extension() + .and_then(|s| s.to_str()) + .map(|s| s.to_lowercase()) + { let exts = [ - "mp3","wav","m4a","mp4","aac","flac","ogg","wma","webm","mkv","mov","avi","m4b","3gp","opus","aiff","alac" + "mp3", "wav", "m4a", "mp4", "aac", "flac", "ogg", "wma", "webm", "mkv", "mov", "avi", + "m4b", "3gp", "opus", "aiff", "alac", ]; return exts.contains(&ext.as_str()); } false } -fn normalize_lang_code(input: &str) -> Option { - let mut s = input.trim().to_lowercase(); - if s.is_empty() || s == "auto" || s == "c" || s == "posix" { return None; } - if let Some((lhs, _)) = s.split_once('.') { s = lhs.to_string(); } - if let Some((lhs, _)) = s.split_once('_') { s = lhs.to_string(); } - let code = match s.as_str() { - // ISO codes directly - "en"=>"en","de"=>"de","es"=>"es","fr"=>"fr","it"=>"it","pt"=>"pt","nl"=>"nl","ru"=>"ru","pl"=>"pl", - "uk"=>"uk","cs"=>"cs","sv"=>"sv","no"=>"no","da"=>"da","fi"=>"fi","hu"=>"hu","tr"=>"tr","el"=>"el", - "zh"=>"zh","ja"=>"ja","ko"=>"ko","ar"=>"ar","he"=>"he","hi"=>"hi","ro"=>"ro","bg"=>"bg","sk"=>"sk", - // Common English names - "english"=>"en","german"=>"de","spanish"=>"es","french"=>"fr","italian"=>"it","portuguese"=>"pt", - "dutch"=>"nl","russian"=>"ru","polish"=>"pl","ukrainian"=>"uk","czech"=>"cs","swedish"=>"sv", - "norwegian"=>"no","danish"=>"da","finnish"=>"fi","hungarian"=>"hu","turkish"=>"tr","greek"=>"el", - "chinese"=>"zh","japanese"=>"ja","korean"=>"ko","arabic"=>"ar","hebrew"=>"he","hindi"=>"hi", - "romanian"=>"ro","bulgarian"=>"bg","slovak"=>"sk", - _ => return None, - }; - Some(code.to_string()) -} - - - -pub(crate) fn find_model_file() -> Result { - let models_dir_buf = models_dir_path(); - let models_dir = models_dir_buf.as_path(); - if !models_dir.exists() { - create_dir_all(models_dir).with_context(|| format!("Failed to create models directory: {}", models_dir.display()))?; - } - - // If env var WHISPER_MODEL is set and valid, prefer it - if let Ok(env_model) = env::var("WHISPER_MODEL") { - let p = PathBuf::from(env_model); - if p.is_file() { - // persist selection - let _ = std::fs::write(models_dir.join(".last_model"), p.display().to_string()); - LAST_MODEL_WRITTEN.store(true, Ordering::Relaxed); - return Ok(p); - } - } - - // Enumerate local models - let mut candidates: Vec = Vec::new(); - let rd = std::fs::read_dir(models_dir) - .with_context(|| format!("Failed to read models directory: {}", models_dir.display()))?; - for entry in rd { - let entry = entry?; - let path = entry.path(); - if path.is_file() { - if let Some(ext) = path.extension().and_then(|s| s.to_str()).map(|s| s.to_lowercase()) { - if ext == "bin" { - candidates.push(path); - } - } - } - } - - if candidates.is_empty() { - // In quiet mode we still prompt for models; suppress only non-essential logs - warnlog!("No Whisper model files (*.bin) found in {}.", models_dir.display()); - eprint!("Would you like to download models now? [Y/n]: "); - io::stderr().flush().ok(); - let mut input = String::new(); - io::stdin().read_line(&mut input).ok(); - let ans = input.trim().to_lowercase(); - if ans.is_empty() || ans == "y" || ans == "yes" { - if let Err(e) = models::run_interactive_model_downloader() { - errorlog!("Downloader failed: {:#}", e); - } - // Re-scan - candidates.clear(); - let rd2 = std::fs::read_dir(models_dir) - .with_context(|| format!("Failed to read models directory: {}", models_dir.display()))?; - for entry in rd2 { - let entry = entry?; - let path = entry.path(); - if path.is_file() { - if let Some(ext) = path.extension().and_then(|s| s.to_str()).map(|s| s.to_lowercase()) { - if ext == "bin" { - candidates.push(path); - } - } - } - } - } - } - - if candidates.is_empty() { - return Err(anyhow!("No Whisper model files (*.bin) available in {}", models_dir.display())); - } - - // If only one, persist and return it - if candidates.len() == 1 { - let only = candidates.remove(0); - let _ = std::fs::write(models_dir.join(".last_model"), only.display().to_string()); - LAST_MODEL_WRITTEN.store(true, Ordering::Relaxed); - return Ok(only); - } - - // If a previous selection exists and is still valid, use it - let last_file = models_dir.join(".last_model"); - if let Ok(prev) = std::fs::read_to_string(&last_file) { - let prev = prev.trim(); - if !prev.is_empty() { - let p = PathBuf::from(prev); - if p.is_file() { - // Also ensure it's one of the candidates (same dir) - if candidates.iter().any(|c| c == &p) { - vlog!(0, "Using previously selected model: {}", p.display()); - return Ok(p); - } - } - } - } - - // Multiple models and no previous selection: prompt user to choose, then persist - eprintln!("Multiple Whisper models found in {}:", models_dir.display()); - for (i, p) in candidates.iter().enumerate() { - eprintln!(" {}) {}", i + 1, p.display()); - } - eprint!("Select model by number [1-{}]: ", candidates.len()); - io::stderr().flush().ok(); - let mut input = String::new(); - io::stdin().read_line(&mut input).context("Failed to read selection")?; - let sel: usize = input.trim().parse().map_err(|_| anyhow!("Invalid selection: {}", input.trim()))?; - if sel == 0 || sel > candidates.len() { - return Err(anyhow!("Selection out of range")); - } - let chosen = candidates.swap_remove(sel - 1); - let _ = std::fs::write(models_dir.join(".last_model"), chosen.display().to_string()); - LAST_MODEL_WRITTEN.store(true, Ordering::Relaxed); - Ok(chosen) -} - -pub(crate) fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result> { - let output = Command::new("ffmpeg") - .arg("-i").arg(audio_path) - .arg("-f").arg("f32le") - .arg("-ac").arg("1") - .arg("-ar").arg("16000") - .arg("pipe:1") - .output() - .with_context(|| format!("Failed to execute ffmpeg for {}", audio_path.display()))?; - if !output.status.success() { - return Err(anyhow!( - "ffmpeg failed for {}: {}", - audio_path.display(), - String::from_utf8_lossy(&output.stderr) - )); - } - let bytes = output.stdout; - if bytes.len() % 4 != 0 { - // Truncate to nearest multiple of 4 bytes to avoid partial f32 - let truncated = bytes.len() - (bytes.len() % 4); - let mut v = Vec::with_capacity(truncated / 4); - for chunk in bytes[..truncated].chunks_exact(4) { - let arr = [chunk[0], chunk[1], chunk[2], chunk[3]]; - v.push(f32::from_le_bytes(arr)); - } - Ok(v) - } else { - let mut v = Vec::with_capacity(bytes.len() / 4); - for chunk in bytes.chunks_exact(4) { - let arr = [chunk[0], chunk[1], chunk[2], chunk[3]]; - v.push(f32::from_le_bytes(arr)); - } - Ok(v) - } -} - struct LastModelCleanup { path: PathBuf, } @@ -427,10 +200,9 @@ impl Drop for LastModelCleanup { } } - fn main() -> Result<()> { // Parse CLI - let mut args = Args::parse(); + let args = Args::parse(); // Initialize verbosity VERBOSE.store(args.verbose, Ordering::Relaxed); @@ -443,7 +215,7 @@ fn main() -> Result<()> { let mut cmd = Args::command(); let bin_name = cmd.get_name().to_string(); clap_complete::generate(*shell, &mut cmd, bin_name, &mut io::stdout()); - return Ok(()) + return Ok(()); } AuxCommands::Man => { let cmd = Args::command(); @@ -451,7 +223,7 @@ fn main() -> Result<()> { let mut out = Vec::new(); man.render(&mut out)?; io::stdout().write_all(&out)?; - return Ok(()) + return Ok(()); } } } @@ -460,7 +232,9 @@ fn main() -> Result<()> { let models_dir_buf = models_dir_path(); let last_model_path = models_dir_buf.join(".last_model"); // Ensure cleanup at end of program, regardless of exit path - let _last_model_cleanup = LastModelCleanup { path: last_model_path.clone() }; + let _last_model_cleanup = LastModelCleanup { + path: last_model_path.clone(), + }; // Select backend let requested = match args.gpu_backend { @@ -475,7 +249,7 @@ fn main() -> Result<()> { // If requested, run the interactive model downloader first. If no inputs were provided, exit after downloading. if args.download_models { - if let Err(e) = models::run_interactive_model_downloader() { + if let Err(e) = polyscribe::models::run_interactive_model_downloader() { errorlog!("Model downloader failed: {:#}", e); } if args.inputs.is_empty() { @@ -485,7 +259,7 @@ fn main() -> Result<()> { // If requested, update local models and exit unless inputs provided to continue if args.update_models { - if let Err(e) = models::update_local_models() { + if let Err(e) = polyscribe::models::update_local_models() { errorlog!("Model update failed: {:#}", e); return Err(e); } @@ -496,7 +270,7 @@ fn main() -> Result<()> { } // Determine inputs and optional output path - vlog!(1, "Parsed {} input(s)", args.inputs.len()); + vlog!(1, "Parsed {} input(s)", args.inputs.len()); let mut inputs = args.inputs; let mut output_path = args.output; if output_path.is_none() && inputs.len() >= 2 { @@ -519,11 +293,13 @@ fn main() -> Result<()> { }; let any_audio = inputs.iter().any(|p| is_audio_file(Path::new(p))); if any_audio && lang_hint.is_none() { - return Err(anyhow!("Please specify --language (e.g., --language en). Language detection was removed.")); + return Err(anyhow!( + "Please specify --language (e.g., --language en). Language detection was removed." + )); } if args.merge_and_separate { - vlog!(1, "Mode: merge-and-separate; output_dir={:?}", output_path); + vlog!(1, "Mode: merge-and-separate; output_dir={:?}", output_path); // Combined mode: write separate outputs per input and also a merged output set // Require an output directory let out_dir = match output_path.as_ref() { @@ -531,8 +307,9 @@ fn main() -> Result<()> { None => return Err(anyhow!("--merge-and-separate requires -o OUTPUT_DIR")), }; if !out_dir.as_os_str().is_empty() { - create_dir_all(&out_dir) - .with_context(|| format!("Failed to create output directory: {}", out_dir.display()))?; + create_dir_all(&out_dir).with_context(|| { + format!("Failed to create output directory: {}", out_dir.display()) + })?; } let mut merged_entries: Vec = Vec::new(); @@ -540,14 +317,22 @@ fn main() -> Result<()> { for input_path in &inputs { let path = Path::new(input_path); let default_speaker = sanitize_speaker_name( - path.file_stem().and_then(|s| s.to_str()).unwrap_or("speaker") + path.file_stem() + .and_then(|s| s.to_str()) + .unwrap_or("speaker"), ); - let speaker = prompt_speaker_name_for_path(path, &default_speaker, args.set_speaker_names); + let speaker = + prompt_speaker_name_for_path(path, &default_speaker, args.set_speaker_names); // Collect entries per file and extend merged let mut entries: Vec = Vec::new(); if is_audio_file(path) { - let items = sel.backend.transcribe(path, &speaker, lang_hint.as_deref(), args.gpu_layers)?; + let items = sel.backend.transcribe( + path, + &speaker, + lang_hint.as_deref(), + args.gpu_layers, + )?; entries.extend(items.into_iter()); } else if is_json_file(path) { let mut buf = String::new(); @@ -555,39 +340,67 @@ fn main() -> Result<()> { .with_context(|| format!("Failed to open: {}", input_path))? .read_to_string(&mut buf) .with_context(|| format!("Failed to read: {}", input_path))?; - let root: InputRoot = serde_json::from_str(&buf) - .with_context(|| format!("Invalid JSON transcript parsed from {}", input_path))?; + let root: InputRoot = serde_json::from_str(&buf).with_context(|| { + format!("Invalid JSON transcript parsed from {}", input_path) + })?; for seg in root.segments { - entries.push(OutputEntry { id: 0, speaker: speaker.clone(), start: seg.start, end: seg.end, text: seg.text }); + entries.push(OutputEntry { + id: 0, + speaker: speaker.clone(), + start: seg.start, + end: seg.end, + text: seg.text, + }); } } else { - return Err(anyhow!(format!("Unsupported input type (expected .json or audio media): {}", input_path))); + return Err(anyhow!(format!( + "Unsupported input type (expected .json or audio media): {}", + input_path + ))); } // Sort and reassign ids per file entries.sort_by(|a, b| { - match a.start.partial_cmp(&b.start) { Some(std::cmp::Ordering::Equal) | None => {} Some(o) => return o } - a.end.partial_cmp(&b.end).unwrap_or(std::cmp::Ordering::Equal) + match a.start.partial_cmp(&b.start) { + Some(std::cmp::Ordering::Equal) | None => {} + Some(o) => return o, + } + a.end + .partial_cmp(&b.end) + .unwrap_or(std::cmp::Ordering::Equal) }); - for (i, e) in entries.iter_mut().enumerate() { e.id = i as u64; } + for (i, e) in entries.iter_mut().enumerate() { + e.id = i as u64; + } // Write separate outputs to out_dir - let out = OutputRoot { items: entries.clone() }; - let stem = path.file_stem().and_then(|s| s.to_str()).unwrap_or("output"); + let out = OutputRoot { + items: entries.clone(), + }; + let stem = path + .file_stem() + .and_then(|s| s.to_str()) + .unwrap_or("output"); let date = date_prefix(); let base_name = format!("{}_{}", date, stem); let json_path = out_dir.join(format!("{}.json", &base_name)); let toml_path = out_dir.join(format!("{}.toml", &base_name)); let srt_path = out_dir.join(format!("{}.srt", &base_name)); - let mut json_file = File::create(&json_path) - .with_context(|| format!("Failed to create output file: {}", json_path.display()))?; - serde_json::to_writer_pretty(&mut json_file, &out)?; writeln!(&mut json_file)?; + let mut json_file = File::create(&json_path).with_context(|| { + format!("Failed to create output file: {}", json_path.display()) + })?; + serde_json::to_writer_pretty(&mut json_file, &out)?; + writeln!(&mut json_file)?; let toml_str = toml::to_string_pretty(&out)?; - let mut toml_file = File::create(&toml_path) - .with_context(|| format!("Failed to create output file: {}", toml_path.display()))?; - toml_file.write_all(toml_str.as_bytes())?; if !toml_str.ends_with('\n') { writeln!(&mut toml_file)?; } + let mut toml_file = File::create(&toml_path).with_context(|| { + format!("Failed to create output file: {}", toml_path.display()) + })?; + toml_file.write_all(toml_str.as_bytes())?; + if !toml_str.ends_with('\n') { + writeln!(&mut toml_file)?; + } let srt_str = render_srt(&out.items); let mut srt_file = File::create(&srt_path) @@ -600,11 +413,20 @@ fn main() -> Result<()> { // Now write merged output set into out_dir merged_entries.sort_by(|a, b| { - match a.start.partial_cmp(&b.start) { Some(std::cmp::Ordering::Equal) | None => {} Some(o) => return o } - a.end.partial_cmp(&b.end).unwrap_or(std::cmp::Ordering::Equal) + match a.start.partial_cmp(&b.start) { + Some(std::cmp::Ordering::Equal) | None => {} + Some(o) => return o, + } + a.end + .partial_cmp(&b.end) + .unwrap_or(std::cmp::Ordering::Equal) }); - for (i, e) in merged_entries.iter_mut().enumerate() { e.id = i as u64; } - let merged_out = OutputRoot { items: merged_entries }; + for (i, e) in merged_entries.iter_mut().enumerate() { + e.id = i as u64; + } + let merged_out = OutputRoot { + items: merged_entries, + }; let date = date_prefix(); let merged_base = format!("{}_merged", date); @@ -614,19 +436,23 @@ fn main() -> Result<()> { let mut mj = File::create(&m_json) .with_context(|| format!("Failed to create output file: {}", m_json.display()))?; - serde_json::to_writer_pretty(&mut mj, &merged_out)?; writeln!(&mut mj)?; + serde_json::to_writer_pretty(&mut mj, &merged_out)?; + writeln!(&mut mj)?; let m_toml_str = toml::to_string_pretty(&merged_out)?; let mut mt = File::create(&m_toml) .with_context(|| format!("Failed to create output file: {}", m_toml.display()))?; - mt.write_all(m_toml_str.as_bytes())?; if !m_toml_str.ends_with('\n') { writeln!(&mut mt)?; } + mt.write_all(m_toml_str.as_bytes())?; + if !m_toml_str.ends_with('\n') { + writeln!(&mut mt)?; + } let m_srt_str = render_srt(&merged_out.items); let mut ms = File::create(&m_srt) .with_context(|| format!("Failed to create output file: {}", m_srt.display()))?; ms.write_all(m_srt_str.as_bytes())?; } else if args.merge { - vlog!(1, "Mode: merge; output_base={:?}", output_path); + vlog!(1, "Mode: merge; output_base={:?}", output_path); // MERGED MODE (previous default) let mut entries: Vec = Vec::new(); for input_path in &inputs { @@ -634,14 +460,22 @@ fn main() -> Result<()> { let default_speaker = sanitize_speaker_name( path.file_stem() .and_then(|s| s.to_str()) - .unwrap_or("speaker") + .unwrap_or("speaker"), ); - let speaker = prompt_speaker_name_for_path(path, &default_speaker, args.set_speaker_names); + let speaker = + prompt_speaker_name_for_path(path, &default_speaker, args.set_speaker_names); let mut buf = String::new(); if is_audio_file(path) { - let items = sel.backend.transcribe(path, &speaker, lang_hint.as_deref(), args.gpu_layers)?; - for e in items { entries.push(e); } + let items = sel.backend.transcribe( + path, + &speaker, + lang_hint.as_deref(), + args.gpu_layers, + )?; + for e in items { + entries.push(e); + } continue; } else if is_json_file(path) { File::open(path) @@ -649,7 +483,10 @@ fn main() -> Result<()> { .read_to_string(&mut buf) .with_context(|| format!("Failed to read: {}", input_path))?; } else { - return Err(anyhow!(format!("Unsupported input type (expected .json or audio media): {}", input_path))); + return Err(anyhow!(format!( + "Unsupported input type (expected .json or audio media): {}", + input_path + ))); } let root: InputRoot = serde_json::from_str(&buf) @@ -676,7 +513,9 @@ fn main() -> Result<()> { .partial_cmp(&b.end) .unwrap_or(std::cmp::Ordering::Equal) }); - for (i, e) in entries.iter_mut().enumerate() { e.id = i as u64; } + for (i, e) in entries.iter_mut().enumerate() { + e.id = i as u64; + } let out = OutputRoot { items: entries }; if let Some(path) = output_path { @@ -685,11 +524,17 @@ fn main() -> Result<()> { if let Some(parent) = parent_opt { if !parent.as_os_str().is_empty() { create_dir_all(parent).with_context(|| { - format!("Failed to create parent directory for output: {}", parent.display()) + format!( + "Failed to create parent directory for output: {}", + parent.display() + ) })?; } } - let stem = base_path.file_stem().and_then(|s| s.to_str()).unwrap_or("output"); + let stem = base_path + .file_stem() + .and_then(|s| s.to_str()) + .unwrap_or("output"); let date = date_prefix(); let base_name = format!("{}_{}", date, stem); let dir = parent_opt.unwrap_or(Path::new("")); @@ -697,14 +542,20 @@ fn main() -> Result<()> { let toml_path = dir.join(format!("{}.toml", &base_name)); let srt_path = dir.join(format!("{}.srt", &base_name)); - let mut json_file = File::create(&json_path) - .with_context(|| format!("Failed to create output file: {}", json_path.display()))?; - serde_json::to_writer_pretty(&mut json_file, &out)?; writeln!(&mut json_file)?; + let mut json_file = File::create(&json_path).with_context(|| { + format!("Failed to create output file: {}", json_path.display()) + })?; + serde_json::to_writer_pretty(&mut json_file, &out)?; + writeln!(&mut json_file)?; let toml_str = toml::to_string_pretty(&out)?; - let mut toml_file = File::create(&toml_path) - .with_context(|| format!("Failed to create output file: {}", toml_path.display()))?; - toml_file.write_all(toml_str.as_bytes())?; if !toml_str.ends_with('\n') { writeln!(&mut toml_file)?; } + let mut toml_file = File::create(&toml_path).with_context(|| { + format!("Failed to create output file: {}", toml_path.display()) + })?; + toml_file.write_all(toml_str.as_bytes())?; + if !toml_str.ends_with('\n') { + writeln!(&mut toml_file)?; + } let srt_str = render_srt(&out.items); let mut srt_file = File::create(&srt_path) @@ -713,35 +564,48 @@ fn main() -> Result<()> { } else { let stdout = io::stdout(); let mut handle = stdout.lock(); - serde_json::to_writer_pretty(&mut handle, &out)?; writeln!(&mut handle)?; + serde_json::to_writer_pretty(&mut handle, &out)?; + writeln!(&mut handle)?; } } else { vlog!(1, "Mode: separate; output_dir={:?}", output_path); // SEPARATE MODE (default now) // If writing to stdout with multiple inputs, not supported if output_path.is_none() && inputs.len() > 1 { - return Err(anyhow!("Multiple inputs without --merge require -o OUTPUT_DIR to write separate files")); + return Err(anyhow!( + "Multiple inputs without --merge require -o OUTPUT_DIR to write separate files" + )); } // If output_path is provided, treat it as a directory. Create it. let out_dir: Option = output_path.as_ref().map(|p| PathBuf::from(p)); if let Some(dir) = &out_dir { if !dir.as_os_str().is_empty() { - create_dir_all(dir).with_context(|| format!("Failed to create output directory: {}", dir.display()))?; + create_dir_all(dir).with_context(|| { + format!("Failed to create output directory: {}", dir.display()) + })?; } } for input_path in &inputs { let path = Path::new(input_path); let default_speaker = sanitize_speaker_name( - path.file_stem().and_then(|s| s.to_str()).unwrap_or("speaker") + path.file_stem() + .and_then(|s| s.to_str()) + .unwrap_or("speaker"), ); - let speaker = prompt_speaker_name_for_path(path, &default_speaker, args.set_speaker_names); + let speaker = + prompt_speaker_name_for_path(path, &default_speaker, args.set_speaker_names); // Collect entries per file let mut entries: Vec = Vec::new(); if is_audio_file(path) { - let items = sel.backend.transcribe(path, &speaker, lang_hint.as_deref(), args.gpu_layers)?; + let items = sel.backend.transcribe( + path, + &speaker, + lang_hint.as_deref(), + args.gpu_layers, + )?; entries.extend(items); } else if is_json_file(path) { let mut buf = String::new(); @@ -749,50 +613,78 @@ fn main() -> Result<()> { .with_context(|| format!("Failed to open: {}", input_path))? .read_to_string(&mut buf) .with_context(|| format!("Failed to read: {}", input_path))?; - let root: InputRoot = serde_json::from_str(&buf) - .with_context(|| format!("Invalid JSON transcript parsed from {}", input_path))?; + let root: InputRoot = serde_json::from_str(&buf).with_context(|| { + format!("Invalid JSON transcript parsed from {}", input_path) + })?; for seg in root.segments { - entries.push(OutputEntry { id: 0, speaker: speaker.clone(), start: seg.start, end: seg.end, text: seg.text }); + entries.push(OutputEntry { + id: 0, + speaker: speaker.clone(), + start: seg.start, + end: seg.end, + text: seg.text, + }); } } else { - return Err(anyhow!(format!("Unsupported input type (expected .json or audio media): {}", input_path))); + return Err(anyhow!(format!( + "Unsupported input type (expected .json or audio media): {}", + input_path + ))); } // Sort and reassign ids per file entries.sort_by(|a, b| { - match a.start.partial_cmp(&b.start) { Some(std::cmp::Ordering::Equal) | None => {} Some(o) => return o } - a.end.partial_cmp(&b.end).unwrap_or(std::cmp::Ordering::Equal) + match a.start.partial_cmp(&b.start) { + Some(std::cmp::Ordering::Equal) | None => {} + Some(o) => return o, + } + a.end + .partial_cmp(&b.end) + .unwrap_or(std::cmp::Ordering::Equal) }); - for (i, e) in entries.iter_mut().enumerate() { e.id = i as u64; } + for (i, e) in entries.iter_mut().enumerate() { + e.id = i as u64; + } let out = OutputRoot { items: entries }; if let Some(dir) = &out_dir { // Build file names using input stem - let stem = path.file_stem().and_then(|s| s.to_str()).unwrap_or("output"); + let stem = path + .file_stem() + .and_then(|s| s.to_str()) + .unwrap_or("output"); let date = date_prefix(); - let base_name = format!("{}_{}", date, stem); + let base_name = format!("{date}_{stem}"); let json_path = dir.join(format!("{}.json", &base_name)); let toml_path = dir.join(format!("{}.toml", &base_name)); let srt_path = dir.join(format!("{}.srt", &base_name)); - let mut json_file = File::create(&json_path) - .with_context(|| format!("Failed to create output file: {}", json_path.display()))?; - serde_json::to_writer_pretty(&mut json_file, &out)?; writeln!(&mut json_file)?; + let mut json_file = File::create(&json_path).with_context(|| { + format!("Failed to create output file: {}", json_path.display()) + })?; + serde_json::to_writer_pretty(&mut json_file, &out)?; + writeln!(&mut json_file)?; let toml_str = toml::to_string_pretty(&out)?; - let mut toml_file = File::create(&toml_path) - .with_context(|| format!("Failed to create output file: {}", toml_path.display()))?; - toml_file.write_all(toml_str.as_bytes())?; if !toml_str.ends_with('\n') { writeln!(&mut toml_file)?; } + let mut toml_file = File::create(&toml_path).with_context(|| { + format!("Failed to create output file: {}", toml_path.display()) + })?; + toml_file.write_all(toml_str.as_bytes())?; + if !toml_str.ends_with('\n') { + writeln!(&mut toml_file)?; + } let srt_str = render_srt(&out.items); - let mut srt_file = File::create(&srt_path) - .with_context(|| format!("Failed to create output file: {}", srt_path.display()))?; + let mut srt_file = File::create(&srt_path).with_context(|| { + format!("Failed to create output file: {}", srt_path.display()) + })?; srt_file.write_all(srt_str.as_bytes())?; } else { // stdout (only single input reaches here) let stdout = io::stdout(); let mut handle = stdout.lock(); - serde_json::to_writer_pretty(&mut handle, &out)?; writeln!(&mut handle)?; + serde_json::to_writer_pretty(&mut handle, &out)?; + writeln!(&mut handle)?; } } } @@ -800,16 +692,14 @@ fn main() -> Result<()> { Ok(()) } - - #[cfg(test)] mod tests { use super::*; - use std::fs; - use std::io::Write; - use std::env as std_env; use clap::CommandFactory; - use super::backend::*; + use polyscribe::backend::*; + use polyscribe::format_srt_time; + use std::env as std_env; + use std::fs; #[test] fn test_cli_name_polyscribe() { @@ -827,7 +717,6 @@ mod tests { } assert!(!last.exists(), ".last_model should be removed on drop"); } - use super::*; use std::path::Path; #[test] @@ -844,8 +733,20 @@ mod tests { #[test] fn test_render_srt_with_and_without_speaker() { let items = vec![ - OutputEntry { id: 0, speaker: "Alice".to_string(), start: 0.0, end: 1.0, text: "Hello".to_string() }, - OutputEntry { id: 1, speaker: String::new(), start: 1.0, end: 2.0, text: "World".to_string() }, + OutputEntry { + id: 0, + speaker: "Alice".to_string(), + start: 0.0, + end: 1.0, + text: "Hello".to_string(), + }, + OutputEntry { + id: 1, + speaker: String::new(), + start: 1.0, + end: 2.0, + text: "World".to_string(), + }, ]; let srt = render_srt(&items); let expected = "1\n00:00:00,000 --> 00:00:01,000\nAlice: Hello\n\n2\n00:00:01,000 --> 00:00:02,000\nWorld\n\n"; @@ -890,7 +791,12 @@ mod tests { let d = date_prefix(); assert_eq!(d.len(), 10); let bytes = d.as_bytes(); - assert!(bytes[0].is_ascii_digit() && bytes[1].is_ascii_digit() && bytes[2].is_ascii_digit() && bytes[3].is_ascii_digit()); + assert!( + bytes[0].is_ascii_digit() + && bytes[1].is_ascii_digit() + && bytes[2].is_ascii_digit() + && bytes[3].is_ascii_digit() + ); assert_eq!(bytes[4], b'-'); assert!(bytes[5].is_ascii_digit() && bytes[6].is_ascii_digit()); assert_eq!(bytes[7], b'-'); @@ -901,35 +807,54 @@ mod tests { #[cfg(debug_assertions)] fn test_models_dir_path_default_debug_and_env_override() { // clear override - unsafe { std_env::remove_var("POLYSCRIBE_MODELS_DIR"); } + unsafe { + std_env::remove_var("POLYSCRIBE_MODELS_DIR"); + } assert_eq!(models_dir_path(), PathBuf::from("models")); // override let tmp = tempfile::tempdir().unwrap(); - unsafe { std_env::set_var("POLYSCRIBE_MODELS_DIR", tmp.path()); } + unsafe { + std_env::set_var("POLYSCRIBE_MODELS_DIR", tmp.path()); + } assert_eq!(models_dir_path(), tmp.path().to_path_buf()); // cleanup - unsafe { std_env::remove_var("POLYSCRIBE_MODELS_DIR"); } + unsafe { + std_env::remove_var("POLYSCRIBE_MODELS_DIR"); + } } #[test] #[cfg(not(debug_assertions))] fn test_models_dir_path_default_release() { // Ensure override is cleared - unsafe { std_env::remove_var("POLYSCRIBE_MODELS_DIR"); } + unsafe { + std_env::remove_var("POLYSCRIBE_MODELS_DIR"); + } // Prefer XDG_DATA_HOME when set let tmp_xdg = tempfile::tempdir().unwrap(); unsafe { std_env::set_var("XDG_DATA_HOME", tmp_xdg.path()); std_env::remove_var("HOME"); } - assert_eq!(models_dir_path(), tmp_xdg.path().join("polyscribe").join("models")); + assert_eq!( + models_dir_path(), + tmp_xdg.path().join("polyscribe").join("models") + ); // Else fall back to HOME/.local/share let tmp_home = tempfile::tempdir().unwrap(); unsafe { std_env::remove_var("XDG_DATA_HOME"); std_env::set_var("HOME", tmp_home.path()); } - assert_eq!(models_dir_path(), tmp_home.path().join(".local").join("share").join("polyscribe").join("models")); + assert_eq!( + models_dir_path(), + tmp_home + .path() + .join(".local") + .join("share") + .join("polyscribe") + .join("models") + ); // Cleanup unsafe { std_env::remove_var("XDG_DATA_HOME"); @@ -959,15 +884,22 @@ mod tests { let sel = select_backend(BackendKind::Auto, false).unwrap(); assert_eq!(sel.chosen, BackendKind::Cpu); // Vulkan only - unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1"); } + unsafe { + std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1"); + } let sel = select_backend(BackendKind::Auto, false).unwrap(); assert_eq!(sel.chosen, BackendKind::Vulkan); // HIP preferred over Vulkan - unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1"); std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); } + unsafe { + std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1"); + std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); + } let sel = select_backend(BackendKind::Auto, false).unwrap(); assert_eq!(sel.chosen, BackendKind::Hip); // CUDA preferred over HIP - unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); } + unsafe { + std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); + } let sel = select_backend(BackendKind::Auto, false).unwrap(); assert_eq!(sel.chosen, BackendKind::Cuda); // Cleanup @@ -990,15 +922,25 @@ mod tests { assert!(select_backend(BackendKind::Hip, false).is_err()); assert!(select_backend(BackendKind::Vulkan, false).is_err()); // Turn on CUDA only - unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); } + unsafe { + std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); + } assert!(select_backend(BackendKind::Cuda, false).is_ok()); // Turn on HIP only - unsafe { std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA"); std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1"); } + unsafe { + std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA"); + std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1"); + } assert!(select_backend(BackendKind::Hip, false).is_ok()); // Turn on Vulkan only - unsafe { std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP"); std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1"); } + unsafe { + std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP"); + std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1"); + } assert!(select_backend(BackendKind::Vulkan, false).is_ok()); // Cleanup - unsafe { std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); } + unsafe { + std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); + } } } diff --git a/src/models.rs b/src/models.rs index a303c68..ddb54d8 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,14 +1,14 @@ +use std::collections::BTreeMap; +use std::env; use std::fs::{File, create_dir_all}; use std::io::{self, Read, Write}; use std::path::Path; -use std::collections::BTreeMap; use std::time::Duration; -use std::env; -use anyhow::{anyhow, Context, Result}; -use serde::Deserialize; +use anyhow::{Context, Result, anyhow}; use reqwest::blocking::Client; use reqwest::redirect::Policy; +use serde::Deserialize; use sha2::{Digest, Sha256}; // Print to stderr only when not in quiet mode @@ -80,22 +80,33 @@ fn human_size(bytes: u64) -> String { const MB: f64 = KB * 1024.0; const GB: f64 = MB * 1024.0; let b = bytes as f64; - if b >= GB { format!("{:.2} GiB", b / GB) } - else if b >= MB { format!("{:.2} MiB", b / MB) } - else if b >= KB { format!("{:.2} KiB", b / KB) } - else { format!("{} B", bytes) } + if b >= GB { + format!("{:.2} GiB", b / GB) + } else if b >= MB { + format!("{:.2} MiB", b / MB) + } else if b >= KB { + format!("{:.2} KiB", b / KB) + } else { + format!("{} B", bytes) + } } fn to_hex_lower(bytes: &[u8]) -> String { let mut s = String::with_capacity(bytes.len() * 2); - for b in bytes { s.push_str(&format!("{:02x}", b)); } + for b in bytes { + s.push_str(&format!("{:02x}", b)); + } s } fn expected_sha_from_sibling(s: &HFSibling) -> Option { - if let Some(h) = &s.sha256 { return Some(h.to_lowercase()); } + if let Some(h) = &s.sha256 { + return Some(h.to_lowercase()); + } if let Some(lfs) = &s.lfs { - if let Some(h) = &lfs.sha256 { return Some(h.to_lowercase()); } + if let Some(h) = &lfs.sha256 { + return Some(h.to_lowercase()); + } if let Some(oid) = &lfs.oid { // e.g. "sha256:abcdef..." if let Some(rest) = oid.strip_prefix("sha256:") { @@ -107,15 +118,23 @@ fn expected_sha_from_sibling(s: &HFSibling) -> Option { } fn size_from_sibling(s: &HFSibling) -> Option { - if let Some(sz) = s.size { return Some(sz); } - if let Some(lfs) = &s.lfs { return lfs.size; } + if let Some(sz) = s.size { + return Some(sz); + } + if let Some(lfs) = &s.lfs { + return lfs.size; + } None } fn expected_sha_from_tree(s: &HFTreeItem) -> Option { - if let Some(h) = &s.sha256 { return Some(h.to_lowercase()); } + if let Some(h) = &s.sha256 { + return Some(h.to_lowercase()); + } if let Some(lfs) = &s.lfs { - if let Some(h) = &lfs.sha256 { return Some(h.to_lowercase()); } + if let Some(h) = &lfs.sha256 { + return Some(h.to_lowercase()); + } if let Some(oid) = &lfs.oid { if let Some(rest) = oid.strip_prefix("sha256:") { return Some(rest.to_lowercase().to_string()); @@ -126,8 +145,12 @@ fn expected_sha_from_tree(s: &HFTreeItem) -> Option { } fn size_from_tree(s: &HFTreeItem) -> Option { - if let Some(sz) = s.size { return Some(sz); } - if let Some(lfs) = &s.lfs { return lfs.size; } + if let Some(sz) = s.size { + return Some(sz); + } + if let Some(lfs) = &s.lfs { + return lfs.size; + } None } @@ -136,12 +159,20 @@ fn fill_meta_via_head(repo: &str, name: &str) -> (Option, Option) { .user_agent("PolyScribe/0.1 (+https://github.com/)") .redirect(Policy::none()) .timeout(Duration::from_secs(30)) - .build() { + .build() + { Ok(c) => c, Err(_) => return (None, None), }; - let url = format!("https://huggingface.co/{}/resolve/main/ggml-{}.bin", repo, name); - let resp = match head_client.head(url).send().and_then(|r| r.error_for_status()) { + let url = format!( + "https://huggingface.co/{}/resolve/main/ggml-{}.bin", + repo, name + ); + let resp = match head_client + .head(url) + .send() + .and_then(|r| r.error_for_status()) + { Ok(r) => r, Err(_) => return (None, None), }; @@ -179,21 +210,40 @@ fn fill_meta_via_head(repo: &str, name: &str) -> (Option, Option) { fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result> { qlog!("Fetching online data: listing models from {}...", repo); // Prefer the tree endpoint for reliable size/hash metadata, then fall back to model metadata - let tree_url = format!("https://huggingface.co/api/models/{}/tree/main?recursive=1", repo); + let tree_url = format!( + "https://huggingface.co/api/models/{}/tree/main?recursive=1", + repo + ); let mut out: Vec = Vec::new(); - match client.get(tree_url).send().and_then(|r| r.error_for_status()) { + match client + .get(tree_url) + .send() + .and_then(|r| r.error_for_status()) + { Ok(resp) => { match resp.json::>() { Ok(items) => { for it in items { let path = it.path.clone(); - if !(path.starts_with("ggml-") && path.ends_with(".bin")) { continue; } - let model_name = path.trim_start_matches("ggml-").trim_end_matches(".bin").to_string(); + if !(path.starts_with("ggml-") && path.ends_with(".bin")) { + continue; + } + let model_name = path + .trim_start_matches("ggml-") + .trim_end_matches(".bin") + .to_string(); let (base, subtype) = split_model_name(&model_name); let size = size_from_tree(&it).unwrap_or(0); let sha256 = expected_sha_from_tree(&it); - out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo: repo.to_string() }); + out.push(ModelEntry { + name: model_name, + base, + subtype, + size, + sha256, + repo: repo.to_string(), + }); } } Err(_) => { /* fall back below */ } @@ -210,30 +260,49 @@ fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result Result Result> { let mut v1 = hf_fetch_repo_models(client, "ggerganov/whisper.cpp")?; // main repo must succeed // Optional tinydiarize repo; ignore errors but log to stderr - let mut v2: Vec = match hf_fetch_repo_models(client, "akashmjn/tinydiarize-whisper.cpp") { - Ok(v) => v, - Err(e) => { - qlog!("Warning: failed to fetch optional repo akashmjn/tinydiarize-whisper.cpp: {:#}", e); - Vec::new() - } - }; + let mut v2: Vec = + match hf_fetch_repo_models(client, "akashmjn/tinydiarize-whisper.cpp") { + Ok(v) => v, + Err(e) => { + qlog!( + "Warning: failed to fetch optional repo akashmjn/tinydiarize-whisper.cpp: {:#}", + e + ); + Vec::new() + } + }; v1.append(&mut v2); // Deduplicate by name preferring ggerganov repo if duplicates let mut map: BTreeMap = BTreeMap::new(); for m in v1 { - map.entry(m.name.clone()).and_modify(|existing| { - if existing.repo != "ggerganov/whisper.cpp" && m.repo == "ggerganov/whisper.cpp" { - *existing = m.clone(); - } - }).or_insert(m); + map.entry(m.name.clone()) + .and_modify(|existing| { + if existing.repo != "ggerganov/whisper.cpp" && m.repo == "ggerganov/whisper.cpp" { + *existing = m.clone(); + } + }) + .or_insert(m); } let mut list: Vec = map.into_values().collect(); - list.sort_by(|a, b| a.base.cmp(&b.base).then(a.subtype.cmp(&b.subtype)).then(a.name.cmp(&b.name))); + list.sort_by(|a, b| { + a.base + .cmp(&b.base) + .then(a.subtype.cmp(&b.subtype)) + .then(a.name.cmp(&b.name)) + }); Ok(list) } - fn format_model_list(models: &[ModelEntry]) -> String { let mut out = String::new(); out.push_str("Available ggml Whisper models:\n"); @@ -305,7 +389,9 @@ fn format_model_list(models: &[ModelEntry]) -> String { )); idx += 1; } - out.push_str("\nEnter selection by indices (e.g., 1 3 5-7), or 'all', '*' for all, 'q' to cancel.\n"); + out.push_str( + "\nEnter selection by indices (e.g., 1 3 5-7), or 'all', '*' for all, 'q' to cancel.\n", + ); out } @@ -335,21 +421,33 @@ fn prompt_select_models_two_stage(models: &[ModelEntry]) -> Result() { - if i >= 1 && i <= bases.len() { Some(bases[i - 1].clone()) } else { None } + if i >= 1 && i <= bases.len() { + Some(bases[i - 1].clone()) + } else { + None + } } else if !s.is_empty() { // accept exact name match (case-insensitive) bases.iter().find(|b| b.eq_ignore_ascii_case(s)).cloned() - } else { None }; + } else { + None + }; if let Some(base) = chosen_base { // 2) Choose sub-type(s) within that base - let filtered: Vec = models.iter().filter(|m| m.base == base).cloned().collect(); + let filtered: Vec = + models.iter().filter(|m| m.base == base).cloned().collect(); if filtered.is_empty() { eprintln!("No models found for base '{}'.", base); continue; @@ -370,22 +468,32 @@ fn prompt_select_models_two_stage(models: &[ModelEntry]) -> Result = Vec::new(); if s2 == "all" || s2 == "*" { selected = (1..idx).collect(); } else if !s2.is_empty() { for part in s2.split(|c| c == ',' || c == ' ' || c == ';') { let part = part.trim(); - if part.is_empty() { continue; } + if part.is_empty() { + continue; + } if let Some((a, b)) = part.split_once('-') { if let (Ok(ia), Ok(ib)) = (a.parse::(), b.parse::()) { - if ia >= 1 && ib < idx && ia <= ib { selected.extend(ia..=ib); } + if ia >= 1 && ib < idx && ia <= ib { + selected.extend(ia..=ib); + } } } else if let Ok(i) = part.parse::() { - if i >= 1 && i < idx { selected.push(i); } + if i >= 1 && i < idx { + selected.push(i); + } } } } @@ -395,12 +503,17 @@ fn prompt_select_models_two_stage(models: &[ModelEntry]) -> Result = selected.into_iter().map(|i| filtered[index_map[i - 1]].clone()).collect(); + let chosen: Vec = selected + .into_iter() + .map(|i| filtered[index_map[i - 1]].clone()) + .collect(); return Ok(chosen); } } else { - eprintln!("Invalid base selection. Please enter a number from 1-{} or a base name.", bases.len()); - continue; + eprintln!( + "Invalid base selection. Please enter a number from 1-{} or a base name.", + bases.len() + ); } } } @@ -413,52 +526,30 @@ fn compute_file_sha256_hex(path: &Path) -> Result { let mut buf = [0u8; 1024 * 128]; loop { let n = reader.read(&mut buf).context("Read error during hashing")?; - if n == 0 { break; } + if n == 0 { + break; + } hasher.update(&buf[..n]); } Ok(to_hex_lower(&hasher.finalize())) } -fn models_dir_path() -> std::path::PathBuf { - // Highest priority: explicit override - if let Ok(p) = env::var("POLYSCRIBE_MODELS_DIR") { - let pb = std::path::PathBuf::from(p); - if !pb.as_os_str().is_empty() { return pb; } - } - // In debug builds, keep local ./models for convenience - if cfg!(debug_assertions) { - return std::path::PathBuf::from("models"); - } - // In release builds, choose a user-writable data directory - if let Ok(xdg) = env::var("XDG_DATA_HOME") { - if !xdg.is_empty() { - return std::path::PathBuf::from(xdg).join("polyscribe").join("models"); - } - } - if let Ok(home) = env::var("HOME") { - if !home.is_empty() { - return std::path::PathBuf::from(home) - .join(".local") - .join("share") - .join("polyscribe") - .join("models"); - } - } - // Last resort fallback - std::path::PathBuf::from("models") -} - +/// Interactively list and download Whisper models from Hugging Face into the models directory. pub fn run_interactive_model_downloader() -> Result<()> { - let models_dir_buf = models_dir_path(); + let models_dir_buf = crate::models_dir_path(); let models_dir = models_dir_buf.as_path(); - if !models_dir.exists() { create_dir_all(models_dir).context("Failed to create models directory")?; } + if !models_dir.exists() { + create_dir_all(models_dir).context("Failed to create models directory")?; + } let client = Client::builder() .user_agent("PolyScribe/0.1 (+https://github.com/)") .timeout(std::time::Duration::from_secs(600)) .build() .context("Failed to build HTTP client")?; - qlog!("Fetching online data: contacting Hugging Face to retrieve available models (this may take a moment)..."); + qlog!( + "Fetching online data: contacting Hugging Face to retrieve available models (this may take a moment)..." + ); let models = fetch_all_models(&client)?; if models.is_empty() { qlog!("No models found on Hugging Face listing. Please try again later."); @@ -470,12 +561,15 @@ pub fn run_interactive_model_downloader() -> Result<()> { return Ok(()); } for m in selected { - if let Err(e) = download_one_model(&client, models_dir, &m) { qlog!("Error: {:#}", e); } + if let Err(e) = download_one_model(&client, models_dir, &m) { + qlog!("Error: {:#}", e); + } } Ok(()) } -pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> { +/// Download a single model entry into the given models directory, verifying SHA-256 when available. +fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> { let final_path = models_dir.join(format!("ggml-{}.bin", entry.name)); // If the model already exists, verify against online metadata @@ -497,9 +591,10 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry } Err(e) => { qlog!( - "Warning: failed to hash existing {}: {}. Will re-download to ensure correctness.", - final_path.display(), e - ); + "Warning: failed to hash existing {}: {}. Will re-download to ensure correctness.", + final_path.display(), + e + ); } } } else if entry.size > 0 { @@ -508,20 +603,24 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry if md.len() == entry.size { qlog!( "Model {} appears up-to-date by size ({}).", - final_path.display(), entry.size + final_path.display(), + entry.size ); return Ok(()); } else { qlog!( "Local model {} size ({}) differs from online ({}). Updating...", - final_path.display(), md.len(), entry.size + final_path.display(), + md.len(), + entry.size ); } } Err(e) => { qlog!( "Warning: failed to stat existing {}: {}. Will re-download to ensure correctness.", - final_path.display(), e + final_path.display(), + e ); } } @@ -540,9 +639,16 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry if src_path.exists() { qlog!("Copying {} from {}...", entry.name, src_path.display()); let tmp_path = models_dir.join(format!("ggml-{}.bin.part", entry.name)); - if tmp_path.exists() { let _ = std::fs::remove_file(&tmp_path); } - std::fs::copy(&src_path, &tmp_path) - .with_context(|| format!("Failed to copy from {} to {}", src_path.display(), tmp_path.display()))?; + if tmp_path.exists() { + let _ = std::fs::remove_file(&tmp_path); + } + std::fs::copy(&src_path, &tmp_path).with_context(|| { + format!( + "Failed to copy from {} to {}", + src_path.display(), + tmp_path.display() + ) + })?; // Verify hash if available if let Some(expected) = &entry.sha256 { let got = compute_file_sha256_hex(&tmp_path)?; @@ -550,12 +656,16 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry let _ = std::fs::remove_file(&tmp_path); return Err(anyhow!( "SHA-256 mismatch for {} (copied): expected {}, got {}", - entry.name, expected, got + entry.name, + expected, + got )); } } // Replace existing file safely - if final_path.exists() { let _ = std::fs::remove_file(&final_path); } + if final_path.exists() { + let _ = std::fs::remove_file(&final_path); + } std::fs::rename(&tmp_path, &final_path) .with_context(|| format!("Failed to move into place: {}", final_path.display()))?; qlog!("Saved: {}", final_path.display()); @@ -563,8 +673,16 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry } } - let url = format!("https://huggingface.co/{}/resolve/main/ggml-{}.bin", entry.repo, entry.name); - qlog!("Downloading {} ({} | {})...", entry.name, human_size(entry.size), url); + let url = format!( + "https://huggingface.co/{}/resolve/main/ggml-{}.bin", + entry.repo, entry.name + ); + qlog!( + "Downloading {} ({} | {})...", + entry.name, + human_size(entry.size), + url + ); let mut resp = client .get(url) .send() @@ -577,14 +695,16 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry } let mut file = std::io::BufWriter::new( File::create(&tmp_path) - .with_context(|| format!("Failed to create {}", tmp_path.display()))? + .with_context(|| format!("Failed to create {}", tmp_path.display()))?, ); let mut hasher = Sha256::new(); let mut buf = [0u8; 1024 * 128]; loop { let n = resp.read(&mut buf).context("Network read error")?; - if n == 0 { break; } + if n == 0 { + break; + } hasher.update(&buf[..n]); file.write_all(&buf[..n]).context("Write error")?; } @@ -596,11 +716,16 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry let _ = std::fs::remove_file(&tmp_path); return Err(anyhow!( "SHA-256 mismatch for {}: expected {}, got {}", - entry.name, expected, got + entry.name, + expected, + got )); } } else { - qlog!("Warning: no SHA-256 available for {}. Skipping verification.", entry.name); + qlog!( + "Warning: no SHA-256 available for {}. Skipping verification.", + entry.name + ); } // Replace existing file safely if final_path.exists() { @@ -612,8 +737,9 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry Ok(()) } +/// Update locally stored models by re-downloading when size or hash does not match online metadata. pub fn update_local_models() -> Result<()> { - let models_dir_buf = models_dir_path(); + let models_dir_buf = crate::models_dir_path(); let models_dir = models_dir_buf.as_path(); if !models_dir.exists() { create_dir_all(models_dir).context("Failed to create models directory")?; @@ -627,13 +753,14 @@ pub fn update_local_models() -> Result<()> { .context("Failed to build HTTP client")?; // Obtain manifest: env override or online fetch - let models: Vec = if let Ok(manifest_path) = env::var("POLYSCRIBE_MODELS_MANIFEST") { + let models: Vec = if let Ok(manifest_path) = env::var("POLYSCRIBE_MODELS_MANIFEST") + { let data = std::fs::read_to_string(&manifest_path) .with_context(|| format!("Failed to read manifest at {}", manifest_path))?; let mut list: Vec = serde_json::from_str(&data) .with_context(|| format!("Invalid JSON manifest: {}", manifest_path))?; // sort for stability - list.sort_by(|a,b| a.name.cmp(&b.name)); + list.sort_by(|a, b| a.name.cmp(&b.name)); list } else { fetch_all_models(&client)? @@ -641,7 +768,9 @@ pub fn update_local_models() -> Result<()> { // Map name -> entry for fast lookup let mut map: BTreeMap = BTreeMap::new(); - for m in models { map.insert(m.name.clone(), m); } + for m in models { + map.insert(m.name.clone(), m); + } // Scan local ggml-*.bin models let rd = std::fs::read_dir(models_dir) @@ -649,10 +778,20 @@ pub fn update_local_models() -> Result<()> { for entry in rd { let entry = entry?; let path = entry.path(); - if !path.is_file() { continue; } - let fname = match path.file_name().and_then(|s| s.to_str()) { Some(s) => s.to_string(), None => continue }; - if !fname.starts_with("ggml-") || !fname.ends_with(".bin") { continue; } - let model_name = fname.trim_start_matches("ggml-").trim_end_matches(".bin").to_string(); + if !path.is_file() { + continue; + } + let fname = match path.file_name().and_then(|s| s.to_str()) { + Some(s) => s.to_string(), + None => continue, + }; + if !fname.starts_with("ggml-") || !fname.ends_with(".bin") { + continue; + } + let model_name = fname + .trim_start_matches("ggml-") + .trim_end_matches(".bin") + .to_string(); if let Some(remote) = map.get(&model_name) { // If SHA256 available, verify and update if mismatch @@ -664,11 +803,11 @@ pub fn update_local_models() -> Result<()> { continue; } else { qlog!( - "{} hash differs (local {}.. != remote {}..). Updating...", - fname, - &local_hash[..std::cmp::min(8, local_hash.len())], - &expected[..std::cmp::min(8, expected.len())] - ); + "{} hash differs (local {}.. != remote {}..). Updating...", + fname, + &local_hash[..std::cmp::min(8, local_hash.len())], + &expected[..std::cmp::min(8, expected.len())] + ); } } Err(e) => { @@ -683,7 +822,12 @@ pub fn update_local_models() -> Result<()> { continue; } Ok(md) => { - qlog!("{} size {} differs from remote {}. Updating...", fname, md.len(), remote.size); + qlog!( + "{} size {} differs from remote {}. Updating...", + fname, + md.len(), + remote.size + ); download_one_model(&client, models_dir, remote)?; } Err(e) => { @@ -702,20 +846,43 @@ pub fn update_local_models() -> Result<()> { Ok(()) } - #[cfg(test)] mod tests { use super::*; - use tempfile::tempdir; use std::fs; - use std::io::Write; + use tempfile::tempdir; #[test] fn test_format_model_list_spacing_and_structure() { let models = vec![ - ModelEntry { name: "tiny.en-q5_1".to_string(), base: "tiny".to_string(), subtype: "en-q5_1".to_string(), size: 1024*1024, sha256: Some("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef".to_string()), repo: "ggerganov/whisper.cpp".to_string() }, - ModelEntry { name: "tiny-q5_1".to_string(), base: "tiny".to_string(), subtype: "q5_1".to_string(), size: 2048, sha256: None, repo: "ggerganov/whisper.cpp".to_string() }, - ModelEntry { name: "base.en-q5_1".to_string(), base: "base".to_string(), subtype: "en-q5_1".to_string(), size: 10, sha256: Some("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string()), repo: "akashmjn/tinydiarize-whisper.cpp".to_string() }, + ModelEntry { + name: "tiny.en-q5_1".to_string(), + base: "tiny".to_string(), + subtype: "en-q5_1".to_string(), + size: 1024 * 1024, + sha256: Some( + "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef".to_string(), + ), + repo: "ggerganov/whisper.cpp".to_string(), + }, + ModelEntry { + name: "tiny-q5_1".to_string(), + base: "tiny".to_string(), + subtype: "q5_1".to_string(), + size: 2048, + sha256: None, + repo: "ggerganov/whisper.cpp".to_string(), + }, + ModelEntry { + name: "base.en-q5_1".to_string(), + base: "base".to_string(), + subtype: "en-q5_1".to_string(), + size: 10, + sha256: Some( + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string(), + ), + repo: "akashmjn/tinydiarize-whisper.cpp".to_string(), + }, ]; let s = format_model_list(&models); // Header present @@ -724,7 +891,10 @@ mod tests { assert!(s.contains("\ntiny:\n")); assert!(s.contains("\nbase:\n")); // No immediate double space before a bracket after parenthesis - assert!(!s.contains(") ["), "should not have double space immediately before bracket"); + assert!( + !s.contains(") ["), + "should not have double space immediately before bracket" + ); // Lines contain normalized spacing around pipes and no hash assert!(s.contains("[ggerganov/whisper.cpp | 1.00 MiB]")); assert!(s.contains("[ggerganov/whisper.cpp | 2.00 KiB]")); @@ -748,7 +918,9 @@ mod tests { hasher.update(data); let out = hasher.finalize(); let mut s = String::new(); - for b in out { s.push_str(&format!("{:02x}", b)); } + for b in out { + s.push_str(&format!("{:02x}", b)); + } s } @@ -786,7 +958,11 @@ mod tests { "repo": "ggerganov/whisper.cpp" } ]); - fs::write(&manifest_path, serde_json::to_string_pretty(&manifest).unwrap()).unwrap(); + fs::write( + &manifest_path, + serde_json::to_string_pretty(&manifest).unwrap(), + ) + .unwrap(); // Set env vars to force offline behavior and directories unsafe { @@ -807,34 +983,54 @@ mod tests { #[cfg(debug_assertions)] fn test_models_dir_path_default_debug_and_env_override_models_mod() { // clear override - unsafe { std::env::remove_var("POLYSCRIBE_MODELS_DIR"); } - assert_eq!(super::models_dir_path(), std::path::PathBuf::from("models")); + unsafe { + std::env::remove_var("POLYSCRIBE_MODELS_DIR"); + } + assert_eq!(crate::models_dir_path(), std::path::PathBuf::from("models")); // override let tmp = tempfile::tempdir().unwrap(); - unsafe { std::env::set_var("POLYSCRIBE_MODELS_DIR", tmp.path()); } - assert_eq!(super::models_dir_path(), tmp.path().to_path_buf()); + unsafe { + std::env::set_var("POLYSCRIBE_MODELS_DIR", tmp.path()); + } + assert_eq!(crate::models_dir_path(), tmp.path().to_path_buf()); // cleanup - unsafe { std::env::remove_var("POLYSCRIBE_MODELS_DIR"); } + unsafe { + std::env::remove_var("POLYSCRIBE_MODELS_DIR"); + } } #[test] #[cfg(not(debug_assertions))] fn test_models_dir_path_default_release_models_mod() { - unsafe { std::env::remove_var("POLYSCRIBE_MODELS_DIR"); } + unsafe { + std::env::remove_var("POLYSCRIBE_MODELS_DIR"); + } // With XDG_DATA_HOME set let tmp_xdg = tempfile::tempdir().unwrap(); unsafe { std::env::set_var("XDG_DATA_HOME", tmp_xdg.path()); std::env::remove_var("HOME"); } - assert_eq!(super::models_dir_path(), std::path::PathBuf::from(tmp_xdg.path()).join("polyscribe").join("models")); + assert_eq!( + crate::models_dir_path(), + std::path::PathBuf::from(tmp_xdg.path()) + .join("polyscribe") + .join("models") + ); // With HOME fallback let tmp_home = tempfile::tempdir().unwrap(); unsafe { std::env::remove_var("XDG_DATA_HOME"); std::env::set_var("HOME", tmp_home.path()); } - assert_eq!(super::models_dir_path(), std::path::PathBuf::from(tmp_home.path()).join(".local").join("share").join("polyscribe").join("models")); + assert_eq!( + super::models_dir_path(), + std::path::PathBuf::from(tmp_home.path()) + .join(".local") + .join("share") + .join("polyscribe") + .join("models") + ); unsafe { std::env::remove_var("XDG_DATA_HOME"); std::env::remove_var("HOME"); diff --git a/tests/integration_aux.rs b/tests/integration_aux.rs index 1703a62..3cb1706 100644 --- a/tests/integration_aux.rs +++ b/tests/integration_aux.rs @@ -1,6 +1,8 @@ use std::process::Command; -fn bin() -> &'static str { env!("CARGO_BIN_EXE_polyscribe") } +fn bin() -> &'static str { + env!("CARGO_BIN_EXE_polyscribe") +} #[test] fn aux_completions_bash_outputs_script() { @@ -9,11 +11,21 @@ fn aux_completions_bash_outputs_script() { .arg("bash") .output() .expect("failed to run polyscribe completions bash"); - assert!(out.status.success(), "completions bash exited with failure: {:?}", out.status); + assert!( + out.status.success(), + "completions bash exited with failure: {:?}", + out.status + ); let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8"); - assert!(!stdout.trim().is_empty(), "completions bash stdout is empty"); + assert!( + !stdout.trim().is_empty(), + "completions bash stdout is empty" + ); // Heuristic: bash completion scripts often contain 'complete -F' lines - assert!(stdout.contains("complete") || stdout.contains("_polyscribe"), "bash completion script did not contain expected markers"); + assert!( + stdout.contains("complete") || stdout.contains("_polyscribe"), + "bash completion script did not contain expected markers" + ); } #[test] @@ -23,11 +35,18 @@ fn aux_completions_zsh_outputs_script() { .arg("zsh") .output() .expect("failed to run polyscribe completions zsh"); - assert!(out.status.success(), "completions zsh exited with failure: {:?}", out.status); + assert!( + out.status.success(), + "completions zsh exited with failure: {:?}", + out.status + ); let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8"); assert!(!stdout.trim().is_empty(), "completions zsh stdout is empty"); // Heuristic: zsh completion scripts often start with '#compdef' - assert!(stdout.contains("#compdef") || stdout.contains("#compdef polyscribe"), "zsh completion script did not contain expected markers"); + assert!( + stdout.contains("#compdef") || stdout.contains("#compdef polyscribe"), + "zsh completion script did not contain expected markers" + ); } #[test] @@ -36,10 +55,21 @@ fn aux_man_outputs_roff() { .arg("man") .output() .expect("failed to run polyscribe man"); - assert!(out.status.success(), "man exited with failure: {:?}", out.status); + assert!( + out.status.success(), + "man exited with failure: {:?}", + out.status + ); let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8"); assert!(!stdout.trim().is_empty(), "man stdout is empty"); // clap_mangen typically emits roff with .TH and/or section headers - let looks_like_roff = stdout.contains(".TH ") || stdout.starts_with(".TH") || stdout.contains(".SH NAME") || stdout.contains(".SH SYNOPSIS"); - assert!(looks_like_roff, "man output does not look like a roff manpage; got: {}", &stdout.lines().take(3).collect::>().join(" | ")); + let looks_like_roff = stdout.contains(".TH ") + || stdout.starts_with(".TH") + || stdout.contains(".SH NAME") + || stdout.contains(".SH SYNOPSIS"); + assert!( + looks_like_roff, + "man output does not look like a roff manpage; got: {}", + &stdout.lines().take(3).collect::>().join(" | ") + ); } diff --git a/tests/integration_cli.rs b/tests/integration_cli.rs index 930ff0b..9b8ee06 100644 --- a/tests/integration_cli.rs +++ b/tests/integration_cli.rs @@ -30,7 +30,9 @@ impl TestDir { fs::create_dir_all(&p).expect("Failed to create temp dir"); TestDir(p) } - fn path(&self) -> &Path { &self.0 } + fn path(&self) -> &Path { + &self.0 + } } impl Drop for TestDir { fn drop(&mut self) { @@ -79,14 +81,32 @@ fn cli_writes_separate_outputs_by_default() { for e in entries { let p = e.unwrap().path(); if let Some(name) = p.file_name().and_then(|s| s.to_str()) { - if name.ends_with(".json") { json_paths.push(p.clone()); } - if name.ends_with(".toml") { count_toml += 1; } - if name.ends_with(".srt") { count_srt += 1; } + if name.ends_with(".json") { + json_paths.push(p.clone()); + } + if name.ends_with(".toml") { + count_toml += 1; + } + if name.ends_with(".srt") { + count_srt += 1; + } } } - assert!(json_paths.len() >= 2, "expected at least 2 JSON files, found {}", json_paths.len()); - assert!(count_toml >= 2, "expected at least 2 TOML files, found {}", count_toml); - assert!(count_srt >= 2, "expected at least 2 SRT files, found {}", count_srt); + assert!( + json_paths.len() >= 2, + "expected at least 2 JSON files, found {}", + json_paths.len() + ); + assert!( + count_toml >= 2, + "expected at least 2 TOML files, found {}", + count_toml + ); + assert!( + count_srt >= 2, + "expected at least 2 SRT files, found {}", + count_srt + ); // JSON contents are assumed valid if files exist; detailed parsing is covered elsewhere @@ -124,9 +144,15 @@ fn cli_merges_json_inputs_with_flag_and_writes_outputs_to_temp_dir() { for e in entries { let p = e.unwrap().path(); if let Some(name) = p.file_name().and_then(|s| s.to_str()) { - if name.ends_with("_out.json") { found_json = Some(p.clone()); } - if name.ends_with("_out.toml") { found_toml = Some(p.clone()); } - if name.ends_with("_out.srt") { found_srt = Some(p.clone()); } + if name.ends_with("_out.json") { + found_json = Some(p.clone()); + } + if name.ends_with("_out.toml") { + found_toml = Some(p.clone()); + } + if name.ends_with("_out.srt") { + found_srt = Some(p.clone()); + } } } let _json_path = found_json.expect("missing JSON output in temp dir"); @@ -154,7 +180,10 @@ fn cli_prints_json_to_stdout_when_no_output_path_merge_mode() { assert!(output.status.success(), "CLI failed"); let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8"); - assert!(stdout.contains("\"items\""), "stdout should contain items JSON array"); + assert!( + stdout.contains("\"items\""), + "stdout should contain items JSON array" + ); } #[test] @@ -187,16 +216,36 @@ fn cli_merge_and_separate_writes_both_kinds_of_outputs() { for e in entries { let p = e.unwrap().path(); if let Some(name) = p.file_name().and_then(|s| s.to_str()) { - if name.ends_with(".json") { json_count += 1; } - if name.ends_with(".toml") { toml_count += 1; } - if name.ends_with(".srt") { srt_count += 1; } - if name.ends_with("_merged.json") { merged_json = Some(p.clone()); } + if name.ends_with(".json") { + json_count += 1; + } + if name.ends_with(".toml") { + toml_count += 1; + } + if name.ends_with(".srt") { + srt_count += 1; + } + if name.ends_with("_merged.json") { + merged_json = Some(p.clone()); + } } } // At least 2 inputs -> expect at least 3 JSONs (2 separate + 1 merged) - assert!(json_count >= 3, "expected at least 3 JSON files, found {}", json_count); - assert!(toml_count >= 3, "expected at least 3 TOML files, found {}", toml_count); - assert!(srt_count >= 3, "expected at least 3 SRT files, found {}", srt_count); + assert!( + json_count >= 3, + "expected at least 3 JSON files, found {}", + json_count + ); + assert!( + toml_count >= 3, + "expected at least 3 TOML files, found {}", + toml_count + ); + assert!( + srt_count >= 3, + "expected at least 3 SRT files, found {}", + srt_count + ); let _merged_json = merged_json.expect("missing merged JSON output ending with _merged.json"); // Contents of merged JSON are validated by unit tests and other integration coverage @@ -205,7 +254,6 @@ fn cli_merge_and_separate_writes_both_kinds_of_outputs() { let _ = fs::remove_dir_all(&out_dir); } - #[test] fn cli_set_speaker_names_merge_prompts_and_uses_names() { use std::io::{Read as _, Write as _}; @@ -238,7 +286,8 @@ fn cli_set_speaker_names_merge_prompts_and_uses_names() { let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8"); let root: OutputRoot = serde_json::from_str(&stdout).unwrap(); - let speakers: std::collections::HashSet = root.items.into_iter().map(|e| e.speaker).collect(); + let speakers: std::collections::HashSet = + root.items.into_iter().map(|e| e.speaker).collect(); assert!(speakers.contains("Alpha"), "Alpha not found in speakers"); assert!(speakers.contains("Beta"), "Beta not found in speakers"); } @@ -279,12 +328,17 @@ fn cli_set_speaker_names_separate_single_input() { for e in fs::read_dir(&out_dir).unwrap() { let p = e.unwrap().path(); if let Some(name) = p.file_name().and_then(|s| s.to_str()) { - if name.ends_with(".json") { json_paths.push(p.clone()); } + if name.ends_with(".json") { + json_paths.push(p.clone()); + } } } assert!(!json_paths.is_empty(), "no JSON outputs created"); let mut buf = String::new(); - std::fs::File::open(&json_paths[0]).unwrap().read_to_string(&mut buf).unwrap(); + std::fs::File::open(&json_paths[0]) + .unwrap() + .read_to_string(&mut buf) + .unwrap(); let root: OutputRoot = serde_json::from_str(&buf).unwrap(); assert!(root.items.iter().all(|e| e.speaker == "ChosenOne"));