// SPDX-License-Identifier: MIT use crate::config::ConfigService; use crate::prelude::*; use anyhow::{Context, anyhow}; use chrono::{DateTime, Utc}; use hex::ToHex; use reqwest::blocking::Client; use reqwest::header::{ ACCEPT_RANGES, CONTENT_LENGTH, CONTENT_RANGE, ETAG, IF_RANGE, LAST_MODIFIED, RANGE, }; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use std::collections::BTreeSet; use std::fs::{self, File, OpenOptions}; use std::io::{Read, Write}; use std::path::{Path, PathBuf}; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; fn format_size_mb(size: Option) -> String { match size { Some(bytes) => { let mib = bytes as f64 / 1024.0 / 1024.0; format!("{mib:.2} MiB") } None => "? MiB".to_string(), } } fn format_size_gib(bytes: u64) -> String { let gib = bytes as f64 / 1024.0 / 1024.0 / 1024.0; format!("{gib:.2} GiB") } fn short_date(s: &str) -> String { DateTime::parse_from_rfc3339(s) .ok() .map(|dt| dt.with_timezone(&Utc).format("%Y-%m-%d").to_string()) .unwrap_or_else(|| s.to_string()) } fn free_space_bytes_for_path(path: &Path) -> Result { use libc::statvfs; use std::ffi::CString; let dir = if path.is_dir() { path } else { path.parent().unwrap_or(Path::new(".")) }; let cpath = CString::new(dir.as_os_str().to_string_lossy().as_bytes()) .map_err(|_| anyhow!("invalid path for statvfs"))?; unsafe { let mut s: libc::statvfs = std::mem::zeroed(); if statvfs(cpath.as_ptr(), &mut s) != 0 { return Err(anyhow!("statvfs failed for {}", dir.display()).into()); } Ok((s.f_bsize as u64) * (s.f_bavail as u64)) } } fn mirror_label(url: &str) -> &'static str { if url.contains("eu") { "EU mirror" } else if url.contains("us") { "US mirror" } else { "source" } } type HeadMeta = (Option, Option, Option, bool); fn head_entry(client: &Client, url: &str) -> Result { let resp = client.head(url).send()?.error_for_status()?; let len = resp .headers() .get(CONTENT_LENGTH) .and_then(|v| v.to_str().ok()) .and_then(|s| s.parse::().ok()); let etag = resp .headers() .get(ETAG) .and_then(|v| v.to_str().ok()) .map(|s| s.trim_matches('"').to_string()); let last_mod = resp .headers() .get(LAST_MODIFIED) .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()); let ranges_ok = resp .headers() .get(ACCEPT_RANGES) .and_then(|v| v.to_str().ok()) .map(|s| s.to_ascii_lowercase().contains("bytes")) .unwrap_or(false); Ok((len, etag, last_mod, ranges_ok)) } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] struct ModelEntry { name: String, file: String, url: String, size: Option, sha256: Option, last_modified: Option, base: String, variant: String, } #[derive(Debug, Deserialize)] struct HfModelInfo { siblings: Option>, files: Option>, } #[derive(Debug, Deserialize)] struct HfLfsInfo { oid: Option, size: Option, sha256: Option, } #[derive(Debug, Deserialize)] struct HfFile { rfilename: String, size: Option, sha256: Option, lfs: Option, #[serde(rename = "lastModified")] last_modified: Option, } fn parse_base_variant(display_name: &str) -> (String, String) { let mut variant = "default".to_string(); let mut head = display_name; if let Some((h, rest)) = display_name.split_once('.') { head = h; variant = rest.to_string(); } if let Some((b, v)) = head.split_once('-') { return (b.to_string(), v.to_string()); } (head.to_string(), variant) } fn hf_repo_manifest_api(repo: &str) -> Result> { let client = Client::builder() .user_agent(ConfigService::user_agent()) .build()?; let base = ConfigService::hf_api_base_for(repo); let resp = client.get(&base).send()?; let mut entries = if resp.status().is_success() { let info: HfModelInfo = resp.json()?; hf_info_to_entries(repo, info)? } else { Vec::new() }; if entries.is_empty() { let url = format!("{base}?expand=files"); let resp2 = client.get(&url).send()?; if !resp2.status().is_success() { return Err(anyhow!("HF API {} for {}", resp2.status(), url).into()); } let info: HfModelInfo = resp2.json()?; entries = hf_info_to_entries(repo, info)?; } if entries.is_empty() { return Err(anyhow!("HF API returned no usable .bin files").into()); } Ok(entries) } fn hf_info_to_entries(repo: &str, info: HfModelInfo) -> Result> { let files = info.files.or(info.siblings).unwrap_or_default(); let mut out = Vec::new(); for f in files { let fname = f.rfilename; if !fname.ends_with(".bin") { continue; } let stem = fname.strip_suffix(".bin").unwrap_or(&fname).to_string(); let name_no_prefix = stem .strip_prefix("ggml-") .or_else(|| stem.strip_prefix("gguf-")) .unwrap_or(&stem) .to_string(); let sha_from_lfs = f.lfs.as_ref().and_then(|l| { l.sha256.clone().or_else(|| { l.oid .as_ref() .and_then(|oid| oid.strip_prefix("sha256:").map(|s| s.to_string())) }) }); let size = f.size.or_else(|| f.lfs.as_ref().and_then(|l| l.size)); let url = format!( "https://huggingface.co/{}/resolve/main/{}?download=true", repo, fname ); let (base, variant) = parse_base_variant(&name_no_prefix); out.push(ModelEntry { name: name_no_prefix, file: fname, url, size, sha256: sha_from_lfs.or(f.sha256), last_modified: f.last_modified.clone(), base, variant, }); } Ok(out) } fn scrape_tree_manifest(repo: &str) -> Result> { let client = Client::builder() .user_agent(ConfigService::user_agent()) .build()?; let url = format!("https://huggingface.co/{}/tree/main?recursive=1", repo); let resp = client.get(&url).send()?; if !resp.status().is_success() { return Err(anyhow!("tree page HTTP {} for {}", resp.status(), url).into()); } let html = resp.text()?; let mut files = BTreeSet::new(); for mat in html.match_indices(".bin") { let end = mat.0 + 4; let start = html[..end] .rfind("/blob/main/") .or_else(|| html[..end].rfind("/resolve/main/")); if let Some(s) = start { let slice = &html[s..end]; if let Some(pos) = slice.find("/blob/main/") { let path = slice[pos + "/blob/main/".len()..].to_string(); files.insert(path); } else if let Some(pos) = slice.find("/resolve/main/") { let path = slice[pos + "/resolve/main/".len()..].to_string(); files.insert(path); } } } let mut out = Vec::new(); for fname in files.into_iter().filter(|f| f.ends_with(".bin")) { let stem = fname.strip_suffix(".bin").unwrap_or(&fname).to_string(); let name = stem .rsplit('/') .next() .unwrap_or(&stem) .strip_prefix("ggml-") .or_else(|| { stem.rsplit('/') .next() .unwrap_or(&stem) .strip_prefix("gguf-") }) .unwrap_or_else(|| stem.rsplit('/').next().unwrap_or(&stem)) .to_string(); let url = format!( "https://huggingface.co/{}/resolve/main/{}?download=true", repo, fname ); let (base, variant) = parse_base_variant(&name); out.push(ModelEntry { name, file: fname.rsplit('/').next().unwrap_or(&fname).to_string(), url, size: None, sha256: None, last_modified: None, base, variant, }); } if out.is_empty() { return Err(anyhow!("tree scraper found no .bin files").into()); } Ok(out) } fn parse_sha_from_header_value(s: &str) -> Option { let lower = s.to_ascii_lowercase(); if let Some(idx) = lower.find("sha256:") { let tail = &lower[idx + "sha256:".len()..]; let hex: String = tail.chars().take_while(|c| c.is_ascii_hexdigit()).collect(); if !hex.is_empty() { return Some(hex); } } None } fn enrich_entry_via_head(entry: &mut ModelEntry) -> Result<()> { if entry.size.is_some() && entry.sha256.is_some() && entry.last_modified.is_some() { return Ok(()); } let client = Client::builder() .user_agent(ConfigService::user_agent()) .timeout(Duration::from_secs(ConfigService::http_timeout_secs())) .build()?; let mut head_url = entry.url.clone(); if let Some((base, _)) = entry.url.split_once('?') { head_url = base.to_string(); } let started = Instant::now(); crate::dlog!(1, "HEAD {}", head_url); let resp = client .head(&head_url) .send() .or_else(|_| client.head(&entry.url).send())?; if !resp.status().is_success() { crate::dlog!(1, "HEAD {} -> HTTP {}", head_url, resp.status()); return Ok(()); } let mut filled_size = false; let mut filled_sha = false; let mut filled_lm = false; if entry.size.is_none() && let Some(sz) = resp .headers() .get(CONTENT_LENGTH) .and_then(|v| v.to_str().ok()) .and_then(|s| s.parse::().ok()) { entry.size = Some(sz); filled_size = true; } if entry.sha256.is_none() { let _ = resp .headers() .get("x-linked-etag") .and_then(|v| v.to_str().ok()) .and_then(parse_sha_from_header_value) .map(|hex| { entry.sha256 = Some(hex); filled_sha = true; }); if !filled_sha { let _ = resp .headers() .get(ETAG) .and_then(|v| v.to_str().ok()) .and_then(parse_sha_from_header_value) .map(|hex| { entry.sha256 = Some(hex); filled_sha = true; }); } } if entry.last_modified.is_none() { let _ = resp .headers() .get(LAST_MODIFIED) .and_then(|v| v.to_str().ok()) .map(|v| { entry.last_modified = Some(v.to_string()); filled_lm = true; }); } let elapsed_ms = started.elapsed().as_millis(); crate::dlog!( 1, "HEAD ok in {} ms for {} (size: {}, sha256: {}, last-modified: {})", elapsed_ms, entry.file, if filled_size { "new" } else if entry.size.is_some() { "kept" } else { "missing" }, if filled_sha { "new" } else if entry.sha256.is_some() { "kept" } else { "missing" }, if filled_lm { "new" } else if entry.last_modified.is_some() { "kept" } else { "missing" }, ); Ok(()) } #[derive(Debug, Serialize, Deserialize)] struct CachedManifest { fetched_at: u64, etag: Option, last_modified: Option, entries: Vec, } fn get_cache_dir() -> Result { Ok(ConfigService::manifest_cache_dir() .ok_or_else(|| anyhow!("could not determine platform directories"))?) } fn get_cached_manifest_path() -> Result { let cache_dir = get_cache_dir()?; Ok(cache_dir.join(ConfigService::manifest_cache_filename())) } fn should_bypass_cache() -> bool { ConfigService::bypass_manifest_cache() } fn get_cache_ttl() -> u64 { ConfigService::manifest_cache_ttl_seconds() } fn load_cached_manifest() -> Option { if should_bypass_cache() { return None; } let cache_path = get_cached_manifest_path().ok()?; if !cache_path.exists() { return None; } let cache_file = File::open(cache_path).ok()?; let cached: CachedManifest = serde_json::from_reader(cache_file).ok()?; let now = SystemTime::now().duration_since(UNIX_EPOCH).ok()?.as_secs(); let ttl = get_cache_ttl(); if now.saturating_sub(cached.fetched_at) > ttl { crate::dlog!( 1, "Cache expired (age: {}s, TTL: {}s)", now.saturating_sub(cached.fetched_at), ttl ); return None; } crate::dlog!( 1, "Using cached manifest (age: {}s)", now.saturating_sub(cached.fetched_at) ); Some(cached) } fn save_manifest_to_cache( entries: &[ModelEntry], etag: Option<&str>, last_modified: Option<&str>, ) -> Result<()> { if should_bypass_cache() { return Ok(()); } let cache_dir = get_cache_dir()?; fs::create_dir_all(&cache_dir)?; let cache_path = get_cached_manifest_path()?; let now = SystemTime::now() .duration_since(UNIX_EPOCH) .map_err(|_| anyhow!("system time error"))? .as_secs(); let cached = CachedManifest { fetched_at: now, etag: etag.map(|s| s.to_string()), last_modified: last_modified.map(|s| s.to_string()), entries: entries.to_vec(), }; let cache_file = OpenOptions::new() .create(true) .write(true) .truncate(true) .open(&cache_path) .with_context(|| format!("opening cache file {}", cache_path.display()))?; serde_json::to_writer_pretty(cache_file, &cached) .with_context(|| "serializing cached manifest")?; crate::dlog!(1, "Saved manifest to cache: {} entries", entries.len()); Ok(()) } fn fetch_manifest_with_cache() -> Result> { let cached = load_cached_manifest(); let client = Client::builder() .user_agent(ConfigService::user_agent()) .build()?; let repo = ConfigService::hf_repo(); let base_url = ConfigService::hf_api_base_for(&repo); let mut req = client.get(&base_url); if let Some(ref cached) = cached { if let Some(ref etag) = cached.etag { req = req.header("If-None-Match", format!("\"{}\"", etag)); } else if let Some(ref last_mod) = cached.last_modified { req = req.header("If-Modified-Since", last_mod); } } let resp = req.send()?; if resp.status().as_u16() == 304 { if let Some(cached) = cached { crate::dlog!(1, "Manifest not modified, using cache"); return Ok(cached.entries); } } if !resp.status().is_success() { return Err(anyhow!("HF API {} for {}", resp.status(), base_url).into()); } let etag = resp .headers() .get(ETAG) .and_then(|v| v.to_str().ok()) .map(|s| s.trim_matches('"').to_string()); let last_modified = resp .headers() .get(LAST_MODIFIED) .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()); let info: HfModelInfo = resp.json()?; let mut entries = hf_info_to_entries(&repo, info)?; if entries.is_empty() { let url = format!("{}?expand=files", base_url); let resp2 = client.get(&url).send()?; if !resp2.status().is_success() { return Err(anyhow!("HF API {} for {}", resp2.status(), url).into()); } let info: HfModelInfo = resp2.json()?; entries = hf_info_to_entries(&repo, info)?; } if entries.is_empty() { return Err(anyhow!("HF API returned no usable .bin files").into()); } let _ = save_manifest_to_cache(&entries, etag.as_deref(), last_modified.as_deref()); Ok(entries) } fn current_manifest() -> Result> { let started = Instant::now(); crate::dlog!(1, "Fetching HF manifest…"); let mut list = match fetch_manifest_with_cache() { Ok(list) if !list.is_empty() => { crate::dlog!( 1, "Manifest loaded from HF API with cache ({} entries)", list.len() ); list } _ => { crate::ilog!("Cache failed, falling back to direct API"); let repo = ConfigService::hf_repo(); let list = match hf_repo_manifest_api(&repo) { Ok(list) if !list.is_empty() => { crate::dlog!(1, "Manifest loaded from HF API ({} entries)", list.len()); list } _ => { crate::ilog!("Falling back to scraping the repository tree page"); let scraped = scrape_tree_manifest(&repo)?; crate::dlog!(1, "Manifest loaded via scrape ({} entries)", scraped.len()); scraped } }; let _ = save_manifest_to_cache(&list, None, None); list } }; let mut need_enrich = 0usize; for m in &list { if m.size.is_none() || m.sha256.is_none() || m.last_modified.is_none() { need_enrich += 1; } } crate::dlog!(1, "Enriching {} entries via HEAD…", need_enrich); for m in &mut list { if m.size.is_none() || m.sha256.is_none() || m.last_modified.is_none() { let _ = enrich_entry_via_head(m); } } let elapsed_ms = started.elapsed().as_millis(); let sizes_known = list.iter().filter(|m| m.size.is_some()).count(); let hashes_known = list.iter().filter(|m| m.sha256.is_some()).count(); crate::dlog!( 1, "Manifest ready in {} ms (entries: {}, sizes: {}/{}, hashes: {}/{})", elapsed_ms, list.len(), sizes_known, list.len(), hashes_known, list.len() ); if list.is_empty() { return Err(anyhow!("no usable .bin files discovered").into()); } Ok(list) } pub fn pick_best_local_model(dir: &Path) -> Option { let rd = fs::read_dir(dir).ok()?; rd.flatten() .map(|e| e.path()) .filter(|p| { p.is_file() && p.extension() .and_then(|s| s.to_str()) .is_some_and(|s| s.eq_ignore_ascii_case("bin")) }) .filter_map(|p| fs::metadata(&p).ok().map(|md| (md.len(), p))) .max_by_key(|(sz, _)| *sz) .map(|(_, p)| p) } fn resolve_models_dir() -> Result { Ok(ConfigService::models_dir(None) .ok_or_else(|| anyhow!("could not determine models directory"))?) } pub fn ensure_model_available_noninteractive(name: &str) -> Result { let entry = find_manifest_entry(name)?.ok_or_else(|| anyhow!("unknown model: {name}"))?; let dir = resolve_models_dir()?; fs::create_dir_all(&dir).ok(); let dest = dir.join(&entry.file); if file_matches(&dest, entry.size, entry.sha256.as_deref())? { crate::ui::info(format!("Already up to date: {}", dest.display())); return Ok(dest); } let base = &entry.base; let variant = &entry.variant; let size_str = format_size_mb(entry.size); crate::ui::println_above_bars(format!("Base: {base} • Type: {variant}")); crate::ui::println_above_bars(format!( "Source: {} • Size: {}", mirror_label(&entry.url), size_str )); download_with_progress(&dest, &entry)?; Ok(dest) } pub fn clear_manifest_cache() -> Result<()> { let cache_path = get_cached_manifest_path()?; if cache_path.exists() { fs::remove_file(&cache_path)?; crate::dlog!(1, "Cleared manifest cache"); } Ok(()) } fn find_manifest_entry(name: &str) -> Result> { let wanted_name = name .strip_suffix(".bin") .unwrap_or(name) .to_ascii_lowercase(); let wanted_file = name.to_ascii_lowercase(); for e in current_manifest()? { let file_lc = e.file.to_ascii_lowercase(); let stem_lc = e .file .strip_suffix(".bin") .unwrap_or(&e.file) .to_ascii_lowercase(); if e.name.to_ascii_lowercase() == wanted_name || file_lc == wanted_file || stem_lc == wanted_name { return Ok(Some(e)); } } Ok(None) } fn file_matches(path: &Path, size: Option, sha256_hex: Option<&str>) -> Result { if !path.exists() { return Ok(false); } if let Some(exp_hash) = sha256_hex { let mut f = File::open(path).with_context(|| format!("opening {}", path.display()))?; let mut hasher = Sha256::new(); let mut buf = vec![0u8; 1024 * 1024]; loop { let n = f.read(&mut buf)?; if n == 0 { break; } hasher.update(&buf[..n]); } let actual = hasher.finalize(); let actual_hex = actual.encode_hex::(); return Ok(actual_hex.eq_ignore_ascii_case(exp_hash)); } if let Some(expected) = size { let meta = fs::metadata(path).with_context(|| format!("stat {}", path.display()))?; return Ok(meta.len() == expected); } Ok(false) } fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> { let url = &entry.url; let client = Client::builder() .user_agent(ConfigService::downloader_user_agent()) .build()?; crate::ui::info(format!("Resolving source: {} ({})", mirror_label(url), url)); let (mut total_len, remote_etag, _remote_last_mod, ranges_ok) = head_entry(&client, url).context("probing remote file")?; if total_len.is_none() { total_len = entry.size; } if let Some(expected) = total_len { let free = free_space_bytes_for_path(dest_path)?; let need = expected + (expected / 10) + 16 * 1024 * 1024; if free < need { return Err(anyhow!( "insufficient disk space: need {}, have {}", format_size_mb(Some(need)), format_size_gib(free) ) .into()); } } if dest_path.exists() && file_matches(dest_path, total_len, entry.sha256.as_deref())? { crate::ui::info(format!("Already up to date: {}", dest_path.display())); return Ok(()); } let part_path = dest_path.with_extension("part"); // Guard to cleanup .part on errors struct TempGuard { path: std::path::PathBuf, armed: bool } impl TempGuard { fn disarm(&mut self) { self.armed = false; } } impl Drop for TempGuard { fn drop(&mut self) { if self.armed { let _ = fs::remove_file(&self.path); } } } let mut _tmp_guard = TempGuard { path: part_path.clone(), armed: true }; let mut resume_from: u64 = 0; if part_path.exists() && ranges_ok { let meta = fs::metadata(&part_path)?; resume_from = meta.len(); } let mut part_file = OpenOptions::new() .create(true) .read(true) .append(true) .open(&part_path) .with_context(|| format!("opening {}", part_path.display()))?; let mut req = client.get(url); if ranges_ok && resume_from > 0 { req = req.header(RANGE, format!("bytes={resume_from}-")); if let Some(etag) = &remote_etag { req = req.header(IF_RANGE, format!("\"{etag}\"")); } } crate::ui::info(format!("Download: {}", part_path.display())); let pb_total = total_len.unwrap_or(0); let mut bar = crate::ui::BytesProgress::start(pb_total, "Downloading", resume_from); let start = Instant::now(); let mut resp = req.send()?.error_for_status()?; if resp.status().as_u16() == 304 && resume_from == 0 { let req2 = client.get(url); resp = req2.send()?.error_for_status()?; } let is_partial_response = resp.headers().get(CONTENT_RANGE).is_some(); if resume_from > 0 && !is_partial_response { drop(part_file); fs::remove_file(&part_path).ok(); let req2 = client.get(url); resp = req2.send()?.error_for_status()?; bar.stop("restarting"); bar = crate::ui::BytesProgress::start(pb_total, "Downloading", 0); part_file = OpenOptions::new() .create(true) .read(true) .append(true) .open(&part_path) .with_context(|| format!("opening {}", part_path.display()))?; } { let mut body = resp; let mut buf = vec![0u8; 1024 * 64]; loop { let read = body.read(&mut buf)?; if read == 0 { break; } part_file.write_all(&buf[..read])?; bar.inc(read as u64); } part_file.flush()?; part_file.sync_all()?; } bar.stop("done"); if let Some(expected_hex) = entry.sha256.as_deref() { crate::ui::info("Verify: SHA-256"); let mut f = File::open(&part_path)?; let mut hasher = Sha256::new(); let mut buf = vec![0u8; 1024 * 1024]; loop { let n = f.read(&mut buf)?; if n == 0 { break; } hasher.update(&buf[..n]); } let actual_hex = hasher.finalize().encode_hex::(); if !actual_hex.eq_ignore_ascii_case(expected_hex) { return Err(anyhow!( "checksum mismatch: expected {}, got {}", expected_hex, actual_hex ) .into()); } } else { crate::ui::info("Verify: checksum not provided by source (skipped)"); } if let Some(parent) = dest_path.parent() { fs::create_dir_all(parent).ok(); } drop(part_file); fs::rename(&part_path, dest_path) .with_context(|| format!("renaming {} → {}", part_path.display(), dest_path.display()))?; _tmp_guard.disarm(); let final_size = fs::metadata(dest_path).map(|m| m.len()).ok(); let elapsed = start.elapsed().as_secs_f64(); if let Some(sz) = final_size { if elapsed > 0.0 { let mib = sz as f64 / 1024.0 / 1024.0; let rate = mib / elapsed; crate::ui::success(format!( "Saved: {} ({}) in {:.1}s, {:.1} MiB/s", dest_path.display(), format_size_mb(Some(sz)), elapsed, rate )); } else { crate::ui::success(format!( "Saved: {} ({})", dest_path.display(), format_size_mb(Some(sz)) )); } } else { crate::ui::success(format!( "Saved: {} ({})", dest_path.display(), format_size_mb(None) )); } Ok(()) } pub fn run_interactive_model_downloader() -> Result<()> { use crate::ui; if crate::is_no_interaction() || !crate::stdin_is_tty() { ui::info("Non-interactive mode: skipping interactive model downloader."); return Ok(()); } let available = current_manifest()?; use std::collections::BTreeMap; let mut by_base: BTreeMap> = BTreeMap::new(); for m in available.into_iter() { by_base.entry(m.base.clone()).or_default().push(m); } let pref_order = ["tiny", "small", "base", "medium", "large"]; let mut ordered_bases: Vec = Vec::new(); for b in pref_order { if by_base.contains_key(b) { ordered_bases.push(b.to_string()); } } for b in by_base.keys() { if !ordered_bases.iter().any(|x| x == b) { ordered_bases.push(b.clone()); } } ui::intro("PolyScribe model downloader"); let mut base_labels: Vec = Vec::new(); for base in &ordered_bases { let variants = &by_base[base]; let (min_sz, max_sz) = variants.iter().fold((None, None), |acc, m| { let (mut lo, mut hi) = acc; if let Some(sz) = m.size { lo = Some(lo.map(|v: u64| v.min(sz)).unwrap_or(sz)); hi = Some(hi.map(|v: u64| v.max(sz)).unwrap_or(sz)); } (lo, hi) }); let size_info = match (min_sz, max_sz) { (Some(lo), Some(hi)) if lo != hi => format!( " ~{:.2}–{:.2} MB", lo as f64 / 1_000_000.0, hi as f64 / 1_000_000.0 ), (Some(sz), _) => format!(" ~{:.2} MB", sz as f64 / 1_000_000.0), _ => String::new(), }; base_labels.push(format!("{} ({} types){}", base, variants.len(), size_info)); } let base_refs: Vec<&str> = base_labels.iter().map(|s| s.as_str()).collect(); let base_idx = ui::prompt_select("Choose a model base", &base_refs)?; let chosen_base = ordered_bases[base_idx].clone(); let mut variants = by_base.remove(&chosen_base).unwrap_or_default(); variants.sort_by(|a, b| { let rank = |v: &str| match v { "default" => 0, "en" => 1, _ => 2, }; rank(&a.variant) .cmp(&rank(&b.variant)) .then_with(|| a.variant.cmp(&b.variant)) }); let mut variant_labels: Vec = Vec::new(); for m in &variants { let size = format_size_mb(m.size.as_ref().copied()); let updated = m .last_modified .as_deref() .map(short_date) .map(|d| format!(" • updated {}", d)) .unwrap_or_default(); let variant_label = if m.variant == "default" { "default" } else { &m.variant }; variant_labels.push(format!("{} ({}{})", variant_label, size, updated)); } let variant_refs: Vec<&str> = variant_labels.iter().map(|s| s.as_str()).collect(); let mut defaults = vec![false; variant_refs.len()]; if !defaults.is_empty() { defaults[0] = true; } let picks = ui::prompt_multi_select( &format!("Select types for '{}'", chosen_base), &variant_refs, Some(&defaults), )?; if picks.is_empty() { ui::warn("No types selected; aborting."); ui::outro("No changes made."); return Ok(()); } ui::println_above_bars("Downloading selected models..."); let labels: Vec = picks .iter() .map(|&i| { let m = &variants[i]; format!("{} ({})", m.name, format_size_mb(m.size)) }) .collect(); let mut pm = ui::progress::FileProgress::default_for_files(labels.len()); pm.init_files(&labels); for (bar_idx, idx) in picks.into_iter().enumerate() { let picked = variants[idx].clone(); pm.set_file_message(bar_idx, "downloading"); let _path = ensure_model_available_noninteractive(&picked.name)?; pm.mark_file_done(bar_idx); ui::success(format!("Ready: {}", picked.name)); } pm.finish_total("all done"); ui::outro("Model selection complete."); Ok(()) } pub fn update_local_models() -> Result<()> { use crate::ui; use std::collections::HashMap; let manifest = current_manifest()?; let dir = crate::models_dir_path(); fs::create_dir_all(&dir).ok(); ui::info("Checking locally available models, then verifying against the online manifest…"); let mut by_file: HashMap = HashMap::new(); let mut by_stem_or_name: HashMap = HashMap::new(); for m in manifest { by_file.insert(m.file.to_ascii_lowercase(), m.clone()); let stem = m .file .strip_suffix(".bin") .unwrap_or(&m.file) .to_ascii_lowercase(); by_stem_or_name.insert(stem, m.clone()); by_stem_or_name.insert(m.name.to_ascii_lowercase(), m); } let mut updated = 0usize; let mut up_to_date = 0usize; let rd = fs::read_dir(&dir).with_context(|| format!("reading models dir {}", dir.display()))?; let entries: Vec<_> = rd.flatten().collect(); if entries.is_empty() { ui::info("No local models found."); } else { for ent in entries { let path = ent.path(); if !path.is_file() { continue; } let is_bin = path .extension() .and_then(|s| s.to_str()) .is_some_and(|s| s.eq_ignore_ascii_case("bin")); if !is_bin { continue; } let file_name = match path.file_name().and_then(|s| s.to_str()) { Some(s) => s.to_string(), None => continue, }; let file_lc = file_name.to_ascii_lowercase(); let stem_lc = file_lc.strip_suffix(".bin").unwrap_or(&file_lc).to_string(); let mut manifest_entry = by_file .get(&file_lc) .or_else(|| by_stem_or_name.get(&stem_lc)) .cloned(); let Some(mut m) = manifest_entry.take() else { ui::warn(format!( "Skipping unknown local model (not in online manifest): {}", path.display() )); continue; }; let _ = enrich_entry_via_head(&mut m); let target_path = if m.file.eq_ignore_ascii_case(&file_name) { path.clone() } else { dir.join(&m.file) }; if target_path.exists() && file_matches(&target_path, m.size, m.sha256.as_deref())? { crate::dlog!(1, "OK: {}", target_path.display()); up_to_date += 1; continue; } if target_path == path && target_path.exists() { crate::ilog!("Updating {}", file_name); let _ = fs::remove_file(&target_path); } else if !target_path.exists() { crate::ilog!("Fetching latest for '{}' -> {}", file_name, m.file); } else { crate::ilog!("Refreshing {}", target_path.display()); } download_with_progress(&target_path, &m)?; updated += 1; } if updated == 0 { ui::info(format!("All {} local model(s) are up to date.", up_to_date)); } else { ui::info(format!("Updated {updated} local model(s).")); } } Ok(()) } #[cfg(test)] mod tests { use super::*; use std::env; #[test] fn test_cache_bypass_environment() { unsafe { env::remove_var(ConfigService::ENV_NO_CACHE_MANIFEST); } assert!(!should_bypass_cache()); unsafe { env::set_var(ConfigService::ENV_NO_CACHE_MANIFEST, "1"); } assert!(should_bypass_cache()); unsafe { env::remove_var(ConfigService::ENV_NO_CACHE_MANIFEST); } } #[test] fn test_cache_ttl_environment() { unsafe { env::remove_var(ConfigService::ENV_MANIFEST_TTL_SECONDS); } assert_eq!( get_cache_ttl(), ConfigService::DEFAULT_MANIFEST_CACHE_TTL_SECONDS ); unsafe { env::set_var(ConfigService::ENV_MANIFEST_TTL_SECONDS, "3600"); } assert_eq!(get_cache_ttl(), 3600); unsafe { env::remove_var(ConfigService::ENV_MANIFEST_TTL_SECONDS); } } #[test] fn test_cached_manifest_serialization() { let entries = vec![ModelEntry { name: "test".to_string(), file: "test.bin".to_string(), url: "https://example.com/test.bin".to_string(), size: Some(1024), sha256: Some("abc123".to_string()), last_modified: Some("2023-01-01T00:00:00Z".to_string()), base: "test".to_string(), variant: "default".to_string(), }]; let cached = CachedManifest { fetched_at: 1234567890, etag: Some("etag123".to_string()), last_modified: Some("2023-01-01T00:00:00Z".to_string()), entries: entries.clone(), }; let json = serde_json::to_string(&cached).unwrap(); let deserialized: CachedManifest = serde_json::from_str(&json).unwrap(); assert_eq!(deserialized.fetched_at, cached.fetched_at); assert_eq!(deserialized.etag, cached.etag); assert_eq!(deserialized.last_modified, cached.last_modified); assert_eq!(deserialized.entries.len(), entries.len()); assert_eq!(deserialized.entries[0].name, entries[0].name); } }