[refactor] rename and simplify ProgressManager to FileProgress, enhance caching logic, update Hugging Face API integration, and clean up unused comments
Some checks failed
CI / build (push) Has been cancelled
Some checks failed
CI / build (push) Has been cancelled
This commit is contained in:
@@ -1,9 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
//! Model management for PolyScribe: discovery, download, and verification.
|
||||
//! Fetches the live file table from Hugging Face, using size and sha256
|
||||
//! data for verification. Falls back to scraping the repository tree page
|
||||
//! if the JSON API is unavailable or incomplete. No built-in manifest.
|
||||
|
||||
use crate::config::ConfigService;
|
||||
use crate::prelude::*;
|
||||
use anyhow::{Context, anyhow};
|
||||
use chrono::{DateTime, Utc};
|
||||
@@ -12,13 +9,13 @@ use reqwest::blocking::Client;
|
||||
use reqwest::header::{
|
||||
ACCEPT_RANGES, CONTENT_LENGTH, CONTENT_RANGE, ETAG, IF_RANGE, LAST_MODIFIED, RANGE,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::collections::BTreeSet;
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
use std::io::{Read, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::{Duration, Instant};
|
||||
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
||||
|
||||
fn format_size_mb(size: Option<u64>) -> String {
|
||||
match size {
|
||||
@@ -35,7 +32,6 @@ fn format_size_gib(bytes: u64) -> String {
|
||||
format!("{gib:.2} GiB")
|
||||
}
|
||||
|
||||
// Short date formatter (RFC -> yyyy-mm-dd)
|
||||
fn short_date(s: &str) -> String {
|
||||
DateTime::parse_from_rfc3339(s)
|
||||
.ok()
|
||||
@@ -43,12 +39,10 @@ fn short_date(s: &str) -> String {
|
||||
.unwrap_or_else(|| s.to_string())
|
||||
}
|
||||
|
||||
// Free disk space using libc::statvfs (already in Cargo)
|
||||
fn free_space_bytes_for_path(path: &Path) -> Result<u64> {
|
||||
use libc::statvfs;
|
||||
use std::ffi::CString;
|
||||
|
||||
// use parent dir or current dir if none
|
||||
let dir = if path.is_dir() {
|
||||
path
|
||||
} else {
|
||||
@@ -66,9 +60,7 @@ fn free_space_bytes_for_path(path: &Path) -> Result<u64> {
|
||||
}
|
||||
}
|
||||
|
||||
// Minimal mirror note shown in single-line style
|
||||
fn mirror_label(url: &str) -> &'static str {
|
||||
// Very light heuristic; replace with your actual mirror selection if you have it
|
||||
if url.contains("eu") {
|
||||
"EU mirror"
|
||||
} else if url.contains("us") {
|
||||
@@ -78,7 +70,6 @@ fn mirror_label(url: &str) -> &'static str {
|
||||
}
|
||||
}
|
||||
|
||||
// Perform a HEAD to get size/etag/last-modified and fill what we can
|
||||
type HeadMeta = (Option<u64>, Option<String>, Option<String>, bool);
|
||||
|
||||
fn head_entry(client: &Client, url: &str) -> Result<HeadMeta> {
|
||||
@@ -107,39 +98,27 @@ fn head_entry(client: &Client, url: &str) -> Result<HeadMeta> {
|
||||
Ok((len, etag, last_mod, ranges_ok))
|
||||
}
|
||||
|
||||
/// Represents a downloadable Whisper model artifact.
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
struct ModelEntry {
|
||||
/// Display name and local short name (informational; may equal stem of file)
|
||||
name: String,
|
||||
/// Remote file name (with extension)
|
||||
file: String,
|
||||
/// Remote URL
|
||||
url: String,
|
||||
/// Expected file size (optional)
|
||||
size: Option<u64>,
|
||||
/// Expected SHA-256 in hex (optional)
|
||||
sha256: Option<String>,
|
||||
/// New: last modified timestamp string if available
|
||||
last_modified: Option<String>,
|
||||
/// New: parsed base and variant for 2-step UI
|
||||
base: String,
|
||||
variant: String,
|
||||
}
|
||||
|
||||
// -------- Hugging Face API integration --------
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct HfModelInfo {
|
||||
// Returned sometimes at /api/models/{repo}
|
||||
siblings: Option<Vec<HfFile>>,
|
||||
// Returned when using `?expand=files`
|
||||
files: Option<Vec<HfFile>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct HfLfsInfo {
|
||||
// Sometimes an "oid" like "sha256:<hex>"
|
||||
oid: Option<String>,
|
||||
size: Option<u64>,
|
||||
sha256: Option<String>,
|
||||
@@ -147,53 +126,33 @@ struct HfLfsInfo {
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct HfFile {
|
||||
// Relative filename within repo (e.g., "ggml-tiny.bin")
|
||||
rfilename: String,
|
||||
// Size reported at top-level for non-LFS files; often present
|
||||
size: Option<u64>,
|
||||
// Some entries include sha256 at top level
|
||||
sha256: Option<String>,
|
||||
// LFS metadata with size and possibly sha256 embedded
|
||||
lfs: Option<HfLfsInfo>,
|
||||
// New: last modified timestamp provided by HF API on expanded files
|
||||
#[serde(rename = "lastModified")]
|
||||
last_modified: Option<String>,
|
||||
}
|
||||
|
||||
fn parse_base_variant(display_name: &str) -> (String, String) {
|
||||
// display_name is name without ggml-/gguf- and without .bin
|
||||
// Examples:
|
||||
// - "tiny" -> base=tiny, variant=default
|
||||
// - "tiny.en" -> base=tiny, variant=en
|
||||
// - "base" -> base=base, variant=default
|
||||
// - "large-v2" -> base=large, variant=v2
|
||||
// - "large-v3" -> base=large, variant=v3
|
||||
// - "medium" -> base=medium, variant=default
|
||||
let mut variant = "default".to_string();
|
||||
|
||||
// Split off dot-based suffix (e.g., ".en")
|
||||
let mut head = display_name;
|
||||
if let Some((h, rest)) = display_name.split_once('.') {
|
||||
head = h;
|
||||
// if there is more than one dot, just keep everything after first as variant
|
||||
variant = rest.to_string();
|
||||
}
|
||||
|
||||
// Handle hyphenated versions like large-v2
|
||||
if let Some((b, v)) = head.split_once('-') {
|
||||
return (b.to_string(), v.to_string());
|
||||
}
|
||||
|
||||
(head.to_string(), variant)
|
||||
}
|
||||
|
||||
/// Build a manifest by calling the Hugging Face API for a repo.
|
||||
/// Prefers the plain API URL, then retries with `?expand=files` if needed.
|
||||
fn hf_repo_manifest_api(repo: &str) -> Result<Vec<ModelEntry>> {
|
||||
let client = Client::builder().user_agent("polyscribe/0.1").build()?;
|
||||
let client = Client::builder()
|
||||
.user_agent(ConfigService::user_agent())
|
||||
.build()?;
|
||||
|
||||
// 1) Try the plain API you specified
|
||||
let base = format!("https://huggingface.co/api/models/{}", repo);
|
||||
let base = ConfigService::hf_api_base_for(repo);
|
||||
let resp = client.get(&base).send()?;
|
||||
let mut entries = if resp.status().is_success() {
|
||||
let info: HfModelInfo = resp.json()?;
|
||||
@@ -202,7 +161,6 @@ fn hf_repo_manifest_api(repo: &str) -> Result<Vec<ModelEntry>> {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
// 2) If empty, try with expand=files (some repos require this for full file listing)
|
||||
if entries.is_empty() {
|
||||
let url = format!("{base}?expand=files");
|
||||
let resp2 = client.get(&url).send()?;
|
||||
@@ -228,7 +186,6 @@ fn hf_info_to_entries(repo: &str, info: HfModelInfo) -> Result<Vec<ModelEntry>>
|
||||
continue;
|
||||
}
|
||||
|
||||
// Derive a simple display name from the file stem
|
||||
let stem = fname.strip_suffix(".bin").unwrap_or(&fname).to_string();
|
||||
let name_no_prefix = stem
|
||||
.strip_prefix("ggml-")
|
||||
@@ -236,7 +193,6 @@ fn hf_info_to_entries(repo: &str, info: HfModelInfo) -> Result<Vec<ModelEntry>>
|
||||
.unwrap_or(&stem)
|
||||
.to_string();
|
||||
|
||||
// Prefer explicit sha256; else try to parse from LFS oid "sha256:<hex>"
|
||||
let sha_from_lfs = f.lfs.as_ref().and_then(|l| {
|
||||
l.sha256.clone().or_else(|| {
|
||||
l.oid
|
||||
@@ -268,12 +224,11 @@ fn hf_info_to_entries(repo: &str, info: HfModelInfo) -> Result<Vec<ModelEntry>>
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
// -------- HTML scraping fallback (tree view) --------
|
||||
|
||||
/// Scrape the repository tree page when the API doesn't return a usable list.
|
||||
/// Note: sizes and hashes are generally unavailable in this path.
|
||||
fn scrape_tree_manifest(repo: &str) -> Result<Vec<ModelEntry>> {
|
||||
let client = Client::builder().user_agent("polyscribe/0.1").build()?;
|
||||
let client = Client::builder()
|
||||
.user_agent(ConfigService::user_agent())
|
||||
.build()?;
|
||||
|
||||
let url = format!("https://huggingface.co/{}/tree/main?recursive=1", repo);
|
||||
let resp = client.get(&url).send()?;
|
||||
@@ -282,10 +237,6 @@ fn scrape_tree_manifest(repo: &str) -> Result<Vec<ModelEntry>> {
|
||||
}
|
||||
let html = resp.text()?;
|
||||
|
||||
// Extract .bin paths from links. Match both blob/main and resolve/main.
|
||||
// Example matches:
|
||||
// - /{repo}/blob/main/ggml-base.en.bin
|
||||
// - /{repo}/resolve/main/ggml-base.en.bin
|
||||
let mut files = BTreeSet::new();
|
||||
for mat in html.match_indices(".bin") {
|
||||
let end = mat.0 + 4;
|
||||
@@ -346,13 +297,8 @@ fn scrape_tree_manifest(repo: &str) -> Result<Vec<ModelEntry>> {
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
// -------- Metadata enrichment via HEAD (size/hash/last-modified) --------
|
||||
|
||||
fn parse_sha_from_header_value(s: &str) -> Option<String> {
|
||||
// Common HF patterns:
|
||||
// - ETag: "SHA256:<hex>"
|
||||
// - X-Linked-ETag: "SHA256:<hex>"
|
||||
// - Sometimes weak etags: W/"SHA256:<hex>"
|
||||
let lower = s.to_ascii_lowercase();
|
||||
if let Some(idx) = lower.find("sha256:") {
|
||||
let tail = &lower[idx + "sha256:".len()..];
|
||||
@@ -365,14 +311,13 @@ fn parse_sha_from_header_value(s: &str) -> Option<String> {
|
||||
}
|
||||
|
||||
fn enrich_entry_via_head(entry: &mut ModelEntry) -> Result<()> {
|
||||
// If we already have everything, nothing to do
|
||||
if entry.size.is_some() && entry.sha256.is_some() && entry.last_modified.is_some() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let client = Client::builder()
|
||||
.user_agent("polyscribe/0.1")
|
||||
.timeout(Duration::from_secs(8))
|
||||
.user_agent(ConfigService::user_agent())
|
||||
.timeout(Duration::from_secs(ConfigService::http_timeout_secs()))
|
||||
.build()?;
|
||||
|
||||
let mut head_url = entry.url.clone();
|
||||
@@ -397,7 +342,6 @@ fn enrich_entry_via_head(entry: &mut ModelEntry) -> Result<()> {
|
||||
let mut filled_sha = false;
|
||||
let mut filled_lm = false;
|
||||
|
||||
// Content-Length
|
||||
if entry.size.is_none()
|
||||
&& let Some(sz) = resp
|
||||
.headers()
|
||||
@@ -409,7 +353,6 @@ fn enrich_entry_via_head(entry: &mut ModelEntry) -> Result<()> {
|
||||
filled_size = true;
|
||||
}
|
||||
|
||||
// SHA256 from headers if available
|
||||
if entry.sha256.is_none() {
|
||||
let _ = resp
|
||||
.headers()
|
||||
@@ -433,7 +376,6 @@ fn enrich_entry_via_head(entry: &mut ModelEntry) -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
// Last-Modified
|
||||
if entry.last_modified.is_none() {
|
||||
let _ = resp
|
||||
.headers()
|
||||
@@ -477,28 +419,204 @@ fn enrich_entry_via_head(entry: &mut ModelEntry) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// -------- Online manifest (API first, then scrape) --------
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct CachedManifest {
|
||||
fetched_at: u64,
|
||||
etag: Option<String>,
|
||||
last_modified: Option<String>,
|
||||
entries: Vec<ModelEntry>,
|
||||
}
|
||||
|
||||
fn get_cache_dir() -> Result<PathBuf> {
|
||||
Ok(ConfigService::manifest_cache_dir()
|
||||
.ok_or_else(|| anyhow!("could not determine platform directories"))?)
|
||||
}
|
||||
|
||||
fn get_cached_manifest_path() -> Result<PathBuf> {
|
||||
let cache_dir = get_cache_dir()?;
|
||||
Ok(cache_dir.join(ConfigService::manifest_cache_filename()))
|
||||
}
|
||||
|
||||
fn should_bypass_cache() -> bool {
|
||||
ConfigService::bypass_manifest_cache()
|
||||
}
|
||||
|
||||
fn get_cache_ttl() -> u64 {
|
||||
ConfigService::manifest_cache_ttl_seconds()
|
||||
}
|
||||
|
||||
fn load_cached_manifest() -> Option<CachedManifest> {
|
||||
if should_bypass_cache() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let cache_path = get_cached_manifest_path().ok()?;
|
||||
if !cache_path.exists() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let cache_file = File::open(cache_path).ok()?;
|
||||
let cached: CachedManifest = serde_json::from_reader(cache_file).ok()?;
|
||||
|
||||
let now = SystemTime::now().duration_since(UNIX_EPOCH).ok()?.as_secs();
|
||||
|
||||
let ttl = get_cache_ttl();
|
||||
if now.saturating_sub(cached.fetched_at) > ttl {
|
||||
crate::dlog!(
|
||||
1,
|
||||
"Cache expired (age: {}s, TTL: {}s)",
|
||||
now.saturating_sub(cached.fetched_at),
|
||||
ttl
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
crate::dlog!(
|
||||
1,
|
||||
"Using cached manifest (age: {}s)",
|
||||
now.saturating_sub(cached.fetched_at)
|
||||
);
|
||||
Some(cached)
|
||||
}
|
||||
|
||||
fn save_manifest_to_cache(
|
||||
entries: &[ModelEntry],
|
||||
etag: Option<&str>,
|
||||
last_modified: Option<&str>,
|
||||
) -> Result<()> {
|
||||
if should_bypass_cache() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let cache_dir = get_cache_dir()?;
|
||||
fs::create_dir_all(&cache_dir)?;
|
||||
|
||||
let cache_path = get_cached_manifest_path()?;
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map_err(|_| anyhow!("system time error"))?
|
||||
.as_secs();
|
||||
|
||||
let cached = CachedManifest {
|
||||
fetched_at: now,
|
||||
etag: etag.map(|s| s.to_string()),
|
||||
last_modified: last_modified.map(|s| s.to_string()),
|
||||
entries: entries.to_vec(),
|
||||
};
|
||||
|
||||
let cache_file = OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.truncate(true)
|
||||
.open(&cache_path)
|
||||
.with_context(|| format!("opening cache file {}", cache_path.display()))?;
|
||||
|
||||
serde_json::to_writer_pretty(cache_file, &cached)
|
||||
.with_context(|| "serializing cached manifest")?;
|
||||
|
||||
crate::dlog!(1, "Saved manifest to cache: {} entries", entries.len());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn fetch_manifest_with_cache() -> Result<Vec<ModelEntry>> {
|
||||
let cached = load_cached_manifest();
|
||||
|
||||
let client = Client::builder()
|
||||
.user_agent(ConfigService::user_agent())
|
||||
.build()?;
|
||||
let repo = ConfigService::hf_repo();
|
||||
let base_url = ConfigService::hf_api_base_for(&repo);
|
||||
|
||||
let mut req = client.get(&base_url);
|
||||
if let Some(ref cached) = cached {
|
||||
if let Some(ref etag) = cached.etag {
|
||||
req = req.header("If-None-Match", format!("\"{}\"", etag));
|
||||
} else if let Some(ref last_mod) = cached.last_modified {
|
||||
req = req.header("If-Modified-Since", last_mod);
|
||||
}
|
||||
}
|
||||
|
||||
let resp = req.send()?;
|
||||
|
||||
if resp.status().as_u16() == 304 {
|
||||
if let Some(cached) = cached {
|
||||
crate::dlog!(1, "Manifest not modified, using cache");
|
||||
return Ok(cached.entries);
|
||||
}
|
||||
}
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(anyhow!("HF API {} for {}", resp.status(), base_url).into());
|
||||
}
|
||||
|
||||
let etag = resp
|
||||
.headers()
|
||||
.get(ETAG)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.trim_matches('"').to_string());
|
||||
|
||||
let last_modified = resp
|
||||
.headers()
|
||||
.get(LAST_MODIFIED)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let info: HfModelInfo = resp.json()?;
|
||||
let mut entries = hf_info_to_entries(&repo, info)?;
|
||||
|
||||
if entries.is_empty() {
|
||||
let url = format!("{}?expand=files", base_url);
|
||||
let resp2 = client.get(&url).send()?;
|
||||
if !resp2.status().is_success() {
|
||||
return Err(anyhow!("HF API {} for {}", resp2.status(), url).into());
|
||||
}
|
||||
let info: HfModelInfo = resp2.json()?;
|
||||
entries = hf_info_to_entries(&repo, info)?;
|
||||
}
|
||||
|
||||
if entries.is_empty() {
|
||||
return Err(anyhow!("HF API returned no usable .bin files").into());
|
||||
}
|
||||
|
||||
let _ = save_manifest_to_cache(&entries, etag.as_deref(), last_modified.as_deref());
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
/// Returns the current manifest (online only).
|
||||
fn current_manifest() -> Result<Vec<ModelEntry>> {
|
||||
let started = Instant::now();
|
||||
crate::dlog!(1, "Fetching HF manifest…");
|
||||
|
||||
// 1) Load from API, else scrape
|
||||
let mut list = match hf_repo_manifest_api("ggerganov/whisper.cpp") {
|
||||
let mut list = match fetch_manifest_with_cache() {
|
||||
Ok(list) if !list.is_empty() => {
|
||||
crate::dlog!(1, "Manifest loaded from HF API ({} entries)", list.len());
|
||||
crate::dlog!(
|
||||
1,
|
||||
"Manifest loaded from HF API with cache ({} entries)",
|
||||
list.len()
|
||||
);
|
||||
list
|
||||
}
|
||||
_ => {
|
||||
crate::ilog!("Falling back to scraping the repository tree page");
|
||||
let scraped = scrape_tree_manifest("ggerganov/whisper.cpp")?;
|
||||
crate::dlog!(1, "Manifest loaded via scrape ({} entries)", scraped.len());
|
||||
scraped
|
||||
crate::ilog!("Cache failed, falling back to direct API");
|
||||
let repo = ConfigService::hf_repo();
|
||||
let list = match hf_repo_manifest_api(&repo) {
|
||||
Ok(list) if !list.is_empty() => {
|
||||
crate::dlog!(1, "Manifest loaded from HF API ({} entries)", list.len());
|
||||
list
|
||||
}
|
||||
_ => {
|
||||
crate::ilog!("Falling back to scraping the repository tree page");
|
||||
let scraped = scrape_tree_manifest(&repo)?;
|
||||
crate::dlog!(1, "Manifest loaded via scrape ({} entries)", scraped.len());
|
||||
scraped
|
||||
}
|
||||
};
|
||||
|
||||
let _ = save_manifest_to_cache(&list, None, None);
|
||||
list
|
||||
}
|
||||
};
|
||||
|
||||
// 2) Enrich missing metadata so the UI can show sizes and hashes
|
||||
let mut need_enrich = 0usize;
|
||||
for m in &list {
|
||||
if m.size.is_none() || m.sha256.is_none() || m.last_modified.is_none() {
|
||||
@@ -532,8 +650,6 @@ fn current_manifest() -> Result<Vec<ModelEntry>> {
|
||||
Ok(list)
|
||||
}
|
||||
|
||||
/// Pick the best local Whisper model in the given directory.
|
||||
/// Heuristic: choose the largest .bin file by size. Returns None if none found.
|
||||
pub fn pick_best_local_model(dir: &Path) -> Option<PathBuf> {
|
||||
let rd = fs::read_dir(dir).ok()?;
|
||||
rd.flatten()
|
||||
@@ -549,39 +665,23 @@ pub fn pick_best_local_model(dir: &Path) -> Option<PathBuf> {
|
||||
.map(|(_, p)| p)
|
||||
}
|
||||
|
||||
/// Returns the directory where models should be stored based on platform conventions.
|
||||
fn resolve_models_dir() -> Result<PathBuf> {
|
||||
let dirs = directories::ProjectDirs::from("dev", "polyscribe", "polyscribe")
|
||||
.ok_or_else(|| anyhow!("could not determine platform directories"))?;
|
||||
let data_dir = dirs.data_dir().join("models");
|
||||
Ok(data_dir)
|
||||
Ok(ConfigService::models_dir(None)
|
||||
.ok_or_else(|| anyhow!("could not determine models directory"))?)
|
||||
}
|
||||
|
||||
// Example of a non-interactive path ensuring a given model by name exists, with improved copy.
|
||||
// Wire this into CLI flags as needed.
|
||||
/// Ensures a model is available by name, downloading it if necessary.
|
||||
/// This is a non-interactive version that doesn't prompt the user.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `name` - Name of the model to ensure is available
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<PathBuf>` - Path to the downloaded model file on success
|
||||
pub fn ensure_model_available_noninteractive(name: &str) -> Result<PathBuf> {
|
||||
let entry = find_manifest_entry(name)?.ok_or_else(|| anyhow!("unknown model: {name}"))?;
|
||||
|
||||
// Resolve destination file path; ensure XDG path (or your existing logic)
|
||||
let dir = resolve_models_dir()?; // implement or reuse your existing directory resolver
|
||||
let dir = resolve_models_dir()?;
|
||||
fs::create_dir_all(&dir).ok();
|
||||
let dest = dir.join(&entry.file);
|
||||
|
||||
// If already matches, early return
|
||||
if file_matches(&dest, entry.size, entry.sha256.as_deref())? {
|
||||
crate::ui::info(format!("Already up to date: {}", dest.display()));
|
||||
return Ok(dest);
|
||||
}
|
||||
|
||||
// Single-line header
|
||||
let base = &entry.base;
|
||||
let variant = &entry.variant;
|
||||
let size_str = format_size_mb(entry.size);
|
||||
@@ -596,9 +696,16 @@ pub fn ensure_model_available_noninteractive(name: &str) -> Result<PathBuf> {
|
||||
Ok(dest)
|
||||
}
|
||||
|
||||
pub fn clear_manifest_cache() -> Result<()> {
|
||||
let cache_path = get_cached_manifest_path()?;
|
||||
if cache_path.exists() {
|
||||
fs::remove_file(&cache_path)?;
|
||||
crate::dlog!(1, "Cleared manifest cache");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn find_manifest_entry(name: &str) -> Result<Option<ModelEntry>> {
|
||||
// Accept either manifest display name, file stem, or direct file name.
|
||||
// Normalize: strip ".bin" for comparisons and also handle input that already includes it.
|
||||
let wanted_name = name
|
||||
.strip_suffix(".bin")
|
||||
.unwrap_or(name)
|
||||
@@ -622,10 +729,6 @@ fn find_manifest_entry(name: &str) -> Result<Option<ModelEntry>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
// Return true if the file at `path` matches expected size and/or sha256 (when provided).
|
||||
// - If sha256 is provided, verify it (preferred).
|
||||
// - Else if size is provided, check size.
|
||||
// - If neither provided, return false (cannot verify).
|
||||
fn file_matches(path: &Path, size: Option<u64>, sha256_hex: Option<&str>) -> Result<bool> {
|
||||
if !path.exists() {
|
||||
return Ok(false);
|
||||
@@ -655,21 +758,14 @@ fn file_matches(path: &Path, size: Option<u64>, sha256_hex: Option<&str>) -> Res
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
// Download with:
|
||||
// - Free-space preflight (size * 1.1 overhead).
|
||||
// - Resume via Range if .part exists and server supports it.
|
||||
// - Atomic write: download to .part (temp) then rename.
|
||||
// - Checksum verification when available.
|
||||
// - Single-line progress UI.
|
||||
fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
|
||||
let url = &entry.url;
|
||||
let client = Client::builder()
|
||||
.user_agent("polyscribe-model-downloader/1")
|
||||
.user_agent(ConfigService::downloader_user_agent())
|
||||
.build()?;
|
||||
|
||||
crate::ui::info(format!("Resolving source: {} ({})", mirror_label(url), url));
|
||||
|
||||
// HEAD for size/etag/ranges
|
||||
let (mut total_len, remote_etag, _remote_last_mod, ranges_ok) =
|
||||
head_entry(&client, url).context("probing remote file")?;
|
||||
|
||||
@@ -710,9 +806,6 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
|
||||
.open(&part_path)
|
||||
.with_context(|| format!("opening {}", part_path.display()))?;
|
||||
|
||||
// Build request:
|
||||
// - Fresh download: plain GET (no If-None-Match).
|
||||
// - Resume: Range + optional If-Range with ETag.
|
||||
let mut req = client.get(url);
|
||||
if ranges_ok && resume_from > 0 {
|
||||
req = req.header(RANGE, format!("bytes={resume_from}-"));
|
||||
@@ -729,30 +822,21 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
|
||||
let start = Instant::now();
|
||||
let mut resp = req.send()?.error_for_status()?;
|
||||
|
||||
// Defensive: if server returns 304 but we don't have a valid cached copy, retry without conditionals.
|
||||
if resp.status().as_u16() == 304 && resume_from == 0 {
|
||||
// Fresh download must not be conditional; redo as plain GET
|
||||
let req2 = client.get(url);
|
||||
resp = req2.send()?.error_for_status()?;
|
||||
}
|
||||
|
||||
// If server ignored RANGE and returned full body, reset partial
|
||||
let is_partial_response = resp.headers().get(CONTENT_RANGE).is_some();
|
||||
if resume_from > 0 && !is_partial_response {
|
||||
// Server did not honor range → start over
|
||||
drop(part_file);
|
||||
fs::remove_file(&part_path).ok();
|
||||
// Reset local accounting; we also reinitialize the progress bar below
|
||||
// and reopen the part file. No need to re-read this variable afterwards.
|
||||
let _ = 0; // avoid unused-assignment lint for resume_from
|
||||
|
||||
// Plain GET without conditional headers
|
||||
let req2 = client.get(url);
|
||||
resp = req2.send()?.error_for_status()?;
|
||||
bar.stop("restarting");
|
||||
bar = crate::ui::BytesProgress::start(pb_total, "Downloading", 0);
|
||||
|
||||
// Reopen the part file since we dropped it
|
||||
part_file = OpenOptions::new()
|
||||
.create(true)
|
||||
.read(true)
|
||||
@@ -842,10 +926,6 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Run an interactive model downloader UI (2-step):
|
||||
/// 1) Choose model base (tiny, small, base, medium, large)
|
||||
/// 2) Choose model type/variant specific to that base
|
||||
/// Displays meta info (size and last updated). Does not show raw ggml filenames.
|
||||
pub fn run_interactive_model_downloader() -> Result<()> {
|
||||
use crate::ui;
|
||||
|
||||
@@ -877,7 +957,6 @@ pub fn run_interactive_model_downloader() -> Result<()> {
|
||||
|
||||
ui::intro("PolyScribe model downloader");
|
||||
|
||||
// Build Select items for bases with counts and size ranges
|
||||
let mut base_labels: Vec<String> = Vec::new();
|
||||
for base in &ordered_bases {
|
||||
let variants = &by_base[base];
|
||||
@@ -904,7 +983,6 @@ pub fn run_interactive_model_downloader() -> Result<()> {
|
||||
let base_idx = ui::prompt_select("Choose a model base", &base_refs)?;
|
||||
let chosen_base = ordered_bases[base_idx].clone();
|
||||
|
||||
// Prepare variant list for chosen base
|
||||
let mut variants = by_base.remove(&chosen_base).unwrap_or_default();
|
||||
variants.sort_by(|a, b| {
|
||||
let rank = |v: &str| match v {
|
||||
@@ -917,7 +995,6 @@ pub fn run_interactive_model_downloader() -> Result<()> {
|
||||
.then_with(|| a.variant.cmp(&b.variant))
|
||||
});
|
||||
|
||||
// Build Multi-Select items for variants
|
||||
let mut variant_labels: Vec<String> = Vec::new();
|
||||
for m in &variants {
|
||||
let size = format_size_mb(m.size.as_ref().copied());
|
||||
@@ -953,7 +1030,6 @@ pub fn run_interactive_model_downloader() -> Result<()> {
|
||||
|
||||
ui::println_above_bars("Downloading selected models...");
|
||||
|
||||
// Setup multi-progress when multiple items are selected
|
||||
let labels: Vec<String> = picks
|
||||
.iter()
|
||||
.map(|&i| {
|
||||
@@ -961,12 +1037,12 @@ pub fn run_interactive_model_downloader() -> Result<()> {
|
||||
format!("{} ({})", m.name, format_size_mb(m.size))
|
||||
})
|
||||
.collect();
|
||||
let mut pm = ui::progress::ProgressManager::default_for_files(labels.len());
|
||||
let mut pm = ui::progress::FileProgress::default_for_files(labels.len());
|
||||
pm.init_files(&labels);
|
||||
|
||||
for (bar_idx, idx) in picks.into_iter().enumerate() {
|
||||
let picked = variants[idx].clone();
|
||||
pm.set_per_message(bar_idx, "downloading");
|
||||
pm.set_file_message(bar_idx, "downloading");
|
||||
let _path = ensure_model_available_noninteractive(&picked.name)?;
|
||||
pm.mark_file_done(bar_idx);
|
||||
ui::success(format!("Ready: {}", picked.name));
|
||||
@@ -977,9 +1053,6 @@ pub fn run_interactive_model_downloader() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Verify/update local models by comparing with the online manifest.
|
||||
/// - If a model file exists and matches expected size/hash (when provided), it is kept.
|
||||
/// - If missing or mismatched, it will be downloaded.
|
||||
pub fn update_local_models() -> Result<()> {
|
||||
use crate::ui;
|
||||
use std::collections::HashMap;
|
||||
@@ -990,7 +1063,6 @@ pub fn update_local_models() -> Result<()> {
|
||||
|
||||
ui::info("Checking locally available models, then verifying against the online manifest…");
|
||||
|
||||
// Index manifest by filename and by stem/display name for matching.
|
||||
let mut by_file: HashMap<String, ModelEntry> = HashMap::new();
|
||||
let mut by_stem_or_name: HashMap<String, ModelEntry> = HashMap::new();
|
||||
for m in manifest {
|
||||
@@ -1007,7 +1079,6 @@ pub fn update_local_models() -> Result<()> {
|
||||
let mut updated = 0usize;
|
||||
let mut up_to_date = 0usize;
|
||||
|
||||
// Enumerate only local .bin files.
|
||||
let rd = fs::read_dir(&dir).with_context(|| format!("reading models dir {}", dir.display()))?;
|
||||
let entries: Vec<_> = rd.flatten().collect();
|
||||
|
||||
@@ -1034,7 +1105,6 @@ pub fn update_local_models() -> Result<()> {
|
||||
let file_lc = file_name.to_ascii_lowercase();
|
||||
let stem_lc = file_lc.strip_suffix(".bin").unwrap_or(&file_lc).to_string();
|
||||
|
||||
// Try to find a matching manifest entry for this local file.
|
||||
let mut manifest_entry = by_file
|
||||
.get(&file_lc)
|
||||
.or_else(|| by_stem_or_name.get(&stem_lc))
|
||||
@@ -1048,24 +1118,20 @@ pub fn update_local_models() -> Result<()> {
|
||||
continue;
|
||||
};
|
||||
|
||||
// Enrich metadata before verification (helps when API lacked size/hash)
|
||||
let _ = enrich_entry_via_head(&mut m);
|
||||
|
||||
// Determine target filename from manifest; if different, download to the canonical name.
|
||||
let target_path = if m.file.eq_ignore_ascii_case(&file_name) {
|
||||
path.clone()
|
||||
} else {
|
||||
dir.join(&m.file)
|
||||
};
|
||||
|
||||
// If the target already exists and matches (size/hash when available), it is up-to-date.
|
||||
if target_path.exists() && file_matches(&target_path, m.size, m.sha256.as_deref())? {
|
||||
crate::dlog!(1, "OK: {}", target_path.display());
|
||||
up_to_date += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// If the current file is the same as the target and mismatched, remove before re-download.
|
||||
if target_path == path && target_path.exists() {
|
||||
crate::ilog!("Updating {}", file_name);
|
||||
let _ = fs::remove_file(&target_path);
|
||||
@@ -1088,3 +1154,76 @@ pub fn update_local_models() -> Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::env;
|
||||
|
||||
#[test]
|
||||
fn test_cache_bypass_environment() {
|
||||
unsafe {
|
||||
env::remove_var(ConfigService::ENV_NO_CACHE_MANIFEST);
|
||||
}
|
||||
assert!(!should_bypass_cache());
|
||||
|
||||
unsafe {
|
||||
env::set_var(ConfigService::ENV_NO_CACHE_MANIFEST, "1");
|
||||
}
|
||||
assert!(should_bypass_cache());
|
||||
|
||||
unsafe {
|
||||
env::remove_var(ConfigService::ENV_NO_CACHE_MANIFEST);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_ttl_environment() {
|
||||
unsafe {
|
||||
env::remove_var(ConfigService::ENV_MANIFEST_TTL_SECONDS);
|
||||
}
|
||||
assert_eq!(
|
||||
get_cache_ttl(),
|
||||
ConfigService::DEFAULT_MANIFEST_CACHE_TTL_SECONDS
|
||||
);
|
||||
|
||||
unsafe {
|
||||
env::set_var(ConfigService::ENV_MANIFEST_TTL_SECONDS, "3600");
|
||||
}
|
||||
assert_eq!(get_cache_ttl(), 3600);
|
||||
|
||||
unsafe {
|
||||
env::remove_var(ConfigService::ENV_MANIFEST_TTL_SECONDS);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cached_manifest_serialization() {
|
||||
let entries = vec![ModelEntry {
|
||||
name: "test".to_string(),
|
||||
file: "test.bin".to_string(),
|
||||
url: "https://example.com/test.bin".to_string(),
|
||||
size: Some(1024),
|
||||
sha256: Some("abc123".to_string()),
|
||||
last_modified: Some("2023-01-01T00:00:00Z".to_string()),
|
||||
base: "test".to_string(),
|
||||
variant: "default".to_string(),
|
||||
}];
|
||||
|
||||
let cached = CachedManifest {
|
||||
fetched_at: 1234567890,
|
||||
etag: Some("etag123".to_string()),
|
||||
last_modified: Some("2023-01-01T00:00:00Z".to_string()),
|
||||
entries: entries.clone(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&cached).unwrap();
|
||||
let deserialized: CachedManifest = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(deserialized.fetched_at, cached.fetched_at);
|
||||
assert_eq!(deserialized.etag, cached.etag);
|
||||
assert_eq!(deserialized.last_modified, cached.last_modified);
|
||||
assert_eq!(deserialized.entries.len(), entries.len());
|
||||
assert_eq!(deserialized.entries[0].name, entries[0].name);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user