From aa8ea14407fcea3d6692276866f639563be35575 Mon Sep 17 00:00:00 2001 From: vikingowl Date: Fri, 8 Aug 2025 11:32:31 +0200 Subject: [PATCH] [feat] add `update_models` CLI option and implement local-remote model synchronization logic --- src/main.rs | 4 ++ src/models.rs | 135 +++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 131 insertions(+), 8 deletions(-) diff --git a/src/main.rs b/src/main.rs index 916c11d..bb7dfd4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -33,6 +33,10 @@ struct Args { /// Launch interactive model downloader (list HF models, multi-select and download) #[arg(long)] download_models: bool, + + /// Update local Whisper models by comparing hashes/sizes with remote manifest + #[arg(long)] + update_models: bool, } #[derive(Debug, Deserialize)] diff --git a/src/models.rs b/src/models.rs index 7efa47c..798fa00 100644 --- a/src/models.rs +++ b/src/models.rs @@ -3,6 +3,7 @@ use std::io::{self, Read, Write}; use std::path::Path; use std::collections::BTreeMap; use std::time::Duration; +use std::env; use anyhow::{anyhow, Context, Result}; use serde::Deserialize; @@ -41,7 +42,7 @@ struct HFTreeItem { lfs: Option, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Deserialize)] struct ModelEntry { // e.g. "tiny.en-q5_1" name: String, @@ -49,7 +50,7 @@ struct ModelEntry { subtype: String, size: u64, sha256: Option, - repo: &'static str, // e.g. "ggerganov/whisper.cpp" + repo: String, // e.g. "ggerganov/whisper.cpp" } fn split_model_name(model: &str) -> (String, String) { @@ -123,7 +124,7 @@ fn size_from_tree(s: &HFTreeItem) -> Option { None } -fn fill_meta_via_head(repo: &'static str, name: &str) -> (Option, Option) { +fn fill_meta_via_head(repo: &str, name: &str) -> (Option, Option) { let head_client = match Client::builder() .user_agent("PolyScribe/0.1 (+https://github.com/)") .redirect(Policy::none()) @@ -185,7 +186,7 @@ fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result { /* fall back below */ } @@ -212,7 +213,7 @@ fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result Result Result<()> { Ok(()) } -fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> { +pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> { let final_path = models_dir.join(format!("ggml-{}.bin", entry.name)); // If the model already exists, verify against online metadata @@ -420,7 +421,36 @@ fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> "Model {} exists but remote hash/size not available; will download to verify contents.", final_path.display() ); - // Fall through to download for content comparison + // Fall through to download/copy for content comparison + } + } + + // Offline/local copy mode for tests: if set, copy from a given base directory instead of HTTP + if let Ok(base_dir) = env::var("POLYSCRIBE_MODELS_BASE_COPY_DIR") { + let src_path = std::path::Path::new(&base_dir).join(format!("ggml-{}.bin", entry.name)); + if src_path.exists() { + eprintln!("Copying {} from {}...", entry.name, src_path.display()); + let tmp_path = models_dir.join(format!("ggml-{}.bin.part", entry.name)); + if tmp_path.exists() { let _ = std::fs::remove_file(&tmp_path); } + std::fs::copy(&src_path, &tmp_path) + .with_context(|| format!("Failed to copy from {} to {}", src_path.display(), tmp_path.display()))?; + // Verify hash if available + if let Some(expected) = &entry.sha256 { + let got = compute_file_sha256_hex(&tmp_path)?; + if !got.eq_ignore_ascii_case(expected) { + let _ = std::fs::remove_file(&tmp_path); + return Err(anyhow!( + "SHA-256 mismatch for {} (copied): expected {}, got {}", + entry.name, expected, got + )); + } + } + // Replace existing file safely + if final_path.exists() { let _ = std::fs::remove_file(&final_path); } + std::fs::rename(&tmp_path, &final_path) + .with_context(|| format!("Failed to move into place: {}", final_path.display()))?; + eprintln!("Saved: {}", final_path.display()); + return Ok(()); } } @@ -472,3 +502,92 @@ fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> eprintln!("Saved: {}", final_path.display()); Ok(()) } + +pub fn update_local_models() -> Result<()> { + let models_dir = Path::new("models"); + if !models_dir.exists() { + create_dir_all(models_dir).context("Failed to create models directory")?; + } + + // Build HTTP client (may be unused in offline copy mode) + let client = Client::builder() + .user_agent("PolyScribe/0.1 (+https://github.com/)") + .timeout(std::time::Duration::from_secs(600)) + .build() + .context("Failed to build HTTP client")?; + + // Obtain manifest: env override or online fetch + let models: Vec = if let Ok(manifest_path) = env::var("POLYSCRIBE_MODELS_MANIFEST") { + let data = std::fs::read_to_string(&manifest_path) + .with_context(|| format!("Failed to read manifest at {}", manifest_path))?; + let mut list: Vec = serde_json::from_str(&data) + .with_context(|| format!("Invalid JSON manifest: {}", manifest_path))?; + // sort for stability + list.sort_by(|a,b| a.name.cmp(&b.name)); + list + } else { + fetch_all_models(&client)? + }; + + // Map name -> entry for fast lookup + let mut map: BTreeMap = BTreeMap::new(); + for m in models { map.insert(m.name.clone(), m); } + + // Scan local ggml-*.bin models + let rd = std::fs::read_dir(models_dir) + .with_context(|| format!("Failed to read models dir: {}", models_dir.display()))?; + for entry in rd { + let entry = entry?; + let path = entry.path(); + if !path.is_file() { continue; } + let fname = match path.file_name().and_then(|s| s.to_str()) { Some(s) => s.to_string(), None => continue }; + if !fname.starts_with("ggml-") || !fname.ends_with(".bin") { continue; } + let model_name = fname.trim_start_matches("ggml-").trim_end_matches(".bin").to_string(); + + if let Some(remote) = map.get(&model_name) { + // If SHA256 available, verify and update if mismatch + if let Some(expected) = &remote.sha256 { + match compute_file_sha256_hex(&path) { + Ok(local_hash) => { + if local_hash.eq_ignore_ascii_case(expected) { + eprintln!("{} is up-to-date.", fname); + continue; + } else { + eprintln!( + "{} hash differs (local {}.. != remote {}..). Updating...", + fname, + &local_hash[..std::cmp::min(8, local_hash.len())], + &expected[..std::cmp::min(8, expected.len())] + ); + } + } + Err(e) => { + eprintln!("Warning: failed hashing {}: {}. Re-downloading.", fname, e); + } + } + download_one_model(&client, models_dir, remote)?; + } else if remote.size > 0 { + match std::fs::metadata(&path) { + Ok(md) if md.len() == remote.size => { + eprintln!("{} appears up-to-date by size ({}).", fname, remote.size); + continue; + } + Ok(md) => { + eprintln!("{} size {} differs from remote {}. Updating...", fname, md.len(), remote.size); + download_one_model(&client, models_dir, remote)?; + } + Err(e) => { + eprintln!("Warning: stat failed for {}: {}. Updating...", fname, e); + download_one_model(&client, models_dir, remote)?; + } + } + } else { + eprintln!("No remote hash/size for {}. Skipping.", fname); + } + } else { + eprintln!("No remote metadata for {}. Skipping.", fname); + } + } + + Ok(()) +}