From 0128bf2eecff726bcc25029d91811bb2b01aa516 Mon Sep 17 00:00:00 2001 From: vikingowl Date: Wed, 27 Aug 2025 20:56:05 +0200 Subject: [PATCH] [feat] add `ModelManager` with caching, manifest management, and Hugging Face API integration --- Cargo.lock | 1 + Cargo.toml | 2 +- crates/polyscribe-cli/src/cli.rs | 67 +- crates/polyscribe-cli/src/main.rs | 296 ++++++- crates/polyscribe-core/src/lib.rs | 2 + crates/polyscribe-core/src/model_manager.rs | 893 ++++++++++++++++++++ crates/polyscribe-core/src/models.rs | 8 + crates/polyscribe-core/src/ui.rs | 149 +++- 8 files changed, 1347 insertions(+), 71 deletions(-) create mode 100644 crates/polyscribe-core/src/model_manager.rs diff --git a/Cargo.lock b/Cargo.lock index 3536986..5033e3f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -251,6 +251,7 @@ dependencies = [ "iana-time-zone", "js-sys", "num-traits", + "serde", "wasm-bindgen", "windows-link", ] diff --git a/Cargo.toml b/Cargo.toml index dc2e186..693c444 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ anyhow = "1.0.99" libc = "0.2.175" toml = "0.8.23" serde_json = "1.0.142" -chrono = "0.4.41" +chrono = { version = "0.4.41", features = ["serde"] } sha2 = "0.10.9" which = "6.0.3" tokio = { version = "1.47.1", features = ["rt-multi-thread", "macros"] } diff --git a/crates/polyscribe-cli/src/cli.rs b/crates/polyscribe-cli/src/cli.rs index 38f9631..201e0ef 100644 --- a/crates/polyscribe-cli/src/cli.rs +++ b/crates/polyscribe-cli/src/cli.rs @@ -74,7 +74,7 @@ pub enum Commands { inputs: Vec, }, - /// Manage Whisper models + /// Manage Whisper GGUF models (Hugging Face) Models { #[command(subcommand)] cmd: ModelsCmd, @@ -97,14 +97,67 @@ pub enum Commands { Man, } +#[derive(Debug, Clone, Parser)] +pub struct ModelCommon { + /// Concurrency for ranged downloads + #[arg(long, default_value_t = 4)] + pub concurrency: usize, + /// Limit download rate in bytes/sec (approximate) + #[arg(long)] + pub limit_rate: Option, + /// Emit machine JSON output + #[arg(long, default_value_t = false)] + pub json: bool, +} + #[derive(Debug, Subcommand)] pub enum ModelsCmd { - /// Verify or update local models non-interactively - Update, - /// Interactive multi-select downloader - Download, - /// Clear the cached Hugging Face manifest - ClearCache, + /// List installed models (from manifest) + Ls { + #[command(flatten)] + common: ModelCommon, + }, + /// Add or update a model + Add { + /// Hugging Face repo, e.g. ggml-org/models + repo: String, + /// File name in repo (e.g., gguf-tiny-q4_0.bin) + file: String, + #[command(flatten)] + common: ModelCommon, + }, + /// Remove a model by alias + Rm { + alias: String, + #[command(flatten)] + common: ModelCommon, + }, + /// Verify model file integrity by alias + Verify { + alias: String, + #[command(flatten)] + common: ModelCommon, + }, + /// Update all models (HEAD + ETag; skip if unchanged) + Update { + #[command(flatten)] + common: ModelCommon, + }, + /// Garbage-collect unreferenced files and stale manifest entries + Gc { + #[command(flatten)] + common: ModelCommon, + }, + /// Search a repo for GGUF files + Search { + /// Hugging Face repo, e.g. ggml-org/models + repo: String, + /// Optional substring to filter filenames + #[arg(long)] + query: Option, + #[command(flatten)] + common: ModelCommon, + }, } #[derive(Debug, Subcommand)] diff --git a/crates/polyscribe-cli/src/main.rs b/crates/polyscribe-cli/src/main.rs index cf236bf..2ee4389 100644 --- a/crates/polyscribe-cli/src/main.rs +++ b/crates/polyscribe-cli/src/main.rs @@ -2,8 +2,49 @@ mod cli; use anyhow::{Context, Result, anyhow}; use clap::{CommandFactory, Parser}; -use cli::{Cli, Commands, GpuBackend, ModelsCmd, PluginsCmd}; -use polyscribe_core::models; +use cli::{Cli, Commands, GpuBackend, ModelsCmd, ModelCommon, PluginsCmd}; +use polyscribe_core::model_manager::{ModelManager, Settings, ReqwestClient}; +use polyscribe_core::ui; +fn normalized_similarity(a: &str, b: &str) -> f64 { + // simple Levenshtein distance; normalized to [0,1] + let a_bytes = a.as_bytes(); + let b_bytes = b.as_bytes(); + let n = a_bytes.len(); + let m = b_bytes.len(); + if n == 0 && m == 0 { return 1.0; } + if n == 0 || m == 0 { return 0.0; } + let mut prev: Vec = (0..=m).collect(); + let mut curr: Vec = vec![0; m + 1]; + for i in 1..=n { + curr[0] = i; + for j in 1..=m { + let cost = if a_bytes[i - 1] == b_bytes[j - 1] { 0 } else { 1 }; + curr[j] = (prev[j] + 1) + .min(curr[j - 1] + 1) + .min(prev[j - 1] + cost); + } + std::mem::swap(&mut prev, &mut curr); + } + let dist = prev[m] as f64; + let max_len = n.max(m) as f64; + 1.0 - (dist / max_len) +} + +fn human_size(bytes: Option) -> String { + match bytes { + Some(n) => { + let x = n as f64; + const KB: f64 = 1024.0; + const MB: f64 = 1024.0 * KB; + const GB: f64 = 1024.0 * MB; + if x >= GB { format!("{:.2} GiB", x / GB) } + else if x >= MB { format!("{:.2} MiB", x / MB) } + else if x >= KB { format!("{:.2} KiB", x / KB) } + else { format!("{} B", n) } + } + None => "?".to_string(), + } +} use polyscribe_core::ui::progress::ProgressReporter; use polyscribe_host::PluginManager; use tokio::io::AsyncWriteExt; @@ -29,8 +70,7 @@ fn init_tracing(quiet: bool, verbose: u8) { .init(); } -#[tokio::main] -async fn main() -> Result<()> { +fn main() -> Result<()> { let args = Cli::parse(); init_tracing(args.quiet, args.verbose); @@ -71,32 +111,188 @@ async fn main() -> Result<()> { } Commands::Models { cmd } => { - match cmd { - ModelsCmd::Update => { - polyscribe_core::ui::info("verifying/updating local models"); - tokio::task::spawn_blocking(models::update_local_models) - .await - .map_err(|e| anyhow!("blocking task join error: {e}"))? - .context("updating models")?; + // predictable exit codes + const EXIT_OK: i32 = 0; + const EXIT_NOT_FOUND: i32 = 2; + const EXIT_NETWORK: i32 = 3; + const EXIT_VERIFY_FAILED: i32 = 4; + // const EXIT_NO_CHANGE: i32 = 5; // reserved + + let handle_common = |c: &ModelCommon| Settings { + concurrency: c.concurrency.max(1), + limit_rate: c.limit_rate, + ..Default::default() + }; + + let exit = match cmd { + ModelsCmd::Ls { common } => { + let mm: ModelManager = ModelManager::new(handle_common(&common))?; + let list = mm.ls()?; + if common.json { + println!("{}", serde_json::to_string_pretty(&list)?); + } else { + println!("Model (Repo)"); + for r in list { + println!("{} ({})", r.file, r.repo); + } + } + EXIT_OK } - ModelsCmd::Download => { - polyscribe_core::ui::info("interactive model selection and download"); - tokio::task::spawn_blocking(models::run_interactive_model_downloader) - .await - .map_err(|e| anyhow!("blocking task join error: {e}"))? - .context("running downloader")?; - polyscribe_core::ui::success("Model download complete."); + ModelsCmd::Add { repo, file, common } => { + let settings = handle_common(&common); + let mm: ModelManager = ModelManager::new(settings.clone())?; + // Derive an alias automatically from repo and file + fn derive_alias(repo: &str, file: &str) -> String { + use std::path::Path; + let repo_tail = repo.rsplit('/').next().unwrap_or(repo); + let stem = Path::new(file) + .file_stem() + .and_then(|s| s.to_str()) + .unwrap_or(file); + format!("{}-{}", repo_tail, stem) + } + let alias = derive_alias(&repo, &file); + match mm.add_or_update(&alias, &repo, &file) { + Ok(rec) => { + if common.json { println!("{}", serde_json::to_string_pretty(&rec)?); } + else { println!("installed: {} -> {}/{}", alias, repo, rec.file); } + EXIT_OK + } + Err(e) => { + // On not found or similar errors, try suggesting close matches interactively + if common.json || polyscribe_core::is_no_interaction() { + if common.json { println!("{{\"error\":{}}}", serde_json::to_string(&e.to_string())?); } + else { eprintln!("error: {e}"); } + EXIT_NOT_FOUND + } else { + ui::warn(format!("{}", e)); + ui::info("Searching for similar model filenames…"); + match polyscribe_core::model_manager::search_repo(&repo, None) { + Ok(mut files) => { + if files.is_empty() { + ui::warn("No files found in repository."); + EXIT_NOT_FOUND + } else { + // rank by similarity + files.sort_by(|a, b| normalized_similarity(&file, b) + .partial_cmp(&normalized_similarity(&file, a)) + .unwrap_or(std::cmp::Ordering::Equal)); + let top: Vec = files.into_iter().take(5).collect(); + if top.is_empty() { + EXIT_NOT_FOUND + } else if top.len() == 1 { + let cand = &top[0]; + // Fetch repo size list once + let size_map: std::collections::HashMap> = + polyscribe_core::model_manager::list_repo_files_with_meta(&repo) + .unwrap_or_default() + .into_iter().collect(); + let mut size = size_map.get(cand).cloned().unwrap_or(None); + if size.is_none() { + size = polyscribe_core::model_manager::head_len_for_file(&repo, cand); + } + let local_files: std::collections::HashSet = mm.ls()?.into_iter().map(|r| r.file).collect(); + let is_local = local_files.contains(cand); + let label = format!("{} [{}]{}", cand, human_size(size), if is_local { " (local)" } else { "" }); + let ok = ui::prompt_confirm(&format!("Did you mean {}?", label), true) + .unwrap_or(false); + if !ok { EXIT_NOT_FOUND } else { + let mm2: ModelManager = ModelManager::new(settings)?; + let alias2 = derive_alias(&repo, cand); + match mm2.add_or_update(&alias2, &repo, cand) { + Ok(rec) => { println!("installed: {} -> {}/{}", alias2, repo, rec.file); EXIT_OK } + Err(e2) => { eprintln!("error: {e2}"); EXIT_NETWORK } + } + } + } else { + let opts: Vec = top; + let local_files: std::collections::HashSet = mm.ls()?.into_iter().map(|r| r.file).collect(); + // Enrich labels with size and local tag using a single API call + let size_map: std::collections::HashMap> = + polyscribe_core::model_manager::list_repo_files_with_meta(&repo) + .unwrap_or_default() + .into_iter().collect(); + let mut labels_owned: Vec = Vec::new(); + for f in &opts { + let mut size = size_map.get(f).cloned().unwrap_or(None); + if size.is_none() { + size = polyscribe_core::model_manager::head_len_for_file(&repo, f); + } + let is_local = local_files.contains(f); + let suffix = if is_local { " (local)" } else { "" }; + labels_owned.push(format!("{} [{}]{}", f, human_size(size), suffix)); + } + let labels: Vec<&str> = labels_owned.iter().map(|s| s.as_str()).collect(); + match ui::prompt_select("Pick a model", &labels) { + Ok(idx) => { + let chosen = &opts[idx]; + let mm2: ModelManager = ModelManager::new(settings)?; + let alias2 = derive_alias(&repo, chosen); + match mm2.add_or_update(&alias2, &repo, chosen) { + Ok(rec) => { println!("installed: {} -> {}/{}", alias2, repo, rec.file); EXIT_OK } + Err(e2) => { eprintln!("error: {e2}"); EXIT_NETWORK } + } + } + Err(_) => EXIT_NOT_FOUND, + } + } + } + } + Err(e2) => { + eprintln!("error: {}", e2); + EXIT_NETWORK + } + } + } + } + } } - ModelsCmd::ClearCache => { - polyscribe_core::ui::info("clearing manifest cache"); - tokio::task::spawn_blocking(models::clear_manifest_cache) - .await - .map_err(|e| anyhow!("blocking task join error: {e}"))? - .context("clearing cache")?; - polyscribe_core::ui::success("Manifest cache cleared."); + ModelsCmd::Rm { alias, common } => { + let mm: ModelManager = ModelManager::new(handle_common(&common))?; + let ok = mm.rm(&alias)?; + if common.json { println!("{{\"removed\":{}}}", ok); } + else { println!("{}", if ok { "removed" } else { "not found" }); } + if ok { EXIT_OK } else { EXIT_NOT_FOUND } } - } - Ok(()) + ModelsCmd::Verify { alias, common } => { + let mm: ModelManager = ModelManager::new(handle_common(&common))?; + let found = mm.ls()?.into_iter().any(|r| r.alias == alias); + if !found { + if common.json { println!("{{\"ok\":false,\"error\":\"not found\"}}"); } else { println!("not found"); } + EXIT_NOT_FOUND + } else { + let ok = mm.verify(&alias)?; + if common.json { println!("{{\"ok\":{}}}", ok); } else { println!("{}", if ok { "ok" } else { "corrupt" }); } + if ok { EXIT_OK } else { EXIT_VERIFY_FAILED } + } + } + ModelsCmd::Update { common } => { + let mm: ModelManager = ModelManager::new(handle_common(&common))?; + let mut rc = EXIT_OK; + for rec in mm.ls()? { + match mm.add_or_update(&rec.alias, &rec.repo, &rec.file) { + Ok(_) => {} + Err(e) => { rc = EXIT_NETWORK; if common.json { println!("{{\"alias\":\"{}\",\"error\":{}}}", rec.alias, serde_json::to_string(&e.to_string())?); } else { eprintln!("update {}: {e}", rec.alias); } } + } + } + rc + } + ModelsCmd::Gc { common } => { + let mm: ModelManager = ModelManager::new(handle_common(&common))?; + let (files_removed, entries_removed) = mm.gc()?; + if common.json { println!("{{\"files_removed\":{},\"entries_removed\":{}}}", files_removed, entries_removed); } + else { println!("files_removed={} entries_removed={}", files_removed, entries_removed); } + EXIT_OK + } + ModelsCmd::Search { repo, query, common } => { + let res = polyscribe_core::model_manager::search_repo(&repo, query.as_deref()); + match res { + Ok(files) => { if common.json { println!("{}", serde_json::to_string_pretty(&files)?); } else { for f in files { println!("{}", f); } } EXIT_OK } + Err(e) => { if common.json { println!("{{\"error\":{}}}", serde_json::to_string(&e.to_string())?); } else { eprintln!("error: {e}"); } EXIT_NETWORK } + } + } + }; + std::process::exit(exit); } Commands::Plugins { cmd } => { @@ -123,27 +319,35 @@ async fn main() -> Result<()> { command, json, } => { - let payload = json.unwrap_or_else(|| "{}".to_string()); - let mut child = plugin_manager - .spawn(&name, &command) - .with_context(|| format!("spawning plugin {name} {command}"))?; + // Use a local Tokio runtime only for this async path + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .context("building tokio runtime")?; - if let Some(mut stdin) = child.stdin.take() { - stdin - .write_all(payload.as_bytes()) - .await - .context("writing JSON payload to plugin stdin")?; - } + rt.block_on(async { + let payload = json.unwrap_or_else(|| "{}".to_string()); + let mut child = plugin_manager + .spawn(&name, &command) + .with_context(|| format!("spawning plugin {name} {command}"))?; - let status = plugin_manager.forward_stdio(&mut child).await?; - if !status.success() { - polyscribe_core::ui::error(format!( - "plugin returned non-zero exit code: {}", - status - )); - return Err(anyhow!("plugin failed")); - } - Ok(()) + if let Some(mut stdin) = child.stdin.take() { + stdin + .write_all(payload.as_bytes()) + .await + .context("writing JSON payload to plugin stdin")?; + } + + let status = plugin_manager.forward_stdio(&mut child).await?; + if !status.success() { + polyscribe_core::ui::error(format!( + "plugin returned non-zero exit code: {}", + status + )); + return Err(anyhow!("plugin failed")); + } + Ok(()) + }) } } } diff --git a/crates/polyscribe-core/src/lib.rs b/crates/polyscribe-core/src/lib.rs index fb40f86..5f339a0 100644 --- a/crates/polyscribe-core/src/lib.rs +++ b/crates/polyscribe-core/src/lib.rs @@ -214,6 +214,8 @@ pub fn render_srt(entries: &[OutputEntry]) -> String { srt } +pub mod model_manager; + pub fn models_dir_path() -> PathBuf { if let Ok(env_val) = env::var("POLYSCRIBE_MODELS_DIR") { let env_path = PathBuf::from(env_val); diff --git a/crates/polyscribe-core/src/model_manager.rs b/crates/polyscribe-core/src/model_manager.rs new file mode 100644 index 0000000..08a3c01 --- /dev/null +++ b/crates/polyscribe-core/src/model_manager.rs @@ -0,0 +1,893 @@ +// SPDX-License-Identifier: MIT + +use crate::prelude::*; +use crate::ui::BytesProgress; +use anyhow::{anyhow, Context}; +use chrono::{DateTime, Utc}; +use reqwest::blocking::Client; +use reqwest::header::{ + ACCEPT_RANGES, AUTHORIZATION, CONTENT_LENGTH, ETAG, IF_NONE_MATCH, LAST_MODIFIED, RANGE, +}; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use std::cmp::min; +use std::collections::BTreeMap; +use std::fs::{self, File, OpenOptions}; +use std::io::{Read, Seek, SeekFrom, Write}; +use std::path::{Path, PathBuf}; +use std::sync::{mpsc, Arc, Mutex}; +use std::thread; +use std::time::Duration; + +const DEFAULT_CHUNK_SIZE: u64 = 8 * 1024 * 1024; // 8 MiB + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct ModelRecord { + pub alias: String, + pub repo: String, + pub file: String, + pub revision: Option, // ETag or commit hash + pub sha256: Option, + pub size_bytes: Option, + pub quant: Option, + pub installed_at: Option>, + pub last_used: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct Manifest { + pub models: BTreeMap, // key = alias +} + +#[derive(Debug, Clone)] +pub struct Settings { + pub concurrency: usize, + pub limit_rate: Option, // bytes/sec + pub chunk_size: u64, +} + +impl Default for Settings { + fn default() -> Self { + Self { + concurrency: 4, + limit_rate: None, + chunk_size: DEFAULT_CHUNK_SIZE, + } + } +} + +#[derive(Debug, Clone)] +pub struct Paths { + pub cache_dir: PathBuf, // $XDG_CACHE_HOME/polyscribe/models + pub config_path: PathBuf, // $XDG_CONFIG_HOME/polyscribe/models.json +} + +impl Paths { + pub fn resolve() -> Result { + if let Ok(over) = std::env::var("POLYSCRIBE_CACHE_DIR") { + if !over.is_empty() { + let cache_dir = PathBuf::from(over).join("models"); + let config_path = std::env::var("POLYSCRIBE_CONFIG_DIR") + .map(|p| PathBuf::from(p).join("models.json")) + .unwrap_or_else(|_| default_config_path()); + return Ok(Self { + cache_dir, + config_path, + }); + } + } + let cache_dir = default_cache_dir(); + let config_path = default_config_path(); + Ok(Self { + cache_dir, + config_path, + }) + } +} + +fn default_cache_dir() -> PathBuf { + if let Ok(xdg) = std::env::var("XDG_CACHE_HOME") { + if !xdg.is_empty() { + return PathBuf::from(xdg).join("polyscribe").join("models"); + } + } + if let Ok(home) = std::env::var("HOME") { + if !home.is_empty() { + return PathBuf::from(home) + .join(".cache") + .join("polyscribe") + .join("models"); + } + } + PathBuf::from("models") +} + +fn default_config_path() -> PathBuf { + if let Ok(xdg) = std::env::var("XDG_CONFIG_HOME") { + if !xdg.is_empty() { + return PathBuf::from(xdg).join("polyscribe").join("models.json"); + } + } + if let Ok(home) = std::env::var("HOME") { + if !home.is_empty() { + return PathBuf::from(home) + .join(".config") + .join("polyscribe") + .join("models.json"); + } + } + PathBuf::from("models.json") +} + +pub trait HttpClient: Send + Sync { + fn head(&self, url: &str, etag: Option<&str>) -> Result; + fn get_range(&self, url: &str, start: u64, end_inclusive: u64) -> Result>; + fn get_whole_to(&self, url: &str, writer: &mut dyn Write) -> Result<()>; + fn get_from_to(&self, url: &str, start: u64, writer: &mut dyn Write) -> Result<()>; +} + +#[derive(Debug, Clone)] +pub struct ReqwestClient { + client: Client, + token: Option, +} + +impl ReqwestClient { + pub fn new() -> Result { + let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty()); + let client = Client::builder() + .user_agent(crate::config::ConfigService::user_agent()) + .build()?; + Ok(Self { client, token }) + } + + fn auth(&self, mut req: reqwest::blocking::RequestBuilder) -> reqwest::blocking::RequestBuilder { + if let Some(t) = &self.token { + req = req.header(AUTHORIZATION, format!("Bearer {}", t)); + } + req + } +} + +#[derive(Debug, Clone)] +pub struct HeadMeta { + pub len: Option, + pub etag: Option, + pub last_modified: Option, + pub accept_ranges: bool, + pub not_modified: bool, + pub status: u16, +} + +impl HttpClient for ReqwestClient { + fn head(&self, url: &str, etag: Option<&str>) -> Result { + let mut req = self.client.head(url); + if let Some(e) = etag { + req = req.header(IF_NONE_MATCH, format!("\"{}\"", e)); + } + let resp = self.auth(req).send()?; + let status = resp.status().as_u16(); + if status == 304 { + return Ok(HeadMeta { + len: None, + etag: etag.map(|s| s.to_string()), + last_modified: None, + accept_ranges: true, + not_modified: true, + status, + }); + } + let len = resp + .headers() + .get(CONTENT_LENGTH) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()); + let etag = resp + .headers() + .get(ETAG) + .and_then(|v| v.to_str().ok()) + .map(|s| s.trim_matches('"').to_string()); + let last_modified = resp + .headers() + .get(LAST_MODIFIED) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + let accept_ranges = resp + .headers() + .get(ACCEPT_RANGES) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_ascii_lowercase().contains("bytes")) + .unwrap_or(false); + Ok(HeadMeta { + len, + etag, + last_modified, + accept_ranges, + not_modified: false, + status, + }) + } + + fn get_range(&self, url: &str, start: u64, end_inclusive: u64) -> Result> { + let range_val = format!("bytes={}-{}", start, end_inclusive); + let resp = self + .auth(self.client.get(url)) + .header(RANGE, range_val) + .send()?; + if !resp.status().is_success() && resp.status().as_u16() != 206 { + return Err(anyhow!("HTTP {} for ranged GET", resp.status()).into()); + } + let mut buf = Vec::new(); + let mut r = resp; + r.copy_to(&mut buf)?; + Ok(buf) + } + + fn get_whole_to(&self, url: &str, writer: &mut dyn Write) -> Result<()> { + let resp = self.auth(self.client.get(url)).send()?; + if !resp.status().is_success() { + return Err(anyhow!("HTTP {} for GET", resp.status()).into()); + } + let mut r = resp; + r.copy_to(writer)?; + Ok(()) + } + + fn get_from_to(&self, url: &str, start: u64, writer: &mut dyn Write) -> Result<()> { + let mut req = self.auth(self.client.get(url)); + if start > 0 { + req = req.header(RANGE, format!("bytes={}-", start)); + } + let resp = req.send()?; + if !resp.status().is_success() && resp.status().as_u16() != 206 { + return Err(anyhow!("HTTP {} for ranged GET from {}", resp.status(), start).into()); + } + let mut r = resp; + r.copy_to(writer)?; + Ok(()) + } +} + +pub struct ModelManager { + pub paths: Paths, + pub settings: Settings, + client: Arc, +} + +impl ModelManager { + pub fn new_with_client(client: C, settings: Settings) -> Result { + Ok(Self { + paths: Paths::resolve()?, + settings, + client: Arc::new(client), + }) + } + + pub fn new(settings: Settings) -> Result + where + C: Default, + { + Ok(Self { + paths: Paths::resolve()?, + settings, + client: Arc::new(C::default()), + }) + } + + fn load_manifest(&self) -> Result { + let p = &self.paths.config_path; + if !p.exists() { + return Ok(Manifest::default()); + } + let file = File::open(p).with_context(|| format!("open manifest {}", p.display()))?; + let m: Manifest = serde_json::from_reader(file).context("parse manifest")?; + Ok(m) + } + + fn save_manifest(&self, m: &Manifest) -> Result<()> { + let p = &self.paths.config_path; + if let Some(dir) = p.parent() { + fs::create_dir_all(dir) + .with_context(|| format!("create config dir {}", dir.display()))?; + } + let tmp = p.with_extension("json.tmp"); + let f = OpenOptions::new() + .create(true) + .write(true) + .truncate(true) + .open(&tmp)?; + serde_json::to_writer_pretty(f, m).context("serialize manifest")?; + fs::rename(&tmp, p).with_context(|| format!("atomic rename {} -> {}", tmp.display(), p.display()))?; + Ok(()) + } + + fn model_path(&self, file: &str) -> PathBuf { + self.paths.cache_dir.join(file) + } + + fn compute_sha256(path: &Path) -> Result { + let mut f = File::open(path)?; + let mut hasher = Sha256::new(); + let mut buf = [0u8; 64 * 1024]; + loop { + let n = f.read(&mut buf)?; + if n == 0 { + break; + } + hasher.update(&buf[..n]); + } + Ok(format!("{:x}", hasher.finalize())) + } + + pub fn ls(&self) -> Result> { + let m = self.load_manifest()?; + Ok(m.models.values().cloned().collect()) + } + + pub fn rm(&self, alias: &str) -> Result { + let mut m = self.load_manifest()?; + if let Some(rec) = m.models.remove(alias) { + let p = self.model_path(&rec.file); + let _ = fs::remove_file(&p); + self.save_manifest(&m)?; + return Ok(true); + } + Ok(false) + } + + pub fn verify(&self, alias: &str) -> Result { + let m = self.load_manifest()?; + let Some(rec) = m.models.get(alias) else { return Ok(false) }; + let p = self.model_path(&rec.file); + if !p.exists() { return Ok(false); } + if let Some(expected) = &rec.sha256 { + let actual = Self::compute_sha256(&p)?; + return Ok(&actual == expected); + } + Ok(true) + } + + pub fn gc(&self) -> Result<(usize, usize)> { + // Remove files not referenced by manifest; also drop manifest entries whose file is missing + fs::create_dir_all(&self.paths.cache_dir).ok(); + let mut m = self.load_manifest()?; + let mut referenced = BTreeMap::new(); + for (alias, rec) in &m.models { + referenced.insert(rec.file.clone(), alias.clone()); + } + let mut removed_files = 0usize; + if let Ok(rd) = fs::read_dir(&self.paths.cache_dir) { + for ent in rd.flatten() { + let p = ent.path(); + if p.is_file() { + let fname = p.file_name().and_then(|s| s.to_str()).unwrap_or(""); + if !referenced.contains_key(fname) { + let _ = fs::remove_file(&p); + removed_files += 1; + } + } + } + } + + m.models.retain(|_, rec| self.model_path(&rec.file).exists()); + let removed_entries = referenced + .keys() + .filter(|f| !self.model_path(f).exists()) + .count(); + + self.save_manifest(&m)?; + Ok((removed_files, removed_entries)) + } + + pub fn add_or_update( + &self, + alias: &str, + repo: &str, + file: &str, + ) -> Result { + fs::create_dir_all(&self.paths.cache_dir) + .with_context(|| format!("create cache dir {}", self.paths.cache_dir.display()))?; + + let url = format!("https://huggingface.co/{}/resolve/main/{}?download=true", repo, file); + + let mut manifest = self.load_manifest()?; + let prev = manifest.models.get(alias).cloned(); + let prev_etag = prev.as_ref().and_then(|r| r.revision.clone()); + // Fetch remote meta (size/hash) when available to verify the download + let (_api_size, api_sha) = hf_fetch_file_meta(repo, file).unwrap_or((None, None)); + let head = self.client.head(&url, prev_etag.as_deref())?; + if head.not_modified { + // up-to-date; ensure record present and touch last_used + let mut rec = prev.ok_or_else(|| anyhow!("not installed yet but server says 304"))?; + rec.last_used = Some(Utc::now()); + self.save_touch(&mut manifest, rec.clone())?; + return Ok(rec); + } + + // Quick check: if HEAD failed (e.g., 404), report a helpful error before attempting download + if head.status >= 400 { + return Err(anyhow!( + "file not found or inaccessible: repo='{}' file='{}' (HTTP {})\nHint: run `polyscribe models search {} --query {}` to list available files", + repo, + file, + head.status, + repo, + file + ).into()); + } + + let total_len = head.len.ok_or_else(|| anyhow!("missing content-length (HEAD)"))?; + let etag = head.etag.clone(); + let dest_tmp = self.model_path(&format!("{}.part", file)); + // If a previous cancelled download left a .part file, remove it to avoid clutter/resume. + if dest_tmp.exists() { let _ = fs::remove_file(&dest_tmp); } + // Guard to ensure .part is cleaned up on errors + struct TempGuard { path: PathBuf, armed: bool } + impl TempGuard { fn disarm(&mut self) { self.armed = false; } } + impl Drop for TempGuard { + fn drop(&mut self) { + if self.armed { + let _ = fs::remove_file(&self.path); + } + } + } + let mut _tmp_guard = TempGuard { path: dest_tmp.clone(), armed: true }; + let dest_final = self.model_path(file); + + // Do not resume after cancellation; start fresh to avoid stale .part files + let start_from = 0u64; + + // Open tmp for write + let f = OpenOptions::new().create(true).write(true).read(true).open(&dest_tmp)?; + f.set_len(total_len)?; // pre-allocate for random writes + let f = Arc::new(Mutex::new(f)); + + // Create progress bar + let mut progress = BytesProgress::start(total_len, &format!("Downloading {}", file), start_from); + + // Create work chunks + let chunk_size = self.settings.chunk_size; + let mut chunks = Vec::new(); + let mut pos = start_from; + while pos < total_len { + let end = min(total_len - 1, pos + chunk_size - 1); + chunks.push((pos, end)); + pos = end + 1; + } + + // Attempt concurrent ranged download; on failure, fallback to streaming GET + let mut ranged_failed = false; + if head.accept_ranges && self.settings.concurrency > 1 { + let (work_tx, work_rx) = mpsc::channel::<(u64, u64)>(); + let (prog_tx, prog_rx) = mpsc::channel::(); + for ch in chunks { + work_tx.send(ch).unwrap(); + } + drop(work_tx); + let rx = Arc::new(Mutex::new(work_rx)); + + let workers = self.settings.concurrency.max(1); + let mut handles = Vec::new(); + for _ in 0..workers { + let rx = rx.clone(); + let url = url.clone(); + let client = self.client.clone(); + let f = f.clone(); + let limit = self.settings.limit_rate; + let prog_tx = prog_tx.clone(); + let handle = thread::spawn(move || -> Result<()> { + loop { + let next = { + let guard = rx.lock().unwrap(); + guard.recv().ok() + }; + let Some((start, end)) = next else { break; }; + let data = client.get_range(&url, start, end)?; + if let Some(max_bps) = limit { + let dur = Duration::from_secs_f64((data.len() as f64) / (max_bps as f64)); + if dur > Duration::from_millis(1) { + thread::sleep(dur); + } + } + let mut guard = f.lock().unwrap(); + guard.seek(SeekFrom::Start(start))?; + guard.write_all(&data)?; + let _ = prog_tx.send(data.len() as u64); + } + Ok(()) + }); + handles.push(handle); + } + drop(prog_tx); + + for delta in prog_rx { + progress.inc(delta); + } + + let mut ranged_err: Option = None; + for h in handles { + match h.join() { + Ok(Ok(())) => {} + Ok(Err(e)) => { ranged_failed = true; if ranged_err.is_none() { ranged_err = Some(e.into()); } } + Err(_) => { ranged_failed = true; if ranged_err.is_none() { ranged_err = Some(anyhow!("worker panicked")); } } + } + } + } else { + ranged_failed = true; + } + + if ranged_failed { + // Restart progress if we are abandoning previous partial state + if start_from > 0 { + progress.stop("retrying as streaming"); + progress = BytesProgress::start(total_len, &format!("Downloading {}", file), 0); + } + // Fallback to streaming GET; try URL with and without ?download=true + let mut try_urls = vec![url.clone()]; + if let Some((base, _qs)) = url.split_once('?') { + try_urls.push(base.to_string()); + } else { + try_urls.push(format!("{}?download=true", url)); + } + // Fresh temp file for streaming + let mut wf = OpenOptions::new().create(true).write(true).truncate(true).open(&dest_tmp)?; + let mut ok = false; + let mut last_err: Option = None; + // Counting writer to update progress inline + struct CountingWriter<'a, 'b> { + inner: &'a mut File, + progress: &'b mut BytesProgress, + } + impl<'a, 'b> Write for CountingWriter<'a, 'b> { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let n = self.inner.write(buf)?; + self.progress.inc(n as u64); + Ok(n) + } + fn flush(&mut self) -> std::io::Result<()> { self.inner.flush() } + } + let mut cw = CountingWriter { inner: &mut wf, progress: &mut progress }; + for u in try_urls { + // For fallback, stream from scratch to ensure integrity + let res = self.client.get_whole_to(&u, &mut cw); + match res { + Ok(()) => { ok = true; break; } + Err(e) => { last_err = Some(e.into()); } + } + } + if !ok { + if let Some(e) = last_err { return Err(e.into()); } + return Err(anyhow!("download failed (ranged and streaming)").into()); + } + } + + progress.stop("download complete"); + + // Verify integrity + let sha = Self::compute_sha256(&dest_tmp)?; + if let Some(expected) = api_sha.as_ref() { + if &sha != expected { + return Err(anyhow!( + "sha256 mismatch (expected {}, got {})", + expected, + sha + ).into()); + } + } else if prev.as_ref().map(|r| r.file.eq(file)).unwrap_or(false) { + if let Some(expected) = prev.as_ref().and_then(|r| r.sha256.as_ref()) { + if &sha != expected { + return Err(anyhow!("sha256 mismatch").into()); + } + } + } + // Atomic rename + fs::rename(&dest_tmp, &dest_final).with_context(|| format!("rename {} -> {}", dest_tmp.display(), dest_final.display()))?; + // Disarm guard; .part has been moved or cleaned + _tmp_guard.disarm(); + + let rec = ModelRecord { + alias: alias.to_string(), + repo: repo.to_string(), + file: file.to_string(), + revision: etag, + sha256: Some(sha.clone()), + size_bytes: Some(total_len), + quant: infer_quant(file), + installed_at: Some(Utc::now()), + last_used: Some(Utc::now()), + }; + self.save_touch(&mut manifest, rec.clone())?; + Ok(rec) + } + + fn save_touch(&self, manifest: &mut Manifest, rec: ModelRecord) -> Result<()> { + manifest.models.insert(rec.alias.clone(), rec); + self.save_manifest(manifest) + } +} + +fn infer_quant(file: &str) -> Option { + // Try to extract a Q* token, e.g. Q5_K_M from filename + let lower = file.to_ascii_uppercase(); + if let Some(pos) = lower.find('Q') { + let tail = &lower[pos..]; + let token: String = tail + .chars() + .take_while(|c| c.is_ascii_uppercase() || c.is_ascii_digit() || *c == '_' || *c == '-') + .collect(); + if token.len() >= 2 { + return Some(token); + } + } + None +} + +impl Default for ReqwestClient { + fn default() -> Self { + Self::new().expect("reqwest client") + } +} + +// Hugging Face API types for file metadata +#[derive(Debug, Deserialize)] +struct ApiHfLfs { + oid: Option, + size: Option, + sha256: Option, +} + +#[derive(Debug, Deserialize)] +struct ApiHfFile { + rfilename: String, + size: Option, + sha256: Option, + lfs: Option, +} + +#[derive(Debug, Deserialize)] +struct ApiHfModelInfo { + siblings: Option>, + files: Option>, +} + +fn pick_sha_from_file(f: &ApiHfFile) -> Option { + if let Some(s) = &f.sha256 { return Some(s.to_string()); } + if let Some(l) = &f.lfs { + if let Some(s) = &l.sha256 { return Some(s.to_string()); } + if let Some(oid) = &l.oid { return oid.strip_prefix("sha256:").map(|s| s.to_string()); } + } + None +} + +fn hf_fetch_file_meta(repo: &str, target: &str) -> Result<(Option, Option)> { + let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty()); + let client = Client::builder() + .user_agent(crate::config::ConfigService::user_agent()) + .build()?; + let base = crate::config::ConfigService::hf_api_base_for(repo); + let urls = [base.clone(), format!("{}?expand=files", base)]; + for url in urls { + let mut req = client.get(&url); + if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); } + let resp = req.send()?; + if !resp.status().is_success() { continue; } + let info: ApiHfModelInfo = resp.json()?; + let list = info.files.or(info.siblings).unwrap_or_default(); + for f in list { + let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename); + if name.eq_ignore_ascii_case(target) { + let sz = f.size.or_else(|| f.lfs.as_ref().and_then(|l| l.size)); + let sha = pick_sha_from_file(&f); + return Ok((sz, sha)); + } + } + } + Err(anyhow!("file not found in HF API").into()) +} + +/// Fetch remote metadata (size, sha256) for a single file in a HF repo. +pub fn fetch_file_meta(repo: &str, file: &str) -> Result<(Option, Option)> { + hf_fetch_file_meta(repo, file) +} + +/// Search a Hugging Face repo for GGUF/BIN files via API. Returns file names only. +pub fn search_repo(repo: &str, query: Option<&str>) -> Result> { + let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty()); + let client = Client::builder() + .user_agent(crate::config::ConfigService::user_agent()) + .build()?; + + let base = crate::config::ConfigService::hf_api_base_for(repo); + let mut urls = vec![base.clone(), format!("{}?expand=files", base)]; + let mut files = Vec::::new(); + + for url in urls.drain(..) { + let mut req = client.get(&url); + if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); } + let resp = req.send()?; + if !resp.status().is_success() { continue; } + let info: ApiHfModelInfo = resp.json()?; + let list = info.files.or(info.siblings).unwrap_or_default(); + for f in list { + if f.rfilename.ends_with(".gguf") || f.rfilename.ends_with(".bin") { + let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename).to_string(); + if !files.contains(&name) { files.push(name); } + } + } + if !files.is_empty() { break; } + } + + if let Some(q) = query { let qlc = q.to_ascii_lowercase(); files.retain(|f| f.to_ascii_lowercase().contains(&qlc)); } + files.sort(); + Ok(files) +} + +/// List repo files with optional size metadata for GGUF/BIN entries. +pub fn list_repo_files_with_meta(repo: &str) -> Result)>> { + let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty()); + let client = Client::builder() + .user_agent(crate::config::ConfigService::user_agent()) + .build()?; + + let base = crate::config::ConfigService::hf_api_base_for(repo); + for url in [base.clone(), format!("{}?expand=files", base)] { + let mut req = client.get(&url); + if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); } + let resp = req.send()?; + if !resp.status().is_success() { continue; } + let info: ApiHfModelInfo = resp.json()?; + let list = info.files.or(info.siblings).unwrap_or_default(); + let mut out = Vec::new(); + for f in list { + if !(f.rfilename.ends_with(".gguf") || f.rfilename.ends_with(".bin")) { continue; } + let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename).to_string(); + let size = f.size.or_else(|| f.lfs.as_ref().and_then(|l| l.size)); + out.push((name, size)); + } + if !out.is_empty() { return Ok(out); } + } + Ok(Vec::new()) +} + +/// Fallback: HEAD request for a single file to retrieve Content-Length (size). +pub fn head_len_for_file(repo: &str, file: &str) -> Option { + let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty()); + let client = Client::builder() + .user_agent(crate::config::ConfigService::user_agent()) + .build().ok()?; + let mut urls = Vec::new(); + urls.push(format!("https://huggingface.co/{}/resolve/main/{}?download=true", repo, file)); + urls.push(format!("https://huggingface.co/{}/resolve/main/{}", repo, file)); + for url in urls { + let mut req = client.head(&url); + if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); } + if let Ok(resp) = req.send() { + if resp.status().is_success() { + if let Some(len) = resp.headers().get(CONTENT_LENGTH) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()) + { return Some(len); } + } + } + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + use std::env; + use tempfile::TempDir; + + #[derive(Clone)] + struct StubHttp { + data: Arc>, + etag: Arc>, + accept_ranges: bool, + } + impl HttpClient for StubHttp { + fn head(&self, _url: &str, etag: Option<&str>) -> Result { + let not_modified = etag.is_some() && self.etag.as_ref().as_deref() == etag; + Ok(HeadMeta { + len: Some(self.data.len() as u64), + etag: self.etag.as_ref().clone(), + last_modified: None, + accept_ranges: self.accept_ranges, + not_modified, + status: if not_modified { 304 } else { 200 }, + }) + } + fn get_range(&self, _url: &str, start: u64, end_inclusive: u64) -> Result> { + let s = start as usize; + let e = (end_inclusive as usize) + 1; + Ok(self.data[s..e].to_vec()) + } + + fn get_whole_to(&self, _url: &str, writer: &mut dyn Write) -> Result<()> { + writer.write_all(&self.data)?; + Ok(()) + } + + fn get_from_to(&self, _url: &str, start: u64, writer: &mut dyn Write) -> Result<()> { + let s = start as usize; + writer.write_all(&self.data[s..])?; + Ok(()) + } + } + + fn setup_env(cache: &Path, cfg: &Path) { + unsafe { + env::set_var("POLYSCRIBE_CACHE_DIR", cache.to_string_lossy().to_string()); + env::set_var( + "POLYSCRIBE_CONFIG_DIR", + cfg.parent().unwrap().to_string_lossy().to_string(), + ); + } + } + + #[test] + fn test_manifest_roundtrip() { + let temp = TempDir::new().unwrap(); + let cache = temp.path().join("cache"); + let cfg = temp.path().join("config").join("models.json"); + setup_env(&cache, &cfg); + + let client = StubHttp { + data: Arc::new(vec![0u8; 1024]), + etag: Arc::new(Some("etag123".into())), + accept_ranges: true, + }; + let mm: ModelManager = ModelManager::new_with_client(client, Settings::default()).unwrap(); + let m = mm.load_manifest().unwrap(); + assert!(m.models.is_empty()); + + let rec = ModelRecord { + alias: "tiny".into(), + repo: "foo/bar".into(), + file: "gguf-tiny.bin".into(), + revision: Some("etag123".into()), + sha256: None, + size_bytes: None, + quant: None, + installed_at: None, + last_used: None, + }; + + let mut m2 = Manifest::default(); + mm.save_touch(&mut m2, rec.clone()).unwrap(); + let m3 = mm.load_manifest().unwrap(); + assert!(m3.models.contains_key("tiny")); + } + + #[test] + fn test_add_verify_update_gc() { + let temp = TempDir::new().unwrap(); + let cache = temp.path().join("cache"); + let cfg_dir = temp.path().join("config"); + let cfg = cfg_dir.join("models.json"); + setup_env(&cache, &cfg); + + let data = (0..1024 * 1024u32).flat_map(|i| i.to_le_bytes()).collect::>(); + let etag = Some("abc123".to_string()); + let client = StubHttp { data: Arc::new(data), etag: Arc::new(etag), accept_ranges: true }; + let mm: ModelManager = ModelManager::new_with_client(client.clone(), Settings{ concurrency: 3, ..Default::default() }).unwrap(); + + // add + let rec = mm.add_or_update("tiny", "gguf/models", "gguf-tiny-q4_0.bin").unwrap(); + assert_eq!(rec.alias, "tiny"); + assert!(mm.verify("tiny").unwrap()); + + // update (304) + let rec2 = mm.add_or_update("tiny", "gguf/models", "gguf-tiny-q4_0.bin").unwrap(); + assert_eq!(rec2.alias, "tiny"); + + // gc (nothing to remove) + let (files_removed, entries_removed) = mm.gc().unwrap(); + assert_eq!(files_removed, 0); + assert_eq!(entries_removed, 0); + + // rm + assert!(mm.rm("tiny").unwrap()); + assert!(!mm.rm("tiny").unwrap()); + } +} diff --git a/crates/polyscribe-core/src/models.rs b/crates/polyscribe-core/src/models.rs index c2c71cb..9080db3 100644 --- a/crates/polyscribe-core/src/models.rs +++ b/crates/polyscribe-core/src/models.rs @@ -792,6 +792,13 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> { } let part_path = dest_path.with_extension("part"); + // Guard to cleanup .part on errors + struct TempGuard { path: std::path::PathBuf, armed: bool } + impl TempGuard { fn disarm(&mut self) { self.armed = false; } } + impl Drop for TempGuard { + fn drop(&mut self) { if self.armed { let _ = fs::remove_file(&self.path); } } + } + let mut _tmp_guard = TempGuard { path: part_path.clone(), armed: true }; let mut resume_from: u64 = 0; if part_path.exists() && ranges_ok { @@ -893,6 +900,7 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> { drop(part_file); fs::rename(&part_path, dest_path) .with_context(|| format!("renaming {} → {}", part_path.display(), dest_path.display()))?; + _tmp_guard.disarm(); let final_size = fs::metadata(dest_path).map(|m| m.len()).ok(); let elapsed = start.elapsed().as_secs_f64(); diff --git a/crates/polyscribe-core/src/ui.rs b/crates/polyscribe-core/src/ui.rs index b1c780d..d979248 100644 --- a/crates/polyscribe-core/src/ui.rs +++ b/crates/polyscribe-core/src/ui.rs @@ -4,6 +4,8 @@ pub mod progress; use std::io; use std::io::IsTerminal; +use std::io::Write as _; +use std::time::{Duration, Instant}; pub fn info(msg: impl AsRef) { let m = msg.as_ref(); @@ -170,43 +172,156 @@ impl Spinner { } } -pub struct BytesProgress(Option); +pub struct BytesProgress { + enabled: bool, + total: u64, + current: u64, + started: Instant, + last_msg: Instant, + width: usize, + // Sticky ETA to carry through zero-speed stalls + last_eta_secs: Option, +} impl BytesProgress { pub fn start(total: u64, text: &str, initial: u64) -> Self { - if crate::is_no_progress() + let enabled = !(crate::is_no_progress() || crate::is_no_interaction() || !std::io::stderr().is_terminal() - || total == 0 - { + || total == 0); + if !enabled { let _ = cliclack::log::info(text); - return Self(None); } - let b = cliclack::progress_bar(total); - b.start(text); - if initial > 0 { - b.inc(initial); + let mut me = Self { + enabled, + total, + current: initial.min(total), + started: Instant::now(), + last_msg: Instant::now(), + width: 40, + last_eta_secs: None, + }; + me.draw(); + me + } + + fn human_bytes(n: u64) -> String { + const KB: f64 = 1024.0; + const MB: f64 = 1024.0 * KB; + const GB: f64 = 1024.0 * MB; + let x = n as f64; + if x >= GB { + format!("{:.2} GiB", x / GB) + } else if x >= MB { + format!("{:.2} MiB", x / MB) + } else if x >= KB { + format!("{:.2} KiB", x / KB) + } else { + format!("{} B", n) } - Self(Some(b)) + } + + // Elapsed formatting is used for stable, finite durations. For ETA, we guard + // against zero-speed or unstable estimates separately via `format_eta`. + + fn refresh_allowed(&mut self) -> (f64, f64) { + let now = Instant::now(); + let since_last = now.duration_since(self.last_msg); + if since_last < Duration::from_millis(100) { + // Too soon to refresh; keep previous ETA if any + let eta = self.last_eta_secs.unwrap_or(f64::INFINITY); + return (0.0, eta); + } + self.last_msg = now; + let elapsed = now.duration_since(self.started).as_secs_f64().max(0.001); + let speed = (self.current as f64) / elapsed; + let remaining = self.total.saturating_sub(self.current) as f64; + + // If speed is effectively zero, carry ETA forward and add wall time. + const EPS: f64 = 1e-6; + let eta = if speed <= EPS { + let prev = self.last_eta_secs.unwrap_or(f64::INFINITY); + if prev.is_finite() { + prev + since_last.as_secs_f64() + } else { + prev + } + } else { + remaining / speed + }; + // Remember only finite ETAs to use during stalls + if eta.is_finite() { + self.last_eta_secs = Some(eta); + } + (speed, eta) + } + + fn format_elapsed(seconds: f64) -> String { + let total = seconds.round() as u64; + let h = total / 3600; + let m = (total % 3600) / 60; + let s = total % 60; + if h > 0 { format!("{:02}:{:02}:{:02}", h, m, s) } else { format!("{:02}:{:02}", m, s) } + } + + fn format_eta(seconds: f64) -> String { + // If ETA is not finite (e.g., divide-by-zero speed) or unreasonably large, + // show a placeholder rather than overflowing into huge values. + if !seconds.is_finite() { + return "—".to_string(); + } + // Cap ETA display to 99:59:59 to avoid silly numbers; beyond that, show placeholder. + const CAP_SECS: f64 = 99.0 * 3600.0 + 59.0 * 60.0 + 59.0; + if seconds > CAP_SECS { + return "—".to_string(); + } + Self::format_elapsed(seconds) + } + + fn draw(&mut self) { + if !self.enabled { return; } + let (speed, eta) = self.refresh_allowed(); + let elapsed = Instant::now().duration_since(self.started).as_secs_f64(); + // Build bar + let width = self.width.max(10); + let filled = ((self.current as f64 / self.total.max(1) as f64) * width as f64).round() as usize; + let filled = filled.min(width); + let mut bar = String::with_capacity(width); + for _ in 0..filled { bar.push('■'); } + for _ in filled..width { bar.push('□'); } + + let line = format!( + "[{}] {} [{}] ({}/{} at {}/s)", + Self::format_elapsed(elapsed), + bar, + Self::format_eta(eta), + Self::human_bytes(self.current), + Self::human_bytes(self.total), + Self::human_bytes(speed.max(0.0) as u64), + ); + eprint!("\r{}\x1b[K", line); + let _ = io::stderr().flush(); } pub fn inc(&mut self, delta: u64) { - if let Some(b) = self.0.as_mut() { - b.inc(delta); - } + self.current = self.current.saturating_add(delta).min(self.total); + self.draw(); } pub fn stop(mut self, text: &str) { - if let Some(b) = self.0.take() { - b.stop(text); + if self.enabled { + self.draw(); + eprintln!(); } else { let _ = cliclack::log::info(text); } } pub fn error(mut self, text: &str) { - if let Some(b) = self.0.take() { - b.error(text); + if self.enabled { + self.draw(); + eprintln!(); + let _ = cliclack::log::error(text); } else { let _ = cliclack::log::error(text); }