Compare commits
4 Commits
5ec297397e
...
840383fcf7
Author | SHA1 | Date | |
---|---|---|---|
840383fcf7 | |||
1982e9b48b | |||
0128bf2eec | |||
da5a76d253 |
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -251,6 +251,7 @@ dependencies = [
|
||||
"iana-time-zone",
|
||||
"js-sys",
|
||||
"num-traits",
|
||||
"serde",
|
||||
"wasm-bindgen",
|
||||
"windows-link",
|
||||
]
|
||||
|
16
Cargo.toml
16
Cargo.toml
@@ -7,6 +7,12 @@ members = [
|
||||
]
|
||||
resolver = "3"
|
||||
|
||||
[workspace.package]
|
||||
edition = "2024"
|
||||
version = "0.1.0"
|
||||
license = "MIT"
|
||||
rust-version = "1.89"
|
||||
|
||||
# Optional: Keep dependency versions consistent across members
|
||||
[workspace.dependencies]
|
||||
thiserror = "1.0.69"
|
||||
@@ -15,7 +21,7 @@ anyhow = "1.0.99"
|
||||
libc = "0.2.175"
|
||||
toml = "0.8.23"
|
||||
serde_json = "1.0.142"
|
||||
chrono = "0.4.41"
|
||||
chrono = { version = "0.4.41", features = ["serde"] }
|
||||
sha2 = "0.10.9"
|
||||
which = "6.0.3"
|
||||
tokio = { version = "1.47.1", features = ["rt-multi-thread", "macros"] }
|
||||
@@ -26,6 +32,14 @@ cliclack = "0.3.6"
|
||||
clap_complete = "4.5.57"
|
||||
clap_mangen = "0.2.29"
|
||||
|
||||
# Additional shared deps used across members
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
|
||||
reqwest = { version = "0.12.7", default-features = false, features = ["blocking", "rustls-tls", "gzip", "json"] }
|
||||
hex = "0.4.3"
|
||||
tempfile = "3.12.0"
|
||||
assert_cmd = "2.0.16"
|
||||
|
||||
[workspace.lints.rust]
|
||||
unused_imports = "deny"
|
||||
dead_code = "warn"
|
||||
|
16
build.rs
16
build.rs
@@ -1,16 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
|
||||
|
||||
fn main() {
|
||||
// Only run special build steps when gpu-vulkan feature is enabled.
|
||||
let vulkan_enabled = std::env::var("CARGO_FEATURE_GPU_VULKAN").is_ok();
|
||||
if !vulkan_enabled {
|
||||
return;
|
||||
}
|
||||
// Placeholder: In a full implementation, we would invoke CMake for whisper.cpp with GGML_VULKAN=1.
|
||||
// For now, emit a helpful note. Build will proceed; runtime Vulkan backend returns an explanatory error.
|
||||
println!("cargo:rerun-if-changed=extern/whisper.cpp");
|
||||
println!(
|
||||
"cargo:warning=Building with gpu-vulkan: ensure Vulkan SDK/loader are installed. Future versions will compile whisper.cpp via CMake."
|
||||
);
|
||||
}
|
@@ -1,24 +1,24 @@
|
||||
[package]
|
||||
name = "polyscribe-cli"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "polyscribe"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.99"
|
||||
clap = { version = "4.5.44", features = ["derive"] }
|
||||
clap_complete = "4.5.57"
|
||||
clap_mangen = "0.2.29"
|
||||
directories = "5.0.1"
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
serde_json = "1.0.142"
|
||||
tokio = { version = "1.47.1", features = ["rt-multi-thread", "macros", "process", "fs"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
|
||||
which = "6.0.3"
|
||||
anyhow = { workspace = true }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
clap_complete = { workspace = true }
|
||||
clap_mangen = { workspace = true }
|
||||
directories = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt-multi-thread", "macros", "process", "fs"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true, features = ["fmt", "env-filter"] }
|
||||
which = { workspace = true }
|
||||
|
||||
polyscribe-core = { path = "../polyscribe-core" }
|
||||
polyscribe-host = { path = "../polyscribe-host" }
|
||||
@@ -29,4 +29,4 @@ polyscribe-protocol = { path = "../polyscribe-protocol" }
|
||||
default = []
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2.0.16"
|
||||
assert_cmd = { workspace = true }
|
||||
|
@@ -1,4 +1,4 @@
|
||||
use clap::{Parser, Subcommand, ValueEnum};
|
||||
use clap::{Args, Parser, Subcommand, ValueEnum};
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Debug, Clone, ValueEnum)]
|
||||
@@ -10,21 +10,33 @@ pub enum GpuBackend {
|
||||
Vulkan,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Args)]
|
||||
pub struct OutputOpts {
|
||||
/// Emit machine-readable JSON to stdout; suppress decorative logs
|
||||
#[arg(long, global = true, action = clap::ArgAction::SetTrue)]
|
||||
pub json: bool,
|
||||
/// Reduce log chatter (errors only unless --json)
|
||||
#[arg(long, global = true, action = clap::ArgAction::SetTrue)]
|
||||
pub quiet: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
#[command(
|
||||
name = "polyscribe",
|
||||
version,
|
||||
about = "PolyScribe – local-first transcription and plugins"
|
||||
about = "PolyScribe – local-first transcription and plugins",
|
||||
propagate_version = true,
|
||||
arg_required_else_help = true,
|
||||
)]
|
||||
pub struct Cli {
|
||||
/// Global output options
|
||||
#[command(flatten)]
|
||||
pub output: OutputOpts,
|
||||
|
||||
/// Increase verbosity (-v, -vv)
|
||||
#[arg(short, long, action = clap::ArgAction::Count)]
|
||||
pub verbose: u8,
|
||||
|
||||
/// Quiet mode (suppresses non-error logs)
|
||||
#[arg(short, long, default_value_t = false)]
|
||||
pub quiet: bool,
|
||||
|
||||
/// Never prompt for user input (non-interactive mode)
|
||||
#[arg(long, default_value_t = false)]
|
||||
pub no_interaction: bool,
|
||||
@@ -74,7 +86,7 @@ pub enum Commands {
|
||||
inputs: Vec<PathBuf>,
|
||||
},
|
||||
|
||||
/// Manage Whisper models
|
||||
/// Manage Whisper GGUF models (Hugging Face)
|
||||
Models {
|
||||
#[command(subcommand)]
|
||||
cmd: ModelsCmd,
|
||||
@@ -97,14 +109,64 @@ pub enum Commands {
|
||||
Man,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Parser)]
|
||||
pub struct ModelCommon {
|
||||
/// Concurrency for ranged downloads
|
||||
#[arg(long, default_value_t = 4)]
|
||||
pub concurrency: usize,
|
||||
/// Limit download rate in bytes/sec (approximate)
|
||||
#[arg(long)]
|
||||
pub limit_rate: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
pub enum ModelsCmd {
|
||||
/// Verify or update local models non-interactively
|
||||
Update,
|
||||
/// Interactive multi-select downloader
|
||||
Download,
|
||||
/// Clear the cached Hugging Face manifest
|
||||
ClearCache,
|
||||
/// List installed models (from manifest)
|
||||
Ls {
|
||||
#[command(flatten)]
|
||||
common: ModelCommon,
|
||||
},
|
||||
/// Add or update a model
|
||||
Add {
|
||||
/// Hugging Face repo, e.g. ggml-org/models
|
||||
repo: String,
|
||||
/// File name in repo (e.g., gguf-tiny-q4_0.bin)
|
||||
file: String,
|
||||
#[command(flatten)]
|
||||
common: ModelCommon,
|
||||
},
|
||||
/// Remove a model by alias
|
||||
Rm {
|
||||
alias: String,
|
||||
#[command(flatten)]
|
||||
common: ModelCommon,
|
||||
},
|
||||
/// Verify model file integrity by alias
|
||||
Verify {
|
||||
alias: String,
|
||||
#[command(flatten)]
|
||||
common: ModelCommon,
|
||||
},
|
||||
/// Update all models (HEAD + ETag; skip if unchanged)
|
||||
Update {
|
||||
#[command(flatten)]
|
||||
common: ModelCommon,
|
||||
},
|
||||
/// Garbage-collect unreferenced files and stale manifest entries
|
||||
Gc {
|
||||
#[command(flatten)]
|
||||
common: ModelCommon,
|
||||
},
|
||||
/// Search a repo for GGUF files
|
||||
Search {
|
||||
/// Hugging Face repo, e.g. ggml-org/models
|
||||
repo: String,
|
||||
/// Optional substring to filter filenames
|
||||
#[arg(long)]
|
||||
query: Option<String>,
|
||||
#[command(flatten)]
|
||||
common: ModelCommon,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
|
@@ -1,41 +1,84 @@
|
||||
mod cli;
|
||||
mod output;
|
||||
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use clap::{CommandFactory, Parser};
|
||||
use cli::{Cli, Commands, GpuBackend, ModelsCmd, PluginsCmd};
|
||||
use polyscribe_core::models;
|
||||
use cli::{Cli, Commands, GpuBackend, ModelsCmd, ModelCommon, PluginsCmd};
|
||||
use output::OutputMode;
|
||||
use polyscribe_core::model_manager::{ModelManager, Settings, ReqwestClient};
|
||||
use polyscribe_core::ui;
|
||||
fn normalized_similarity(a: &str, b: &str) -> f64 {
|
||||
// simple Levenshtein distance; normalized to [0,1]
|
||||
let a_bytes = a.as_bytes();
|
||||
let b_bytes = b.as_bytes();
|
||||
let n = a_bytes.len();
|
||||
let m = b_bytes.len();
|
||||
if n == 0 && m == 0 { return 1.0; }
|
||||
if n == 0 || m == 0 { return 0.0; }
|
||||
let mut prev: Vec<usize> = (0..=m).collect();
|
||||
let mut curr: Vec<usize> = vec![0; m + 1];
|
||||
for i in 1..=n {
|
||||
curr[0] = i;
|
||||
for j in 1..=m {
|
||||
let cost = if a_bytes[i - 1] == b_bytes[j - 1] { 0 } else { 1 };
|
||||
curr[j] = (prev[j] + 1)
|
||||
.min(curr[j - 1] + 1)
|
||||
.min(prev[j - 1] + cost);
|
||||
}
|
||||
std::mem::swap(&mut prev, &mut curr);
|
||||
}
|
||||
let dist = prev[m] as f64;
|
||||
let max_len = n.max(m) as f64;
|
||||
1.0 - (dist / max_len)
|
||||
}
|
||||
|
||||
fn human_size(bytes: Option<u64>) -> String {
|
||||
match bytes {
|
||||
Some(n) => {
|
||||
let x = n as f64;
|
||||
const KB: f64 = 1024.0;
|
||||
const MB: f64 = 1024.0 * KB;
|
||||
const GB: f64 = 1024.0 * MB;
|
||||
if x >= GB { format!("{:.2} GiB", x / GB) }
|
||||
else if x >= MB { format!("{:.2} MiB", x / MB) }
|
||||
else if x >= KB { format!("{:.2} KiB", x / KB) }
|
||||
else { format!("{} B", n) }
|
||||
}
|
||||
None => "?".to_string(),
|
||||
}
|
||||
}
|
||||
use polyscribe_core::ui::progress::ProgressReporter;
|
||||
use polyscribe_host::PluginManager;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
fn init_tracing(quiet: bool, verbose: u8) {
|
||||
let log_level = if quiet {
|
||||
"error"
|
||||
} else {
|
||||
match verbose {
|
||||
0 => "info",
|
||||
1 => "debug",
|
||||
_ => "trace",
|
||||
}
|
||||
};
|
||||
|
||||
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(log_level));
|
||||
fn init_tracing(json_mode: bool, quiet: bool, verbose: u8) {
|
||||
// In JSON mode, suppress human logs; route errors to stderr only.
|
||||
let level = if json_mode || quiet { "error" } else { match verbose { 0 => "info", 1 => "debug", _ => "trace" } };
|
||||
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(level));
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(filter)
|
||||
.with_target(false)
|
||||
.with_level(true)
|
||||
.with_writer(std::io::stderr)
|
||||
.compact()
|
||||
.init();
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
fn main() -> Result<()> {
|
||||
let args = Cli::parse();
|
||||
|
||||
init_tracing(args.quiet, args.verbose);
|
||||
// Determine output mode early for logging and UI configuration
|
||||
let output_mode = if args.output.json {
|
||||
OutputMode::Json
|
||||
} else {
|
||||
OutputMode::Human { quiet: args.output.quiet }
|
||||
};
|
||||
|
||||
polyscribe_core::set_quiet(args.quiet);
|
||||
init_tracing(matches!(output_mode, OutputMode::Json), args.output.quiet, args.verbose);
|
||||
|
||||
// Suppress decorative UI output in JSON mode as well
|
||||
polyscribe_core::set_quiet(args.output.quiet || matches!(output_mode, OutputMode::Json));
|
||||
polyscribe_core::set_no_interaction(args.no_interaction);
|
||||
polyscribe_core::set_verbose(args.verbose);
|
||||
polyscribe_core::set_no_progress(args.no_progress);
|
||||
@@ -71,32 +114,274 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
Commands::Models { cmd } => {
|
||||
match cmd {
|
||||
ModelsCmd::Update => {
|
||||
polyscribe_core::ui::info("verifying/updating local models");
|
||||
tokio::task::spawn_blocking(models::update_local_models)
|
||||
.await
|
||||
.map_err(|e| anyhow!("blocking task join error: {e}"))?
|
||||
.context("updating models")?;
|
||||
// predictable exit codes
|
||||
const EXIT_OK: i32 = 0;
|
||||
const EXIT_NOT_FOUND: i32 = 2;
|
||||
const EXIT_NETWORK: i32 = 3;
|
||||
const EXIT_VERIFY_FAILED: i32 = 4;
|
||||
// const EXIT_NO_CHANGE: i32 = 5; // reserved
|
||||
|
||||
let handle_common = |c: &ModelCommon| Settings {
|
||||
concurrency: c.concurrency.max(1),
|
||||
limit_rate: c.limit_rate,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let exit = match cmd {
|
||||
ModelsCmd::Ls { common } => {
|
||||
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
|
||||
let list = mm.ls()?;
|
||||
match output_mode {
|
||||
OutputMode::Json => {
|
||||
// Always emit JSON array (possibly empty)
|
||||
output_mode.print_json(&list);
|
||||
}
|
||||
OutputMode::Human { quiet } => {
|
||||
if list.is_empty() {
|
||||
if !quiet { println!("No models installed."); }
|
||||
} else {
|
||||
if !quiet { println!("Model (Repo)"); }
|
||||
for r in list {
|
||||
if !quiet { println!("{} ({})", r.file, r.repo); }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
EXIT_OK
|
||||
}
|
||||
ModelsCmd::Download => {
|
||||
polyscribe_core::ui::info("interactive model selection and download");
|
||||
tokio::task::spawn_blocking(models::run_interactive_model_downloader)
|
||||
.await
|
||||
.map_err(|e| anyhow!("blocking task join error: {e}"))?
|
||||
.context("running downloader")?;
|
||||
polyscribe_core::ui::success("Model download complete.");
|
||||
ModelsCmd::Add { repo, file, common } => {
|
||||
let settings = handle_common(&common);
|
||||
let mm: ModelManager<ReqwestClient> = ModelManager::new(settings.clone())?;
|
||||
// Derive an alias automatically from repo and file
|
||||
fn derive_alias(repo: &str, file: &str) -> String {
|
||||
use std::path::Path;
|
||||
let repo_tail = repo.rsplit('/').next().unwrap_or(repo);
|
||||
let stem = Path::new(file)
|
||||
.file_stem()
|
||||
.and_then(|s| s.to_str())
|
||||
.unwrap_or(file);
|
||||
format!("{}-{}", repo_tail, stem)
|
||||
}
|
||||
let alias = derive_alias(&repo, &file);
|
||||
match mm.add_or_update(&alias, &repo, &file) {
|
||||
Ok(rec) => {
|
||||
match output_mode {
|
||||
OutputMode::Json => output_mode.print_json(&rec),
|
||||
OutputMode::Human { quiet } => {
|
||||
if !quiet { println!("installed: {} -> {}/{}", alias, repo, rec.file); }
|
||||
}
|
||||
}
|
||||
EXIT_OK
|
||||
}
|
||||
Err(e) => {
|
||||
// On not found or similar errors, try suggesting close matches interactively
|
||||
if matches!(output_mode, OutputMode::Json) || polyscribe_core::is_no_interaction() {
|
||||
match output_mode {
|
||||
OutputMode::Json => {
|
||||
// Emit error JSON object
|
||||
#[derive(serde::Serialize)]
|
||||
struct ErrObj<'a> { error: &'a str }
|
||||
let eo = ErrObj { error: &e.to_string() };
|
||||
output_mode.print_json(&eo);
|
||||
}
|
||||
_ => { eprintln!("error: {e}"); }
|
||||
}
|
||||
EXIT_NOT_FOUND
|
||||
} else {
|
||||
ui::warn(format!("{}", e));
|
||||
ui::info("Searching for similar model filenames…");
|
||||
match polyscribe_core::model_manager::search_repo(&repo, None) {
|
||||
Ok(mut files) => {
|
||||
if files.is_empty() {
|
||||
ui::warn("No files found in repository.");
|
||||
EXIT_NOT_FOUND
|
||||
} else {
|
||||
// rank by similarity
|
||||
files.sort_by(|a, b| normalized_similarity(&file, b)
|
||||
.partial_cmp(&normalized_similarity(&file, a))
|
||||
.unwrap_or(std::cmp::Ordering::Equal));
|
||||
let top: Vec<String> = files.into_iter().take(5).collect();
|
||||
if top.is_empty() {
|
||||
EXIT_NOT_FOUND
|
||||
} else if top.len() == 1 {
|
||||
let cand = &top[0];
|
||||
// Fetch repo size list once
|
||||
let size_map: std::collections::HashMap<String, Option<u64>> =
|
||||
polyscribe_core::model_manager::list_repo_files_with_meta(&repo)
|
||||
.unwrap_or_default()
|
||||
.into_iter().collect();
|
||||
let mut size = size_map.get(cand).cloned().unwrap_or(None);
|
||||
if size.is_none() {
|
||||
size = polyscribe_core::model_manager::head_len_for_file(&repo, cand);
|
||||
}
|
||||
let local_files: std::collections::HashSet<String> = mm.ls()?.into_iter().map(|r| r.file).collect();
|
||||
let is_local = local_files.contains(cand);
|
||||
let label = format!("{} [{}]{}", cand, human_size(size), if is_local { " (local)" } else { "" });
|
||||
let ok = ui::prompt_confirm(&format!("Did you mean {}?", label), true)
|
||||
.unwrap_or(false);
|
||||
if !ok { EXIT_NOT_FOUND } else {
|
||||
let mm2: ModelManager<ReqwestClient> = ModelManager::new(settings)?;
|
||||
let alias2 = derive_alias(&repo, cand);
|
||||
match mm2.add_or_update(&alias2, &repo, cand) {
|
||||
Ok(rec) => {
|
||||
match output_mode {
|
||||
OutputMode::Json => output_mode.print_json(&rec),
|
||||
OutputMode::Human { quiet } => { if !quiet { println!("installed: {} -> {}/{}", alias2, repo, rec.file); } }
|
||||
}
|
||||
EXIT_OK
|
||||
}
|
||||
Err(e2) => { eprintln!("error: {e2}"); EXIT_NETWORK }
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let opts: Vec<String> = top;
|
||||
let local_files: std::collections::HashSet<String> = mm.ls()?.into_iter().map(|r| r.file).collect();
|
||||
// Enrich labels with size and local tag using a single API call
|
||||
let size_map: std::collections::HashMap<String, Option<u64>> =
|
||||
polyscribe_core::model_manager::list_repo_files_with_meta(&repo)
|
||||
.unwrap_or_default()
|
||||
.into_iter().collect();
|
||||
let mut labels_owned: Vec<String> = Vec::new();
|
||||
for f in &opts {
|
||||
let mut size = size_map.get(f).cloned().unwrap_or(None);
|
||||
if size.is_none() {
|
||||
size = polyscribe_core::model_manager::head_len_for_file(&repo, f);
|
||||
}
|
||||
let is_local = local_files.contains(f);
|
||||
let suffix = if is_local { " (local)" } else { "" };
|
||||
labels_owned.push(format!("{} [{}]{}", f, human_size(size), suffix));
|
||||
}
|
||||
let labels: Vec<&str> = labels_owned.iter().map(|s| s.as_str()).collect();
|
||||
match ui::prompt_select("Pick a model", &labels) {
|
||||
Ok(idx) => {
|
||||
let chosen = &opts[idx];
|
||||
let mm2: ModelManager<ReqwestClient> = ModelManager::new(settings)?;
|
||||
let alias2 = derive_alias(&repo, chosen);
|
||||
match mm2.add_or_update(&alias2, &repo, chosen) {
|
||||
Ok(rec) => {
|
||||
match output_mode {
|
||||
OutputMode::Json => output_mode.print_json(&rec),
|
||||
OutputMode::Human { quiet } => { if !quiet { println!("installed: {} -> {}/{}", alias2, repo, rec.file); } }
|
||||
}
|
||||
EXIT_OK
|
||||
}
|
||||
Err(e2) => { eprintln!("error: {e2}"); EXIT_NETWORK }
|
||||
}
|
||||
}
|
||||
Err(_) => EXIT_NOT_FOUND,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e2) => {
|
||||
eprintln!("error: {}", e2);
|
||||
EXIT_NETWORK
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ModelsCmd::ClearCache => {
|
||||
polyscribe_core::ui::info("clearing manifest cache");
|
||||
tokio::task::spawn_blocking(models::clear_manifest_cache)
|
||||
.await
|
||||
.map_err(|e| anyhow!("blocking task join error: {e}"))?
|
||||
.context("clearing cache")?;
|
||||
polyscribe_core::ui::success("Manifest cache cleared.");
|
||||
ModelsCmd::Rm { alias, common } => {
|
||||
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
|
||||
let ok = mm.rm(&alias)?;
|
||||
match output_mode {
|
||||
OutputMode::Json => {
|
||||
#[derive(serde::Serialize)]
|
||||
struct R { removed: bool }
|
||||
output_mode.print_json(&R { removed: ok });
|
||||
}
|
||||
OutputMode::Human { quiet } => {
|
||||
if !quiet { println!("{}", if ok { "removed" } else { "not found" }); }
|
||||
}
|
||||
}
|
||||
if ok { EXIT_OK } else { EXIT_NOT_FOUND }
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
ModelsCmd::Verify { alias, common } => {
|
||||
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
|
||||
let found = mm.ls()?.into_iter().any(|r| r.alias == alias);
|
||||
if !found {
|
||||
match output_mode {
|
||||
OutputMode::Json => {
|
||||
#[derive(serde::Serialize)]
|
||||
struct R<'a> { ok: bool, error: &'a str }
|
||||
output_mode.print_json(&R { ok: false, error: "not found" });
|
||||
}
|
||||
OutputMode::Human { quiet } => { if !quiet { println!("not found"); } }
|
||||
}
|
||||
EXIT_NOT_FOUND
|
||||
} else {
|
||||
let ok = mm.verify(&alias)?;
|
||||
match output_mode {
|
||||
OutputMode::Json => {
|
||||
#[derive(serde::Serialize)]
|
||||
struct R { ok: bool }
|
||||
output_mode.print_json(&R { ok });
|
||||
}
|
||||
OutputMode::Human { quiet } => { if !quiet { println!("{}", if ok { "ok" } else { "corrupt" }); } }
|
||||
}
|
||||
if ok { EXIT_OK } else { EXIT_VERIFY_FAILED }
|
||||
}
|
||||
}
|
||||
ModelsCmd::Update { common } => {
|
||||
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
|
||||
let mut rc = EXIT_OK;
|
||||
for rec in mm.ls()? {
|
||||
match mm.add_or_update(&rec.alias, &rec.repo, &rec.file) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
rc = EXIT_NETWORK;
|
||||
match output_mode {
|
||||
OutputMode::Json => {
|
||||
#[derive(serde::Serialize)]
|
||||
struct R<'a> { alias: &'a str, error: String }
|
||||
output_mode.print_json(&R { alias: &rec.alias, error: e.to_string() });
|
||||
}
|
||||
_ => { eprintln!("update {}: {e}", rec.alias); }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rc
|
||||
}
|
||||
ModelsCmd::Gc { common } => {
|
||||
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
|
||||
let (files_removed, entries_removed) = mm.gc()?;
|
||||
match output_mode {
|
||||
OutputMode::Json => {
|
||||
#[derive(serde::Serialize)]
|
||||
struct R { files_removed: usize, entries_removed: usize }
|
||||
output_mode.print_json(&R { files_removed, entries_removed });
|
||||
}
|
||||
OutputMode::Human { quiet } => { if !quiet { println!("files_removed={} entries_removed={}", files_removed, entries_removed); } }
|
||||
}
|
||||
EXIT_OK
|
||||
}
|
||||
ModelsCmd::Search { repo, query, common } => {
|
||||
let res = polyscribe_core::model_manager::search_repo(&repo, query.as_deref());
|
||||
match res {
|
||||
Ok(files) => {
|
||||
match output_mode {
|
||||
OutputMode::Json => output_mode.print_json(&files),
|
||||
OutputMode::Human { quiet } => { for f in files { if !quiet { println!("{}", f); } } }
|
||||
}
|
||||
EXIT_OK
|
||||
}
|
||||
Err(e) => {
|
||||
match output_mode {
|
||||
OutputMode::Json => {
|
||||
#[derive(serde::Serialize)]
|
||||
struct R { error: String }
|
||||
output_mode.print_json(&R { error: e.to_string() });
|
||||
}
|
||||
_ => { eprintln!("error: {e}"); }
|
||||
}
|
||||
EXIT_NETWORK
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
std::process::exit(exit);
|
||||
}
|
||||
|
||||
Commands::Plugins { cmd } => {
|
||||
@@ -123,27 +408,35 @@ async fn main() -> Result<()> {
|
||||
command,
|
||||
json,
|
||||
} => {
|
||||
let payload = json.unwrap_or_else(|| "{}".to_string());
|
||||
let mut child = plugin_manager
|
||||
.spawn(&name, &command)
|
||||
.with_context(|| format!("spawning plugin {name} {command}"))?;
|
||||
// Use a local Tokio runtime only for this async path
|
||||
let rt = tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.context("building tokio runtime")?;
|
||||
|
||||
if let Some(mut stdin) = child.stdin.take() {
|
||||
stdin
|
||||
.write_all(payload.as_bytes())
|
||||
.await
|
||||
.context("writing JSON payload to plugin stdin")?;
|
||||
}
|
||||
rt.block_on(async {
|
||||
let payload = json.unwrap_or_else(|| "{}".to_string());
|
||||
let mut child = plugin_manager
|
||||
.spawn(&name, &command)
|
||||
.with_context(|| format!("spawning plugin {name} {command}"))?;
|
||||
|
||||
let status = plugin_manager.forward_stdio(&mut child).await?;
|
||||
if !status.success() {
|
||||
polyscribe_core::ui::error(format!(
|
||||
"plugin returned non-zero exit code: {}",
|
||||
status
|
||||
));
|
||||
return Err(anyhow!("plugin failed"));
|
||||
}
|
||||
Ok(())
|
||||
if let Some(mut stdin) = child.stdin.take() {
|
||||
stdin
|
||||
.write_all(payload.as_bytes())
|
||||
.await
|
||||
.context("writing JSON payload to plugin stdin")?;
|
||||
}
|
||||
|
||||
let status = plugin_manager.forward_stdio(&mut child).await?;
|
||||
if !status.success() {
|
||||
polyscribe_core::ui::error(format!(
|
||||
"plugin returned non-zero exit code: {}",
|
||||
status
|
||||
));
|
||||
return Err(anyhow!("plugin failed"));
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
36
crates/polyscribe-cli/src/output.rs
Normal file
36
crates/polyscribe-cli/src/output.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
use std::io::{self, Write};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum OutputMode {
|
||||
Json,
|
||||
Human { quiet: bool },
|
||||
}
|
||||
|
||||
impl OutputMode {
|
||||
pub fn is_quiet(&self) -> bool {
|
||||
matches!(self, OutputMode::Json) || matches!(self, OutputMode::Human { quiet: true })
|
||||
}
|
||||
|
||||
pub fn print_json<T: serde::Serialize>(&self, v: &T) {
|
||||
if let OutputMode::Json = self {
|
||||
// Write compact JSON to stdout without prefixes
|
||||
// and ensure a trailing newline for CLI ergonomics
|
||||
let s = serde_json::to_string(v).unwrap_or_else(|e| format!("\"JSON_ERROR:{}\"", e));
|
||||
println!("{}", s);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn print_line(&self, s: impl AsRef<str>) {
|
||||
match self {
|
||||
OutputMode::Json => {
|
||||
// Suppress human lines in JSON mode
|
||||
}
|
||||
OutputMode::Human { quiet } => {
|
||||
if !quiet {
|
||||
let _ = writeln!(io::stdout(), "{}", s.as_ref());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
42
crates/polyscribe-cli/tests/models_smoke.rs
Normal file
42
crates/polyscribe-cli/tests/models_smoke.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use assert_cmd::cargo::cargo_bin;
|
||||
use std::process::Command;
|
||||
|
||||
fn bin() -> std::path::PathBuf { cargo_bin("polyscribe") }
|
||||
|
||||
#[test]
|
||||
fn models_help_shows_global_output_flags() {
|
||||
let out = Command::new(bin())
|
||||
.args(["models", "--help"]) // subcommand help
|
||||
.output()
|
||||
.expect("failed to run polyscribe models --help");
|
||||
assert!(out.status.success(), "help exited non-zero: {:?}", out.status);
|
||||
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
|
||||
assert!(stdout.contains("--json"), "--json not shown in help: {stdout}");
|
||||
assert!(stdout.contains("--quiet"), "--quiet not shown in help: {stdout}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn models_version_contains_pkg_version() {
|
||||
let out = Command::new(bin())
|
||||
.args(["models", "--version"]) // propagate_version
|
||||
.output()
|
||||
.expect("failed to run polyscribe models --version");
|
||||
assert!(out.status.success(), "version exited non-zero: {:?}", out.status);
|
||||
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
|
||||
let want = env!("CARGO_PKG_VERSION");
|
||||
assert!(stdout.contains(want), "version output missing {want}: {stdout}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn models_ls_json_quiet_emits_pure_json() {
|
||||
let out = Command::new(bin())
|
||||
.args(["models", "ls", "--json", "--quiet"]) // global flags
|
||||
.output()
|
||||
.expect("failed to run polyscribe models ls --json --quiet");
|
||||
assert!(out.status.success(), "ls exited non-zero: {:?}", out.status);
|
||||
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
|
||||
serde_json::from_str::<serde_json::Value>(stdout.trim()).expect("stdout is not valid JSON");
|
||||
// Expect no extra logs on stdout; stderr should be empty in success path
|
||||
assert!(out.stderr.is_empty(), "expected no stderr noise");
|
||||
}
|
||||
|
@@ -1,22 +1,22 @@
|
||||
[package]
|
||||
name = "polyscribe-core"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.99"
|
||||
thiserror = "1.0.69"
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
serde_json = "1.0.142"
|
||||
toml = "0.8.23"
|
||||
directories = "5.0.1"
|
||||
chrono = "0.4.41"
|
||||
libc = "0.2.175"
|
||||
whisper-rs = "0.14.3"
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
directories = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
libc = { workspace = true }
|
||||
whisper-rs = { workspace = true }
|
||||
# UI and progress
|
||||
cliclack = { workspace = true }
|
||||
# New: HTTP downloads + hashing
|
||||
reqwest = { version = "0.12.7", default-features = false, features = ["blocking", "rustls-tls", "gzip", "json"] }
|
||||
sha2 = "0.10.8"
|
||||
hex = "0.4.3"
|
||||
tempfile = "3.12.0"
|
||||
# HTTP downloads + hashing
|
||||
reqwest = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
hex = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
|
@@ -214,6 +214,8 @@ pub fn render_srt(entries: &[OutputEntry]) -> String {
|
||||
srt
|
||||
}
|
||||
|
||||
pub mod model_manager;
|
||||
|
||||
pub fn models_dir_path() -> PathBuf {
|
||||
if let Ok(env_val) = env::var("POLYSCRIBE_MODELS_DIR") {
|
||||
let env_path = PathBuf::from(env_val);
|
||||
|
893
crates/polyscribe-core/src/model_manager.rs
Normal file
893
crates/polyscribe-core/src/model_manager.rs
Normal file
@@ -0,0 +1,893 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
use crate::prelude::*;
|
||||
use crate::ui::BytesProgress;
|
||||
use anyhow::{anyhow, Context};
|
||||
use chrono::{DateTime, Utc};
|
||||
use reqwest::blocking::Client;
|
||||
use reqwest::header::{
|
||||
ACCEPT_RANGES, AUTHORIZATION, CONTENT_LENGTH, ETAG, IF_NONE_MATCH, LAST_MODIFIED, RANGE,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::cmp::min;
|
||||
use std::collections::BTreeMap;
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
use std::io::{Read, Seek, SeekFrom, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::{mpsc, Arc, Mutex};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
const DEFAULT_CHUNK_SIZE: u64 = 8 * 1024 * 1024; // 8 MiB
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct ModelRecord {
|
||||
pub alias: String,
|
||||
pub repo: String,
|
||||
pub file: String,
|
||||
pub revision: Option<String>, // ETag or commit hash
|
||||
pub sha256: Option<String>,
|
||||
pub size_bytes: Option<u64>,
|
||||
pub quant: Option<String>,
|
||||
pub installed_at: Option<DateTime<Utc>>,
|
||||
pub last_used: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct Manifest {
|
||||
pub models: BTreeMap<String, ModelRecord>, // key = alias
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Settings {
|
||||
pub concurrency: usize,
|
||||
pub limit_rate: Option<u64>, // bytes/sec
|
||||
pub chunk_size: u64,
|
||||
}
|
||||
|
||||
impl Default for Settings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
concurrency: 4,
|
||||
limit_rate: None,
|
||||
chunk_size: DEFAULT_CHUNK_SIZE,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Paths {
|
||||
pub cache_dir: PathBuf, // $XDG_CACHE_HOME/polyscribe/models
|
||||
pub config_path: PathBuf, // $XDG_CONFIG_HOME/polyscribe/models.json
|
||||
}
|
||||
|
||||
impl Paths {
|
||||
pub fn resolve() -> Result<Self> {
|
||||
if let Ok(over) = std::env::var("POLYSCRIBE_CACHE_DIR") {
|
||||
if !over.is_empty() {
|
||||
let cache_dir = PathBuf::from(over).join("models");
|
||||
let config_path = std::env::var("POLYSCRIBE_CONFIG_DIR")
|
||||
.map(|p| PathBuf::from(p).join("models.json"))
|
||||
.unwrap_or_else(|_| default_config_path());
|
||||
return Ok(Self {
|
||||
cache_dir,
|
||||
config_path,
|
||||
});
|
||||
}
|
||||
}
|
||||
let cache_dir = default_cache_dir();
|
||||
let config_path = default_config_path();
|
||||
Ok(Self {
|
||||
cache_dir,
|
||||
config_path,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn default_cache_dir() -> PathBuf {
|
||||
if let Ok(xdg) = std::env::var("XDG_CACHE_HOME") {
|
||||
if !xdg.is_empty() {
|
||||
return PathBuf::from(xdg).join("polyscribe").join("models");
|
||||
}
|
||||
}
|
||||
if let Ok(home) = std::env::var("HOME") {
|
||||
if !home.is_empty() {
|
||||
return PathBuf::from(home)
|
||||
.join(".cache")
|
||||
.join("polyscribe")
|
||||
.join("models");
|
||||
}
|
||||
}
|
||||
PathBuf::from("models")
|
||||
}
|
||||
|
||||
fn default_config_path() -> PathBuf {
|
||||
if let Ok(xdg) = std::env::var("XDG_CONFIG_HOME") {
|
||||
if !xdg.is_empty() {
|
||||
return PathBuf::from(xdg).join("polyscribe").join("models.json");
|
||||
}
|
||||
}
|
||||
if let Ok(home) = std::env::var("HOME") {
|
||||
if !home.is_empty() {
|
||||
return PathBuf::from(home)
|
||||
.join(".config")
|
||||
.join("polyscribe")
|
||||
.join("models.json");
|
||||
}
|
||||
}
|
||||
PathBuf::from("models.json")
|
||||
}
|
||||
|
||||
pub trait HttpClient: Send + Sync {
|
||||
fn head(&self, url: &str, etag: Option<&str>) -> Result<HeadMeta>;
|
||||
fn get_range(&self, url: &str, start: u64, end_inclusive: u64) -> Result<Vec<u8>>;
|
||||
fn get_whole_to(&self, url: &str, writer: &mut dyn Write) -> Result<()>;
|
||||
fn get_from_to(&self, url: &str, start: u64, writer: &mut dyn Write) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ReqwestClient {
|
||||
client: Client,
|
||||
token: Option<String>,
|
||||
}
|
||||
|
||||
impl ReqwestClient {
|
||||
pub fn new() -> Result<Self> {
|
||||
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
|
||||
let client = Client::builder()
|
||||
.user_agent(crate::config::ConfigService::user_agent())
|
||||
.build()?;
|
||||
Ok(Self { client, token })
|
||||
}
|
||||
|
||||
fn auth(&self, mut req: reqwest::blocking::RequestBuilder) -> reqwest::blocking::RequestBuilder {
|
||||
if let Some(t) = &self.token {
|
||||
req = req.header(AUTHORIZATION, format!("Bearer {}", t));
|
||||
}
|
||||
req
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HeadMeta {
|
||||
pub len: Option<u64>,
|
||||
pub etag: Option<String>,
|
||||
pub last_modified: Option<String>,
|
||||
pub accept_ranges: bool,
|
||||
pub not_modified: bool,
|
||||
pub status: u16,
|
||||
}
|
||||
|
||||
impl HttpClient for ReqwestClient {
|
||||
fn head(&self, url: &str, etag: Option<&str>) -> Result<HeadMeta> {
|
||||
let mut req = self.client.head(url);
|
||||
if let Some(e) = etag {
|
||||
req = req.header(IF_NONE_MATCH, format!("\"{}\"", e));
|
||||
}
|
||||
let resp = self.auth(req).send()?;
|
||||
let status = resp.status().as_u16();
|
||||
if status == 304 {
|
||||
return Ok(HeadMeta {
|
||||
len: None,
|
||||
etag: etag.map(|s| s.to_string()),
|
||||
last_modified: None,
|
||||
accept_ranges: true,
|
||||
not_modified: true,
|
||||
status,
|
||||
});
|
||||
}
|
||||
let len = resp
|
||||
.headers()
|
||||
.get(CONTENT_LENGTH)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<u64>().ok());
|
||||
let etag = resp
|
||||
.headers()
|
||||
.get(ETAG)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.trim_matches('"').to_string());
|
||||
let last_modified = resp
|
||||
.headers()
|
||||
.get(LAST_MODIFIED)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
let accept_ranges = resp
|
||||
.headers()
|
||||
.get(ACCEPT_RANGES)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_ascii_lowercase().contains("bytes"))
|
||||
.unwrap_or(false);
|
||||
Ok(HeadMeta {
|
||||
len,
|
||||
etag,
|
||||
last_modified,
|
||||
accept_ranges,
|
||||
not_modified: false,
|
||||
status,
|
||||
})
|
||||
}
|
||||
|
||||
fn get_range(&self, url: &str, start: u64, end_inclusive: u64) -> Result<Vec<u8>> {
|
||||
let range_val = format!("bytes={}-{}", start, end_inclusive);
|
||||
let resp = self
|
||||
.auth(self.client.get(url))
|
||||
.header(RANGE, range_val)
|
||||
.send()?;
|
||||
if !resp.status().is_success() && resp.status().as_u16() != 206 {
|
||||
return Err(anyhow!("HTTP {} for ranged GET", resp.status()).into());
|
||||
}
|
||||
let mut buf = Vec::new();
|
||||
let mut r = resp;
|
||||
r.copy_to(&mut buf)?;
|
||||
Ok(buf)
|
||||
}
|
||||
|
||||
fn get_whole_to(&self, url: &str, writer: &mut dyn Write) -> Result<()> {
|
||||
let resp = self.auth(self.client.get(url)).send()?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(anyhow!("HTTP {} for GET", resp.status()).into());
|
||||
}
|
||||
let mut r = resp;
|
||||
r.copy_to(writer)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_from_to(&self, url: &str, start: u64, writer: &mut dyn Write) -> Result<()> {
|
||||
let mut req = self.auth(self.client.get(url));
|
||||
if start > 0 {
|
||||
req = req.header(RANGE, format!("bytes={}-", start));
|
||||
}
|
||||
let resp = req.send()?;
|
||||
if !resp.status().is_success() && resp.status().as_u16() != 206 {
|
||||
return Err(anyhow!("HTTP {} for ranged GET from {}", resp.status(), start).into());
|
||||
}
|
||||
let mut r = resp;
|
||||
r.copy_to(writer)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ModelManager<C: HttpClient = ReqwestClient> {
|
||||
pub paths: Paths,
|
||||
pub settings: Settings,
|
||||
client: Arc<C>,
|
||||
}
|
||||
|
||||
impl<C: HttpClient + 'static> ModelManager<C> {
|
||||
pub fn new_with_client(client: C, settings: Settings) -> Result<Self> {
|
||||
Ok(Self {
|
||||
paths: Paths::resolve()?,
|
||||
settings,
|
||||
client: Arc::new(client),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new(settings: Settings) -> Result<Self>
|
||||
where
|
||||
C: Default,
|
||||
{
|
||||
Ok(Self {
|
||||
paths: Paths::resolve()?,
|
||||
settings,
|
||||
client: Arc::new(C::default()),
|
||||
})
|
||||
}
|
||||
|
||||
fn load_manifest(&self) -> Result<Manifest> {
|
||||
let p = &self.paths.config_path;
|
||||
if !p.exists() {
|
||||
return Ok(Manifest::default());
|
||||
}
|
||||
let file = File::open(p).with_context(|| format!("open manifest {}", p.display()))?;
|
||||
let m: Manifest = serde_json::from_reader(file).context("parse manifest")?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
fn save_manifest(&self, m: &Manifest) -> Result<()> {
|
||||
let p = &self.paths.config_path;
|
||||
if let Some(dir) = p.parent() {
|
||||
fs::create_dir_all(dir)
|
||||
.with_context(|| format!("create config dir {}", dir.display()))?;
|
||||
}
|
||||
let tmp = p.with_extension("json.tmp");
|
||||
let f = OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.truncate(true)
|
||||
.open(&tmp)?;
|
||||
serde_json::to_writer_pretty(f, m).context("serialize manifest")?;
|
||||
fs::rename(&tmp, p).with_context(|| format!("atomic rename {} -> {}", tmp.display(), p.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn model_path(&self, file: &str) -> PathBuf {
|
||||
self.paths.cache_dir.join(file)
|
||||
}
|
||||
|
||||
fn compute_sha256(path: &Path) -> Result<String> {
|
||||
let mut f = File::open(path)?;
|
||||
let mut hasher = Sha256::new();
|
||||
let mut buf = [0u8; 64 * 1024];
|
||||
loop {
|
||||
let n = f.read(&mut buf)?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
hasher.update(&buf[..n]);
|
||||
}
|
||||
Ok(format!("{:x}", hasher.finalize()))
|
||||
}
|
||||
|
||||
pub fn ls(&self) -> Result<Vec<ModelRecord>> {
|
||||
let m = self.load_manifest()?;
|
||||
Ok(m.models.values().cloned().collect())
|
||||
}
|
||||
|
||||
pub fn rm(&self, alias: &str) -> Result<bool> {
|
||||
let mut m = self.load_manifest()?;
|
||||
if let Some(rec) = m.models.remove(alias) {
|
||||
let p = self.model_path(&rec.file);
|
||||
let _ = fs::remove_file(&p);
|
||||
self.save_manifest(&m)?;
|
||||
return Ok(true);
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
pub fn verify(&self, alias: &str) -> Result<bool> {
|
||||
let m = self.load_manifest()?;
|
||||
let Some(rec) = m.models.get(alias) else { return Ok(false) };
|
||||
let p = self.model_path(&rec.file);
|
||||
if !p.exists() { return Ok(false); }
|
||||
if let Some(expected) = &rec.sha256 {
|
||||
let actual = Self::compute_sha256(&p)?;
|
||||
return Ok(&actual == expected);
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
pub fn gc(&self) -> Result<(usize, usize)> {
|
||||
// Remove files not referenced by manifest; also drop manifest entries whose file is missing
|
||||
fs::create_dir_all(&self.paths.cache_dir).ok();
|
||||
let mut m = self.load_manifest()?;
|
||||
let mut referenced = BTreeMap::new();
|
||||
for (alias, rec) in &m.models {
|
||||
referenced.insert(rec.file.clone(), alias.clone());
|
||||
}
|
||||
let mut removed_files = 0usize;
|
||||
if let Ok(rd) = fs::read_dir(&self.paths.cache_dir) {
|
||||
for ent in rd.flatten() {
|
||||
let p = ent.path();
|
||||
if p.is_file() {
|
||||
let fname = p.file_name().and_then(|s| s.to_str()).unwrap_or("");
|
||||
if !referenced.contains_key(fname) {
|
||||
let _ = fs::remove_file(&p);
|
||||
removed_files += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.models.retain(|_, rec| self.model_path(&rec.file).exists());
|
||||
let removed_entries = referenced
|
||||
.keys()
|
||||
.filter(|f| !self.model_path(f).exists())
|
||||
.count();
|
||||
|
||||
self.save_manifest(&m)?;
|
||||
Ok((removed_files, removed_entries))
|
||||
}
|
||||
|
||||
pub fn add_or_update(
|
||||
&self,
|
||||
alias: &str,
|
||||
repo: &str,
|
||||
file: &str,
|
||||
) -> Result<ModelRecord> {
|
||||
fs::create_dir_all(&self.paths.cache_dir)
|
||||
.with_context(|| format!("create cache dir {}", self.paths.cache_dir.display()))?;
|
||||
|
||||
let url = format!("https://huggingface.co/{}/resolve/main/{}?download=true", repo, file);
|
||||
|
||||
let mut manifest = self.load_manifest()?;
|
||||
let prev = manifest.models.get(alias).cloned();
|
||||
let prev_etag = prev.as_ref().and_then(|r| r.revision.clone());
|
||||
// Fetch remote meta (size/hash) when available to verify the download
|
||||
let (_api_size, api_sha) = hf_fetch_file_meta(repo, file).unwrap_or((None, None));
|
||||
let head = self.client.head(&url, prev_etag.as_deref())?;
|
||||
if head.not_modified {
|
||||
// up-to-date; ensure record present and touch last_used
|
||||
let mut rec = prev.ok_or_else(|| anyhow!("not installed yet but server says 304"))?;
|
||||
rec.last_used = Some(Utc::now());
|
||||
self.save_touch(&mut manifest, rec.clone())?;
|
||||
return Ok(rec);
|
||||
}
|
||||
|
||||
// Quick check: if HEAD failed (e.g., 404), report a helpful error before attempting download
|
||||
if head.status >= 400 {
|
||||
return Err(anyhow!(
|
||||
"file not found or inaccessible: repo='{}' file='{}' (HTTP {})\nHint: run `polyscribe models search {} --query {}` to list available files",
|
||||
repo,
|
||||
file,
|
||||
head.status,
|
||||
repo,
|
||||
file
|
||||
).into());
|
||||
}
|
||||
|
||||
let total_len = head.len.ok_or_else(|| anyhow!("missing content-length (HEAD)"))?;
|
||||
let etag = head.etag.clone();
|
||||
let dest_tmp = self.model_path(&format!("{}.part", file));
|
||||
// If a previous cancelled download left a .part file, remove it to avoid clutter/resume.
|
||||
if dest_tmp.exists() { let _ = fs::remove_file(&dest_tmp); }
|
||||
// Guard to ensure .part is cleaned up on errors
|
||||
struct TempGuard { path: PathBuf, armed: bool }
|
||||
impl TempGuard { fn disarm(&mut self) { self.armed = false; } }
|
||||
impl Drop for TempGuard {
|
||||
fn drop(&mut self) {
|
||||
if self.armed {
|
||||
let _ = fs::remove_file(&self.path);
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut _tmp_guard = TempGuard { path: dest_tmp.clone(), armed: true };
|
||||
let dest_final = self.model_path(file);
|
||||
|
||||
// Do not resume after cancellation; start fresh to avoid stale .part files
|
||||
let start_from = 0u64;
|
||||
|
||||
// Open tmp for write
|
||||
let f = OpenOptions::new().create(true).write(true).read(true).open(&dest_tmp)?;
|
||||
f.set_len(total_len)?; // pre-allocate for random writes
|
||||
let f = Arc::new(Mutex::new(f));
|
||||
|
||||
// Create progress bar
|
||||
let mut progress = BytesProgress::start(total_len, &format!("Downloading {}", file), start_from);
|
||||
|
||||
// Create work chunks
|
||||
let chunk_size = self.settings.chunk_size;
|
||||
let mut chunks = Vec::new();
|
||||
let mut pos = start_from;
|
||||
while pos < total_len {
|
||||
let end = min(total_len - 1, pos + chunk_size - 1);
|
||||
chunks.push((pos, end));
|
||||
pos = end + 1;
|
||||
}
|
||||
|
||||
// Attempt concurrent ranged download; on failure, fallback to streaming GET
|
||||
let mut ranged_failed = false;
|
||||
if head.accept_ranges && self.settings.concurrency > 1 {
|
||||
let (work_tx, work_rx) = mpsc::channel::<(u64, u64)>();
|
||||
let (prog_tx, prog_rx) = mpsc::channel::<u64>();
|
||||
for ch in chunks {
|
||||
work_tx.send(ch).unwrap();
|
||||
}
|
||||
drop(work_tx);
|
||||
let rx = Arc::new(Mutex::new(work_rx));
|
||||
|
||||
let workers = self.settings.concurrency.max(1);
|
||||
let mut handles = Vec::new();
|
||||
for _ in 0..workers {
|
||||
let rx = rx.clone();
|
||||
let url = url.clone();
|
||||
let client = self.client.clone();
|
||||
let f = f.clone();
|
||||
let limit = self.settings.limit_rate;
|
||||
let prog_tx = prog_tx.clone();
|
||||
let handle = thread::spawn(move || -> Result<()> {
|
||||
loop {
|
||||
let next = {
|
||||
let guard = rx.lock().unwrap();
|
||||
guard.recv().ok()
|
||||
};
|
||||
let Some((start, end)) = next else { break; };
|
||||
let data = client.get_range(&url, start, end)?;
|
||||
if let Some(max_bps) = limit {
|
||||
let dur = Duration::from_secs_f64((data.len() as f64) / (max_bps as f64));
|
||||
if dur > Duration::from_millis(1) {
|
||||
thread::sleep(dur);
|
||||
}
|
||||
}
|
||||
let mut guard = f.lock().unwrap();
|
||||
guard.seek(SeekFrom::Start(start))?;
|
||||
guard.write_all(&data)?;
|
||||
let _ = prog_tx.send(data.len() as u64);
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
drop(prog_tx);
|
||||
|
||||
for delta in prog_rx {
|
||||
progress.inc(delta);
|
||||
}
|
||||
|
||||
let mut ranged_err: Option<anyhow::Error> = None;
|
||||
for h in handles {
|
||||
match h.join() {
|
||||
Ok(Ok(())) => {}
|
||||
Ok(Err(e)) => { ranged_failed = true; if ranged_err.is_none() { ranged_err = Some(e.into()); } }
|
||||
Err(_) => { ranged_failed = true; if ranged_err.is_none() { ranged_err = Some(anyhow!("worker panicked")); } }
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ranged_failed = true;
|
||||
}
|
||||
|
||||
if ranged_failed {
|
||||
// Restart progress if we are abandoning previous partial state
|
||||
if start_from > 0 {
|
||||
progress.stop("retrying as streaming");
|
||||
progress = BytesProgress::start(total_len, &format!("Downloading {}", file), 0);
|
||||
}
|
||||
// Fallback to streaming GET; try URL with and without ?download=true
|
||||
let mut try_urls = vec![url.clone()];
|
||||
if let Some((base, _qs)) = url.split_once('?') {
|
||||
try_urls.push(base.to_string());
|
||||
} else {
|
||||
try_urls.push(format!("{}?download=true", url));
|
||||
}
|
||||
// Fresh temp file for streaming
|
||||
let mut wf = OpenOptions::new().create(true).write(true).truncate(true).open(&dest_tmp)?;
|
||||
let mut ok = false;
|
||||
let mut last_err: Option<anyhow::Error> = None;
|
||||
// Counting writer to update progress inline
|
||||
struct CountingWriter<'a, 'b> {
|
||||
inner: &'a mut File,
|
||||
progress: &'b mut BytesProgress,
|
||||
}
|
||||
impl<'a, 'b> Write for CountingWriter<'a, 'b> {
|
||||
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
let n = self.inner.write(buf)?;
|
||||
self.progress.inc(n as u64);
|
||||
Ok(n)
|
||||
}
|
||||
fn flush(&mut self) -> std::io::Result<()> { self.inner.flush() }
|
||||
}
|
||||
let mut cw = CountingWriter { inner: &mut wf, progress: &mut progress };
|
||||
for u in try_urls {
|
||||
// For fallback, stream from scratch to ensure integrity
|
||||
let res = self.client.get_whole_to(&u, &mut cw);
|
||||
match res {
|
||||
Ok(()) => { ok = true; break; }
|
||||
Err(e) => { last_err = Some(e.into()); }
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
if let Some(e) = last_err { return Err(e.into()); }
|
||||
return Err(anyhow!("download failed (ranged and streaming)").into());
|
||||
}
|
||||
}
|
||||
|
||||
progress.stop("download complete");
|
||||
|
||||
// Verify integrity
|
||||
let sha = Self::compute_sha256(&dest_tmp)?;
|
||||
if let Some(expected) = api_sha.as_ref() {
|
||||
if &sha != expected {
|
||||
return Err(anyhow!(
|
||||
"sha256 mismatch (expected {}, got {})",
|
||||
expected,
|
||||
sha
|
||||
).into());
|
||||
}
|
||||
} else if prev.as_ref().map(|r| r.file.eq(file)).unwrap_or(false) {
|
||||
if let Some(expected) = prev.as_ref().and_then(|r| r.sha256.as_ref()) {
|
||||
if &sha != expected {
|
||||
return Err(anyhow!("sha256 mismatch").into());
|
||||
}
|
||||
}
|
||||
}
|
||||
// Atomic rename
|
||||
fs::rename(&dest_tmp, &dest_final).with_context(|| format!("rename {} -> {}", dest_tmp.display(), dest_final.display()))?;
|
||||
// Disarm guard; .part has been moved or cleaned
|
||||
_tmp_guard.disarm();
|
||||
|
||||
let rec = ModelRecord {
|
||||
alias: alias.to_string(),
|
||||
repo: repo.to_string(),
|
||||
file: file.to_string(),
|
||||
revision: etag,
|
||||
sha256: Some(sha.clone()),
|
||||
size_bytes: Some(total_len),
|
||||
quant: infer_quant(file),
|
||||
installed_at: Some(Utc::now()),
|
||||
last_used: Some(Utc::now()),
|
||||
};
|
||||
self.save_touch(&mut manifest, rec.clone())?;
|
||||
Ok(rec)
|
||||
}
|
||||
|
||||
fn save_touch(&self, manifest: &mut Manifest, rec: ModelRecord) -> Result<()> {
|
||||
manifest.models.insert(rec.alias.clone(), rec);
|
||||
self.save_manifest(manifest)
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_quant(file: &str) -> Option<String> {
|
||||
// Try to extract a Q* token, e.g. Q5_K_M from filename
|
||||
let lower = file.to_ascii_uppercase();
|
||||
if let Some(pos) = lower.find('Q') {
|
||||
let tail = &lower[pos..];
|
||||
let token: String = tail
|
||||
.chars()
|
||||
.take_while(|c| c.is_ascii_uppercase() || c.is_ascii_digit() || *c == '_' || *c == '-')
|
||||
.collect();
|
||||
if token.len() >= 2 {
|
||||
return Some(token);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
impl Default for ReqwestClient {
|
||||
fn default() -> Self {
|
||||
Self::new().expect("reqwest client")
|
||||
}
|
||||
}
|
||||
|
||||
// Hugging Face API types for file metadata
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiHfLfs {
|
||||
oid: Option<String>,
|
||||
size: Option<u64>,
|
||||
sha256: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiHfFile {
|
||||
rfilename: String,
|
||||
size: Option<u64>,
|
||||
sha256: Option<String>,
|
||||
lfs: Option<ApiHfLfs>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiHfModelInfo {
|
||||
siblings: Option<Vec<ApiHfFile>>,
|
||||
files: Option<Vec<ApiHfFile>>,
|
||||
}
|
||||
|
||||
fn pick_sha_from_file(f: &ApiHfFile) -> Option<String> {
|
||||
if let Some(s) = &f.sha256 { return Some(s.to_string()); }
|
||||
if let Some(l) = &f.lfs {
|
||||
if let Some(s) = &l.sha256 { return Some(s.to_string()); }
|
||||
if let Some(oid) = &l.oid { return oid.strip_prefix("sha256:").map(|s| s.to_string()); }
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn hf_fetch_file_meta(repo: &str, target: &str) -> Result<(Option<u64>, Option<String>)> {
|
||||
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
|
||||
let client = Client::builder()
|
||||
.user_agent(crate::config::ConfigService::user_agent())
|
||||
.build()?;
|
||||
let base = crate::config::ConfigService::hf_api_base_for(repo);
|
||||
let urls = [base.clone(), format!("{}?expand=files", base)];
|
||||
for url in urls {
|
||||
let mut req = client.get(&url);
|
||||
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
|
||||
let resp = req.send()?;
|
||||
if !resp.status().is_success() { continue; }
|
||||
let info: ApiHfModelInfo = resp.json()?;
|
||||
let list = info.files.or(info.siblings).unwrap_or_default();
|
||||
for f in list {
|
||||
let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename);
|
||||
if name.eq_ignore_ascii_case(target) {
|
||||
let sz = f.size.or_else(|| f.lfs.as_ref().and_then(|l| l.size));
|
||||
let sha = pick_sha_from_file(&f);
|
||||
return Ok((sz, sha));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(anyhow!("file not found in HF API").into())
|
||||
}
|
||||
|
||||
/// Fetch remote metadata (size, sha256) for a single file in a HF repo.
|
||||
pub fn fetch_file_meta(repo: &str, file: &str) -> Result<(Option<u64>, Option<String>)> {
|
||||
hf_fetch_file_meta(repo, file)
|
||||
}
|
||||
|
||||
/// Search a Hugging Face repo for GGUF/BIN files via API. Returns file names only.
|
||||
pub fn search_repo(repo: &str, query: Option<&str>) -> Result<Vec<String>> {
|
||||
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
|
||||
let client = Client::builder()
|
||||
.user_agent(crate::config::ConfigService::user_agent())
|
||||
.build()?;
|
||||
|
||||
let base = crate::config::ConfigService::hf_api_base_for(repo);
|
||||
let mut urls = vec![base.clone(), format!("{}?expand=files", base)];
|
||||
let mut files = Vec::<String>::new();
|
||||
|
||||
for url in urls.drain(..) {
|
||||
let mut req = client.get(&url);
|
||||
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
|
||||
let resp = req.send()?;
|
||||
if !resp.status().is_success() { continue; }
|
||||
let info: ApiHfModelInfo = resp.json()?;
|
||||
let list = info.files.or(info.siblings).unwrap_or_default();
|
||||
for f in list {
|
||||
if f.rfilename.ends_with(".gguf") || f.rfilename.ends_with(".bin") {
|
||||
let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename).to_string();
|
||||
if !files.contains(&name) { files.push(name); }
|
||||
}
|
||||
}
|
||||
if !files.is_empty() { break; }
|
||||
}
|
||||
|
||||
if let Some(q) = query { let qlc = q.to_ascii_lowercase(); files.retain(|f| f.to_ascii_lowercase().contains(&qlc)); }
|
||||
files.sort();
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
/// List repo files with optional size metadata for GGUF/BIN entries.
|
||||
pub fn list_repo_files_with_meta(repo: &str) -> Result<Vec<(String, Option<u64>)>> {
|
||||
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
|
||||
let client = Client::builder()
|
||||
.user_agent(crate::config::ConfigService::user_agent())
|
||||
.build()?;
|
||||
|
||||
let base = crate::config::ConfigService::hf_api_base_for(repo);
|
||||
for url in [base.clone(), format!("{}?expand=files", base)] {
|
||||
let mut req = client.get(&url);
|
||||
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
|
||||
let resp = req.send()?;
|
||||
if !resp.status().is_success() { continue; }
|
||||
let info: ApiHfModelInfo = resp.json()?;
|
||||
let list = info.files.or(info.siblings).unwrap_or_default();
|
||||
let mut out = Vec::new();
|
||||
for f in list {
|
||||
if !(f.rfilename.ends_with(".gguf") || f.rfilename.ends_with(".bin")) { continue; }
|
||||
let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename).to_string();
|
||||
let size = f.size.or_else(|| f.lfs.as_ref().and_then(|l| l.size));
|
||||
out.push((name, size));
|
||||
}
|
||||
if !out.is_empty() { return Ok(out); }
|
||||
}
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
/// Fallback: HEAD request for a single file to retrieve Content-Length (size).
|
||||
pub fn head_len_for_file(repo: &str, file: &str) -> Option<u64> {
|
||||
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
|
||||
let client = Client::builder()
|
||||
.user_agent(crate::config::ConfigService::user_agent())
|
||||
.build().ok()?;
|
||||
let mut urls = Vec::new();
|
||||
urls.push(format!("https://huggingface.co/{}/resolve/main/{}?download=true", repo, file));
|
||||
urls.push(format!("https://huggingface.co/{}/resolve/main/{}", repo, file));
|
||||
for url in urls {
|
||||
let mut req = client.head(&url);
|
||||
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
|
||||
if let Ok(resp) = req.send() {
|
||||
if resp.status().is_success() {
|
||||
if let Some(len) = resp.headers().get(CONTENT_LENGTH)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<u64>().ok())
|
||||
{ return Some(len); }
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::env;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct StubHttp {
|
||||
data: Arc<Vec<u8>>,
|
||||
etag: Arc<Option<String>>,
|
||||
accept_ranges: bool,
|
||||
}
|
||||
impl HttpClient for StubHttp {
|
||||
fn head(&self, _url: &str, etag: Option<&str>) -> Result<HeadMeta> {
|
||||
let not_modified = etag.is_some() && self.etag.as_ref().as_deref() == etag;
|
||||
Ok(HeadMeta {
|
||||
len: Some(self.data.len() as u64),
|
||||
etag: self.etag.as_ref().clone(),
|
||||
last_modified: None,
|
||||
accept_ranges: self.accept_ranges,
|
||||
not_modified,
|
||||
status: if not_modified { 304 } else { 200 },
|
||||
})
|
||||
}
|
||||
fn get_range(&self, _url: &str, start: u64, end_inclusive: u64) -> Result<Vec<u8>> {
|
||||
let s = start as usize;
|
||||
let e = (end_inclusive as usize) + 1;
|
||||
Ok(self.data[s..e].to_vec())
|
||||
}
|
||||
|
||||
fn get_whole_to(&self, _url: &str, writer: &mut dyn Write) -> Result<()> {
|
||||
writer.write_all(&self.data)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_from_to(&self, _url: &str, start: u64, writer: &mut dyn Write) -> Result<()> {
|
||||
let s = start as usize;
|
||||
writer.write_all(&self.data[s..])?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn setup_env(cache: &Path, cfg: &Path) {
|
||||
unsafe {
|
||||
env::set_var("POLYSCRIBE_CACHE_DIR", cache.to_string_lossy().to_string());
|
||||
env::set_var(
|
||||
"POLYSCRIBE_CONFIG_DIR",
|
||||
cfg.parent().unwrap().to_string_lossy().to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_manifest_roundtrip() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let cache = temp.path().join("cache");
|
||||
let cfg = temp.path().join("config").join("models.json");
|
||||
setup_env(&cache, &cfg);
|
||||
|
||||
let client = StubHttp {
|
||||
data: Arc::new(vec![0u8; 1024]),
|
||||
etag: Arc::new(Some("etag123".into())),
|
||||
accept_ranges: true,
|
||||
};
|
||||
let mm: ModelManager<StubHttp> = ModelManager::new_with_client(client, Settings::default()).unwrap();
|
||||
let m = mm.load_manifest().unwrap();
|
||||
assert!(m.models.is_empty());
|
||||
|
||||
let rec = ModelRecord {
|
||||
alias: "tiny".into(),
|
||||
repo: "foo/bar".into(),
|
||||
file: "gguf-tiny.bin".into(),
|
||||
revision: Some("etag123".into()),
|
||||
sha256: None,
|
||||
size_bytes: None,
|
||||
quant: None,
|
||||
installed_at: None,
|
||||
last_used: None,
|
||||
};
|
||||
|
||||
let mut m2 = Manifest::default();
|
||||
mm.save_touch(&mut m2, rec.clone()).unwrap();
|
||||
let m3 = mm.load_manifest().unwrap();
|
||||
assert!(m3.models.contains_key("tiny"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_verify_update_gc() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let cache = temp.path().join("cache");
|
||||
let cfg_dir = temp.path().join("config");
|
||||
let cfg = cfg_dir.join("models.json");
|
||||
setup_env(&cache, &cfg);
|
||||
|
||||
let data = (0..1024 * 1024u32).flat_map(|i| i.to_le_bytes()).collect::<Vec<u8>>();
|
||||
let etag = Some("abc123".to_string());
|
||||
let client = StubHttp { data: Arc::new(data), etag: Arc::new(etag), accept_ranges: true };
|
||||
let mm: ModelManager<StubHttp> = ModelManager::new_with_client(client.clone(), Settings{ concurrency: 3, ..Default::default() }).unwrap();
|
||||
|
||||
// add
|
||||
let rec = mm.add_or_update("tiny", "gguf/models", "gguf-tiny-q4_0.bin").unwrap();
|
||||
assert_eq!(rec.alias, "tiny");
|
||||
assert!(mm.verify("tiny").unwrap());
|
||||
|
||||
// update (304)
|
||||
let rec2 = mm.add_or_update("tiny", "gguf/models", "gguf-tiny-q4_0.bin").unwrap();
|
||||
assert_eq!(rec2.alias, "tiny");
|
||||
|
||||
// gc (nothing to remove)
|
||||
let (files_removed, entries_removed) = mm.gc().unwrap();
|
||||
assert_eq!(files_removed, 0);
|
||||
assert_eq!(entries_removed, 0);
|
||||
|
||||
// rm
|
||||
assert!(mm.rm("tiny").unwrap());
|
||||
assert!(!mm.rm("tiny").unwrap());
|
||||
}
|
||||
}
|
@@ -792,6 +792,13 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
|
||||
}
|
||||
|
||||
let part_path = dest_path.with_extension("part");
|
||||
// Guard to cleanup .part on errors
|
||||
struct TempGuard { path: std::path::PathBuf, armed: bool }
|
||||
impl TempGuard { fn disarm(&mut self) { self.armed = false; } }
|
||||
impl Drop for TempGuard {
|
||||
fn drop(&mut self) { if self.armed { let _ = fs::remove_file(&self.path); } }
|
||||
}
|
||||
let mut _tmp_guard = TempGuard { path: part_path.clone(), armed: true };
|
||||
|
||||
let mut resume_from: u64 = 0;
|
||||
if part_path.exists() && ranges_ok {
|
||||
@@ -893,6 +900,7 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
|
||||
drop(part_file);
|
||||
fs::rename(&part_path, dest_path)
|
||||
.with_context(|| format!("renaming {} → {}", part_path.display(), dest_path.display()))?;
|
||||
_tmp_guard.disarm();
|
||||
|
||||
let final_size = fs::metadata(dest_path).map(|m| m.len()).ok();
|
||||
let elapsed = start.elapsed().as_secs_f64();
|
||||
|
@@ -4,6 +4,8 @@ pub mod progress;
|
||||
|
||||
use std::io;
|
||||
use std::io::IsTerminal;
|
||||
use std::io::Write as _;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
pub fn info(msg: impl AsRef<str>) {
|
||||
let m = msg.as_ref();
|
||||
@@ -170,43 +172,156 @@ impl Spinner {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BytesProgress(Option<cliclack::ProgressBar>);
|
||||
pub struct BytesProgress {
|
||||
enabled: bool,
|
||||
total: u64,
|
||||
current: u64,
|
||||
started: Instant,
|
||||
last_msg: Instant,
|
||||
width: usize,
|
||||
// Sticky ETA to carry through zero-speed stalls
|
||||
last_eta_secs: Option<f64>,
|
||||
}
|
||||
|
||||
impl BytesProgress {
|
||||
pub fn start(total: u64, text: &str, initial: u64) -> Self {
|
||||
if crate::is_no_progress()
|
||||
let enabled = !(crate::is_no_progress()
|
||||
|| crate::is_no_interaction()
|
||||
|| !std::io::stderr().is_terminal()
|
||||
|| total == 0
|
||||
{
|
||||
|| total == 0);
|
||||
if !enabled {
|
||||
let _ = cliclack::log::info(text);
|
||||
return Self(None);
|
||||
}
|
||||
let b = cliclack::progress_bar(total);
|
||||
b.start(text);
|
||||
if initial > 0 {
|
||||
b.inc(initial);
|
||||
let mut me = Self {
|
||||
enabled,
|
||||
total,
|
||||
current: initial.min(total),
|
||||
started: Instant::now(),
|
||||
last_msg: Instant::now(),
|
||||
width: 40,
|
||||
last_eta_secs: None,
|
||||
};
|
||||
me.draw();
|
||||
me
|
||||
}
|
||||
|
||||
fn human_bytes(n: u64) -> String {
|
||||
const KB: f64 = 1024.0;
|
||||
const MB: f64 = 1024.0 * KB;
|
||||
const GB: f64 = 1024.0 * MB;
|
||||
let x = n as f64;
|
||||
if x >= GB {
|
||||
format!("{:.2} GiB", x / GB)
|
||||
} else if x >= MB {
|
||||
format!("{:.2} MiB", x / MB)
|
||||
} else if x >= KB {
|
||||
format!("{:.2} KiB", x / KB)
|
||||
} else {
|
||||
format!("{} B", n)
|
||||
}
|
||||
Self(Some(b))
|
||||
}
|
||||
|
||||
// Elapsed formatting is used for stable, finite durations. For ETA, we guard
|
||||
// against zero-speed or unstable estimates separately via `format_eta`.
|
||||
|
||||
fn refresh_allowed(&mut self) -> (f64, f64) {
|
||||
let now = Instant::now();
|
||||
let since_last = now.duration_since(self.last_msg);
|
||||
if since_last < Duration::from_millis(100) {
|
||||
// Too soon to refresh; keep previous ETA if any
|
||||
let eta = self.last_eta_secs.unwrap_or(f64::INFINITY);
|
||||
return (0.0, eta);
|
||||
}
|
||||
self.last_msg = now;
|
||||
let elapsed = now.duration_since(self.started).as_secs_f64().max(0.001);
|
||||
let speed = (self.current as f64) / elapsed;
|
||||
let remaining = self.total.saturating_sub(self.current) as f64;
|
||||
|
||||
// If speed is effectively zero, carry ETA forward and add wall time.
|
||||
const EPS: f64 = 1e-6;
|
||||
let eta = if speed <= EPS {
|
||||
let prev = self.last_eta_secs.unwrap_or(f64::INFINITY);
|
||||
if prev.is_finite() {
|
||||
prev + since_last.as_secs_f64()
|
||||
} else {
|
||||
prev
|
||||
}
|
||||
} else {
|
||||
remaining / speed
|
||||
};
|
||||
// Remember only finite ETAs to use during stalls
|
||||
if eta.is_finite() {
|
||||
self.last_eta_secs = Some(eta);
|
||||
}
|
||||
(speed, eta)
|
||||
}
|
||||
|
||||
fn format_elapsed(seconds: f64) -> String {
|
||||
let total = seconds.round() as u64;
|
||||
let h = total / 3600;
|
||||
let m = (total % 3600) / 60;
|
||||
let s = total % 60;
|
||||
if h > 0 { format!("{:02}:{:02}:{:02}", h, m, s) } else { format!("{:02}:{:02}", m, s) }
|
||||
}
|
||||
|
||||
fn format_eta(seconds: f64) -> String {
|
||||
// If ETA is not finite (e.g., divide-by-zero speed) or unreasonably large,
|
||||
// show a placeholder rather than overflowing into huge values.
|
||||
if !seconds.is_finite() {
|
||||
return "—".to_string();
|
||||
}
|
||||
// Cap ETA display to 99:59:59 to avoid silly numbers; beyond that, show placeholder.
|
||||
const CAP_SECS: f64 = 99.0 * 3600.0 + 59.0 * 60.0 + 59.0;
|
||||
if seconds > CAP_SECS {
|
||||
return "—".to_string();
|
||||
}
|
||||
Self::format_elapsed(seconds)
|
||||
}
|
||||
|
||||
fn draw(&mut self) {
|
||||
if !self.enabled { return; }
|
||||
let (speed, eta) = self.refresh_allowed();
|
||||
let elapsed = Instant::now().duration_since(self.started).as_secs_f64();
|
||||
// Build bar
|
||||
let width = self.width.max(10);
|
||||
let filled = ((self.current as f64 / self.total.max(1) as f64) * width as f64).round() as usize;
|
||||
let filled = filled.min(width);
|
||||
let mut bar = String::with_capacity(width);
|
||||
for _ in 0..filled { bar.push('■'); }
|
||||
for _ in filled..width { bar.push('□'); }
|
||||
|
||||
let line = format!(
|
||||
"[{}] {} [{}] ({}/{} at {}/s)",
|
||||
Self::format_elapsed(elapsed),
|
||||
bar,
|
||||
Self::format_eta(eta),
|
||||
Self::human_bytes(self.current),
|
||||
Self::human_bytes(self.total),
|
||||
Self::human_bytes(speed.max(0.0) as u64),
|
||||
);
|
||||
eprint!("\r{}\x1b[K", line);
|
||||
let _ = io::stderr().flush();
|
||||
}
|
||||
|
||||
pub fn inc(&mut self, delta: u64) {
|
||||
if let Some(b) = self.0.as_mut() {
|
||||
b.inc(delta);
|
||||
}
|
||||
self.current = self.current.saturating_add(delta).min(self.total);
|
||||
self.draw();
|
||||
}
|
||||
|
||||
pub fn stop(mut self, text: &str) {
|
||||
if let Some(b) = self.0.take() {
|
||||
b.stop(text);
|
||||
if self.enabled {
|
||||
self.draw();
|
||||
eprintln!();
|
||||
} else {
|
||||
let _ = cliclack::log::info(text);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn error(mut self, text: &str) {
|
||||
if let Some(b) = self.0.take() {
|
||||
b.error(text);
|
||||
if self.enabled {
|
||||
self.draw();
|
||||
eprintln!();
|
||||
let _ = cliclack::log::error(text);
|
||||
} else {
|
||||
let _ = cliclack::log::error(text);
|
||||
}
|
||||
|
@@ -1,12 +1,12 @@
|
||||
[package]
|
||||
name = "polyscribe-host"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.99"
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
serde_json = "1.0.142"
|
||||
tokio = { version = "1.47.1", features = ["rt-multi-thread", "process", "io-util"] }
|
||||
which = "6.0.3"
|
||||
anyhow = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt-multi-thread", "process", "io-util"] }
|
||||
which = { workspace = true }
|
||||
directories = { workspace = true }
|
||||
|
@@ -1,8 +1,8 @@
|
||||
[package]
|
||||
name = "polyscribe-protocol"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
serde_json = "1.0.142"
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
|
Reference in New Issue
Block a user