1238 lines
36 KiB
Rust
1238 lines
36 KiB
Rust
// SPDX-License-Identifier: MIT
|
||
|
||
use crate::config::ConfigService;
|
||
use crate::prelude::*;
|
||
use anyhow::{Context, anyhow};
|
||
use chrono::{DateTime, Utc};
|
||
use hex::ToHex;
|
||
use reqwest::blocking::Client;
|
||
use reqwest::header::{
|
||
ACCEPT_RANGES, CONTENT_LENGTH, CONTENT_RANGE, ETAG, IF_RANGE, LAST_MODIFIED, RANGE,
|
||
};
|
||
use serde::{Deserialize, Serialize};
|
||
use sha2::{Digest, Sha256};
|
||
use std::collections::BTreeSet;
|
||
use std::fs::{self, File, OpenOptions};
|
||
use std::io::{Read, Write};
|
||
use std::path::{Path, PathBuf};
|
||
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
||
|
||
fn format_size_mb(size: Option<u64>) -> String {
|
||
match size {
|
||
Some(bytes) => {
|
||
let mib = bytes as f64 / 1024.0 / 1024.0;
|
||
format!("{mib:.2} MiB")
|
||
}
|
||
None => "? MiB".to_string(),
|
||
}
|
||
}
|
||
|
||
fn format_size_gib(bytes: u64) -> String {
|
||
let gib = bytes as f64 / 1024.0 / 1024.0 / 1024.0;
|
||
format!("{gib:.2} GiB")
|
||
}
|
||
|
||
fn short_date(s: &str) -> String {
|
||
DateTime::parse_from_rfc3339(s)
|
||
.ok()
|
||
.map(|dt| dt.with_timezone(&Utc).format("%Y-%m-%d").to_string())
|
||
.unwrap_or_else(|| s.to_string())
|
||
}
|
||
|
||
fn free_space_bytes_for_path(path: &Path) -> Result<u64> {
|
||
use libc::statvfs;
|
||
use std::ffi::CString;
|
||
|
||
let dir = if path.is_dir() {
|
||
path
|
||
} else {
|
||
path.parent().unwrap_or(Path::new("."))
|
||
};
|
||
|
||
let cpath = CString::new(dir.as_os_str().to_string_lossy().as_bytes())
|
||
.map_err(|_| anyhow!("invalid path for statvfs"))?;
|
||
unsafe {
|
||
let mut s: libc::statvfs = std::mem::zeroed();
|
||
if statvfs(cpath.as_ptr(), &mut s) != 0 {
|
||
return Err(anyhow!("statvfs failed for {}", dir.display()).into());
|
||
}
|
||
Ok((s.f_bsize as u64) * (s.f_bavail as u64))
|
||
}
|
||
}
|
||
|
||
fn mirror_label(url: &str) -> &'static str {
|
||
if url.contains("eu") {
|
||
"EU mirror"
|
||
} else if url.contains("us") {
|
||
"US mirror"
|
||
} else {
|
||
"source"
|
||
}
|
||
}
|
||
|
||
type HeadMeta = (Option<u64>, Option<String>, Option<String>, bool);
|
||
|
||
fn head_entry(client: &Client, url: &str) -> Result<HeadMeta> {
|
||
let resp = client.head(url).send()?.error_for_status()?;
|
||
let len = resp
|
||
.headers()
|
||
.get(CONTENT_LENGTH)
|
||
.and_then(|v| v.to_str().ok())
|
||
.and_then(|s| s.parse::<u64>().ok());
|
||
let etag = resp
|
||
.headers()
|
||
.get(ETAG)
|
||
.and_then(|v| v.to_str().ok())
|
||
.map(|s| s.trim_matches('"').to_string());
|
||
let last_mod = resp
|
||
.headers()
|
||
.get(LAST_MODIFIED)
|
||
.and_then(|v| v.to_str().ok())
|
||
.map(|s| s.to_string());
|
||
let ranges_ok = resp
|
||
.headers()
|
||
.get(ACCEPT_RANGES)
|
||
.and_then(|v| v.to_str().ok())
|
||
.map(|s| s.to_ascii_lowercase().contains("bytes"))
|
||
.unwrap_or(false);
|
||
Ok((len, etag, last_mod, ranges_ok))
|
||
}
|
||
|
||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||
struct ModelEntry {
|
||
name: String,
|
||
file: String,
|
||
url: String,
|
||
size: Option<u64>,
|
||
sha256: Option<String>,
|
||
last_modified: Option<String>,
|
||
base: String,
|
||
variant: String,
|
||
}
|
||
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct HfModelInfo {
|
||
siblings: Option<Vec<HfFile>>,
|
||
files: Option<Vec<HfFile>>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct HfLfsInfo {
|
||
oid: Option<String>,
|
||
size: Option<u64>,
|
||
sha256: Option<String>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct HfFile {
|
||
rfilename: String,
|
||
size: Option<u64>,
|
||
sha256: Option<String>,
|
||
lfs: Option<HfLfsInfo>,
|
||
#[serde(rename = "lastModified")]
|
||
last_modified: Option<String>,
|
||
}
|
||
|
||
fn parse_base_variant(display_name: &str) -> (String, String) {
|
||
let mut variant = "default".to_string();
|
||
let mut head = display_name;
|
||
if let Some((h, rest)) = display_name.split_once('.') {
|
||
head = h;
|
||
variant = rest.to_string();
|
||
}
|
||
if let Some((b, v)) = head.split_once('-') {
|
||
return (b.to_string(), v.to_string());
|
||
}
|
||
(head.to_string(), variant)
|
||
}
|
||
|
||
fn hf_repo_manifest_api(repo: &str) -> Result<Vec<ModelEntry>> {
|
||
let client = Client::builder()
|
||
.user_agent(ConfigService::user_agent())
|
||
.build()?;
|
||
|
||
let base = ConfigService::hf_api_base_for(repo);
|
||
let resp = client.get(&base).send()?;
|
||
let mut entries = if resp.status().is_success() {
|
||
let info: HfModelInfo = resp.json()?;
|
||
hf_info_to_entries(repo, info)?
|
||
} else {
|
||
Vec::new()
|
||
};
|
||
|
||
if entries.is_empty() {
|
||
let url = format!("{base}?expand=files");
|
||
let resp2 = client.get(&url).send()?;
|
||
if !resp2.status().is_success() {
|
||
return Err(anyhow!("HF API {} for {}", resp2.status(), url).into());
|
||
}
|
||
let info: HfModelInfo = resp2.json()?;
|
||
entries = hf_info_to_entries(repo, info)?;
|
||
}
|
||
|
||
if entries.is_empty() {
|
||
return Err(anyhow!("HF API returned no usable .bin files").into());
|
||
}
|
||
Ok(entries)
|
||
}
|
||
|
||
fn hf_info_to_entries(repo: &str, info: HfModelInfo) -> Result<Vec<ModelEntry>> {
|
||
let files = info.files.or(info.siblings).unwrap_or_default();
|
||
let mut out = Vec::new();
|
||
for f in files {
|
||
let fname = f.rfilename;
|
||
if !fname.ends_with(".bin") {
|
||
continue;
|
||
}
|
||
|
||
let stem = fname.strip_suffix(".bin").unwrap_or(&fname).to_string();
|
||
let name_no_prefix = stem
|
||
.strip_prefix("ggml-")
|
||
.or_else(|| stem.strip_prefix("gguf-"))
|
||
.unwrap_or(&stem)
|
||
.to_string();
|
||
|
||
let sha_from_lfs = f.lfs.as_ref().and_then(|l| {
|
||
l.sha256.clone().or_else(|| {
|
||
l.oid
|
||
.as_ref()
|
||
.and_then(|oid| oid.strip_prefix("sha256:").map(|s| s.to_string()))
|
||
})
|
||
});
|
||
|
||
let size = f.size.or_else(|| f.lfs.as_ref().and_then(|l| l.size));
|
||
|
||
let url = format!(
|
||
"https://huggingface.co/{}/resolve/main/{}?download=true",
|
||
repo, fname
|
||
);
|
||
|
||
let (base, variant) = parse_base_variant(&name_no_prefix);
|
||
|
||
out.push(ModelEntry {
|
||
name: name_no_prefix,
|
||
file: fname,
|
||
url,
|
||
size,
|
||
sha256: sha_from_lfs.or(f.sha256),
|
||
last_modified: f.last_modified.clone(),
|
||
base,
|
||
variant,
|
||
});
|
||
}
|
||
Ok(out)
|
||
}
|
||
|
||
|
||
fn scrape_tree_manifest(repo: &str) -> Result<Vec<ModelEntry>> {
|
||
let client = Client::builder()
|
||
.user_agent(ConfigService::user_agent())
|
||
.build()?;
|
||
|
||
let url = format!("https://huggingface.co/{}/tree/main?recursive=1", repo);
|
||
let resp = client.get(&url).send()?;
|
||
if !resp.status().is_success() {
|
||
return Err(anyhow!("tree page HTTP {} for {}", resp.status(), url).into());
|
||
}
|
||
let html = resp.text()?;
|
||
|
||
let mut files = BTreeSet::new();
|
||
for mat in html.match_indices(".bin") {
|
||
let end = mat.0 + 4;
|
||
let start = html[..end]
|
||
.rfind("/blob/main/")
|
||
.or_else(|| html[..end].rfind("/resolve/main/"));
|
||
if let Some(s) = start {
|
||
let slice = &html[s..end];
|
||
if let Some(pos) = slice.find("/blob/main/") {
|
||
let path = slice[pos + "/blob/main/".len()..].to_string();
|
||
files.insert(path);
|
||
} else if let Some(pos) = slice.find("/resolve/main/") {
|
||
let path = slice[pos + "/resolve/main/".len()..].to_string();
|
||
files.insert(path);
|
||
}
|
||
}
|
||
}
|
||
|
||
let mut out = Vec::new();
|
||
for fname in files.into_iter().filter(|f| f.ends_with(".bin")) {
|
||
let stem = fname.strip_suffix(".bin").unwrap_or(&fname).to_string();
|
||
let name = stem
|
||
.rsplit('/')
|
||
.next()
|
||
.unwrap_or(&stem)
|
||
.strip_prefix("ggml-")
|
||
.or_else(|| {
|
||
stem.rsplit('/')
|
||
.next()
|
||
.unwrap_or(&stem)
|
||
.strip_prefix("gguf-")
|
||
})
|
||
.unwrap_or_else(|| stem.rsplit('/').next().unwrap_or(&stem))
|
||
.to_string();
|
||
|
||
let url = format!(
|
||
"https://huggingface.co/{}/resolve/main/{}?download=true",
|
||
repo, fname
|
||
);
|
||
|
||
let (base, variant) = parse_base_variant(&name);
|
||
|
||
out.push(ModelEntry {
|
||
name,
|
||
file: fname.rsplit('/').next().unwrap_or(&fname).to_string(),
|
||
url,
|
||
size: None,
|
||
sha256: None,
|
||
last_modified: None,
|
||
base,
|
||
variant,
|
||
});
|
||
}
|
||
|
||
if out.is_empty() {
|
||
return Err(anyhow!("tree scraper found no .bin files").into());
|
||
}
|
||
Ok(out)
|
||
}
|
||
|
||
|
||
fn parse_sha_from_header_value(s: &str) -> Option<String> {
|
||
let lower = s.to_ascii_lowercase();
|
||
if let Some(idx) = lower.find("sha256:") {
|
||
let tail = &lower[idx + "sha256:".len()..];
|
||
let hex: String = tail.chars().take_while(|c| c.is_ascii_hexdigit()).collect();
|
||
if !hex.is_empty() {
|
||
return Some(hex);
|
||
}
|
||
}
|
||
None
|
||
}
|
||
|
||
fn enrich_entry_via_head(entry: &mut ModelEntry) -> Result<()> {
|
||
if entry.size.is_some() && entry.sha256.is_some() && entry.last_modified.is_some() {
|
||
return Ok(());
|
||
}
|
||
|
||
let client = Client::builder()
|
||
.user_agent(ConfigService::user_agent())
|
||
.timeout(Duration::from_secs(ConfigService::http_timeout_secs()))
|
||
.build()?;
|
||
|
||
let mut head_url = entry.url.clone();
|
||
if let Some((base, _)) = entry.url.split_once('?') {
|
||
head_url = base.to_string();
|
||
}
|
||
|
||
let started = Instant::now();
|
||
crate::dlog!(1, "HEAD {}", head_url);
|
||
|
||
let resp = client
|
||
.head(&head_url)
|
||
.send()
|
||
.or_else(|_| client.head(&entry.url).send())?;
|
||
|
||
if !resp.status().is_success() {
|
||
crate::dlog!(1, "HEAD {} -> HTTP {}", head_url, resp.status());
|
||
return Ok(());
|
||
}
|
||
|
||
let mut filled_size = false;
|
||
let mut filled_sha = false;
|
||
let mut filled_lm = false;
|
||
|
||
if entry.size.is_none()
|
||
&& let Some(sz) = resp
|
||
.headers()
|
||
.get(CONTENT_LENGTH)
|
||
.and_then(|v| v.to_str().ok())
|
||
.and_then(|s| s.parse::<u64>().ok())
|
||
{
|
||
entry.size = Some(sz);
|
||
filled_size = true;
|
||
}
|
||
|
||
if entry.sha256.is_none() {
|
||
let _ = resp
|
||
.headers()
|
||
.get("x-linked-etag")
|
||
.and_then(|v| v.to_str().ok())
|
||
.and_then(parse_sha_from_header_value)
|
||
.map(|hex| {
|
||
entry.sha256 = Some(hex);
|
||
filled_sha = true;
|
||
});
|
||
if !filled_sha {
|
||
let _ = resp
|
||
.headers()
|
||
.get(ETAG)
|
||
.and_then(|v| v.to_str().ok())
|
||
.and_then(parse_sha_from_header_value)
|
||
.map(|hex| {
|
||
entry.sha256 = Some(hex);
|
||
filled_sha = true;
|
||
});
|
||
}
|
||
}
|
||
|
||
if entry.last_modified.is_none() {
|
||
let _ = resp
|
||
.headers()
|
||
.get(LAST_MODIFIED)
|
||
.and_then(|v| v.to_str().ok())
|
||
.map(|v| {
|
||
entry.last_modified = Some(v.to_string());
|
||
filled_lm = true;
|
||
});
|
||
}
|
||
|
||
let elapsed_ms = started.elapsed().as_millis();
|
||
crate::dlog!(
|
||
1,
|
||
"HEAD ok in {} ms for {} (size: {}, sha256: {}, last-modified: {})",
|
||
elapsed_ms,
|
||
entry.file,
|
||
if filled_size {
|
||
"new"
|
||
} else if entry.size.is_some() {
|
||
"kept"
|
||
} else {
|
||
"missing"
|
||
},
|
||
if filled_sha {
|
||
"new"
|
||
} else if entry.sha256.is_some() {
|
||
"kept"
|
||
} else {
|
||
"missing"
|
||
},
|
||
if filled_lm {
|
||
"new"
|
||
} else if entry.last_modified.is_some() {
|
||
"kept"
|
||
} else {
|
||
"missing"
|
||
},
|
||
);
|
||
|
||
Ok(())
|
||
}
|
||
|
||
#[derive(Debug, Serialize, Deserialize)]
|
||
struct CachedManifest {
|
||
fetched_at: u64,
|
||
etag: Option<String>,
|
||
last_modified: Option<String>,
|
||
entries: Vec<ModelEntry>,
|
||
}
|
||
|
||
fn get_cache_dir() -> Result<PathBuf> {
|
||
Ok(ConfigService::manifest_cache_dir()
|
||
.ok_or_else(|| anyhow!("could not determine platform directories"))?)
|
||
}
|
||
|
||
fn get_cached_manifest_path() -> Result<PathBuf> {
|
||
let cache_dir = get_cache_dir()?;
|
||
Ok(cache_dir.join(ConfigService::manifest_cache_filename()))
|
||
}
|
||
|
||
fn should_bypass_cache() -> bool {
|
||
ConfigService::bypass_manifest_cache()
|
||
}
|
||
|
||
fn get_cache_ttl() -> u64 {
|
||
ConfigService::manifest_cache_ttl_seconds()
|
||
}
|
||
|
||
fn load_cached_manifest() -> Option<CachedManifest> {
|
||
if should_bypass_cache() {
|
||
return None;
|
||
}
|
||
|
||
let cache_path = get_cached_manifest_path().ok()?;
|
||
if !cache_path.exists() {
|
||
return None;
|
||
}
|
||
|
||
let cache_file = File::open(cache_path).ok()?;
|
||
let cached: CachedManifest = serde_json::from_reader(cache_file).ok()?;
|
||
|
||
let now = SystemTime::now().duration_since(UNIX_EPOCH).ok()?.as_secs();
|
||
|
||
let ttl = get_cache_ttl();
|
||
if now.saturating_sub(cached.fetched_at) > ttl {
|
||
crate::dlog!(
|
||
1,
|
||
"Cache expired (age: {}s, TTL: {}s)",
|
||
now.saturating_sub(cached.fetched_at),
|
||
ttl
|
||
);
|
||
return None;
|
||
}
|
||
|
||
crate::dlog!(
|
||
1,
|
||
"Using cached manifest (age: {}s)",
|
||
now.saturating_sub(cached.fetched_at)
|
||
);
|
||
Some(cached)
|
||
}
|
||
|
||
fn save_manifest_to_cache(
|
||
entries: &[ModelEntry],
|
||
etag: Option<&str>,
|
||
last_modified: Option<&str>,
|
||
) -> Result<()> {
|
||
if should_bypass_cache() {
|
||
return Ok(());
|
||
}
|
||
|
||
let cache_dir = get_cache_dir()?;
|
||
fs::create_dir_all(&cache_dir)?;
|
||
|
||
let cache_path = get_cached_manifest_path()?;
|
||
let now = SystemTime::now()
|
||
.duration_since(UNIX_EPOCH)
|
||
.map_err(|_| anyhow!("system time error"))?
|
||
.as_secs();
|
||
|
||
let cached = CachedManifest {
|
||
fetched_at: now,
|
||
etag: etag.map(|s| s.to_string()),
|
||
last_modified: last_modified.map(|s| s.to_string()),
|
||
entries: entries.to_vec(),
|
||
};
|
||
|
||
let cache_file = OpenOptions::new()
|
||
.create(true)
|
||
.write(true)
|
||
.truncate(true)
|
||
.open(&cache_path)
|
||
.with_context(|| format!("opening cache file {}", cache_path.display()))?;
|
||
|
||
serde_json::to_writer_pretty(cache_file, &cached)
|
||
.with_context(|| "serializing cached manifest")?;
|
||
|
||
crate::dlog!(1, "Saved manifest to cache: {} entries", entries.len());
|
||
Ok(())
|
||
}
|
||
|
||
fn fetch_manifest_with_cache() -> Result<Vec<ModelEntry>> {
|
||
let cached = load_cached_manifest();
|
||
|
||
let client = Client::builder()
|
||
.user_agent(ConfigService::user_agent())
|
||
.build()?;
|
||
let repo = ConfigService::hf_repo();
|
||
let base_url = ConfigService::hf_api_base_for(&repo);
|
||
|
||
let mut req = client.get(&base_url);
|
||
if let Some(ref cached) = cached {
|
||
if let Some(ref etag) = cached.etag {
|
||
req = req.header("If-None-Match", format!("\"{}\"", etag));
|
||
} else if let Some(ref last_mod) = cached.last_modified {
|
||
req = req.header("If-Modified-Since", last_mod);
|
||
}
|
||
}
|
||
|
||
let resp = req.send()?;
|
||
|
||
if resp.status().as_u16() == 304 {
|
||
if let Some(cached) = cached {
|
||
crate::dlog!(1, "Manifest not modified, using cache");
|
||
return Ok(cached.entries);
|
||
}
|
||
}
|
||
|
||
if !resp.status().is_success() {
|
||
return Err(anyhow!("HF API {} for {}", resp.status(), base_url).into());
|
||
}
|
||
|
||
let etag = resp
|
||
.headers()
|
||
.get(ETAG)
|
||
.and_then(|v| v.to_str().ok())
|
||
.map(|s| s.trim_matches('"').to_string());
|
||
|
||
let last_modified = resp
|
||
.headers()
|
||
.get(LAST_MODIFIED)
|
||
.and_then(|v| v.to_str().ok())
|
||
.map(|s| s.to_string());
|
||
|
||
let info: HfModelInfo = resp.json()?;
|
||
let mut entries = hf_info_to_entries(&repo, info)?;
|
||
|
||
if entries.is_empty() {
|
||
let url = format!("{}?expand=files", base_url);
|
||
let resp2 = client.get(&url).send()?;
|
||
if !resp2.status().is_success() {
|
||
return Err(anyhow!("HF API {} for {}", resp2.status(), url).into());
|
||
}
|
||
let info: HfModelInfo = resp2.json()?;
|
||
entries = hf_info_to_entries(&repo, info)?;
|
||
}
|
||
|
||
if entries.is_empty() {
|
||
return Err(anyhow!("HF API returned no usable .bin files").into());
|
||
}
|
||
|
||
let _ = save_manifest_to_cache(&entries, etag.as_deref(), last_modified.as_deref());
|
||
|
||
Ok(entries)
|
||
}
|
||
|
||
fn current_manifest() -> Result<Vec<ModelEntry>> {
|
||
let started = Instant::now();
|
||
crate::dlog!(1, "Fetching HF manifest…");
|
||
|
||
let mut list = match fetch_manifest_with_cache() {
|
||
Ok(list) if !list.is_empty() => {
|
||
crate::dlog!(
|
||
1,
|
||
"Manifest loaded from HF API with cache ({} entries)",
|
||
list.len()
|
||
);
|
||
list
|
||
}
|
||
_ => {
|
||
crate::ilog!("Cache failed, falling back to direct API");
|
||
let repo = ConfigService::hf_repo();
|
||
let list = match hf_repo_manifest_api(&repo) {
|
||
Ok(list) if !list.is_empty() => {
|
||
crate::dlog!(1, "Manifest loaded from HF API ({} entries)", list.len());
|
||
list
|
||
}
|
||
_ => {
|
||
crate::ilog!("Falling back to scraping the repository tree page");
|
||
let scraped = scrape_tree_manifest(&repo)?;
|
||
crate::dlog!(1, "Manifest loaded via scrape ({} entries)", scraped.len());
|
||
scraped
|
||
}
|
||
};
|
||
|
||
let _ = save_manifest_to_cache(&list, None, None);
|
||
list
|
||
}
|
||
};
|
||
|
||
let mut need_enrich = 0usize;
|
||
for m in &list {
|
||
if m.size.is_none() || m.sha256.is_none() || m.last_modified.is_none() {
|
||
need_enrich += 1;
|
||
}
|
||
}
|
||
crate::dlog!(1, "Enriching {} entries via HEAD…", need_enrich);
|
||
for m in &mut list {
|
||
if m.size.is_none() || m.sha256.is_none() || m.last_modified.is_none() {
|
||
let _ = enrich_entry_via_head(m);
|
||
}
|
||
}
|
||
|
||
let elapsed_ms = started.elapsed().as_millis();
|
||
let sizes_known = list.iter().filter(|m| m.size.is_some()).count();
|
||
let hashes_known = list.iter().filter(|m| m.sha256.is_some()).count();
|
||
crate::dlog!(
|
||
1,
|
||
"Manifest ready in {} ms (entries: {}, sizes: {}/{}, hashes: {}/{})",
|
||
elapsed_ms,
|
||
list.len(),
|
||
sizes_known,
|
||
list.len(),
|
||
hashes_known,
|
||
list.len()
|
||
);
|
||
|
||
if list.is_empty() {
|
||
return Err(anyhow!("no usable .bin files discovered").into());
|
||
}
|
||
Ok(list)
|
||
}
|
||
|
||
pub fn pick_best_local_model(dir: &Path) -> Option<PathBuf> {
|
||
let rd = fs::read_dir(dir).ok()?;
|
||
rd.flatten()
|
||
.map(|e| e.path())
|
||
.filter(|p| {
|
||
p.is_file()
|
||
&& p.extension()
|
||
.and_then(|s| s.to_str())
|
||
.is_some_and(|s| s.eq_ignore_ascii_case("bin"))
|
||
})
|
||
.filter_map(|p| fs::metadata(&p).ok().map(|md| (md.len(), p)))
|
||
.max_by_key(|(sz, _)| *sz)
|
||
.map(|(_, p)| p)
|
||
}
|
||
|
||
fn resolve_models_dir() -> Result<PathBuf> {
|
||
Ok(ConfigService::models_dir(None)
|
||
.ok_or_else(|| anyhow!("could not determine models directory"))?)
|
||
}
|
||
|
||
pub fn ensure_model_available_noninteractive(name: &str) -> Result<PathBuf> {
|
||
let entry = find_manifest_entry(name)?.ok_or_else(|| anyhow!("unknown model: {name}"))?;
|
||
|
||
let dir = resolve_models_dir()?;
|
||
fs::create_dir_all(&dir).ok();
|
||
let dest = dir.join(&entry.file);
|
||
|
||
if file_matches(&dest, entry.size, entry.sha256.as_deref())? {
|
||
crate::ui::info(format!("Already up to date: {}", dest.display()));
|
||
return Ok(dest);
|
||
}
|
||
|
||
let base = &entry.base;
|
||
let variant = &entry.variant;
|
||
let size_str = format_size_mb(entry.size);
|
||
crate::ui::println_above_bars(format!("Base: {base} • Type: {variant}"));
|
||
crate::ui::println_above_bars(format!(
|
||
"Source: {} • Size: {}",
|
||
mirror_label(&entry.url),
|
||
size_str
|
||
));
|
||
|
||
download_with_progress(&dest, &entry)?;
|
||
Ok(dest)
|
||
}
|
||
|
||
pub fn clear_manifest_cache() -> Result<()> {
|
||
let cache_path = get_cached_manifest_path()?;
|
||
if cache_path.exists() {
|
||
fs::remove_file(&cache_path)?;
|
||
crate::dlog!(1, "Cleared manifest cache");
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
fn find_manifest_entry(name: &str) -> Result<Option<ModelEntry>> {
|
||
let wanted_name = name
|
||
.strip_suffix(".bin")
|
||
.unwrap_or(name)
|
||
.to_ascii_lowercase();
|
||
let wanted_file = name.to_ascii_lowercase();
|
||
|
||
for e in current_manifest()? {
|
||
let file_lc = e.file.to_ascii_lowercase();
|
||
let stem_lc = e
|
||
.file
|
||
.strip_suffix(".bin")
|
||
.unwrap_or(&e.file)
|
||
.to_ascii_lowercase();
|
||
if e.name.to_ascii_lowercase() == wanted_name
|
||
|| file_lc == wanted_file
|
||
|| stem_lc == wanted_name
|
||
{
|
||
return Ok(Some(e));
|
||
}
|
||
}
|
||
Ok(None)
|
||
}
|
||
|
||
fn file_matches(path: &Path, size: Option<u64>, sha256_hex: Option<&str>) -> Result<bool> {
|
||
if !path.exists() {
|
||
return Ok(false);
|
||
}
|
||
|
||
if let Some(exp_hash) = sha256_hex {
|
||
let mut f = File::open(path).with_context(|| format!("opening {}", path.display()))?;
|
||
let mut hasher = Sha256::new();
|
||
let mut buf = vec![0u8; 1024 * 1024];
|
||
loop {
|
||
let n = f.read(&mut buf)?;
|
||
if n == 0 {
|
||
break;
|
||
}
|
||
hasher.update(&buf[..n]);
|
||
}
|
||
let actual = hasher.finalize();
|
||
let actual_hex = actual.encode_hex::<String>();
|
||
return Ok(actual_hex.eq_ignore_ascii_case(exp_hash));
|
||
}
|
||
|
||
if let Some(expected) = size {
|
||
let meta = fs::metadata(path).with_context(|| format!("stat {}", path.display()))?;
|
||
return Ok(meta.len() == expected);
|
||
}
|
||
|
||
Ok(false)
|
||
}
|
||
|
||
fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
|
||
let url = &entry.url;
|
||
let client = Client::builder()
|
||
.user_agent(ConfigService::downloader_user_agent())
|
||
.build()?;
|
||
|
||
crate::ui::info(format!("Resolving source: {} ({})", mirror_label(url), url));
|
||
|
||
let (mut total_len, remote_etag, _remote_last_mod, ranges_ok) =
|
||
head_entry(&client, url).context("probing remote file")?;
|
||
|
||
if total_len.is_none() {
|
||
total_len = entry.size;
|
||
}
|
||
|
||
if let Some(expected) = total_len {
|
||
let free = free_space_bytes_for_path(dest_path)?;
|
||
let need = expected + (expected / 10) + 16 * 1024 * 1024;
|
||
if free < need {
|
||
return Err(anyhow!(
|
||
"insufficient disk space: need {}, have {}",
|
||
format_size_mb(Some(need)),
|
||
format_size_gib(free)
|
||
)
|
||
.into());
|
||
}
|
||
}
|
||
|
||
if dest_path.exists() && file_matches(dest_path, total_len, entry.sha256.as_deref())? {
|
||
crate::ui::info(format!("Already up to date: {}", dest_path.display()));
|
||
return Ok(());
|
||
}
|
||
|
||
let part_path = dest_path.with_extension("part");
|
||
// Guard to cleanup .part on errors
|
||
struct TempGuard { path: std::path::PathBuf, armed: bool }
|
||
impl TempGuard { fn disarm(&mut self) { self.armed = false; } }
|
||
impl Drop for TempGuard {
|
||
fn drop(&mut self) { if self.armed { let _ = fs::remove_file(&self.path); } }
|
||
}
|
||
let mut _tmp_guard = TempGuard { path: part_path.clone(), armed: true };
|
||
|
||
let mut resume_from: u64 = 0;
|
||
if part_path.exists() && ranges_ok {
|
||
let meta = fs::metadata(&part_path)?;
|
||
resume_from = meta.len();
|
||
}
|
||
|
||
let mut part_file = OpenOptions::new()
|
||
.create(true)
|
||
.read(true)
|
||
.append(true)
|
||
.open(&part_path)
|
||
.with_context(|| format!("opening {}", part_path.display()))?;
|
||
|
||
let mut req = client.get(url);
|
||
if ranges_ok && resume_from > 0 {
|
||
req = req.header(RANGE, format!("bytes={resume_from}-"));
|
||
if let Some(etag) = &remote_etag {
|
||
req = req.header(IF_RANGE, format!("\"{etag}\""));
|
||
}
|
||
}
|
||
|
||
crate::ui::info(format!("Download: {}", part_path.display()));
|
||
|
||
let pb_total = total_len.unwrap_or(0);
|
||
let mut bar = crate::ui::BytesProgress::start(pb_total, "Downloading", resume_from);
|
||
|
||
let start = Instant::now();
|
||
let mut resp = req.send()?.error_for_status()?;
|
||
|
||
if resp.status().as_u16() == 304 && resume_from == 0 {
|
||
let req2 = client.get(url);
|
||
resp = req2.send()?.error_for_status()?;
|
||
}
|
||
|
||
let is_partial_response = resp.headers().get(CONTENT_RANGE).is_some();
|
||
if resume_from > 0 && !is_partial_response {
|
||
drop(part_file);
|
||
fs::remove_file(&part_path).ok();
|
||
|
||
let req2 = client.get(url);
|
||
resp = req2.send()?.error_for_status()?;
|
||
bar.stop("restarting");
|
||
bar = crate::ui::BytesProgress::start(pb_total, "Downloading", 0);
|
||
|
||
part_file = OpenOptions::new()
|
||
.create(true)
|
||
.read(true)
|
||
.append(true)
|
||
.open(&part_path)
|
||
.with_context(|| format!("opening {}", part_path.display()))?;
|
||
}
|
||
|
||
{
|
||
let mut body = resp;
|
||
let mut buf = vec![0u8; 1024 * 64];
|
||
loop {
|
||
let read = body.read(&mut buf)?;
|
||
if read == 0 {
|
||
break;
|
||
}
|
||
part_file.write_all(&buf[..read])?;
|
||
bar.inc(read as u64);
|
||
}
|
||
part_file.flush()?;
|
||
part_file.sync_all()?;
|
||
}
|
||
|
||
bar.stop("done");
|
||
|
||
if let Some(expected_hex) = entry.sha256.as_deref() {
|
||
crate::ui::info("Verify: SHA-256");
|
||
let mut f = File::open(&part_path)?;
|
||
let mut hasher = Sha256::new();
|
||
let mut buf = vec![0u8; 1024 * 1024];
|
||
loop {
|
||
let n = f.read(&mut buf)?;
|
||
if n == 0 {
|
||
break;
|
||
}
|
||
hasher.update(&buf[..n]);
|
||
}
|
||
let actual_hex = hasher.finalize().encode_hex::<String>();
|
||
if !actual_hex.eq_ignore_ascii_case(expected_hex) {
|
||
return Err(anyhow!(
|
||
"checksum mismatch: expected {}, got {}",
|
||
expected_hex,
|
||
actual_hex
|
||
)
|
||
.into());
|
||
}
|
||
} else {
|
||
crate::ui::info("Verify: checksum not provided by source (skipped)");
|
||
}
|
||
|
||
if let Some(parent) = dest_path.parent() {
|
||
fs::create_dir_all(parent).ok();
|
||
}
|
||
drop(part_file);
|
||
fs::rename(&part_path, dest_path)
|
||
.with_context(|| format!("renaming {} → {}", part_path.display(), dest_path.display()))?;
|
||
_tmp_guard.disarm();
|
||
|
||
let final_size = fs::metadata(dest_path).map(|m| m.len()).ok();
|
||
let elapsed = start.elapsed().as_secs_f64();
|
||
|
||
if let Some(sz) = final_size {
|
||
if elapsed > 0.0 {
|
||
let mib = sz as f64 / 1024.0 / 1024.0;
|
||
let rate = mib / elapsed;
|
||
crate::ui::success(format!(
|
||
"Saved: {} ({}) in {:.1}s, {:.1} MiB/s",
|
||
dest_path.display(),
|
||
format_size_mb(Some(sz)),
|
||
elapsed,
|
||
rate
|
||
));
|
||
} else {
|
||
crate::ui::success(format!(
|
||
"Saved: {} ({})",
|
||
dest_path.display(),
|
||
format_size_mb(Some(sz))
|
||
));
|
||
}
|
||
} else {
|
||
crate::ui::success(format!(
|
||
"Saved: {} ({})",
|
||
dest_path.display(),
|
||
format_size_mb(None)
|
||
));
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
pub fn run_interactive_model_downloader() -> Result<()> {
|
||
use crate::ui;
|
||
|
||
if crate::is_no_interaction() || !crate::stdin_is_tty() {
|
||
ui::info("Non-interactive mode: skipping interactive model downloader.");
|
||
return Ok(());
|
||
}
|
||
|
||
let available = current_manifest()?;
|
||
|
||
use std::collections::BTreeMap;
|
||
let mut by_base: BTreeMap<String, Vec<ModelEntry>> = BTreeMap::new();
|
||
for m in available.into_iter() {
|
||
by_base.entry(m.base.clone()).or_default().push(m);
|
||
}
|
||
|
||
let pref_order = ["tiny", "small", "base", "medium", "large"];
|
||
let mut ordered_bases: Vec<String> = Vec::new();
|
||
for b in pref_order {
|
||
if by_base.contains_key(b) {
|
||
ordered_bases.push(b.to_string());
|
||
}
|
||
}
|
||
for b in by_base.keys() {
|
||
if !ordered_bases.iter().any(|x| x == b) {
|
||
ordered_bases.push(b.clone());
|
||
}
|
||
}
|
||
|
||
ui::intro("PolyScribe model downloader");
|
||
|
||
let mut base_labels: Vec<String> = Vec::new();
|
||
for base in &ordered_bases {
|
||
let variants = &by_base[base];
|
||
let (min_sz, max_sz) = variants.iter().fold((None, None), |acc, m| {
|
||
let (mut lo, mut hi) = acc;
|
||
if let Some(sz) = m.size {
|
||
lo = Some(lo.map(|v: u64| v.min(sz)).unwrap_or(sz));
|
||
hi = Some(hi.map(|v: u64| v.max(sz)).unwrap_or(sz));
|
||
}
|
||
(lo, hi)
|
||
});
|
||
let size_info = match (min_sz, max_sz) {
|
||
(Some(lo), Some(hi)) if lo != hi => format!(
|
||
" ~{:.2}–{:.2} MB",
|
||
lo as f64 / 1_000_000.0,
|
||
hi as f64 / 1_000_000.0
|
||
),
|
||
(Some(sz), _) => format!(" ~{:.2} MB", sz as f64 / 1_000_000.0),
|
||
_ => String::new(),
|
||
};
|
||
base_labels.push(format!("{} ({} types){}", base, variants.len(), size_info));
|
||
}
|
||
let base_refs: Vec<&str> = base_labels.iter().map(|s| s.as_str()).collect();
|
||
let base_idx = ui::prompt_select("Choose a model base", &base_refs)?;
|
||
let chosen_base = ordered_bases[base_idx].clone();
|
||
|
||
let mut variants = by_base.remove(&chosen_base).unwrap_or_default();
|
||
variants.sort_by(|a, b| {
|
||
let rank = |v: &str| match v {
|
||
"default" => 0,
|
||
"en" => 1,
|
||
_ => 2,
|
||
};
|
||
rank(&a.variant)
|
||
.cmp(&rank(&b.variant))
|
||
.then_with(|| a.variant.cmp(&b.variant))
|
||
});
|
||
|
||
let mut variant_labels: Vec<String> = Vec::new();
|
||
for m in &variants {
|
||
let size = format_size_mb(m.size.as_ref().copied());
|
||
let updated = m
|
||
.last_modified
|
||
.as_deref()
|
||
.map(short_date)
|
||
.map(|d| format!(" • updated {}", d))
|
||
.unwrap_or_default();
|
||
let variant_label = if m.variant == "default" {
|
||
"default"
|
||
} else {
|
||
&m.variant
|
||
};
|
||
variant_labels.push(format!("{} ({}{})", variant_label, size, updated));
|
||
}
|
||
let variant_refs: Vec<&str> = variant_labels.iter().map(|s| s.as_str()).collect();
|
||
let mut defaults = vec![false; variant_refs.len()];
|
||
if !defaults.is_empty() {
|
||
defaults[0] = true;
|
||
}
|
||
let picks = ui::prompt_multi_select(
|
||
&format!("Select types for '{}'", chosen_base),
|
||
&variant_refs,
|
||
Some(&defaults),
|
||
)?;
|
||
|
||
if picks.is_empty() {
|
||
ui::warn("No types selected; aborting.");
|
||
ui::outro("No changes made.");
|
||
return Ok(());
|
||
}
|
||
|
||
ui::println_above_bars("Downloading selected models...");
|
||
|
||
let labels: Vec<String> = picks
|
||
.iter()
|
||
.map(|&i| {
|
||
let m = &variants[i];
|
||
format!("{} ({})", m.name, format_size_mb(m.size))
|
||
})
|
||
.collect();
|
||
let mut pm = ui::progress::FileProgress::default_for_files(labels.len());
|
||
pm.init_files(&labels);
|
||
|
||
for (bar_idx, idx) in picks.into_iter().enumerate() {
|
||
let picked = variants[idx].clone();
|
||
pm.set_file_message(bar_idx, "downloading");
|
||
let _path = ensure_model_available_noninteractive(&picked.name)?;
|
||
pm.mark_file_done(bar_idx);
|
||
ui::success(format!("Ready: {}", picked.name));
|
||
}
|
||
|
||
pm.finish_total("all done");
|
||
ui::outro("Model selection complete.");
|
||
Ok(())
|
||
}
|
||
|
||
pub fn update_local_models() -> Result<()> {
|
||
use crate::ui;
|
||
use std::collections::HashMap;
|
||
|
||
let manifest = current_manifest()?;
|
||
let dir = crate::models_dir_path();
|
||
fs::create_dir_all(&dir).ok();
|
||
|
||
ui::info("Checking locally available models, then verifying against the online manifest…");
|
||
|
||
let mut by_file: HashMap<String, ModelEntry> = HashMap::new();
|
||
let mut by_stem_or_name: HashMap<String, ModelEntry> = HashMap::new();
|
||
for m in manifest {
|
||
by_file.insert(m.file.to_ascii_lowercase(), m.clone());
|
||
let stem = m
|
||
.file
|
||
.strip_suffix(".bin")
|
||
.unwrap_or(&m.file)
|
||
.to_ascii_lowercase();
|
||
by_stem_or_name.insert(stem, m.clone());
|
||
by_stem_or_name.insert(m.name.to_ascii_lowercase(), m);
|
||
}
|
||
|
||
let mut updated = 0usize;
|
||
let mut up_to_date = 0usize;
|
||
|
||
let rd = fs::read_dir(&dir).with_context(|| format!("reading models dir {}", dir.display()))?;
|
||
let entries: Vec<_> = rd.flatten().collect();
|
||
|
||
if entries.is_empty() {
|
||
ui::info("No local models found.");
|
||
} else {
|
||
for ent in entries {
|
||
let path = ent.path();
|
||
if !path.is_file() {
|
||
continue;
|
||
}
|
||
let is_bin = path
|
||
.extension()
|
||
.and_then(|s| s.to_str())
|
||
.is_some_and(|s| s.eq_ignore_ascii_case("bin"));
|
||
if !is_bin {
|
||
continue;
|
||
}
|
||
|
||
let file_name = match path.file_name().and_then(|s| s.to_str()) {
|
||
Some(s) => s.to_string(),
|
||
None => continue,
|
||
};
|
||
let file_lc = file_name.to_ascii_lowercase();
|
||
let stem_lc = file_lc.strip_suffix(".bin").unwrap_or(&file_lc).to_string();
|
||
|
||
let mut manifest_entry = by_file
|
||
.get(&file_lc)
|
||
.or_else(|| by_stem_or_name.get(&stem_lc))
|
||
.cloned();
|
||
|
||
let Some(mut m) = manifest_entry.take() else {
|
||
ui::warn(format!(
|
||
"Skipping unknown local model (not in online manifest): {}",
|
||
path.display()
|
||
));
|
||
continue;
|
||
};
|
||
|
||
let _ = enrich_entry_via_head(&mut m);
|
||
|
||
let target_path = if m.file.eq_ignore_ascii_case(&file_name) {
|
||
path.clone()
|
||
} else {
|
||
dir.join(&m.file)
|
||
};
|
||
|
||
if target_path.exists() && file_matches(&target_path, m.size, m.sha256.as_deref())? {
|
||
crate::dlog!(1, "OK: {}", target_path.display());
|
||
up_to_date += 1;
|
||
continue;
|
||
}
|
||
|
||
if target_path == path && target_path.exists() {
|
||
crate::ilog!("Updating {}", file_name);
|
||
let _ = fs::remove_file(&target_path);
|
||
} else if !target_path.exists() {
|
||
crate::ilog!("Fetching latest for '{}' -> {}", file_name, m.file);
|
||
} else {
|
||
crate::ilog!("Refreshing {}", target_path.display());
|
||
}
|
||
|
||
download_with_progress(&target_path, &m)?;
|
||
updated += 1;
|
||
}
|
||
|
||
if updated == 0 {
|
||
ui::info(format!("All {} local model(s) are up to date.", up_to_date));
|
||
} else {
|
||
ui::info(format!("Updated {updated} local model(s)."));
|
||
}
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use std::env;
|
||
|
||
#[test]
|
||
fn test_cache_bypass_environment() {
|
||
unsafe {
|
||
env::remove_var(ConfigService::ENV_NO_CACHE_MANIFEST);
|
||
}
|
||
assert!(!should_bypass_cache());
|
||
|
||
unsafe {
|
||
env::set_var(ConfigService::ENV_NO_CACHE_MANIFEST, "1");
|
||
}
|
||
assert!(should_bypass_cache());
|
||
|
||
unsafe {
|
||
env::remove_var(ConfigService::ENV_NO_CACHE_MANIFEST);
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn test_cache_ttl_environment() {
|
||
unsafe {
|
||
env::remove_var(ConfigService::ENV_MANIFEST_TTL_SECONDS);
|
||
}
|
||
assert_eq!(
|
||
get_cache_ttl(),
|
||
ConfigService::DEFAULT_MANIFEST_CACHE_TTL_SECONDS
|
||
);
|
||
|
||
unsafe {
|
||
env::set_var(ConfigService::ENV_MANIFEST_TTL_SECONDS, "3600");
|
||
}
|
||
assert_eq!(get_cache_ttl(), 3600);
|
||
|
||
unsafe {
|
||
env::remove_var(ConfigService::ENV_MANIFEST_TTL_SECONDS);
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn test_cached_manifest_serialization() {
|
||
let entries = vec![ModelEntry {
|
||
name: "test".to_string(),
|
||
file: "test.bin".to_string(),
|
||
url: "https://example.com/test.bin".to_string(),
|
||
size: Some(1024),
|
||
sha256: Some("abc123".to_string()),
|
||
last_modified: Some("2023-01-01T00:00:00Z".to_string()),
|
||
base: "test".to_string(),
|
||
variant: "default".to_string(),
|
||
}];
|
||
|
||
let cached = CachedManifest {
|
||
fetched_at: 1234567890,
|
||
etag: Some("etag123".to_string()),
|
||
last_modified: Some("2023-01-01T00:00:00Z".to_string()),
|
||
entries: entries.clone(),
|
||
};
|
||
|
||
let json = serde_json::to_string(&cached).unwrap();
|
||
let deserialized: CachedManifest = serde_json::from_str(&json).unwrap();
|
||
|
||
assert_eq!(deserialized.fetched_at, cached.fetched_at);
|
||
assert_eq!(deserialized.etag, cached.etag);
|
||
assert_eq!(deserialized.last_modified, cached.last_modified);
|
||
assert_eq!(deserialized.entries.len(), entries.len());
|
||
assert_eq!(deserialized.entries[0].name, entries[0].name);
|
||
}
|
||
}
|