[feat] add models_dir_path
helper and implement dynamic models directory resolution
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -1061,6 +1061,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"tempfile",
|
||||
"toml",
|
||||
"whisper-rs",
|
||||
]
|
||||
|
@@ -13,3 +13,6 @@ chrono = { version = "0.4", features = ["clock"] }
|
||||
reqwest = { version = "0.12", features = ["blocking", "json"] }
|
||||
sha2 = "0.10"
|
||||
whisper-rs = { git = "https://github.com/tazz4843/whisper-rs" }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
|
2
TODO.md
2
TODO.md
@@ -2,7 +2,7 @@
|
||||
- [x] update last_model to be only used during one run
|
||||
- [x] rename project to "PolyScribe"
|
||||
- [x] add tests
|
||||
- update local models using hashes (--update-models)
|
||||
- [x] update local models using hashes (--update-models)
|
||||
- create folder models/ if not present -> use /usr/share/polyscribe/models/ for release version, use ./models/ for development version
|
||||
- create missing folders for output files
|
||||
- for merging (command line flag) -> if not present, treat each file as separate output (--merge | -m)
|
||||
|
58
src/main.rs
58
src/main.rs
@@ -16,6 +16,16 @@ mod models;
|
||||
|
||||
static LAST_MODEL_WRITTEN: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
fn models_dir_path() -> PathBuf {
|
||||
if let Ok(p) = env::var("POLYSCRIBE_MODELS_DIR") {
|
||||
let pb = PathBuf::from(p);
|
||||
if !pb.as_os_str().is_empty() {
|
||||
return pb;
|
||||
}
|
||||
}
|
||||
PathBuf::from("models")
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "PolyScribe", version, about = "Merge multiple JSON transcripts into one or transcribe audio using native whisper")]
|
||||
struct Args {
|
||||
@@ -145,7 +155,8 @@ fn normalize_lang_code(input: &str) -> Option<String> {
|
||||
|
||||
|
||||
fn find_model_file() -> Result<PathBuf> {
|
||||
let models_dir = Path::new("models");
|
||||
let models_dir_buf = models_dir_path();
|
||||
let models_dir = models_dir_buf.as_path();
|
||||
if !models_dir.exists() {
|
||||
create_dir_all(models_dir).with_context(|| format!("Failed to create models directory: {}", models_dir.display()))?;
|
||||
}
|
||||
@@ -349,18 +360,17 @@ struct LastModelCleanup {
|
||||
}
|
||||
impl Drop for LastModelCleanup {
|
||||
fn drop(&mut self) {
|
||||
if LAST_MODEL_WRITTEN.load(Ordering::Relaxed) {
|
||||
let _ = std::fs::remove_file(&self.path);
|
||||
}
|
||||
// Ensure .last_model does not persist across program runs
|
||||
let _ = std::fs::remove_file(&self.path);
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
// Defer cleanup of .last_model until program exit (after all runs within this process)
|
||||
let models_dir = Path::new("models");
|
||||
let last_model_path = models_dir.join(".last_model");
|
||||
// Defer cleanup of .last_model until program exit
|
||||
let models_dir_buf = models_dir_path();
|
||||
let last_model_path = models_dir_buf.join(".last_model");
|
||||
// Ensure cleanup at end of program, regardless of exit path
|
||||
let _last_model_cleanup = LastModelCleanup { path: last_model_path.clone() };
|
||||
|
||||
@@ -374,6 +384,18 @@ fn main() -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
// If requested, update local models and exit unless inputs provided to continue
|
||||
if args.update_models {
|
||||
if let Err(e) = models::update_local_models() {
|
||||
eprintln!("Model update failed: {:#}", e);
|
||||
return Err(e);
|
||||
}
|
||||
// if only updating models and no inputs, exit
|
||||
if args.inputs.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// Determine inputs and optional output path
|
||||
let mut inputs = args.inputs;
|
||||
let mut output_path = args.output;
|
||||
@@ -520,6 +542,28 @@ fn main() -> Result<()> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
use std::env as std_env;
|
||||
use clap::CommandFactory;
|
||||
|
||||
#[test]
|
||||
fn test_cli_name_polyscribe() {
|
||||
let cmd = Args::command();
|
||||
assert_eq!(cmd.get_name(), "PolyScribe");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_last_model_cleanup_removes_file() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let last = tmp.path().join(".last_model");
|
||||
fs::write(&last, "dummy").unwrap();
|
||||
{
|
||||
let _cleanup = LastModelCleanup { path: last.clone() };
|
||||
}
|
||||
assert!(!last.exists(), ".last_model should be removed on drop");
|
||||
}
|
||||
use super::*;
|
||||
use std::path::Path;
|
||||
|
||||
|
@@ -340,8 +340,17 @@ fn compute_file_sha256_hex(path: &Path) -> Result<String> {
|
||||
Ok(to_hex_lower(&hasher.finalize()))
|
||||
}
|
||||
|
||||
fn models_dir_path() -> std::path::PathBuf {
|
||||
if let Ok(p) = env::var("POLYSCRIBE_MODELS_DIR") {
|
||||
let pb = std::path::PathBuf::from(p);
|
||||
if !pb.as_os_str().is_empty() { return pb; }
|
||||
}
|
||||
std::path::PathBuf::from("models")
|
||||
}
|
||||
|
||||
pub fn run_interactive_model_downloader() -> Result<()> {
|
||||
let models_dir = Path::new("models");
|
||||
let models_dir_buf = models_dir_path();
|
||||
let models_dir = models_dir_buf.as_path();
|
||||
if !models_dir.exists() { create_dir_all(models_dir).context("Failed to create models directory")?; }
|
||||
let client = Client::builder()
|
||||
.user_agent("PolyScribe/0.1 (+https://github.com/)")
|
||||
@@ -504,7 +513,8 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
||||
}
|
||||
|
||||
pub fn update_local_models() -> Result<()> {
|
||||
let models_dir = Path::new("models");
|
||||
let models_dir_buf = models_dir_path();
|
||||
let models_dir = models_dir_buf.as_path();
|
||||
if !models_dir.exists() {
|
||||
create_dir_all(models_dir).context("Failed to create models directory")?;
|
||||
}
|
||||
@@ -591,3 +601,69 @@ pub fn update_local_models() -> Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
|
||||
fn sha256_hex(data: &[u8]) -> String {
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(data);
|
||||
let out = hasher.finalize();
|
||||
let mut s = String::new();
|
||||
for b in out { s.push_str(&format!("{:02x}", b)); }
|
||||
s
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_local_models_offline_copy_and_manifest() {
|
||||
let tmp_models = tempdir().unwrap();
|
||||
let tmp_base = tempdir().unwrap();
|
||||
let tmp_manifest = tempdir().unwrap();
|
||||
|
||||
// Prepare source model file content and hash
|
||||
let model_name = "tiny.en-q5_1";
|
||||
let src_path = tmp_base.path().join(format!("ggml-{}.bin", model_name));
|
||||
let new_content = b"new model content";
|
||||
fs::write(&src_path, new_content).unwrap();
|
||||
let expected_sha = sha256_hex(new_content);
|
||||
let expected_size = new_content.len() as u64;
|
||||
|
||||
// Write a wrong existing local file to trigger update
|
||||
let local_path = tmp_models.path().join(format!("ggml-{}.bin", model_name));
|
||||
fs::write(&local_path, b"old content").unwrap();
|
||||
|
||||
// Write manifest JSON
|
||||
let manifest_path = tmp_manifest.path().join("manifest.json");
|
||||
let manifest = serde_json::json!([
|
||||
{
|
||||
"name": model_name,
|
||||
"base": "tiny",
|
||||
"subtype": "en-q5_1",
|
||||
"size": expected_size,
|
||||
"sha256": expected_sha,
|
||||
"repo": "ggerganov/whisper.cpp"
|
||||
}
|
||||
]);
|
||||
fs::write(&manifest_path, serde_json::to_string_pretty(&manifest).unwrap()).unwrap();
|
||||
|
||||
// Set env vars to force offline behavior and directories
|
||||
unsafe {
|
||||
std::env::set_var("POLYSCRIBE_MODELS_MANIFEST", &manifest_path);
|
||||
std::env::set_var("POLYSCRIBE_MODELS_BASE_COPY_DIR", tmp_base.path());
|
||||
std::env::set_var("POLYSCRIBE_MODELS_DIR", tmp_models.path());
|
||||
}
|
||||
|
||||
// Run update
|
||||
update_local_models().unwrap();
|
||||
|
||||
// Verify local file equals source content
|
||||
let got = fs::read(&local_path).unwrap();
|
||||
assert_eq!(got, new_content);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user