[refactor] rename and simplify ProgressManager to FileProgress, enhance caching logic, update Hugging Face API integration, and clean up unused comments
Some checks failed
CI / build (push) Has been cancelled

This commit is contained in:
2025-08-15 11:24:50 +02:00
parent cbf48a0452
commit 5ec297397e
14 changed files with 487 additions and 571 deletions

View File

@@ -1,9 +1,6 @@
// SPDX-License-Identifier: MIT
//! Model management for PolyScribe: discovery, download, and verification.
//! Fetches the live file table from Hugging Face, using size and sha256
//! data for verification. Falls back to scraping the repository tree page
//! if the JSON API is unavailable or incomplete. No built-in manifest.
use crate::config::ConfigService;
use crate::prelude::*;
use anyhow::{Context, anyhow};
use chrono::{DateTime, Utc};
@@ -12,13 +9,13 @@ use reqwest::blocking::Client;
use reqwest::header::{
ACCEPT_RANGES, CONTENT_LENGTH, CONTENT_RANGE, ETAG, IF_RANGE, LAST_MODIFIED, RANGE,
};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::BTreeSet;
use std::fs::{self, File, OpenOptions};
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
use std::time::{Duration, Instant};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
fn format_size_mb(size: Option<u64>) -> String {
match size {
@@ -35,7 +32,6 @@ fn format_size_gib(bytes: u64) -> String {
format!("{gib:.2} GiB")
}
// Short date formatter (RFC -> yyyy-mm-dd)
fn short_date(s: &str) -> String {
DateTime::parse_from_rfc3339(s)
.ok()
@@ -43,12 +39,10 @@ fn short_date(s: &str) -> String {
.unwrap_or_else(|| s.to_string())
}
// Free disk space using libc::statvfs (already in Cargo)
fn free_space_bytes_for_path(path: &Path) -> Result<u64> {
use libc::statvfs;
use std::ffi::CString;
// use parent dir or current dir if none
let dir = if path.is_dir() {
path
} else {
@@ -66,9 +60,7 @@ fn free_space_bytes_for_path(path: &Path) -> Result<u64> {
}
}
// Minimal mirror note shown in single-line style
fn mirror_label(url: &str) -> &'static str {
// Very light heuristic; replace with your actual mirror selection if you have it
if url.contains("eu") {
"EU mirror"
} else if url.contains("us") {
@@ -78,7 +70,6 @@ fn mirror_label(url: &str) -> &'static str {
}
}
// Perform a HEAD to get size/etag/last-modified and fill what we can
type HeadMeta = (Option<u64>, Option<String>, Option<String>, bool);
fn head_entry(client: &Client, url: &str) -> Result<HeadMeta> {
@@ -107,39 +98,27 @@ fn head_entry(client: &Client, url: &str) -> Result<HeadMeta> {
Ok((len, etag, last_mod, ranges_ok))
}
/// Represents a downloadable Whisper model artifact.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct ModelEntry {
/// Display name and local short name (informational; may equal stem of file)
name: String,
/// Remote file name (with extension)
file: String,
/// Remote URL
url: String,
/// Expected file size (optional)
size: Option<u64>,
/// Expected SHA-256 in hex (optional)
sha256: Option<String>,
/// New: last modified timestamp string if available
last_modified: Option<String>,
/// New: parsed base and variant for 2-step UI
base: String,
variant: String,
}
// -------- Hugging Face API integration --------
#[derive(Debug, Deserialize)]
struct HfModelInfo {
// Returned sometimes at /api/models/{repo}
siblings: Option<Vec<HfFile>>,
// Returned when using `?expand=files`
files: Option<Vec<HfFile>>,
}
#[derive(Debug, Deserialize)]
struct HfLfsInfo {
// Sometimes an "oid" like "sha256:<hex>"
oid: Option<String>,
size: Option<u64>,
sha256: Option<String>,
@@ -147,53 +126,33 @@ struct HfLfsInfo {
#[derive(Debug, Deserialize)]
struct HfFile {
// Relative filename within repo (e.g., "ggml-tiny.bin")
rfilename: String,
// Size reported at top-level for non-LFS files; often present
size: Option<u64>,
// Some entries include sha256 at top level
sha256: Option<String>,
// LFS metadata with size and possibly sha256 embedded
lfs: Option<HfLfsInfo>,
// New: last modified timestamp provided by HF API on expanded files
#[serde(rename = "lastModified")]
last_modified: Option<String>,
}
fn parse_base_variant(display_name: &str) -> (String, String) {
// display_name is name without ggml-/gguf- and without .bin
// Examples:
// - "tiny" -> base=tiny, variant=default
// - "tiny.en" -> base=tiny, variant=en
// - "base" -> base=base, variant=default
// - "large-v2" -> base=large, variant=v2
// - "large-v3" -> base=large, variant=v3
// - "medium" -> base=medium, variant=default
let mut variant = "default".to_string();
// Split off dot-based suffix (e.g., ".en")
let mut head = display_name;
if let Some((h, rest)) = display_name.split_once('.') {
head = h;
// if there is more than one dot, just keep everything after first as variant
variant = rest.to_string();
}
// Handle hyphenated versions like large-v2
if let Some((b, v)) = head.split_once('-') {
return (b.to_string(), v.to_string());
}
(head.to_string(), variant)
}
/// Build a manifest by calling the Hugging Face API for a repo.
/// Prefers the plain API URL, then retries with `?expand=files` if needed.
fn hf_repo_manifest_api(repo: &str) -> Result<Vec<ModelEntry>> {
let client = Client::builder().user_agent("polyscribe/0.1").build()?;
let client = Client::builder()
.user_agent(ConfigService::user_agent())
.build()?;
// 1) Try the plain API you specified
let base = format!("https://huggingface.co/api/models/{}", repo);
let base = ConfigService::hf_api_base_for(repo);
let resp = client.get(&base).send()?;
let mut entries = if resp.status().is_success() {
let info: HfModelInfo = resp.json()?;
@@ -202,7 +161,6 @@ fn hf_repo_manifest_api(repo: &str) -> Result<Vec<ModelEntry>> {
Vec::new()
};
// 2) If empty, try with expand=files (some repos require this for full file listing)
if entries.is_empty() {
let url = format!("{base}?expand=files");
let resp2 = client.get(&url).send()?;
@@ -228,7 +186,6 @@ fn hf_info_to_entries(repo: &str, info: HfModelInfo) -> Result<Vec<ModelEntry>>
continue;
}
// Derive a simple display name from the file stem
let stem = fname.strip_suffix(".bin").unwrap_or(&fname).to_string();
let name_no_prefix = stem
.strip_prefix("ggml-")
@@ -236,7 +193,6 @@ fn hf_info_to_entries(repo: &str, info: HfModelInfo) -> Result<Vec<ModelEntry>>
.unwrap_or(&stem)
.to_string();
// Prefer explicit sha256; else try to parse from LFS oid "sha256:<hex>"
let sha_from_lfs = f.lfs.as_ref().and_then(|l| {
l.sha256.clone().or_else(|| {
l.oid
@@ -268,12 +224,11 @@ fn hf_info_to_entries(repo: &str, info: HfModelInfo) -> Result<Vec<ModelEntry>>
Ok(out)
}
// -------- HTML scraping fallback (tree view) --------
/// Scrape the repository tree page when the API doesn't return a usable list.
/// Note: sizes and hashes are generally unavailable in this path.
fn scrape_tree_manifest(repo: &str) -> Result<Vec<ModelEntry>> {
let client = Client::builder().user_agent("polyscribe/0.1").build()?;
let client = Client::builder()
.user_agent(ConfigService::user_agent())
.build()?;
let url = format!("https://huggingface.co/{}/tree/main?recursive=1", repo);
let resp = client.get(&url).send()?;
@@ -282,10 +237,6 @@ fn scrape_tree_manifest(repo: &str) -> Result<Vec<ModelEntry>> {
}
let html = resp.text()?;
// Extract .bin paths from links. Match both blob/main and resolve/main.
// Example matches:
// - /{repo}/blob/main/ggml-base.en.bin
// - /{repo}/resolve/main/ggml-base.en.bin
let mut files = BTreeSet::new();
for mat in html.match_indices(".bin") {
let end = mat.0 + 4;
@@ -346,13 +297,8 @@ fn scrape_tree_manifest(repo: &str) -> Result<Vec<ModelEntry>> {
Ok(out)
}
// -------- Metadata enrichment via HEAD (size/hash/last-modified) --------
fn parse_sha_from_header_value(s: &str) -> Option<String> {
// Common HF patterns:
// - ETag: "SHA256:<hex>"
// - X-Linked-ETag: "SHA256:<hex>"
// - Sometimes weak etags: W/"SHA256:<hex>"
let lower = s.to_ascii_lowercase();
if let Some(idx) = lower.find("sha256:") {
let tail = &lower[idx + "sha256:".len()..];
@@ -365,14 +311,13 @@ fn parse_sha_from_header_value(s: &str) -> Option<String> {
}
fn enrich_entry_via_head(entry: &mut ModelEntry) -> Result<()> {
// If we already have everything, nothing to do
if entry.size.is_some() && entry.sha256.is_some() && entry.last_modified.is_some() {
return Ok(());
}
let client = Client::builder()
.user_agent("polyscribe/0.1")
.timeout(Duration::from_secs(8))
.user_agent(ConfigService::user_agent())
.timeout(Duration::from_secs(ConfigService::http_timeout_secs()))
.build()?;
let mut head_url = entry.url.clone();
@@ -397,7 +342,6 @@ fn enrich_entry_via_head(entry: &mut ModelEntry) -> Result<()> {
let mut filled_sha = false;
let mut filled_lm = false;
// Content-Length
if entry.size.is_none()
&& let Some(sz) = resp
.headers()
@@ -409,7 +353,6 @@ fn enrich_entry_via_head(entry: &mut ModelEntry) -> Result<()> {
filled_size = true;
}
// SHA256 from headers if available
if entry.sha256.is_none() {
let _ = resp
.headers()
@@ -433,7 +376,6 @@ fn enrich_entry_via_head(entry: &mut ModelEntry) -> Result<()> {
}
}
// Last-Modified
if entry.last_modified.is_none() {
let _ = resp
.headers()
@@ -477,28 +419,204 @@ fn enrich_entry_via_head(entry: &mut ModelEntry) -> Result<()> {
Ok(())
}
// -------- Online manifest (API first, then scrape) --------
#[derive(Debug, Serialize, Deserialize)]
struct CachedManifest {
fetched_at: u64,
etag: Option<String>,
last_modified: Option<String>,
entries: Vec<ModelEntry>,
}
fn get_cache_dir() -> Result<PathBuf> {
Ok(ConfigService::manifest_cache_dir()
.ok_or_else(|| anyhow!("could not determine platform directories"))?)
}
fn get_cached_manifest_path() -> Result<PathBuf> {
let cache_dir = get_cache_dir()?;
Ok(cache_dir.join(ConfigService::manifest_cache_filename()))
}
fn should_bypass_cache() -> bool {
ConfigService::bypass_manifest_cache()
}
fn get_cache_ttl() -> u64 {
ConfigService::manifest_cache_ttl_seconds()
}
fn load_cached_manifest() -> Option<CachedManifest> {
if should_bypass_cache() {
return None;
}
let cache_path = get_cached_manifest_path().ok()?;
if !cache_path.exists() {
return None;
}
let cache_file = File::open(cache_path).ok()?;
let cached: CachedManifest = serde_json::from_reader(cache_file).ok()?;
let now = SystemTime::now().duration_since(UNIX_EPOCH).ok()?.as_secs();
let ttl = get_cache_ttl();
if now.saturating_sub(cached.fetched_at) > ttl {
crate::dlog!(
1,
"Cache expired (age: {}s, TTL: {}s)",
now.saturating_sub(cached.fetched_at),
ttl
);
return None;
}
crate::dlog!(
1,
"Using cached manifest (age: {}s)",
now.saturating_sub(cached.fetched_at)
);
Some(cached)
}
fn save_manifest_to_cache(
entries: &[ModelEntry],
etag: Option<&str>,
last_modified: Option<&str>,
) -> Result<()> {
if should_bypass_cache() {
return Ok(());
}
let cache_dir = get_cache_dir()?;
fs::create_dir_all(&cache_dir)?;
let cache_path = get_cached_manifest_path()?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|_| anyhow!("system time error"))?
.as_secs();
let cached = CachedManifest {
fetched_at: now,
etag: etag.map(|s| s.to_string()),
last_modified: last_modified.map(|s| s.to_string()),
entries: entries.to_vec(),
};
let cache_file = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(&cache_path)
.with_context(|| format!("opening cache file {}", cache_path.display()))?;
serde_json::to_writer_pretty(cache_file, &cached)
.with_context(|| "serializing cached manifest")?;
crate::dlog!(1, "Saved manifest to cache: {} entries", entries.len());
Ok(())
}
fn fetch_manifest_with_cache() -> Result<Vec<ModelEntry>> {
let cached = load_cached_manifest();
let client = Client::builder()
.user_agent(ConfigService::user_agent())
.build()?;
let repo = ConfigService::hf_repo();
let base_url = ConfigService::hf_api_base_for(&repo);
let mut req = client.get(&base_url);
if let Some(ref cached) = cached {
if let Some(ref etag) = cached.etag {
req = req.header("If-None-Match", format!("\"{}\"", etag));
} else if let Some(ref last_mod) = cached.last_modified {
req = req.header("If-Modified-Since", last_mod);
}
}
let resp = req.send()?;
if resp.status().as_u16() == 304 {
if let Some(cached) = cached {
crate::dlog!(1, "Manifest not modified, using cache");
return Ok(cached.entries);
}
}
if !resp.status().is_success() {
return Err(anyhow!("HF API {} for {}", resp.status(), base_url).into());
}
let etag = resp
.headers()
.get(ETAG)
.and_then(|v| v.to_str().ok())
.map(|s| s.trim_matches('"').to_string());
let last_modified = resp
.headers()
.get(LAST_MODIFIED)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let info: HfModelInfo = resp.json()?;
let mut entries = hf_info_to_entries(&repo, info)?;
if entries.is_empty() {
let url = format!("{}?expand=files", base_url);
let resp2 = client.get(&url).send()?;
if !resp2.status().is_success() {
return Err(anyhow!("HF API {} for {}", resp2.status(), url).into());
}
let info: HfModelInfo = resp2.json()?;
entries = hf_info_to_entries(&repo, info)?;
}
if entries.is_empty() {
return Err(anyhow!("HF API returned no usable .bin files").into());
}
let _ = save_manifest_to_cache(&entries, etag.as_deref(), last_modified.as_deref());
Ok(entries)
}
/// Returns the current manifest (online only).
fn current_manifest() -> Result<Vec<ModelEntry>> {
let started = Instant::now();
crate::dlog!(1, "Fetching HF manifest…");
// 1) Load from API, else scrape
let mut list = match hf_repo_manifest_api("ggerganov/whisper.cpp") {
let mut list = match fetch_manifest_with_cache() {
Ok(list) if !list.is_empty() => {
crate::dlog!(1, "Manifest loaded from HF API ({} entries)", list.len());
crate::dlog!(
1,
"Manifest loaded from HF API with cache ({} entries)",
list.len()
);
list
}
_ => {
crate::ilog!("Falling back to scraping the repository tree page");
let scraped = scrape_tree_manifest("ggerganov/whisper.cpp")?;
crate::dlog!(1, "Manifest loaded via scrape ({} entries)", scraped.len());
scraped
crate::ilog!("Cache failed, falling back to direct API");
let repo = ConfigService::hf_repo();
let list = match hf_repo_manifest_api(&repo) {
Ok(list) if !list.is_empty() => {
crate::dlog!(1, "Manifest loaded from HF API ({} entries)", list.len());
list
}
_ => {
crate::ilog!("Falling back to scraping the repository tree page");
let scraped = scrape_tree_manifest(&repo)?;
crate::dlog!(1, "Manifest loaded via scrape ({} entries)", scraped.len());
scraped
}
};
let _ = save_manifest_to_cache(&list, None, None);
list
}
};
// 2) Enrich missing metadata so the UI can show sizes and hashes
let mut need_enrich = 0usize;
for m in &list {
if m.size.is_none() || m.sha256.is_none() || m.last_modified.is_none() {
@@ -532,8 +650,6 @@ fn current_manifest() -> Result<Vec<ModelEntry>> {
Ok(list)
}
/// Pick the best local Whisper model in the given directory.
/// Heuristic: choose the largest .bin file by size. Returns None if none found.
pub fn pick_best_local_model(dir: &Path) -> Option<PathBuf> {
let rd = fs::read_dir(dir).ok()?;
rd.flatten()
@@ -549,39 +665,23 @@ pub fn pick_best_local_model(dir: &Path) -> Option<PathBuf> {
.map(|(_, p)| p)
}
/// Returns the directory where models should be stored based on platform conventions.
fn resolve_models_dir() -> Result<PathBuf> {
let dirs = directories::ProjectDirs::from("dev", "polyscribe", "polyscribe")
.ok_or_else(|| anyhow!("could not determine platform directories"))?;
let data_dir = dirs.data_dir().join("models");
Ok(data_dir)
Ok(ConfigService::models_dir(None)
.ok_or_else(|| anyhow!("could not determine models directory"))?)
}
// Example of a non-interactive path ensuring a given model by name exists, with improved copy.
// Wire this into CLI flags as needed.
/// Ensures a model is available by name, downloading it if necessary.
/// This is a non-interactive version that doesn't prompt the user.
///
/// # Arguments
/// * `name` - Name of the model to ensure is available
///
/// # Returns
/// * `Result<PathBuf>` - Path to the downloaded model file on success
pub fn ensure_model_available_noninteractive(name: &str) -> Result<PathBuf> {
let entry = find_manifest_entry(name)?.ok_or_else(|| anyhow!("unknown model: {name}"))?;
// Resolve destination file path; ensure XDG path (or your existing logic)
let dir = resolve_models_dir()?; // implement or reuse your existing directory resolver
let dir = resolve_models_dir()?;
fs::create_dir_all(&dir).ok();
let dest = dir.join(&entry.file);
// If already matches, early return
if file_matches(&dest, entry.size, entry.sha256.as_deref())? {
crate::ui::info(format!("Already up to date: {}", dest.display()));
return Ok(dest);
}
// Single-line header
let base = &entry.base;
let variant = &entry.variant;
let size_str = format_size_mb(entry.size);
@@ -596,9 +696,16 @@ pub fn ensure_model_available_noninteractive(name: &str) -> Result<PathBuf> {
Ok(dest)
}
pub fn clear_manifest_cache() -> Result<()> {
let cache_path = get_cached_manifest_path()?;
if cache_path.exists() {
fs::remove_file(&cache_path)?;
crate::dlog!(1, "Cleared manifest cache");
}
Ok(())
}
fn find_manifest_entry(name: &str) -> Result<Option<ModelEntry>> {
// Accept either manifest display name, file stem, or direct file name.
// Normalize: strip ".bin" for comparisons and also handle input that already includes it.
let wanted_name = name
.strip_suffix(".bin")
.unwrap_or(name)
@@ -622,10 +729,6 @@ fn find_manifest_entry(name: &str) -> Result<Option<ModelEntry>> {
Ok(None)
}
// Return true if the file at `path` matches expected size and/or sha256 (when provided).
// - If sha256 is provided, verify it (preferred).
// - Else if size is provided, check size.
// - If neither provided, return false (cannot verify).
fn file_matches(path: &Path, size: Option<u64>, sha256_hex: Option<&str>) -> Result<bool> {
if !path.exists() {
return Ok(false);
@@ -655,21 +758,14 @@ fn file_matches(path: &Path, size: Option<u64>, sha256_hex: Option<&str>) -> Res
Ok(false)
}
// Download with:
// - Free-space preflight (size * 1.1 overhead).
// - Resume via Range if .part exists and server supports it.
// - Atomic write: download to .part (temp) then rename.
// - Checksum verification when available.
// - Single-line progress UI.
fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
let url = &entry.url;
let client = Client::builder()
.user_agent("polyscribe-model-downloader/1")
.user_agent(ConfigService::downloader_user_agent())
.build()?;
crate::ui::info(format!("Resolving source: {} ({})", mirror_label(url), url));
// HEAD for size/etag/ranges
let (mut total_len, remote_etag, _remote_last_mod, ranges_ok) =
head_entry(&client, url).context("probing remote file")?;
@@ -710,9 +806,6 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
.open(&part_path)
.with_context(|| format!("opening {}", part_path.display()))?;
// Build request:
// - Fresh download: plain GET (no If-None-Match).
// - Resume: Range + optional If-Range with ETag.
let mut req = client.get(url);
if ranges_ok && resume_from > 0 {
req = req.header(RANGE, format!("bytes={resume_from}-"));
@@ -729,30 +822,21 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
let start = Instant::now();
let mut resp = req.send()?.error_for_status()?;
// Defensive: if server returns 304 but we don't have a valid cached copy, retry without conditionals.
if resp.status().as_u16() == 304 && resume_from == 0 {
// Fresh download must not be conditional; redo as plain GET
let req2 = client.get(url);
resp = req2.send()?.error_for_status()?;
}
// If server ignored RANGE and returned full body, reset partial
let is_partial_response = resp.headers().get(CONTENT_RANGE).is_some();
if resume_from > 0 && !is_partial_response {
// Server did not honor range → start over
drop(part_file);
fs::remove_file(&part_path).ok();
// Reset local accounting; we also reinitialize the progress bar below
// and reopen the part file. No need to re-read this variable afterwards.
let _ = 0; // avoid unused-assignment lint for resume_from
// Plain GET without conditional headers
let req2 = client.get(url);
resp = req2.send()?.error_for_status()?;
bar.stop("restarting");
bar = crate::ui::BytesProgress::start(pb_total, "Downloading", 0);
// Reopen the part file since we dropped it
part_file = OpenOptions::new()
.create(true)
.read(true)
@@ -842,10 +926,6 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
Ok(())
}
/// Run an interactive model downloader UI (2-step):
/// 1) Choose model base (tiny, small, base, medium, large)
/// 2) Choose model type/variant specific to that base
/// Displays meta info (size and last updated). Does not show raw ggml filenames.
pub fn run_interactive_model_downloader() -> Result<()> {
use crate::ui;
@@ -877,7 +957,6 @@ pub fn run_interactive_model_downloader() -> Result<()> {
ui::intro("PolyScribe model downloader");
// Build Select items for bases with counts and size ranges
let mut base_labels: Vec<String> = Vec::new();
for base in &ordered_bases {
let variants = &by_base[base];
@@ -904,7 +983,6 @@ pub fn run_interactive_model_downloader() -> Result<()> {
let base_idx = ui::prompt_select("Choose a model base", &base_refs)?;
let chosen_base = ordered_bases[base_idx].clone();
// Prepare variant list for chosen base
let mut variants = by_base.remove(&chosen_base).unwrap_or_default();
variants.sort_by(|a, b| {
let rank = |v: &str| match v {
@@ -917,7 +995,6 @@ pub fn run_interactive_model_downloader() -> Result<()> {
.then_with(|| a.variant.cmp(&b.variant))
});
// Build Multi-Select items for variants
let mut variant_labels: Vec<String> = Vec::new();
for m in &variants {
let size = format_size_mb(m.size.as_ref().copied());
@@ -953,7 +1030,6 @@ pub fn run_interactive_model_downloader() -> Result<()> {
ui::println_above_bars("Downloading selected models...");
// Setup multi-progress when multiple items are selected
let labels: Vec<String> = picks
.iter()
.map(|&i| {
@@ -961,12 +1037,12 @@ pub fn run_interactive_model_downloader() -> Result<()> {
format!("{} ({})", m.name, format_size_mb(m.size))
})
.collect();
let mut pm = ui::progress::ProgressManager::default_for_files(labels.len());
let mut pm = ui::progress::FileProgress::default_for_files(labels.len());
pm.init_files(&labels);
for (bar_idx, idx) in picks.into_iter().enumerate() {
let picked = variants[idx].clone();
pm.set_per_message(bar_idx, "downloading");
pm.set_file_message(bar_idx, "downloading");
let _path = ensure_model_available_noninteractive(&picked.name)?;
pm.mark_file_done(bar_idx);
ui::success(format!("Ready: {}", picked.name));
@@ -977,9 +1053,6 @@ pub fn run_interactive_model_downloader() -> Result<()> {
Ok(())
}
/// Verify/update local models by comparing with the online manifest.
/// - If a model file exists and matches expected size/hash (when provided), it is kept.
/// - If missing or mismatched, it will be downloaded.
pub fn update_local_models() -> Result<()> {
use crate::ui;
use std::collections::HashMap;
@@ -990,7 +1063,6 @@ pub fn update_local_models() -> Result<()> {
ui::info("Checking locally available models, then verifying against the online manifest…");
// Index manifest by filename and by stem/display name for matching.
let mut by_file: HashMap<String, ModelEntry> = HashMap::new();
let mut by_stem_or_name: HashMap<String, ModelEntry> = HashMap::new();
for m in manifest {
@@ -1007,7 +1079,6 @@ pub fn update_local_models() -> Result<()> {
let mut updated = 0usize;
let mut up_to_date = 0usize;
// Enumerate only local .bin files.
let rd = fs::read_dir(&dir).with_context(|| format!("reading models dir {}", dir.display()))?;
let entries: Vec<_> = rd.flatten().collect();
@@ -1034,7 +1105,6 @@ pub fn update_local_models() -> Result<()> {
let file_lc = file_name.to_ascii_lowercase();
let stem_lc = file_lc.strip_suffix(".bin").unwrap_or(&file_lc).to_string();
// Try to find a matching manifest entry for this local file.
let mut manifest_entry = by_file
.get(&file_lc)
.or_else(|| by_stem_or_name.get(&stem_lc))
@@ -1048,24 +1118,20 @@ pub fn update_local_models() -> Result<()> {
continue;
};
// Enrich metadata before verification (helps when API lacked size/hash)
let _ = enrich_entry_via_head(&mut m);
// Determine target filename from manifest; if different, download to the canonical name.
let target_path = if m.file.eq_ignore_ascii_case(&file_name) {
path.clone()
} else {
dir.join(&m.file)
};
// If the target already exists and matches (size/hash when available), it is up-to-date.
if target_path.exists() && file_matches(&target_path, m.size, m.sha256.as_deref())? {
crate::dlog!(1, "OK: {}", target_path.display());
up_to_date += 1;
continue;
}
// If the current file is the same as the target and mismatched, remove before re-download.
if target_path == path && target_path.exists() {
crate::ilog!("Updating {}", file_name);
let _ = fs::remove_file(&target_path);
@@ -1088,3 +1154,76 @@ pub fn update_local_models() -> Result<()> {
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
#[test]
fn test_cache_bypass_environment() {
unsafe {
env::remove_var(ConfigService::ENV_NO_CACHE_MANIFEST);
}
assert!(!should_bypass_cache());
unsafe {
env::set_var(ConfigService::ENV_NO_CACHE_MANIFEST, "1");
}
assert!(should_bypass_cache());
unsafe {
env::remove_var(ConfigService::ENV_NO_CACHE_MANIFEST);
}
}
#[test]
fn test_cache_ttl_environment() {
unsafe {
env::remove_var(ConfigService::ENV_MANIFEST_TTL_SECONDS);
}
assert_eq!(
get_cache_ttl(),
ConfigService::DEFAULT_MANIFEST_CACHE_TTL_SECONDS
);
unsafe {
env::set_var(ConfigService::ENV_MANIFEST_TTL_SECONDS, "3600");
}
assert_eq!(get_cache_ttl(), 3600);
unsafe {
env::remove_var(ConfigService::ENV_MANIFEST_TTL_SECONDS);
}
}
#[test]
fn test_cached_manifest_serialization() {
let entries = vec![ModelEntry {
name: "test".to_string(),
file: "test.bin".to_string(),
url: "https://example.com/test.bin".to_string(),
size: Some(1024),
sha256: Some("abc123".to_string()),
last_modified: Some("2023-01-01T00:00:00Z".to_string()),
base: "test".to_string(),
variant: "default".to_string(),
}];
let cached = CachedManifest {
fetched_at: 1234567890,
etag: Some("etag123".to_string()),
last_modified: Some("2023-01-01T00:00:00Z".to_string()),
entries: entries.clone(),
};
let json = serde_json::to_string(&cached).unwrap();
let deserialized: CachedManifest = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.fetched_at, cached.fetched_at);
assert_eq!(deserialized.etag, cached.etag);
assert_eq!(deserialized.last_modified, cached.last_modified);
assert_eq!(deserialized.entries.len(), entries.len());
assert_eq!(deserialized.entries[0].name, entries[0].name);
}
}