Compare commits

...

4 Commits

15 changed files with 1598 additions and 148 deletions

1
Cargo.lock generated
View File

@@ -251,6 +251,7 @@ dependencies = [
"iana-time-zone",
"js-sys",
"num-traits",
"serde",
"wasm-bindgen",
"windows-link",
]

View File

@@ -7,6 +7,12 @@ members = [
]
resolver = "3"
[workspace.package]
edition = "2024"
version = "0.1.0"
license = "MIT"
rust-version = "1.89"
# Optional: Keep dependency versions consistent across members
[workspace.dependencies]
thiserror = "1.0.69"
@@ -15,7 +21,7 @@ anyhow = "1.0.99"
libc = "0.2.175"
toml = "0.8.23"
serde_json = "1.0.142"
chrono = "0.4.41"
chrono = { version = "0.4.41", features = ["serde"] }
sha2 = "0.10.9"
which = "6.0.3"
tokio = { version = "1.47.1", features = ["rt-multi-thread", "macros"] }
@@ -26,6 +32,14 @@ cliclack = "0.3.6"
clap_complete = "4.5.57"
clap_mangen = "0.2.29"
# Additional shared deps used across members
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
reqwest = { version = "0.12.7", default-features = false, features = ["blocking", "rustls-tls", "gzip", "json"] }
hex = "0.4.3"
tempfile = "3.12.0"
assert_cmd = "2.0.16"
[workspace.lints.rust]
unused_imports = "deny"
dead_code = "warn"

View File

@@ -1,16 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
fn main() {
// Only run special build steps when gpu-vulkan feature is enabled.
let vulkan_enabled = std::env::var("CARGO_FEATURE_GPU_VULKAN").is_ok();
if !vulkan_enabled {
return;
}
// Placeholder: In a full implementation, we would invoke CMake for whisper.cpp with GGML_VULKAN=1.
// For now, emit a helpful note. Build will proceed; runtime Vulkan backend returns an explanatory error.
println!("cargo:rerun-if-changed=extern/whisper.cpp");
println!(
"cargo:warning=Building with gpu-vulkan: ensure Vulkan SDK/loader are installed. Future versions will compile whisper.cpp via CMake."
);
}

View File

@@ -1,24 +1,24 @@
[package]
name = "polyscribe-cli"
version = "0.1.0"
edition = "2024"
version.workspace = true
edition.workspace = true
[[bin]]
name = "polyscribe"
path = "src/main.rs"
[dependencies]
anyhow = "1.0.99"
clap = { version = "4.5.44", features = ["derive"] }
clap_complete = "4.5.57"
clap_mangen = "0.2.29"
directories = "5.0.1"
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.142"
tokio = { version = "1.47.1", features = ["rt-multi-thread", "macros", "process", "fs"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
which = "6.0.3"
anyhow = { workspace = true }
clap = { workspace = true, features = ["derive"] }
clap_complete = { workspace = true }
clap_mangen = { workspace = true }
directories = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
tokio = { workspace = true, features = ["rt-multi-thread", "macros", "process", "fs"] }
tracing = { workspace = true }
tracing-subscriber = { workspace = true, features = ["fmt", "env-filter"] }
which = { workspace = true }
polyscribe-core = { path = "../polyscribe-core" }
polyscribe-host = { path = "../polyscribe-host" }
@@ -29,4 +29,4 @@ polyscribe-protocol = { path = "../polyscribe-protocol" }
default = []
[dev-dependencies]
assert_cmd = "2.0.16"
assert_cmd = { workspace = true }

View File

@@ -1,4 +1,4 @@
use clap::{Parser, Subcommand, ValueEnum};
use clap::{Args, Parser, Subcommand, ValueEnum};
use std::path::PathBuf;
#[derive(Debug, Clone, ValueEnum)]
@@ -10,21 +10,33 @@ pub enum GpuBackend {
Vulkan,
}
#[derive(Debug, Clone, Args)]
pub struct OutputOpts {
/// Emit machine-readable JSON to stdout; suppress decorative logs
#[arg(long, global = true, action = clap::ArgAction::SetTrue)]
pub json: bool,
/// Reduce log chatter (errors only unless --json)
#[arg(long, global = true, action = clap::ArgAction::SetTrue)]
pub quiet: bool,
}
#[derive(Debug, Parser)]
#[command(
name = "polyscribe",
version,
about = "PolyScribe local-first transcription and plugins"
about = "PolyScribe local-first transcription and plugins",
propagate_version = true,
arg_required_else_help = true,
)]
pub struct Cli {
/// Global output options
#[command(flatten)]
pub output: OutputOpts,
/// Increase verbosity (-v, -vv)
#[arg(short, long, action = clap::ArgAction::Count)]
pub verbose: u8,
/// Quiet mode (suppresses non-error logs)
#[arg(short, long, default_value_t = false)]
pub quiet: bool,
/// Never prompt for user input (non-interactive mode)
#[arg(long, default_value_t = false)]
pub no_interaction: bool,
@@ -74,7 +86,7 @@ pub enum Commands {
inputs: Vec<PathBuf>,
},
/// Manage Whisper models
/// Manage Whisper GGUF models (Hugging Face)
Models {
#[command(subcommand)]
cmd: ModelsCmd,
@@ -97,14 +109,64 @@ pub enum Commands {
Man,
}
#[derive(Debug, Clone, Parser)]
pub struct ModelCommon {
/// Concurrency for ranged downloads
#[arg(long, default_value_t = 4)]
pub concurrency: usize,
/// Limit download rate in bytes/sec (approximate)
#[arg(long)]
pub limit_rate: Option<u64>,
}
#[derive(Debug, Subcommand)]
pub enum ModelsCmd {
/// Verify or update local models non-interactively
Update,
/// Interactive multi-select downloader
Download,
/// Clear the cached Hugging Face manifest
ClearCache,
/// List installed models (from manifest)
Ls {
#[command(flatten)]
common: ModelCommon,
},
/// Add or update a model
Add {
/// Hugging Face repo, e.g. ggml-org/models
repo: String,
/// File name in repo (e.g., gguf-tiny-q4_0.bin)
file: String,
#[command(flatten)]
common: ModelCommon,
},
/// Remove a model by alias
Rm {
alias: String,
#[command(flatten)]
common: ModelCommon,
},
/// Verify model file integrity by alias
Verify {
alias: String,
#[command(flatten)]
common: ModelCommon,
},
/// Update all models (HEAD + ETag; skip if unchanged)
Update {
#[command(flatten)]
common: ModelCommon,
},
/// Garbage-collect unreferenced files and stale manifest entries
Gc {
#[command(flatten)]
common: ModelCommon,
},
/// Search a repo for GGUF files
Search {
/// Hugging Face repo, e.g. ggml-org/models
repo: String,
/// Optional substring to filter filenames
#[arg(long)]
query: Option<String>,
#[command(flatten)]
common: ModelCommon,
},
}
#[derive(Debug, Subcommand)]

View File

@@ -1,41 +1,84 @@
mod cli;
mod output;
use anyhow::{Context, Result, anyhow};
use clap::{CommandFactory, Parser};
use cli::{Cli, Commands, GpuBackend, ModelsCmd, PluginsCmd};
use polyscribe_core::models;
use cli::{Cli, Commands, GpuBackend, ModelsCmd, ModelCommon, PluginsCmd};
use output::OutputMode;
use polyscribe_core::model_manager::{ModelManager, Settings, ReqwestClient};
use polyscribe_core::ui;
fn normalized_similarity(a: &str, b: &str) -> f64 {
// simple Levenshtein distance; normalized to [0,1]
let a_bytes = a.as_bytes();
let b_bytes = b.as_bytes();
let n = a_bytes.len();
let m = b_bytes.len();
if n == 0 && m == 0 { return 1.0; }
if n == 0 || m == 0 { return 0.0; }
let mut prev: Vec<usize> = (0..=m).collect();
let mut curr: Vec<usize> = vec![0; m + 1];
for i in 1..=n {
curr[0] = i;
for j in 1..=m {
let cost = if a_bytes[i - 1] == b_bytes[j - 1] { 0 } else { 1 };
curr[j] = (prev[j] + 1)
.min(curr[j - 1] + 1)
.min(prev[j - 1] + cost);
}
std::mem::swap(&mut prev, &mut curr);
}
let dist = prev[m] as f64;
let max_len = n.max(m) as f64;
1.0 - (dist / max_len)
}
fn human_size(bytes: Option<u64>) -> String {
match bytes {
Some(n) => {
let x = n as f64;
const KB: f64 = 1024.0;
const MB: f64 = 1024.0 * KB;
const GB: f64 = 1024.0 * MB;
if x >= GB { format!("{:.2} GiB", x / GB) }
else if x >= MB { format!("{:.2} MiB", x / MB) }
else if x >= KB { format!("{:.2} KiB", x / KB) }
else { format!("{} B", n) }
}
None => "?".to_string(),
}
}
use polyscribe_core::ui::progress::ProgressReporter;
use polyscribe_host::PluginManager;
use tokio::io::AsyncWriteExt;
use tracing_subscriber::EnvFilter;
fn init_tracing(quiet: bool, verbose: u8) {
let log_level = if quiet {
"error"
} else {
match verbose {
0 => "info",
1 => "debug",
_ => "trace",
}
};
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(log_level));
fn init_tracing(json_mode: bool, quiet: bool, verbose: u8) {
// In JSON mode, suppress human logs; route errors to stderr only.
let level = if json_mode || quiet { "error" } else { match verbose { 0 => "info", 1 => "debug", _ => "trace" } };
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(level));
tracing_subscriber::fmt()
.with_env_filter(filter)
.with_target(false)
.with_level(true)
.with_writer(std::io::stderr)
.compact()
.init();
}
#[tokio::main]
async fn main() -> Result<()> {
fn main() -> Result<()> {
let args = Cli::parse();
init_tracing(args.quiet, args.verbose);
// Determine output mode early for logging and UI configuration
let output_mode = if args.output.json {
OutputMode::Json
} else {
OutputMode::Human { quiet: args.output.quiet }
};
polyscribe_core::set_quiet(args.quiet);
init_tracing(matches!(output_mode, OutputMode::Json), args.output.quiet, args.verbose);
// Suppress decorative UI output in JSON mode as well
polyscribe_core::set_quiet(args.output.quiet || matches!(output_mode, OutputMode::Json));
polyscribe_core::set_no_interaction(args.no_interaction);
polyscribe_core::set_verbose(args.verbose);
polyscribe_core::set_no_progress(args.no_progress);
@@ -71,32 +114,274 @@ async fn main() -> Result<()> {
}
Commands::Models { cmd } => {
match cmd {
ModelsCmd::Update => {
polyscribe_core::ui::info("verifying/updating local models");
tokio::task::spawn_blocking(models::update_local_models)
.await
.map_err(|e| anyhow!("blocking task join error: {e}"))?
.context("updating models")?;
// predictable exit codes
const EXIT_OK: i32 = 0;
const EXIT_NOT_FOUND: i32 = 2;
const EXIT_NETWORK: i32 = 3;
const EXIT_VERIFY_FAILED: i32 = 4;
// const EXIT_NO_CHANGE: i32 = 5; // reserved
let handle_common = |c: &ModelCommon| Settings {
concurrency: c.concurrency.max(1),
limit_rate: c.limit_rate,
..Default::default()
};
let exit = match cmd {
ModelsCmd::Ls { common } => {
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
let list = mm.ls()?;
match output_mode {
OutputMode::Json => {
// Always emit JSON array (possibly empty)
output_mode.print_json(&list);
}
OutputMode::Human { quiet } => {
if list.is_empty() {
if !quiet { println!("No models installed."); }
} else {
if !quiet { println!("Model (Repo)"); }
for r in list {
if !quiet { println!("{} ({})", r.file, r.repo); }
}
}
}
}
EXIT_OK
}
ModelsCmd::Download => {
polyscribe_core::ui::info("interactive model selection and download");
tokio::task::spawn_blocking(models::run_interactive_model_downloader)
.await
.map_err(|e| anyhow!("blocking task join error: {e}"))?
.context("running downloader")?;
polyscribe_core::ui::success("Model download complete.");
ModelsCmd::Add { repo, file, common } => {
let settings = handle_common(&common);
let mm: ModelManager<ReqwestClient> = ModelManager::new(settings.clone())?;
// Derive an alias automatically from repo and file
fn derive_alias(repo: &str, file: &str) -> String {
use std::path::Path;
let repo_tail = repo.rsplit('/').next().unwrap_or(repo);
let stem = Path::new(file)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or(file);
format!("{}-{}", repo_tail, stem)
}
let alias = derive_alias(&repo, &file);
match mm.add_or_update(&alias, &repo, &file) {
Ok(rec) => {
match output_mode {
OutputMode::Json => output_mode.print_json(&rec),
OutputMode::Human { quiet } => {
if !quiet { println!("installed: {} -> {}/{}", alias, repo, rec.file); }
}
}
EXIT_OK
}
Err(e) => {
// On not found or similar errors, try suggesting close matches interactively
if matches!(output_mode, OutputMode::Json) || polyscribe_core::is_no_interaction() {
match output_mode {
OutputMode::Json => {
// Emit error JSON object
#[derive(serde::Serialize)]
struct ErrObj<'a> { error: &'a str }
let eo = ErrObj { error: &e.to_string() };
output_mode.print_json(&eo);
}
_ => { eprintln!("error: {e}"); }
}
EXIT_NOT_FOUND
} else {
ui::warn(format!("{}", e));
ui::info("Searching for similar model filenames…");
match polyscribe_core::model_manager::search_repo(&repo, None) {
Ok(mut files) => {
if files.is_empty() {
ui::warn("No files found in repository.");
EXIT_NOT_FOUND
} else {
// rank by similarity
files.sort_by(|a, b| normalized_similarity(&file, b)
.partial_cmp(&normalized_similarity(&file, a))
.unwrap_or(std::cmp::Ordering::Equal));
let top: Vec<String> = files.into_iter().take(5).collect();
if top.is_empty() {
EXIT_NOT_FOUND
} else if top.len() == 1 {
let cand = &top[0];
// Fetch repo size list once
let size_map: std::collections::HashMap<String, Option<u64>> =
polyscribe_core::model_manager::list_repo_files_with_meta(&repo)
.unwrap_or_default()
.into_iter().collect();
let mut size = size_map.get(cand).cloned().unwrap_or(None);
if size.is_none() {
size = polyscribe_core::model_manager::head_len_for_file(&repo, cand);
}
let local_files: std::collections::HashSet<String> = mm.ls()?.into_iter().map(|r| r.file).collect();
let is_local = local_files.contains(cand);
let label = format!("{} [{}]{}", cand, human_size(size), if is_local { " (local)" } else { "" });
let ok = ui::prompt_confirm(&format!("Did you mean {}?", label), true)
.unwrap_or(false);
if !ok { EXIT_NOT_FOUND } else {
let mm2: ModelManager<ReqwestClient> = ModelManager::new(settings)?;
let alias2 = derive_alias(&repo, cand);
match mm2.add_or_update(&alias2, &repo, cand) {
Ok(rec) => {
match output_mode {
OutputMode::Json => output_mode.print_json(&rec),
OutputMode::Human { quiet } => { if !quiet { println!("installed: {} -> {}/{}", alias2, repo, rec.file); } }
}
EXIT_OK
}
Err(e2) => { eprintln!("error: {e2}"); EXIT_NETWORK }
}
}
} else {
let opts: Vec<String> = top;
let local_files: std::collections::HashSet<String> = mm.ls()?.into_iter().map(|r| r.file).collect();
// Enrich labels with size and local tag using a single API call
let size_map: std::collections::HashMap<String, Option<u64>> =
polyscribe_core::model_manager::list_repo_files_with_meta(&repo)
.unwrap_or_default()
.into_iter().collect();
let mut labels_owned: Vec<String> = Vec::new();
for f in &opts {
let mut size = size_map.get(f).cloned().unwrap_or(None);
if size.is_none() {
size = polyscribe_core::model_manager::head_len_for_file(&repo, f);
}
let is_local = local_files.contains(f);
let suffix = if is_local { " (local)" } else { "" };
labels_owned.push(format!("{} [{}]{}", f, human_size(size), suffix));
}
let labels: Vec<&str> = labels_owned.iter().map(|s| s.as_str()).collect();
match ui::prompt_select("Pick a model", &labels) {
Ok(idx) => {
let chosen = &opts[idx];
let mm2: ModelManager<ReqwestClient> = ModelManager::new(settings)?;
let alias2 = derive_alias(&repo, chosen);
match mm2.add_or_update(&alias2, &repo, chosen) {
Ok(rec) => {
match output_mode {
OutputMode::Json => output_mode.print_json(&rec),
OutputMode::Human { quiet } => { if !quiet { println!("installed: {} -> {}/{}", alias2, repo, rec.file); } }
}
EXIT_OK
}
Err(e2) => { eprintln!("error: {e2}"); EXIT_NETWORK }
}
}
Err(_) => EXIT_NOT_FOUND,
}
}
}
}
Err(e2) => {
eprintln!("error: {}", e2);
EXIT_NETWORK
}
}
}
}
}
}
ModelsCmd::ClearCache => {
polyscribe_core::ui::info("clearing manifest cache");
tokio::task::spawn_blocking(models::clear_manifest_cache)
.await
.map_err(|e| anyhow!("blocking task join error: {e}"))?
.context("clearing cache")?;
polyscribe_core::ui::success("Manifest cache cleared.");
ModelsCmd::Rm { alias, common } => {
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
let ok = mm.rm(&alias)?;
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R { removed: bool }
output_mode.print_json(&R { removed: ok });
}
OutputMode::Human { quiet } => {
if !quiet { println!("{}", if ok { "removed" } else { "not found" }); }
}
}
if ok { EXIT_OK } else { EXIT_NOT_FOUND }
}
}
Ok(())
ModelsCmd::Verify { alias, common } => {
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
let found = mm.ls()?.into_iter().any(|r| r.alias == alias);
if !found {
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R<'a> { ok: bool, error: &'a str }
output_mode.print_json(&R { ok: false, error: "not found" });
}
OutputMode::Human { quiet } => { if !quiet { println!("not found"); } }
}
EXIT_NOT_FOUND
} else {
let ok = mm.verify(&alias)?;
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R { ok: bool }
output_mode.print_json(&R { ok });
}
OutputMode::Human { quiet } => { if !quiet { println!("{}", if ok { "ok" } else { "corrupt" }); } }
}
if ok { EXIT_OK } else { EXIT_VERIFY_FAILED }
}
}
ModelsCmd::Update { common } => {
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
let mut rc = EXIT_OK;
for rec in mm.ls()? {
match mm.add_or_update(&rec.alias, &rec.repo, &rec.file) {
Ok(_) => {}
Err(e) => {
rc = EXIT_NETWORK;
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R<'a> { alias: &'a str, error: String }
output_mode.print_json(&R { alias: &rec.alias, error: e.to_string() });
}
_ => { eprintln!("update {}: {e}", rec.alias); }
}
}
}
}
rc
}
ModelsCmd::Gc { common } => {
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
let (files_removed, entries_removed) = mm.gc()?;
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R { files_removed: usize, entries_removed: usize }
output_mode.print_json(&R { files_removed, entries_removed });
}
OutputMode::Human { quiet } => { if !quiet { println!("files_removed={} entries_removed={}", files_removed, entries_removed); } }
}
EXIT_OK
}
ModelsCmd::Search { repo, query, common } => {
let res = polyscribe_core::model_manager::search_repo(&repo, query.as_deref());
match res {
Ok(files) => {
match output_mode {
OutputMode::Json => output_mode.print_json(&files),
OutputMode::Human { quiet } => { for f in files { if !quiet { println!("{}", f); } } }
}
EXIT_OK
}
Err(e) => {
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R { error: String }
output_mode.print_json(&R { error: e.to_string() });
}
_ => { eprintln!("error: {e}"); }
}
EXIT_NETWORK
}
}
}
};
std::process::exit(exit);
}
Commands::Plugins { cmd } => {
@@ -123,27 +408,35 @@ async fn main() -> Result<()> {
command,
json,
} => {
let payload = json.unwrap_or_else(|| "{}".to_string());
let mut child = plugin_manager
.spawn(&name, &command)
.with_context(|| format!("spawning plugin {name} {command}"))?;
// Use a local Tokio runtime only for this async path
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.context("building tokio runtime")?;
if let Some(mut stdin) = child.stdin.take() {
stdin
.write_all(payload.as_bytes())
.await
.context("writing JSON payload to plugin stdin")?;
}
rt.block_on(async {
let payload = json.unwrap_or_else(|| "{}".to_string());
let mut child = plugin_manager
.spawn(&name, &command)
.with_context(|| format!("spawning plugin {name} {command}"))?;
let status = plugin_manager.forward_stdio(&mut child).await?;
if !status.success() {
polyscribe_core::ui::error(format!(
"plugin returned non-zero exit code: {}",
status
));
return Err(anyhow!("plugin failed"));
}
Ok(())
if let Some(mut stdin) = child.stdin.take() {
stdin
.write_all(payload.as_bytes())
.await
.context("writing JSON payload to plugin stdin")?;
}
let status = plugin_manager.forward_stdio(&mut child).await?;
if !status.success() {
polyscribe_core::ui::error(format!(
"plugin returned non-zero exit code: {}",
status
));
return Err(anyhow!("plugin failed"));
}
Ok(())
})
}
}
}

View File

@@ -0,0 +1,36 @@
use std::io::{self, Write};
#[derive(Clone, Debug)]
pub enum OutputMode {
Json,
Human { quiet: bool },
}
impl OutputMode {
pub fn is_quiet(&self) -> bool {
matches!(self, OutputMode::Json) || matches!(self, OutputMode::Human { quiet: true })
}
pub fn print_json<T: serde::Serialize>(&self, v: &T) {
if let OutputMode::Json = self {
// Write compact JSON to stdout without prefixes
// and ensure a trailing newline for CLI ergonomics
let s = serde_json::to_string(v).unwrap_or_else(|e| format!("\"JSON_ERROR:{}\"", e));
println!("{}", s);
}
}
pub fn print_line(&self, s: impl AsRef<str>) {
match self {
OutputMode::Json => {
// Suppress human lines in JSON mode
}
OutputMode::Human { quiet } => {
if !quiet {
let _ = writeln!(io::stdout(), "{}", s.as_ref());
}
}
}
}
}

View File

@@ -0,0 +1,42 @@
use assert_cmd::cargo::cargo_bin;
use std::process::Command;
fn bin() -> std::path::PathBuf { cargo_bin("polyscribe") }
#[test]
fn models_help_shows_global_output_flags() {
let out = Command::new(bin())
.args(["models", "--help"]) // subcommand help
.output()
.expect("failed to run polyscribe models --help");
assert!(out.status.success(), "help exited non-zero: {:?}", out.status);
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
assert!(stdout.contains("--json"), "--json not shown in help: {stdout}");
assert!(stdout.contains("--quiet"), "--quiet not shown in help: {stdout}");
}
#[test]
fn models_version_contains_pkg_version() {
let out = Command::new(bin())
.args(["models", "--version"]) // propagate_version
.output()
.expect("failed to run polyscribe models --version");
assert!(out.status.success(), "version exited non-zero: {:?}", out.status);
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
let want = env!("CARGO_PKG_VERSION");
assert!(stdout.contains(want), "version output missing {want}: {stdout}");
}
#[test]
fn models_ls_json_quiet_emits_pure_json() {
let out = Command::new(bin())
.args(["models", "ls", "--json", "--quiet"]) // global flags
.output()
.expect("failed to run polyscribe models ls --json --quiet");
assert!(out.status.success(), "ls exited non-zero: {:?}", out.status);
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
serde_json::from_str::<serde_json::Value>(stdout.trim()).expect("stdout is not valid JSON");
// Expect no extra logs on stdout; stderr should be empty in success path
assert!(out.stderr.is_empty(), "expected no stderr noise");
}

View File

@@ -1,22 +1,22 @@
[package]
name = "polyscribe-core"
version = "0.1.0"
edition = "2024"
version.workspace = true
edition.workspace = true
[dependencies]
anyhow = "1.0.99"
thiserror = "1.0.69"
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.142"
toml = "0.8.23"
directories = "5.0.1"
chrono = "0.4.41"
libc = "0.2.175"
whisper-rs = "0.14.3"
anyhow = { workspace = true }
thiserror = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
toml = { workspace = true }
directories = { workspace = true }
chrono = { workspace = true }
libc = { workspace = true }
whisper-rs = { workspace = true }
# UI and progress
cliclack = { workspace = true }
# New: HTTP downloads + hashing
reqwest = { version = "0.12.7", default-features = false, features = ["blocking", "rustls-tls", "gzip", "json"] }
sha2 = "0.10.8"
hex = "0.4.3"
tempfile = "3.12.0"
# HTTP downloads + hashing
reqwest = { workspace = true }
sha2 = { workspace = true }
hex = { workspace = true }
tempfile = { workspace = true }

View File

@@ -214,6 +214,8 @@ pub fn render_srt(entries: &[OutputEntry]) -> String {
srt
}
pub mod model_manager;
pub fn models_dir_path() -> PathBuf {
if let Ok(env_val) = env::var("POLYSCRIBE_MODELS_DIR") {
let env_path = PathBuf::from(env_val);

View File

@@ -0,0 +1,893 @@
// SPDX-License-Identifier: MIT
use crate::prelude::*;
use crate::ui::BytesProgress;
use anyhow::{anyhow, Context};
use chrono::{DateTime, Utc};
use reqwest::blocking::Client;
use reqwest::header::{
ACCEPT_RANGES, AUTHORIZATION, CONTENT_LENGTH, ETAG, IF_NONE_MATCH, LAST_MODIFIED, RANGE,
};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::cmp::min;
use std::collections::BTreeMap;
use std::fs::{self, File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::sync::{mpsc, Arc, Mutex};
use std::thread;
use std::time::Duration;
const DEFAULT_CHUNK_SIZE: u64 = 8 * 1024 * 1024; // 8 MiB
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ModelRecord {
pub alias: String,
pub repo: String,
pub file: String,
pub revision: Option<String>, // ETag or commit hash
pub sha256: Option<String>,
pub size_bytes: Option<u64>,
pub quant: Option<String>,
pub installed_at: Option<DateTime<Utc>>,
pub last_used: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Manifest {
pub models: BTreeMap<String, ModelRecord>, // key = alias
}
#[derive(Debug, Clone)]
pub struct Settings {
pub concurrency: usize,
pub limit_rate: Option<u64>, // bytes/sec
pub chunk_size: u64,
}
impl Default for Settings {
fn default() -> Self {
Self {
concurrency: 4,
limit_rate: None,
chunk_size: DEFAULT_CHUNK_SIZE,
}
}
}
#[derive(Debug, Clone)]
pub struct Paths {
pub cache_dir: PathBuf, // $XDG_CACHE_HOME/polyscribe/models
pub config_path: PathBuf, // $XDG_CONFIG_HOME/polyscribe/models.json
}
impl Paths {
pub fn resolve() -> Result<Self> {
if let Ok(over) = std::env::var("POLYSCRIBE_CACHE_DIR") {
if !over.is_empty() {
let cache_dir = PathBuf::from(over).join("models");
let config_path = std::env::var("POLYSCRIBE_CONFIG_DIR")
.map(|p| PathBuf::from(p).join("models.json"))
.unwrap_or_else(|_| default_config_path());
return Ok(Self {
cache_dir,
config_path,
});
}
}
let cache_dir = default_cache_dir();
let config_path = default_config_path();
Ok(Self {
cache_dir,
config_path,
})
}
}
fn default_cache_dir() -> PathBuf {
if let Ok(xdg) = std::env::var("XDG_CACHE_HOME") {
if !xdg.is_empty() {
return PathBuf::from(xdg).join("polyscribe").join("models");
}
}
if let Ok(home) = std::env::var("HOME") {
if !home.is_empty() {
return PathBuf::from(home)
.join(".cache")
.join("polyscribe")
.join("models");
}
}
PathBuf::from("models")
}
fn default_config_path() -> PathBuf {
if let Ok(xdg) = std::env::var("XDG_CONFIG_HOME") {
if !xdg.is_empty() {
return PathBuf::from(xdg).join("polyscribe").join("models.json");
}
}
if let Ok(home) = std::env::var("HOME") {
if !home.is_empty() {
return PathBuf::from(home)
.join(".config")
.join("polyscribe")
.join("models.json");
}
}
PathBuf::from("models.json")
}
pub trait HttpClient: Send + Sync {
fn head(&self, url: &str, etag: Option<&str>) -> Result<HeadMeta>;
fn get_range(&self, url: &str, start: u64, end_inclusive: u64) -> Result<Vec<u8>>;
fn get_whole_to(&self, url: &str, writer: &mut dyn Write) -> Result<()>;
fn get_from_to(&self, url: &str, start: u64, writer: &mut dyn Write) -> Result<()>;
}
#[derive(Debug, Clone)]
pub struct ReqwestClient {
client: Client,
token: Option<String>,
}
impl ReqwestClient {
pub fn new() -> Result<Self> {
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
let client = Client::builder()
.user_agent(crate::config::ConfigService::user_agent())
.build()?;
Ok(Self { client, token })
}
fn auth(&self, mut req: reqwest::blocking::RequestBuilder) -> reqwest::blocking::RequestBuilder {
if let Some(t) = &self.token {
req = req.header(AUTHORIZATION, format!("Bearer {}", t));
}
req
}
}
#[derive(Debug, Clone)]
pub struct HeadMeta {
pub len: Option<u64>,
pub etag: Option<String>,
pub last_modified: Option<String>,
pub accept_ranges: bool,
pub not_modified: bool,
pub status: u16,
}
impl HttpClient for ReqwestClient {
fn head(&self, url: &str, etag: Option<&str>) -> Result<HeadMeta> {
let mut req = self.client.head(url);
if let Some(e) = etag {
req = req.header(IF_NONE_MATCH, format!("\"{}\"", e));
}
let resp = self.auth(req).send()?;
let status = resp.status().as_u16();
if status == 304 {
return Ok(HeadMeta {
len: None,
etag: etag.map(|s| s.to_string()),
last_modified: None,
accept_ranges: true,
not_modified: true,
status,
});
}
let len = resp
.headers()
.get(CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
let etag = resp
.headers()
.get(ETAG)
.and_then(|v| v.to_str().ok())
.map(|s| s.trim_matches('"').to_string());
let last_modified = resp
.headers()
.get(LAST_MODIFIED)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let accept_ranges = resp
.headers()
.get(ACCEPT_RANGES)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_ascii_lowercase().contains("bytes"))
.unwrap_or(false);
Ok(HeadMeta {
len,
etag,
last_modified,
accept_ranges,
not_modified: false,
status,
})
}
fn get_range(&self, url: &str, start: u64, end_inclusive: u64) -> Result<Vec<u8>> {
let range_val = format!("bytes={}-{}", start, end_inclusive);
let resp = self
.auth(self.client.get(url))
.header(RANGE, range_val)
.send()?;
if !resp.status().is_success() && resp.status().as_u16() != 206 {
return Err(anyhow!("HTTP {} for ranged GET", resp.status()).into());
}
let mut buf = Vec::new();
let mut r = resp;
r.copy_to(&mut buf)?;
Ok(buf)
}
fn get_whole_to(&self, url: &str, writer: &mut dyn Write) -> Result<()> {
let resp = self.auth(self.client.get(url)).send()?;
if !resp.status().is_success() {
return Err(anyhow!("HTTP {} for GET", resp.status()).into());
}
let mut r = resp;
r.copy_to(writer)?;
Ok(())
}
fn get_from_to(&self, url: &str, start: u64, writer: &mut dyn Write) -> Result<()> {
let mut req = self.auth(self.client.get(url));
if start > 0 {
req = req.header(RANGE, format!("bytes={}-", start));
}
let resp = req.send()?;
if !resp.status().is_success() && resp.status().as_u16() != 206 {
return Err(anyhow!("HTTP {} for ranged GET from {}", resp.status(), start).into());
}
let mut r = resp;
r.copy_to(writer)?;
Ok(())
}
}
pub struct ModelManager<C: HttpClient = ReqwestClient> {
pub paths: Paths,
pub settings: Settings,
client: Arc<C>,
}
impl<C: HttpClient + 'static> ModelManager<C> {
pub fn new_with_client(client: C, settings: Settings) -> Result<Self> {
Ok(Self {
paths: Paths::resolve()?,
settings,
client: Arc::new(client),
})
}
pub fn new(settings: Settings) -> Result<Self>
where
C: Default,
{
Ok(Self {
paths: Paths::resolve()?,
settings,
client: Arc::new(C::default()),
})
}
fn load_manifest(&self) -> Result<Manifest> {
let p = &self.paths.config_path;
if !p.exists() {
return Ok(Manifest::default());
}
let file = File::open(p).with_context(|| format!("open manifest {}", p.display()))?;
let m: Manifest = serde_json::from_reader(file).context("parse manifest")?;
Ok(m)
}
fn save_manifest(&self, m: &Manifest) -> Result<()> {
let p = &self.paths.config_path;
if let Some(dir) = p.parent() {
fs::create_dir_all(dir)
.with_context(|| format!("create config dir {}", dir.display()))?;
}
let tmp = p.with_extension("json.tmp");
let f = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(&tmp)?;
serde_json::to_writer_pretty(f, m).context("serialize manifest")?;
fs::rename(&tmp, p).with_context(|| format!("atomic rename {} -> {}", tmp.display(), p.display()))?;
Ok(())
}
fn model_path(&self, file: &str) -> PathBuf {
self.paths.cache_dir.join(file)
}
fn compute_sha256(path: &Path) -> Result<String> {
let mut f = File::open(path)?;
let mut hasher = Sha256::new();
let mut buf = [0u8; 64 * 1024];
loop {
let n = f.read(&mut buf)?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
Ok(format!("{:x}", hasher.finalize()))
}
pub fn ls(&self) -> Result<Vec<ModelRecord>> {
let m = self.load_manifest()?;
Ok(m.models.values().cloned().collect())
}
pub fn rm(&self, alias: &str) -> Result<bool> {
let mut m = self.load_manifest()?;
if let Some(rec) = m.models.remove(alias) {
let p = self.model_path(&rec.file);
let _ = fs::remove_file(&p);
self.save_manifest(&m)?;
return Ok(true);
}
Ok(false)
}
pub fn verify(&self, alias: &str) -> Result<bool> {
let m = self.load_manifest()?;
let Some(rec) = m.models.get(alias) else { return Ok(false) };
let p = self.model_path(&rec.file);
if !p.exists() { return Ok(false); }
if let Some(expected) = &rec.sha256 {
let actual = Self::compute_sha256(&p)?;
return Ok(&actual == expected);
}
Ok(true)
}
pub fn gc(&self) -> Result<(usize, usize)> {
// Remove files not referenced by manifest; also drop manifest entries whose file is missing
fs::create_dir_all(&self.paths.cache_dir).ok();
let mut m = self.load_manifest()?;
let mut referenced = BTreeMap::new();
for (alias, rec) in &m.models {
referenced.insert(rec.file.clone(), alias.clone());
}
let mut removed_files = 0usize;
if let Ok(rd) = fs::read_dir(&self.paths.cache_dir) {
for ent in rd.flatten() {
let p = ent.path();
if p.is_file() {
let fname = p.file_name().and_then(|s| s.to_str()).unwrap_or("");
if !referenced.contains_key(fname) {
let _ = fs::remove_file(&p);
removed_files += 1;
}
}
}
}
m.models.retain(|_, rec| self.model_path(&rec.file).exists());
let removed_entries = referenced
.keys()
.filter(|f| !self.model_path(f).exists())
.count();
self.save_manifest(&m)?;
Ok((removed_files, removed_entries))
}
pub fn add_or_update(
&self,
alias: &str,
repo: &str,
file: &str,
) -> Result<ModelRecord> {
fs::create_dir_all(&self.paths.cache_dir)
.with_context(|| format!("create cache dir {}", self.paths.cache_dir.display()))?;
let url = format!("https://huggingface.co/{}/resolve/main/{}?download=true", repo, file);
let mut manifest = self.load_manifest()?;
let prev = manifest.models.get(alias).cloned();
let prev_etag = prev.as_ref().and_then(|r| r.revision.clone());
// Fetch remote meta (size/hash) when available to verify the download
let (_api_size, api_sha) = hf_fetch_file_meta(repo, file).unwrap_or((None, None));
let head = self.client.head(&url, prev_etag.as_deref())?;
if head.not_modified {
// up-to-date; ensure record present and touch last_used
let mut rec = prev.ok_or_else(|| anyhow!("not installed yet but server says 304"))?;
rec.last_used = Some(Utc::now());
self.save_touch(&mut manifest, rec.clone())?;
return Ok(rec);
}
// Quick check: if HEAD failed (e.g., 404), report a helpful error before attempting download
if head.status >= 400 {
return Err(anyhow!(
"file not found or inaccessible: repo='{}' file='{}' (HTTP {})\nHint: run `polyscribe models search {} --query {}` to list available files",
repo,
file,
head.status,
repo,
file
).into());
}
let total_len = head.len.ok_or_else(|| anyhow!("missing content-length (HEAD)"))?;
let etag = head.etag.clone();
let dest_tmp = self.model_path(&format!("{}.part", file));
// If a previous cancelled download left a .part file, remove it to avoid clutter/resume.
if dest_tmp.exists() { let _ = fs::remove_file(&dest_tmp); }
// Guard to ensure .part is cleaned up on errors
struct TempGuard { path: PathBuf, armed: bool }
impl TempGuard { fn disarm(&mut self) { self.armed = false; } }
impl Drop for TempGuard {
fn drop(&mut self) {
if self.armed {
let _ = fs::remove_file(&self.path);
}
}
}
let mut _tmp_guard = TempGuard { path: dest_tmp.clone(), armed: true };
let dest_final = self.model_path(file);
// Do not resume after cancellation; start fresh to avoid stale .part files
let start_from = 0u64;
// Open tmp for write
let f = OpenOptions::new().create(true).write(true).read(true).open(&dest_tmp)?;
f.set_len(total_len)?; // pre-allocate for random writes
let f = Arc::new(Mutex::new(f));
// Create progress bar
let mut progress = BytesProgress::start(total_len, &format!("Downloading {}", file), start_from);
// Create work chunks
let chunk_size = self.settings.chunk_size;
let mut chunks = Vec::new();
let mut pos = start_from;
while pos < total_len {
let end = min(total_len - 1, pos + chunk_size - 1);
chunks.push((pos, end));
pos = end + 1;
}
// Attempt concurrent ranged download; on failure, fallback to streaming GET
let mut ranged_failed = false;
if head.accept_ranges && self.settings.concurrency > 1 {
let (work_tx, work_rx) = mpsc::channel::<(u64, u64)>();
let (prog_tx, prog_rx) = mpsc::channel::<u64>();
for ch in chunks {
work_tx.send(ch).unwrap();
}
drop(work_tx);
let rx = Arc::new(Mutex::new(work_rx));
let workers = self.settings.concurrency.max(1);
let mut handles = Vec::new();
for _ in 0..workers {
let rx = rx.clone();
let url = url.clone();
let client = self.client.clone();
let f = f.clone();
let limit = self.settings.limit_rate;
let prog_tx = prog_tx.clone();
let handle = thread::spawn(move || -> Result<()> {
loop {
let next = {
let guard = rx.lock().unwrap();
guard.recv().ok()
};
let Some((start, end)) = next else { break; };
let data = client.get_range(&url, start, end)?;
if let Some(max_bps) = limit {
let dur = Duration::from_secs_f64((data.len() as f64) / (max_bps as f64));
if dur > Duration::from_millis(1) {
thread::sleep(dur);
}
}
let mut guard = f.lock().unwrap();
guard.seek(SeekFrom::Start(start))?;
guard.write_all(&data)?;
let _ = prog_tx.send(data.len() as u64);
}
Ok(())
});
handles.push(handle);
}
drop(prog_tx);
for delta in prog_rx {
progress.inc(delta);
}
let mut ranged_err: Option<anyhow::Error> = None;
for h in handles {
match h.join() {
Ok(Ok(())) => {}
Ok(Err(e)) => { ranged_failed = true; if ranged_err.is_none() { ranged_err = Some(e.into()); } }
Err(_) => { ranged_failed = true; if ranged_err.is_none() { ranged_err = Some(anyhow!("worker panicked")); } }
}
}
} else {
ranged_failed = true;
}
if ranged_failed {
// Restart progress if we are abandoning previous partial state
if start_from > 0 {
progress.stop("retrying as streaming");
progress = BytesProgress::start(total_len, &format!("Downloading {}", file), 0);
}
// Fallback to streaming GET; try URL with and without ?download=true
let mut try_urls = vec![url.clone()];
if let Some((base, _qs)) = url.split_once('?') {
try_urls.push(base.to_string());
} else {
try_urls.push(format!("{}?download=true", url));
}
// Fresh temp file for streaming
let mut wf = OpenOptions::new().create(true).write(true).truncate(true).open(&dest_tmp)?;
let mut ok = false;
let mut last_err: Option<anyhow::Error> = None;
// Counting writer to update progress inline
struct CountingWriter<'a, 'b> {
inner: &'a mut File,
progress: &'b mut BytesProgress,
}
impl<'a, 'b> Write for CountingWriter<'a, 'b> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let n = self.inner.write(buf)?;
self.progress.inc(n as u64);
Ok(n)
}
fn flush(&mut self) -> std::io::Result<()> { self.inner.flush() }
}
let mut cw = CountingWriter { inner: &mut wf, progress: &mut progress };
for u in try_urls {
// For fallback, stream from scratch to ensure integrity
let res = self.client.get_whole_to(&u, &mut cw);
match res {
Ok(()) => { ok = true; break; }
Err(e) => { last_err = Some(e.into()); }
}
}
if !ok {
if let Some(e) = last_err { return Err(e.into()); }
return Err(anyhow!("download failed (ranged and streaming)").into());
}
}
progress.stop("download complete");
// Verify integrity
let sha = Self::compute_sha256(&dest_tmp)?;
if let Some(expected) = api_sha.as_ref() {
if &sha != expected {
return Err(anyhow!(
"sha256 mismatch (expected {}, got {})",
expected,
sha
).into());
}
} else if prev.as_ref().map(|r| r.file.eq(file)).unwrap_or(false) {
if let Some(expected) = prev.as_ref().and_then(|r| r.sha256.as_ref()) {
if &sha != expected {
return Err(anyhow!("sha256 mismatch").into());
}
}
}
// Atomic rename
fs::rename(&dest_tmp, &dest_final).with_context(|| format!("rename {} -> {}", dest_tmp.display(), dest_final.display()))?;
// Disarm guard; .part has been moved or cleaned
_tmp_guard.disarm();
let rec = ModelRecord {
alias: alias.to_string(),
repo: repo.to_string(),
file: file.to_string(),
revision: etag,
sha256: Some(sha.clone()),
size_bytes: Some(total_len),
quant: infer_quant(file),
installed_at: Some(Utc::now()),
last_used: Some(Utc::now()),
};
self.save_touch(&mut manifest, rec.clone())?;
Ok(rec)
}
fn save_touch(&self, manifest: &mut Manifest, rec: ModelRecord) -> Result<()> {
manifest.models.insert(rec.alias.clone(), rec);
self.save_manifest(manifest)
}
}
fn infer_quant(file: &str) -> Option<String> {
// Try to extract a Q* token, e.g. Q5_K_M from filename
let lower = file.to_ascii_uppercase();
if let Some(pos) = lower.find('Q') {
let tail = &lower[pos..];
let token: String = tail
.chars()
.take_while(|c| c.is_ascii_uppercase() || c.is_ascii_digit() || *c == '_' || *c == '-')
.collect();
if token.len() >= 2 {
return Some(token);
}
}
None
}
impl Default for ReqwestClient {
fn default() -> Self {
Self::new().expect("reqwest client")
}
}
// Hugging Face API types for file metadata
#[derive(Debug, Deserialize)]
struct ApiHfLfs {
oid: Option<String>,
size: Option<u64>,
sha256: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ApiHfFile {
rfilename: String,
size: Option<u64>,
sha256: Option<String>,
lfs: Option<ApiHfLfs>,
}
#[derive(Debug, Deserialize)]
struct ApiHfModelInfo {
siblings: Option<Vec<ApiHfFile>>,
files: Option<Vec<ApiHfFile>>,
}
fn pick_sha_from_file(f: &ApiHfFile) -> Option<String> {
if let Some(s) = &f.sha256 { return Some(s.to_string()); }
if let Some(l) = &f.lfs {
if let Some(s) = &l.sha256 { return Some(s.to_string()); }
if let Some(oid) = &l.oid { return oid.strip_prefix("sha256:").map(|s| s.to_string()); }
}
None
}
fn hf_fetch_file_meta(repo: &str, target: &str) -> Result<(Option<u64>, Option<String>)> {
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
let client = Client::builder()
.user_agent(crate::config::ConfigService::user_agent())
.build()?;
let base = crate::config::ConfigService::hf_api_base_for(repo);
let urls = [base.clone(), format!("{}?expand=files", base)];
for url in urls {
let mut req = client.get(&url);
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
let resp = req.send()?;
if !resp.status().is_success() { continue; }
let info: ApiHfModelInfo = resp.json()?;
let list = info.files.or(info.siblings).unwrap_or_default();
for f in list {
let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename);
if name.eq_ignore_ascii_case(target) {
let sz = f.size.or_else(|| f.lfs.as_ref().and_then(|l| l.size));
let sha = pick_sha_from_file(&f);
return Ok((sz, sha));
}
}
}
Err(anyhow!("file not found in HF API").into())
}
/// Fetch remote metadata (size, sha256) for a single file in a HF repo.
pub fn fetch_file_meta(repo: &str, file: &str) -> Result<(Option<u64>, Option<String>)> {
hf_fetch_file_meta(repo, file)
}
/// Search a Hugging Face repo for GGUF/BIN files via API. Returns file names only.
pub fn search_repo(repo: &str, query: Option<&str>) -> Result<Vec<String>> {
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
let client = Client::builder()
.user_agent(crate::config::ConfigService::user_agent())
.build()?;
let base = crate::config::ConfigService::hf_api_base_for(repo);
let mut urls = vec![base.clone(), format!("{}?expand=files", base)];
let mut files = Vec::<String>::new();
for url in urls.drain(..) {
let mut req = client.get(&url);
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
let resp = req.send()?;
if !resp.status().is_success() { continue; }
let info: ApiHfModelInfo = resp.json()?;
let list = info.files.or(info.siblings).unwrap_or_default();
for f in list {
if f.rfilename.ends_with(".gguf") || f.rfilename.ends_with(".bin") {
let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename).to_string();
if !files.contains(&name) { files.push(name); }
}
}
if !files.is_empty() { break; }
}
if let Some(q) = query { let qlc = q.to_ascii_lowercase(); files.retain(|f| f.to_ascii_lowercase().contains(&qlc)); }
files.sort();
Ok(files)
}
/// List repo files with optional size metadata for GGUF/BIN entries.
pub fn list_repo_files_with_meta(repo: &str) -> Result<Vec<(String, Option<u64>)>> {
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
let client = Client::builder()
.user_agent(crate::config::ConfigService::user_agent())
.build()?;
let base = crate::config::ConfigService::hf_api_base_for(repo);
for url in [base.clone(), format!("{}?expand=files", base)] {
let mut req = client.get(&url);
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
let resp = req.send()?;
if !resp.status().is_success() { continue; }
let info: ApiHfModelInfo = resp.json()?;
let list = info.files.or(info.siblings).unwrap_or_default();
let mut out = Vec::new();
for f in list {
if !(f.rfilename.ends_with(".gguf") || f.rfilename.ends_with(".bin")) { continue; }
let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename).to_string();
let size = f.size.or_else(|| f.lfs.as_ref().and_then(|l| l.size));
out.push((name, size));
}
if !out.is_empty() { return Ok(out); }
}
Ok(Vec::new())
}
/// Fallback: HEAD request for a single file to retrieve Content-Length (size).
pub fn head_len_for_file(repo: &str, file: &str) -> Option<u64> {
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
let client = Client::builder()
.user_agent(crate::config::ConfigService::user_agent())
.build().ok()?;
let mut urls = Vec::new();
urls.push(format!("https://huggingface.co/{}/resolve/main/{}?download=true", repo, file));
urls.push(format!("https://huggingface.co/{}/resolve/main/{}", repo, file));
for url in urls {
let mut req = client.head(&url);
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
if let Ok(resp) = req.send() {
if resp.status().is_success() {
if let Some(len) = resp.headers().get(CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
{ return Some(len); }
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
use tempfile::TempDir;
#[derive(Clone)]
struct StubHttp {
data: Arc<Vec<u8>>,
etag: Arc<Option<String>>,
accept_ranges: bool,
}
impl HttpClient for StubHttp {
fn head(&self, _url: &str, etag: Option<&str>) -> Result<HeadMeta> {
let not_modified = etag.is_some() && self.etag.as_ref().as_deref() == etag;
Ok(HeadMeta {
len: Some(self.data.len() as u64),
etag: self.etag.as_ref().clone(),
last_modified: None,
accept_ranges: self.accept_ranges,
not_modified,
status: if not_modified { 304 } else { 200 },
})
}
fn get_range(&self, _url: &str, start: u64, end_inclusive: u64) -> Result<Vec<u8>> {
let s = start as usize;
let e = (end_inclusive as usize) + 1;
Ok(self.data[s..e].to_vec())
}
fn get_whole_to(&self, _url: &str, writer: &mut dyn Write) -> Result<()> {
writer.write_all(&self.data)?;
Ok(())
}
fn get_from_to(&self, _url: &str, start: u64, writer: &mut dyn Write) -> Result<()> {
let s = start as usize;
writer.write_all(&self.data[s..])?;
Ok(())
}
}
fn setup_env(cache: &Path, cfg: &Path) {
unsafe {
env::set_var("POLYSCRIBE_CACHE_DIR", cache.to_string_lossy().to_string());
env::set_var(
"POLYSCRIBE_CONFIG_DIR",
cfg.parent().unwrap().to_string_lossy().to_string(),
);
}
}
#[test]
fn test_manifest_roundtrip() {
let temp = TempDir::new().unwrap();
let cache = temp.path().join("cache");
let cfg = temp.path().join("config").join("models.json");
setup_env(&cache, &cfg);
let client = StubHttp {
data: Arc::new(vec![0u8; 1024]),
etag: Arc::new(Some("etag123".into())),
accept_ranges: true,
};
let mm: ModelManager<StubHttp> = ModelManager::new_with_client(client, Settings::default()).unwrap();
let m = mm.load_manifest().unwrap();
assert!(m.models.is_empty());
let rec = ModelRecord {
alias: "tiny".into(),
repo: "foo/bar".into(),
file: "gguf-tiny.bin".into(),
revision: Some("etag123".into()),
sha256: None,
size_bytes: None,
quant: None,
installed_at: None,
last_used: None,
};
let mut m2 = Manifest::default();
mm.save_touch(&mut m2, rec.clone()).unwrap();
let m3 = mm.load_manifest().unwrap();
assert!(m3.models.contains_key("tiny"));
}
#[test]
fn test_add_verify_update_gc() {
let temp = TempDir::new().unwrap();
let cache = temp.path().join("cache");
let cfg_dir = temp.path().join("config");
let cfg = cfg_dir.join("models.json");
setup_env(&cache, &cfg);
let data = (0..1024 * 1024u32).flat_map(|i| i.to_le_bytes()).collect::<Vec<u8>>();
let etag = Some("abc123".to_string());
let client = StubHttp { data: Arc::new(data), etag: Arc::new(etag), accept_ranges: true };
let mm: ModelManager<StubHttp> = ModelManager::new_with_client(client.clone(), Settings{ concurrency: 3, ..Default::default() }).unwrap();
// add
let rec = mm.add_or_update("tiny", "gguf/models", "gguf-tiny-q4_0.bin").unwrap();
assert_eq!(rec.alias, "tiny");
assert!(mm.verify("tiny").unwrap());
// update (304)
let rec2 = mm.add_or_update("tiny", "gguf/models", "gguf-tiny-q4_0.bin").unwrap();
assert_eq!(rec2.alias, "tiny");
// gc (nothing to remove)
let (files_removed, entries_removed) = mm.gc().unwrap();
assert_eq!(files_removed, 0);
assert_eq!(entries_removed, 0);
// rm
assert!(mm.rm("tiny").unwrap());
assert!(!mm.rm("tiny").unwrap());
}
}

View File

@@ -792,6 +792,13 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
}
let part_path = dest_path.with_extension("part");
// Guard to cleanup .part on errors
struct TempGuard { path: std::path::PathBuf, armed: bool }
impl TempGuard { fn disarm(&mut self) { self.armed = false; } }
impl Drop for TempGuard {
fn drop(&mut self) { if self.armed { let _ = fs::remove_file(&self.path); } }
}
let mut _tmp_guard = TempGuard { path: part_path.clone(), armed: true };
let mut resume_from: u64 = 0;
if part_path.exists() && ranges_ok {
@@ -893,6 +900,7 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
drop(part_file);
fs::rename(&part_path, dest_path)
.with_context(|| format!("renaming {}{}", part_path.display(), dest_path.display()))?;
_tmp_guard.disarm();
let final_size = fs::metadata(dest_path).map(|m| m.len()).ok();
let elapsed = start.elapsed().as_secs_f64();

View File

@@ -4,6 +4,8 @@ pub mod progress;
use std::io;
use std::io::IsTerminal;
use std::io::Write as _;
use std::time::{Duration, Instant};
pub fn info(msg: impl AsRef<str>) {
let m = msg.as_ref();
@@ -170,43 +172,156 @@ impl Spinner {
}
}
pub struct BytesProgress(Option<cliclack::ProgressBar>);
pub struct BytesProgress {
enabled: bool,
total: u64,
current: u64,
started: Instant,
last_msg: Instant,
width: usize,
// Sticky ETA to carry through zero-speed stalls
last_eta_secs: Option<f64>,
}
impl BytesProgress {
pub fn start(total: u64, text: &str, initial: u64) -> Self {
if crate::is_no_progress()
let enabled = !(crate::is_no_progress()
|| crate::is_no_interaction()
|| !std::io::stderr().is_terminal()
|| total == 0
{
|| total == 0);
if !enabled {
let _ = cliclack::log::info(text);
return Self(None);
}
let b = cliclack::progress_bar(total);
b.start(text);
if initial > 0 {
b.inc(initial);
let mut me = Self {
enabled,
total,
current: initial.min(total),
started: Instant::now(),
last_msg: Instant::now(),
width: 40,
last_eta_secs: None,
};
me.draw();
me
}
fn human_bytes(n: u64) -> String {
const KB: f64 = 1024.0;
const MB: f64 = 1024.0 * KB;
const GB: f64 = 1024.0 * MB;
let x = n as f64;
if x >= GB {
format!("{:.2} GiB", x / GB)
} else if x >= MB {
format!("{:.2} MiB", x / MB)
} else if x >= KB {
format!("{:.2} KiB", x / KB)
} else {
format!("{} B", n)
}
Self(Some(b))
}
// Elapsed formatting is used for stable, finite durations. For ETA, we guard
// against zero-speed or unstable estimates separately via `format_eta`.
fn refresh_allowed(&mut self) -> (f64, f64) {
let now = Instant::now();
let since_last = now.duration_since(self.last_msg);
if since_last < Duration::from_millis(100) {
// Too soon to refresh; keep previous ETA if any
let eta = self.last_eta_secs.unwrap_or(f64::INFINITY);
return (0.0, eta);
}
self.last_msg = now;
let elapsed = now.duration_since(self.started).as_secs_f64().max(0.001);
let speed = (self.current as f64) / elapsed;
let remaining = self.total.saturating_sub(self.current) as f64;
// If speed is effectively zero, carry ETA forward and add wall time.
const EPS: f64 = 1e-6;
let eta = if speed <= EPS {
let prev = self.last_eta_secs.unwrap_or(f64::INFINITY);
if prev.is_finite() {
prev + since_last.as_secs_f64()
} else {
prev
}
} else {
remaining / speed
};
// Remember only finite ETAs to use during stalls
if eta.is_finite() {
self.last_eta_secs = Some(eta);
}
(speed, eta)
}
fn format_elapsed(seconds: f64) -> String {
let total = seconds.round() as u64;
let h = total / 3600;
let m = (total % 3600) / 60;
let s = total % 60;
if h > 0 { format!("{:02}:{:02}:{:02}", h, m, s) } else { format!("{:02}:{:02}", m, s) }
}
fn format_eta(seconds: f64) -> String {
// If ETA is not finite (e.g., divide-by-zero speed) or unreasonably large,
// show a placeholder rather than overflowing into huge values.
if !seconds.is_finite() {
return "".to_string();
}
// Cap ETA display to 99:59:59 to avoid silly numbers; beyond that, show placeholder.
const CAP_SECS: f64 = 99.0 * 3600.0 + 59.0 * 60.0 + 59.0;
if seconds > CAP_SECS {
return "".to_string();
}
Self::format_elapsed(seconds)
}
fn draw(&mut self) {
if !self.enabled { return; }
let (speed, eta) = self.refresh_allowed();
let elapsed = Instant::now().duration_since(self.started).as_secs_f64();
// Build bar
let width = self.width.max(10);
let filled = ((self.current as f64 / self.total.max(1) as f64) * width as f64).round() as usize;
let filled = filled.min(width);
let mut bar = String::with_capacity(width);
for _ in 0..filled { bar.push('■'); }
for _ in filled..width { bar.push('□'); }
let line = format!(
"[{}] {} [{}] ({}/{} at {}/s)",
Self::format_elapsed(elapsed),
bar,
Self::format_eta(eta),
Self::human_bytes(self.current),
Self::human_bytes(self.total),
Self::human_bytes(speed.max(0.0) as u64),
);
eprint!("\r{}\x1b[K", line);
let _ = io::stderr().flush();
}
pub fn inc(&mut self, delta: u64) {
if let Some(b) = self.0.as_mut() {
b.inc(delta);
}
self.current = self.current.saturating_add(delta).min(self.total);
self.draw();
}
pub fn stop(mut self, text: &str) {
if let Some(b) = self.0.take() {
b.stop(text);
if self.enabled {
self.draw();
eprintln!();
} else {
let _ = cliclack::log::info(text);
}
}
pub fn error(mut self, text: &str) {
if let Some(b) = self.0.take() {
b.error(text);
if self.enabled {
self.draw();
eprintln!();
let _ = cliclack::log::error(text);
} else {
let _ = cliclack::log::error(text);
}

View File

@@ -1,12 +1,12 @@
[package]
name = "polyscribe-host"
version = "0.1.0"
edition = "2024"
version.workspace = true
edition.workspace = true
[dependencies]
anyhow = "1.0.99"
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.142"
tokio = { version = "1.47.1", features = ["rt-multi-thread", "process", "io-util"] }
which = "6.0.3"
anyhow = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
tokio = { workspace = true, features = ["rt-multi-thread", "process", "io-util"] }
which = { workspace = true }
directories = { workspace = true }

View File

@@ -1,8 +1,8 @@
[package]
name = "polyscribe-protocol"
version = "0.1.0"
edition = "2024"
version.workspace = true
edition.workspace = true
[dependencies]
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.142"
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }