[feat] add update_models CLI option and implement local-remote model synchronization logic

This commit is contained in:
2025-08-08 11:32:31 +02:00
parent 29b6a2493b
commit aa8ea14407
2 changed files with 131 additions and 8 deletions

View File

@@ -33,6 +33,10 @@ struct Args {
/// Launch interactive model downloader (list HF models, multi-select and download)
#[arg(long)]
download_models: bool,
/// Update local Whisper models by comparing hashes/sizes with remote manifest
#[arg(long)]
update_models: bool,
}
#[derive(Debug, Deserialize)]

View File

@@ -3,6 +3,7 @@ use std::io::{self, Read, Write};
use std::path::Path;
use std::collections::BTreeMap;
use std::time::Duration;
use std::env;
use anyhow::{anyhow, Context, Result};
use serde::Deserialize;
@@ -41,7 +42,7 @@ struct HFTreeItem {
lfs: Option<HFLfsMeta>,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Deserialize)]
struct ModelEntry {
// e.g. "tiny.en-q5_1"
name: String,
@@ -49,7 +50,7 @@ struct ModelEntry {
subtype: String,
size: u64,
sha256: Option<String>,
repo: &'static str, // e.g. "ggerganov/whisper.cpp"
repo: String, // e.g. "ggerganov/whisper.cpp"
}
fn split_model_name(model: &str) -> (String, String) {
@@ -123,7 +124,7 @@ fn size_from_tree(s: &HFTreeItem) -> Option<u64> {
None
}
fn fill_meta_via_head(repo: &'static str, name: &str) -> (Option<u64>, Option<String>) {
fn fill_meta_via_head(repo: &str, name: &str) -> (Option<u64>, Option<String>) {
let head_client = match Client::builder()
.user_agent("PolyScribe/0.1 (+https://github.com/)")
.redirect(Policy::none())
@@ -185,7 +186,7 @@ fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<Model
let (base, subtype) = split_model_name(&model_name);
let size = size_from_tree(&it).unwrap_or(0);
let sha256 = expected_sha_from_tree(&it);
out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo });
out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo: repo.to_string() });
}
}
Err(_) => { /* fall back below */ }
@@ -212,7 +213,7 @@ fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<Model
let (base, subtype) = split_model_name(&model_name);
let size = size_from_sibling(&s).unwrap_or(0);
let sha256 = expected_sha_from_sibling(&s);
out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo });
out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo: repo.to_string() });
}
}
}
@@ -223,7 +224,7 @@ fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<Model
}
for m in out.iter_mut() {
if m.size == 0 || m.sha256.is_none() {
let (sz, sha) = fill_meta_via_head(m.repo, &m.name);
let (sz, sha) = fill_meta_via_head(&m.repo, &m.name);
if m.size == 0 {
if let Some(s) = sz { m.size = s; }
}
@@ -365,7 +366,7 @@ pub fn run_interactive_model_downloader() -> Result<()> {
Ok(())
}
fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> {
pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> {
let final_path = models_dir.join(format!("ggml-{}.bin", entry.name));
// If the model already exists, verify against online metadata
@@ -420,7 +421,36 @@ fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) ->
"Model {} exists but remote hash/size not available; will download to verify contents.",
final_path.display()
);
// Fall through to download for content comparison
// Fall through to download/copy for content comparison
}
}
// Offline/local copy mode for tests: if set, copy from a given base directory instead of HTTP
if let Ok(base_dir) = env::var("POLYSCRIBE_MODELS_BASE_COPY_DIR") {
let src_path = std::path::Path::new(&base_dir).join(format!("ggml-{}.bin", entry.name));
if src_path.exists() {
eprintln!("Copying {} from {}...", entry.name, src_path.display());
let tmp_path = models_dir.join(format!("ggml-{}.bin.part", entry.name));
if tmp_path.exists() { let _ = std::fs::remove_file(&tmp_path); }
std::fs::copy(&src_path, &tmp_path)
.with_context(|| format!("Failed to copy from {} to {}", src_path.display(), tmp_path.display()))?;
// Verify hash if available
if let Some(expected) = &entry.sha256 {
let got = compute_file_sha256_hex(&tmp_path)?;
if !got.eq_ignore_ascii_case(expected) {
let _ = std::fs::remove_file(&tmp_path);
return Err(anyhow!(
"SHA-256 mismatch for {} (copied): expected {}, got {}",
entry.name, expected, got
));
}
}
// Replace existing file safely
if final_path.exists() { let _ = std::fs::remove_file(&final_path); }
std::fs::rename(&tmp_path, &final_path)
.with_context(|| format!("Failed to move into place: {}", final_path.display()))?;
eprintln!("Saved: {}", final_path.display());
return Ok(());
}
}
@@ -472,3 +502,92 @@ fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) ->
eprintln!("Saved: {}", final_path.display());
Ok(())
}
pub fn update_local_models() -> Result<()> {
let models_dir = Path::new("models");
if !models_dir.exists() {
create_dir_all(models_dir).context("Failed to create models directory")?;
}
// Build HTTP client (may be unused in offline copy mode)
let client = Client::builder()
.user_agent("PolyScribe/0.1 (+https://github.com/)")
.timeout(std::time::Duration::from_secs(600))
.build()
.context("Failed to build HTTP client")?;
// Obtain manifest: env override or online fetch
let models: Vec<ModelEntry> = if let Ok(manifest_path) = env::var("POLYSCRIBE_MODELS_MANIFEST") {
let data = std::fs::read_to_string(&manifest_path)
.with_context(|| format!("Failed to read manifest at {}", manifest_path))?;
let mut list: Vec<ModelEntry> = serde_json::from_str(&data)
.with_context(|| format!("Invalid JSON manifest: {}", manifest_path))?;
// sort for stability
list.sort_by(|a,b| a.name.cmp(&b.name));
list
} else {
fetch_all_models(&client)?
};
// Map name -> entry for fast lookup
let mut map: BTreeMap<String, ModelEntry> = BTreeMap::new();
for m in models { map.insert(m.name.clone(), m); }
// Scan local ggml-*.bin models
let rd = std::fs::read_dir(models_dir)
.with_context(|| format!("Failed to read models dir: {}", models_dir.display()))?;
for entry in rd {
let entry = entry?;
let path = entry.path();
if !path.is_file() { continue; }
let fname = match path.file_name().and_then(|s| s.to_str()) { Some(s) => s.to_string(), None => continue };
if !fname.starts_with("ggml-") || !fname.ends_with(".bin") { continue; }
let model_name = fname.trim_start_matches("ggml-").trim_end_matches(".bin").to_string();
if let Some(remote) = map.get(&model_name) {
// If SHA256 available, verify and update if mismatch
if let Some(expected) = &remote.sha256 {
match compute_file_sha256_hex(&path) {
Ok(local_hash) => {
if local_hash.eq_ignore_ascii_case(expected) {
eprintln!("{} is up-to-date.", fname);
continue;
} else {
eprintln!(
"{} hash differs (local {}.. != remote {}..). Updating...",
fname,
&local_hash[..std::cmp::min(8, local_hash.len())],
&expected[..std::cmp::min(8, expected.len())]
);
}
}
Err(e) => {
eprintln!("Warning: failed hashing {}: {}. Re-downloading.", fname, e);
}
}
download_one_model(&client, models_dir, remote)?;
} else if remote.size > 0 {
match std::fs::metadata(&path) {
Ok(md) if md.len() == remote.size => {
eprintln!("{} appears up-to-date by size ({}).", fname, remote.size);
continue;
}
Ok(md) => {
eprintln!("{} size {} differs from remote {}. Updating...", fname, md.len(), remote.size);
download_one_model(&client, models_dir, remote)?;
}
Err(e) => {
eprintln!("Warning: stat failed for {}: {}. Updating...", fname, e);
download_one_model(&client, models_dir, remote)?;
}
}
} else {
eprintln!("No remote hash/size for {}. Skipping.", fname);
}
} else {
eprintln!("No remote metadata for {}. Skipping.", fname);
}
}
Ok(())
}