[feat] add ModelManager
with caching, manifest management, and Hugging Face API integration
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -251,6 +251,7 @@ dependencies = [
|
|||||||
"iana-time-zone",
|
"iana-time-zone",
|
||||||
"js-sys",
|
"js-sys",
|
||||||
"num-traits",
|
"num-traits",
|
||||||
|
"serde",
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
"windows-link",
|
"windows-link",
|
||||||
]
|
]
|
||||||
|
@@ -21,7 +21,7 @@ anyhow = "1.0.99"
|
|||||||
libc = "0.2.175"
|
libc = "0.2.175"
|
||||||
toml = "0.8.23"
|
toml = "0.8.23"
|
||||||
serde_json = "1.0.142"
|
serde_json = "1.0.142"
|
||||||
chrono = "0.4.41"
|
chrono = { version = "0.4.41", features = ["serde"] }
|
||||||
sha2 = "0.10.9"
|
sha2 = "0.10.9"
|
||||||
which = "6.0.3"
|
which = "6.0.3"
|
||||||
tokio = { version = "1.47.1", features = ["rt-multi-thread", "macros"] }
|
tokio = { version = "1.47.1", features = ["rt-multi-thread", "macros"] }
|
||||||
|
@@ -74,7 +74,7 @@ pub enum Commands {
|
|||||||
inputs: Vec<PathBuf>,
|
inputs: Vec<PathBuf>,
|
||||||
},
|
},
|
||||||
|
|
||||||
/// Manage Whisper models
|
/// Manage Whisper GGUF models (Hugging Face)
|
||||||
Models {
|
Models {
|
||||||
#[command(subcommand)]
|
#[command(subcommand)]
|
||||||
cmd: ModelsCmd,
|
cmd: ModelsCmd,
|
||||||
@@ -97,14 +97,67 @@ pub enum Commands {
|
|||||||
Man,
|
Man,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Parser)]
|
||||||
|
pub struct ModelCommon {
|
||||||
|
/// Concurrency for ranged downloads
|
||||||
|
#[arg(long, default_value_t = 4)]
|
||||||
|
pub concurrency: usize,
|
||||||
|
/// Limit download rate in bytes/sec (approximate)
|
||||||
|
#[arg(long)]
|
||||||
|
pub limit_rate: Option<u64>,
|
||||||
|
/// Emit machine JSON output
|
||||||
|
#[arg(long, default_value_t = false)]
|
||||||
|
pub json: bool,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Subcommand)]
|
#[derive(Debug, Subcommand)]
|
||||||
pub enum ModelsCmd {
|
pub enum ModelsCmd {
|
||||||
/// Verify or update local models non-interactively
|
/// List installed models (from manifest)
|
||||||
Update,
|
Ls {
|
||||||
/// Interactive multi-select downloader
|
#[command(flatten)]
|
||||||
Download,
|
common: ModelCommon,
|
||||||
/// Clear the cached Hugging Face manifest
|
},
|
||||||
ClearCache,
|
/// Add or update a model
|
||||||
|
Add {
|
||||||
|
/// Hugging Face repo, e.g. ggml-org/models
|
||||||
|
repo: String,
|
||||||
|
/// File name in repo (e.g., gguf-tiny-q4_0.bin)
|
||||||
|
file: String,
|
||||||
|
#[command(flatten)]
|
||||||
|
common: ModelCommon,
|
||||||
|
},
|
||||||
|
/// Remove a model by alias
|
||||||
|
Rm {
|
||||||
|
alias: String,
|
||||||
|
#[command(flatten)]
|
||||||
|
common: ModelCommon,
|
||||||
|
},
|
||||||
|
/// Verify model file integrity by alias
|
||||||
|
Verify {
|
||||||
|
alias: String,
|
||||||
|
#[command(flatten)]
|
||||||
|
common: ModelCommon,
|
||||||
|
},
|
||||||
|
/// Update all models (HEAD + ETag; skip if unchanged)
|
||||||
|
Update {
|
||||||
|
#[command(flatten)]
|
||||||
|
common: ModelCommon,
|
||||||
|
},
|
||||||
|
/// Garbage-collect unreferenced files and stale manifest entries
|
||||||
|
Gc {
|
||||||
|
#[command(flatten)]
|
||||||
|
common: ModelCommon,
|
||||||
|
},
|
||||||
|
/// Search a repo for GGUF files
|
||||||
|
Search {
|
||||||
|
/// Hugging Face repo, e.g. ggml-org/models
|
||||||
|
repo: String,
|
||||||
|
/// Optional substring to filter filenames
|
||||||
|
#[arg(long)]
|
||||||
|
query: Option<String>,
|
||||||
|
#[command(flatten)]
|
||||||
|
common: ModelCommon,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Subcommand)]
|
#[derive(Debug, Subcommand)]
|
||||||
|
@@ -2,8 +2,49 @@ mod cli;
|
|||||||
|
|
||||||
use anyhow::{Context, Result, anyhow};
|
use anyhow::{Context, Result, anyhow};
|
||||||
use clap::{CommandFactory, Parser};
|
use clap::{CommandFactory, Parser};
|
||||||
use cli::{Cli, Commands, GpuBackend, ModelsCmd, PluginsCmd};
|
use cli::{Cli, Commands, GpuBackend, ModelsCmd, ModelCommon, PluginsCmd};
|
||||||
use polyscribe_core::models;
|
use polyscribe_core::model_manager::{ModelManager, Settings, ReqwestClient};
|
||||||
|
use polyscribe_core::ui;
|
||||||
|
fn normalized_similarity(a: &str, b: &str) -> f64 {
|
||||||
|
// simple Levenshtein distance; normalized to [0,1]
|
||||||
|
let a_bytes = a.as_bytes();
|
||||||
|
let b_bytes = b.as_bytes();
|
||||||
|
let n = a_bytes.len();
|
||||||
|
let m = b_bytes.len();
|
||||||
|
if n == 0 && m == 0 { return 1.0; }
|
||||||
|
if n == 0 || m == 0 { return 0.0; }
|
||||||
|
let mut prev: Vec<usize> = (0..=m).collect();
|
||||||
|
let mut curr: Vec<usize> = vec![0; m + 1];
|
||||||
|
for i in 1..=n {
|
||||||
|
curr[0] = i;
|
||||||
|
for j in 1..=m {
|
||||||
|
let cost = if a_bytes[i - 1] == b_bytes[j - 1] { 0 } else { 1 };
|
||||||
|
curr[j] = (prev[j] + 1)
|
||||||
|
.min(curr[j - 1] + 1)
|
||||||
|
.min(prev[j - 1] + cost);
|
||||||
|
}
|
||||||
|
std::mem::swap(&mut prev, &mut curr);
|
||||||
|
}
|
||||||
|
let dist = prev[m] as f64;
|
||||||
|
let max_len = n.max(m) as f64;
|
||||||
|
1.0 - (dist / max_len)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn human_size(bytes: Option<u64>) -> String {
|
||||||
|
match bytes {
|
||||||
|
Some(n) => {
|
||||||
|
let x = n as f64;
|
||||||
|
const KB: f64 = 1024.0;
|
||||||
|
const MB: f64 = 1024.0 * KB;
|
||||||
|
const GB: f64 = 1024.0 * MB;
|
||||||
|
if x >= GB { format!("{:.2} GiB", x / GB) }
|
||||||
|
else if x >= MB { format!("{:.2} MiB", x / MB) }
|
||||||
|
else if x >= KB { format!("{:.2} KiB", x / KB) }
|
||||||
|
else { format!("{} B", n) }
|
||||||
|
}
|
||||||
|
None => "?".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
use polyscribe_core::ui::progress::ProgressReporter;
|
use polyscribe_core::ui::progress::ProgressReporter;
|
||||||
use polyscribe_host::PluginManager;
|
use polyscribe_host::PluginManager;
|
||||||
use tokio::io::AsyncWriteExt;
|
use tokio::io::AsyncWriteExt;
|
||||||
@@ -29,8 +70,7 @@ fn init_tracing(quiet: bool, verbose: u8) {
|
|||||||
.init();
|
.init();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
fn main() -> Result<()> {
|
||||||
async fn main() -> Result<()> {
|
|
||||||
let args = Cli::parse();
|
let args = Cli::parse();
|
||||||
|
|
||||||
init_tracing(args.quiet, args.verbose);
|
init_tracing(args.quiet, args.verbose);
|
||||||
@@ -71,32 +111,188 @@ async fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Commands::Models { cmd } => {
|
Commands::Models { cmd } => {
|
||||||
match cmd {
|
// predictable exit codes
|
||||||
ModelsCmd::Update => {
|
const EXIT_OK: i32 = 0;
|
||||||
polyscribe_core::ui::info("verifying/updating local models");
|
const EXIT_NOT_FOUND: i32 = 2;
|
||||||
tokio::task::spawn_blocking(models::update_local_models)
|
const EXIT_NETWORK: i32 = 3;
|
||||||
.await
|
const EXIT_VERIFY_FAILED: i32 = 4;
|
||||||
.map_err(|e| anyhow!("blocking task join error: {e}"))?
|
// const EXIT_NO_CHANGE: i32 = 5; // reserved
|
||||||
.context("updating models")?;
|
|
||||||
|
let handle_common = |c: &ModelCommon| Settings {
|
||||||
|
concurrency: c.concurrency.max(1),
|
||||||
|
limit_rate: c.limit_rate,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let exit = match cmd {
|
||||||
|
ModelsCmd::Ls { common } => {
|
||||||
|
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
|
||||||
|
let list = mm.ls()?;
|
||||||
|
if common.json {
|
||||||
|
println!("{}", serde_json::to_string_pretty(&list)?);
|
||||||
|
} else {
|
||||||
|
println!("Model (Repo)");
|
||||||
|
for r in list {
|
||||||
|
println!("{} ({})", r.file, r.repo);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXIT_OK
|
||||||
}
|
}
|
||||||
ModelsCmd::Download => {
|
ModelsCmd::Add { repo, file, common } => {
|
||||||
polyscribe_core::ui::info("interactive model selection and download");
|
let settings = handle_common(&common);
|
||||||
tokio::task::spawn_blocking(models::run_interactive_model_downloader)
|
let mm: ModelManager<ReqwestClient> = ModelManager::new(settings.clone())?;
|
||||||
.await
|
// Derive an alias automatically from repo and file
|
||||||
.map_err(|e| anyhow!("blocking task join error: {e}"))?
|
fn derive_alias(repo: &str, file: &str) -> String {
|
||||||
.context("running downloader")?;
|
use std::path::Path;
|
||||||
polyscribe_core::ui::success("Model download complete.");
|
let repo_tail = repo.rsplit('/').next().unwrap_or(repo);
|
||||||
|
let stem = Path::new(file)
|
||||||
|
.file_stem()
|
||||||
|
.and_then(|s| s.to_str())
|
||||||
|
.unwrap_or(file);
|
||||||
|
format!("{}-{}", repo_tail, stem)
|
||||||
|
}
|
||||||
|
let alias = derive_alias(&repo, &file);
|
||||||
|
match mm.add_or_update(&alias, &repo, &file) {
|
||||||
|
Ok(rec) => {
|
||||||
|
if common.json { println!("{}", serde_json::to_string_pretty(&rec)?); }
|
||||||
|
else { println!("installed: {} -> {}/{}", alias, repo, rec.file); }
|
||||||
|
EXIT_OK
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// On not found or similar errors, try suggesting close matches interactively
|
||||||
|
if common.json || polyscribe_core::is_no_interaction() {
|
||||||
|
if common.json { println!("{{\"error\":{}}}", serde_json::to_string(&e.to_string())?); }
|
||||||
|
else { eprintln!("error: {e}"); }
|
||||||
|
EXIT_NOT_FOUND
|
||||||
|
} else {
|
||||||
|
ui::warn(format!("{}", e));
|
||||||
|
ui::info("Searching for similar model filenames…");
|
||||||
|
match polyscribe_core::model_manager::search_repo(&repo, None) {
|
||||||
|
Ok(mut files) => {
|
||||||
|
if files.is_empty() {
|
||||||
|
ui::warn("No files found in repository.");
|
||||||
|
EXIT_NOT_FOUND
|
||||||
|
} else {
|
||||||
|
// rank by similarity
|
||||||
|
files.sort_by(|a, b| normalized_similarity(&file, b)
|
||||||
|
.partial_cmp(&normalized_similarity(&file, a))
|
||||||
|
.unwrap_or(std::cmp::Ordering::Equal));
|
||||||
|
let top: Vec<String> = files.into_iter().take(5).collect();
|
||||||
|
if top.is_empty() {
|
||||||
|
EXIT_NOT_FOUND
|
||||||
|
} else if top.len() == 1 {
|
||||||
|
let cand = &top[0];
|
||||||
|
// Fetch repo size list once
|
||||||
|
let size_map: std::collections::HashMap<String, Option<u64>> =
|
||||||
|
polyscribe_core::model_manager::list_repo_files_with_meta(&repo)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.into_iter().collect();
|
||||||
|
let mut size = size_map.get(cand).cloned().unwrap_or(None);
|
||||||
|
if size.is_none() {
|
||||||
|
size = polyscribe_core::model_manager::head_len_for_file(&repo, cand);
|
||||||
|
}
|
||||||
|
let local_files: std::collections::HashSet<String> = mm.ls()?.into_iter().map(|r| r.file).collect();
|
||||||
|
let is_local = local_files.contains(cand);
|
||||||
|
let label = format!("{} [{}]{}", cand, human_size(size), if is_local { " (local)" } else { "" });
|
||||||
|
let ok = ui::prompt_confirm(&format!("Did you mean {}?", label), true)
|
||||||
|
.unwrap_or(false);
|
||||||
|
if !ok { EXIT_NOT_FOUND } else {
|
||||||
|
let mm2: ModelManager<ReqwestClient> = ModelManager::new(settings)?;
|
||||||
|
let alias2 = derive_alias(&repo, cand);
|
||||||
|
match mm2.add_or_update(&alias2, &repo, cand) {
|
||||||
|
Ok(rec) => { println!("installed: {} -> {}/{}", alias2, repo, rec.file); EXIT_OK }
|
||||||
|
Err(e2) => { eprintln!("error: {e2}"); EXIT_NETWORK }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let opts: Vec<String> = top;
|
||||||
|
let local_files: std::collections::HashSet<String> = mm.ls()?.into_iter().map(|r| r.file).collect();
|
||||||
|
// Enrich labels with size and local tag using a single API call
|
||||||
|
let size_map: std::collections::HashMap<String, Option<u64>> =
|
||||||
|
polyscribe_core::model_manager::list_repo_files_with_meta(&repo)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.into_iter().collect();
|
||||||
|
let mut labels_owned: Vec<String> = Vec::new();
|
||||||
|
for f in &opts {
|
||||||
|
let mut size = size_map.get(f).cloned().unwrap_or(None);
|
||||||
|
if size.is_none() {
|
||||||
|
size = polyscribe_core::model_manager::head_len_for_file(&repo, f);
|
||||||
|
}
|
||||||
|
let is_local = local_files.contains(f);
|
||||||
|
let suffix = if is_local { " (local)" } else { "" };
|
||||||
|
labels_owned.push(format!("{} [{}]{}", f, human_size(size), suffix));
|
||||||
|
}
|
||||||
|
let labels: Vec<&str> = labels_owned.iter().map(|s| s.as_str()).collect();
|
||||||
|
match ui::prompt_select("Pick a model", &labels) {
|
||||||
|
Ok(idx) => {
|
||||||
|
let chosen = &opts[idx];
|
||||||
|
let mm2: ModelManager<ReqwestClient> = ModelManager::new(settings)?;
|
||||||
|
let alias2 = derive_alias(&repo, chosen);
|
||||||
|
match mm2.add_or_update(&alias2, &repo, chosen) {
|
||||||
|
Ok(rec) => { println!("installed: {} -> {}/{}", alias2, repo, rec.file); EXIT_OK }
|
||||||
|
Err(e2) => { eprintln!("error: {e2}"); EXIT_NETWORK }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => EXIT_NOT_FOUND,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e2) => {
|
||||||
|
eprintln!("error: {}", e2);
|
||||||
|
EXIT_NETWORK
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
ModelsCmd::ClearCache => {
|
ModelsCmd::Rm { alias, common } => {
|
||||||
polyscribe_core::ui::info("clearing manifest cache");
|
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
|
||||||
tokio::task::spawn_blocking(models::clear_manifest_cache)
|
let ok = mm.rm(&alias)?;
|
||||||
.await
|
if common.json { println!("{{\"removed\":{}}}", ok); }
|
||||||
.map_err(|e| anyhow!("blocking task join error: {e}"))?
|
else { println!("{}", if ok { "removed" } else { "not found" }); }
|
||||||
.context("clearing cache")?;
|
if ok { EXIT_OK } else { EXIT_NOT_FOUND }
|
||||||
polyscribe_core::ui::success("Manifest cache cleared.");
|
|
||||||
}
|
}
|
||||||
}
|
ModelsCmd::Verify { alias, common } => {
|
||||||
Ok(())
|
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
|
||||||
|
let found = mm.ls()?.into_iter().any(|r| r.alias == alias);
|
||||||
|
if !found {
|
||||||
|
if common.json { println!("{{\"ok\":false,\"error\":\"not found\"}}"); } else { println!("not found"); }
|
||||||
|
EXIT_NOT_FOUND
|
||||||
|
} else {
|
||||||
|
let ok = mm.verify(&alias)?;
|
||||||
|
if common.json { println!("{{\"ok\":{}}}", ok); } else { println!("{}", if ok { "ok" } else { "corrupt" }); }
|
||||||
|
if ok { EXIT_OK } else { EXIT_VERIFY_FAILED }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ModelsCmd::Update { common } => {
|
||||||
|
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
|
||||||
|
let mut rc = EXIT_OK;
|
||||||
|
for rec in mm.ls()? {
|
||||||
|
match mm.add_or_update(&rec.alias, &rec.repo, &rec.file) {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(e) => { rc = EXIT_NETWORK; if common.json { println!("{{\"alias\":\"{}\",\"error\":{}}}", rec.alias, serde_json::to_string(&e.to_string())?); } else { eprintln!("update {}: {e}", rec.alias); } }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rc
|
||||||
|
}
|
||||||
|
ModelsCmd::Gc { common } => {
|
||||||
|
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
|
||||||
|
let (files_removed, entries_removed) = mm.gc()?;
|
||||||
|
if common.json { println!("{{\"files_removed\":{},\"entries_removed\":{}}}", files_removed, entries_removed); }
|
||||||
|
else { println!("files_removed={} entries_removed={}", files_removed, entries_removed); }
|
||||||
|
EXIT_OK
|
||||||
|
}
|
||||||
|
ModelsCmd::Search { repo, query, common } => {
|
||||||
|
let res = polyscribe_core::model_manager::search_repo(&repo, query.as_deref());
|
||||||
|
match res {
|
||||||
|
Ok(files) => { if common.json { println!("{}", serde_json::to_string_pretty(&files)?); } else { for f in files { println!("{}", f); } } EXIT_OK }
|
||||||
|
Err(e) => { if common.json { println!("{{\"error\":{}}}", serde_json::to_string(&e.to_string())?); } else { eprintln!("error: {e}"); } EXIT_NETWORK }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
std::process::exit(exit);
|
||||||
}
|
}
|
||||||
|
|
||||||
Commands::Plugins { cmd } => {
|
Commands::Plugins { cmd } => {
|
||||||
@@ -123,27 +319,35 @@ async fn main() -> Result<()> {
|
|||||||
command,
|
command,
|
||||||
json,
|
json,
|
||||||
} => {
|
} => {
|
||||||
let payload = json.unwrap_or_else(|| "{}".to_string());
|
// Use a local Tokio runtime only for this async path
|
||||||
let mut child = plugin_manager
|
let rt = tokio::runtime::Builder::new_multi_thread()
|
||||||
.spawn(&name, &command)
|
.enable_all()
|
||||||
.with_context(|| format!("spawning plugin {name} {command}"))?;
|
.build()
|
||||||
|
.context("building tokio runtime")?;
|
||||||
|
|
||||||
if let Some(mut stdin) = child.stdin.take() {
|
rt.block_on(async {
|
||||||
stdin
|
let payload = json.unwrap_or_else(|| "{}".to_string());
|
||||||
.write_all(payload.as_bytes())
|
let mut child = plugin_manager
|
||||||
.await
|
.spawn(&name, &command)
|
||||||
.context("writing JSON payload to plugin stdin")?;
|
.with_context(|| format!("spawning plugin {name} {command}"))?;
|
||||||
}
|
|
||||||
|
|
||||||
let status = plugin_manager.forward_stdio(&mut child).await?;
|
if let Some(mut stdin) = child.stdin.take() {
|
||||||
if !status.success() {
|
stdin
|
||||||
polyscribe_core::ui::error(format!(
|
.write_all(payload.as_bytes())
|
||||||
"plugin returned non-zero exit code: {}",
|
.await
|
||||||
status
|
.context("writing JSON payload to plugin stdin")?;
|
||||||
));
|
}
|
||||||
return Err(anyhow!("plugin failed"));
|
|
||||||
}
|
let status = plugin_manager.forward_stdio(&mut child).await?;
|
||||||
Ok(())
|
if !status.success() {
|
||||||
|
polyscribe_core::ui::error(format!(
|
||||||
|
"plugin returned non-zero exit code: {}",
|
||||||
|
status
|
||||||
|
));
|
||||||
|
return Err(anyhow!("plugin failed"));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -214,6 +214,8 @@ pub fn render_srt(entries: &[OutputEntry]) -> String {
|
|||||||
srt
|
srt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub mod model_manager;
|
||||||
|
|
||||||
pub fn models_dir_path() -> PathBuf {
|
pub fn models_dir_path() -> PathBuf {
|
||||||
if let Ok(env_val) = env::var("POLYSCRIBE_MODELS_DIR") {
|
if let Ok(env_val) = env::var("POLYSCRIBE_MODELS_DIR") {
|
||||||
let env_path = PathBuf::from(env_val);
|
let env_path = PathBuf::from(env_val);
|
||||||
|
893
crates/polyscribe-core/src/model_manager.rs
Normal file
893
crates/polyscribe-core/src/model_manager.rs
Normal file
@@ -0,0 +1,893 @@
|
|||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
use crate::prelude::*;
|
||||||
|
use crate::ui::BytesProgress;
|
||||||
|
use anyhow::{anyhow, Context};
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use reqwest::blocking::Client;
|
||||||
|
use reqwest::header::{
|
||||||
|
ACCEPT_RANGES, AUTHORIZATION, CONTENT_LENGTH, ETAG, IF_NONE_MATCH, LAST_MODIFIED, RANGE,
|
||||||
|
};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use sha2::{Digest, Sha256};
|
||||||
|
use std::cmp::min;
|
||||||
|
use std::collections::BTreeMap;
|
||||||
|
use std::fs::{self, File, OpenOptions};
|
||||||
|
use std::io::{Read, Seek, SeekFrom, Write};
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::sync::{mpsc, Arc, Mutex};
|
||||||
|
use std::thread;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
const DEFAULT_CHUNK_SIZE: u64 = 8 * 1024 * 1024; // 8 MiB
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub struct ModelRecord {
|
||||||
|
pub alias: String,
|
||||||
|
pub repo: String,
|
||||||
|
pub file: String,
|
||||||
|
pub revision: Option<String>, // ETag or commit hash
|
||||||
|
pub sha256: Option<String>,
|
||||||
|
pub size_bytes: Option<u64>,
|
||||||
|
pub quant: Option<String>,
|
||||||
|
pub installed_at: Option<DateTime<Utc>>,
|
||||||
|
pub last_used: Option<DateTime<Utc>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||||
|
pub struct Manifest {
|
||||||
|
pub models: BTreeMap<String, ModelRecord>, // key = alias
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Settings {
|
||||||
|
pub concurrency: usize,
|
||||||
|
pub limit_rate: Option<u64>, // bytes/sec
|
||||||
|
pub chunk_size: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Settings {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
concurrency: 4,
|
||||||
|
limit_rate: None,
|
||||||
|
chunk_size: DEFAULT_CHUNK_SIZE,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Paths {
|
||||||
|
pub cache_dir: PathBuf, // $XDG_CACHE_HOME/polyscribe/models
|
||||||
|
pub config_path: PathBuf, // $XDG_CONFIG_HOME/polyscribe/models.json
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Paths {
|
||||||
|
pub fn resolve() -> Result<Self> {
|
||||||
|
if let Ok(over) = std::env::var("POLYSCRIBE_CACHE_DIR") {
|
||||||
|
if !over.is_empty() {
|
||||||
|
let cache_dir = PathBuf::from(over).join("models");
|
||||||
|
let config_path = std::env::var("POLYSCRIBE_CONFIG_DIR")
|
||||||
|
.map(|p| PathBuf::from(p).join("models.json"))
|
||||||
|
.unwrap_or_else(|_| default_config_path());
|
||||||
|
return Ok(Self {
|
||||||
|
cache_dir,
|
||||||
|
config_path,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let cache_dir = default_cache_dir();
|
||||||
|
let config_path = default_config_path();
|
||||||
|
Ok(Self {
|
||||||
|
cache_dir,
|
||||||
|
config_path,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_cache_dir() -> PathBuf {
|
||||||
|
if let Ok(xdg) = std::env::var("XDG_CACHE_HOME") {
|
||||||
|
if !xdg.is_empty() {
|
||||||
|
return PathBuf::from(xdg).join("polyscribe").join("models");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Ok(home) = std::env::var("HOME") {
|
||||||
|
if !home.is_empty() {
|
||||||
|
return PathBuf::from(home)
|
||||||
|
.join(".cache")
|
||||||
|
.join("polyscribe")
|
||||||
|
.join("models");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PathBuf::from("models")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_config_path() -> PathBuf {
|
||||||
|
if let Ok(xdg) = std::env::var("XDG_CONFIG_HOME") {
|
||||||
|
if !xdg.is_empty() {
|
||||||
|
return PathBuf::from(xdg).join("polyscribe").join("models.json");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Ok(home) = std::env::var("HOME") {
|
||||||
|
if !home.is_empty() {
|
||||||
|
return PathBuf::from(home)
|
||||||
|
.join(".config")
|
||||||
|
.join("polyscribe")
|
||||||
|
.join("models.json");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PathBuf::from("models.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait HttpClient: Send + Sync {
|
||||||
|
fn head(&self, url: &str, etag: Option<&str>) -> Result<HeadMeta>;
|
||||||
|
fn get_range(&self, url: &str, start: u64, end_inclusive: u64) -> Result<Vec<u8>>;
|
||||||
|
fn get_whole_to(&self, url: &str, writer: &mut dyn Write) -> Result<()>;
|
||||||
|
fn get_from_to(&self, url: &str, start: u64, writer: &mut dyn Write) -> Result<()>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ReqwestClient {
|
||||||
|
client: Client,
|
||||||
|
token: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReqwestClient {
|
||||||
|
pub fn new() -> Result<Self> {
|
||||||
|
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
|
||||||
|
let client = Client::builder()
|
||||||
|
.user_agent(crate::config::ConfigService::user_agent())
|
||||||
|
.build()?;
|
||||||
|
Ok(Self { client, token })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn auth(&self, mut req: reqwest::blocking::RequestBuilder) -> reqwest::blocking::RequestBuilder {
|
||||||
|
if let Some(t) = &self.token {
|
||||||
|
req = req.header(AUTHORIZATION, format!("Bearer {}", t));
|
||||||
|
}
|
||||||
|
req
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct HeadMeta {
|
||||||
|
pub len: Option<u64>,
|
||||||
|
pub etag: Option<String>,
|
||||||
|
pub last_modified: Option<String>,
|
||||||
|
pub accept_ranges: bool,
|
||||||
|
pub not_modified: bool,
|
||||||
|
pub status: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HttpClient for ReqwestClient {
|
||||||
|
fn head(&self, url: &str, etag: Option<&str>) -> Result<HeadMeta> {
|
||||||
|
let mut req = self.client.head(url);
|
||||||
|
if let Some(e) = etag {
|
||||||
|
req = req.header(IF_NONE_MATCH, format!("\"{}\"", e));
|
||||||
|
}
|
||||||
|
let resp = self.auth(req).send()?;
|
||||||
|
let status = resp.status().as_u16();
|
||||||
|
if status == 304 {
|
||||||
|
return Ok(HeadMeta {
|
||||||
|
len: None,
|
||||||
|
etag: etag.map(|s| s.to_string()),
|
||||||
|
last_modified: None,
|
||||||
|
accept_ranges: true,
|
||||||
|
not_modified: true,
|
||||||
|
status,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let len = resp
|
||||||
|
.headers()
|
||||||
|
.get(CONTENT_LENGTH)
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.and_then(|s| s.parse::<u64>().ok());
|
||||||
|
let etag = resp
|
||||||
|
.headers()
|
||||||
|
.get(ETAG)
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(|s| s.trim_matches('"').to_string());
|
||||||
|
let last_modified = resp
|
||||||
|
.headers()
|
||||||
|
.get(LAST_MODIFIED)
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
let accept_ranges = resp
|
||||||
|
.headers()
|
||||||
|
.get(ACCEPT_RANGES)
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(|s| s.to_ascii_lowercase().contains("bytes"))
|
||||||
|
.unwrap_or(false);
|
||||||
|
Ok(HeadMeta {
|
||||||
|
len,
|
||||||
|
etag,
|
||||||
|
last_modified,
|
||||||
|
accept_ranges,
|
||||||
|
not_modified: false,
|
||||||
|
status,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_range(&self, url: &str, start: u64, end_inclusive: u64) -> Result<Vec<u8>> {
|
||||||
|
let range_val = format!("bytes={}-{}", start, end_inclusive);
|
||||||
|
let resp = self
|
||||||
|
.auth(self.client.get(url))
|
||||||
|
.header(RANGE, range_val)
|
||||||
|
.send()?;
|
||||||
|
if !resp.status().is_success() && resp.status().as_u16() != 206 {
|
||||||
|
return Err(anyhow!("HTTP {} for ranged GET", resp.status()).into());
|
||||||
|
}
|
||||||
|
let mut buf = Vec::new();
|
||||||
|
let mut r = resp;
|
||||||
|
r.copy_to(&mut buf)?;
|
||||||
|
Ok(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_whole_to(&self, url: &str, writer: &mut dyn Write) -> Result<()> {
|
||||||
|
let resp = self.auth(self.client.get(url)).send()?;
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
return Err(anyhow!("HTTP {} for GET", resp.status()).into());
|
||||||
|
}
|
||||||
|
let mut r = resp;
|
||||||
|
r.copy_to(writer)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_from_to(&self, url: &str, start: u64, writer: &mut dyn Write) -> Result<()> {
|
||||||
|
let mut req = self.auth(self.client.get(url));
|
||||||
|
if start > 0 {
|
||||||
|
req = req.header(RANGE, format!("bytes={}-", start));
|
||||||
|
}
|
||||||
|
let resp = req.send()?;
|
||||||
|
if !resp.status().is_success() && resp.status().as_u16() != 206 {
|
||||||
|
return Err(anyhow!("HTTP {} for ranged GET from {}", resp.status(), start).into());
|
||||||
|
}
|
||||||
|
let mut r = resp;
|
||||||
|
r.copy_to(writer)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ModelManager<C: HttpClient = ReqwestClient> {
|
||||||
|
pub paths: Paths,
|
||||||
|
pub settings: Settings,
|
||||||
|
client: Arc<C>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C: HttpClient + 'static> ModelManager<C> {
|
||||||
|
pub fn new_with_client(client: C, settings: Settings) -> Result<Self> {
|
||||||
|
Ok(Self {
|
||||||
|
paths: Paths::resolve()?,
|
||||||
|
settings,
|
||||||
|
client: Arc::new(client),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new(settings: Settings) -> Result<Self>
|
||||||
|
where
|
||||||
|
C: Default,
|
||||||
|
{
|
||||||
|
Ok(Self {
|
||||||
|
paths: Paths::resolve()?,
|
||||||
|
settings,
|
||||||
|
client: Arc::new(C::default()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_manifest(&self) -> Result<Manifest> {
|
||||||
|
let p = &self.paths.config_path;
|
||||||
|
if !p.exists() {
|
||||||
|
return Ok(Manifest::default());
|
||||||
|
}
|
||||||
|
let file = File::open(p).with_context(|| format!("open manifest {}", p.display()))?;
|
||||||
|
let m: Manifest = serde_json::from_reader(file).context("parse manifest")?;
|
||||||
|
Ok(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn save_manifest(&self, m: &Manifest) -> Result<()> {
|
||||||
|
let p = &self.paths.config_path;
|
||||||
|
if let Some(dir) = p.parent() {
|
||||||
|
fs::create_dir_all(dir)
|
||||||
|
.with_context(|| format!("create config dir {}", dir.display()))?;
|
||||||
|
}
|
||||||
|
let tmp = p.with_extension("json.tmp");
|
||||||
|
let f = OpenOptions::new()
|
||||||
|
.create(true)
|
||||||
|
.write(true)
|
||||||
|
.truncate(true)
|
||||||
|
.open(&tmp)?;
|
||||||
|
serde_json::to_writer_pretty(f, m).context("serialize manifest")?;
|
||||||
|
fs::rename(&tmp, p).with_context(|| format!("atomic rename {} -> {}", tmp.display(), p.display()))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn model_path(&self, file: &str) -> PathBuf {
|
||||||
|
self.paths.cache_dir.join(file)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_sha256(path: &Path) -> Result<String> {
|
||||||
|
let mut f = File::open(path)?;
|
||||||
|
let mut hasher = Sha256::new();
|
||||||
|
let mut buf = [0u8; 64 * 1024];
|
||||||
|
loop {
|
||||||
|
let n = f.read(&mut buf)?;
|
||||||
|
if n == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
hasher.update(&buf[..n]);
|
||||||
|
}
|
||||||
|
Ok(format!("{:x}", hasher.finalize()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ls(&self) -> Result<Vec<ModelRecord>> {
|
||||||
|
let m = self.load_manifest()?;
|
||||||
|
Ok(m.models.values().cloned().collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rm(&self, alias: &str) -> Result<bool> {
|
||||||
|
let mut m = self.load_manifest()?;
|
||||||
|
if let Some(rec) = m.models.remove(alias) {
|
||||||
|
let p = self.model_path(&rec.file);
|
||||||
|
let _ = fs::remove_file(&p);
|
||||||
|
self.save_manifest(&m)?;
|
||||||
|
return Ok(true);
|
||||||
|
}
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn verify(&self, alias: &str) -> Result<bool> {
|
||||||
|
let m = self.load_manifest()?;
|
||||||
|
let Some(rec) = m.models.get(alias) else { return Ok(false) };
|
||||||
|
let p = self.model_path(&rec.file);
|
||||||
|
if !p.exists() { return Ok(false); }
|
||||||
|
if let Some(expected) = &rec.sha256 {
|
||||||
|
let actual = Self::compute_sha256(&p)?;
|
||||||
|
return Ok(&actual == expected);
|
||||||
|
}
|
||||||
|
Ok(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn gc(&self) -> Result<(usize, usize)> {
|
||||||
|
// Remove files not referenced by manifest; also drop manifest entries whose file is missing
|
||||||
|
fs::create_dir_all(&self.paths.cache_dir).ok();
|
||||||
|
let mut m = self.load_manifest()?;
|
||||||
|
let mut referenced = BTreeMap::new();
|
||||||
|
for (alias, rec) in &m.models {
|
||||||
|
referenced.insert(rec.file.clone(), alias.clone());
|
||||||
|
}
|
||||||
|
let mut removed_files = 0usize;
|
||||||
|
if let Ok(rd) = fs::read_dir(&self.paths.cache_dir) {
|
||||||
|
for ent in rd.flatten() {
|
||||||
|
let p = ent.path();
|
||||||
|
if p.is_file() {
|
||||||
|
let fname = p.file_name().and_then(|s| s.to_str()).unwrap_or("");
|
||||||
|
if !referenced.contains_key(fname) {
|
||||||
|
let _ = fs::remove_file(&p);
|
||||||
|
removed_files += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.models.retain(|_, rec| self.model_path(&rec.file).exists());
|
||||||
|
let removed_entries = referenced
|
||||||
|
.keys()
|
||||||
|
.filter(|f| !self.model_path(f).exists())
|
||||||
|
.count();
|
||||||
|
|
||||||
|
self.save_manifest(&m)?;
|
||||||
|
Ok((removed_files, removed_entries))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_or_update(
|
||||||
|
&self,
|
||||||
|
alias: &str,
|
||||||
|
repo: &str,
|
||||||
|
file: &str,
|
||||||
|
) -> Result<ModelRecord> {
|
||||||
|
fs::create_dir_all(&self.paths.cache_dir)
|
||||||
|
.with_context(|| format!("create cache dir {}", self.paths.cache_dir.display()))?;
|
||||||
|
|
||||||
|
let url = format!("https://huggingface.co/{}/resolve/main/{}?download=true", repo, file);
|
||||||
|
|
||||||
|
let mut manifest = self.load_manifest()?;
|
||||||
|
let prev = manifest.models.get(alias).cloned();
|
||||||
|
let prev_etag = prev.as_ref().and_then(|r| r.revision.clone());
|
||||||
|
// Fetch remote meta (size/hash) when available to verify the download
|
||||||
|
let (_api_size, api_sha) = hf_fetch_file_meta(repo, file).unwrap_or((None, None));
|
||||||
|
let head = self.client.head(&url, prev_etag.as_deref())?;
|
||||||
|
if head.not_modified {
|
||||||
|
// up-to-date; ensure record present and touch last_used
|
||||||
|
let mut rec = prev.ok_or_else(|| anyhow!("not installed yet but server says 304"))?;
|
||||||
|
rec.last_used = Some(Utc::now());
|
||||||
|
self.save_touch(&mut manifest, rec.clone())?;
|
||||||
|
return Ok(rec);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Quick check: if HEAD failed (e.g., 404), report a helpful error before attempting download
|
||||||
|
if head.status >= 400 {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"file not found or inaccessible: repo='{}' file='{}' (HTTP {})\nHint: run `polyscribe models search {} --query {}` to list available files",
|
||||||
|
repo,
|
||||||
|
file,
|
||||||
|
head.status,
|
||||||
|
repo,
|
||||||
|
file
|
||||||
|
).into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let total_len = head.len.ok_or_else(|| anyhow!("missing content-length (HEAD)"))?;
|
||||||
|
let etag = head.etag.clone();
|
||||||
|
let dest_tmp = self.model_path(&format!("{}.part", file));
|
||||||
|
// If a previous cancelled download left a .part file, remove it to avoid clutter/resume.
|
||||||
|
if dest_tmp.exists() { let _ = fs::remove_file(&dest_tmp); }
|
||||||
|
// Guard to ensure .part is cleaned up on errors
|
||||||
|
struct TempGuard { path: PathBuf, armed: bool }
|
||||||
|
impl TempGuard { fn disarm(&mut self) { self.armed = false; } }
|
||||||
|
impl Drop for TempGuard {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if self.armed {
|
||||||
|
let _ = fs::remove_file(&self.path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut _tmp_guard = TempGuard { path: dest_tmp.clone(), armed: true };
|
||||||
|
let dest_final = self.model_path(file);
|
||||||
|
|
||||||
|
// Do not resume after cancellation; start fresh to avoid stale .part files
|
||||||
|
let start_from = 0u64;
|
||||||
|
|
||||||
|
// Open tmp for write
|
||||||
|
let f = OpenOptions::new().create(true).write(true).read(true).open(&dest_tmp)?;
|
||||||
|
f.set_len(total_len)?; // pre-allocate for random writes
|
||||||
|
let f = Arc::new(Mutex::new(f));
|
||||||
|
|
||||||
|
// Create progress bar
|
||||||
|
let mut progress = BytesProgress::start(total_len, &format!("Downloading {}", file), start_from);
|
||||||
|
|
||||||
|
// Create work chunks
|
||||||
|
let chunk_size = self.settings.chunk_size;
|
||||||
|
let mut chunks = Vec::new();
|
||||||
|
let mut pos = start_from;
|
||||||
|
while pos < total_len {
|
||||||
|
let end = min(total_len - 1, pos + chunk_size - 1);
|
||||||
|
chunks.push((pos, end));
|
||||||
|
pos = end + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt concurrent ranged download; on failure, fallback to streaming GET
|
||||||
|
let mut ranged_failed = false;
|
||||||
|
if head.accept_ranges && self.settings.concurrency > 1 {
|
||||||
|
let (work_tx, work_rx) = mpsc::channel::<(u64, u64)>();
|
||||||
|
let (prog_tx, prog_rx) = mpsc::channel::<u64>();
|
||||||
|
for ch in chunks {
|
||||||
|
work_tx.send(ch).unwrap();
|
||||||
|
}
|
||||||
|
drop(work_tx);
|
||||||
|
let rx = Arc::new(Mutex::new(work_rx));
|
||||||
|
|
||||||
|
let workers = self.settings.concurrency.max(1);
|
||||||
|
let mut handles = Vec::new();
|
||||||
|
for _ in 0..workers {
|
||||||
|
let rx = rx.clone();
|
||||||
|
let url = url.clone();
|
||||||
|
let client = self.client.clone();
|
||||||
|
let f = f.clone();
|
||||||
|
let limit = self.settings.limit_rate;
|
||||||
|
let prog_tx = prog_tx.clone();
|
||||||
|
let handle = thread::spawn(move || -> Result<()> {
|
||||||
|
loop {
|
||||||
|
let next = {
|
||||||
|
let guard = rx.lock().unwrap();
|
||||||
|
guard.recv().ok()
|
||||||
|
};
|
||||||
|
let Some((start, end)) = next else { break; };
|
||||||
|
let data = client.get_range(&url, start, end)?;
|
||||||
|
if let Some(max_bps) = limit {
|
||||||
|
let dur = Duration::from_secs_f64((data.len() as f64) / (max_bps as f64));
|
||||||
|
if dur > Duration::from_millis(1) {
|
||||||
|
thread::sleep(dur);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut guard = f.lock().unwrap();
|
||||||
|
guard.seek(SeekFrom::Start(start))?;
|
||||||
|
guard.write_all(&data)?;
|
||||||
|
let _ = prog_tx.send(data.len() as u64);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
});
|
||||||
|
handles.push(handle);
|
||||||
|
}
|
||||||
|
drop(prog_tx);
|
||||||
|
|
||||||
|
for delta in prog_rx {
|
||||||
|
progress.inc(delta);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut ranged_err: Option<anyhow::Error> = None;
|
||||||
|
for h in handles {
|
||||||
|
match h.join() {
|
||||||
|
Ok(Ok(())) => {}
|
||||||
|
Ok(Err(e)) => { ranged_failed = true; if ranged_err.is_none() { ranged_err = Some(e.into()); } }
|
||||||
|
Err(_) => { ranged_failed = true; if ranged_err.is_none() { ranged_err = Some(anyhow!("worker panicked")); } }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ranged_failed = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ranged_failed {
|
||||||
|
// Restart progress if we are abandoning previous partial state
|
||||||
|
if start_from > 0 {
|
||||||
|
progress.stop("retrying as streaming");
|
||||||
|
progress = BytesProgress::start(total_len, &format!("Downloading {}", file), 0);
|
||||||
|
}
|
||||||
|
// Fallback to streaming GET; try URL with and without ?download=true
|
||||||
|
let mut try_urls = vec![url.clone()];
|
||||||
|
if let Some((base, _qs)) = url.split_once('?') {
|
||||||
|
try_urls.push(base.to_string());
|
||||||
|
} else {
|
||||||
|
try_urls.push(format!("{}?download=true", url));
|
||||||
|
}
|
||||||
|
// Fresh temp file for streaming
|
||||||
|
let mut wf = OpenOptions::new().create(true).write(true).truncate(true).open(&dest_tmp)?;
|
||||||
|
let mut ok = false;
|
||||||
|
let mut last_err: Option<anyhow::Error> = None;
|
||||||
|
// Counting writer to update progress inline
|
||||||
|
struct CountingWriter<'a, 'b> {
|
||||||
|
inner: &'a mut File,
|
||||||
|
progress: &'b mut BytesProgress,
|
||||||
|
}
|
||||||
|
impl<'a, 'b> Write for CountingWriter<'a, 'b> {
|
||||||
|
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||||
|
let n = self.inner.write(buf)?;
|
||||||
|
self.progress.inc(n as u64);
|
||||||
|
Ok(n)
|
||||||
|
}
|
||||||
|
fn flush(&mut self) -> std::io::Result<()> { self.inner.flush() }
|
||||||
|
}
|
||||||
|
let mut cw = CountingWriter { inner: &mut wf, progress: &mut progress };
|
||||||
|
for u in try_urls {
|
||||||
|
// For fallback, stream from scratch to ensure integrity
|
||||||
|
let res = self.client.get_whole_to(&u, &mut cw);
|
||||||
|
match res {
|
||||||
|
Ok(()) => { ok = true; break; }
|
||||||
|
Err(e) => { last_err = Some(e.into()); }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
if let Some(e) = last_err { return Err(e.into()); }
|
||||||
|
return Err(anyhow!("download failed (ranged and streaming)").into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
progress.stop("download complete");
|
||||||
|
|
||||||
|
// Verify integrity
|
||||||
|
let sha = Self::compute_sha256(&dest_tmp)?;
|
||||||
|
if let Some(expected) = api_sha.as_ref() {
|
||||||
|
if &sha != expected {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"sha256 mismatch (expected {}, got {})",
|
||||||
|
expected,
|
||||||
|
sha
|
||||||
|
).into());
|
||||||
|
}
|
||||||
|
} else if prev.as_ref().map(|r| r.file.eq(file)).unwrap_or(false) {
|
||||||
|
if let Some(expected) = prev.as_ref().and_then(|r| r.sha256.as_ref()) {
|
||||||
|
if &sha != expected {
|
||||||
|
return Err(anyhow!("sha256 mismatch").into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Atomic rename
|
||||||
|
fs::rename(&dest_tmp, &dest_final).with_context(|| format!("rename {} -> {}", dest_tmp.display(), dest_final.display()))?;
|
||||||
|
// Disarm guard; .part has been moved or cleaned
|
||||||
|
_tmp_guard.disarm();
|
||||||
|
|
||||||
|
let rec = ModelRecord {
|
||||||
|
alias: alias.to_string(),
|
||||||
|
repo: repo.to_string(),
|
||||||
|
file: file.to_string(),
|
||||||
|
revision: etag,
|
||||||
|
sha256: Some(sha.clone()),
|
||||||
|
size_bytes: Some(total_len),
|
||||||
|
quant: infer_quant(file),
|
||||||
|
installed_at: Some(Utc::now()),
|
||||||
|
last_used: Some(Utc::now()),
|
||||||
|
};
|
||||||
|
self.save_touch(&mut manifest, rec.clone())?;
|
||||||
|
Ok(rec)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn save_touch(&self, manifest: &mut Manifest, rec: ModelRecord) -> Result<()> {
|
||||||
|
manifest.models.insert(rec.alias.clone(), rec);
|
||||||
|
self.save_manifest(manifest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_quant(file: &str) -> Option<String> {
|
||||||
|
// Try to extract a Q* token, e.g. Q5_K_M from filename
|
||||||
|
let lower = file.to_ascii_uppercase();
|
||||||
|
if let Some(pos) = lower.find('Q') {
|
||||||
|
let tail = &lower[pos..];
|
||||||
|
let token: String = tail
|
||||||
|
.chars()
|
||||||
|
.take_while(|c| c.is_ascii_uppercase() || c.is_ascii_digit() || *c == '_' || *c == '-')
|
||||||
|
.collect();
|
||||||
|
if token.len() >= 2 {
|
||||||
|
return Some(token);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ReqwestClient {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new().expect("reqwest client")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hugging Face API types for file metadata
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct ApiHfLfs {
|
||||||
|
oid: Option<String>,
|
||||||
|
size: Option<u64>,
|
||||||
|
sha256: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct ApiHfFile {
|
||||||
|
rfilename: String,
|
||||||
|
size: Option<u64>,
|
||||||
|
sha256: Option<String>,
|
||||||
|
lfs: Option<ApiHfLfs>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct ApiHfModelInfo {
|
||||||
|
siblings: Option<Vec<ApiHfFile>>,
|
||||||
|
files: Option<Vec<ApiHfFile>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pick_sha_from_file(f: &ApiHfFile) -> Option<String> {
|
||||||
|
if let Some(s) = &f.sha256 { return Some(s.to_string()); }
|
||||||
|
if let Some(l) = &f.lfs {
|
||||||
|
if let Some(s) = &l.sha256 { return Some(s.to_string()); }
|
||||||
|
if let Some(oid) = &l.oid { return oid.strip_prefix("sha256:").map(|s| s.to_string()); }
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn hf_fetch_file_meta(repo: &str, target: &str) -> Result<(Option<u64>, Option<String>)> {
|
||||||
|
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
|
||||||
|
let client = Client::builder()
|
||||||
|
.user_agent(crate::config::ConfigService::user_agent())
|
||||||
|
.build()?;
|
||||||
|
let base = crate::config::ConfigService::hf_api_base_for(repo);
|
||||||
|
let urls = [base.clone(), format!("{}?expand=files", base)];
|
||||||
|
for url in urls {
|
||||||
|
let mut req = client.get(&url);
|
||||||
|
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
|
||||||
|
let resp = req.send()?;
|
||||||
|
if !resp.status().is_success() { continue; }
|
||||||
|
let info: ApiHfModelInfo = resp.json()?;
|
||||||
|
let list = info.files.or(info.siblings).unwrap_or_default();
|
||||||
|
for f in list {
|
||||||
|
let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename);
|
||||||
|
if name.eq_ignore_ascii_case(target) {
|
||||||
|
let sz = f.size.or_else(|| f.lfs.as_ref().and_then(|l| l.size));
|
||||||
|
let sha = pick_sha_from_file(&f);
|
||||||
|
return Ok((sz, sha));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(anyhow!("file not found in HF API").into())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fetch remote metadata (size, sha256) for a single file in a HF repo.
|
||||||
|
pub fn fetch_file_meta(repo: &str, file: &str) -> Result<(Option<u64>, Option<String>)> {
|
||||||
|
hf_fetch_file_meta(repo, file)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Search a Hugging Face repo for GGUF/BIN files via API. Returns file names only.
|
||||||
|
pub fn search_repo(repo: &str, query: Option<&str>) -> Result<Vec<String>> {
|
||||||
|
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
|
||||||
|
let client = Client::builder()
|
||||||
|
.user_agent(crate::config::ConfigService::user_agent())
|
||||||
|
.build()?;
|
||||||
|
|
||||||
|
let base = crate::config::ConfigService::hf_api_base_for(repo);
|
||||||
|
let mut urls = vec![base.clone(), format!("{}?expand=files", base)];
|
||||||
|
let mut files = Vec::<String>::new();
|
||||||
|
|
||||||
|
for url in urls.drain(..) {
|
||||||
|
let mut req = client.get(&url);
|
||||||
|
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
|
||||||
|
let resp = req.send()?;
|
||||||
|
if !resp.status().is_success() { continue; }
|
||||||
|
let info: ApiHfModelInfo = resp.json()?;
|
||||||
|
let list = info.files.or(info.siblings).unwrap_or_default();
|
||||||
|
for f in list {
|
||||||
|
if f.rfilename.ends_with(".gguf") || f.rfilename.ends_with(".bin") {
|
||||||
|
let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename).to_string();
|
||||||
|
if !files.contains(&name) { files.push(name); }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !files.is_empty() { break; }
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(q) = query { let qlc = q.to_ascii_lowercase(); files.retain(|f| f.to_ascii_lowercase().contains(&qlc)); }
|
||||||
|
files.sort();
|
||||||
|
Ok(files)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List repo files with optional size metadata for GGUF/BIN entries.
|
||||||
|
pub fn list_repo_files_with_meta(repo: &str) -> Result<Vec<(String, Option<u64>)>> {
|
||||||
|
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
|
||||||
|
let client = Client::builder()
|
||||||
|
.user_agent(crate::config::ConfigService::user_agent())
|
||||||
|
.build()?;
|
||||||
|
|
||||||
|
let base = crate::config::ConfigService::hf_api_base_for(repo);
|
||||||
|
for url in [base.clone(), format!("{}?expand=files", base)] {
|
||||||
|
let mut req = client.get(&url);
|
||||||
|
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
|
||||||
|
let resp = req.send()?;
|
||||||
|
if !resp.status().is_success() { continue; }
|
||||||
|
let info: ApiHfModelInfo = resp.json()?;
|
||||||
|
let list = info.files.or(info.siblings).unwrap_or_default();
|
||||||
|
let mut out = Vec::new();
|
||||||
|
for f in list {
|
||||||
|
if !(f.rfilename.ends_with(".gguf") || f.rfilename.ends_with(".bin")) { continue; }
|
||||||
|
let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename).to_string();
|
||||||
|
let size = f.size.or_else(|| f.lfs.as_ref().and_then(|l| l.size));
|
||||||
|
out.push((name, size));
|
||||||
|
}
|
||||||
|
if !out.is_empty() { return Ok(out); }
|
||||||
|
}
|
||||||
|
Ok(Vec::new())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fallback: HEAD request for a single file to retrieve Content-Length (size).
|
||||||
|
pub fn head_len_for_file(repo: &str, file: &str) -> Option<u64> {
|
||||||
|
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
|
||||||
|
let client = Client::builder()
|
||||||
|
.user_agent(crate::config::ConfigService::user_agent())
|
||||||
|
.build().ok()?;
|
||||||
|
let mut urls = Vec::new();
|
||||||
|
urls.push(format!("https://huggingface.co/{}/resolve/main/{}?download=true", repo, file));
|
||||||
|
urls.push(format!("https://huggingface.co/{}/resolve/main/{}", repo, file));
|
||||||
|
for url in urls {
|
||||||
|
let mut req = client.head(&url);
|
||||||
|
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
|
||||||
|
if let Ok(resp) = req.send() {
|
||||||
|
if resp.status().is_success() {
|
||||||
|
if let Some(len) = resp.headers().get(CONTENT_LENGTH)
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.and_then(|s| s.parse::<u64>().ok())
|
||||||
|
{ return Some(len); }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::env;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct StubHttp {
|
||||||
|
data: Arc<Vec<u8>>,
|
||||||
|
etag: Arc<Option<String>>,
|
||||||
|
accept_ranges: bool,
|
||||||
|
}
|
||||||
|
impl HttpClient for StubHttp {
|
||||||
|
fn head(&self, _url: &str, etag: Option<&str>) -> Result<HeadMeta> {
|
||||||
|
let not_modified = etag.is_some() && self.etag.as_ref().as_deref() == etag;
|
||||||
|
Ok(HeadMeta {
|
||||||
|
len: Some(self.data.len() as u64),
|
||||||
|
etag: self.etag.as_ref().clone(),
|
||||||
|
last_modified: None,
|
||||||
|
accept_ranges: self.accept_ranges,
|
||||||
|
not_modified,
|
||||||
|
status: if not_modified { 304 } else { 200 },
|
||||||
|
})
|
||||||
|
}
|
||||||
|
fn get_range(&self, _url: &str, start: u64, end_inclusive: u64) -> Result<Vec<u8>> {
|
||||||
|
let s = start as usize;
|
||||||
|
let e = (end_inclusive as usize) + 1;
|
||||||
|
Ok(self.data[s..e].to_vec())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_whole_to(&self, _url: &str, writer: &mut dyn Write) -> Result<()> {
|
||||||
|
writer.write_all(&self.data)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_from_to(&self, _url: &str, start: u64, writer: &mut dyn Write) -> Result<()> {
|
||||||
|
let s = start as usize;
|
||||||
|
writer.write_all(&self.data[s..])?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn setup_env(cache: &Path, cfg: &Path) {
|
||||||
|
unsafe {
|
||||||
|
env::set_var("POLYSCRIBE_CACHE_DIR", cache.to_string_lossy().to_string());
|
||||||
|
env::set_var(
|
||||||
|
"POLYSCRIBE_CONFIG_DIR",
|
||||||
|
cfg.parent().unwrap().to_string_lossy().to_string(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_manifest_roundtrip() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let cache = temp.path().join("cache");
|
||||||
|
let cfg = temp.path().join("config").join("models.json");
|
||||||
|
setup_env(&cache, &cfg);
|
||||||
|
|
||||||
|
let client = StubHttp {
|
||||||
|
data: Arc::new(vec![0u8; 1024]),
|
||||||
|
etag: Arc::new(Some("etag123".into())),
|
||||||
|
accept_ranges: true,
|
||||||
|
};
|
||||||
|
let mm: ModelManager<StubHttp> = ModelManager::new_with_client(client, Settings::default()).unwrap();
|
||||||
|
let m = mm.load_manifest().unwrap();
|
||||||
|
assert!(m.models.is_empty());
|
||||||
|
|
||||||
|
let rec = ModelRecord {
|
||||||
|
alias: "tiny".into(),
|
||||||
|
repo: "foo/bar".into(),
|
||||||
|
file: "gguf-tiny.bin".into(),
|
||||||
|
revision: Some("etag123".into()),
|
||||||
|
sha256: None,
|
||||||
|
size_bytes: None,
|
||||||
|
quant: None,
|
||||||
|
installed_at: None,
|
||||||
|
last_used: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut m2 = Manifest::default();
|
||||||
|
mm.save_touch(&mut m2, rec.clone()).unwrap();
|
||||||
|
let m3 = mm.load_manifest().unwrap();
|
||||||
|
assert!(m3.models.contains_key("tiny"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_add_verify_update_gc() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let cache = temp.path().join("cache");
|
||||||
|
let cfg_dir = temp.path().join("config");
|
||||||
|
let cfg = cfg_dir.join("models.json");
|
||||||
|
setup_env(&cache, &cfg);
|
||||||
|
|
||||||
|
let data = (0..1024 * 1024u32).flat_map(|i| i.to_le_bytes()).collect::<Vec<u8>>();
|
||||||
|
let etag = Some("abc123".to_string());
|
||||||
|
let client = StubHttp { data: Arc::new(data), etag: Arc::new(etag), accept_ranges: true };
|
||||||
|
let mm: ModelManager<StubHttp> = ModelManager::new_with_client(client.clone(), Settings{ concurrency: 3, ..Default::default() }).unwrap();
|
||||||
|
|
||||||
|
// add
|
||||||
|
let rec = mm.add_or_update("tiny", "gguf/models", "gguf-tiny-q4_0.bin").unwrap();
|
||||||
|
assert_eq!(rec.alias, "tiny");
|
||||||
|
assert!(mm.verify("tiny").unwrap());
|
||||||
|
|
||||||
|
// update (304)
|
||||||
|
let rec2 = mm.add_or_update("tiny", "gguf/models", "gguf-tiny-q4_0.bin").unwrap();
|
||||||
|
assert_eq!(rec2.alias, "tiny");
|
||||||
|
|
||||||
|
// gc (nothing to remove)
|
||||||
|
let (files_removed, entries_removed) = mm.gc().unwrap();
|
||||||
|
assert_eq!(files_removed, 0);
|
||||||
|
assert_eq!(entries_removed, 0);
|
||||||
|
|
||||||
|
// rm
|
||||||
|
assert!(mm.rm("tiny").unwrap());
|
||||||
|
assert!(!mm.rm("tiny").unwrap());
|
||||||
|
}
|
||||||
|
}
|
@@ -792,6 +792,13 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let part_path = dest_path.with_extension("part");
|
let part_path = dest_path.with_extension("part");
|
||||||
|
// Guard to cleanup .part on errors
|
||||||
|
struct TempGuard { path: std::path::PathBuf, armed: bool }
|
||||||
|
impl TempGuard { fn disarm(&mut self) { self.armed = false; } }
|
||||||
|
impl Drop for TempGuard {
|
||||||
|
fn drop(&mut self) { if self.armed { let _ = fs::remove_file(&self.path); } }
|
||||||
|
}
|
||||||
|
let mut _tmp_guard = TempGuard { path: part_path.clone(), armed: true };
|
||||||
|
|
||||||
let mut resume_from: u64 = 0;
|
let mut resume_from: u64 = 0;
|
||||||
if part_path.exists() && ranges_ok {
|
if part_path.exists() && ranges_ok {
|
||||||
@@ -893,6 +900,7 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
|
|||||||
drop(part_file);
|
drop(part_file);
|
||||||
fs::rename(&part_path, dest_path)
|
fs::rename(&part_path, dest_path)
|
||||||
.with_context(|| format!("renaming {} → {}", part_path.display(), dest_path.display()))?;
|
.with_context(|| format!("renaming {} → {}", part_path.display(), dest_path.display()))?;
|
||||||
|
_tmp_guard.disarm();
|
||||||
|
|
||||||
let final_size = fs::metadata(dest_path).map(|m| m.len()).ok();
|
let final_size = fs::metadata(dest_path).map(|m| m.len()).ok();
|
||||||
let elapsed = start.elapsed().as_secs_f64();
|
let elapsed = start.elapsed().as_secs_f64();
|
||||||
|
@@ -4,6 +4,8 @@ pub mod progress;
|
|||||||
|
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::io::IsTerminal;
|
use std::io::IsTerminal;
|
||||||
|
use std::io::Write as _;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
pub fn info(msg: impl AsRef<str>) {
|
pub fn info(msg: impl AsRef<str>) {
|
||||||
let m = msg.as_ref();
|
let m = msg.as_ref();
|
||||||
@@ -170,43 +172,156 @@ impl Spinner {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct BytesProgress(Option<cliclack::ProgressBar>);
|
pub struct BytesProgress {
|
||||||
|
enabled: bool,
|
||||||
|
total: u64,
|
||||||
|
current: u64,
|
||||||
|
started: Instant,
|
||||||
|
last_msg: Instant,
|
||||||
|
width: usize,
|
||||||
|
// Sticky ETA to carry through zero-speed stalls
|
||||||
|
last_eta_secs: Option<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
impl BytesProgress {
|
impl BytesProgress {
|
||||||
pub fn start(total: u64, text: &str, initial: u64) -> Self {
|
pub fn start(total: u64, text: &str, initial: u64) -> Self {
|
||||||
if crate::is_no_progress()
|
let enabled = !(crate::is_no_progress()
|
||||||
|| crate::is_no_interaction()
|
|| crate::is_no_interaction()
|
||||||
|| !std::io::stderr().is_terminal()
|
|| !std::io::stderr().is_terminal()
|
||||||
|| total == 0
|
|| total == 0);
|
||||||
{
|
if !enabled {
|
||||||
let _ = cliclack::log::info(text);
|
let _ = cliclack::log::info(text);
|
||||||
return Self(None);
|
|
||||||
}
|
}
|
||||||
let b = cliclack::progress_bar(total);
|
let mut me = Self {
|
||||||
b.start(text);
|
enabled,
|
||||||
if initial > 0 {
|
total,
|
||||||
b.inc(initial);
|
current: initial.min(total),
|
||||||
|
started: Instant::now(),
|
||||||
|
last_msg: Instant::now(),
|
||||||
|
width: 40,
|
||||||
|
last_eta_secs: None,
|
||||||
|
};
|
||||||
|
me.draw();
|
||||||
|
me
|
||||||
|
}
|
||||||
|
|
||||||
|
fn human_bytes(n: u64) -> String {
|
||||||
|
const KB: f64 = 1024.0;
|
||||||
|
const MB: f64 = 1024.0 * KB;
|
||||||
|
const GB: f64 = 1024.0 * MB;
|
||||||
|
let x = n as f64;
|
||||||
|
if x >= GB {
|
||||||
|
format!("{:.2} GiB", x / GB)
|
||||||
|
} else if x >= MB {
|
||||||
|
format!("{:.2} MiB", x / MB)
|
||||||
|
} else if x >= KB {
|
||||||
|
format!("{:.2} KiB", x / KB)
|
||||||
|
} else {
|
||||||
|
format!("{} B", n)
|
||||||
}
|
}
|
||||||
Self(Some(b))
|
}
|
||||||
|
|
||||||
|
// Elapsed formatting is used for stable, finite durations. For ETA, we guard
|
||||||
|
// against zero-speed or unstable estimates separately via `format_eta`.
|
||||||
|
|
||||||
|
fn refresh_allowed(&mut self) -> (f64, f64) {
|
||||||
|
let now = Instant::now();
|
||||||
|
let since_last = now.duration_since(self.last_msg);
|
||||||
|
if since_last < Duration::from_millis(100) {
|
||||||
|
// Too soon to refresh; keep previous ETA if any
|
||||||
|
let eta = self.last_eta_secs.unwrap_or(f64::INFINITY);
|
||||||
|
return (0.0, eta);
|
||||||
|
}
|
||||||
|
self.last_msg = now;
|
||||||
|
let elapsed = now.duration_since(self.started).as_secs_f64().max(0.001);
|
||||||
|
let speed = (self.current as f64) / elapsed;
|
||||||
|
let remaining = self.total.saturating_sub(self.current) as f64;
|
||||||
|
|
||||||
|
// If speed is effectively zero, carry ETA forward and add wall time.
|
||||||
|
const EPS: f64 = 1e-6;
|
||||||
|
let eta = if speed <= EPS {
|
||||||
|
let prev = self.last_eta_secs.unwrap_or(f64::INFINITY);
|
||||||
|
if prev.is_finite() {
|
||||||
|
prev + since_last.as_secs_f64()
|
||||||
|
} else {
|
||||||
|
prev
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
remaining / speed
|
||||||
|
};
|
||||||
|
// Remember only finite ETAs to use during stalls
|
||||||
|
if eta.is_finite() {
|
||||||
|
self.last_eta_secs = Some(eta);
|
||||||
|
}
|
||||||
|
(speed, eta)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn format_elapsed(seconds: f64) -> String {
|
||||||
|
let total = seconds.round() as u64;
|
||||||
|
let h = total / 3600;
|
||||||
|
let m = (total % 3600) / 60;
|
||||||
|
let s = total % 60;
|
||||||
|
if h > 0 { format!("{:02}:{:02}:{:02}", h, m, s) } else { format!("{:02}:{:02}", m, s) }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn format_eta(seconds: f64) -> String {
|
||||||
|
// If ETA is not finite (e.g., divide-by-zero speed) or unreasonably large,
|
||||||
|
// show a placeholder rather than overflowing into huge values.
|
||||||
|
if !seconds.is_finite() {
|
||||||
|
return "—".to_string();
|
||||||
|
}
|
||||||
|
// Cap ETA display to 99:59:59 to avoid silly numbers; beyond that, show placeholder.
|
||||||
|
const CAP_SECS: f64 = 99.0 * 3600.0 + 59.0 * 60.0 + 59.0;
|
||||||
|
if seconds > CAP_SECS {
|
||||||
|
return "—".to_string();
|
||||||
|
}
|
||||||
|
Self::format_elapsed(seconds)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn draw(&mut self) {
|
||||||
|
if !self.enabled { return; }
|
||||||
|
let (speed, eta) = self.refresh_allowed();
|
||||||
|
let elapsed = Instant::now().duration_since(self.started).as_secs_f64();
|
||||||
|
// Build bar
|
||||||
|
let width = self.width.max(10);
|
||||||
|
let filled = ((self.current as f64 / self.total.max(1) as f64) * width as f64).round() as usize;
|
||||||
|
let filled = filled.min(width);
|
||||||
|
let mut bar = String::with_capacity(width);
|
||||||
|
for _ in 0..filled { bar.push('■'); }
|
||||||
|
for _ in filled..width { bar.push('□'); }
|
||||||
|
|
||||||
|
let line = format!(
|
||||||
|
"[{}] {} [{}] ({}/{} at {}/s)",
|
||||||
|
Self::format_elapsed(elapsed),
|
||||||
|
bar,
|
||||||
|
Self::format_eta(eta),
|
||||||
|
Self::human_bytes(self.current),
|
||||||
|
Self::human_bytes(self.total),
|
||||||
|
Self::human_bytes(speed.max(0.0) as u64),
|
||||||
|
);
|
||||||
|
eprint!("\r{}\x1b[K", line);
|
||||||
|
let _ = io::stderr().flush();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn inc(&mut self, delta: u64) {
|
pub fn inc(&mut self, delta: u64) {
|
||||||
if let Some(b) = self.0.as_mut() {
|
self.current = self.current.saturating_add(delta).min(self.total);
|
||||||
b.inc(delta);
|
self.draw();
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn stop(mut self, text: &str) {
|
pub fn stop(mut self, text: &str) {
|
||||||
if let Some(b) = self.0.take() {
|
if self.enabled {
|
||||||
b.stop(text);
|
self.draw();
|
||||||
|
eprintln!();
|
||||||
} else {
|
} else {
|
||||||
let _ = cliclack::log::info(text);
|
let _ = cliclack::log::info(text);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn error(mut self, text: &str) {
|
pub fn error(mut self, text: &str) {
|
||||||
if let Some(b) = self.0.take() {
|
if self.enabled {
|
||||||
b.error(text);
|
self.draw();
|
||||||
|
eprintln!();
|
||||||
|
let _ = cliclack::log::error(text);
|
||||||
} else {
|
} else {
|
||||||
let _ = cliclack::log::error(text);
|
let _ = cliclack::log::error(text);
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user