Compare commits

...

11 Commits

Author SHA1 Message Date
840383fcf7 [feat] add JSON and quiet output modes for models subcommands, update UI suppression logic, and enhance CLI test coverage
Some checks failed
CI / build (push) Has been cancelled
2025-08-27 23:58:57 +02:00
1982e9b48b [feat] add empty state message for models ls command output 2025-08-27 23:51:21 +02:00
0128bf2eec [feat] add ModelManager with caching, manifest management, and Hugging Face API integration 2025-08-27 20:56:05 +02:00
da5a76d253 [refactor] Refactor project into proper rust workspace. 2025-08-27 18:28:37 +02:00
5ec297397e [refactor] rename and simplify ProgressManager to FileProgress, enhance caching logic, update Hugging Face API integration, and clean up unused comments
Some checks failed
CI / build (push) Has been cancelled
2025-08-15 11:24:50 +02:00
cbf48a0452 docs: align CLI docs to models subcommands; host: scan XDG plugin dir; ci: add GitHub Actions; chore: add CHANGELOG 2025-08-14 11:16:50 +02:00
0a249f2197 [refactor] improve code readability, streamline initialization, update error handling, and format multi-line statements for consistency 2025-08-14 11:06:37 +02:00
0573369b81 [refactor] propagate no-progress and no-interaction flags, enhance prompt handling, and update progress bar logic with cliclack 2025-08-14 10:34:52 +02:00
9841550dcc [refactor] replace indicatif with cliclack for progress and logging, updating affected modules and dependencies 2025-08-14 03:31:00 +02:00
53119cd0ab [refactor] enhance model management with metadata enrichment, new API integration, and manifest resolution 2025-08-13 22:44:51 +02:00
144b01d591 [refactor] update Cargo.lock with new dependency additions and version bumps 2025-08-13 14:45:43 +02:00
29 changed files with 4592 additions and 749 deletions

33
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,33 @@
name: CI
on:
push:
branches: [ dev, main ]
pull_request:
branches: [ dev, main ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Rust
uses: dtolnay/rust-toolchain@stable
- name: Cache cargo registry and target
uses: Swatinem/rust-cache@v2
- name: Install components
run: rustup component add clippy rustfmt
- name: Cargo fmt
run: cargo fmt --all -- --check
- name: Clippy
run: cargo clippy --workspace --all-targets -- -D warnings
- name: Test
run: cargo test --workspace --all --locked

14
CHANGELOG.md Normal file
View File

@@ -0,0 +1,14 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on Keep a Changelog, and this project adheres to Semantic Versioning.
## Unreleased
### Changed
- Docs: Replace `--download-models`/`--update-models` flags with `models download`/`models update` subcommands in `README.md`, `docs/usage.md`, and `docs/development.md`.
- Host: Plugin discovery now scans `$XDG_DATA_HOME/polyscribe/plugins` (platform equivalent via `directories`) in addition to `PATH`.
- CI: Add GitHub Actions workflow to run fmt, clippy (warnings as errors), and tests for pushes and PRs.

1240
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -7,6 +7,12 @@ members = [
] ]
resolver = "3" resolver = "3"
[workspace.package]
edition = "2024"
version = "0.1.0"
license = "MIT"
rust-version = "1.89"
# Optional: Keep dependency versions consistent across members # Optional: Keep dependency versions consistent across members
[workspace.dependencies] [workspace.dependencies]
thiserror = "1.0.69" thiserror = "1.0.69"
@@ -15,18 +21,25 @@ anyhow = "1.0.99"
libc = "0.2.175" libc = "0.2.175"
toml = "0.8.23" toml = "0.8.23"
serde_json = "1.0.142" serde_json = "1.0.142"
chrono = "0.4.41" chrono = { version = "0.4.41", features = ["serde"] }
sha2 = "0.10.9" sha2 = "0.10.9"
which = "6.0.3" which = "6.0.3"
tokio = { version = "1.47.1", features = ["rt-multi-thread", "macros"] } tokio = { version = "1.47.1", features = ["rt-multi-thread", "macros"] }
clap = { version = "4.5.44", features = ["derive"] } clap = { version = "4.5.44", features = ["derive"] }
indicatif = "0.17.11"
directories = "5.0.1" directories = "5.0.1"
whisper-rs = "0.14.3" whisper-rs = "0.14.3"
cliclack = "0.3.6" cliclack = "0.3.6"
clap_complete = "4.5.57" clap_complete = "4.5.57"
clap_mangen = "0.2.29" clap_mangen = "0.2.29"
# Additional shared deps used across members
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
reqwest = { version = "0.12.7", default-features = false, features = ["blocking", "rustls-tls", "gzip", "json"] }
hex = "0.4.3"
tempfile = "3.12.0"
assert_cmd = "2.0.16"
[workspace.lints.rust] [workspace.lints.rust]
unused_imports = "deny" unused_imports = "deny"
dead_code = "warn" dead_code = "warn"

149
README.md
View File

@@ -1,121 +1,68 @@
# PolyScribe # PolyScribe
PolyScribe is a fast, local-first CLI for transcribing audio/video and merging existing JSON transcripts. It uses whisper-rs under the hood, can discover and download Whisper models automatically, and supports CPU and optional GPU backends (CUDA, ROCm/HIP, Vulkan). Local-first transcription and plugins.
Key features ## Features
- Transcribe audio and common video files using ffmpeg for audio extraction.
- Merge multiple JSON transcripts, or merge and also keep per-file outputs.
- Model management: interactive downloader and non-interactive updater with hash verification.
- GPU backend selection at runtime; auto-detects available accelerators.
- Clean outputs (JSON and SRT), speaker naming prompts, and useful logging controls.
Prerequisites - **Local-first**: Works offline with downloaded models
- Rust toolchain (rustup recommended) - **Multiple backends**: CPU, CUDA, ROCm/HIP, and Vulkan support
- ffmpeg available on PATH - **Plugin system**: Extensible via JSON-RPC plugins
- Optional for GPU acceleration at runtime: CUDA, ROCm/HIP, or Vulkan drivers (match your build features) - **Model management**: Automatic download and verification of Whisper models
- **Manifest caching**: Local cache for Hugging Face model manifests to reduce network requests
Installation ## Model Management
- Build from source (CPU-only by default):
- rustup install stable
- rustup default stable
- cargo build --release
- Binary path: ./target/release/polyscribe
- GPU builds (optional): build with features
- CUDA: cargo build --release --features gpu-cuda
- HIP: cargo build --release --features gpu-hip
- Vulkan: cargo build --release --features gpu-vulkan
Quickstart PolyScribe automatically manages Whisper models from Hugging Face:
1) Download a model (first run can prompt you):
- ./target/release/polyscribe --download-models
- In the interactive picker, use Up/Down to navigate, Space to toggle selections, and Enter to confirm. Models are grouped by base (e.g., tiny, base, small).
2) Transcribe a file: ```bash
- ./target/release/polyscribe -v -o output my_audio.mp3 # Download models interactively
This writes JSON and SRT into the output directory with a date prefix. polyscribe models download
Shell completions and man page # Update existing models
- Completions: ./target/release/polyscribe completions <bash|zsh|fish|powershell|elvish> > polyscribe.<ext> polyscribe models update
- Then install into your shells completion directory.
- Man page: ./target/release/polyscribe man > polyscribe.1 (then copy to your manpath)
Model locations # Clear manifest cache (force fresh fetch)
- Development (debug builds): ./models next to the project. polyscribe models clear-cache
- Packaged/release builds: $XDG_DATA_HOME/polyscribe/models or ~/.local/share/polyscribe/models. ```
- Override via env var: POLYSCRIBE_MODELS_DIR=/path/to/models.
- Force a specific model file via env var: WHISPER_MODEL=/path/to/model.bin.
Most-used CLI flags ### Manifest Caching
- -o, --output FILE_OR_DIR: Output path base (date prefix added). If omitted, JSON prints to stdout.
- -m, --merge: Merge all inputs into one output; otherwise one output per input.
- --merge-and-separate: Write both merged output and separate per-input outputs (requires -o dir).
- --set-speaker-names: Prompt for a speaker label per input file.
- --update-models: Verify/update local models by size/hash against the upstream manifest.
- --download-models: Interactive model list + multi-select download.
- --language LANG: Language code hint (e.g., en, de). English-only models reject non-en hints.
- --gpu-backend [auto|cpu|cuda|hip|vulkan]: Select backend (auto by default).
- --gpu-layers N: Offload N layers to GPU when supported.
- -v/--verbose (repeatable): Increase log verbosity. -vv shows very detailed logs.
- -q/--quiet: Suppress non-error logs (stderr); does not silence stdout results.
- --no-interaction: Never prompt; suitable for CI.
Minimal usage examples The Hugging Face model manifest is cached locally to avoid repeated network requests:
- Transcribe an audio file to JSON/SRT:
- ./target/release/polyscribe -o output samples/podcast_clip.mp3
- Merge multiple transcripts into one:
- ./target/release/polyscribe -m -o output merged input/a.json input/b.json
- Update local models non-interactively (good for CI):
- ./target/release/polyscribe --update-models --no-interaction -q
- Download models interactively:
- ./target/release/polyscribe --download-models
Troubleshooting & docs - **Default TTL**: 24 hours
- docs/faq.md common issues and solutions (missing ffmpeg, GPU selection, model paths) - **Cache location**: `$XDG_CACHE_HOME/polyscribe/manifest/` (or platform equivalent)
- docs/usage.md complete CLI reference and workflows - **Environment variables**:
- docs/development.md build, run, and contribute locally - `POLYSCRIBE_NO_CACHE_MANIFEST=1`: Disable caching
- docs/design.md architecture overview and decisions - `POLYSCRIBE_MANIFEST_TTL_SECONDS=3600`: Set custom TTL (in seconds)
- docs/release-packaging.md packaging notes for distributions
- CONTRIBUTING.md PR checklist and CI workflow
CI status: [CI badge placeholder] ## Installation
License ```bash
------- cargo install --path .
This project is licensed under the MIT License — see the LICENSE file for details. ```
--- ## Usage
Workspace layout ```bash
- This repo is a Cargo workspace using resolver = "2". # Transcribe audio/video
- Members: polyscribe transcribe input.mp4
- crates/polyscribe-core — types, errors, config service, core helpers.
- crates/polyscribe-protocol — PSP/1 serde types for NDJSON over stdio.
- crates/polyscribe-host — plugin discovery/runner, progress forwarding.
- crates/polyscribe-cli — the CLI, using host + core.
- plugins/polyscribe-plugin-tubescribe — stub plugin used for verification.
Build and run # Merge multiple transcripts
- Build all: cargo build --workspace --all-targets polyscribe transcribe --merge input1.json input2.json
- CLI help: cargo run -p polyscribe-cli -- --help
Plugins # Use specific GPU backend
- Build and link the example plugin into your XDG data plugin dir: polyscribe transcribe --gpu-backend cuda input.mp4
- make -C plugins/polyscribe-plugin-tubescribe link ```
- This creates a symlink at: $XDG_DATA_HOME/polyscribe/plugins/polyscribe-plugin-tubescribe (defaults to ~/.local/share on Linux).
- Discover installed plugins:
- cargo run -p polyscribe-cli -- plugins list
- Show a plugin's capabilities:
- cargo run -p polyscribe-cli -- plugins info tubescribe
- Run a plugin command (JSON-RPC over NDJSON via stdio):
- cargo run -p polyscribe-cli -- plugins run tubescribe generate_metadata --json '{"input":{"kind":"text","summary":"hello world"}}'
Verification commands ## Development
- The above commands are used for acceptance; expected behavior:
- plugins list shows "tubescribe" once linked.
- plugins info tubescribe prints JSON capabilities.
- plugins run ... prints progress events and a JSON result.
Notes ```bash
- No absolute paths are hardcoded; config and plugin dirs respect XDG on Linux and platform equivalents via directories. # Build
- Plugins must be non-interactive (no TTY prompts). All interaction stays in the host/CLI. cargo build
- Config files are written atomically and support env overrides: POLYSCRIBE__SECTION__KEY=value.
# Run tests
cargo test
# Run with verbose logging
cargo run -- --verbose transcribe input.mp4
```

View File

@@ -1,16 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
fn main() {
// Only run special build steps when gpu-vulkan feature is enabled.
let vulkan_enabled = std::env::var("CARGO_FEATURE_GPU_VULKAN").is_ok();
if !vulkan_enabled {
return;
}
// Placeholder: In a full implementation, we would invoke CMake for whisper.cpp with GGML_VULKAN=1.
// For now, emit a helpful note. Build will proceed; runtime Vulkan backend returns an explanatory error.
println!("cargo:rerun-if-changed=extern/whisper.cpp");
println!(
"cargo:warning=Building with gpu-vulkan: ensure Vulkan SDK/loader are installed. Future versions will compile whisper.cpp via CMake."
);
}

View File

@@ -1,21 +1,24 @@
[package] [package]
name = "polyscribe-cli" name = "polyscribe-cli"
version = "0.1.0" version.workspace = true
edition = "2024" edition.workspace = true
[[bin]]
name = "polyscribe"
path = "src/main.rs"
[dependencies] [dependencies]
anyhow = "1.0.99" anyhow = { workspace = true }
clap = { version = "4.5.44", features = ["derive"] } clap = { workspace = true, features = ["derive"] }
clap_complete = "4.5.57" clap_complete = { workspace = true }
clap_mangen = "0.2.29" clap_mangen = { workspace = true }
directories = "5.0.1" directories = { workspace = true }
indicatif = "0.17.11" serde = { workspace = true, features = ["derive"] }
serde = { version = "1.0.219", features = ["derive"] } serde_json = { workspace = true }
serde_json = "1.0.142" tokio = { workspace = true, features = ["rt-multi-thread", "macros", "process", "fs"] }
tokio = { version = "1.47.1", features = ["rt-multi-thread", "macros", "process", "fs"] } tracing = { workspace = true }
tracing = "0.1" tracing-subscriber = { workspace = true, features = ["fmt", "env-filter"] }
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } which = { workspace = true }
which = "6.0.3"
polyscribe-core = { path = "../polyscribe-core" } polyscribe-core = { path = "../polyscribe-core" }
polyscribe-host = { path = "../polyscribe-host" } polyscribe-host = { path = "../polyscribe-host" }
@@ -24,3 +27,6 @@ polyscribe-protocol = { path = "../polyscribe-protocol" }
[features] [features]
# Optional GPU-specific flags can be forwarded down to core/host if needed # Optional GPU-specific flags can be forwarded down to core/host if needed
default = [] default = []
[dev-dependencies]
assert_cmd = { workspace = true }

View File

@@ -1,4 +1,4 @@
use clap::{Parser, Subcommand, ValueEnum}; use clap::{Args, Parser, Subcommand, ValueEnum};
use std::path::PathBuf; use std::path::PathBuf;
#[derive(Debug, Clone, ValueEnum)] #[derive(Debug, Clone, ValueEnum)]
@@ -10,21 +10,41 @@ pub enum GpuBackend {
Vulkan, Vulkan,
} }
#[derive(Debug, Clone, Args)]
pub struct OutputOpts {
/// Emit machine-readable JSON to stdout; suppress decorative logs
#[arg(long, global = true, action = clap::ArgAction::SetTrue)]
pub json: bool,
/// Reduce log chatter (errors only unless --json)
#[arg(long, global = true, action = clap::ArgAction::SetTrue)]
pub quiet: bool,
}
#[derive(Debug, Parser)] #[derive(Debug, Parser)]
#[command(name = "polyscribe", version, about = "PolyScribe local-first transcription and plugins")] #[command(
name = "polyscribe",
version,
about = "PolyScribe local-first transcription and plugins",
propagate_version = true,
arg_required_else_help = true,
)]
pub struct Cli { pub struct Cli {
/// Global output options
#[command(flatten)]
pub output: OutputOpts,
/// Increase verbosity (-v, -vv) /// Increase verbosity (-v, -vv)
#[arg(short, long, action = clap::ArgAction::Count)] #[arg(short, long, action = clap::ArgAction::Count)]
pub verbose: u8, pub verbose: u8,
/// Quiet mode (suppresses non-error logs)
#[arg(short, long, default_value_t = false)]
pub quiet: bool,
/// Never prompt for user input (non-interactive mode) /// Never prompt for user input (non-interactive mode)
#[arg(long, default_value_t = false)] #[arg(long, default_value_t = false)]
pub no_interaction: bool, pub no_interaction: bool,
/// Disable progress bars/spinners
#[arg(long, default_value_t = false)]
pub no_progress: bool,
#[command(subcommand)] #[command(subcommand)]
pub command: Commands, pub command: Commands,
} }
@@ -66,7 +86,7 @@ pub enum Commands {
inputs: Vec<PathBuf>, inputs: Vec<PathBuf>,
}, },
/// Manage Whisper models /// Manage Whisper GGUF models (Hugging Face)
Models { Models {
#[command(subcommand)] #[command(subcommand)]
cmd: ModelsCmd, cmd: ModelsCmd,
@@ -89,12 +109,64 @@ pub enum Commands {
Man, Man,
} }
#[derive(Debug, Clone, Parser)]
pub struct ModelCommon {
/// Concurrency for ranged downloads
#[arg(long, default_value_t = 4)]
pub concurrency: usize,
/// Limit download rate in bytes/sec (approximate)
#[arg(long)]
pub limit_rate: Option<u64>,
}
#[derive(Debug, Subcommand)] #[derive(Debug, Subcommand)]
pub enum ModelsCmd { pub enum ModelsCmd {
/// Verify or update local models non-interactively /// List installed models (from manifest)
Update, Ls {
/// Interactive multi-select downloader #[command(flatten)]
Download, common: ModelCommon,
},
/// Add or update a model
Add {
/// Hugging Face repo, e.g. ggml-org/models
repo: String,
/// File name in repo (e.g., gguf-tiny-q4_0.bin)
file: String,
#[command(flatten)]
common: ModelCommon,
},
/// Remove a model by alias
Rm {
alias: String,
#[command(flatten)]
common: ModelCommon,
},
/// Verify model file integrity by alias
Verify {
alias: String,
#[command(flatten)]
common: ModelCommon,
},
/// Update all models (HEAD + ETag; skip if unchanged)
Update {
#[command(flatten)]
common: ModelCommon,
},
/// Garbage-collect unreferenced files and stale manifest entries
Gc {
#[command(flatten)]
common: ModelCommon,
},
/// Search a repo for GGUF files
Search {
/// Hugging Face repo, e.g. ggml-org/models
repo: String,
/// Optional substring to filter filenames
#[arg(long)]
query: Option<String>,
#[command(flatten)]
common: ModelCommon,
},
} }
#[derive(Debug, Subcommand)] #[derive(Debug, Subcommand)]

View File

@@ -1,54 +1,96 @@
mod cli; mod cli;
mod output;
use anyhow::{anyhow, Context, Result}; use anyhow::{Context, Result, anyhow};
use clap::{Parser, CommandFactory}; use clap::{CommandFactory, Parser};
use cli::{Cli, Commands, GpuBackend, ModelsCmd, PluginsCmd}; use cli::{Cli, Commands, GpuBackend, ModelsCmd, ModelCommon, PluginsCmd};
use polyscribe_core::{config::ConfigService, ui::progress::ProgressReporter}; use output::OutputMode;
use polyscribe_core::model_manager::{ModelManager, Settings, ReqwestClient};
use polyscribe_core::ui;
fn normalized_similarity(a: &str, b: &str) -> f64 {
// simple Levenshtein distance; normalized to [0,1]
let a_bytes = a.as_bytes();
let b_bytes = b.as_bytes();
let n = a_bytes.len();
let m = b_bytes.len();
if n == 0 && m == 0 { return 1.0; }
if n == 0 || m == 0 { return 0.0; }
let mut prev: Vec<usize> = (0..=m).collect();
let mut curr: Vec<usize> = vec![0; m + 1];
for i in 1..=n {
curr[0] = i;
for j in 1..=m {
let cost = if a_bytes[i - 1] == b_bytes[j - 1] { 0 } else { 1 };
curr[j] = (prev[j] + 1)
.min(curr[j - 1] + 1)
.min(prev[j - 1] + cost);
}
std::mem::swap(&mut prev, &mut curr);
}
let dist = prev[m] as f64;
let max_len = n.max(m) as f64;
1.0 - (dist / max_len)
}
fn human_size(bytes: Option<u64>) -> String {
match bytes {
Some(n) => {
let x = n as f64;
const KB: f64 = 1024.0;
const MB: f64 = 1024.0 * KB;
const GB: f64 = 1024.0 * MB;
if x >= GB { format!("{:.2} GiB", x / GB) }
else if x >= MB { format!("{:.2} MiB", x / MB) }
else if x >= KB { format!("{:.2} KiB", x / KB) }
else { format!("{} B", n) }
}
None => "?".to_string(),
}
}
use polyscribe_core::ui::progress::ProgressReporter;
use polyscribe_host::PluginManager; use polyscribe_host::PluginManager;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
use tracing::{error, info};
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
fn init_tracing(quiet: bool, verbose: u8) { fn init_tracing(json_mode: bool, quiet: bool, verbose: u8) {
let level = if quiet { // In JSON mode, suppress human logs; route errors to stderr only.
"error" let level = if json_mode || quiet { "error" } else { match verbose { 0 => "info", 1 => "debug", _ => "trace" } };
} else {
match verbose {
0 => "info",
1 => "debug",
_ => "trace",
}
};
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(level)); let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(level));
tracing_subscriber::fmt() tracing_subscriber::fmt()
.with_env_filter(filter) .with_env_filter(filter)
.with_target(false) .with_target(false)
.with_level(true) .with_level(true)
.with_writer(std::io::stderr)
.compact() .compact()
.init(); .init();
} }
#[tokio::main] fn main() -> Result<()> {
async fn main() -> Result<()> {
let args = Cli::parse(); let args = Cli::parse();
init_tracing(args.quiet, args.verbose); // Determine output mode early for logging and UI configuration
let output_mode = if args.output.json {
OutputMode::Json
} else {
OutputMode::Human { quiet: args.output.quiet }
};
let _cfg = ConfigService::load_or_default().context("loading configuration")?; init_tracing(matches!(output_mode, OutputMode::Json), args.output.quiet, args.verbose);
// Suppress decorative UI output in JSON mode as well
polyscribe_core::set_quiet(args.output.quiet || matches!(output_mode, OutputMode::Json));
polyscribe_core::set_no_interaction(args.no_interaction);
polyscribe_core::set_verbose(args.verbose);
polyscribe_core::set_no_progress(args.no_progress);
match args.command { match args.command {
Commands::Transcribe { Commands::Transcribe {
output: _output,
merge: _merge,
merge_and_separate: _merge_and_separate,
language: _language,
set_speaker_names: _set_speaker_names,
gpu_backend, gpu_backend,
gpu_layers, gpu_layers,
inputs, inputs,
..
} => { } => {
info!("starting transcription workflow"); polyscribe_core::ui::info("starting transcription workflow");
let mut progress = ProgressReporter::new(args.no_interaction); let mut progress = ProgressReporter::new(args.no_interaction);
progress.step("Validating inputs"); progress.step("Validating inputs");
@@ -72,54 +114,329 @@ async fn main() -> Result<()> {
} }
Commands::Models { cmd } => { Commands::Models { cmd } => {
match cmd { // predictable exit codes
ModelsCmd::Update => { const EXIT_OK: i32 = 0;
info!("verifying/updating local models"); const EXIT_NOT_FOUND: i32 = 2;
println!("Models updated (stub)."); const EXIT_NETWORK: i32 = 3;
const EXIT_VERIFY_FAILED: i32 = 4;
// const EXIT_NO_CHANGE: i32 = 5; // reserved
let handle_common = |c: &ModelCommon| Settings {
concurrency: c.concurrency.max(1),
limit_rate: c.limit_rate,
..Default::default()
};
let exit = match cmd {
ModelsCmd::Ls { common } => {
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
let list = mm.ls()?;
match output_mode {
OutputMode::Json => {
// Always emit JSON array (possibly empty)
output_mode.print_json(&list);
}
OutputMode::Human { quiet } => {
if list.is_empty() {
if !quiet { println!("No models installed."); }
} else {
if !quiet { println!("Model (Repo)"); }
for r in list {
if !quiet { println!("{} ({})", r.file, r.repo); }
}
}
}
}
EXIT_OK
} }
ModelsCmd::Download => { ModelsCmd::Add { repo, file, common } => {
info!("interactive model selection and download"); let settings = handle_common(&common);
println!("Model download complete (stub)."); let mm: ModelManager<ReqwestClient> = ModelManager::new(settings.clone())?;
// Derive an alias automatically from repo and file
fn derive_alias(repo: &str, file: &str) -> String {
use std::path::Path;
let repo_tail = repo.rsplit('/').next().unwrap_or(repo);
let stem = Path::new(file)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or(file);
format!("{}-{}", repo_tail, stem)
}
let alias = derive_alias(&repo, &file);
match mm.add_or_update(&alias, &repo, &file) {
Ok(rec) => {
match output_mode {
OutputMode::Json => output_mode.print_json(&rec),
OutputMode::Human { quiet } => {
if !quiet { println!("installed: {} -> {}/{}", alias, repo, rec.file); }
}
}
EXIT_OK
}
Err(e) => {
// On not found or similar errors, try suggesting close matches interactively
if matches!(output_mode, OutputMode::Json) || polyscribe_core::is_no_interaction() {
match output_mode {
OutputMode::Json => {
// Emit error JSON object
#[derive(serde::Serialize)]
struct ErrObj<'a> { error: &'a str }
let eo = ErrObj { error: &e.to_string() };
output_mode.print_json(&eo);
}
_ => { eprintln!("error: {e}"); }
}
EXIT_NOT_FOUND
} else {
ui::warn(format!("{}", e));
ui::info("Searching for similar model filenames…");
match polyscribe_core::model_manager::search_repo(&repo, None) {
Ok(mut files) => {
if files.is_empty() {
ui::warn("No files found in repository.");
EXIT_NOT_FOUND
} else {
// rank by similarity
files.sort_by(|a, b| normalized_similarity(&file, b)
.partial_cmp(&normalized_similarity(&file, a))
.unwrap_or(std::cmp::Ordering::Equal));
let top: Vec<String> = files.into_iter().take(5).collect();
if top.is_empty() {
EXIT_NOT_FOUND
} else if top.len() == 1 {
let cand = &top[0];
// Fetch repo size list once
let size_map: std::collections::HashMap<String, Option<u64>> =
polyscribe_core::model_manager::list_repo_files_with_meta(&repo)
.unwrap_or_default()
.into_iter().collect();
let mut size = size_map.get(cand).cloned().unwrap_or(None);
if size.is_none() {
size = polyscribe_core::model_manager::head_len_for_file(&repo, cand);
}
let local_files: std::collections::HashSet<String> = mm.ls()?.into_iter().map(|r| r.file).collect();
let is_local = local_files.contains(cand);
let label = format!("{} [{}]{}", cand, human_size(size), if is_local { " (local)" } else { "" });
let ok = ui::prompt_confirm(&format!("Did you mean {}?", label), true)
.unwrap_or(false);
if !ok { EXIT_NOT_FOUND } else {
let mm2: ModelManager<ReqwestClient> = ModelManager::new(settings)?;
let alias2 = derive_alias(&repo, cand);
match mm2.add_or_update(&alias2, &repo, cand) {
Ok(rec) => {
match output_mode {
OutputMode::Json => output_mode.print_json(&rec),
OutputMode::Human { quiet } => { if !quiet { println!("installed: {} -> {}/{}", alias2, repo, rec.file); } }
}
EXIT_OK
}
Err(e2) => { eprintln!("error: {e2}"); EXIT_NETWORK }
}
}
} else {
let opts: Vec<String> = top;
let local_files: std::collections::HashSet<String> = mm.ls()?.into_iter().map(|r| r.file).collect();
// Enrich labels with size and local tag using a single API call
let size_map: std::collections::HashMap<String, Option<u64>> =
polyscribe_core::model_manager::list_repo_files_with_meta(&repo)
.unwrap_or_default()
.into_iter().collect();
let mut labels_owned: Vec<String> = Vec::new();
for f in &opts {
let mut size = size_map.get(f).cloned().unwrap_or(None);
if size.is_none() {
size = polyscribe_core::model_manager::head_len_for_file(&repo, f);
}
let is_local = local_files.contains(f);
let suffix = if is_local { " (local)" } else { "" };
labels_owned.push(format!("{} [{}]{}", f, human_size(size), suffix));
}
let labels: Vec<&str> = labels_owned.iter().map(|s| s.as_str()).collect();
match ui::prompt_select("Pick a model", &labels) {
Ok(idx) => {
let chosen = &opts[idx];
let mm2: ModelManager<ReqwestClient> = ModelManager::new(settings)?;
let alias2 = derive_alias(&repo, chosen);
match mm2.add_or_update(&alias2, &repo, chosen) {
Ok(rec) => {
match output_mode {
OutputMode::Json => output_mode.print_json(&rec),
OutputMode::Human { quiet } => { if !quiet { println!("installed: {} -> {}/{}", alias2, repo, rec.file); } }
}
EXIT_OK
}
Err(e2) => { eprintln!("error: {e2}"); EXIT_NETWORK }
}
}
Err(_) => EXIT_NOT_FOUND,
}
}
}
}
Err(e2) => {
eprintln!("error: {}", e2);
EXIT_NETWORK
}
}
}
}
}
} }
} ModelsCmd::Rm { alias, common } => {
Ok(()) let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
let ok = mm.rm(&alias)?;
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R { removed: bool }
output_mode.print_json(&R { removed: ok });
}
OutputMode::Human { quiet } => {
if !quiet { println!("{}", if ok { "removed" } else { "not found" }); }
}
}
if ok { EXIT_OK } else { EXIT_NOT_FOUND }
}
ModelsCmd::Verify { alias, common } => {
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
let found = mm.ls()?.into_iter().any(|r| r.alias == alias);
if !found {
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R<'a> { ok: bool, error: &'a str }
output_mode.print_json(&R { ok: false, error: "not found" });
}
OutputMode::Human { quiet } => { if !quiet { println!("not found"); } }
}
EXIT_NOT_FOUND
} else {
let ok = mm.verify(&alias)?;
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R { ok: bool }
output_mode.print_json(&R { ok });
}
OutputMode::Human { quiet } => { if !quiet { println!("{}", if ok { "ok" } else { "corrupt" }); } }
}
if ok { EXIT_OK } else { EXIT_VERIFY_FAILED }
}
}
ModelsCmd::Update { common } => {
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
let mut rc = EXIT_OK;
for rec in mm.ls()? {
match mm.add_or_update(&rec.alias, &rec.repo, &rec.file) {
Ok(_) => {}
Err(e) => {
rc = EXIT_NETWORK;
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R<'a> { alias: &'a str, error: String }
output_mode.print_json(&R { alias: &rec.alias, error: e.to_string() });
}
_ => { eprintln!("update {}: {e}", rec.alias); }
}
}
}
}
rc
}
ModelsCmd::Gc { common } => {
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
let (files_removed, entries_removed) = mm.gc()?;
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R { files_removed: usize, entries_removed: usize }
output_mode.print_json(&R { files_removed, entries_removed });
}
OutputMode::Human { quiet } => { if !quiet { println!("files_removed={} entries_removed={}", files_removed, entries_removed); } }
}
EXIT_OK
}
ModelsCmd::Search { repo, query, common } => {
let res = polyscribe_core::model_manager::search_repo(&repo, query.as_deref());
match res {
Ok(files) => {
match output_mode {
OutputMode::Json => output_mode.print_json(&files),
OutputMode::Human { quiet } => { for f in files { if !quiet { println!("{}", f); } } }
}
EXIT_OK
}
Err(e) => {
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R { error: String }
output_mode.print_json(&R { error: e.to_string() });
}
_ => { eprintln!("error: {e}"); }
}
EXIT_NETWORK
}
}
}
};
std::process::exit(exit);
} }
Commands::Plugins { cmd } => { Commands::Plugins { cmd } => {
let pm = PluginManager::default(); let plugin_manager = PluginManager;
match cmd { match cmd {
PluginsCmd::List => { PluginsCmd::List => {
let list = pm.list().context("discovering plugins")?; let list = plugin_manager.list().context("discovering plugins")?;
for item in list { for item in list {
println!("{}", item.name); polyscribe_core::ui::info(item.name);
} }
Ok(()) Ok(())
} }
PluginsCmd::Info { name } => { PluginsCmd::Info { name } => {
let info = pm.info(&name).with_context(|| format!("getting info for {}", name))?; let info = plugin_manager
println!("{}", serde_json::to_string_pretty(&info)?); .info(&name)
.with_context(|| format!("getting info for {}", name))?;
let info_json = serde_json::to_string_pretty(&info)?;
polyscribe_core::ui::info(info_json);
Ok(()) Ok(())
} }
PluginsCmd::Run { name, command, json } => { PluginsCmd::Run {
let payload = json.unwrap_or_else(|| "{}".to_string()); name,
let mut child = pm command,
.spawn(&name, &command) json,
.with_context(|| format!("spawning plugin {name} {command}"))?; } => {
// Use a local Tokio runtime only for this async path
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.context("building tokio runtime")?;
if let Some(mut stdin) = child.stdin.take() { rt.block_on(async {
stdin let payload = json.unwrap_or_else(|| "{}".to_string());
.write_all(payload.as_bytes()) let mut child = plugin_manager
.await .spawn(&name, &command)
.context("writing JSON payload to plugin stdin")?; .with_context(|| format!("spawning plugin {name} {command}"))?;
}
let status = pm.forward_stdio(&mut child).await?; if let Some(mut stdin) = child.stdin.take() {
if !status.success() { stdin
error!("plugin returned non-zero exit code: {}", status); .write_all(payload.as_bytes())
return Err(anyhow!("plugin failed")); .await
} .context("writing JSON payload to plugin stdin")?;
Ok(()) }
let status = plugin_manager.forward_stdio(&mut child).await?;
if !status.success() {
polyscribe_core::ui::error(format!(
"plugin returned non-zero exit code: {}",
status
));
return Err(anyhow!("plugin failed"));
}
Ok(())
})
} }
} }
} }

View File

@@ -0,0 +1,36 @@
use std::io::{self, Write};
#[derive(Clone, Debug)]
pub enum OutputMode {
Json,
Human { quiet: bool },
}
impl OutputMode {
pub fn is_quiet(&self) -> bool {
matches!(self, OutputMode::Json) || matches!(self, OutputMode::Human { quiet: true })
}
pub fn print_json<T: serde::Serialize>(&self, v: &T) {
if let OutputMode::Json = self {
// Write compact JSON to stdout without prefixes
// and ensure a trailing newline for CLI ergonomics
let s = serde_json::to_string(v).unwrap_or_else(|e| format!("\"JSON_ERROR:{}\"", e));
println!("{}", s);
}
}
pub fn print_line(&self, s: impl AsRef<str>) {
match self {
OutputMode::Json => {
// Suppress human lines in JSON mode
}
OutputMode::Human { quiet } => {
if !quiet {
let _ = writeln!(io::stdout(), "{}", s.as_ref());
}
}
}
}
}

View File

@@ -1,11 +1,11 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved. // Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
use assert_cmd::cargo::cargo_bin;
use std::process::Command; use std::process::Command;
fn bin() -> String { fn bin() -> std::path::PathBuf {
std::env::var("CARGO_BIN_EXE_polyscribe") cargo_bin("polyscribe")
.unwrap_or_else(|_| "polyscribe".to_string())
} }
#[test] #[test]

View File

@@ -0,0 +1,42 @@
use assert_cmd::cargo::cargo_bin;
use std::process::Command;
fn bin() -> std::path::PathBuf { cargo_bin("polyscribe") }
#[test]
fn models_help_shows_global_output_flags() {
let out = Command::new(bin())
.args(["models", "--help"]) // subcommand help
.output()
.expect("failed to run polyscribe models --help");
assert!(out.status.success(), "help exited non-zero: {:?}", out.status);
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
assert!(stdout.contains("--json"), "--json not shown in help: {stdout}");
assert!(stdout.contains("--quiet"), "--quiet not shown in help: {stdout}");
}
#[test]
fn models_version_contains_pkg_version() {
let out = Command::new(bin())
.args(["models", "--version"]) // propagate_version
.output()
.expect("failed to run polyscribe models --version");
assert!(out.status.success(), "version exited non-zero: {:?}", out.status);
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
let want = env!("CARGO_PKG_VERSION");
assert!(stdout.contains(want), "version output missing {want}: {stdout}");
}
#[test]
fn models_ls_json_quiet_emits_pure_json() {
let out = Command::new(bin())
.args(["models", "ls", "--json", "--quiet"]) // global flags
.output()
.expect("failed to run polyscribe models ls --json --quiet");
assert!(out.status.success(), "ls exited non-zero: {:?}", out.status);
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
serde_json::from_str::<serde_json::Value>(stdout.trim()).expect("stdout is not valid JSON");
// Expect no extra logs on stdout; stderr should be empty in success path
assert!(out.stderr.is_empty(), "expected no stderr noise");
}

View File

@@ -1,16 +1,22 @@
[package] [package]
name = "polyscribe-core" name = "polyscribe-core"
version = "0.1.0" version.workspace = true
edition = "2024" edition.workspace = true
[dependencies] [dependencies]
anyhow = "1.0.99" anyhow = { workspace = true }
thiserror = "1.0.69" thiserror = { workspace = true }
serde = { version = "1.0.219", features = ["derive"] } serde = { workspace = true, features = ["derive"] }
serde_json = "1.0.142" serde_json = { workspace = true }
toml = "0.8.23" toml = { workspace = true }
directories = "5.0.1" directories = { workspace = true }
chrono = "0.4.41" chrono = { workspace = true }
libc = "0.2.175" libc = { workspace = true }
whisper-rs = "0.14.3" whisper-rs = { workspace = true }
indicatif = "0.17.11" # UI and progress
cliclack = { workspace = true }
# HTTP downloads + hashing
reqwest = { workspace = true }
sha2 = { workspace = true }
hex = { workspace = true }
tempfile = { workspace = true }

View File

@@ -1,12 +1,14 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Move original build.rs behavior into core crate
fn main() { fn main() {
// Only run special build steps when gpu-vulkan feature is enabled.
let vulkan_enabled = std::env::var("CARGO_FEATURE_GPU_VULKAN").is_ok(); let vulkan_enabled = std::env::var("CARGO_FEATURE_GPU_VULKAN").is_ok();
println!("cargo:rerun-if-changed=extern/whisper.cpp");
if !vulkan_enabled { if !vulkan_enabled {
println!(
"cargo:warning=gpu-vulkan feature is disabled; skipping Vulkan-dependent build steps."
);
return; return;
} }
println!("cargo:rerun-if-changed=extern/whisper.cpp");
println!( println!(
"cargo:warning=Building with gpu-vulkan: ensure Vulkan SDK/loader are installed. Future versions will compile whisper.cpp via CMake." "cargo:warning=Building with gpu-vulkan: ensure Vulkan SDK/loader are installed. Future versions will compile whisper.cpp via CMake."
); );

View File

@@ -1,34 +1,23 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
//! Transcription backend selection and implementations (CPU/GPU) used by PolyScribe.
use crate::OutputEntry; use crate::OutputEntry;
use crate::prelude::*;
use crate::{decode_audio_to_pcm_f32_ffmpeg, find_model_file}; use crate::{decode_audio_to_pcm_f32_ffmpeg, find_model_file};
use anyhow::{Context, Result, anyhow}; use anyhow::{Context, anyhow};
use std::env; use std::env;
use std::path::Path; use std::path::Path;
// Re-export a public enum for CLI parsing usage
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
/// Kind of transcription backend to use.
pub enum BackendKind { pub enum BackendKind {
/// Automatically detect the best available backend (CUDA > HIP > Vulkan > CPU).
Auto, Auto,
/// Pure CPU backend using whisper-rs.
Cpu, Cpu,
/// NVIDIA CUDA backend (requires CUDA runtime available at load time and proper feature build).
Cuda, Cuda,
/// AMD ROCm/HIP backend (requires hip/rocBLAS libraries available and proper feature build).
Hip, Hip,
/// Vulkan backend (experimental; requires Vulkan loader/SDK and feature build).
Vulkan, Vulkan,
} }
/// Abstraction for a transcription backend.
pub trait TranscribeBackend { pub trait TranscribeBackend {
/// Backend kind implemented by this type.
fn kind(&self) -> BackendKind; fn kind(&self) -> BackendKind;
/// Transcribe the given audio and return transcript entries.
fn transcribe( fn transcribe(
&self, &self,
audio_path: &Path, audio_path: &Path,
@@ -39,15 +28,13 @@ pub trait TranscribeBackend {
) -> Result<Vec<OutputEntry>>; ) -> Result<Vec<OutputEntry>>;
} }
fn check_lib(_names: &[&str]) -> bool { fn is_library_available(_names: &[&str]) -> bool {
#[cfg(test)] #[cfg(test)]
{ {
// During unit tests, avoid touching system libs to prevent loader crashes in CI.
false false
} }
#[cfg(not(test))] #[cfg(not(test))]
{ {
// Disabled runtime dlopen probing to avoid loader instability; rely on environment overrides.
false false
} }
} }
@@ -56,7 +43,7 @@ fn cuda_available() -> bool {
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_CUDA") { if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_CUDA") {
return x == "1"; return x == "1";
} }
check_lib(&[ is_library_available(&[
"libcudart.so", "libcudart.so",
"libcudart.so.12", "libcudart.so.12",
"libcudart.so.11", "libcudart.so.11",
@@ -69,33 +56,31 @@ fn hip_available() -> bool {
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_HIP") { if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_HIP") {
return x == "1"; return x == "1";
} }
check_lib(&["libhipblas.so", "librocblas.so"]) is_library_available(&["libhipblas.so", "librocblas.so"])
} }
fn vulkan_available() -> bool { fn vulkan_available() -> bool {
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_VULKAN") { if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_VULKAN") {
return x == "1"; return x == "1";
} }
check_lib(&["libvulkan.so.1", "libvulkan.so"]) is_library_available(&["libvulkan.so.1", "libvulkan.so"])
} }
/// CPU-based transcription backend using whisper-rs.
#[derive(Default)] #[derive(Default)]
pub struct CpuBackend; pub struct CpuBackend;
/// CUDA-accelerated transcription backend for NVIDIA GPUs.
#[derive(Default)] #[derive(Default)]
pub struct CudaBackend; pub struct CudaBackend;
/// ROCm/HIP-accelerated transcription backend for AMD GPUs.
#[derive(Default)] #[derive(Default)]
pub struct HipBackend; pub struct HipBackend;
/// Vulkan-based transcription backend (experimental/incomplete).
#[derive(Default)] #[derive(Default)]
pub struct VulkanBackend; pub struct VulkanBackend;
macro_rules! impl_whisper_backend { macro_rules! impl_whisper_backend {
($ty:ty, $kind:expr) => { ($ty:ty, $kind:expr) => {
impl TranscribeBackend for $ty { impl TranscribeBackend for $ty {
fn kind(&self) -> BackendKind { $kind } fn kind(&self) -> BackendKind {
$kind
}
fn transcribe( fn transcribe(
&self, &self,
audio_path: &Path, audio_path: &Path,
@@ -128,29 +113,17 @@ impl TranscribeBackend for VulkanBackend {
) -> Result<Vec<OutputEntry>> { ) -> Result<Vec<OutputEntry>> {
Err(anyhow!( Err(anyhow!(
"Vulkan backend not yet wired to whisper.cpp FFI. Build with --features gpu-vulkan and ensure Vulkan SDK is installed. How to fix: install Vulkan loader (libvulkan), set VULKAN_SDK, and run cargo build --features gpu-vulkan." "Vulkan backend not yet wired to whisper.cpp FFI. Build with --features gpu-vulkan and ensure Vulkan SDK is installed. How to fix: install Vulkan loader (libvulkan), set VULKAN_SDK, and run cargo build --features gpu-vulkan."
)) ).into())
} }
} }
/// Result of choosing a transcription backend. pub struct BackendSelection {
pub struct SelectionResult {
/// The constructed backend instance to perform transcription with.
pub backend: Box<dyn TranscribeBackend + Send + Sync>, pub backend: Box<dyn TranscribeBackend + Send + Sync>,
/// Which backend kind was ultimately selected.
pub chosen: BackendKind, pub chosen: BackendKind,
/// Which backend kinds were detected as available on this system.
pub detected: Vec<BackendKind>, pub detected: Vec<BackendKind>,
} }
/// Select an appropriate backend based on user request and system detection. pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<BackendSelection> {
///
/// If `requested` is `BackendKind::Auto`, the function prefers CUDA, then HIP,
/// then Vulkan, falling back to CPU when no GPU backend is detected. When a
/// specific GPU backend is requested but unavailable, an error is returned with
/// guidance on how to enable it.
///
/// Set `verbose` to true to print detection/selection info to stderr.
pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<SelectionResult> {
let mut detected = Vec::new(); let mut detected = Vec::new();
if cuda_available() { if cuda_available() {
detected.push(BackendKind::Cuda); detected.push(BackendKind::Cuda);
@@ -164,11 +137,11 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
let instantiate_backend = |k: BackendKind| -> Box<dyn TranscribeBackend + Send + Sync> { let instantiate_backend = |k: BackendKind| -> Box<dyn TranscribeBackend + Send + Sync> {
match k { match k {
BackendKind::Cpu => Box::new(CpuBackend::default()), BackendKind::Cpu => Box::new(CpuBackend),
BackendKind::Cuda => Box::new(CudaBackend::default()), BackendKind::Cuda => Box::new(CudaBackend),
BackendKind::Hip => Box::new(HipBackend::default()), BackendKind::Hip => Box::new(HipBackend),
BackendKind::Vulkan => Box::new(VulkanBackend::default()), BackendKind::Vulkan => Box::new(VulkanBackend),
BackendKind::Auto => Box::new(CpuBackend::default()), // placeholder for Auto BackendKind::Auto => Box::new(CpuBackend),
} }
}; };
@@ -190,7 +163,7 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
} else { } else {
return Err(anyhow!( return Err(anyhow!(
"Requested CUDA backend but CUDA libraries/devices not detected. How to fix: install NVIDIA driver + CUDA toolkit, ensure libcudart/libcublas are in loader path, and build with --features gpu-cuda." "Requested CUDA backend but CUDA libraries/devices not detected. How to fix: install NVIDIA driver + CUDA toolkit, ensure libcudart/libcublas are in loader path, and build with --features gpu-cuda."
)); ).into());
} }
} }
BackendKind::Hip => { BackendKind::Hip => {
@@ -199,7 +172,7 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
} else { } else {
return Err(anyhow!( return Err(anyhow!(
"Requested ROCm/HIP backend but libraries/devices not detected. How to fix: install ROCm hipBLAS/rocBLAS, ensure libs are in loader path, and build with --features gpu-hip." "Requested ROCm/HIP backend but libraries/devices not detected. How to fix: install ROCm hipBLAS/rocBLAS, ensure libs are in loader path, and build with --features gpu-hip."
)); ).into());
} }
} }
BackendKind::Vulkan => { BackendKind::Vulkan => {
@@ -208,7 +181,7 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
} else { } else {
return Err(anyhow!( return Err(anyhow!(
"Requested Vulkan backend but libvulkan not detected. How to fix: install Vulkan loader/SDK and build with --features gpu-vulkan." "Requested Vulkan backend but libvulkan not detected. How to fix: install Vulkan loader/SDK and build with --features gpu-vulkan."
)); ).into());
} }
} }
BackendKind::Cpu => BackendKind::Cpu, BackendKind::Cpu => BackendKind::Cpu,
@@ -219,14 +192,13 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
crate::dlog!(1, "Selected backend: {:?}", chosen); crate::dlog!(1, "Selected backend: {:?}", chosen);
} }
Ok(SelectionResult { Ok(BackendSelection {
backend: instantiate_backend(chosen), backend: instantiate_backend(chosen),
chosen, chosen,
detected, detected,
}) })
} }
// Internal helper: transcription using whisper-rs with CPU/GPU (depending on build features)
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub(crate) fn transcribe_with_whisper_rs( pub(crate) fn transcribe_with_whisper_rs(
audio_path: &Path, audio_path: &Path,
@@ -235,7 +207,9 @@ pub(crate) fn transcribe_with_whisper_rs(
progress: Option<&(dyn Fn(i32) + Send + Sync)>, progress: Option<&(dyn Fn(i32) + Send + Sync)>,
) -> Result<Vec<OutputEntry>> { ) -> Result<Vec<OutputEntry>> {
let report = |p: i32| { let report = |p: i32| {
if let Some(cb) = progress { cb(p); } if let Some(cb) = progress {
cb(p);
}
}; };
report(0); report(0);
@@ -248,21 +222,21 @@ pub(crate) fn transcribe_with_whisper_rs(
.and_then(|s| s.to_str()) .and_then(|s| s.to_str())
.map(|s| s.contains(".en.") || s.ends_with(".en.bin")) .map(|s| s.contains(".en.") || s.ends_with(".en.bin"))
.unwrap_or(false); .unwrap_or(false);
if let Some(lang) = language { if let Some(lang) = language
if english_only_model && lang != "en" { && english_only_model
return Err(anyhow!( && lang != "en"
"Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.", {
model_path.display(), return Err(anyhow!(
lang "Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.",
)); model_path.display(),
} lang
).into());
} }
let model_path_str = model_path let model_path_str = model_path
.to_str() .to_str()
.ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model_path.display()))?; .ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model_path.display()))?;
if crate::verbose_level() < 2 { if crate::verbose_level() < 2 {
// Some builds of whisper/ggml expect these env vars; harmless if unknown
unsafe { unsafe {
std::env::set_var("GGML_LOG_LEVEL", "0"); std::env::set_var("GGML_LOG_LEVEL", "0");
std::env::set_var("WHISPER_PRINT_PROGRESS", "0"); std::env::set_var("WHISPER_PRINT_PROGRESS", "0");

View File

@@ -1,108 +1,104 @@
use crate::prelude::*; // SPDX-License-Identifier: MIT
use directories::ProjectDirs;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{fs, path::PathBuf}; use std::env;
use std::path::PathBuf;
const ENV_PREFIX: &str = "POLYSCRIBE";
/// Configuration for the Polyscribe application
///
/// Contains paths to models and plugins directories that can be customized
/// through configuration files or environment variables.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
/// Directory path where ML models are stored
pub models_dir: Option<PathBuf>,
/// Directory path where plugins are stored
pub plugins_dir: Option<PathBuf>,
}
impl Default for Config {
fn default() -> Self {
Self {
models_dir: None,
plugins_dir: None,
}
}
}
/// Service for managing Polyscribe configuration
///
/// Provides functionality to load, save, and access configuration settings
/// from disk or environment variables.
pub struct ConfigService; pub struct ConfigService;
impl ConfigService { impl ConfigService {
/// Loads configuration from disk or returns default values if not found pub const ENV_NO_CACHE_MANIFEST: &'static str = "POLYSCRIBE_NO_CACHE_MANIFEST";
/// pub const ENV_MANIFEST_TTL_SECONDS: &'static str = "POLYSCRIBE_MANIFEST_TTL_SECONDS";
/// This function attempts to read the configuration file from disk. If the file pub const ENV_MODELS_DIR: &'static str = "POLYSCRIBE_MODELS_DIR";
/// doesn't exist or can't be parsed, it falls back to default values. pub const ENV_USER_AGENT: &'static str = "POLYSCRIBE_USER_AGENT";
/// Environment variable overrides are then applied to the configuration. pub const ENV_HTTP_TIMEOUT_SECS: &'static str = "POLYSCRIBE_HTTP_TIMEOUT_SECS";
pub fn load_or_default() -> Result<Config> { pub const ENV_HF_REPO: &'static str = "POLYSCRIBE_HF_REPO";
let mut cfg = Self::read_disk().unwrap_or_default(); pub const ENV_CACHE_FILENAME: &'static str = "POLYSCRIBE_MANIFEST_CACHE_FILENAME";
Self::apply_env_overrides(&mut cfg)?;
Ok(cfg) pub const DEFAULT_USER_AGENT: &'static str = "polyscribe/0.1";
pub const DEFAULT_DOWNLOADER_UA: &'static str = "polyscribe-model-downloader/1";
pub const DEFAULT_HF_REPO: &'static str = "ggerganov/whisper.cpp";
pub const DEFAULT_CACHE_FILENAME: &'static str = "hf_manifest_whisper_cpp.json";
pub const DEFAULT_HTTP_TIMEOUT_SECS: u64 = 8;
pub const DEFAULT_MANIFEST_CACHE_TTL_SECONDS: u64 = 24 * 60 * 60;
pub fn project_dirs() -> Option<directories::ProjectDirs> {
directories::ProjectDirs::from("dev", "polyscribe", "polyscribe")
} }
/// Saves the configuration to disk
///
/// This function serializes the configuration to TOML format and writes it
/// to the standard configuration directory for the application.
/// Returns an error if writing fails or if project directories cannot be determined.
pub fn save(cfg: &Config) -> Result<()> {
let Some(dirs) = Self::dirs() else {
return Err(Error::Other("unable to get project dirs".into()));
};
let cfg_dir = dirs.config_dir();
fs::create_dir_all(cfg_dir)?;
let path = cfg_dir.join("config.toml");
let s = toml::to_string_pretty(cfg)?;
fs::write(path, s)?;
Ok(())
}
fn read_disk() -> Option<Config> {
let dirs = Self::dirs()?;
let path = dirs.config_dir().join("config.toml");
let s = fs::read_to_string(path).ok()?;
toml::from_str(&s).ok()
}
fn apply_env_overrides(cfg: &mut Config) -> Result<()> {
// POLYSCRIBE__SECTION__KEY format reserved for future nested config.
if let Ok(v) = std::env::var(format!("{ENV_PREFIX}_MODELS_DIR")) {
cfg.models_dir = Some(PathBuf::from(v));
}
if let Ok(v) = std::env::var(format!("{ENV_PREFIX}_PLUGINS_DIR")) {
cfg.plugins_dir = Some(PathBuf::from(v));
}
Ok(())
}
/// Returns the standard project directories for the application
///
/// This function creates a ProjectDirs instance with the appropriate
/// organization and application names for Polyscribe.
/// Returns None if the project directories cannot be determined.
pub fn dirs() -> Option<ProjectDirs> {
ProjectDirs::from("dev", "polyscribe", "polyscribe")
}
/// Returns the default directory path for storing ML models
///
/// This function determines the standard data directory for the application
/// and appends a 'models' subdirectory to it.
/// Returns None if the project directories cannot be determined.
pub fn default_models_dir() -> Option<PathBuf> { pub fn default_models_dir() -> Option<PathBuf> {
Self::dirs().map(|d| d.data_dir().join("models")) Self::project_dirs().map(|d| d.data_dir().join("models"))
} }
/// Returns the default directory path for storing plugins
///
/// This function determines the standard data directory for the application
/// and appends a 'plugins' subdirectory to it.
/// Returns None if the project directories cannot be determined.
pub fn default_plugins_dir() -> Option<PathBuf> { pub fn default_plugins_dir() -> Option<PathBuf> {
Self::dirs().map(|d| d.data_dir().join("plugins")) Self::project_dirs().map(|d| d.data_dir().join("plugins"))
}
pub fn manifest_cache_dir() -> Option<PathBuf> {
Self::project_dirs().map(|d| d.cache_dir().join("manifest"))
}
pub fn bypass_manifest_cache() -> bool {
env::var(Self::ENV_NO_CACHE_MANIFEST).is_ok()
}
pub fn manifest_cache_ttl_seconds() -> u64 {
env::var(Self::ENV_MANIFEST_TTL_SECONDS)
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(Self::DEFAULT_MANIFEST_CACHE_TTL_SECONDS)
}
pub fn manifest_cache_filename() -> String {
env::var(Self::ENV_CACHE_FILENAME)
.unwrap_or_else(|_| Self::DEFAULT_CACHE_FILENAME.to_string())
}
pub fn models_dir(cfg: Option<&Config>) -> Option<PathBuf> {
if let Ok(env_dir) = env::var(Self::ENV_MODELS_DIR) {
if !env_dir.is_empty() {
return Some(PathBuf::from(env_dir));
}
}
if let Some(c) = cfg {
if let Some(dir) = c.models_dir.clone() {
return Some(dir);
}
}
Self::default_models_dir()
}
pub fn user_agent() -> String {
env::var(Self::ENV_USER_AGENT).unwrap_or_else(|_| Self::DEFAULT_USER_AGENT.to_string())
}
pub fn downloader_user_agent() -> String {
env::var(Self::ENV_USER_AGENT).unwrap_or_else(|_| Self::DEFAULT_DOWNLOADER_UA.to_string())
}
pub fn http_timeout_secs() -> u64 {
env::var(Self::ENV_HTTP_TIMEOUT_SECS)
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(Self::DEFAULT_HTTP_TIMEOUT_SECS)
}
pub fn hf_repo() -> String {
env::var(Self::ENV_HF_REPO).unwrap_or_else(|_| Self::DEFAULT_HF_REPO.to_string())
}
pub fn hf_api_base_for(repo: &str) -> String {
format!("https://huggingface.co/api/models/{}", repo)
}
pub fn manifest_cache_path() -> Option<PathBuf> {
let dir = Self::manifest_cache_dir()?;
Some(dir.join(Self::manifest_cache_filename()))
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Config {
pub models_dir: Option<PathBuf>,
pub plugins_dir: Option<PathBuf>,
}

View File

@@ -1,34 +1,26 @@
use thiserror::Error; use thiserror::Error;
#[derive(Debug, Error)] #[derive(Debug, Error)]
/// Error types for the polyscribe-core crate.
///
/// This enum represents various error conditions that can occur during
/// operations in this crate, including I/O errors, serialization/deserialization
/// errors, and environment variable access errors.
pub enum Error { pub enum Error {
#[error("I/O error: {0}")] #[error("I/O error: {0}")]
/// Represents an I/O error that occurred during file or stream operations
Io(#[from] std::io::Error), Io(#[from] std::io::Error),
#[error("serde error: {0}")] #[error("serde error: {0}")]
/// Represents a JSON serialization or deserialization error
Serde(#[from] serde_json::Error), Serde(#[from] serde_json::Error),
#[error("toml error: {0}")] #[error("toml error: {0}")]
/// Represents a TOML deserialization error
Toml(#[from] toml::de::Error), Toml(#[from] toml::de::Error),
#[error("toml ser error: {0}")] #[error("toml ser error: {0}")]
/// Represents a TOML serialization error
TomlSer(#[from] toml::ser::Error), TomlSer(#[from] toml::ser::Error),
#[error("env var error: {0}")] #[error("env var error: {0}")]
/// Represents an error that occurred during environment variable access
EnvVar(#[from] std::env::VarError), EnvVar(#[from] std::env::VarError),
#[error("http error: {0}")]
Http(#[from] reqwest::Error),
#[error("other: {0}")] #[error("other: {0}")]
/// Represents a general error condition with a custom message
Other(String), Other(String),
} }

View File

@@ -1,18 +1,13 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
#![forbid(elided_lifetimes_in_paths)] #![forbid(elided_lifetimes_in_paths)]
#![forbid(unused_must_use)] #![forbid(unused_must_use)]
#![deny(missing_docs)]
#![warn(clippy::all)] #![warn(clippy::all)]
//! PolyScribe library: business logic and core types.
//!
//! This crate exposes the reusable parts of the PolyScribe CLI as a library.
//! The binary entry point (main.rs) remains a thin CLI wrapper.
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering}; use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
use anyhow::{anyhow, Context, Result}; use crate::prelude::*;
use anyhow::{Context, anyhow};
use chrono::Local; use chrono::Local;
use std::env; use std::env;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
@@ -21,56 +16,44 @@ use std::process::Command;
#[cfg(unix)] #[cfg(unix)]
use libc::{O_WRONLY, close, dup, dup2, open}; use libc::{O_WRONLY, close, dup, dup2, open};
/// Global runtime flags
static QUIET: AtomicBool = AtomicBool::new(false); static QUIET: AtomicBool = AtomicBool::new(false);
static NO_INTERACTION: AtomicBool = AtomicBool::new(false); static NO_INTERACTION: AtomicBool = AtomicBool::new(false);
static VERBOSE: AtomicU8 = AtomicU8::new(0); static VERBOSE: AtomicU8 = AtomicU8::new(0);
static NO_PROGRESS: AtomicBool = AtomicBool::new(false); static NO_PROGRESS: AtomicBool = AtomicBool::new(false);
/// Set quiet mode: when true, non-interactive logs should be suppressed.
pub fn set_quiet(enabled: bool) { pub fn set_quiet(enabled: bool) {
QUIET.store(enabled, Ordering::Relaxed); QUIET.store(enabled, Ordering::Relaxed);
} }
/// Return current quiet mode state.
pub fn is_quiet() -> bool { pub fn is_quiet() -> bool {
QUIET.load(Ordering::Relaxed) QUIET.load(Ordering::Relaxed)
} }
/// Set non-interactive mode: when true, interactive prompts must be skipped.
pub fn set_no_interaction(enabled: bool) { pub fn set_no_interaction(enabled: bool) {
NO_INTERACTION.store(enabled, Ordering::Relaxed); NO_INTERACTION.store(enabled, Ordering::Relaxed);
} }
/// Return current non-interactive state.
pub fn is_no_interaction() -> bool { pub fn is_no_interaction() -> bool {
NO_INTERACTION.load(Ordering::Relaxed) NO_INTERACTION.load(Ordering::Relaxed)
} }
/// Set verbose level (0 = normal, 1 = verbose, 2 = super-verbose)
pub fn set_verbose(level: u8) { pub fn set_verbose(level: u8) {
VERBOSE.store(level, Ordering::Relaxed); VERBOSE.store(level, Ordering::Relaxed);
} }
/// Get current verbose level.
pub fn verbose_level() -> u8 { pub fn verbose_level() -> u8 {
VERBOSE.load(Ordering::Relaxed) VERBOSE.load(Ordering::Relaxed)
} }
/// Disable interactive progress indicators (bars/spinners)
pub fn set_no_progress(enabled: bool) { pub fn set_no_progress(enabled: bool) {
NO_PROGRESS.store(enabled, Ordering::Relaxed); NO_PROGRESS.store(enabled, Ordering::Relaxed);
} }
/// Return current no-progress state
pub fn is_no_progress() -> bool { pub fn is_no_progress() -> bool {
NO_PROGRESS.load(Ordering::Relaxed) NO_PROGRESS.load(Ordering::Relaxed)
} }
/// Check whether stdin is connected to a TTY. Used to avoid blocking prompts when not interactive.
pub fn stdin_is_tty() -> bool { pub fn stdin_is_tty() -> bool {
use std::io::IsTerminal as _; use std::io::IsTerminal as _;
std::io::stdin().is_terminal() std::io::stdin().is_terminal()
} }
/// A guard that temporarily redirects stderr to /dev/null on Unix when quiet mode is active.
/// No-op on non-Unix or when quiet is disabled. Restores stderr on drop.
pub struct StderrSilencer { pub struct StderrSilencer {
#[cfg(unix)] #[cfg(unix)]
old_stderr_fd: i32, old_stderr_fd: i32,
@@ -80,7 +63,6 @@ pub struct StderrSilencer {
} }
impl StderrSilencer { impl StderrSilencer {
/// Activate stderr silencing if quiet is set and on Unix; otherwise returns a no-op guard.
pub fn activate_if_quiet() -> Self { pub fn activate_if_quiet() -> Self {
if !is_quiet() { if !is_quiet() {
return Self { return Self {
@@ -94,7 +76,6 @@ impl StderrSilencer {
Self::activate() Self::activate()
} }
/// Activate stderr silencing unconditionally (used internally); no-op on non-Unix.
pub fn activate() -> Self { pub fn activate() -> Self {
#[cfg(unix)] #[cfg(unix)]
unsafe { unsafe {
@@ -106,7 +87,6 @@ impl StderrSilencer {
devnull_fd: -1, devnull_fd: -1,
}; };
} }
// Open /dev/null for writing
let devnull_cstr = std::ffi::CString::new("/dev/null").unwrap(); let devnull_cstr = std::ffi::CString::new("/dev/null").unwrap();
let devnull_fd = open(devnull_cstr.as_ptr(), O_WRONLY); let devnull_fd = open(devnull_cstr.as_ptr(), O_WRONLY);
if devnull_fd < 0 { if devnull_fd < 0 {
@@ -153,7 +133,6 @@ impl Drop for StderrSilencer {
} }
} }
/// Run the given closure with stderr temporarily silenced (Unix-only). Returns the closure result.
pub fn with_suppressed_stderr<F, T>(f: F) -> T pub fn with_suppressed_stderr<F, T>(f: F) -> T
where where
F: FnOnce() -> T, F: FnOnce() -> T,
@@ -164,13 +143,11 @@ where
result result
} }
/// Log an error line (always printed).
#[macro_export] #[macro_export]
macro_rules! elog { macro_rules! elog {
($($arg:tt)*) => {{ $crate::ui::error(format!($($arg)*)); }} ($($arg:tt)*) => {{ $crate::ui::error(format!($($arg)*)); }}
} }
/// Log an informational line using the UI helper unless quiet mode is enabled.
#[macro_export] #[macro_export]
macro_rules! ilog { macro_rules! ilog {
($($arg:tt)*) => {{ ($($arg:tt)*) => {{
@@ -178,7 +155,6 @@ macro_rules! ilog {
}} }}
} }
/// Log a debug/trace line when verbose level is at least the given level (u8).
#[macro_export] #[macro_export]
macro_rules! dlog { macro_rules! dlog {
($lvl:expr, $($arg:tt)*) => {{ ($lvl:expr, $($arg:tt)*) => {{
@@ -186,44 +162,28 @@ macro_rules! dlog {
}} }}
} }
/// Backward-compatibility: map old qlog! to ilog!
#[macro_export]
macro_rules! qlog {
($($arg:tt)*) => {{ $crate::ilog!($($arg)*); }}
}
pub mod backend; pub mod backend;
pub mod models;
/// Configuration handling for PolyScribe
pub mod config; pub mod config;
// Use the file-backed ui.rs module, which also declares its own `progress` submodule. pub mod models;
pub mod ui;
/// Error definitions for the PolyScribe library
pub mod error; pub mod error;
pub mod ui;
pub use error::Error; pub use error::Error;
pub mod prelude; pub mod prelude;
/// Transcript entry for a single segment.
#[derive(Debug, serde::Serialize, Clone)] #[derive(Debug, serde::Serialize, Clone)]
pub struct OutputEntry { pub struct OutputEntry {
/// Sequential id in output ordering.
pub id: u64, pub id: u64,
/// Speaker label associated with the segment.
pub speaker: String, pub speaker: String,
/// Start time in seconds.
pub start: f64, pub start: f64,
/// End time in seconds.
pub end: f64, pub end: f64,
/// Text content.
pub text: String, pub text: String,
} }
/// Return a YYYY-MM-DD date prefix string for output file naming.
pub fn date_prefix() -> String { pub fn date_prefix() -> String {
Local::now().format("%Y-%m-%d").to_string() Local::now().format("%Y-%m-%d").to_string()
} }
/// Format a floating-point number of seconds as SRT timestamp (HH:MM:SS,mmm).
pub fn format_srt_time(seconds: f64) -> String { pub fn format_srt_time(seconds: f64) -> String {
let total_ms = (seconds * 1000.0).round() as i64; let total_ms = (seconds * 1000.0).round() as i64;
let ms = total_ms % 1000; let ms = total_ms % 1000;
@@ -234,7 +194,6 @@ pub fn format_srt_time(seconds: f64) -> String {
format!("{hour:02}:{min:02}:{sec:02},{ms:03}") format!("{hour:02}:{min:02}:{sec:02},{ms:03}")
} }
/// Render a list of transcript entries to SRT format.
pub fn render_srt(entries: &[OutputEntry]) -> String { pub fn render_srt(entries: &[OutputEntry]) -> String {
let mut srt = String::new(); let mut srt = String::new();
for (index, entry) in entries.iter().enumerate() { for (index, entry) in entries.iter().enumerate() {
@@ -255,7 +214,8 @@ pub fn render_srt(entries: &[OutputEntry]) -> String {
srt srt
} }
/// Determine the default models directory, honoring POLYSCRIBE_MODELS_DIR override. pub mod model_manager;
pub fn models_dir_path() -> PathBuf { pub fn models_dir_path() -> PathBuf {
if let Ok(env_val) = env::var("POLYSCRIBE_MODELS_DIR") { if let Ok(env_val) = env::var("POLYSCRIBE_MODELS_DIR") {
let env_path = PathBuf::from(env_val); let env_path = PathBuf::from(env_val);
@@ -266,24 +226,23 @@ pub fn models_dir_path() -> PathBuf {
if cfg!(debug_assertions) { if cfg!(debug_assertions) {
return PathBuf::from("models"); return PathBuf::from("models");
} }
if let Ok(xdg) = env::var("XDG_DATA_HOME") { if let Ok(xdg) = env::var("XDG_DATA_HOME")
if !xdg.is_empty() { && !xdg.is_empty()
return PathBuf::from(xdg).join("polyscribe").join("models"); {
} return PathBuf::from(xdg).join("polyscribe").join("models");
} }
if let Ok(home) = env::var("HOME") { if let Ok(home) = env::var("HOME")
if !home.is_empty() { && !home.is_empty()
return PathBuf::from(home) {
.join(".local") return PathBuf::from(home)
.join("share") .join(".local")
.join("polyscribe") .join("share")
.join("models"); .join("polyscribe")
} .join("models");
} }
PathBuf::from("models") PathBuf::from("models")
} }
/// Normalize a language identifier to a short ISO code when possible.
pub fn normalize_lang_code(input: &str) -> Option<String> { pub fn normalize_lang_code(input: &str) -> Option<String> {
let mut lang = input.trim().to_lowercase(); let mut lang = input.trim().to_lowercase();
if lang.is_empty() || lang == "auto" || lang == "c" || lang == "posix" { if lang.is_empty() || lang == "auto" || lang == "c" || lang == "posix" {
@@ -355,47 +314,48 @@ pub fn normalize_lang_code(input: &str) -> Option<String> {
Some(code.to_string()) Some(code.to_string())
} }
/// Find the Whisper model file path to use.
pub fn find_model_file() -> Result<PathBuf> { pub fn find_model_file() -> Result<PathBuf> {
// 1) Explicit override via environment
if let Ok(path) = env::var("WHISPER_MODEL") { if let Ok(path) = env::var("WHISPER_MODEL") {
let p = PathBuf::from(path); let p = PathBuf::from(path);
if !p.exists() { if !p.exists() {
return Err(anyhow!( return Err(anyhow!(
"WHISPER_MODEL points to a non-existing path: {}", "WHISPER_MODEL points to a non-existing path: {}",
p.display() p.display()
)); )
.into());
} }
if !p.is_file() { if !p.is_file() {
return Err(anyhow!( return Err(anyhow!(
"WHISPER_MODEL must point to a file, but is not: {}", "WHISPER_MODEL must point to a file, but is not: {}",
p.display() p.display()
)); )
.into());
} }
return Ok(p); return Ok(p);
} }
// 2) Resolve models directory and ensure it exists and is a directory
let models_dir = models_dir_path(); let models_dir = models_dir_path();
if models_dir.exists() && !models_dir.is_dir() { if models_dir.exists() && !models_dir.is_dir() {
return Err(anyhow!( return Err(anyhow!(
"Models path exists but is not a directory: {}", "Models path exists but is not a directory: {}",
models_dir.display() models_dir.display()
)); )
.into());
} }
std::fs::create_dir_all(&models_dir).with_context(|| { std::fs::create_dir_all(&models_dir).with_context(|| {
format!("Failed to ensure models dir exists: {}", models_dir.display()) format!(
"Failed to ensure models dir exists: {}",
models_dir.display()
)
})?; })?;
// 3) Gather candidate .bin files (regular files only), prefer largest
let mut candidates = Vec::new(); let mut candidates = Vec::new();
for entry in std::fs::read_dir(&models_dir).with_context(|| { for entry in std::fs::read_dir(&models_dir)
format!("Failed to read models dir: {}", models_dir.display()) .with_context(|| format!("Failed to read models dir: {}", models_dir.display()))?
})? { {
let entry = entry?; let entry = entry?;
let path = entry.path(); let path = entry.path();
// Only consider .bin files
let is_bin = path let is_bin = path
.extension() .extension()
.and_then(|s| s.to_str()) .and_then(|s| s.to_str())
@@ -404,7 +364,6 @@ pub fn find_model_file() -> Result<PathBuf> {
continue; continue;
} }
// Only consider regular files
let md = match std::fs::metadata(&path) { let md = match std::fs::metadata(&path) {
Ok(m) if m.is_file() => m, Ok(m) if m.is_file() => m,
_ => continue, _ => continue,
@@ -414,7 +373,6 @@ pub fn find_model_file() -> Result<PathBuf> {
} }
if candidates.is_empty() { if candidates.is_empty() {
// 4) Fallback to known tiny English model if present
let fallback = models_dir.join("ggml-tiny.en.bin"); let fallback = models_dir.join("ggml-tiny.en.bin");
if fallback.is_file() { if fallback.is_file() {
return Ok(fallback); return Ok(fallback);
@@ -423,7 +381,8 @@ pub fn find_model_file() -> Result<PathBuf> {
"No Whisper model files (*.bin) found in {}. \ "No Whisper model files (*.bin) found in {}. \
Please download a model or set WHISPER_MODEL.", Please download a model or set WHISPER_MODEL.",
models_dir.display() models_dir.display()
)); )
.into());
} }
candidates.sort_by_key(|(size, _)| *size); candidates.sort_by_key(|(size, _)| *size);
@@ -431,19 +390,16 @@ pub fn find_model_file() -> Result<PathBuf> {
Ok(path) Ok(path)
} }
/// Decode an audio file into PCM f32 samples using ffmpeg (ffmpeg executable required).
pub fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> { pub fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> {
let in_path = audio_path let in_path = audio_path
.to_str() .to_str()
.ok_or_else(|| anyhow!("Audio path must be valid UTF-8: {}", audio_path.display()))?; .ok_or_else(|| anyhow!("Audio path must be valid UTF-8: {}", audio_path.display()))?;
// Use a raw f32le file to match the -f f32le output format.
let tmp_raw = std::env::temp_dir().join("polyscribe_tmp_input.f32le"); let tmp_raw = std::env::temp_dir().join("polyscribe_tmp_input.f32le");
let tmp_raw_str = tmp_raw let tmp_raw_str = tmp_raw
.to_str() .to_str()
.ok_or_else(|| anyhow!("Temp path not valid UTF-8: {}", tmp_raw.display()))?; .ok_or_else(|| anyhow!("Temp path not valid UTF-8: {}", tmp_raw.display()))?;
// ffmpeg -i input -f f32le -ac 1 -ar 16000 -y /tmp/tmp.f32le
let status = Command::new("ffmpeg") let status = Command::new("ffmpeg")
.arg("-hide_banner") .arg("-hide_banner")
.arg("-loglevel") .arg("-loglevel")
@@ -465,21 +421,17 @@ pub fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> {
return Err(anyhow!( return Err(anyhow!(
"ffmpeg exited with non-zero status when decoding {}", "ffmpeg exited with non-zero status when decoding {}",
in_path in_path
)); )
.into());
} }
let raw = std::fs::read(&tmp_raw) let raw = std::fs::read(&tmp_raw)
.with_context(|| format!("Failed to read temp PCM file: {}", tmp_raw.display()))?; .with_context(|| format!("Failed to read temp PCM file: {}", tmp_raw.display()))?;
// Best-effort cleanup of the temp file
let _ = std::fs::remove_file(&tmp_raw); let _ = std::fs::remove_file(&tmp_raw);
// Interpret raw bytes as f32 little-endian
if raw.len() % 4 != 0 { if raw.len() % 4 != 0 {
return Err(anyhow!( return Err(anyhow!("Decoded PCM file length not multiple of 4: {}", raw.len()).into());
"Decoded PCM file length not multiple of 4: {}",
raw.len()
));
} }
let mut samples = Vec::with_capacity(raw.len() / 4); let mut samples = Vec::with_capacity(raw.len() / 4);
for chunk in raw.chunks_exact(4) { for chunk in raw.chunks_exact(4) {

View File

@@ -0,0 +1,893 @@
// SPDX-License-Identifier: MIT
use crate::prelude::*;
use crate::ui::BytesProgress;
use anyhow::{anyhow, Context};
use chrono::{DateTime, Utc};
use reqwest::blocking::Client;
use reqwest::header::{
ACCEPT_RANGES, AUTHORIZATION, CONTENT_LENGTH, ETAG, IF_NONE_MATCH, LAST_MODIFIED, RANGE,
};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::cmp::min;
use std::collections::BTreeMap;
use std::fs::{self, File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::sync::{mpsc, Arc, Mutex};
use std::thread;
use std::time::Duration;
const DEFAULT_CHUNK_SIZE: u64 = 8 * 1024 * 1024; // 8 MiB
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ModelRecord {
pub alias: String,
pub repo: String,
pub file: String,
pub revision: Option<String>, // ETag or commit hash
pub sha256: Option<String>,
pub size_bytes: Option<u64>,
pub quant: Option<String>,
pub installed_at: Option<DateTime<Utc>>,
pub last_used: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Manifest {
pub models: BTreeMap<String, ModelRecord>, // key = alias
}
#[derive(Debug, Clone)]
pub struct Settings {
pub concurrency: usize,
pub limit_rate: Option<u64>, // bytes/sec
pub chunk_size: u64,
}
impl Default for Settings {
fn default() -> Self {
Self {
concurrency: 4,
limit_rate: None,
chunk_size: DEFAULT_CHUNK_SIZE,
}
}
}
#[derive(Debug, Clone)]
pub struct Paths {
pub cache_dir: PathBuf, // $XDG_CACHE_HOME/polyscribe/models
pub config_path: PathBuf, // $XDG_CONFIG_HOME/polyscribe/models.json
}
impl Paths {
pub fn resolve() -> Result<Self> {
if let Ok(over) = std::env::var("POLYSCRIBE_CACHE_DIR") {
if !over.is_empty() {
let cache_dir = PathBuf::from(over).join("models");
let config_path = std::env::var("POLYSCRIBE_CONFIG_DIR")
.map(|p| PathBuf::from(p).join("models.json"))
.unwrap_or_else(|_| default_config_path());
return Ok(Self {
cache_dir,
config_path,
});
}
}
let cache_dir = default_cache_dir();
let config_path = default_config_path();
Ok(Self {
cache_dir,
config_path,
})
}
}
fn default_cache_dir() -> PathBuf {
if let Ok(xdg) = std::env::var("XDG_CACHE_HOME") {
if !xdg.is_empty() {
return PathBuf::from(xdg).join("polyscribe").join("models");
}
}
if let Ok(home) = std::env::var("HOME") {
if !home.is_empty() {
return PathBuf::from(home)
.join(".cache")
.join("polyscribe")
.join("models");
}
}
PathBuf::from("models")
}
fn default_config_path() -> PathBuf {
if let Ok(xdg) = std::env::var("XDG_CONFIG_HOME") {
if !xdg.is_empty() {
return PathBuf::from(xdg).join("polyscribe").join("models.json");
}
}
if let Ok(home) = std::env::var("HOME") {
if !home.is_empty() {
return PathBuf::from(home)
.join(".config")
.join("polyscribe")
.join("models.json");
}
}
PathBuf::from("models.json")
}
pub trait HttpClient: Send + Sync {
fn head(&self, url: &str, etag: Option<&str>) -> Result<HeadMeta>;
fn get_range(&self, url: &str, start: u64, end_inclusive: u64) -> Result<Vec<u8>>;
fn get_whole_to(&self, url: &str, writer: &mut dyn Write) -> Result<()>;
fn get_from_to(&self, url: &str, start: u64, writer: &mut dyn Write) -> Result<()>;
}
#[derive(Debug, Clone)]
pub struct ReqwestClient {
client: Client,
token: Option<String>,
}
impl ReqwestClient {
pub fn new() -> Result<Self> {
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
let client = Client::builder()
.user_agent(crate::config::ConfigService::user_agent())
.build()?;
Ok(Self { client, token })
}
fn auth(&self, mut req: reqwest::blocking::RequestBuilder) -> reqwest::blocking::RequestBuilder {
if let Some(t) = &self.token {
req = req.header(AUTHORIZATION, format!("Bearer {}", t));
}
req
}
}
#[derive(Debug, Clone)]
pub struct HeadMeta {
pub len: Option<u64>,
pub etag: Option<String>,
pub last_modified: Option<String>,
pub accept_ranges: bool,
pub not_modified: bool,
pub status: u16,
}
impl HttpClient for ReqwestClient {
fn head(&self, url: &str, etag: Option<&str>) -> Result<HeadMeta> {
let mut req = self.client.head(url);
if let Some(e) = etag {
req = req.header(IF_NONE_MATCH, format!("\"{}\"", e));
}
let resp = self.auth(req).send()?;
let status = resp.status().as_u16();
if status == 304 {
return Ok(HeadMeta {
len: None,
etag: etag.map(|s| s.to_string()),
last_modified: None,
accept_ranges: true,
not_modified: true,
status,
});
}
let len = resp
.headers()
.get(CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
let etag = resp
.headers()
.get(ETAG)
.and_then(|v| v.to_str().ok())
.map(|s| s.trim_matches('"').to_string());
let last_modified = resp
.headers()
.get(LAST_MODIFIED)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let accept_ranges = resp
.headers()
.get(ACCEPT_RANGES)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_ascii_lowercase().contains("bytes"))
.unwrap_or(false);
Ok(HeadMeta {
len,
etag,
last_modified,
accept_ranges,
not_modified: false,
status,
})
}
fn get_range(&self, url: &str, start: u64, end_inclusive: u64) -> Result<Vec<u8>> {
let range_val = format!("bytes={}-{}", start, end_inclusive);
let resp = self
.auth(self.client.get(url))
.header(RANGE, range_val)
.send()?;
if !resp.status().is_success() && resp.status().as_u16() != 206 {
return Err(anyhow!("HTTP {} for ranged GET", resp.status()).into());
}
let mut buf = Vec::new();
let mut r = resp;
r.copy_to(&mut buf)?;
Ok(buf)
}
fn get_whole_to(&self, url: &str, writer: &mut dyn Write) -> Result<()> {
let resp = self.auth(self.client.get(url)).send()?;
if !resp.status().is_success() {
return Err(anyhow!("HTTP {} for GET", resp.status()).into());
}
let mut r = resp;
r.copy_to(writer)?;
Ok(())
}
fn get_from_to(&self, url: &str, start: u64, writer: &mut dyn Write) -> Result<()> {
let mut req = self.auth(self.client.get(url));
if start > 0 {
req = req.header(RANGE, format!("bytes={}-", start));
}
let resp = req.send()?;
if !resp.status().is_success() && resp.status().as_u16() != 206 {
return Err(anyhow!("HTTP {} for ranged GET from {}", resp.status(), start).into());
}
let mut r = resp;
r.copy_to(writer)?;
Ok(())
}
}
pub struct ModelManager<C: HttpClient = ReqwestClient> {
pub paths: Paths,
pub settings: Settings,
client: Arc<C>,
}
impl<C: HttpClient + 'static> ModelManager<C> {
pub fn new_with_client(client: C, settings: Settings) -> Result<Self> {
Ok(Self {
paths: Paths::resolve()?,
settings,
client: Arc::new(client),
})
}
pub fn new(settings: Settings) -> Result<Self>
where
C: Default,
{
Ok(Self {
paths: Paths::resolve()?,
settings,
client: Arc::new(C::default()),
})
}
fn load_manifest(&self) -> Result<Manifest> {
let p = &self.paths.config_path;
if !p.exists() {
return Ok(Manifest::default());
}
let file = File::open(p).with_context(|| format!("open manifest {}", p.display()))?;
let m: Manifest = serde_json::from_reader(file).context("parse manifest")?;
Ok(m)
}
fn save_manifest(&self, m: &Manifest) -> Result<()> {
let p = &self.paths.config_path;
if let Some(dir) = p.parent() {
fs::create_dir_all(dir)
.with_context(|| format!("create config dir {}", dir.display()))?;
}
let tmp = p.with_extension("json.tmp");
let f = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(&tmp)?;
serde_json::to_writer_pretty(f, m).context("serialize manifest")?;
fs::rename(&tmp, p).with_context(|| format!("atomic rename {} -> {}", tmp.display(), p.display()))?;
Ok(())
}
fn model_path(&self, file: &str) -> PathBuf {
self.paths.cache_dir.join(file)
}
fn compute_sha256(path: &Path) -> Result<String> {
let mut f = File::open(path)?;
let mut hasher = Sha256::new();
let mut buf = [0u8; 64 * 1024];
loop {
let n = f.read(&mut buf)?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
Ok(format!("{:x}", hasher.finalize()))
}
pub fn ls(&self) -> Result<Vec<ModelRecord>> {
let m = self.load_manifest()?;
Ok(m.models.values().cloned().collect())
}
pub fn rm(&self, alias: &str) -> Result<bool> {
let mut m = self.load_manifest()?;
if let Some(rec) = m.models.remove(alias) {
let p = self.model_path(&rec.file);
let _ = fs::remove_file(&p);
self.save_manifest(&m)?;
return Ok(true);
}
Ok(false)
}
pub fn verify(&self, alias: &str) -> Result<bool> {
let m = self.load_manifest()?;
let Some(rec) = m.models.get(alias) else { return Ok(false) };
let p = self.model_path(&rec.file);
if !p.exists() { return Ok(false); }
if let Some(expected) = &rec.sha256 {
let actual = Self::compute_sha256(&p)?;
return Ok(&actual == expected);
}
Ok(true)
}
pub fn gc(&self) -> Result<(usize, usize)> {
// Remove files not referenced by manifest; also drop manifest entries whose file is missing
fs::create_dir_all(&self.paths.cache_dir).ok();
let mut m = self.load_manifest()?;
let mut referenced = BTreeMap::new();
for (alias, rec) in &m.models {
referenced.insert(rec.file.clone(), alias.clone());
}
let mut removed_files = 0usize;
if let Ok(rd) = fs::read_dir(&self.paths.cache_dir) {
for ent in rd.flatten() {
let p = ent.path();
if p.is_file() {
let fname = p.file_name().and_then(|s| s.to_str()).unwrap_or("");
if !referenced.contains_key(fname) {
let _ = fs::remove_file(&p);
removed_files += 1;
}
}
}
}
m.models.retain(|_, rec| self.model_path(&rec.file).exists());
let removed_entries = referenced
.keys()
.filter(|f| !self.model_path(f).exists())
.count();
self.save_manifest(&m)?;
Ok((removed_files, removed_entries))
}
pub fn add_or_update(
&self,
alias: &str,
repo: &str,
file: &str,
) -> Result<ModelRecord> {
fs::create_dir_all(&self.paths.cache_dir)
.with_context(|| format!("create cache dir {}", self.paths.cache_dir.display()))?;
let url = format!("https://huggingface.co/{}/resolve/main/{}?download=true", repo, file);
let mut manifest = self.load_manifest()?;
let prev = manifest.models.get(alias).cloned();
let prev_etag = prev.as_ref().and_then(|r| r.revision.clone());
// Fetch remote meta (size/hash) when available to verify the download
let (_api_size, api_sha) = hf_fetch_file_meta(repo, file).unwrap_or((None, None));
let head = self.client.head(&url, prev_etag.as_deref())?;
if head.not_modified {
// up-to-date; ensure record present and touch last_used
let mut rec = prev.ok_or_else(|| anyhow!("not installed yet but server says 304"))?;
rec.last_used = Some(Utc::now());
self.save_touch(&mut manifest, rec.clone())?;
return Ok(rec);
}
// Quick check: if HEAD failed (e.g., 404), report a helpful error before attempting download
if head.status >= 400 {
return Err(anyhow!(
"file not found or inaccessible: repo='{}' file='{}' (HTTP {})\nHint: run `polyscribe models search {} --query {}` to list available files",
repo,
file,
head.status,
repo,
file
).into());
}
let total_len = head.len.ok_or_else(|| anyhow!("missing content-length (HEAD)"))?;
let etag = head.etag.clone();
let dest_tmp = self.model_path(&format!("{}.part", file));
// If a previous cancelled download left a .part file, remove it to avoid clutter/resume.
if dest_tmp.exists() { let _ = fs::remove_file(&dest_tmp); }
// Guard to ensure .part is cleaned up on errors
struct TempGuard { path: PathBuf, armed: bool }
impl TempGuard { fn disarm(&mut self) { self.armed = false; } }
impl Drop for TempGuard {
fn drop(&mut self) {
if self.armed {
let _ = fs::remove_file(&self.path);
}
}
}
let mut _tmp_guard = TempGuard { path: dest_tmp.clone(), armed: true };
let dest_final = self.model_path(file);
// Do not resume after cancellation; start fresh to avoid stale .part files
let start_from = 0u64;
// Open tmp for write
let f = OpenOptions::new().create(true).write(true).read(true).open(&dest_tmp)?;
f.set_len(total_len)?; // pre-allocate for random writes
let f = Arc::new(Mutex::new(f));
// Create progress bar
let mut progress = BytesProgress::start(total_len, &format!("Downloading {}", file), start_from);
// Create work chunks
let chunk_size = self.settings.chunk_size;
let mut chunks = Vec::new();
let mut pos = start_from;
while pos < total_len {
let end = min(total_len - 1, pos + chunk_size - 1);
chunks.push((pos, end));
pos = end + 1;
}
// Attempt concurrent ranged download; on failure, fallback to streaming GET
let mut ranged_failed = false;
if head.accept_ranges && self.settings.concurrency > 1 {
let (work_tx, work_rx) = mpsc::channel::<(u64, u64)>();
let (prog_tx, prog_rx) = mpsc::channel::<u64>();
for ch in chunks {
work_tx.send(ch).unwrap();
}
drop(work_tx);
let rx = Arc::new(Mutex::new(work_rx));
let workers = self.settings.concurrency.max(1);
let mut handles = Vec::new();
for _ in 0..workers {
let rx = rx.clone();
let url = url.clone();
let client = self.client.clone();
let f = f.clone();
let limit = self.settings.limit_rate;
let prog_tx = prog_tx.clone();
let handle = thread::spawn(move || -> Result<()> {
loop {
let next = {
let guard = rx.lock().unwrap();
guard.recv().ok()
};
let Some((start, end)) = next else { break; };
let data = client.get_range(&url, start, end)?;
if let Some(max_bps) = limit {
let dur = Duration::from_secs_f64((data.len() as f64) / (max_bps as f64));
if dur > Duration::from_millis(1) {
thread::sleep(dur);
}
}
let mut guard = f.lock().unwrap();
guard.seek(SeekFrom::Start(start))?;
guard.write_all(&data)?;
let _ = prog_tx.send(data.len() as u64);
}
Ok(())
});
handles.push(handle);
}
drop(prog_tx);
for delta in prog_rx {
progress.inc(delta);
}
let mut ranged_err: Option<anyhow::Error> = None;
for h in handles {
match h.join() {
Ok(Ok(())) => {}
Ok(Err(e)) => { ranged_failed = true; if ranged_err.is_none() { ranged_err = Some(e.into()); } }
Err(_) => { ranged_failed = true; if ranged_err.is_none() { ranged_err = Some(anyhow!("worker panicked")); } }
}
}
} else {
ranged_failed = true;
}
if ranged_failed {
// Restart progress if we are abandoning previous partial state
if start_from > 0 {
progress.stop("retrying as streaming");
progress = BytesProgress::start(total_len, &format!("Downloading {}", file), 0);
}
// Fallback to streaming GET; try URL with and without ?download=true
let mut try_urls = vec![url.clone()];
if let Some((base, _qs)) = url.split_once('?') {
try_urls.push(base.to_string());
} else {
try_urls.push(format!("{}?download=true", url));
}
// Fresh temp file for streaming
let mut wf = OpenOptions::new().create(true).write(true).truncate(true).open(&dest_tmp)?;
let mut ok = false;
let mut last_err: Option<anyhow::Error> = None;
// Counting writer to update progress inline
struct CountingWriter<'a, 'b> {
inner: &'a mut File,
progress: &'b mut BytesProgress,
}
impl<'a, 'b> Write for CountingWriter<'a, 'b> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let n = self.inner.write(buf)?;
self.progress.inc(n as u64);
Ok(n)
}
fn flush(&mut self) -> std::io::Result<()> { self.inner.flush() }
}
let mut cw = CountingWriter { inner: &mut wf, progress: &mut progress };
for u in try_urls {
// For fallback, stream from scratch to ensure integrity
let res = self.client.get_whole_to(&u, &mut cw);
match res {
Ok(()) => { ok = true; break; }
Err(e) => { last_err = Some(e.into()); }
}
}
if !ok {
if let Some(e) = last_err { return Err(e.into()); }
return Err(anyhow!("download failed (ranged and streaming)").into());
}
}
progress.stop("download complete");
// Verify integrity
let sha = Self::compute_sha256(&dest_tmp)?;
if let Some(expected) = api_sha.as_ref() {
if &sha != expected {
return Err(anyhow!(
"sha256 mismatch (expected {}, got {})",
expected,
sha
).into());
}
} else if prev.as_ref().map(|r| r.file.eq(file)).unwrap_or(false) {
if let Some(expected) = prev.as_ref().and_then(|r| r.sha256.as_ref()) {
if &sha != expected {
return Err(anyhow!("sha256 mismatch").into());
}
}
}
// Atomic rename
fs::rename(&dest_tmp, &dest_final).with_context(|| format!("rename {} -> {}", dest_tmp.display(), dest_final.display()))?;
// Disarm guard; .part has been moved or cleaned
_tmp_guard.disarm();
let rec = ModelRecord {
alias: alias.to_string(),
repo: repo.to_string(),
file: file.to_string(),
revision: etag,
sha256: Some(sha.clone()),
size_bytes: Some(total_len),
quant: infer_quant(file),
installed_at: Some(Utc::now()),
last_used: Some(Utc::now()),
};
self.save_touch(&mut manifest, rec.clone())?;
Ok(rec)
}
fn save_touch(&self, manifest: &mut Manifest, rec: ModelRecord) -> Result<()> {
manifest.models.insert(rec.alias.clone(), rec);
self.save_manifest(manifest)
}
}
fn infer_quant(file: &str) -> Option<String> {
// Try to extract a Q* token, e.g. Q5_K_M from filename
let lower = file.to_ascii_uppercase();
if let Some(pos) = lower.find('Q') {
let tail = &lower[pos..];
let token: String = tail
.chars()
.take_while(|c| c.is_ascii_uppercase() || c.is_ascii_digit() || *c == '_' || *c == '-')
.collect();
if token.len() >= 2 {
return Some(token);
}
}
None
}
impl Default for ReqwestClient {
fn default() -> Self {
Self::new().expect("reqwest client")
}
}
// Hugging Face API types for file metadata
#[derive(Debug, Deserialize)]
struct ApiHfLfs {
oid: Option<String>,
size: Option<u64>,
sha256: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ApiHfFile {
rfilename: String,
size: Option<u64>,
sha256: Option<String>,
lfs: Option<ApiHfLfs>,
}
#[derive(Debug, Deserialize)]
struct ApiHfModelInfo {
siblings: Option<Vec<ApiHfFile>>,
files: Option<Vec<ApiHfFile>>,
}
fn pick_sha_from_file(f: &ApiHfFile) -> Option<String> {
if let Some(s) = &f.sha256 { return Some(s.to_string()); }
if let Some(l) = &f.lfs {
if let Some(s) = &l.sha256 { return Some(s.to_string()); }
if let Some(oid) = &l.oid { return oid.strip_prefix("sha256:").map(|s| s.to_string()); }
}
None
}
fn hf_fetch_file_meta(repo: &str, target: &str) -> Result<(Option<u64>, Option<String>)> {
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
let client = Client::builder()
.user_agent(crate::config::ConfigService::user_agent())
.build()?;
let base = crate::config::ConfigService::hf_api_base_for(repo);
let urls = [base.clone(), format!("{}?expand=files", base)];
for url in urls {
let mut req = client.get(&url);
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
let resp = req.send()?;
if !resp.status().is_success() { continue; }
let info: ApiHfModelInfo = resp.json()?;
let list = info.files.or(info.siblings).unwrap_or_default();
for f in list {
let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename);
if name.eq_ignore_ascii_case(target) {
let sz = f.size.or_else(|| f.lfs.as_ref().and_then(|l| l.size));
let sha = pick_sha_from_file(&f);
return Ok((sz, sha));
}
}
}
Err(anyhow!("file not found in HF API").into())
}
/// Fetch remote metadata (size, sha256) for a single file in a HF repo.
pub fn fetch_file_meta(repo: &str, file: &str) -> Result<(Option<u64>, Option<String>)> {
hf_fetch_file_meta(repo, file)
}
/// Search a Hugging Face repo for GGUF/BIN files via API. Returns file names only.
pub fn search_repo(repo: &str, query: Option<&str>) -> Result<Vec<String>> {
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
let client = Client::builder()
.user_agent(crate::config::ConfigService::user_agent())
.build()?;
let base = crate::config::ConfigService::hf_api_base_for(repo);
let mut urls = vec![base.clone(), format!("{}?expand=files", base)];
let mut files = Vec::<String>::new();
for url in urls.drain(..) {
let mut req = client.get(&url);
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
let resp = req.send()?;
if !resp.status().is_success() { continue; }
let info: ApiHfModelInfo = resp.json()?;
let list = info.files.or(info.siblings).unwrap_or_default();
for f in list {
if f.rfilename.ends_with(".gguf") || f.rfilename.ends_with(".bin") {
let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename).to_string();
if !files.contains(&name) { files.push(name); }
}
}
if !files.is_empty() { break; }
}
if let Some(q) = query { let qlc = q.to_ascii_lowercase(); files.retain(|f| f.to_ascii_lowercase().contains(&qlc)); }
files.sort();
Ok(files)
}
/// List repo files with optional size metadata for GGUF/BIN entries.
pub fn list_repo_files_with_meta(repo: &str) -> Result<Vec<(String, Option<u64>)>> {
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
let client = Client::builder()
.user_agent(crate::config::ConfigService::user_agent())
.build()?;
let base = crate::config::ConfigService::hf_api_base_for(repo);
for url in [base.clone(), format!("{}?expand=files", base)] {
let mut req = client.get(&url);
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
let resp = req.send()?;
if !resp.status().is_success() { continue; }
let info: ApiHfModelInfo = resp.json()?;
let list = info.files.or(info.siblings).unwrap_or_default();
let mut out = Vec::new();
for f in list {
if !(f.rfilename.ends_with(".gguf") || f.rfilename.ends_with(".bin")) { continue; }
let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename).to_string();
let size = f.size.or_else(|| f.lfs.as_ref().and_then(|l| l.size));
out.push((name, size));
}
if !out.is_empty() { return Ok(out); }
}
Ok(Vec::new())
}
/// Fallback: HEAD request for a single file to retrieve Content-Length (size).
pub fn head_len_for_file(repo: &str, file: &str) -> Option<u64> {
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
let client = Client::builder()
.user_agent(crate::config::ConfigService::user_agent())
.build().ok()?;
let mut urls = Vec::new();
urls.push(format!("https://huggingface.co/{}/resolve/main/{}?download=true", repo, file));
urls.push(format!("https://huggingface.co/{}/resolve/main/{}", repo, file));
for url in urls {
let mut req = client.head(&url);
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
if let Ok(resp) = req.send() {
if resp.status().is_success() {
if let Some(len) = resp.headers().get(CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
{ return Some(len); }
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
use tempfile::TempDir;
#[derive(Clone)]
struct StubHttp {
data: Arc<Vec<u8>>,
etag: Arc<Option<String>>,
accept_ranges: bool,
}
impl HttpClient for StubHttp {
fn head(&self, _url: &str, etag: Option<&str>) -> Result<HeadMeta> {
let not_modified = etag.is_some() && self.etag.as_ref().as_deref() == etag;
Ok(HeadMeta {
len: Some(self.data.len() as u64),
etag: self.etag.as_ref().clone(),
last_modified: None,
accept_ranges: self.accept_ranges,
not_modified,
status: if not_modified { 304 } else { 200 },
})
}
fn get_range(&self, _url: &str, start: u64, end_inclusive: u64) -> Result<Vec<u8>> {
let s = start as usize;
let e = (end_inclusive as usize) + 1;
Ok(self.data[s..e].to_vec())
}
fn get_whole_to(&self, _url: &str, writer: &mut dyn Write) -> Result<()> {
writer.write_all(&self.data)?;
Ok(())
}
fn get_from_to(&self, _url: &str, start: u64, writer: &mut dyn Write) -> Result<()> {
let s = start as usize;
writer.write_all(&self.data[s..])?;
Ok(())
}
}
fn setup_env(cache: &Path, cfg: &Path) {
unsafe {
env::set_var("POLYSCRIBE_CACHE_DIR", cache.to_string_lossy().to_string());
env::set_var(
"POLYSCRIBE_CONFIG_DIR",
cfg.parent().unwrap().to_string_lossy().to_string(),
);
}
}
#[test]
fn test_manifest_roundtrip() {
let temp = TempDir::new().unwrap();
let cache = temp.path().join("cache");
let cfg = temp.path().join("config").join("models.json");
setup_env(&cache, &cfg);
let client = StubHttp {
data: Arc::new(vec![0u8; 1024]),
etag: Arc::new(Some("etag123".into())),
accept_ranges: true,
};
let mm: ModelManager<StubHttp> = ModelManager::new_with_client(client, Settings::default()).unwrap();
let m = mm.load_manifest().unwrap();
assert!(m.models.is_empty());
let rec = ModelRecord {
alias: "tiny".into(),
repo: "foo/bar".into(),
file: "gguf-tiny.bin".into(),
revision: Some("etag123".into()),
sha256: None,
size_bytes: None,
quant: None,
installed_at: None,
last_used: None,
};
let mut m2 = Manifest::default();
mm.save_touch(&mut m2, rec.clone()).unwrap();
let m3 = mm.load_manifest().unwrap();
assert!(m3.models.contains_key("tiny"));
}
#[test]
fn test_add_verify_update_gc() {
let temp = TempDir::new().unwrap();
let cache = temp.path().join("cache");
let cfg_dir = temp.path().join("config");
let cfg = cfg_dir.join("models.json");
setup_env(&cache, &cfg);
let data = (0..1024 * 1024u32).flat_map(|i| i.to_le_bytes()).collect::<Vec<u8>>();
let etag = Some("abc123".to_string());
let client = StubHttp { data: Arc::new(data), etag: Arc::new(etag), accept_ranges: true };
let mm: ModelManager<StubHttp> = ModelManager::new_with_client(client.clone(), Settings{ concurrency: 3, ..Default::default() }).unwrap();
// add
let rec = mm.add_or_update("tiny", "gguf/models", "gguf-tiny-q4_0.bin").unwrap();
assert_eq!(rec.alias, "tiny");
assert!(mm.verify("tiny").unwrap());
// update (304)
let rec2 = mm.add_or_update("tiny", "gguf/models", "gguf-tiny-q4_0.bin").unwrap();
assert_eq!(rec2.alias, "tiny");
// gc (nothing to remove)
let (files_removed, entries_removed) = mm.gc().unwrap();
assert_eq!(files_removed, 0);
assert_eq!(entries_removed, 0);
// rm
assert!(mm.rm("tiny").unwrap());
assert!(!mm.rm("tiny").unwrap());
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,16 +1,7 @@
// rust
//! Commonly used exports for convenient glob-imports in binaries and tests.
//! Usage: `use polyscribe_core::prelude::*;`
pub use crate::backend::*; pub use crate::backend::*;
pub use crate::config::*; pub use crate::config::*;
pub use crate::error::Error; pub use crate::error::Error;
pub use crate::models::*; pub use crate::models::*;
// If you frequently use UI helpers across binaries/tests, export them too.
// Keep this lean to avoid pulling UI everywhere unintentionally.
#[allow(unused_imports)]
pub use crate::ui::*; pub use crate::ui::*;
/// A convenient alias for `std::result::Result` with the error type defaulting to [`Error`].
pub type Result<T, E = Error> = std::result::Result<T, E>; pub type Result<T, E = Error> = std::result::Result<T, E>;

View File

@@ -1,64 +1,329 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
//! Minimal UI helpers used across the core crate.
//! This keeps interactive bits centralized and easy to stub in tests.
/// Progress indicators and reporting tools for displaying task completion.
pub mod progress; pub mod progress;
use std::io::{self, Write}; use std::io;
use std::io::IsTerminal;
use std::io::Write as _;
use std::time::{Duration, Instant};
/// Print an informational line to stderr (suppressed when quiet mode is enabled by callers).
pub fn info(msg: impl AsRef<str>) { pub fn info(msg: impl AsRef<str>) {
eprintln!("{}", msg.as_ref()); let m = msg.as_ref();
let _ = cliclack::log::info(m);
} }
/// Print a warning line to stderr.
pub fn warn(msg: impl AsRef<str>) { pub fn warn(msg: impl AsRef<str>) {
eprintln!("WARNING: {}", msg.as_ref()); let m = msg.as_ref();
let _ = cliclack::log::warning(m);
} }
/// Print an error line to stderr.
pub fn error(msg: impl AsRef<str>) { pub fn error(msg: impl AsRef<str>) {
eprintln!("ERROR: {}", msg.as_ref()); let m = msg.as_ref();
let _ = cliclack::log::error(m);
}
pub fn success(msg: impl AsRef<str>) {
let m = msg.as_ref();
let _ = cliclack::log::success(m);
}
pub fn note(prompt: impl AsRef<str>, message: impl AsRef<str>) {
let _ = cliclack::note(prompt.as_ref(), message.as_ref());
} }
/// Print a short intro header (non-fancy).
pub fn intro(title: impl AsRef<str>) { pub fn intro(title: impl AsRef<str>) {
eprintln!("== {} ==", title.as_ref()); let _ = cliclack::intro(title.as_ref());
} }
/// Print a short outro footer (non-fancy).
pub fn outro(msg: impl AsRef<str>) { pub fn outro(msg: impl AsRef<str>) {
eprintln!("{}", msg.as_ref()); let _ = cliclack::outro(msg.as_ref());
} }
/// Print a line that should appear above any progress indicators (plain for now).
pub fn println_above_bars(line: impl AsRef<str>) { pub fn println_above_bars(line: impl AsRef<str>) {
eprintln!("{}", line.as_ref()); let _ = cliclack::log::info(line.as_ref());
} }
/// Prompt for input on stdin. Returns default if provided and user enters empty string.
/// In non-interactive workflows, callers should skip prompt based on their flags.
pub fn prompt_input(prompt: &str, default: Option<&str>) -> io::Result<String> { pub fn prompt_input(prompt: &str, default: Option<&str>) -> io::Result<String> {
let mut stdout = io::stdout(); if crate::is_no_interaction() || !crate::stdin_is_tty() {
match default { return Ok(default.unwrap_or("").to_string());
Some(def) => { }
write!(stdout, "{} [{}]: ", prompt, def)?; let mut q = cliclack::input(prompt);
} if let Some(def) = default {
None => { q = q.default_input(def);
write!(stdout, "{}: ", prompt)?; }
q.interact().map_err(|e| io::Error::other(e.to_string()))
}
pub fn prompt_select(prompt: &str, items: &[&str]) -> io::Result<usize> {
if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Err(io::Error::other("interactive prompt disabled"));
}
let mut sel = cliclack::select::<usize>(prompt);
for (idx, label) in items.iter().enumerate() {
sel = sel.item(idx, *label, "");
}
sel.interact().map_err(|e| io::Error::other(e.to_string()))
}
pub fn prompt_multi_select(
prompt: &str,
items: &[&str],
defaults: Option<&[bool]>,
) -> io::Result<Vec<usize>> {
if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Err(io::Error::other("interactive prompt disabled"));
}
let mut ms = cliclack::multiselect::<usize>(prompt);
for (idx, label) in items.iter().enumerate() {
ms = ms.item(idx, *label, "");
}
if let Some(def) = defaults {
let selected: Vec<usize> = def
.iter()
.enumerate()
.filter_map(|(i, &on)| if on { Some(i) } else { None })
.collect();
if !selected.is_empty() {
ms = ms.initial_values(selected);
} }
} }
stdout.flush()?; ms.interact().map_err(|e| io::Error::other(e.to_string()))
}
let mut buf = String::new(); pub fn prompt_confirm(prompt: &str, default: bool) -> io::Result<bool> {
io::stdin().read_line(&mut buf)?; if crate::is_no_interaction() || !crate::stdin_is_tty() {
let trimmed = buf.trim(); return Ok(default);
if trimmed.is_empty() { }
Ok(default.unwrap_or_default().to_string()) let mut q = cliclack::confirm(prompt);
} else { q.interact().map_err(|e| io::Error::other(e.to_string()))
Ok(trimmed.to_string()) }
pub fn prompt_password(prompt: &str) -> io::Result<String> {
if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Err(io::Error::other(
"password prompt disabled in non-interactive mode",
));
}
let mut q = cliclack::password(prompt);
q.interact().map_err(|e| io::Error::other(e.to_string()))
}
pub fn prompt_input_validated<F>(
prompt: &str,
default: Option<&str>,
validate: F,
) -> io::Result<String>
where
F: Fn(&str) -> Result<(), String> + 'static,
{
if crate::is_no_interaction() || !crate::stdin_is_tty() {
if let Some(def) = default {
return Ok(def.to_string());
}
return Err(io::Error::other("interactive prompt disabled"));
}
let mut q = cliclack::input(prompt);
if let Some(def) = default {
q = q.default_input(def);
}
q.validate(move |s: &String| validate(s))
.interact()
.map_err(|e| io::Error::other(e.to_string()))
}
pub struct Spinner(cliclack::ProgressBar);
impl Spinner {
pub fn start(text: impl AsRef<str>) -> Self {
if crate::is_no_progress() || crate::is_no_interaction() || !std::io::stderr().is_terminal()
{
let _ = cliclack::log::info(text.as_ref());
let s = cliclack::spinner();
Self(s)
} else {
let s = cliclack::spinner();
s.start(text.as_ref());
Self(s)
}
}
pub fn stop(self, text: impl AsRef<str>) {
let s = self.0;
if crate::is_no_progress() {
let _ = cliclack::log::info(text.as_ref());
} else {
s.stop(text.as_ref());
}
}
pub fn success(self, text: impl AsRef<str>) {
let s = self.0;
if crate::is_no_progress() {
let _ = cliclack::log::success(text.as_ref());
} else {
s.stop(text.as_ref());
}
}
pub fn error(self, text: impl AsRef<str>) {
let s = self.0;
if crate::is_no_progress() {
let _ = cliclack::log::error(text.as_ref());
} else {
s.error(text.as_ref());
}
}
}
pub struct BytesProgress {
enabled: bool,
total: u64,
current: u64,
started: Instant,
last_msg: Instant,
width: usize,
// Sticky ETA to carry through zero-speed stalls
last_eta_secs: Option<f64>,
}
impl BytesProgress {
pub fn start(total: u64, text: &str, initial: u64) -> Self {
let enabled = !(crate::is_no_progress()
|| crate::is_no_interaction()
|| !std::io::stderr().is_terminal()
|| total == 0);
if !enabled {
let _ = cliclack::log::info(text);
}
let mut me = Self {
enabled,
total,
current: initial.min(total),
started: Instant::now(),
last_msg: Instant::now(),
width: 40,
last_eta_secs: None,
};
me.draw();
me
}
fn human_bytes(n: u64) -> String {
const KB: f64 = 1024.0;
const MB: f64 = 1024.0 * KB;
const GB: f64 = 1024.0 * MB;
let x = n as f64;
if x >= GB {
format!("{:.2} GiB", x / GB)
} else if x >= MB {
format!("{:.2} MiB", x / MB)
} else if x >= KB {
format!("{:.2} KiB", x / KB)
} else {
format!("{} B", n)
}
}
// Elapsed formatting is used for stable, finite durations. For ETA, we guard
// against zero-speed or unstable estimates separately via `format_eta`.
fn refresh_allowed(&mut self) -> (f64, f64) {
let now = Instant::now();
let since_last = now.duration_since(self.last_msg);
if since_last < Duration::from_millis(100) {
// Too soon to refresh; keep previous ETA if any
let eta = self.last_eta_secs.unwrap_or(f64::INFINITY);
return (0.0, eta);
}
self.last_msg = now;
let elapsed = now.duration_since(self.started).as_secs_f64().max(0.001);
let speed = (self.current as f64) / elapsed;
let remaining = self.total.saturating_sub(self.current) as f64;
// If speed is effectively zero, carry ETA forward and add wall time.
const EPS: f64 = 1e-6;
let eta = if speed <= EPS {
let prev = self.last_eta_secs.unwrap_or(f64::INFINITY);
if prev.is_finite() {
prev + since_last.as_secs_f64()
} else {
prev
}
} else {
remaining / speed
};
// Remember only finite ETAs to use during stalls
if eta.is_finite() {
self.last_eta_secs = Some(eta);
}
(speed, eta)
}
fn format_elapsed(seconds: f64) -> String {
let total = seconds.round() as u64;
let h = total / 3600;
let m = (total % 3600) / 60;
let s = total % 60;
if h > 0 { format!("{:02}:{:02}:{:02}", h, m, s) } else { format!("{:02}:{:02}", m, s) }
}
fn format_eta(seconds: f64) -> String {
// If ETA is not finite (e.g., divide-by-zero speed) or unreasonably large,
// show a placeholder rather than overflowing into huge values.
if !seconds.is_finite() {
return "".to_string();
}
// Cap ETA display to 99:59:59 to avoid silly numbers; beyond that, show placeholder.
const CAP_SECS: f64 = 99.0 * 3600.0 + 59.0 * 60.0 + 59.0;
if seconds > CAP_SECS {
return "".to_string();
}
Self::format_elapsed(seconds)
}
fn draw(&mut self) {
if !self.enabled { return; }
let (speed, eta) = self.refresh_allowed();
let elapsed = Instant::now().duration_since(self.started).as_secs_f64();
// Build bar
let width = self.width.max(10);
let filled = ((self.current as f64 / self.total.max(1) as f64) * width as f64).round() as usize;
let filled = filled.min(width);
let mut bar = String::with_capacity(width);
for _ in 0..filled { bar.push('■'); }
for _ in filled..width { bar.push('□'); }
let line = format!(
"[{}] {} [{}] ({}/{} at {}/s)",
Self::format_elapsed(elapsed),
bar,
Self::format_eta(eta),
Self::human_bytes(self.current),
Self::human_bytes(self.total),
Self::human_bytes(speed.max(0.0) as u64),
);
eprint!("\r{}\x1b[K", line);
let _ = io::stderr().flush();
}
pub fn inc(&mut self, delta: u64) {
self.current = self.current.saturating_add(delta).min(self.total);
self.draw();
}
pub fn stop(mut self, text: &str) {
if self.enabled {
self.draw();
eprintln!();
} else {
let _ = cliclack::log::info(text);
}
}
pub fn error(mut self, text: &str) {
if self.enabled {
self.draw();
eprintln!();
let _ = cliclack::log::error(text);
} else {
let _ = cliclack::log::error(text);
}
} }
} }

View File

@@ -1,125 +1,122 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use std::io::IsTerminal as _; use std::io::IsTerminal as _;
/// Manages a set of per-file progress bars plus a top aggregate bar. pub struct FileProgress {
pub struct ProgressManager {
enabled: bool, enabled: bool,
mp: Option<MultiProgress>, file_bars: Vec<cliclack::ProgressBar>,
per: Vec<ProgressBar>, total_bar: Option<cliclack::ProgressBar>,
total: Option<ProgressBar>,
completed: usize, completed: usize,
total_file_count: usize,
} }
impl ProgressManager { impl FileProgress {
/// Create a new manager with the given enabled flag.
pub fn new(enabled: bool) -> Self { pub fn new(enabled: bool) -> Self {
Self { enabled, mp: None, per: Vec::new(), total: None, completed: 0 } Self {
enabled,
file_bars: Vec::new(),
total_bar: None,
completed: 0,
total_file_count: 0,
}
} }
/// Create a manager that enables bars when `n > 1`, stderr is a TTY, and not quiet. pub fn default_for_files(file_count: usize) -> Self {
pub fn default_for_files(n: usize) -> Self { let enabled = file_count > 1
let enabled = n > 1 && std::io::stderr().is_terminal() && !crate::is_quiet() && !crate::is_no_progress(); && std::io::stderr().is_terminal()
&& !crate::is_quiet()
&& !crate::is_no_progress();
Self::new(enabled) Self::new(enabled)
} }
/// Initialize bars for the given file labels. If disabled or single file, no-op.
pub fn init_files(&mut self, labels: &[String]) { pub fn init_files(&mut self, labels: &[String]) {
self.total_file_count = labels.len();
if !self.enabled || labels.len() <= 1 { if !self.enabled || labels.len() <= 1 {
// No bars in single-file mode or when disabled
self.enabled = false; self.enabled = false;
return; return;
} }
let mp = MultiProgress::new(); let total = cliclack::progress_bar(labels.len() as u64);
// Aggregate bar at the top total.start("Total");
let total = mp.add(ProgressBar::new(labels.len() as u64)); self.total_bar = Some(total);
total.set_style(ProgressStyle::with_template("{prefix} [{bar:40.cyan/blue}] {pos}/{len}")
.unwrap()
.progress_chars("=>-"));
total.set_prefix("Total");
self.total = Some(total);
// Per-file bars
for label in labels { for label in labels {
let pb = mp.add(ProgressBar::new(100)); let pb = cliclack::progress_bar(100);
pb.set_style(ProgressStyle::with_template("{prefix} [{bar:40.green/black}] {pos}% {msg}") pb.start(label);
.unwrap() self.file_bars.push(pb);
.progress_chars("=>-"));
pb.set_position(0);
pb.set_prefix(label.clone());
self.per.push(pb);
} }
self.mp = Some(mp);
} }
/// Returns true when bars are enabled (multi-file TTY mode). pub fn is_enabled(&self) -> bool {
pub fn is_enabled(&self) -> bool { self.enabled } self.enabled
/// Get a clone of the per-file progress bar at index, if enabled.
pub fn per_bar(&self, idx: usize) -> Option<ProgressBar> {
if !self.enabled { return None; }
self.per.get(idx).cloned()
} }
/// Get a clone of the aggregate (total) progress bar, if enabled. pub fn set_file_message(&mut self, idx: usize, message: &str) {
pub fn total_bar(&self) -> Option<ProgressBar> { if !self.enabled {
if !self.enabled { return None; } return;
self.total.as_ref().cloned() }
if let Some(pb) = self.file_bars.get_mut(idx) {
pb.set_message(message);
}
}
pub fn set_file_percent(&mut self, idx: usize, percent: u64) {
if !self.enabled {
return;
}
if let Some(pb) = self.file_bars.get_mut(idx) {
let p = percent.min(100);
pb.set_message(format!("{p}%"));
}
} }
/// Mark a file as finished (set to 100% and update total counter).
pub fn mark_file_done(&mut self, idx: usize) { pub fn mark_file_done(&mut self, idx: usize) {
if !self.enabled { return; } if !self.enabled {
if let Some(pb) = self.per.get(idx) { return;
pb.set_position(100); }
pb.finish_with_message("done"); if let Some(pb) = self.file_bars.get_mut(idx) {
pb.stop("done");
} }
self.completed += 1; self.completed += 1;
if let Some(total) = &self.total { total.set_position(self.completed as u64); } if let Some(total) = &mut self.total_bar {
total.inc(1);
if self.completed >= self.total_file_count {
total.stop("all done");
}
}
}
pub fn finish_total(&mut self, message: &str) {
if !self.enabled {
return;
}
if let Some(total) = &mut self.total_bar {
total.stop(message);
}
} }
} }
/// A simple reporter for displaying progress messages in the terminal.
/// Provides different output formatting based on whether the environment is interactive or not.
#[derive(Debug)] #[derive(Debug)]
pub struct ProgressReporter { pub struct ProgressReporter {
non_interactive: bool, non_interactive: bool,
} }
impl ProgressReporter { impl ProgressReporter {
/// Creates a new progress reporter.
///
/// # Arguments
///
/// * `non_interactive` - Whether the output should be formatted for non-interactive environments.
pub fn new(non_interactive: bool) -> Self { pub fn new(non_interactive: bool) -> Self {
Self { non_interactive } Self { non_interactive }
} }
/// Displays a progress step message.
///
/// # Arguments
///
/// * `message` - The message to display for this progress step.
pub fn step(&mut self, message: &str) { pub fn step(&mut self, message: &str) {
if self.non_interactive { if self.non_interactive {
eprintln!("[..] {message}"); let _ = cliclack::log::info(format!("[..] {message}"));
} else { } else {
eprintln!("{message}"); let _ = cliclack::log::info(format!("{message}"));
} }
} }
/// Displays a completion message.
///
/// # Arguments
///
/// * `message` - The message to display when a task is completed.
pub fn finish_with_message(&mut self, message: &str) { pub fn finish_with_message(&mut self, message: &str) {
if self.non_interactive { if self.non_interactive {
eprintln!("[ok] {message}"); let _ = cliclack::log::info(format!("[ok] {message}"));
} else { } else {
eprintln!("{message}"); let _ = cliclack::log::info(format!("{message}"));
} }
} }
} }

View File

@@ -1,11 +1,12 @@
[package] [package]
name = "polyscribe-host" name = "polyscribe-host"
version = "0.1.0" version.workspace = true
edition = "2024" edition.workspace = true
[dependencies] [dependencies]
anyhow = "1.0.99" anyhow = { workspace = true }
serde = { version = "1.0.219", features = ["derive"] } serde = { workspace = true, features = ["derive"] }
serde_json = "1.0.142" serde_json = { workspace = true }
tokio = { version = "1.47.1", features = ["rt-multi-thread", "process", "io-util"] } tokio = { workspace = true, features = ["rt-multi-thread", "process", "io-util"] }
which = "6.0.3" which = { workspace = true }
directories = { workspace = true }

View File

@@ -1,8 +1,7 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use serde::Deserialize; use std::process::Stdio;
use std::{ use std::{
env, env, fs,
fs,
os::unix::fs::PermissionsExt, os::unix::fs::PermissionsExt,
path::Path, path::Path,
}; };
@@ -10,7 +9,6 @@ use tokio::{
io::{AsyncBufReadExt, BufReader}, io::{AsyncBufReadExt, BufReader},
process::{Child as TokioChild, Command}, process::{Child as TokioChild, Command},
}; };
use std::process::Stdio;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct PluginInfo { pub struct PluginInfo {
@@ -25,27 +23,19 @@ impl PluginManager {
pub fn list(&self) -> Result<Vec<PluginInfo>> { pub fn list(&self) -> Result<Vec<PluginInfo>> {
let mut plugins = Vec::new(); let mut plugins = Vec::new();
// Scan PATH entries for executables starting with "polyscribe-plugin-"
if let Ok(path) = env::var("PATH") { if let Ok(path) = env::var("PATH") {
for dir in env::split_paths(&path) { for dir in env::split_paths(&path) {
if let Ok(read_dir) = fs::read_dir(&dir) { scan_dir_for_plugins(&dir, &mut plugins);
for entry in read_dir.flatten() {
let path = entry.path();
if let Some(fname) = path.file_name().and_then(|s| s.to_str()) {
if fname.starts_with("polyscribe-plugin-") && is_executable(&path) {
let name = fname.trim_start_matches("polyscribe-plugin-").to_string();
plugins.push(PluginInfo {
name,
path: path.to_string_lossy().to_string(),
});
}
}
}
}
} }
} }
// TODO: also scan XDG data plugins dir for symlinks/binaries if let Some(dirs) = directories::ProjectDirs::from("dev", "polyscribe", "polyscribe") {
let plugin_dir = dirs.data_dir().join("plugins");
scan_dir_for_plugins(&plugin_dir, &mut plugins);
}
plugins.sort_by(|a, b| a.path.cmp(&b.path));
plugins.dedup_by(|a, b| a.path == b.path);
Ok(plugins) Ok(plugins)
} }
@@ -89,7 +79,8 @@ impl PluginManager {
fn resolve(&self, name: &str) -> Result<String> { fn resolve(&self, name: &str) -> Result<String> {
let bin = format!("polyscribe-plugin-{name}"); let bin = format!("polyscribe-plugin-{name}");
let path = which::which(&bin).with_context(|| format!("plugin not found in PATH: {bin}"))?; let path =
which::which(&bin).with_context(|| format!("plugin not found in PATH: {bin}"))?;
Ok(path.to_string_lossy().to_string()) Ok(path.to_string_lossy().to_string())
} }
} }
@@ -102,17 +93,27 @@ fn is_executable(path: &Path) -> bool {
{ {
if let Ok(meta) = fs::metadata(path) { if let Ok(meta) = fs::metadata(path) {
let mode = meta.permissions().mode(); let mode = meta.permissions().mode();
// if any execute bit is set
return mode & 0o111 != 0; return mode & 0o111 != 0;
} }
} }
// Fallback for non-unix (treat files as candidates)
true true
} }
#[allow(dead_code)] fn scan_dir_for_plugins(dir: &Path, out: &mut Vec<PluginInfo>) {
#[derive(Debug, Deserialize)] if let Ok(read_dir) = fs::read_dir(dir) {
struct Capability { for entry in read_dir.flatten() {
command: String, let path = entry.path();
summary: String, if let Some(fname) = path.file_name().and_then(|s| s.to_str())
&& fname.starts_with("polyscribe-plugin-")
&& is_executable(&path)
{
let name = fname.trim_start_matches("polyscribe-plugin-").to_string();
out.push(PluginInfo {
name,
path: path.to_string_lossy().to_string(),
});
}
}
}
} }

View File

@@ -1,8 +1,8 @@
[package] [package]
name = "polyscribe-protocol" name = "polyscribe-protocol"
version = "0.1.0" version.workspace = true
edition = "2024" edition.workspace = true
[dependencies] [dependencies]
serde = { version = "1.0.219", features = ["derive"] } serde = { workspace = true, features = ["derive"] }
serde_json = "1.0.142" serde_json = { workspace = true }

View File

@@ -32,18 +32,20 @@ Run locally
Models during development Models during development
- Interactive downloader: - Interactive downloader:
- cargo run -- --download-models - cargo run -- models download
- Non-interactive update (checks sizes/hashes, downloads if missing): - Non-interactive update (checks sizes/hashes, downloads if missing):
- cargo run -- --update-models --no-interaction -q - cargo run -- models update --no-interaction -q
Tests Tests
- Run all tests: - Run all tests:
- cargo test - cargo test
- The test suite includes CLI-oriented integration tests and unit tests. Some tests simulate GPU detection using env vars (POLYSCRIBE_TEST_FORCE_*). Do not rely on these flags in production code. - The test suite includes CLI-oriented integration tests and unit tests. Some tests simulate GPU detection using env vars (POLYSCRIBE_TEST_FORCE_*). Do not rely on these flags in production code.
Clippy Clippy & formatting
- Run lint checks and treat warnings as errors: - Run lint checks and treat warnings as errors:
- cargo clippy --all-targets -- -D warnings - cargo clippy --all-targets -- -D warnings
- Check formatting:
- cargo fmt --all -- --check
- Common warnings can often be fixed by simplifying code, removing unused imports, and following idiomatic patterns. - Common warnings can often be fixed by simplifying code, removing unused imports, and following idiomatic patterns.
Code layout Code layout
@@ -61,10 +63,10 @@ Adding a feature
Running the model downloader Running the model downloader
- Interactive: - Interactive:
- cargo run -- --download-models - cargo run -- models download
- Non-interactive suggestions for CI: - Non-interactive suggestions for CI:
- POLYSCRIBE_MODELS_DIR=$PWD/models \ - POLYSCRIBE_MODELS_DIR=$PWD/models \
cargo run -- --update-models --no-interaction -q cargo run -- models update --no-interaction -q
Env var examples for local testing Env var examples for local testing
- Use a local copy of models and a specific model file: - Use a local copy of models and a specific model file:

View File

@@ -30,10 +30,10 @@ CLI reference
- Choose runtime backend. Default is auto (prefers CUDA → HIP → Vulkan → CPU), depending on detection. - Choose runtime backend. Default is auto (prefers CUDA → HIP → Vulkan → CPU), depending on detection.
- --gpu-layers N - --gpu-layers N
- Number of layers to offload to the GPU when supported. - Number of layers to offload to the GPU when supported.
- --download-models - models download
- Launch interactive model downloader (lists Hugging Face models; multi-select to download). - Launch interactive model downloader (lists Hugging Face models; multi-select to download).
- Controls: Use Up/Down to navigate, Space to toggle selections, and Enter to confirm. Models are grouped by base (e.g., tiny, base, small). - Controls: Use Up/Down to navigate, Space to toggle selections, and Enter to confirm. Models are grouped by base (e.g., tiny, base, small).
- --update-models - models update
- Verify/update local models by comparing sizes and hashes with the upstream manifest. - Verify/update local models by comparing sizes and hashes with the upstream manifest.
- -v, --verbose (repeatable) - -v, --verbose (repeatable)
- Increase log verbosity; use -vv for very detailed logs. - Increase log verbosity; use -vv for very detailed logs.
@@ -42,6 +42,9 @@ CLI reference
- --no-interaction - --no-interaction
- Disable all interactive prompts (for CI). Combine with env vars to control behavior. - Disable all interactive prompts (for CI). Combine with env vars to control behavior.
- Subcommands: - Subcommands:
- models download: Launch interactive model downloader.
- models update: Verify/update local models (non-interactive).
- plugins list|info|run: Discover and run plugins.
- completions <shell>: Write shell completion script to stdout. - completions <shell>: Write shell completion script to stdout.
- man: Write a man page to stdout. - man: Write a man page to stdout.

View File

@@ -1,5 +1,4 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Stub plugin: tubescribe
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use clap::Parser; use clap::Parser;
@@ -36,7 +35,6 @@ fn main() -> Result<()> {
serve_once()?; serve_once()?;
return Ok(()); return Ok(());
} }
// Default: show capabilities (friendly behavior if run without flags)
let caps = psp::Capabilities { let caps = psp::Capabilities {
name: "tubescribe".to_string(), name: "tubescribe".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(), version: env!("CARGO_PKG_VERSION").to_string(),
@@ -49,14 +47,12 @@ fn main() -> Result<()> {
} }
fn serve_once() -> Result<()> { fn serve_once() -> Result<()> {
// Read exactly one line (one request)
let stdin = std::io::stdin(); let stdin = std::io::stdin();
let mut reader = BufReader::new(stdin.lock()); let mut reader = BufReader::new(stdin.lock());
let mut line = String::new(); let mut line = String::new();
reader.read_line(&mut line).context("failed to read request line")?; reader.read_line(&mut line).context("failed to read request line")?;
let req: psp::JsonRpcRequest = serde_json::from_str(line.trim()).context("invalid JSON-RPC request")?; let req: psp::JsonRpcRequest = serde_json::from_str(line.trim()).context("invalid JSON-RPC request")?;
// Simulate doing some work with progress
emit(&psp::StreamItem::progress(5, Some("start".into()), Some("initializing".into())))?; emit(&psp::StreamItem::progress(5, Some("start".into()), Some("initializing".into())))?;
std::thread::sleep(std::time::Duration::from_millis(50)); std::thread::sleep(std::time::Duration::from_millis(50));
emit(&psp::StreamItem::progress(25, Some("probe".into()), Some("probing sources".into())))?; emit(&psp::StreamItem::progress(25, Some("probe".into()), Some("probing sources".into())))?;
@@ -65,7 +61,6 @@ fn serve_once() -> Result<()> {
std::thread::sleep(std::time::Duration::from_millis(50)); std::thread::sleep(std::time::Duration::from_millis(50));
emit(&psp::StreamItem::progress(90, Some("finalize".into()), Some("finalizing".into())))?; emit(&psp::StreamItem::progress(90, Some("finalize".into()), Some("finalizing".into())))?;
// Handle method and produce result
let result = match req.method.as_str() { let result = match req.method.as_str() {
"generate_metadata" => { "generate_metadata" => {
let title = "Canned title"; let title = "Canned title";
@@ -78,7 +73,6 @@ fn serve_once() -> Result<()> {
}) })
} }
other => { other => {
// Unknown method
let err = psp::StreamItem::err(req.id.clone(), -32601, format!("Method not found: {}", other), None); let err = psp::StreamItem::err(req.id.clone(), -32601, format!("Method not found: {}", other), None);
emit(&err)?; emit(&err)?;
return Ok(()); return Ok(());