diff --git a/Cargo.lock b/Cargo.lock index 55e92d1..d3fdbdf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1061,6 +1061,7 @@ dependencies = [ "serde", "serde_json", "sha2", + "tempfile", "toml", "whisper-rs", ] diff --git a/Cargo.toml b/Cargo.toml index cd840d8..4f1c8ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,3 +13,6 @@ chrono = { version = "0.4", features = ["clock"] } reqwest = { version = "0.12", features = ["blocking", "json"] } sha2 = "0.10" whisper-rs = { git = "https://github.com/tazz4843/whisper-rs" } + +[dev-dependencies] +tempfile = "3" diff --git a/TODO.md b/TODO.md index f8e653d..e9959a6 100644 --- a/TODO.md +++ b/TODO.md @@ -2,7 +2,7 @@ - [x] update last_model to be only used during one run - [x] rename project to "PolyScribe" - [x] add tests -- update local models using hashes (--update-models) +- [x] update local models using hashes (--update-models) - create folder models/ if not present -> use /usr/share/polyscribe/models/ for release version, use ./models/ for development version - create missing folders for output files - for merging (command line flag) -> if not present, treat each file as separate output (--merge | -m) diff --git a/src/main.rs b/src/main.rs index bb7dfd4..7edfcd0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,6 +16,16 @@ mod models; static LAST_MODEL_WRITTEN: AtomicBool = AtomicBool::new(false); +fn models_dir_path() -> PathBuf { + if let Ok(p) = env::var("POLYSCRIBE_MODELS_DIR") { + let pb = PathBuf::from(p); + if !pb.as_os_str().is_empty() { + return pb; + } + } + PathBuf::from("models") +} + #[derive(Parser, Debug)] #[command(name = "PolyScribe", version, about = "Merge multiple JSON transcripts into one or transcribe audio using native whisper")] struct Args { @@ -145,7 +155,8 @@ fn normalize_lang_code(input: &str) -> Option { fn find_model_file() -> Result { - let models_dir = Path::new("models"); + let models_dir_buf = models_dir_path(); + let models_dir = models_dir_buf.as_path(); if !models_dir.exists() { create_dir_all(models_dir).with_context(|| format!("Failed to create models directory: {}", models_dir.display()))?; } @@ -349,18 +360,17 @@ struct LastModelCleanup { } impl Drop for LastModelCleanup { fn drop(&mut self) { - if LAST_MODEL_WRITTEN.load(Ordering::Relaxed) { - let _ = std::fs::remove_file(&self.path); - } + // Ensure .last_model does not persist across program runs + let _ = std::fs::remove_file(&self.path); } } fn main() -> Result<()> { let args = Args::parse(); - // Defer cleanup of .last_model until program exit (after all runs within this process) - let models_dir = Path::new("models"); - let last_model_path = models_dir.join(".last_model"); + // Defer cleanup of .last_model until program exit + let models_dir_buf = models_dir_path(); + let last_model_path = models_dir_buf.join(".last_model"); // Ensure cleanup at end of program, regardless of exit path let _last_model_cleanup = LastModelCleanup { path: last_model_path.clone() }; @@ -374,6 +384,18 @@ fn main() -> Result<()> { } } + // If requested, update local models and exit unless inputs provided to continue + if args.update_models { + if let Err(e) = models::update_local_models() { + eprintln!("Model update failed: {:#}", e); + return Err(e); + } + // if only updating models and no inputs, exit + if args.inputs.is_empty() { + return Ok(()); + } + } + // Determine inputs and optional output path let mut inputs = args.inputs; let mut output_path = args.output; @@ -520,6 +542,28 @@ fn main() -> Result<()> { #[cfg(test)] mod tests { + use super::*; + use std::fs; + use std::io::Write; + use std::env as std_env; + use clap::CommandFactory; + + #[test] + fn test_cli_name_polyscribe() { + let cmd = Args::command(); + assert_eq!(cmd.get_name(), "PolyScribe"); + } + + #[test] + fn test_last_model_cleanup_removes_file() { + let tmp = tempfile::tempdir().unwrap(); + let last = tmp.path().join(".last_model"); + fs::write(&last, "dummy").unwrap(); + { + let _cleanup = LastModelCleanup { path: last.clone() }; + } + assert!(!last.exists(), ".last_model should be removed on drop"); + } use super::*; use std::path::Path; diff --git a/src/models.rs b/src/models.rs index 798fa00..a7cf1f5 100644 --- a/src/models.rs +++ b/src/models.rs @@ -340,8 +340,17 @@ fn compute_file_sha256_hex(path: &Path) -> Result { Ok(to_hex_lower(&hasher.finalize())) } +fn models_dir_path() -> std::path::PathBuf { + if let Ok(p) = env::var("POLYSCRIBE_MODELS_DIR") { + let pb = std::path::PathBuf::from(p); + if !pb.as_os_str().is_empty() { return pb; } + } + std::path::PathBuf::from("models") +} + pub fn run_interactive_model_downloader() -> Result<()> { - let models_dir = Path::new("models"); + let models_dir_buf = models_dir_path(); + let models_dir = models_dir_buf.as_path(); if !models_dir.exists() { create_dir_all(models_dir).context("Failed to create models directory")?; } let client = Client::builder() .user_agent("PolyScribe/0.1 (+https://github.com/)") @@ -504,7 +513,8 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry } pub fn update_local_models() -> Result<()> { - let models_dir = Path::new("models"); + let models_dir_buf = models_dir_path(); + let models_dir = models_dir_buf.as_path(); if !models_dir.exists() { create_dir_all(models_dir).context("Failed to create models directory")?; } @@ -591,3 +601,69 @@ pub fn update_local_models() -> Result<()> { Ok(()) } + + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::tempdir; + use std::fs; + use std::io::Write; + + fn sha256_hex(data: &[u8]) -> String { + use sha2::{Digest, Sha256}; + let mut hasher = Sha256::new(); + hasher.update(data); + let out = hasher.finalize(); + let mut s = String::new(); + for b in out { s.push_str(&format!("{:02x}", b)); } + s + } + + #[test] + fn test_update_local_models_offline_copy_and_manifest() { + let tmp_models = tempdir().unwrap(); + let tmp_base = tempdir().unwrap(); + let tmp_manifest = tempdir().unwrap(); + + // Prepare source model file content and hash + let model_name = "tiny.en-q5_1"; + let src_path = tmp_base.path().join(format!("ggml-{}.bin", model_name)); + let new_content = b"new model content"; + fs::write(&src_path, new_content).unwrap(); + let expected_sha = sha256_hex(new_content); + let expected_size = new_content.len() as u64; + + // Write a wrong existing local file to trigger update + let local_path = tmp_models.path().join(format!("ggml-{}.bin", model_name)); + fs::write(&local_path, b"old content").unwrap(); + + // Write manifest JSON + let manifest_path = tmp_manifest.path().join("manifest.json"); + let manifest = serde_json::json!([ + { + "name": model_name, + "base": "tiny", + "subtype": "en-q5_1", + "size": expected_size, + "sha256": expected_sha, + "repo": "ggerganov/whisper.cpp" + } + ]); + fs::write(&manifest_path, serde_json::to_string_pretty(&manifest).unwrap()).unwrap(); + + // Set env vars to force offline behavior and directories + unsafe { + std::env::set_var("POLYSCRIBE_MODELS_MANIFEST", &manifest_path); + std::env::set_var("POLYSCRIBE_MODELS_BASE_COPY_DIR", tmp_base.path()); + std::env::set_var("POLYSCRIBE_MODELS_DIR", tmp_models.path()); + } + + // Run update + update_local_models().unwrap(); + + // Verify local file equals source content + let got = fs::read(&local_path).unwrap(); + assert_eq!(got, new_content); + } +}