[feat] add update_models
CLI option and implement local-remote model synchronization logic
This commit is contained in:
@@ -33,6 +33,10 @@ struct Args {
|
||||
/// Launch interactive model downloader (list HF models, multi-select and download)
|
||||
#[arg(long)]
|
||||
download_models: bool,
|
||||
|
||||
/// Update local Whisper models by comparing hashes/sizes with remote manifest
|
||||
#[arg(long)]
|
||||
update_models: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
135
src/models.rs
135
src/models.rs
@@ -3,6 +3,7 @@ use std::io::{self, Read, Write};
|
||||
use std::path::Path;
|
||||
use std::collections::BTreeMap;
|
||||
use std::time::Duration;
|
||||
use std::env;
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use serde::Deserialize;
|
||||
@@ -41,7 +42,7 @@ struct HFTreeItem {
|
||||
lfs: Option<HFLfsMeta>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
struct ModelEntry {
|
||||
// e.g. "tiny.en-q5_1"
|
||||
name: String,
|
||||
@@ -49,7 +50,7 @@ struct ModelEntry {
|
||||
subtype: String,
|
||||
size: u64,
|
||||
sha256: Option<String>,
|
||||
repo: &'static str, // e.g. "ggerganov/whisper.cpp"
|
||||
repo: String, // e.g. "ggerganov/whisper.cpp"
|
||||
}
|
||||
|
||||
fn split_model_name(model: &str) -> (String, String) {
|
||||
@@ -123,7 +124,7 @@ fn size_from_tree(s: &HFTreeItem) -> Option<u64> {
|
||||
None
|
||||
}
|
||||
|
||||
fn fill_meta_via_head(repo: &'static str, name: &str) -> (Option<u64>, Option<String>) {
|
||||
fn fill_meta_via_head(repo: &str, name: &str) -> (Option<u64>, Option<String>) {
|
||||
let head_client = match Client::builder()
|
||||
.user_agent("PolyScribe/0.1 (+https://github.com/)")
|
||||
.redirect(Policy::none())
|
||||
@@ -185,7 +186,7 @@ fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<Model
|
||||
let (base, subtype) = split_model_name(&model_name);
|
||||
let size = size_from_tree(&it).unwrap_or(0);
|
||||
let sha256 = expected_sha_from_tree(&it);
|
||||
out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo });
|
||||
out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo: repo.to_string() });
|
||||
}
|
||||
}
|
||||
Err(_) => { /* fall back below */ }
|
||||
@@ -212,7 +213,7 @@ fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<Model
|
||||
let (base, subtype) = split_model_name(&model_name);
|
||||
let size = size_from_sibling(&s).unwrap_or(0);
|
||||
let sha256 = expected_sha_from_sibling(&s);
|
||||
out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo });
|
||||
out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo: repo.to_string() });
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -223,7 +224,7 @@ fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<Model
|
||||
}
|
||||
for m in out.iter_mut() {
|
||||
if m.size == 0 || m.sha256.is_none() {
|
||||
let (sz, sha) = fill_meta_via_head(m.repo, &m.name);
|
||||
let (sz, sha) = fill_meta_via_head(&m.repo, &m.name);
|
||||
if m.size == 0 {
|
||||
if let Some(s) = sz { m.size = s; }
|
||||
}
|
||||
@@ -365,7 +366,7 @@ pub fn run_interactive_model_downloader() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> {
|
||||
pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> {
|
||||
let final_path = models_dir.join(format!("ggml-{}.bin", entry.name));
|
||||
|
||||
// If the model already exists, verify against online metadata
|
||||
@@ -420,7 +421,36 @@ fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) ->
|
||||
"Model {} exists but remote hash/size not available; will download to verify contents.",
|
||||
final_path.display()
|
||||
);
|
||||
// Fall through to download for content comparison
|
||||
// Fall through to download/copy for content comparison
|
||||
}
|
||||
}
|
||||
|
||||
// Offline/local copy mode for tests: if set, copy from a given base directory instead of HTTP
|
||||
if let Ok(base_dir) = env::var("POLYSCRIBE_MODELS_BASE_COPY_DIR") {
|
||||
let src_path = std::path::Path::new(&base_dir).join(format!("ggml-{}.bin", entry.name));
|
||||
if src_path.exists() {
|
||||
eprintln!("Copying {} from {}...", entry.name, src_path.display());
|
||||
let tmp_path = models_dir.join(format!("ggml-{}.bin.part", entry.name));
|
||||
if tmp_path.exists() { let _ = std::fs::remove_file(&tmp_path); }
|
||||
std::fs::copy(&src_path, &tmp_path)
|
||||
.with_context(|| format!("Failed to copy from {} to {}", src_path.display(), tmp_path.display()))?;
|
||||
// Verify hash if available
|
||||
if let Some(expected) = &entry.sha256 {
|
||||
let got = compute_file_sha256_hex(&tmp_path)?;
|
||||
if !got.eq_ignore_ascii_case(expected) {
|
||||
let _ = std::fs::remove_file(&tmp_path);
|
||||
return Err(anyhow!(
|
||||
"SHA-256 mismatch for {} (copied): expected {}, got {}",
|
||||
entry.name, expected, got
|
||||
));
|
||||
}
|
||||
}
|
||||
// Replace existing file safely
|
||||
if final_path.exists() { let _ = std::fs::remove_file(&final_path); }
|
||||
std::fs::rename(&tmp_path, &final_path)
|
||||
.with_context(|| format!("Failed to move into place: {}", final_path.display()))?;
|
||||
eprintln!("Saved: {}", final_path.display());
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -472,3 +502,92 @@ fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) ->
|
||||
eprintln!("Saved: {}", final_path.display());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn update_local_models() -> Result<()> {
|
||||
let models_dir = Path::new("models");
|
||||
if !models_dir.exists() {
|
||||
create_dir_all(models_dir).context("Failed to create models directory")?;
|
||||
}
|
||||
|
||||
// Build HTTP client (may be unused in offline copy mode)
|
||||
let client = Client::builder()
|
||||
.user_agent("PolyScribe/0.1 (+https://github.com/)")
|
||||
.timeout(std::time::Duration::from_secs(600))
|
||||
.build()
|
||||
.context("Failed to build HTTP client")?;
|
||||
|
||||
// Obtain manifest: env override or online fetch
|
||||
let models: Vec<ModelEntry> = if let Ok(manifest_path) = env::var("POLYSCRIBE_MODELS_MANIFEST") {
|
||||
let data = std::fs::read_to_string(&manifest_path)
|
||||
.with_context(|| format!("Failed to read manifest at {}", manifest_path))?;
|
||||
let mut list: Vec<ModelEntry> = serde_json::from_str(&data)
|
||||
.with_context(|| format!("Invalid JSON manifest: {}", manifest_path))?;
|
||||
// sort for stability
|
||||
list.sort_by(|a,b| a.name.cmp(&b.name));
|
||||
list
|
||||
} else {
|
||||
fetch_all_models(&client)?
|
||||
};
|
||||
|
||||
// Map name -> entry for fast lookup
|
||||
let mut map: BTreeMap<String, ModelEntry> = BTreeMap::new();
|
||||
for m in models { map.insert(m.name.clone(), m); }
|
||||
|
||||
// Scan local ggml-*.bin models
|
||||
let rd = std::fs::read_dir(models_dir)
|
||||
.with_context(|| format!("Failed to read models dir: {}", models_dir.display()))?;
|
||||
for entry in rd {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if !path.is_file() { continue; }
|
||||
let fname = match path.file_name().and_then(|s| s.to_str()) { Some(s) => s.to_string(), None => continue };
|
||||
if !fname.starts_with("ggml-") || !fname.ends_with(".bin") { continue; }
|
||||
let model_name = fname.trim_start_matches("ggml-").trim_end_matches(".bin").to_string();
|
||||
|
||||
if let Some(remote) = map.get(&model_name) {
|
||||
// If SHA256 available, verify and update if mismatch
|
||||
if let Some(expected) = &remote.sha256 {
|
||||
match compute_file_sha256_hex(&path) {
|
||||
Ok(local_hash) => {
|
||||
if local_hash.eq_ignore_ascii_case(expected) {
|
||||
eprintln!("{} is up-to-date.", fname);
|
||||
continue;
|
||||
} else {
|
||||
eprintln!(
|
||||
"{} hash differs (local {}.. != remote {}..). Updating...",
|
||||
fname,
|
||||
&local_hash[..std::cmp::min(8, local_hash.len())],
|
||||
&expected[..std::cmp::min(8, expected.len())]
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Warning: failed hashing {}: {}. Re-downloading.", fname, e);
|
||||
}
|
||||
}
|
||||
download_one_model(&client, models_dir, remote)?;
|
||||
} else if remote.size > 0 {
|
||||
match std::fs::metadata(&path) {
|
||||
Ok(md) if md.len() == remote.size => {
|
||||
eprintln!("{} appears up-to-date by size ({}).", fname, remote.size);
|
||||
continue;
|
||||
}
|
||||
Ok(md) => {
|
||||
eprintln!("{} size {} differs from remote {}. Updating...", fname, md.len(), remote.size);
|
||||
download_one_model(&client, models_dir, remote)?;
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Warning: stat failed for {}: {}. Updating...", fname, e);
|
||||
download_one_model(&client, models_dir, remote)?;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
eprintln!("No remote hash/size for {}. Skipping.", fname);
|
||||
}
|
||||
} else {
|
||||
eprintln!("No remote metadata for {}. Skipping.", fname);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user