[feat] add update_models
CLI option and implement local-remote model synchronization logic
This commit is contained in:
@@ -33,6 +33,10 @@ struct Args {
|
|||||||
/// Launch interactive model downloader (list HF models, multi-select and download)
|
/// Launch interactive model downloader (list HF models, multi-select and download)
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
download_models: bool,
|
download_models: bool,
|
||||||
|
|
||||||
|
/// Update local Whisper models by comparing hashes/sizes with remote manifest
|
||||||
|
#[arg(long)]
|
||||||
|
update_models: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
|
135
src/models.rs
135
src/models.rs
@@ -3,6 +3,7 @@ use std::io::{self, Read, Write};
|
|||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
use std::env;
|
||||||
|
|
||||||
use anyhow::{anyhow, Context, Result};
|
use anyhow::{anyhow, Context, Result};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
@@ -41,7 +42,7 @@ struct HFTreeItem {
|
|||||||
lfs: Option<HFLfsMeta>,
|
lfs: Option<HFLfsMeta>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
struct ModelEntry {
|
struct ModelEntry {
|
||||||
// e.g. "tiny.en-q5_1"
|
// e.g. "tiny.en-q5_1"
|
||||||
name: String,
|
name: String,
|
||||||
@@ -49,7 +50,7 @@ struct ModelEntry {
|
|||||||
subtype: String,
|
subtype: String,
|
||||||
size: u64,
|
size: u64,
|
||||||
sha256: Option<String>,
|
sha256: Option<String>,
|
||||||
repo: &'static str, // e.g. "ggerganov/whisper.cpp"
|
repo: String, // e.g. "ggerganov/whisper.cpp"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn split_model_name(model: &str) -> (String, String) {
|
fn split_model_name(model: &str) -> (String, String) {
|
||||||
@@ -123,7 +124,7 @@ fn size_from_tree(s: &HFTreeItem) -> Option<u64> {
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fill_meta_via_head(repo: &'static str, name: &str) -> (Option<u64>, Option<String>) {
|
fn fill_meta_via_head(repo: &str, name: &str) -> (Option<u64>, Option<String>) {
|
||||||
let head_client = match Client::builder()
|
let head_client = match Client::builder()
|
||||||
.user_agent("PolyScribe/0.1 (+https://github.com/)")
|
.user_agent("PolyScribe/0.1 (+https://github.com/)")
|
||||||
.redirect(Policy::none())
|
.redirect(Policy::none())
|
||||||
@@ -185,7 +186,7 @@ fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<Model
|
|||||||
let (base, subtype) = split_model_name(&model_name);
|
let (base, subtype) = split_model_name(&model_name);
|
||||||
let size = size_from_tree(&it).unwrap_or(0);
|
let size = size_from_tree(&it).unwrap_or(0);
|
||||||
let sha256 = expected_sha_from_tree(&it);
|
let sha256 = expected_sha_from_tree(&it);
|
||||||
out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo });
|
out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo: repo.to_string() });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(_) => { /* fall back below */ }
|
Err(_) => { /* fall back below */ }
|
||||||
@@ -212,7 +213,7 @@ fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<Model
|
|||||||
let (base, subtype) = split_model_name(&model_name);
|
let (base, subtype) = split_model_name(&model_name);
|
||||||
let size = size_from_sibling(&s).unwrap_or(0);
|
let size = size_from_sibling(&s).unwrap_or(0);
|
||||||
let sha256 = expected_sha_from_sibling(&s);
|
let sha256 = expected_sha_from_sibling(&s);
|
||||||
out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo });
|
out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo: repo.to_string() });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -223,7 +224,7 @@ fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<Model
|
|||||||
}
|
}
|
||||||
for m in out.iter_mut() {
|
for m in out.iter_mut() {
|
||||||
if m.size == 0 || m.sha256.is_none() {
|
if m.size == 0 || m.sha256.is_none() {
|
||||||
let (sz, sha) = fill_meta_via_head(m.repo, &m.name);
|
let (sz, sha) = fill_meta_via_head(&m.repo, &m.name);
|
||||||
if m.size == 0 {
|
if m.size == 0 {
|
||||||
if let Some(s) = sz { m.size = s; }
|
if let Some(s) = sz { m.size = s; }
|
||||||
}
|
}
|
||||||
@@ -365,7 +366,7 @@ pub fn run_interactive_model_downloader() -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> {
|
pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> {
|
||||||
let final_path = models_dir.join(format!("ggml-{}.bin", entry.name));
|
let final_path = models_dir.join(format!("ggml-{}.bin", entry.name));
|
||||||
|
|
||||||
// If the model already exists, verify against online metadata
|
// If the model already exists, verify against online metadata
|
||||||
@@ -420,7 +421,36 @@ fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) ->
|
|||||||
"Model {} exists but remote hash/size not available; will download to verify contents.",
|
"Model {} exists but remote hash/size not available; will download to verify contents.",
|
||||||
final_path.display()
|
final_path.display()
|
||||||
);
|
);
|
||||||
// Fall through to download for content comparison
|
// Fall through to download/copy for content comparison
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Offline/local copy mode for tests: if set, copy from a given base directory instead of HTTP
|
||||||
|
if let Ok(base_dir) = env::var("POLYSCRIBE_MODELS_BASE_COPY_DIR") {
|
||||||
|
let src_path = std::path::Path::new(&base_dir).join(format!("ggml-{}.bin", entry.name));
|
||||||
|
if src_path.exists() {
|
||||||
|
eprintln!("Copying {} from {}...", entry.name, src_path.display());
|
||||||
|
let tmp_path = models_dir.join(format!("ggml-{}.bin.part", entry.name));
|
||||||
|
if tmp_path.exists() { let _ = std::fs::remove_file(&tmp_path); }
|
||||||
|
std::fs::copy(&src_path, &tmp_path)
|
||||||
|
.with_context(|| format!("Failed to copy from {} to {}", src_path.display(), tmp_path.display()))?;
|
||||||
|
// Verify hash if available
|
||||||
|
if let Some(expected) = &entry.sha256 {
|
||||||
|
let got = compute_file_sha256_hex(&tmp_path)?;
|
||||||
|
if !got.eq_ignore_ascii_case(expected) {
|
||||||
|
let _ = std::fs::remove_file(&tmp_path);
|
||||||
|
return Err(anyhow!(
|
||||||
|
"SHA-256 mismatch for {} (copied): expected {}, got {}",
|
||||||
|
entry.name, expected, got
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Replace existing file safely
|
||||||
|
if final_path.exists() { let _ = std::fs::remove_file(&final_path); }
|
||||||
|
std::fs::rename(&tmp_path, &final_path)
|
||||||
|
.with_context(|| format!("Failed to move into place: {}", final_path.display()))?;
|
||||||
|
eprintln!("Saved: {}", final_path.display());
|
||||||
|
return Ok(());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -472,3 +502,92 @@ fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) ->
|
|||||||
eprintln!("Saved: {}", final_path.display());
|
eprintln!("Saved: {}", final_path.display());
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn update_local_models() -> Result<()> {
|
||||||
|
let models_dir = Path::new("models");
|
||||||
|
if !models_dir.exists() {
|
||||||
|
create_dir_all(models_dir).context("Failed to create models directory")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build HTTP client (may be unused in offline copy mode)
|
||||||
|
let client = Client::builder()
|
||||||
|
.user_agent("PolyScribe/0.1 (+https://github.com/)")
|
||||||
|
.timeout(std::time::Duration::from_secs(600))
|
||||||
|
.build()
|
||||||
|
.context("Failed to build HTTP client")?;
|
||||||
|
|
||||||
|
// Obtain manifest: env override or online fetch
|
||||||
|
let models: Vec<ModelEntry> = if let Ok(manifest_path) = env::var("POLYSCRIBE_MODELS_MANIFEST") {
|
||||||
|
let data = std::fs::read_to_string(&manifest_path)
|
||||||
|
.with_context(|| format!("Failed to read manifest at {}", manifest_path))?;
|
||||||
|
let mut list: Vec<ModelEntry> = serde_json::from_str(&data)
|
||||||
|
.with_context(|| format!("Invalid JSON manifest: {}", manifest_path))?;
|
||||||
|
// sort for stability
|
||||||
|
list.sort_by(|a,b| a.name.cmp(&b.name));
|
||||||
|
list
|
||||||
|
} else {
|
||||||
|
fetch_all_models(&client)?
|
||||||
|
};
|
||||||
|
|
||||||
|
// Map name -> entry for fast lookup
|
||||||
|
let mut map: BTreeMap<String, ModelEntry> = BTreeMap::new();
|
||||||
|
for m in models { map.insert(m.name.clone(), m); }
|
||||||
|
|
||||||
|
// Scan local ggml-*.bin models
|
||||||
|
let rd = std::fs::read_dir(models_dir)
|
||||||
|
.with_context(|| format!("Failed to read models dir: {}", models_dir.display()))?;
|
||||||
|
for entry in rd {
|
||||||
|
let entry = entry?;
|
||||||
|
let path = entry.path();
|
||||||
|
if !path.is_file() { continue; }
|
||||||
|
let fname = match path.file_name().and_then(|s| s.to_str()) { Some(s) => s.to_string(), None => continue };
|
||||||
|
if !fname.starts_with("ggml-") || !fname.ends_with(".bin") { continue; }
|
||||||
|
let model_name = fname.trim_start_matches("ggml-").trim_end_matches(".bin").to_string();
|
||||||
|
|
||||||
|
if let Some(remote) = map.get(&model_name) {
|
||||||
|
// If SHA256 available, verify and update if mismatch
|
||||||
|
if let Some(expected) = &remote.sha256 {
|
||||||
|
match compute_file_sha256_hex(&path) {
|
||||||
|
Ok(local_hash) => {
|
||||||
|
if local_hash.eq_ignore_ascii_case(expected) {
|
||||||
|
eprintln!("{} is up-to-date.", fname);
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
eprintln!(
|
||||||
|
"{} hash differs (local {}.. != remote {}..). Updating...",
|
||||||
|
fname,
|
||||||
|
&local_hash[..std::cmp::min(8, local_hash.len())],
|
||||||
|
&expected[..std::cmp::min(8, expected.len())]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("Warning: failed hashing {}: {}. Re-downloading.", fname, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
download_one_model(&client, models_dir, remote)?;
|
||||||
|
} else if remote.size > 0 {
|
||||||
|
match std::fs::metadata(&path) {
|
||||||
|
Ok(md) if md.len() == remote.size => {
|
||||||
|
eprintln!("{} appears up-to-date by size ({}).", fname, remote.size);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Ok(md) => {
|
||||||
|
eprintln!("{} size {} differs from remote {}. Updating...", fname, md.len(), remote.size);
|
||||||
|
download_one_model(&client, models_dir, remote)?;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("Warning: stat failed for {}: {}. Updating...", fname, e);
|
||||||
|
download_one_model(&client, models_dir, remote)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
eprintln!("No remote hash/size for {}. Skipping.", fname);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
eprintln!("No remote metadata for {}. Skipping.", fname);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user