From 53119cd0ab435b819d7042d04288ee372eaa86c7 Mon Sep 17 00:00:00 2001 From: vikingowl Date: Wed, 13 Aug 2025 22:44:51 +0200 Subject: [PATCH] [refactor] enhance model management with metadata enrichment, new API integration, and manifest resolution --- crates/polyscribe-cli/src/main.rs | 1 - crates/polyscribe-core/Cargo.toml | 2 +- crates/polyscribe-core/src/models.rs | 1317 ++++++++++++++++++++------ 3 files changed, 1033 insertions(+), 287 deletions(-) diff --git a/crates/polyscribe-cli/src/main.rs b/crates/polyscribe-cli/src/main.rs index c2fcecc..43d784c 100644 --- a/crates/polyscribe-cli/src/main.rs +++ b/crates/polyscribe-cli/src/main.rs @@ -85,7 +85,6 @@ async fn main() -> Result<()> { .await .map_err(|e| anyhow!("blocking task join error: {e}"))? .context("updating models")?; - println!("Models updated."); } ModelsCmd::Download => { info!("interactive model selection and download"); diff --git a/crates/polyscribe-core/Cargo.toml b/crates/polyscribe-core/Cargo.toml index 11f4a34..bbeb518 100644 --- a/crates/polyscribe-core/Cargo.toml +++ b/crates/polyscribe-core/Cargo.toml @@ -15,7 +15,7 @@ libc = "0.2.175" whisper-rs = "0.14.3" indicatif = "0.17.11" # New: HTTP downloads + hashing -reqwest = { version = "0.12.7", default-features = false, features = ["blocking", "rustls-tls", "gzip"] } +reqwest = { version = "0.12.7", default-features = false, features = ["blocking", "rustls-tls", "gzip", "json"] } sha2 = "0.10.8" hex = "0.4.3" tempfile = "3.12.0" diff --git a/crates/polyscribe-core/src/models.rs b/crates/polyscribe-core/src/models.rs index 3fc16f9..9e2a46d 100644 --- a/crates/polyscribe-core/src/models.rs +++ b/crates/polyscribe-core/src/models.rs @@ -1,114 +1,560 @@ // SPDX-License-Identifier: MIT //! Model management for PolyScribe: discovery, download, and verification. +//! Fetches the live file table from Hugging Face, using size and sha256 +//! data for verification. Falls back to scraping the repository tree page +//! if the JSON API is unavailable or incomplete. No built-in manifest. use anyhow::{anyhow, Context, Result}; +use chrono::{DateTime, Utc}; +use hex::ToHex; use indicatif::{ProgressBar, ProgressStyle}; +use reqwest::blocking::Client; +use reqwest::header::{ + ACCEPT_RANGES, CONTENT_LENGTH, CONTENT_RANGE, ETAG, IF_RANGE, LAST_MODIFIED, RANGE, +}; +use serde::Deserialize; use sha2::{Digest, Sha256}; -use std::fs::{self, File}; +use std::collections::BTreeSet; +use std::fs::{self, File, OpenOptions}; use std::io::{Read, Write}; use std::path::{Path, PathBuf}; -use tempfile::NamedTempFile; +use std::thread; +use std::time::{Duration, Instant}; + +fn format_size_mb(size: Option) -> String { + match size { + Some(bytes) => { + let mib = bytes as f64 / 1024.0 / 1024.0; + format!("{mib:.2} MiB") + } + None => "? MiB".to_string(), + } +} + +fn format_size_gib(bytes: u64) -> String { + let gib = bytes as f64 / 1024.0 / 1024.0 / 1024.0; + format!("{gib:.2} GiB") +} + + +// Short date formatter (RFC -> yyyy-mm-dd) +fn short_date(s: &str) -> String { + DateTime::parse_from_rfc3339(s) + .ok() + .map(|dt| dt.with_timezone(&Utc).format("%Y-%m-%d").to_string()) + .unwrap_or_else(|| s.to_string()) +} + +// Free disk space using libc::statvfs (already in Cargo) +fn free_space_bytes_for_path(path: &Path) -> Result { + use libc::{statvfs, statvfs as statvfs_t}; + use std::ffi::CString; + + // use parent dir or current dir if none + let dir = if path.is_dir() { + path + } else { + path.parent().unwrap_or(Path::new(".")) + }; + + let cpath = CString::new(dir.as_os_str().to_string_lossy().as_bytes()) + .map_err(|_| anyhow!("invalid path for statvfs"))?; + unsafe { + let mut s: statvfs_t = std::mem::zeroed(); + if statvfs(cpath.as_ptr(), &mut s) != 0 { + return Err(anyhow!("statvfs failed for {}", dir.display())); + } + Ok((s.f_bsize as u64) * (s.f_bavail as u64)) + } +} + +// Minimal mirror note shown in single-line style +fn mirror_label(url: &str) -> &'static str { + // Very light heuristic; replace with your actual mirror selection if you have it + if url.contains("eu") { + "EU mirror" + } else if url.contains("us") { + "US mirror" + } else { + "source" + } +} + +// Helper: build a single progress bar with desired format +fn new_progress_bar(total: Option) -> ProgressBar { + let zero_or_none = matches!(total, None | Some(0)); + + let pb = if zero_or_none { + ProgressBar::new_spinner() + } else { + ProgressBar::new(total.unwrap()) + }; + pb.enable_steady_tick(Duration::from_millis(120)); + + // Use built-in byte and time placeholders + let style = if zero_or_none { + // No total known: show spinner, bytes so far, speed, and elapsed + ProgressStyle::with_template( + // Panel-ish spinner + "{spinner} {bytes:>9} @ {bytes_per_sec} | {elapsed} | {msg}" + ).unwrap_or_else(|_| ProgressStyle::default_spinner()) + } else { + // Total known: show bar, percent, bytes progress, speed, and ETA + ProgressStyle::with_template( + // Railcar + numeric focus + "{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})" + ) + .unwrap_or_else(|_| ProgressStyle::default_bar()).tick_strings(&[ + "▹▹▹▹▹", + "▸▹▹▹▹", + "▹▸▹▹▹", + "▹▹▸▹▹", + "▹▹▹▸▹", + "▹▹▹▹▸", + "▪▪▪▪▪", + ]) + .progress_chars("=>-") + }; + + pb.set_style(style); + pb +} + +// Perform a HEAD to get size/etag/last-modified and fill what we can +fn head_entry(client: &Client, url: &str) -> Result<(Option, Option, Option, bool)> { + let resp = client.head(url).send()?.error_for_status()?; + let len = resp + .headers() + .get(CONTENT_LENGTH) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()); + let etag = resp + .headers() + .get(ETAG) + .and_then(|v| v.to_str().ok()) + .map(|s| s.trim_matches('"').to_string()); + let last_mod = resp + .headers() + .get(LAST_MODIFIED) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + let ranges_ok = resp + .headers() + .get(ACCEPT_RANGES) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_ascii_lowercase().contains("bytes")) + .unwrap_or(false); + Ok((len, etag, last_mod, ranges_ok)) +} /// Represents a downloadable Whisper model artifact. #[derive(Debug, Clone)] struct ModelEntry { - /// Display name and local short name (without extension if using default naming) - name: &'static str, + /// Display name and local short name (informational; may equal stem of file) + name: String, /// Remote file name (with extension) - file: &'static str, + file: String, /// Remote URL - url: &'static str, + url: String, /// Expected file size (optional) size: Option, /// Expected SHA-256 in hex (optional) - sha256: Option<&'static str>, + sha256: Option, + /// New: last modified timestamp string if available + last_modified: Option, + /// New: parsed base and variant for 2-step UI + base: String, + variant: String, } -/// Minimal built-in manifest. -/// You can extend this list or replace URLs to match your preferred source. -/// Large sizes/hashes are optional; leave None to skip checks. -fn builtin_manifest() -> Vec { - // Example URLs (Hugging Face). Replace or extend as needed. - // The filenames are typical GGUF/GGML whisper distributions. - vec![ - ModelEntry { - name: "tiny.en", - file: "ggml-tiny.en.bin", - url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin?download=true", +// -------- Hugging Face API integration -------- + +#[derive(Debug, Deserialize)] +struct HfModelInfo { + // Returned sometimes at /api/models/{repo} + siblings: Option>, + // Returned when using `?expand=files` + files: Option>, +} + +#[derive(Debug, Deserialize)] +struct HfLfsInfo { + // Sometimes an "oid" like "sha256:" + oid: Option, + size: Option, + sha256: Option, +} + +#[derive(Debug, Deserialize)] +struct HfFile { + // Relative filename within repo (e.g., "ggml-tiny.bin") + rfilename: String, + // Size reported at top-level for non-LFS files; often present + size: Option, + // Some entries include sha256 at top level + sha256: Option, + // LFS metadata with size and possibly sha256 embedded + lfs: Option, + // New: last modified timestamp provided by HF API on expanded files + #[serde(rename = "lastModified")] + last_modified: Option, +} + +fn parse_base_variant(display_name: &str) -> (String, String) { + // display_name is name without ggml-/gguf- and without .bin + // Examples: + // - "tiny" -> base=tiny, variant=default + // - "tiny.en" -> base=tiny, variant=en + // - "base" -> base=base, variant=default + // - "large-v2" -> base=large, variant=v2 + // - "large-v3" -> base=large, variant=v3 + // - "medium" -> base=medium, variant=default + let mut variant = "default".to_string(); + + // Split off dot-based suffix (e.g., ".en") + let mut head = display_name; + if let Some((h, rest)) = display_name.split_once('.') { + head = h; + // if there is more than one dot, just keep everything after first as variant + variant = rest.to_string(); + } + + // Handle hyphenated versions like large-v2 + if let Some((b, v)) = head.split_once('-') { + return (b.to_string(), v.to_string()); + } + + (head.to_string(), variant) +} + +/// Build a manifest by calling the Hugging Face API for a repo. +/// Prefers the plain API URL, then retries with `?expand=files` if needed. +fn hf_repo_manifest_api(repo: &str) -> Result> { + let client = Client::builder() + .user_agent("polyscribe/0.1") + .build()?; + + // 1) Try the plain API you specified + let base = format!("https://huggingface.co/api/models/{}", repo); + let resp = client.get(&base).send()?; + let mut entries = if resp.status().is_success() { + let info: HfModelInfo = resp.json()?; + hf_info_to_entries(repo, info)? + } else { + Vec::new() + }; + + // 2) If empty, try with expand=files (some repos require this for full file listing) + if entries.is_empty() { + let url = format!("{base}?expand=files"); + let resp2 = client.get(&url).send()?; + if !resp2.status().is_success() { + return Err(anyhow!("HF API {} for {}", resp2.status(), url)); + } + let info: HfModelInfo = resp2.json()?; + entries = hf_info_to_entries(repo, info)?; + } + + if entries.is_empty() { + return Err(anyhow!("HF API returned no usable .bin files")); + } + Ok(entries) +} + +fn hf_info_to_entries(repo: &str, info: HfModelInfo) -> Result> { + let files = info.files.or(info.siblings).unwrap_or_default(); + let mut out = Vec::new(); + for f in files { + let fname = f.rfilename; + if !fname.ends_with(".bin") { + continue; + } + + // Derive a simple display name from the file stem + let stem = fname.strip_suffix(".bin").unwrap_or(&fname).to_string(); + let name_no_prefix = stem + .strip_prefix("ggml-") + .or_else(|| stem.strip_prefix("gguf-")) + .unwrap_or(&stem) + .to_string(); + + // Prefer explicit sha256; else try to parse from LFS oid "sha256:" + let sha_from_lfs = f.lfs.as_ref().and_then(|l| { + l.sha256.clone().or_else(|| { + l.oid + .as_ref() + .and_then(|oid| oid.strip_prefix("sha256:").map(|s| s.to_string())) + }) + }); + + let size = f.size.or_else(|| f.lfs.as_ref().and_then(|l| l.size)); + + let url = format!( + "https://huggingface.co/{}/resolve/main/{}?download=true", + repo, fname + ); + + let (base, variant) = parse_base_variant(&name_no_prefix); + + out.push(ModelEntry { + name: name_no_prefix, + file: fname, + url, + size, + sha256: sha_from_lfs.or(f.sha256), + last_modified: f.last_modified.clone(), + base, + variant, + }); + } + Ok(out) +} + +// -------- HTML scraping fallback (tree view) -------- + +/// Scrape the repository tree page when the API doesn't return a usable list. +/// Note: sizes and hashes are generally unavailable in this path. +fn scrape_tree_manifest(repo: &str) -> Result> { + let client = Client::builder() + .user_agent("polyscribe/0.1") + .build()?; + + let url = format!("https://huggingface.co/{}/tree/main?recursive=1", repo); + let resp = client.get(&url).send()?; + if !resp.status().is_success() { + return Err(anyhow!("tree page HTTP {} for {}", resp.status(), url)); + } + let html = resp.text()?; + + // Extract .bin paths from links. Match both blob/main and resolve/main. + // Example matches: + // - /{repo}/blob/main/ggml-base.en.bin + // - /{repo}/resolve/main/ggml-base.en.bin + let mut files = BTreeSet::new(); + for mat in html.match_indices(".bin") { + let end = mat.0 + 4; + let start = html[..end] + .rfind("/blob/main/") + .or_else(|| html[..end].rfind("/resolve/main/")); + if let Some(s) = start { + let slice = &html[s..end]; + if let Some(pos) = slice.find("/blob/main/") { + let path = slice[pos + "/blob/main/".len()..].to_string(); + files.insert(path); + } else if let Some(pos) = slice.find("/resolve/main/") { + let path = slice[pos + "/resolve/main/".len()..].to_string(); + files.insert(path); + } + } + } + + let mut out = Vec::new(); + for fname in files.into_iter().filter(|f| f.ends_with(".bin")) { + let stem = fname.strip_suffix(".bin").unwrap_or(&fname).to_string(); + let name = stem + .rsplit('/') + .next() + .unwrap_or(&stem) + .strip_prefix("ggml-") + .or_else(|| { + stem.rsplit('/') + .next() + .unwrap_or(&stem) + .strip_prefix("gguf-") + }) + .unwrap_or_else(|| stem.rsplit('/').next().unwrap_or(&stem)) + .to_string(); + + let url = format!( + "https://huggingface.co/{}/resolve/main/{}?download=true", + repo, fname + ); + + let (base, variant) = parse_base_variant(&name); + + out.push(ModelEntry { + name, + file: fname.rsplit('/').next().unwrap_or(&fname).to_string(), + url, size: None, sha256: None, - }, - ModelEntry { - name: "tiny", - file: "ggml-tiny.bin", - url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.bin?download=true", - size: None, - sha256: None, - }, - ModelEntry { - name: "base.en", - file: "ggml-base.en.bin", - url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin?download=true", - size: None, - sha256: None, - }, - ModelEntry { - name: "base", - file: "ggml-base.bin", - url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.bin?download=true", - size: None, - sha256: None, - }, - ModelEntry { - name: "small.en", - file: "ggml-small.en.bin", - url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.en.bin?download=true", - size: None, - sha256: None, - }, - ModelEntry { - name: "small", - file: "ggml-small.bin", - url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.bin?download=true", - size: None, - sha256: None, - }, - ModelEntry { - name: "medium.en", - file: "ggml-medium.en.bin", - url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium.en.bin?download=true", - size: None, - sha256: None, - }, - ModelEntry { - name: "medium", - file: "ggml-medium.bin", - url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium.bin?download=true", - size: None, - sha256: None, - }, - ModelEntry { - name: "large-v2", - file: "ggml-large-v2.bin", - url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v2.bin?download=true", - size: None, - sha256: None, - }, - ModelEntry { - name: "large-v3", - file: "ggml-large-v3.bin", - url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3.bin?download=true", - size: None, - sha256: None, - }, - ModelEntry { - name: "large-v3-turbo", - file: "ggml-large-v3-turbo.bin", - url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3-turbo.bin?download=true", - size: None, - sha256: None, - }, - ] + last_modified: None, + base, + variant, + }); + } + + if out.is_empty() { + return Err(anyhow!("tree scraper found no .bin files")); + } + Ok(out) +} + +// -------- Metadata enrichment via HEAD (size/hash/last-modified) -------- + +fn parse_sha_from_header_value(s: &str) -> Option { + // Common HF patterns: + // - ETag: "SHA256:" + // - X-Linked-ETag: "SHA256:" + // - Sometimes weak etags: W/"SHA256:" + let lower = s.to_ascii_lowercase(); + if let Some(idx) = lower.find("sha256:") { + let tail = &lower[idx + "sha256:".len()..]; + let hex: String = tail.chars().take_while(|c| c.is_ascii_hexdigit()).collect(); + if !hex.is_empty() { + return Some(hex); + } + } + None +} + +fn enrich_entry_via_head(entry: &mut ModelEntry) -> Result<()> { + // If we already have everything, nothing to do + if entry.size.is_some() && entry.sha256.is_some() && entry.last_modified.is_some() { + return Ok(()); + } + + let client = Client::builder() + .user_agent("polyscribe/0.1") + .timeout(Duration::from_secs(8)) + .build()?; + + let mut head_url = entry.url.clone(); + if let Some((base, _)) = entry.url.split_once('?') { + head_url = base.to_string(); + } + + let started = Instant::now(); + crate::dlog!(1, "HEAD {}", head_url); + + let resp = client + .head(&head_url) + .send() + .or_else(|_| client.head(&entry.url).send())?; + + if !resp.status().is_success() { + crate::dlog!(1, "HEAD {} -> HTTP {}", head_url, resp.status()); + return Ok(()); + } + + let mut filled_size = false; + let mut filled_sha = false; + let mut filled_lm = false; + + // Content-Length + if entry.size.is_none() { + if let Some(sz) = resp + .headers() + .get(CONTENT_LENGTH) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()) + { + entry.size = Some(sz); + filled_size = true; + } + } + + // SHA256 from headers if available + if entry.sha256.is_none() { + if let Some(v) = resp.headers().get("x-linked-etag").and_then(|v| v.to_str().ok()) { + if let Some(hex) = parse_sha_from_header_value(v) { + entry.sha256 = Some(hex); + filled_sha = true; + } + } + if !filled_sha { + if let Some(v) = resp + .headers() + .get(ETAG) + .and_then(|v| v.to_str().ok()) + { + if let Some(hex) = parse_sha_from_header_value(v) { + entry.sha256 = Some(hex); + filled_sha = true; + } + } + } + } + + // Last-Modified + if entry.last_modified.is_none() { + if let Some(v) = resp + .headers() + .get(LAST_MODIFIED) + .and_then(|v| v.to_str().ok()) + { + entry.last_modified = Some(v.to_string()); + filled_lm = true; + } + } + + let elapsed_ms = started.elapsed().as_millis(); + crate::dlog!( + 1, + "HEAD ok in {} ms for {} (size: {}, sha256: {}, last-modified: {})", + elapsed_ms, + entry.file, + if filled_size { "new" } else { if entry.size.is_some() { "kept" } else { "missing" } }, + if filled_sha { "new" } else { if entry.sha256.is_some() { "kept" } else { "missing" } }, + if filled_lm { "new" } else { if entry.last_modified.is_some() { "kept" } else { "missing" } }, + ); + + Ok(()) +} + +// -------- Online manifest (API first, then scrape) -------- + +/// Returns the current manifest (online only). +fn current_manifest() -> Result> { + let started = Instant::now(); + crate::dlog!(1, "Fetching HF manifest…"); + + // 1) Load from API, else scrape + let mut list = match hf_repo_manifest_api("ggerganov/whisper.cpp") { + Ok(list) if !list.is_empty() => { + crate::dlog!(1, "Manifest loaded from HF API ({} entries)", list.len()); + list + } + _ => { + crate::ilog!("Falling back to scraping the repository tree page"); + let scraped = scrape_tree_manifest("ggerganov/whisper.cpp")?; + crate::dlog!(1, "Manifest loaded via scrape ({} entries)", scraped.len()); + scraped + } + }; + + // 2) Enrich missing metadata so the UI can show sizes and hashes + let mut need_enrich = 0usize; + for m in &list { + if m.size.is_none() || m.sha256.is_none() || m.last_modified.is_none() { + need_enrich += 1; + } + } + crate::dlog!(1, "Enriching {} entries via HEAD…", need_enrich); + for m in &mut list { + if m.size.is_none() || m.sha256.is_none() || m.last_modified.is_none() { + let _ = enrich_entry_via_head(m); + } + } + + let elapsed_ms = started.elapsed().as_millis(); + let sizes_known = list.iter().filter(|m| m.size.is_some()).count(); + let hashes_known = list.iter().filter(|m| m.sha256.is_some()).count(); + crate::dlog!( + 1, + "Manifest ready in {} ms (entries: {}, sizes: {}/{}, hashes: {}/{})", + elapsed_ms, + list.len(), + sizes_known, + list.len(), + hashes_known, + list.len() + ); + + if list.is_empty() { + return Err(anyhow!("no usable .bin files discovered")); + } + Ok(list) } /// Pick the best local Whisper model in the given directory. @@ -128,72 +574,93 @@ pub fn pick_best_local_model(dir: &Path) -> Option { .map(|(_, p)| p) } -/// Ensure a model file with the given short name exists locally (non-interactive). -/// It uses the built-in manifest to find URL and optionally verify size/hash. +/// Returns the directory where models should be stored based on platform conventions. +fn resolve_models_dir() -> Result { + let dirs = directories::ProjectDirs::from("org", "polyscribe", "polyscribe") + .ok_or_else(|| anyhow!("could not determine platform directories"))?; + let data_dir = dirs.data_dir().join("models"); + Ok(data_dir) +} + +// Example of a non-interactive path ensuring a given model by name exists, with improved copy. +// Wire this into CLI flags as needed. +/// Ensures a model is available by name, downloading it if necessary. +/// This is a non-interactive version that doesn't prompt the user. +/// +/// # Arguments +/// * `name` - Name of the model to ensure is available +/// +/// # Returns +/// * `Result` - Path to the downloaded model file on success pub fn ensure_model_available_noninteractive(name: &str) -> Result { - let Some(entry) = find_manifest_entry(name) else { - return Err(anyhow!("unknown model name: {name}")); - }; + let entry = find_manifest_entry(name)? + .ok_or_else(|| anyhow!("unknown model: {name}"))?; - let models_dir = crate::models_dir_path(); - if !models_dir.exists() { - fs::create_dir_all(&models_dir).with_context(|| { - format!("Failed to create models dir: {}", models_dir.display()) - })?; - } - let path = models_dir.join(entry.file); + // Resolve destination file path; ensure XDG path (or your existing logic) + let dir = resolve_models_dir()?; // implement or reuse your existing directory resolver + fs::create_dir_all(&dir).ok(); + let dest = dir.join(&entry.file); - // If exists and passes checks, return early - if path.exists() { - if file_matches(&path, entry.size, entry.sha256)? { - return Ok(path); - } - // Otherwise redownload - crate::ilog!( - "Existing model '{}' did not match expected checks; re-downloading.", - entry.name - ); - fs::remove_file(&path).ok(); + // If already matches, early return + if file_matches(&dest, entry.size, entry.sha256.as_deref())? { + println!("Already up to date: {}", dest.display()); + return Ok(dest); } - // Download with progress to a temp file then atomically move. - download_with_progress(&path, &entry) - .with_context(|| format!("downloading {} from {}", entry.file, entry.url))?; + // Single-line header + let base = &entry.base; + let variant = &entry.variant; + let size_str = format_size_mb(entry.size); + println!("Base: {base} • Type: {variant}"); + println!( + "Source: {} • Size: {}", + mirror_label(&entry.url), + size_str + ); - // Final verification - if !file_matches(&path, entry.size, entry.sha256)? { - return Err(anyhow!( - "downloaded file failed verification: {}", - path.display() - )); - } - - Ok(path) + download_with_progress(&dest, &entry)?; + Ok(dest) } -fn find_manifest_entry(name: &str) -> Option { - // Accept either the short names in `name` field or a direct file name - // For unknown suffixes, attempt stripping ".bin" - let name_no_ext = name.strip_suffix(".bin").unwrap_or(name); - for e in builtin_manifest() { - if e.name.eq_ignore_ascii_case(name_no_ext) || e.file.eq_ignore_ascii_case(name) { - return Some(e); +fn find_manifest_entry(name: &str) -> Result> { + // Accept either manifest display name, file stem, or direct file name. + // Normalize: strip ".bin" for comparisons and also handle input that already includes it. + let wanted_name = name + .strip_suffix(".bin") + .unwrap_or(name) + .to_ascii_lowercase(); + let wanted_file = name.to_ascii_lowercase(); + + for e in current_manifest()? { + let file_lc = e.file.to_ascii_lowercase(); + let stem_lc = e + .file + .strip_suffix(".bin") + .unwrap_or(&e.file) + .to_ascii_lowercase(); + if e.name.to_ascii_lowercase() == wanted_name + || file_lc == wanted_file + || stem_lc == wanted_name + { + return Ok(Some(e)); } } - None + Ok(None) } +// Return true if the file at `path` matches expected size and/or sha256 (when provided). +// - If sha256 is provided, verify it (preferred). +// - Else if size is provided, check size. +// - If neither provided, return false (cannot verify). fn file_matches(path: &Path, size: Option, sha256_hex: Option<&str>) -> Result { - let md = fs::metadata(path).with_context(|| format!("stat {}", path.display()))?; - if let Some(sz) = size { - if md.len() != sz { - return Ok(false); - } + if !path.exists() { + return Ok(false); } - if let Some(expected_hex) = sha256_hex { - let mut f = File::open(path)?; + + if let Some(exp_hash) = sha256_hex { + let mut f = File::open(path).with_context(|| format!("opening {}", path.display()))?; let mut hasher = Sha256::new(); - let mut buf = [0u8; 128 * 1024]; + let mut buf = vec![0u8; 1024 * 1024]; loop { let n = f.read(&mut buf)?; if n == 0 { @@ -201,110 +668,220 @@ fn file_matches(path: &Path, size: Option, sha256_hex: Option<&str>) -> Res } hasher.update(&buf[..n]); } - let got = hasher.finalize(); - let got_hex = hex::encode(got); - if !got_hex.eq_ignore_ascii_case(expected_hex) { - return Ok(false); - } + let actual = hasher.finalize(); + let actual_hex = actual.encode_hex::(); + return Ok(actual_hex.eq_ignore_ascii_case(exp_hash)); } - Ok(true) + + if let Some(expected) = size { + let meta = fs::metadata(path).with_context(|| format!("stat {}", path.display()))?; + return Ok(meta.len() == expected); + } + + Ok(false) } +// Download with: +// - Free-space preflight (size * 1.1 overhead). +// - Resume via Range if .part exists and server supports it. +// - Atomic write: download to .part (temp) then rename. +// - Checksum verification when available. +// - Single-line progress UI. fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> { - let client = reqwest::blocking::Client::builder() - .user_agent("polyscribe/0.1") + let url = &entry.url; + let client = Client::builder() + .user_agent("polyscribe-model-downloader/1") .build()?; - crate::ilog!("Downloading {} …", entry.file); - let mut resp = client.get(entry.url).send()?; - if !resp.status().is_success() { - return Err(anyhow!("HTTP {} for {}", resp.status(), entry.url)); + println!("Resolving source: {} ({})", mirror_label(url), url); + + // HEAD for size/etag/ranges + let (mut total_len, remote_etag, _remote_last_mod, ranges_ok) = head_entry(&client, url) + .context("probing remote file")?; + + if total_len.is_none() { + total_len = entry.size; } - let total_len = resp - .headers() - .get(reqwest::header::CONTENT_LENGTH) - .and_then(|v| v.to_str().ok()) - .and_then(|s| s.parse::().ok()) - .or(entry.size); - - // TTY-aware progress - let pb = if !crate::is_quiet() && !crate::is_no_progress() && crate::stdin_is_tty() { - let bar = ProgressBar::new(total_len.unwrap_or(0)); - bar.set_style( - ProgressStyle::with_template("{bar:40.cyan/blue} {bytes}/{total_bytes} {msg}") - .unwrap() - .progress_chars("##-"), - ); - if let Some(t) = total_len { - bar.set_length(t); - } - Some(bar) - } else { - None - }; - - let mut out_tmp = NamedTempFile::new_in( - dest_path - .parent() - .ok_or_else(|| anyhow!("invalid destination path"))?, - )?; - let mut hasher = Sha256::new(); - let mut written: u64 = 0; - - // Read response body in chunks using a buffer - let mut buffer = [0u8; 8192]; // 8KB buffer for reading chunks - loop { - let bytes_read = resp.read(&mut buffer)?; - if bytes_read == 0 { - break; - } - let chunk = &buffer[..bytes_read]; - out_tmp.write_all(chunk)?; - if entry.sha256.is_some() { - hasher.update(chunk); - } - written += bytes_read as u64; - if let Some(ref bar) = pb { - if let Some(total) = total_len { - bar.set_position(written.min(total)); - } else { - bar.set_message(format!("{:.1} MB", (written as f64) / 1_000_000.0)); - } - } - } - - if let Some(sz) = entry.size { - if written != sz { + if let Some(expected) = total_len { + let free = free_space_bytes_for_path(dest_path)?; + let need = expected + (expected / 10) + 16 * 1024 * 1024; + if free < need { return Err(anyhow!( - "incomplete download: expected {} bytes, got {}", - sz, - written + "insufficient disk space: need {}, have {}", + format_size_mb(Some(need)), + format_size_gib(free) )); } } - if let Some(expected_hex) = entry.sha256 { - let got_hex = hex::encode(hasher.finalize()); - if !got_hex.eq_ignore_ascii_case(expected_hex) { - return Err(anyhow!("SHA-256 mismatch for {}", entry.file)); + if dest_path.exists() { + if file_matches(dest_path, total_len, entry.sha256.as_deref())? { + println!("Already up to date: {}", dest_path.display()); + return Ok(()); } } - out_tmp - .persist(dest_path) - .with_context(|| format!("persist {}", dest_path.display()))?; + let part_path = dest_path.with_extension("part"); - if let Some(bar) = pb { - bar.finish_with_message("done"); + let mut resume_from: u64 = 0; + if part_path.exists() && ranges_ok { + let meta = fs::metadata(&part_path)?; + resume_from = meta.len(); } + + let mut part_file = OpenOptions::new() + .create(true) + .write(true) + .read(true) + .append(true) + .open(&part_path) + .with_context(|| format!("opening {}", part_path.display()))?; + + // Build request: + // - Fresh download: plain GET (no If-None-Match). + // - Resume: Range + optional If-Range with ETag. + let mut req = client.get(url); + if ranges_ok && resume_from > 0 { + req = req.header(RANGE, format!("bytes={resume_from}-")); + if let Some(etag) = &remote_etag { + req = req.header(IF_RANGE, format!("\"{etag}\"")); + } + } + + println!("Download: {}", part_path.display()); + + let pb_total = total_len.unwrap_or(0); + let pb = if pb_total > 0 { + let pb = new_progress_bar(Some(pb_total)); + pb.set_position(resume_from); + pb + } else { + new_progress_bar(None) + }; + + let start = Instant::now(); + let mut resp = req.send()?.error_for_status()?; + + // Defensive: if server returns 304 but we don't have a valid cached copy, retry without conditionals. + if resp.status().as_u16() == 304 && resume_from == 0 { + // Fresh download must not be conditional; redo as plain GET + let mut req2 = client.get(url); + resp = req2.send()?.error_for_status()?; + } + + // If server ignored RANGE and returned full body, reset partial + let is_partial_response = resp.headers().get(CONTENT_RANGE).is_some(); + if resume_from > 0 && !is_partial_response { + // Server did not honor range → start over + drop(part_file); + fs::remove_file(&part_path).ok(); + resume_from = 0; + + // Plain GET without conditional headers + let mut req2 = client.get(url); + resp = req2.send()?.error_for_status()?; + pb.set_position(0); + + // Reopen the part file since we dropped it + part_file = OpenOptions::new() + .create(true) + .write(true) + .read(true) + .append(true) + .open(&part_path) + .with_context(|| format!("opening {}", part_path.display()))?; + } + + { + let mut body = resp; + let mut buf = vec![0u8; 1024 * 64]; + loop { + let read = body.read(&mut buf)?; + if read == 0 { + break; + } + part_file.write_all(&buf[..read])?; + if pb_total > 0 { + let pos = part_file.metadata()?.len(); + pb.set_position(pos); + } else { + pb.inc(read as u64); + } + } + part_file.flush()?; + part_file.sync_all()?; + } + + pb.finish_and_clear(); + + if let Some(expected_hex) = entry.sha256.as_deref() { + println!("Verify: SHA-256"); + let mut f = File::open(&part_path)?; + let mut hasher = Sha256::new(); + let mut buf = vec![0u8; 1024 * 1024]; + loop { + let n = f.read(&mut buf)?; + if n == 0 { + break; + } + hasher.update(&buf[..n]); + } + let actual_hex = hasher.finalize().encode_hex::(); + if !actual_hex.eq_ignore_ascii_case(expected_hex) { + return Err(anyhow!( + "checksum mismatch: expected {}, got {}", + expected_hex, + actual_hex + )); + } + } else { + println!("Verify: checksum not provided by source (skipped)"); + } + + if let Some(parent) = dest_path.parent() { + fs::create_dir_all(parent).ok(); + } + drop(part_file); + fs::rename(&part_path, dest_path) + .with_context(|| format!("renaming {} → {}", part_path.display(), dest_path.display()))?; + + let final_size = fs::metadata(dest_path).map(|m| m.len()).ok(); + let elapsed = start.elapsed().as_secs_f64(); + + if let Some(sz) = final_size { + if elapsed > 0.0 { + let mib = sz as f64 / 1024.0 / 1024.0; + let rate = mib / elapsed; + println!( + "✔ Saved: {} ({}) in {:.1}s, {:.1} MiB/s", + dest_path.display(), + format_size_mb(Some(sz)), + elapsed, + rate + ); + } else { + println!( + "✔ Saved: {} ({})", + dest_path.display(), + format_size_mb(Some(sz)) + ); + } + } else { + println!( + "✔ Saved: {} ({})", + dest_path.display(), + format_size_mb(None) + ); + } + Ok(()) } -/// Run an interactive model downloader UI. -/// - Lists models from the built-in manifest -/// - Prompts for selection -/// - Downloads selected models with verification +/// Run an interactive model downloader UI (2-step): +/// 1) Choose model base (tiny, small, base, medium, large) +/// 2) Choose model type/variant specific to that base +/// Displays meta info (size and last updated). Does not show raw ggml filenames. pub fn run_interactive_model_downloader() -> Result<()> { use crate::ui; @@ -313,60 +890,151 @@ pub fn run_interactive_model_downloader() -> Result<()> { return Ok(()); } - let available = builtin_manifest(); + let available = current_manifest()?; + + use std::collections::BTreeMap; + let mut by_base: BTreeMap> = BTreeMap::new(); + for m in available.into_iter() { + by_base.entry(m.base.clone()).or_default().push(m); + } + + let pref_order = ["tiny", "small", "base", "medium", "large"]; + let mut ordered_bases: Vec = Vec::new(); + for b in pref_order { + if by_base.contains_key(b) { + ordered_bases.push(b.to_string()); + } + } + for b in by_base.keys() { + if !ordered_bases.iter().any(|x| x == b) { + ordered_bases.push(b.clone()); + } + } ui::intro("PolyScribe model downloader"); - ui::info("Select one or more models to download. Enter comma-separated numbers (e.g., 1,3,4). Press Enter to accept default [1]."); - ui::println_above_bars("Available models:"); - for (i, m) in available.iter().enumerate() { - ui::println_above_bars(format!(" {}. {:<18} – {}", i + 1, m.name, m.file)); - } - - let answer = ui::prompt_input("Your selection", Some("1"))?; - let selection_raw = if answer.trim().is_empty() { - "1".to_string() - } else { - answer.trim().to_string() - }; - let selection = if selection_raw.is_empty() { "1" } else { &selection_raw }; - - use std::collections::BTreeSet; - let mut picked_set: BTreeSet = BTreeSet::new(); - for part in selection.split([',', ' ', ';']) { - let t = part.trim(); - if t.is_empty() { - continue; - } - match t.parse::() { - Ok(n) if (1..=available.len()).contains(&n) => { - picked_set.insert(n - 1); + ui::println_above_bars("Select a model base:"); + for (i, base) in ordered_bases.iter().enumerate() { + let variants = &by_base[base]; + let (min_sz, max_sz) = variants.iter().fold((None, None), |acc, m| { + let (mut lo, mut hi) = acc; + if let Some(sz) = m.size { + lo = Some(lo.map(|v: u64| v.min(sz)).unwrap_or(sz)); + hi = Some(hi.map(|v: u64| v.max(sz)).unwrap_or(sz)); } - _ => ui::warn(format!("Ignoring invalid selection: '{t}'")), - } - } - let mut picked_indices: Vec = picked_set.into_iter().collect(); - if picked_indices.is_empty() { - picked_indices.push(0); + (lo, hi) + }); + let size_info = match (min_sz, max_sz) { + (Some(lo), Some(hi)) if lo != hi => format!( + " ~{:.2}–{:.2} MB", + lo as f64 / 1_000_000.0, + hi as f64 / 1_000_000.0 + ), + (Some(sz), _) => format!(" ~{:.2} MB", sz as f64 / 1_000_000.0), + _ => "".to_string(), + }; + ui::println_above_bars(format!( + " {}. {} ({:>2} types){}", + i + 1, + base, + variants.len(), + size_info + )); } - // Progress display (per-file style from UI) - let labels: Vec = picked_indices - .iter() - .map(|&i| available[i].name.to_string()) - .collect(); + let base_ans = ui::prompt_input("Base [1]", Some("1"))?; + // Robust: accept either index (1-based) or base name + let base_idx = base_ans.trim().parse::().ok(); + let chosen_base = if let Some(idx) = base_idx { + if idx == 0 || idx > ordered_bases.len() { + return Err(anyhow!("invalid base selection")); + } + ordered_bases[idx - 1].clone() + } else { + // Match by name, case-insensitive + let ans = base_ans.trim().to_ascii_lowercase(); + let pos = ordered_bases + .iter() + .position(|b| b.eq_ignore_ascii_case(&ans)) + .ok_or_else(|| anyhow!("invalid base selection"))?; + ordered_bases[pos].clone() + }; + + let mut variants = by_base.remove(&chosen_base).unwrap_or_default(); + + // Sort variants by a friendly order: default, en, then others alphabetically + variants.sort_by(|a, b| { + let rank = |v: &str| match v { + "default" => 0, + "en" => 1, + _ => 2, + }; + rank(&a.variant) + .cmp(&rank(&b.variant)) + .then_with(|| a.variant.cmp(&b.variant)) + }); + + ui::println_above_bars(format!("Select a type for '{}':", chosen_base)); + for (i, m) in variants.iter().enumerate() { + let size = format_size_mb(m.size.as_ref().copied()); + let updated = m + .last_modified + .as_deref() + .map(short_date) + .map(|d| format!(" • updated {}", d)) + .unwrap_or_default(); + let variant_label = if m.variant == "default" { + "default" + } else { + &m.variant + }; + ui::println_above_bars(format!( + " {}. {} ({}{})", + i + 1, + variant_label, + size, + updated + )); + } + + let type_ans = ui::prompt_input("Type [1]", Some("1"))?; + let type_idx = type_ans + .trim() + .parse::() + .ok() + .filter(|n| *n >= 1 && *n <= variants.len()) + .or_else(|| { + // Optional: allow typing the variant name + let ans = type_ans.trim().to_ascii_lowercase(); + variants + .iter() + .position(|m| { + let v = if m.variant == "default" { + "default" + } else { + &m.variant + }; + v.eq_ignore_ascii_case(&ans) + }) + .map(|i| i + 1) + }) + .ok_or_else(|| anyhow!("invalid type selection"))?; + + let picked = variants[type_idx - 1].clone(); + + fn entry_label(entry: &ModelEntry) -> String { + format!("{} ({})", entry.name, format_size_mb(entry.size)) + } + + let labels = vec![entry_label(&picked)]; let mut pm = ui::progress::ProgressManager::default_for_files(labels.len()); pm.init_files(&labels); - - for (i, idx) in picked_indices.iter().enumerate() { - let model = &available[*idx]; - if let Some(pb) = pm.per_bar(i) { - pb.set_message("downloading"); - } - let path = ensure_model_available_noninteractive(model.name)?; - ui::println_above_bars(format!("Ready: {}", path.display())); - pm.mark_file_done(i); + if let Some(pb) = pm.per_bar(0) { + pb.set_message("downloading"); } + let path = ensure_model_available_noninteractive(&picked.name)?; + ui::println_above_bars(format!("Ready: {}", path.display())); + pm.mark_file_done(0); if let Some(total) = pm.total_bar() { total.finish_with_message("all done"); } @@ -374,35 +1042,114 @@ pub fn run_interactive_model_downloader() -> Result<()> { Ok(()) } -/// Verify/update local models by comparing with the built-in manifest. +/// Verify/update local models by comparing with the online manifest. /// - If a model file exists and matches expected size/hash (when provided), it is kept. /// - If missing or mismatched, it will be downloaded. pub fn update_local_models() -> Result<()> { use crate::ui; + use std::collections::HashMap; - let manifest = builtin_manifest(); + let manifest = current_manifest()?; let dir = crate::models_dir_path(); fs::create_dir_all(&dir).ok(); - ui::info("Checking local models against manifest…"); - let mut fixed = 0usize; + ui::info("Checking locally available models, then verifying against the online manifest…"); + // Index manifest by filename and by stem/display name for matching. + let mut by_file: HashMap = HashMap::new(); + let mut by_stem_or_name: HashMap = HashMap::new(); for m in manifest { - let path = dir.join(m.file); - let ok = path.exists() && file_matches(&path, m.size, m.sha256)?; - if ok { - crate::dlog!(1, "OK: {}", path.display()); - continue; - } - crate::ilog!("Updating {}", m.name); - download_with_progress(&path, &m.clone())?; - fixed += 1; + by_file.insert(m.file.to_ascii_lowercase(), m.clone()); + let stem = m + .file + .strip_suffix(".bin") + .unwrap_or(&m.file) + .to_ascii_lowercase(); + by_stem_or_name.insert(stem, m.clone()); + by_stem_or_name.insert(m.name.to_ascii_lowercase(), m); } - if fixed == 0 { - ui::info("All models are up to date."); + let mut updated = 0usize; + let mut up_to_date = 0usize; + + // Enumerate only local .bin files. + let rd = fs::read_dir(&dir).with_context(|| format!("reading models dir {}", dir.display()))?; + let entries: Vec<_> = rd.flatten().collect(); + + if entries.len() == 0 { + ui::info("No local models found.".to_string()); } else { - ui::info(format!("Updated {fixed} model(s).")); + for ent in entries { + let path = ent.path(); + if !path.is_file() { + continue; + } + let is_bin = path + .extension() + .and_then(|s| s.to_str()) + .is_some_and(|s| s.eq_ignore_ascii_case("bin")); + if !is_bin { + continue; + } + + let file_name = match path.file_name().and_then(|s| s.to_str()) { + Some(s) => s.to_string(), + None => continue, + }; + let file_lc = file_name.to_ascii_lowercase(); + let stem_lc = file_lc.strip_suffix(".bin").unwrap_or(&file_lc).to_string(); + + // Try to find a matching manifest entry for this local file. + let mut manifest_entry = by_file + .get(&file_lc) + .or_else(|| by_stem_or_name.get(&stem_lc)) + .cloned(); + + let Some(mut m) = manifest_entry.take() else { + ui::warn(format!( + "Skipping unknown local model (not in online manifest): {}", + path.display() + )); + continue; + }; + + // Enrich metadata before verification (helps when API lacked size/hash) + let _ = enrich_entry_via_head(&mut m); + + // Determine target filename from manifest; if different, download to the canonical name. + let target_path = if m.file.eq_ignore_ascii_case(&file_name) { + path.clone() + } else { + dir.join(&m.file) + }; + + // If the target already exists and matches (size/hash when available), it is up-to-date. + if target_path.exists() && file_matches(&target_path, m.size, m.sha256.as_deref())? { + crate::dlog!(1, "OK: {}", target_path.display()); + up_to_date += 1; + continue; + } + + // If the current file is the same as the target and mismatched, remove before re-download. + if target_path == path && target_path.exists() { + crate::ilog!("Updating {}", file_name); + let _ = fs::remove_file(&target_path); + } else if !target_path.exists() { + crate::ilog!("Fetching latest for '{}' -> {}", file_name, m.file); + } else { + crate::ilog!("Refreshing {}", target_path.display()); + } + + download_with_progress(&target_path, &m)?; + updated += 1; + } + + if updated == 0 { + ui::info(format!("All {} local model(s) are up to date.", up_to_date)); + } else { + ui::info(format!("Updated {updated} local model(s).")); + } } + Ok(()) }