Compare commits

...

16 Commits

Author SHA1 Message Date
840383fcf7 [feat] add JSON and quiet output modes for models subcommands, update UI suppression logic, and enhance CLI test coverage
Some checks failed
CI / build (push) Has been cancelled
2025-08-27 23:58:57 +02:00
1982e9b48b [feat] add empty state message for models ls command output 2025-08-27 23:51:21 +02:00
0128bf2eec [feat] add ModelManager with caching, manifest management, and Hugging Face API integration 2025-08-27 20:56:05 +02:00
da5a76d253 [refactor] Refactor project into proper rust workspace. 2025-08-27 18:28:37 +02:00
5ec297397e [refactor] rename and simplify ProgressManager to FileProgress, enhance caching logic, update Hugging Face API integration, and clean up unused comments
Some checks failed
CI / build (push) Has been cancelled
2025-08-15 11:24:50 +02:00
cbf48a0452 docs: align CLI docs to models subcommands; host: scan XDG plugin dir; ci: add GitHub Actions; chore: add CHANGELOG 2025-08-14 11:16:50 +02:00
0a249f2197 [refactor] improve code readability, streamline initialization, update error handling, and format multi-line statements for consistency 2025-08-14 11:06:37 +02:00
0573369b81 [refactor] propagate no-progress and no-interaction flags, enhance prompt handling, and update progress bar logic with cliclack 2025-08-14 10:34:52 +02:00
9841550dcc [refactor] replace indicatif with cliclack for progress and logging, updating affected modules and dependencies 2025-08-14 03:31:00 +02:00
53119cd0ab [refactor] enhance model management with metadata enrichment, new API integration, and manifest resolution 2025-08-13 22:44:51 +02:00
144b01d591 [refactor] update Cargo.lock with new dependency additions and version bumps 2025-08-13 14:45:43 +02:00
ffd451b404 [refactor] remove unused test suites, examples, CI docs, and PR description file 2025-08-13 14:26:18 +02:00
5c64677e79 [refactor] streamline crate structure, update dependencies, and integrate CLI functionalities 2025-08-13 14:05:13 +02:00
128db0f733 [refactor] remove backend and library modules, consolidating features into main crate 2025-08-13 13:35:53 +02:00
06fd3efd1f Merge remote-tracking branch 'origin/main' into dev 2025-08-13 11:48:37 +02:00
49513d5099 [chore] remove outdated changelog file 2025-08-13 11:48:18 +02:00
48 changed files with 5096 additions and 4116 deletions

17
.cargo/config.toml Normal file
View File

@@ -0,0 +1,17 @@
# SPDX-License-Identifier: MIT
[build]
# Make target-dir consistent across workspace for better cache reuse.
target-dir = "target"
[profile.dev]
opt-level = 1
debug = true
incremental = true
[profile.release]
# Reasonable defaults for CLI apps/libraries
lto = "thin"
codegen-units = 1
strip = "debuginfo"
opt-level = 3

33
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,33 @@
name: CI
on:
push:
branches: [ dev, main ]
pull_request:
branches: [ dev, main ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Rust
uses: dtolnay/rust-toolchain@stable
- name: Cache cargo registry and target
uses: Swatinem/rust-cache@v2
- name: Install components
run: rustup component add clippy rustfmt
- name: Cargo fmt
run: cargo fmt --all -- --check
- name: Clippy
run: cargo clippy --workspace --all-targets -- -D warnings
- name: Test
run: cargo test --workspace --all --locked

View File

@@ -1,40 +1,14 @@
# PolyScribe Refactor toward Rust 2024 — Incremental Patches
# Changelog
This changelog documents each incremental step applied to keep the build green while moving the codebase toward Rust 2024 idioms.
All notable changes to this project will be documented in this file.
## 1) Formatting only (rustfmt)
- Ran `cargo fmt` across the repository.
- No semantic changes.
- Build status: OK (`cargo build` succeeded).
The format is based on Keep a Changelog, and this project adheres to Semantic Versioning.
## 2) Lints — initial fixes (non-pedantic)
- Adjusted crate lint policy in `src/lib.rs`:
- Replaced `#![warn(clippy::pedantic, clippy::nursery, clippy::cargo)]` with `#![warn(clippy::all)]` to align with the plan (skip pedantic/nursery for now).
- Added comment/TODO to revisit stricter lints in a later pass.
- Fixed several clippy warnings that were causing `cargo clippy --all-targets` to error under tests:
- `src/backend.rs`: conditionally import `libloading::Library` only for non-test builds and mark `names` parameter as used in test cfg to avoid unused warnings; keep `check_lib()` sideeffect free during tests.
- `src/models.rs`: removed an unused `std::io::Write` import in test module.
- `src/main.rs` (unit tests): imported `polyscribe::format_srt_time` explicitly and removed a duplicate `use super::*;` to fix unresolved name and unused import warnings under clippy test builds.
- Build/Clippy status:
- `cargo build`: OK.
- `cargo clippy --all-targets`: OK (only warnings remain; no errors).
## Unreleased
## 3) Module hygiene
- Verified crate structure:
- Library crate (`src/lib.rs`) exposes a coherent API and reexports `backend` and `models` via `pub mod`.
- Binary (`src/main.rs`) consumes the library API through `polyscribe::...` paths.
- No structural changes required. Build status: OK.
### Changed
- Docs: Replace `--download-models`/`--update-models` flags with `models download`/`models update` subcommands in `README.md`, `docs/usage.md`, and `docs/development.md`.
- Host: Plugin discovery now scans `$XDG_DATA_HOME/polyscribe/plugins` (platform equivalent via `directories`) in addition to `PATH`.
- CI: Add GitHub Actions workflow to run fmt, clippy (warnings as errors), and tests for pushes and PRs.
## 4) Edition
- The project already targets `edition = "2024"` in Cargo.toml.
- Verified that the project compiles under Rust 2024. No changes needed.
- TODO: If stricter lints or new features from 2024 edition introduce issues in future steps, document blockers here.
## 5) Error handling
- The codebase already returns `anyhow::Result` in the binary and uses contextual errors widely.
- No `unwrap`/`expect` usages in production paths required attention in this pass.
- Build status: OK.
## Next planned steps (not yet applied in this changelog)
- Gradually fix remaining clippy warnings (e.g., `uninlined_format_args`, small style nits) in small, compilegreen patches.
- Optionally reenable `clippy::pedantic`, `clippy::nursery`, and `clippy::cargo` once warnings are significantly reduced, then address nonbreaking warnings.

View File

@@ -1,32 +1,26 @@
# Contributing to PolyScribe
# Contributing
Thanks for your interest in contributing! This guide explains the workflow and the checklist to follow before opening a Pull Request.
Thank you for your interest in contributing!
Workflow (fork → branch → PR)
1) Fork the repository to your account.
2) Create a feature branch:
- git checkout -b feat/short-description
3) Make changes with focused commits and good messages.
4) Run the checklist below.
5) Push and open a Pull Request against the main repository.
Development setup
- Install Rust via rustup.
- Ensure ffmpeg is installed and available on PATH.
- For GPU builds, install the appropriate runtime (CUDA/ROCm/Vulkan) and enable the matching features.
Developer checklist (before opening a PR)
- Build:
- cargo build (preferably without warnings)
- Tests:
- cargo test (all tests pass)
- Lints:
- cargo clippy --all-targets -- -D warnings (fix warnings)
- Documentation:
- Update README/docs for user-visible changes
- Update CHANGELOG.md if applicable
- Tests for changes:
- Add or update tests for bug fixes and new features where reasonable
Coding guidelines
- Prefer small, focused changes.
- Add tests where reasonable.
- Keep user-facing changes documented in README/docs.
- Run clippy and fix warnings.
Local development tips
- Use `cargo run -- <args>` during development.
- For faster feedback, keep examples in the examples/ folder handy.
- Keep functions small and focused; prefer clear error messages with context.
CI checklist
- Build: cargo build --all-targets --locked
- Tests: cargo test --all --locked
- Lints: cargo clippy --all-targets -- -D warnings
- Optional: smoke-run examples inline (from README):
- ./target/release/polyscribe --update-models --no-interaction -q
- ./target/release/polyscribe -o output samples/podcast_clip.mp3
Code of conduct
- Be respectful and constructive. Assume good intent.
Notes
- For GPU features, use --features gpu-cuda|gpu-hip|gpu-vulkan as needed in your local runs.
- For docs-only changes, please still ensure the project builds.

1341
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,50 @@ members = [
"crates/polyscribe-protocol",
"crates/polyscribe-host",
"crates/polyscribe-cli",
"plugins/polyscribe-plugin-tubescribe",
]
resolver = "2"
resolver = "3"
[workspace.package]
edition = "2024"
version = "0.1.0"
license = "MIT"
rust-version = "1.89"
# Optional: Keep dependency versions consistent across members
[workspace.dependencies]
thiserror = "1.0.69"
serde = { version = "1.0.219", features = ["derive"] }
anyhow = "1.0.99"
libc = "0.2.175"
toml = "0.8.23"
serde_json = "1.0.142"
chrono = { version = "0.4.41", features = ["serde"] }
sha2 = "0.10.9"
which = "6.0.3"
tokio = { version = "1.47.1", features = ["rt-multi-thread", "macros"] }
clap = { version = "4.5.44", features = ["derive"] }
directories = "5.0.1"
whisper-rs = "0.14.3"
cliclack = "0.3.6"
clap_complete = "4.5.57"
clap_mangen = "0.2.29"
# Additional shared deps used across members
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
reqwest = { version = "0.12.7", default-features = false, features = ["blocking", "rustls-tls", "gzip", "json"] }
hex = "0.4.3"
tempfile = "3.12.0"
assert_cmd = "2.0.16"
[workspace.lints.rust]
unused_imports = "deny"
dead_code = "warn"
[profile.release]
lto = "fat"
codegen-units = 1
panic = "abort"
[profile.dev]
panic = "unwind"

View File

@@ -1,99 +0,0 @@
# Pull Request: PolyScribe workspace + plugin system
This PR refactors the repository into a multi-crate Cargo workspace and adds a minimal, working plugin system scaffold over NDJSON/stdio, while preserving existing CLI behavior. It also introduces a stub plugin `polyscribe-plugin-tubescribe` and documentation updates.
Differences & Adaptations
- The repository already contained most of the workspace and plugin scaffolding; this PR focuses on completing and verifying the setup, fixing a symlink path issue in the plugin Makefile, and adding documentation and minor cleanup.
- Existing CLI commands and flags are preserved; a new `plugins` command group is added (list/info/run) without breaking existing outputs.
## Commits
### 1) chore(workspace): scaffold workspace + move crates
Rationale
- Ensure workspace members and resolver are properly defined. The repository already contained these crates; this commit documents the layout and confirms no absolute paths are used.
Updated files (representative snapshots)
- Cargo.toml (workspace):
```
[workspace]
members = [
"crates/polyscribe-core",
"crates/polyscribe-protocol",
"crates/polyscribe-host",
"crates/polyscribe-cli",
"plugins/polyscribe-plugin-tubescribe",
]
resolver = "2"
```
Repository tree after this commit (abridged)
```
.
├── Cargo.toml
├── crates
│ ├── polyscribe-cli
│ ├── polyscribe-core
│ ├── polyscribe-host
│ └── polyscribe-protocol
└── plugins
└── polyscribe-plugin-tubescribe
```
### 2) feat(plugins): host/stdio runner + CLI plugin commands
Rationale
- Provide plugin discovery and stdio NDJSON JSON-RPC runner in host crate; add `plugins` subcommands to CLI. These were already implemented; this commit verifies and documents behavior.
Updated files (representative snapshots)
- crates/polyscribe-host/src/lib.rs: discover(), capabilities(), run_method().
- crates/polyscribe-cli/src/main.rs: `plugins list|info|run` wired to host, forwarding progress.
Repository tree after this commit: unchanged from above.
### 3) feat(plugin): add stub polyscribe-plugin-tubescribe + docs
Rationale (risky change explained)
- Fixed a symlink path issue in the Makefile by switching from $(PWD) to $(CURDIR) to avoid brittle relative paths. This ensures discovery finds the plugin consistently on all shells.
- Removed an unused import to keep clippy clean.
- Added README docs covering workspace layout and verification commands.
Updated files (full contents included in repo):
- plugins/polyscribe-plugin-tubescribe/Makefile
- plugins/polyscribe-plugin-tubescribe/src/main.rs
- README.md (appended Workspace & Plugins section)
Repository tree after this commit (abridged)
```
.
├── Cargo.toml
├── README.md
├── crates
│ ├── polyscribe-cli
│ ├── polyscribe-core
│ ├── polyscribe-host
│ └── polyscribe-protocol
└── plugins
└── polyscribe-plugin-tubescribe
├── Cargo.toml
├── Makefile
└── src/main.rs
```
## Verification commands
- Build the workspace:
- cargo build --workspace --all-targets
- Show CLI help and plugin subcommands:
- cargo run -p polyscribe-cli -- --help
- Discover plugins (before linking, likely empty):
- cargo run -p polyscribe-cli -- plugins list
- Build and link the stub plugin:
- make -C plugins/polyscribe-plugin-tubescribe link
- Discover again:
- cargo run -p polyscribe-cli -- plugins list
- Show plugin capabilities:
- cargo run -p polyscribe-cli -- plugins info tubescribe
- Run a plugin command and observe progress + JSON result:
- cargo run -p polyscribe-cli -- plugins run tubescribe generate_metadata --json '{"input":{"kind":"text","summary":"hello world"}}'
All acceptance checks pass locally.

151
README.md
View File

@@ -1,127 +1,68 @@
# PolyScribe
PolyScribe is a fast, local-first CLI for transcribing audio/video and merging existing JSON transcripts. It uses whisper-rs under the hood, can discover and download Whisper models automatically, and supports CPU and optional GPU backends (CUDA, ROCm/HIP, Vulkan).
Local-first transcription and plugins.
Key features
- Transcribe audio and common video files using ffmpeg for audio extraction.
- Merge multiple JSON transcripts, or merge and also keep per-file outputs.
- Model management: interactive downloader and non-interactive updater with hash verification.
- GPU backend selection at runtime; auto-detects available accelerators.
- Clean outputs (JSON and SRT), speaker naming prompts, and useful logging controls.
## Features
Prerequisites
- Rust toolchain (rustup recommended)
- ffmpeg available on PATH
- Optional for GPU acceleration at runtime: CUDA, ROCm/HIP, or Vulkan drivers (match your build features)
- **Local-first**: Works offline with downloaded models
- **Multiple backends**: CPU, CUDA, ROCm/HIP, and Vulkan support
- **Plugin system**: Extensible via JSON-RPC plugins
- **Model management**: Automatic download and verification of Whisper models
- **Manifest caching**: Local cache for Hugging Face model manifests to reduce network requests
Installation
- Build from source (CPU-only by default):
- rustup install stable
- rustup default stable
- cargo build --release
- Binary path: ./target/release/polyscribe
- GPU builds (optional): build with features
- CUDA: cargo build --release --features gpu-cuda
- HIP: cargo build --release --features gpu-hip
- Vulkan: cargo build --release --features gpu-vulkan
## Model Management
Quickstart
1) Download a model (first run can prompt you):
- ./target/release/polyscribe --download-models
- In the interactive picker, use Up/Down to navigate, Space to toggle selections, and Enter to confirm. Models are grouped by base (e.g., tiny, base, small).
PolyScribe automatically manages Whisper models from Hugging Face:
2) Transcribe a file:
- ./target/release/polyscribe -v -o output my_audio.mp3
This writes JSON and SRT into the output directory with a date prefix.
```bash
# Download models interactively
polyscribe models download
Shell completions and man page
- Completions: ./target/release/polyscribe completions <bash|zsh|fish|powershell|elvish> > polyscribe.<ext>
- Then install into your shells completion directory.
- Man page: ./target/release/polyscribe man > polyscribe.1 (then copy to your manpath)
# Update existing models
polyscribe models update
Model locations
- Development (debug builds): ./models next to the project.
- Packaged/release builds: $XDG_DATA_HOME/polyscribe/models or ~/.local/share/polyscribe/models.
- Override via env var: POLYSCRIBE_MODELS_DIR=/path/to/models.
- Force a specific model file via env var: WHISPER_MODEL=/path/to/model.bin.
# Clear manifest cache (force fresh fetch)
polyscribe models clear-cache
```
Most-used CLI flags
- -o, --output FILE_OR_DIR: Output path base (date prefix added). If omitted, JSON prints to stdout.
- -m, --merge: Merge all inputs into one output; otherwise one output per input.
- --merge-and-separate: Write both merged output and separate per-input outputs (requires -o dir).
- --set-speaker-names: Prompt for a speaker label per input file.
- --update-models: Verify/update local models by size/hash against the upstream manifest.
- --download-models: Interactive model list + multi-select download.
- --language LANG: Language code hint (e.g., en, de). English-only models reject non-en hints.
- --gpu-backend [auto|cpu|cuda|hip|vulkan]: Select backend (auto by default).
- --gpu-layers N: Offload N layers to GPU when supported.
- -v/--verbose (repeatable): Increase log verbosity. -vv shows very detailed logs.
- -q/--quiet: Suppress non-error logs (stderr); does not silence stdout results.
- --no-interaction: Never prompt; suitable for CI.
### Manifest Caching
Minimal usage examples
- Transcribe an audio file to JSON/SRT:
- ./target/release/polyscribe -o output samples/podcast_clip.mp3
- Merge multiple transcripts into one:
- ./target/release/polyscribe -m -o output merged input/a.json input/b.json
- Update local models non-interactively (good for CI):
- ./target/release/polyscribe --update-models --no-interaction -q
The Hugging Face model manifest is cached locally to avoid repeated network requests:
Troubleshooting & docs
- docs/faq.md common issues and solutions (missing ffmpeg, GPU selection, model paths)
- docs/usage.md complete CLI reference and workflows
- docs/development.md build, run, and contribute locally
- docs/design.md architecture overview and decisions
- docs/release-packaging.md packaging notes for distributions
- docs/ci.md minimal CI checklist and job outline
- CONTRIBUTING.md PR checklist and workflow
- **Default TTL**: 24 hours
- **Cache location**: `$XDG_CACHE_HOME/polyscribe/manifest/` (or platform equivalent)
- **Environment variables**:
- `POLYSCRIBE_NO_CACHE_MANIFEST=1`: Disable caching
- `POLYSCRIBE_MANIFEST_TTL_SECONDS=3600`: Set custom TTL (in seconds)
CI status: [CI badge placeholder]
## Installation
Examples
See the examples/ directory for copy-paste scripts:
- examples/transcribe_file.sh
- examples/update_models.sh
- examples/download_models_interactive.sh
```bash
cargo install --path .
```
License
-------
This project is licensed under the MIT License — see the LICENSE file for details.
## Usage
```bash
# Transcribe audio/video
polyscribe transcribe input.mp4
---
# Merge multiple transcripts
polyscribe transcribe --merge input1.json input2.json
Workspace layout
- This repo is a Cargo workspace using resolver = "2".
- Members:
- crates/polyscribe-core — types, errors, config service, core helpers.
- crates/polyscribe-protocol — PSP/1 serde types for NDJSON over stdio.
- crates/polyscribe-host — plugin discovery/runner, progress forwarding.
- crates/polyscribe-cli — the CLI, using host + core.
- plugins/polyscribe-plugin-tubescribe — stub plugin used for verification.
# Use specific GPU backend
polyscribe transcribe --gpu-backend cuda input.mp4
```
Build and run
- Build all: cargo build --workspace --all-targets
- CLI help: cargo run -p polyscribe-cli -- --help
## Development
Plugins
- Build and link the example plugin into your XDG data plugin dir:
- make -C plugins/polyscribe-plugin-tubescribe link
- This creates a symlink at: $XDG_DATA_HOME/polyscribe/plugins/polyscribe-plugin-tubescribe (defaults to ~/.local/share on Linux).
- Discover installed plugins:
- cargo run -p polyscribe-cli -- plugins list
- Show a plugin's capabilities:
- cargo run -p polyscribe-cli -- plugins info tubescribe
- Run a plugin command (JSON-RPC over NDJSON via stdio):
- cargo run -p polyscribe-cli -- plugins run tubescribe generate_metadata --json '{"input":{"kind":"text","summary":"hello world"}}'
```bash
# Build
cargo build
Verification commands
- The above commands are used for acceptance; expected behavior:
- plugins list shows "tubescribe" once linked.
- plugins info tubescribe prints JSON capabilities.
- plugins run ... prints progress events and a JSON result.
# Run tests
cargo test
Notes
- No absolute paths are hardcoded; config and plugin dirs respect XDG on Linux and platform equivalents via directories.
- Plugins must be non-interactive (no TTY prompts). All interaction stays in the host/CLI.
- Config files are written atomically and support env overrides: POLYSCRIBE__SECTION__KEY=value.
# Run with verbose logging
cargo run -- --verbose transcribe input.mp4
```

View File

@@ -1,16 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
fn main() {
// Only run special build steps when gpu-vulkan feature is enabled.
let vulkan_enabled = std::env::var("CARGO_FEATURE_GPU_VULKAN").is_ok();
if !vulkan_enabled {
return;
}
// Placeholder: In a full implementation, we would invoke CMake for whisper.cpp with GGML_VULKAN=1.
// For now, emit a helpful note. Build will proceed; runtime Vulkan backend returns an explanatory error.
println!("cargo:rerun-if-changed=extern/whisper.cpp");
println!(
"cargo:warning=Building with gpu-vulkan: ensure Vulkan SDK/loader are installed. Future versions will compile whisper.cpp via CMake."
);
}

View File

@@ -1,24 +1,32 @@
[package]
name = "polyscribe-cli"
version = "0.1.0"
edition = "2024"
license = "MIT"
version.workspace = true
edition.workspace = true
[[bin]]
name = "polyscribe"
path = "src/main.rs"
[dependencies]
anyhow = "1.0.98"
clap = { version = "4.5.43", features = ["derive"] }
clap_complete = "4.5.28"
clap_mangen = "0.2"
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.142"
toml = "0.8"
chrono = { version = "0.4", features = ["clock"] }
cliclack = "0.3"
indicatif = "0.17"
polyscribe = { path = "../polyscribe-core" }
anyhow = { workspace = true }
clap = { workspace = true, features = ["derive"] }
clap_complete = { workspace = true }
clap_mangen = { workspace = true }
directories = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
tokio = { workspace = true, features = ["rt-multi-thread", "macros", "process", "fs"] }
tracing = { workspace = true }
tracing-subscriber = { workspace = true, features = ["fmt", "env-filter"] }
which = { workspace = true }
polyscribe-core = { path = "../polyscribe-core" }
polyscribe-host = { path = "../polyscribe-host" }
polyscribe-protocol = { path = "../polyscribe-protocol" }
[features]
# Optional GPU-specific flags can be forwarded down to core/host if needed
default = []
[dev-dependencies]
assert_cmd = { workspace = true }

View File

@@ -0,0 +1,191 @@
use clap::{Args, Parser, Subcommand, ValueEnum};
use std::path::PathBuf;
#[derive(Debug, Clone, ValueEnum)]
pub enum GpuBackend {
Auto,
Cpu,
Cuda,
Hip,
Vulkan,
}
#[derive(Debug, Clone, Args)]
pub struct OutputOpts {
/// Emit machine-readable JSON to stdout; suppress decorative logs
#[arg(long, global = true, action = clap::ArgAction::SetTrue)]
pub json: bool,
/// Reduce log chatter (errors only unless --json)
#[arg(long, global = true, action = clap::ArgAction::SetTrue)]
pub quiet: bool,
}
#[derive(Debug, Parser)]
#[command(
name = "polyscribe",
version,
about = "PolyScribe local-first transcription and plugins",
propagate_version = true,
arg_required_else_help = true,
)]
pub struct Cli {
/// Global output options
#[command(flatten)]
pub output: OutputOpts,
/// Increase verbosity (-v, -vv)
#[arg(short, long, action = clap::ArgAction::Count)]
pub verbose: u8,
/// Never prompt for user input (non-interactive mode)
#[arg(long, default_value_t = false)]
pub no_interaction: bool,
/// Disable progress bars/spinners
#[arg(long, default_value_t = false)]
pub no_progress: bool,
#[command(subcommand)]
pub command: Commands,
}
#[derive(Debug, Subcommand)]
pub enum Commands {
/// Transcribe audio/video files or merge existing transcripts
Transcribe {
/// Output file or directory (date prefix is added when directory)
#[arg(short, long)]
output: Option<PathBuf>,
/// Merge multiple inputs into one output
#[arg(short = 'm', long, default_value_t = false)]
merge: bool,
/// Write both merged and per-input outputs (requires -o dir)
#[arg(long, default_value_t = false)]
merge_and_separate: bool,
/// Language code hint, e.g. en, de
#[arg(long)]
language: Option<String>,
/// Prompt for a speaker label per input file
#[arg(long, default_value_t = false)]
set_speaker_names: bool,
/// GPU backend selection
#[arg(long, value_enum, default_value_t = GpuBackend::Auto)]
gpu_backend: GpuBackend,
/// Offload N layers to GPU (when supported)
#[arg(long, default_value_t = 0)]
gpu_layers: usize,
/// Input paths: audio/video files or JSON transcripts
#[arg(required = true)]
inputs: Vec<PathBuf>,
},
/// Manage Whisper GGUF models (Hugging Face)
Models {
#[command(subcommand)]
cmd: ModelsCmd,
},
/// Discover and run plugins
Plugins {
#[command(subcommand)]
cmd: PluginsCmd,
},
/// Generate shell completions to stdout
Completions {
/// Shell to generate completions for
#[arg(value_parser = ["bash", "zsh", "fish", "powershell", "elvish"])]
shell: String,
},
/// Generate a man page to stdout
Man,
}
#[derive(Debug, Clone, Parser)]
pub struct ModelCommon {
/// Concurrency for ranged downloads
#[arg(long, default_value_t = 4)]
pub concurrency: usize,
/// Limit download rate in bytes/sec (approximate)
#[arg(long)]
pub limit_rate: Option<u64>,
}
#[derive(Debug, Subcommand)]
pub enum ModelsCmd {
/// List installed models (from manifest)
Ls {
#[command(flatten)]
common: ModelCommon,
},
/// Add or update a model
Add {
/// Hugging Face repo, e.g. ggml-org/models
repo: String,
/// File name in repo (e.g., gguf-tiny-q4_0.bin)
file: String,
#[command(flatten)]
common: ModelCommon,
},
/// Remove a model by alias
Rm {
alias: String,
#[command(flatten)]
common: ModelCommon,
},
/// Verify model file integrity by alias
Verify {
alias: String,
#[command(flatten)]
common: ModelCommon,
},
/// Update all models (HEAD + ETag; skip if unchanged)
Update {
#[command(flatten)]
common: ModelCommon,
},
/// Garbage-collect unreferenced files and stale manifest entries
Gc {
#[command(flatten)]
common: ModelCommon,
},
/// Search a repo for GGUF files
Search {
/// Hugging Face repo, e.g. ggml-org/models
repo: String,
/// Optional substring to filter filenames
#[arg(long)]
query: Option<String>,
#[command(flatten)]
common: ModelCommon,
},
}
#[derive(Debug, Subcommand)]
pub enum PluginsCmd {
/// List installed plugins
List,
/// Show a plugin's capabilities (as JSON)
Info {
/// Plugin short name, e.g., "tubescribe"
name: String,
},
/// Run a plugin command (JSON-RPC over NDJSON via stdio)
Run {
/// Plugin short name
name: String,
/// Command name in plugin's API
command: String,
/// JSON payload string
#[arg(long)]
json: Option<String>,
},
}

View File

@@ -1,536 +1,470 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
use std::fs::{File, create_dir_all};
use std::io::{self, Read, Write};
use std::path::{Path, PathBuf};
mod cli;
mod output;
use anyhow::{Context, Result, anyhow};
use clap::{Parser, Subcommand, ValueEnum, CommandFactory};
use clap_complete::Shell;
use serde::{Deserialize, Serialize};
use polyscribe::{OutputEntry, date_prefix, normalize_lang_code, render_srt};
use polyscribe_host as host;
#[derive(Subcommand, Debug, Clone)]
enum PluginsCmd {
/// List available plugins
List,
/// Show plugin capabilities
Info { name: String },
/// Run a plugin command with a JSON payload
Run {
name: String,
command: String,
/// JSON payload string passed to the plugin as request.params
#[arg(long = "json")]
json: String,
},
}
#[derive(Subcommand, Debug, Clone)]
enum Command {
Completions { #[arg(value_enum)] shell: Shell },
Man,
Plugins { #[command(subcommand)] cmd: PluginsCmd },
}
#[derive(ValueEnum, Debug, Clone, Copy)]
#[value(rename_all = "kebab-case")]
enum GpuBackendCli {
Auto,
Cpu,
Cuda,
Hip,
Vulkan,
}
#[derive(Parser, Debug)]
#[command(
name = "PolyScribe",
bin_name = "polyscribe",
version,
about = "Merge JSON transcripts or transcribe audio using native whisper"
)]
struct Args {
/// Increase verbosity (-v, -vv). Repeat to increase.
/// Debug logs appear with -v; very verbose with -vv. Logs go to stderr.
#[arg(short = 'v', long = "verbose", action = clap::ArgAction::Count, global = true)]
verbose: u8,
/// Quiet mode: suppress non-error logging on stderr (overrides -v)
/// Does not suppress interactive prompts or stdout output.
#[arg(short = 'q', long = "quiet", global = true)]
quiet: bool,
/// Non-interactive mode: never prompt; use defaults instead.
#[arg(long = "no-interaction", global = true)]
no_interaction: bool,
/// Disable interactive progress indicators (bars/spinners)
#[arg(long = "no-progress", global = true)]
no_progress: bool,
/// Optional subcommands (completions, man, plugins)
#[command(subcommand)]
cmd: Option<Command>,
/// Input .json transcript files or audio files to merge/transcribe
inputs: Vec<String>,
/// Output file path base or directory (date prefix added).
/// In merge mode: base path.
/// In separate mode: directory.
/// If omitted: prints JSON to stdout for merge mode; separate mode requires directory for multiple inputs.
#[arg(short, long, value_name = "FILE")]
output: Option<String>,
/// Merge all inputs into a single output; if not set, each input is written as a separate output
#[arg(short = 'm', long = "merge")]
merge: bool,
/// Merge and also write separate outputs per input; requires -o OUTPUT_DIR
#[arg(long = "merge-and-separate")]
merge_and_separate: bool,
/// Prompt for speaker names per input file
#[arg(long = "set-speaker-names")]
set_speaker_names: bool,
/// Language code to use for transcription (e.g., en, de). No auto-detection.
#[arg(short, long, value_name = "LANG")]
language: Option<String>,
/// Launch interactive model downloader (list HF models, multi-select and download)
#[arg(long)]
download_models: bool,
/// Update local Whisper models by comparing hashes/sizes with remote manifest
#[arg(long)]
update_models: bool,
}
#[derive(Debug, Deserialize)]
struct InputRoot {
#[serde(default)]
segments: Vec<InputSegment>,
}
#[derive(Debug, Deserialize)]
struct InputSegment {
start: f64,
end: f64,
text: String,
}
#[derive(Debug, Serialize)]
struct OutputRoot {
items: Vec<OutputEntry>,
}
fn is_json_file(path: &Path) -> bool {
matches!(path.extension().and_then(|s| s.to_str()).map(|s| s.to_lowercase()), Some(ext) if ext == "json")
}
fn is_audio_file(path: &Path) -> bool {
if let Some(ext) = path.extension().and_then(|s| s.to_str()).map(|s| s.to_lowercase()) {
let exts = [
"mp3", "wav", "m4a", "mp4", "aac", "flac", "ogg", "wma", "webm", "mkv", "mov", "avi",
"m4b", "3gp", "opus", "aiff", "alac",
];
return exts.contains(&ext.as_str());
}
false
}
fn validate_input_path(path: &Path) -> anyhow::Result<()> {
let display = path.display();
if !path.exists() {
return Err(anyhow!("Input not found: {}", display));
}
let metadata = std::fs::metadata(path).with_context(|| format!("Failed to stat input: {}", display))?;
if metadata.is_dir() {
return Err(anyhow!("Input is a directory (expected a file): {}", display));
}
std::fs::File::open(path)
.with_context(|| format!("Failed to open input file: {}", display))
.map(|_| ())
}
fn sanitize_speaker_name(raw: &str) -> String {
if let Some((prefix, rest)) = raw.split_once('-') {
if !prefix.is_empty() && prefix.chars().all(|c| c.is_ascii_digit()) {
return rest.to_string();
use clap::{CommandFactory, Parser};
use cli::{Cli, Commands, GpuBackend, ModelsCmd, ModelCommon, PluginsCmd};
use output::OutputMode;
use polyscribe_core::model_manager::{ModelManager, Settings, ReqwestClient};
use polyscribe_core::ui;
fn normalized_similarity(a: &str, b: &str) -> f64 {
// simple Levenshtein distance; normalized to [0,1]
let a_bytes = a.as_bytes();
let b_bytes = b.as_bytes();
let n = a_bytes.len();
let m = b_bytes.len();
if n == 0 && m == 0 { return 1.0; }
if n == 0 || m == 0 { return 0.0; }
let mut prev: Vec<usize> = (0..=m).collect();
let mut curr: Vec<usize> = vec![0; m + 1];
for i in 1..=n {
curr[0] = i;
for j in 1..=m {
let cost = if a_bytes[i - 1] == b_bytes[j - 1] { 0 } else { 1 };
curr[j] = (prev[j] + 1)
.min(curr[j - 1] + 1)
.min(prev[j - 1] + cost);
}
std::mem::swap(&mut prev, &mut curr);
}
raw.to_string()
let dist = prev[m] as f64;
let max_len = n.max(m) as f64;
1.0 - (dist / max_len)
}
fn prompt_speaker_name_for_path(
_path: &Path,
default_name: &str,
enabled: bool,
) -> String {
if !enabled || polyscribe::is_no_interaction() {
return sanitize_speaker_name(default_name);
}
// TODO implement cliclack for this
let mut input_line = String::new();
match std::io::stdin().read_line(&mut input_line) {
Ok(_) => {
let trimmed = input_line.trim();
if trimmed.is_empty() {
sanitize_speaker_name(default_name)
} else {
sanitize_speaker_name(trimmed)
}
fn human_size(bytes: Option<u64>) -> String {
match bytes {
Some(n) => {
let x = n as f64;
const KB: f64 = 1024.0;
const MB: f64 = 1024.0 * KB;
const GB: f64 = 1024.0 * MB;
if x >= GB { format!("{:.2} GiB", x / GB) }
else if x >= MB { format!("{:.2} MiB", x / MB) }
else if x >= KB { format!("{:.2} KiB", x / KB) }
else { format!("{} B", n) }
}
Err(_) => sanitize_speaker_name(default_name),
None => "?".to_string(),
}
}
use polyscribe_core::ui::progress::ProgressReporter;
use polyscribe_host::PluginManager;
use tokio::io::AsyncWriteExt;
use tracing_subscriber::EnvFilter;
fn handle_plugins(cmd: PluginsCmd) -> Result<()> {
match cmd {
PluginsCmd::List => {
let list = host::discover()?;
for p in list {
println!("{}\t{}", p.name, p.path.display());
}
Ok(())
}
PluginsCmd::Info { name } => {
let p = host::find_plugin_by_name(&name)?;
let caps = host::capabilities(&p.path)?;
println!("{}", serde_json::to_string_pretty(&caps)?);
Ok(())
}
PluginsCmd::Run { name, command, json } => {
let p = host::find_plugin_by_name(&name)?;
let params: serde_json::Value = serde_json::from_str(&json).context("--json payload must be valid JSON")?;
let mut last_pct = 0u8;
let result = host::run_method(&p.path, &command, params, |prog| {
// Render minimal progress
let stage = prog.stage.as_deref().unwrap_or("");
let msg = prog.message.as_deref().unwrap_or("");
if prog.pct != last_pct {
let _ = cliclack::log::info(format!("[{}%] {} {}", prog.pct, stage, msg).trim());
last_pct = prog.pct;
}
})?;
println!("{}", serde_json::to_string_pretty(&result)?);
Ok(())
}
}
fn init_tracing(json_mode: bool, quiet: bool, verbose: u8) {
// In JSON mode, suppress human logs; route errors to stderr only.
let level = if json_mode || quiet { "error" } else { match verbose { 0 => "info", 1 => "debug", _ => "trace" } };
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(level));
tracing_subscriber::fmt()
.with_env_filter(filter)
.with_target(false)
.with_level(true)
.with_writer(std::io::stderr)
.compact()
.init();
}
fn main() -> Result<()> {
let args = Args::parse();
let args = Cli::parse();
// Initialize runtime flags for the library
polyscribe::set_verbose(args.verbose);
polyscribe::set_quiet(args.quiet);
polyscribe::set_no_interaction(args.no_interaction);
polyscribe::set_no_progress(args.no_progress);
// Determine output mode early for logging and UI configuration
let output_mode = if args.output.json {
OutputMode::Json
} else {
OutputMode::Human { quiet: args.output.quiet }
};
// Handle subcommands
if let Some(cmd) = &args.cmd {
match cmd.clone() {
Command::Completions { shell } => {
let mut cmd = Args::command();
let bin_name = cmd.get_name().to_string();
clap_complete::generate(shell, &mut cmd, bin_name, &mut io::stdout());
return Ok(());
init_tracing(matches!(output_mode, OutputMode::Json), args.output.quiet, args.verbose);
// Suppress decorative UI output in JSON mode as well
polyscribe_core::set_quiet(args.output.quiet || matches!(output_mode, OutputMode::Json));
polyscribe_core::set_no_interaction(args.no_interaction);
polyscribe_core::set_verbose(args.verbose);
polyscribe_core::set_no_progress(args.no_progress);
match args.command {
Commands::Transcribe {
gpu_backend,
gpu_layers,
inputs,
..
} => {
polyscribe_core::ui::info("starting transcription workflow");
let mut progress = ProgressReporter::new(args.no_interaction);
progress.step("Validating inputs");
if inputs.is_empty() {
return Err(anyhow!("no inputs provided"));
}
Command::Man => {
let cmd = Args::command();
let man = clap_mangen::Man::new(cmd);
let mut man_bytes = Vec::new();
man.render(&mut man_bytes)?;
io::stdout().write_all(&man_bytes)?;
return Ok(());
progress.step("Selecting backend and preparing model");
match gpu_backend {
GpuBackend::Auto => {}
GpuBackend::Cpu => {}
GpuBackend::Cuda => {
let _ = gpu_layers;
}
GpuBackend::Hip => {}
GpuBackend::Vulkan => {}
}
Command::Plugins { cmd } => {
return handle_plugins(cmd);
}
}
}
// Optional model management actions
if args.download_models {
if let Err(err) = polyscribe::models::run_interactive_model_downloader() {
polyscribe::elog!("Model downloader failed: {:#}", err);
}
if args.inputs.is_empty() {
return Ok(())
}
}
if args.update_models {
if let Err(err) = polyscribe::models::update_local_models() {
polyscribe::elog!("Model update failed: {:#}", err);
return Err(err);
}
if args.inputs.is_empty() {
return Ok(())
}
}
// Process inputs
let mut inputs = args.inputs;
if inputs.is_empty() {
return Err(anyhow!("No input files provided"));
}
// If last arg looks like an output path and not existing file, accept it as -o when multiple inputs
let mut output_path = args.output;
if output_path.is_none() && inputs.len() >= 2 {
if let Some(candidate_output) = inputs.last().cloned() {
if !Path::new(&candidate_output).exists() {
inputs.pop();
output_path = Some(candidate_output);
}
}
}
// Validate inputs; allow JSON and audio. For audio, require --language.
for input_arg in &inputs {
let path_ref = Path::new(input_arg);
validate_input_path(path_ref)?;
if !(is_json_file(path_ref) || is_audio_file(path_ref)) {
return Err(anyhow!(
"Unsupported input type (expected .json transcript or audio media): {}",
path_ref.display()
));
}
if is_audio_file(path_ref) && args.language.is_none() {
return Err(anyhow!("Please specify --language (e.g., --language en). Language detection was removed."));
}
}
// Derive speakers (prompt if requested)
let speakers: Vec<String> = inputs
.iter()
.map(|input_path| {
let path = Path::new(input_path);
let default_speaker = sanitize_speaker_name(
path.file_stem().and_then(|s| s.to_str()).unwrap_or("speaker"),
);
prompt_speaker_name_for_path(path, &default_speaker, args.set_speaker_names)
})
.collect();
// MERGE-AND-SEPARATE mode
if args.merge_and_separate {
polyscribe::dlog!(1, "Mode: merge-and-separate; output_dir={:?}", output_path);
let out_dir = match output_path.as_ref() {
Some(p) => PathBuf::from(p),
None => return Err(anyhow!("--merge-and-separate requires -o OUTPUT_DIR")),
};
if !out_dir.as_os_str().is_empty() {
create_dir_all(&out_dir).with_context(|| {
format!("Failed to create output directory: {}", out_dir.display())
})?;
progress.finish_with_message("Transcription completed (stub)");
Ok(())
}
let mut merged_entries: Vec<OutputEntry> = Vec::new();
for (idx, input_path) in inputs.iter().enumerate() {
let path = Path::new(input_path);
let speaker = speakers[idx].clone();
// Decide based on input type (JSON transcript vs audio to transcribe)
// TODO remove duplicate
let mut entries: Vec<OutputEntry> = if is_json_file(path) {
let mut buf = String::new();
File::open(path)
.with_context(|| format!("Failed to open: {input_path}"))?
.read_to_string(&mut buf)
.with_context(|| format!("Failed to read: {input_path}"))?;
let root: InputRoot = serde_json::from_str(&buf)
.with_context(|| format!("Invalid JSON transcript parsed from {input_path}"))?;
root
.segments
.into_iter()
.map(|seg| OutputEntry { id: 0, speaker: speaker.clone(), start: seg.start, end: seg.end, text: seg.text })
.collect()
} else {
let lang_norm: Option<String> = args.language.as_deref().and_then(|s| normalize_lang_code(s));
let selected_backend = polyscribe::backend::select_backend(polyscribe::backend::BackendKind::Auto, args.verbose > 0)?;
selected_backend.backend.transcribe(path, &speaker, lang_norm.as_deref(), None, None)?
Commands::Models { cmd } => {
// predictable exit codes
const EXIT_OK: i32 = 0;
const EXIT_NOT_FOUND: i32 = 2;
const EXIT_NETWORK: i32 = 3;
const EXIT_VERIFY_FAILED: i32 = 4;
// const EXIT_NO_CHANGE: i32 = 5; // reserved
let handle_common = |c: &ModelCommon| Settings {
concurrency: c.concurrency.max(1),
limit_rate: c.limit_rate,
..Default::default()
};
// Sort and id per-file
// TODO remove duplicate
entries.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap_or(std::cmp::Ordering::Equal)
.then(a.end.partial_cmp(&b.end).unwrap_or(std::cmp::Ordering::Equal)));
for (i, entry) in entries.iter_mut().enumerate() { entry.id = i as u64; }
// Write per-file outputs
let stem = path.file_stem().and_then(|s| s.to_str()).unwrap_or("output");
let date = date_prefix();
let base_name = format!("{date}_{stem}");
let json_path = out_dir.join(format!("{}.json", &base_name));
let toml_path = out_dir.join(format!("{}.toml", &base_name));
let srt_path = out_dir.join(format!("{}.srt", &base_name));
let output_bundle = OutputRoot { items: entries.clone() };
let mut json_file = File::create(&json_path).with_context(|| format!("Failed to create output file: {}", json_path.display()))?;
serde_json::to_writer_pretty(&mut json_file, &output_bundle)?; writeln!(&mut json_file)?;
let toml_str = toml::to_string_pretty(&output_bundle)?;
let mut toml_file = File::create(&toml_path).with_context(|| format!("Failed to create output file: {}", toml_path.display()))?;
toml_file.write_all(toml_str.as_bytes())?; if !toml_str.ends_with('\n') { writeln!(&mut toml_file)?; }
let srt_str = render_srt(&output_bundle.items);
let mut srt_file = File::create(&srt_path).with_context(|| format!("Failed to create output file: {}", srt_path.display()))?;
srt_file.write_all(srt_str.as_bytes())?;
merged_entries.extend(output_bundle.items.into_iter());
}
// Write merged outputs into out_dir
// TODO remove duplicate
merged_entries.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap_or(std::cmp::Ordering::Equal)
.then(a.end.partial_cmp(&b.end).unwrap_or(std::cmp::Ordering::Equal)));
for (index, entry) in merged_entries.iter_mut().enumerate() { entry.id = index as u64; }
let merged_output = OutputRoot { items: merged_entries };
let date = date_prefix();
let merged_base = format!("{date}_merged");
let merged_json_path = out_dir.join(format!("{}.json", &merged_base));
let merged_toml_path = out_dir.join(format!("{}.toml", &merged_base));
let merged_srt_path = out_dir.join(format!("{}.srt", &merged_base));
let mut merged_json_file = File::create(&merged_json_path).with_context(|| format!("Failed to create output file: {}", merged_json_path.display()))?;
serde_json::to_writer_pretty(&mut merged_json_file, &merged_output)?; writeln!(&mut merged_json_file)?;
let merged_toml_str = toml::to_string_pretty(&merged_output)?;
let mut merged_toml_file = File::create(&merged_toml_path).with_context(|| format!("Failed to create output file: {}", merged_toml_path.display()))?;
merged_toml_file.write_all(merged_toml_str.as_bytes())?; if !merged_toml_str.ends_with('\n') { writeln!(&mut merged_toml_file)?; }
let merged_srt_str = render_srt(&merged_output.items);
let mut merged_srt_file = File::create(&merged_srt_path).with_context(|| format!("Failed to create output file: {}", merged_srt_path.display()))?;
merged_srt_file.write_all(merged_srt_str.as_bytes())?;
return Ok(());
}
// MERGE mode
if args.merge {
polyscribe::dlog!(1, "Mode: merge; output_base={:?}", output_path);
let mut entries: Vec<OutputEntry> = Vec::new();
for (index, input_path) in inputs.iter().enumerate() {
let path = Path::new(input_path);
let speaker = speakers[index].clone();
if is_json_file(path) {
let mut buf = String::new();
File::open(path)
.with_context(|| format!("Failed to open: {}", input_path))?
.read_to_string(&mut buf)
.with_context(|| format!("Failed to read: {}", input_path))?;
let root: InputRoot = serde_json::from_str(&buf)
.with_context(|| format!("Invalid JSON transcript parsed from {}", input_path))?;
for seg in root.segments {
entries.push(OutputEntry { id: 0, speaker: speaker.clone(), start: seg.start, end: seg.end, text: seg.text });
let exit = match cmd {
ModelsCmd::Ls { common } => {
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
let list = mm.ls()?;
match output_mode {
OutputMode::Json => {
// Always emit JSON array (possibly empty)
output_mode.print_json(&list);
}
OutputMode::Human { quiet } => {
if list.is_empty() {
if !quiet { println!("No models installed."); }
} else {
if !quiet { println!("Model (Repo)"); }
for r in list {
if !quiet { println!("{} ({})", r.file, r.repo); }
}
}
}
}
EXIT_OK
}
} else {
let lang_norm: Option<String> = args.language.as_deref().and_then(|s| normalize_lang_code(s));
let selected_backend = polyscribe::backend::select_backend(polyscribe::backend::BackendKind::Auto, args.verbose > 0)?;
let mut new_entries = selected_backend.backend.transcribe(path, &speaker, lang_norm.as_deref(), None, None)?;
entries.append(&mut new_entries);
}
ModelsCmd::Add { repo, file, common } => {
let settings = handle_common(&common);
let mm: ModelManager<ReqwestClient> = ModelManager::new(settings.clone())?;
// Derive an alias automatically from repo and file
fn derive_alias(repo: &str, file: &str) -> String {
use std::path::Path;
let repo_tail = repo.rsplit('/').next().unwrap_or(repo);
let stem = Path::new(file)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or(file);
format!("{}-{}", repo_tail, stem)
}
let alias = derive_alias(&repo, &file);
match mm.add_or_update(&alias, &repo, &file) {
Ok(rec) => {
match output_mode {
OutputMode::Json => output_mode.print_json(&rec),
OutputMode::Human { quiet } => {
if !quiet { println!("installed: {} -> {}/{}", alias, repo, rec.file); }
}
}
EXIT_OK
}
Err(e) => {
// On not found or similar errors, try suggesting close matches interactively
if matches!(output_mode, OutputMode::Json) || polyscribe_core::is_no_interaction() {
match output_mode {
OutputMode::Json => {
// Emit error JSON object
#[derive(serde::Serialize)]
struct ErrObj<'a> { error: &'a str }
let eo = ErrObj { error: &e.to_string() };
output_mode.print_json(&eo);
}
_ => { eprintln!("error: {e}"); }
}
EXIT_NOT_FOUND
} else {
ui::warn(format!("{}", e));
ui::info("Searching for similar model filenames…");
match polyscribe_core::model_manager::search_repo(&repo, None) {
Ok(mut files) => {
if files.is_empty() {
ui::warn("No files found in repository.");
EXIT_NOT_FOUND
} else {
// rank by similarity
files.sort_by(|a, b| normalized_similarity(&file, b)
.partial_cmp(&normalized_similarity(&file, a))
.unwrap_or(std::cmp::Ordering::Equal));
let top: Vec<String> = files.into_iter().take(5).collect();
if top.is_empty() {
EXIT_NOT_FOUND
} else if top.len() == 1 {
let cand = &top[0];
// Fetch repo size list once
let size_map: std::collections::HashMap<String, Option<u64>> =
polyscribe_core::model_manager::list_repo_files_with_meta(&repo)
.unwrap_or_default()
.into_iter().collect();
let mut size = size_map.get(cand).cloned().unwrap_or(None);
if size.is_none() {
size = polyscribe_core::model_manager::head_len_for_file(&repo, cand);
}
let local_files: std::collections::HashSet<String> = mm.ls()?.into_iter().map(|r| r.file).collect();
let is_local = local_files.contains(cand);
let label = format!("{} [{}]{}", cand, human_size(size), if is_local { " (local)" } else { "" });
let ok = ui::prompt_confirm(&format!("Did you mean {}?", label), true)
.unwrap_or(false);
if !ok { EXIT_NOT_FOUND } else {
let mm2: ModelManager<ReqwestClient> = ModelManager::new(settings)?;
let alias2 = derive_alias(&repo, cand);
match mm2.add_or_update(&alias2, &repo, cand) {
Ok(rec) => {
match output_mode {
OutputMode::Json => output_mode.print_json(&rec),
OutputMode::Human { quiet } => { if !quiet { println!("installed: {} -> {}/{}", alias2, repo, rec.file); } }
}
EXIT_OK
}
Err(e2) => { eprintln!("error: {e2}"); EXIT_NETWORK }
}
}
} else {
let opts: Vec<String> = top;
let local_files: std::collections::HashSet<String> = mm.ls()?.into_iter().map(|r| r.file).collect();
// Enrich labels with size and local tag using a single API call
let size_map: std::collections::HashMap<String, Option<u64>> =
polyscribe_core::model_manager::list_repo_files_with_meta(&repo)
.unwrap_or_default()
.into_iter().collect();
let mut labels_owned: Vec<String> = Vec::new();
for f in &opts {
let mut size = size_map.get(f).cloned().unwrap_or(None);
if size.is_none() {
size = polyscribe_core::model_manager::head_len_for_file(&repo, f);
}
let is_local = local_files.contains(f);
let suffix = if is_local { " (local)" } else { "" };
labels_owned.push(format!("{} [{}]{}", f, human_size(size), suffix));
}
let labels: Vec<&str> = labels_owned.iter().map(|s| s.as_str()).collect();
match ui::prompt_select("Pick a model", &labels) {
Ok(idx) => {
let chosen = &opts[idx];
let mm2: ModelManager<ReqwestClient> = ModelManager::new(settings)?;
let alias2 = derive_alias(&repo, chosen);
match mm2.add_or_update(&alias2, &repo, chosen) {
Ok(rec) => {
match output_mode {
OutputMode::Json => output_mode.print_json(&rec),
OutputMode::Human { quiet } => { if !quiet { println!("installed: {} -> {}/{}", alias2, repo, rec.file); } }
}
EXIT_OK
}
Err(e2) => { eprintln!("error: {e2}"); EXIT_NETWORK }
}
}
Err(_) => EXIT_NOT_FOUND,
}
}
}
}
Err(e2) => {
eprintln!("error: {}", e2);
EXIT_NETWORK
}
}
}
}
}
}
ModelsCmd::Rm { alias, common } => {
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
let ok = mm.rm(&alias)?;
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R { removed: bool }
output_mode.print_json(&R { removed: ok });
}
OutputMode::Human { quiet } => {
if !quiet { println!("{}", if ok { "removed" } else { "not found" }); }
}
}
if ok { EXIT_OK } else { EXIT_NOT_FOUND }
}
ModelsCmd::Verify { alias, common } => {
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
let found = mm.ls()?.into_iter().any(|r| r.alias == alias);
if !found {
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R<'a> { ok: bool, error: &'a str }
output_mode.print_json(&R { ok: false, error: "not found" });
}
OutputMode::Human { quiet } => { if !quiet { println!("not found"); } }
}
EXIT_NOT_FOUND
} else {
let ok = mm.verify(&alias)?;
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R { ok: bool }
output_mode.print_json(&R { ok });
}
OutputMode::Human { quiet } => { if !quiet { println!("{}", if ok { "ok" } else { "corrupt" }); } }
}
if ok { EXIT_OK } else { EXIT_VERIFY_FAILED }
}
}
ModelsCmd::Update { common } => {
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
let mut rc = EXIT_OK;
for rec in mm.ls()? {
match mm.add_or_update(&rec.alias, &rec.repo, &rec.file) {
Ok(_) => {}
Err(e) => {
rc = EXIT_NETWORK;
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R<'a> { alias: &'a str, error: String }
output_mode.print_json(&R { alias: &rec.alias, error: e.to_string() });
}
_ => { eprintln!("update {}: {e}", rec.alias); }
}
}
}
}
rc
}
ModelsCmd::Gc { common } => {
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
let (files_removed, entries_removed) = mm.gc()?;
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R { files_removed: usize, entries_removed: usize }
output_mode.print_json(&R { files_removed, entries_removed });
}
OutputMode::Human { quiet } => { if !quiet { println!("files_removed={} entries_removed={}", files_removed, entries_removed); } }
}
EXIT_OK
}
ModelsCmd::Search { repo, query, common } => {
let res = polyscribe_core::model_manager::search_repo(&repo, query.as_deref());
match res {
Ok(files) => {
match output_mode {
OutputMode::Json => output_mode.print_json(&files),
OutputMode::Human { quiet } => { for f in files { if !quiet { println!("{}", f); } } }
}
EXIT_OK
}
Err(e) => {
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R { error: String }
output_mode.print_json(&R { error: e.to_string() });
}
_ => { eprintln!("error: {e}"); }
}
EXIT_NETWORK
}
}
}
};
std::process::exit(exit);
}
// TODO remove duplicate
entries.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap_or(std::cmp::Ordering::Equal)
.then(a.end.partial_cmp(&b.end).unwrap_or(std::cmp::Ordering::Equal)));
for (i, entry) in entries.iter_mut().enumerate() { entry.id = i as u64; }
let output_bundle = OutputRoot { items: entries };
if let Some(path) = output_path {
let base_path = Path::new(&path);
let parent_opt = base_path.parent();
if let Some(parent) = parent_opt {
if !parent.as_os_str().is_empty() {
create_dir_all(parent).with_context(|| {
format!("Failed to create parent directory for output: {}", parent.display())
})?;
Commands::Plugins { cmd } => {
let plugin_manager = PluginManager;
match cmd {
PluginsCmd::List => {
let list = plugin_manager.list().context("discovering plugins")?;
for item in list {
polyscribe_core::ui::info(item.name);
}
Ok(())
}
PluginsCmd::Info { name } => {
let info = plugin_manager
.info(&name)
.with_context(|| format!("getting info for {}", name))?;
let info_json = serde_json::to_string_pretty(&info)?;
polyscribe_core::ui::info(info_json);
Ok(())
}
PluginsCmd::Run {
name,
command,
json,
} => {
// Use a local Tokio runtime only for this async path
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.context("building tokio runtime")?;
rt.block_on(async {
let payload = json.unwrap_or_else(|| "{}".to_string());
let mut child = plugin_manager
.spawn(&name, &command)
.with_context(|| format!("spawning plugin {name} {command}"))?;
if let Some(mut stdin) = child.stdin.take() {
stdin
.write_all(payload.as_bytes())
.await
.context("writing JSON payload to plugin stdin")?;
}
let status = plugin_manager.forward_stdio(&mut child).await?;
if !status.success() {
polyscribe_core::ui::error(format!(
"plugin returned non-zero exit code: {}",
status
));
return Err(anyhow!("plugin failed"));
}
Ok(())
})
}
}
let stem = base_path.file_stem().and_then(|s| s.to_str()).unwrap_or("output");
let date = date_prefix();
let base_name = format!("{}_{}", date, stem);
let dir = parent_opt.unwrap_or(Path::new(""));
let json_path = dir.join(format!("{}.json", &base_name));
let toml_path = dir.join(format!("{}.toml", &base_name));
let srt_path = dir.join(format!("{}.srt", &base_name));
let mut json_file = File::create(&json_path).with_context(|| format!("Failed to create output file: {}", json_path.display()))?;
serde_json::to_writer_pretty(&mut json_file, &output_bundle)?; writeln!(&mut json_file)?;
let toml_str = toml::to_string_pretty(&output_bundle)?;
let mut toml_file = File::create(&toml_path).with_context(|| format!("Failed to create output file: {}", toml_path.display()))?;
toml_file.write_all(toml_str.as_bytes())?; if !toml_str.ends_with('\n') { writeln!(&mut toml_file)?; }
let srt_str = render_srt(&output_bundle.items);
let mut srt_file = File::create(&srt_path).with_context(|| format!("Failed to create output file: {}", srt_path.display()))?;
srt_file.write_all(srt_str.as_bytes())?;
} else {
let stdout = io::stdout();
let mut handle = stdout.lock();
serde_json::to_writer_pretty(&mut handle, &output_bundle)?; writeln!(&mut handle)?;
}
return Ok(());
}
// SEPARATE (default)
polyscribe::dlog!(1, "Mode: separate; output_dir={:?}", output_path);
if output_path.is_none() && inputs.len() > 1 {
return Err(anyhow!("Multiple inputs without --merge require -o OUTPUT_DIR to write separate files"));
}
let out_dir: Option<PathBuf> = output_path.as_ref().map(PathBuf::from);
if let Some(dir) = &out_dir {
if !dir.as_os_str().is_empty() {
create_dir_all(dir).with_context(|| format!("Failed to create output directory: {}", dir.display()))?;
Commands::Completions { shell } => {
use clap_complete::{generate, shells};
use std::io;
let mut cmd = Cli::command();
let name = cmd.get_name().to_string();
match shell.as_str() {
"bash" => generate(shells::Bash, &mut cmd, name, &mut io::stdout()),
"zsh" => generate(shells::Zsh, &mut cmd, name, &mut io::stdout()),
"fish" => generate(shells::Fish, &mut cmd, name, &mut io::stdout()),
"powershell" => generate(shells::PowerShell, &mut cmd, name, &mut io::stdout()),
"elvish" => generate(shells::Elvish, &mut cmd, name, &mut io::stdout()),
_ => return Err(anyhow!("unsupported shell: {shell}")),
}
Ok(())
}
Commands::Man => {
use clap_mangen::Man;
let cmd = Cli::command();
let man = Man::new(cmd);
man.render(&mut std::io::stdout())?;
Ok(())
}
}
for (index, input_path) in inputs.iter().enumerate() {
let path = Path::new(input_path);
let speaker = speakers[index].clone();
// TODO remove duplicate
let mut entries: Vec<OutputEntry> = if is_json_file(path) {
let mut buf = String::new();
File::open(path)
.with_context(|| format!("Failed to open: {input_path}"))?
.read_to_string(&mut buf)
.with_context(|| format!("Failed to read: {input_path}"))?;
let root: InputRoot = serde_json::from_str(&buf)
.with_context(|| format!("Invalid JSON transcript parsed from {input_path}"))?;
root
.segments
.into_iter()
.map(|seg| OutputEntry { id: 0, speaker: speaker.clone(), start: seg.start, end: seg.end, text: seg.text })
.collect()
} else {
let lang_norm: Option<String> = args.language.as_deref().and_then(|s| normalize_lang_code(s));
let selected_backend = polyscribe::backend::select_backend(polyscribe::backend::BackendKind::Auto, args.verbose > 0)?;
selected_backend.backend.transcribe(path, &speaker, lang_norm.as_deref(), None, None)?
};
// TODO remove duplicate
entries.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap_or(std::cmp::Ordering::Equal)
.then(a.end.partial_cmp(&b.end).unwrap_or(std::cmp::Ordering::Equal)));
for (i, entry) in entries.iter_mut().enumerate() { entry.id = i as u64; }
let stem = path.file_stem().and_then(|s| s.to_str()).unwrap_or("output");
let date = date_prefix();
let base_name = format!("{date}_{stem}");
if let Some(dir) = &out_dir {
let json_path = dir.join(format!("{}.json", &base_name));
let toml_path = dir.join(format!("{}.toml", &base_name));
let srt_path = dir.join(format!("{}.srt", &base_name));
let output_bundle = OutputRoot { items: entries };
let mut json_file = File::create(&json_path).with_context(|| format!("Failed to create output file: {}", json_path.display()))?;
serde_json::to_writer_pretty(&mut json_file, &output_bundle)?; writeln!(&mut json_file)?;
let toml_str = toml::to_string_pretty(&output_bundle)?;
let mut toml_file = File::create(&toml_path).with_context(|| format!("Failed to create output file: {}", toml_path.display()))?;
toml_file.write_all(toml_str.as_bytes())?; if !toml_str.ends_with('\n') { writeln!(&mut toml_file)?; }
let srt_str = render_srt(&output_bundle.items);
let mut srt_file = File::create(&srt_path).with_context(|| format!("Failed to create output file: {}", srt_path.display()))?;
srt_file.write_all(srt_str.as_bytes())?;
} else {
// In separate mode with single input and no output dir, print JSON to stdout
let stdout = io::stdout();
let mut handle = stdout.lock();
let output_bundle = OutputRoot { items: entries };
serde_json::to_writer_pretty(&mut handle, &output_bundle)?; writeln!(&mut handle)?;
}
}
Ok(())
}

View File

@@ -0,0 +1,36 @@
use std::io::{self, Write};
#[derive(Clone, Debug)]
pub enum OutputMode {
Json,
Human { quiet: bool },
}
impl OutputMode {
pub fn is_quiet(&self) -> bool {
matches!(self, OutputMode::Json) || matches!(self, OutputMode::Human { quiet: true })
}
pub fn print_json<T: serde::Serialize>(&self, v: &T) {
if let OutputMode::Json = self {
// Write compact JSON to stdout without prefixes
// and ensure a trailing newline for CLI ergonomics
let s = serde_json::to_string(v).unwrap_or_else(|e| format!("\"JSON_ERROR:{}\"", e));
println!("{}", s);
}
}
pub fn print_line(&self, s: impl AsRef<str>) {
match self {
OutputMode::Json => {
// Suppress human lines in JSON mode
}
OutputMode::Human { quiet } => {
if !quiet {
let _ = writeln!(io::stdout(), "{}", s.as_ref());
}
}
}
}
}

View File

@@ -1,10 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
use assert_cmd::cargo::cargo_bin;
use std::process::Command;
fn bin() -> &'static str {
env!("CARGO_BIN_EXE_polyscribe")
fn bin() -> std::path::PathBuf {
cargo_bin("polyscribe")
}
#[test]

View File

@@ -0,0 +1,42 @@
use assert_cmd::cargo::cargo_bin;
use std::process::Command;
fn bin() -> std::path::PathBuf { cargo_bin("polyscribe") }
#[test]
fn models_help_shows_global_output_flags() {
let out = Command::new(bin())
.args(["models", "--help"]) // subcommand help
.output()
.expect("failed to run polyscribe models --help");
assert!(out.status.success(), "help exited non-zero: {:?}", out.status);
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
assert!(stdout.contains("--json"), "--json not shown in help: {stdout}");
assert!(stdout.contains("--quiet"), "--quiet not shown in help: {stdout}");
}
#[test]
fn models_version_contains_pkg_version() {
let out = Command::new(bin())
.args(["models", "--version"]) // propagate_version
.output()
.expect("failed to run polyscribe models --version");
assert!(out.status.success(), "version exited non-zero: {:?}", out.status);
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
let want = env!("CARGO_PKG_VERSION");
assert!(stdout.contains(want), "version output missing {want}: {stdout}");
}
#[test]
fn models_ls_json_quiet_emits_pure_json() {
let out = Command::new(bin())
.args(["models", "ls", "--json", "--quiet"]) // global flags
.output()
.expect("failed to run polyscribe models ls --json --quiet");
assert!(out.status.success(), "ls exited non-zero: {:?}", out.status);
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
serde_json::from_str::<serde_json::Value>(stdout.trim()).expect("stdout is not valid JSON");
// Expect no extra logs on stdout; stderr should be empty in success path
assert!(out.stderr.is_empty(), "expected no stderr noise");
}

View File

@@ -1,32 +1,22 @@
[package]
name = "polyscribe"
version = "0.1.0"
edition = "2024"
license = "MIT"
[features]
# Default: CPU only; no GPU features enabled
default = []
# GPU backends map to whisper-rs features or FFI stub for Vulkan
gpu-cuda = ["whisper-rs/cuda"]
gpu-hip = ["whisper-rs/hipblas"]
gpu-vulkan = []
# explicit CPU fallback feature (no effect at build time, used for clarity)
cpu-fallback = []
name = "polyscribe-core"
version.workspace = true
edition.workspace = true
[dependencies]
anyhow = "1.0.98"
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.142"
toml = "0.8"
chrono = { version = "0.4", features = ["clock"] }
sha2 = "0.10"
whisper-rs = { git = "https://github.com/tazz4843/whisper-rs" }
libc = "0.2"
cliclack = "0.3"
indicatif = "0.17"
thiserror = "1"
directories = "5"
[build-dependencies]
# no special build deps
anyhow = { workspace = true }
thiserror = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
toml = { workspace = true }
directories = { workspace = true }
chrono = { workspace = true }
libc = { workspace = true }
whisper-rs = { workspace = true }
# UI and progress
cliclack = { workspace = true }
# HTTP downloads + hashing
reqwest = { workspace = true }
sha2 = { workspace = true }
hex = { workspace = true }
tempfile = { workspace = true }

View File

@@ -1,12 +1,14 @@
// SPDX-License-Identifier: MIT
// Move original build.rs behavior into core crate
fn main() {
// Only run special build steps when gpu-vulkan feature is enabled.
let vulkan_enabled = std::env::var("CARGO_FEATURE_GPU_VULKAN").is_ok();
println!("cargo:rerun-if-changed=extern/whisper.cpp");
if !vulkan_enabled {
println!(
"cargo:warning=gpu-vulkan feature is disabled; skipping Vulkan-dependent build steps."
);
return;
}
println!("cargo:rerun-if-changed=extern/whisper.cpp");
println!(
"cargo:warning=Building with gpu-vulkan: ensure Vulkan SDK/loader are installed. Future versions will compile whisper.cpp via CMake."
);

View File

@@ -1,34 +1,23 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
//! Transcription backend selection and implementations (CPU/GPU) used by PolyScribe.
use crate::OutputEntry;
use crate::prelude::*;
use crate::{decode_audio_to_pcm_f32_ffmpeg, find_model_file};
use anyhow::{Context, Result, anyhow};
use anyhow::{Context, anyhow};
use std::env;
use std::path::Path;
// Re-export a public enum for CLI parsing usage
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
/// Kind of transcription backend to use.
pub enum BackendKind {
/// Automatically detect the best available backend (CUDA > HIP > Vulkan > CPU).
Auto,
/// Pure CPU backend using whisper-rs.
Cpu,
/// NVIDIA CUDA backend (requires CUDA runtime available at load time and proper feature build).
Cuda,
/// AMD ROCm/HIP backend (requires hip/rocBLAS libraries available and proper feature build).
Hip,
/// Vulkan backend (experimental; requires Vulkan loader/SDK and feature build).
Vulkan,
}
/// Abstraction for a transcription backend.
pub trait TranscribeBackend {
/// Backend kind implemented by this type.
fn kind(&self) -> BackendKind;
/// Transcribe the given audio and return transcript entries.
fn transcribe(
&self,
audio_path: &Path,
@@ -39,15 +28,13 @@ pub trait TranscribeBackend {
) -> Result<Vec<OutputEntry>>;
}
fn check_lib(_names: &[&str]) -> bool {
fn is_library_available(_names: &[&str]) -> bool {
#[cfg(test)]
{
// During unit tests, avoid touching system libs to prevent loader crashes in CI.
false
}
#[cfg(not(test))]
{
// Disabled runtime dlopen probing to avoid loader instability; rely on environment overrides.
false
}
}
@@ -56,7 +43,7 @@ fn cuda_available() -> bool {
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_CUDA") {
return x == "1";
}
check_lib(&[
is_library_available(&[
"libcudart.so",
"libcudart.so.12",
"libcudart.so.11",
@@ -69,33 +56,31 @@ fn hip_available() -> bool {
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_HIP") {
return x == "1";
}
check_lib(&["libhipblas.so", "librocblas.so"])
is_library_available(&["libhipblas.so", "librocblas.so"])
}
fn vulkan_available() -> bool {
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_VULKAN") {
return x == "1";
}
check_lib(&["libvulkan.so.1", "libvulkan.so"])
is_library_available(&["libvulkan.so.1", "libvulkan.so"])
}
/// CPU-based transcription backend using whisper-rs.
#[derive(Default)]
pub struct CpuBackend;
/// CUDA-accelerated transcription backend for NVIDIA GPUs.
#[derive(Default)]
pub struct CudaBackend;
/// ROCm/HIP-accelerated transcription backend for AMD GPUs.
#[derive(Default)]
pub struct HipBackend;
/// Vulkan-based transcription backend (experimental/incomplete).
#[derive(Default)]
pub struct VulkanBackend;
macro_rules! impl_whisper_backend {
($ty:ty, $kind:expr) => {
impl TranscribeBackend for $ty {
fn kind(&self) -> BackendKind { $kind }
fn kind(&self) -> BackendKind {
$kind
}
fn transcribe(
&self,
audio_path: &Path,
@@ -128,29 +113,17 @@ impl TranscribeBackend for VulkanBackend {
) -> Result<Vec<OutputEntry>> {
Err(anyhow!(
"Vulkan backend not yet wired to whisper.cpp FFI. Build with --features gpu-vulkan and ensure Vulkan SDK is installed. How to fix: install Vulkan loader (libvulkan), set VULKAN_SDK, and run cargo build --features gpu-vulkan."
))
).into())
}
}
/// Result of choosing a transcription backend.
pub struct SelectionResult {
/// The constructed backend instance to perform transcription with.
pub struct BackendSelection {
pub backend: Box<dyn TranscribeBackend + Send + Sync>,
/// Which backend kind was ultimately selected.
pub chosen: BackendKind,
/// Which backend kinds were detected as available on this system.
pub detected: Vec<BackendKind>,
}
/// Select an appropriate backend based on user request and system detection.
///
/// If `requested` is `BackendKind::Auto`, the function prefers CUDA, then HIP,
/// then Vulkan, falling back to CPU when no GPU backend is detected. When a
/// specific GPU backend is requested but unavailable, an error is returned with
/// guidance on how to enable it.
///
/// Set `verbose` to true to print detection/selection info to stderr.
pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<SelectionResult> {
pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<BackendSelection> {
let mut detected = Vec::new();
if cuda_available() {
detected.push(BackendKind::Cuda);
@@ -164,11 +137,11 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
let instantiate_backend = |k: BackendKind| -> Box<dyn TranscribeBackend + Send + Sync> {
match k {
BackendKind::Cpu => Box::new(CpuBackend::default()),
BackendKind::Cuda => Box::new(CudaBackend::default()),
BackendKind::Hip => Box::new(HipBackend::default()),
BackendKind::Vulkan => Box::new(VulkanBackend::default()),
BackendKind::Auto => Box::new(CpuBackend::default()), // placeholder for Auto
BackendKind::Cpu => Box::new(CpuBackend),
BackendKind::Cuda => Box::new(CudaBackend),
BackendKind::Hip => Box::new(HipBackend),
BackendKind::Vulkan => Box::new(VulkanBackend),
BackendKind::Auto => Box::new(CpuBackend),
}
};
@@ -190,7 +163,7 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
} else {
return Err(anyhow!(
"Requested CUDA backend but CUDA libraries/devices not detected. How to fix: install NVIDIA driver + CUDA toolkit, ensure libcudart/libcublas are in loader path, and build with --features gpu-cuda."
));
).into());
}
}
BackendKind::Hip => {
@@ -199,7 +172,7 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
} else {
return Err(anyhow!(
"Requested ROCm/HIP backend but libraries/devices not detected. How to fix: install ROCm hipBLAS/rocBLAS, ensure libs are in loader path, and build with --features gpu-hip."
));
).into());
}
}
BackendKind::Vulkan => {
@@ -208,7 +181,7 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
} else {
return Err(anyhow!(
"Requested Vulkan backend but libvulkan not detected. How to fix: install Vulkan loader/SDK and build with --features gpu-vulkan."
));
).into());
}
}
BackendKind::Cpu => BackendKind::Cpu,
@@ -219,14 +192,13 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
crate::dlog!(1, "Selected backend: {:?}", chosen);
}
Ok(SelectionResult {
Ok(BackendSelection {
backend: instantiate_backend(chosen),
chosen,
detected,
})
}
// Internal helper: transcription using whisper-rs with CPU/GPU (depending on build features)
#[allow(clippy::too_many_arguments)]
pub(crate) fn transcribe_with_whisper_rs(
audio_path: &Path,
@@ -235,7 +207,9 @@ pub(crate) fn transcribe_with_whisper_rs(
progress: Option<&(dyn Fn(i32) + Send + Sync)>,
) -> Result<Vec<OutputEntry>> {
let report = |p: i32| {
if let Some(cb) = progress { cb(p); }
if let Some(cb) = progress {
cb(p);
}
};
report(0);
@@ -248,21 +222,21 @@ pub(crate) fn transcribe_with_whisper_rs(
.and_then(|s| s.to_str())
.map(|s| s.contains(".en.") || s.ends_with(".en.bin"))
.unwrap_or(false);
if let Some(lang) = language {
if english_only_model && lang != "en" {
return Err(anyhow!(
"Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.",
model_path.display(),
lang
));
}
if let Some(lang) = language
&& english_only_model
&& lang != "en"
{
return Err(anyhow!(
"Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.",
model_path.display(),
lang
).into());
}
let model_path_str = model_path
.to_str()
.ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model_path.display()))?;
if crate::verbose_level() < 2 {
// Some builds of whisper/ggml expect these env vars; harmless if unknown
unsafe {
std::env::set_var("GGML_LOG_LEVEL", "0");
std::env::set_var("WHISPER_PRINT_PROGRESS", "0");

View File

@@ -1,149 +1,104 @@
// SPDX-License-Identifier: MIT
// Simple ConfigService with XDG/system/workspace merge and atomic writes
use anyhow::{Context, Result};
use directories::BaseDirs;
use serde::{Deserialize, Serialize};
use std::env;
use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::path::PathBuf;
/// Generic configuration represented as TOML table
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Config(pub toml::value::Table);
impl Config {
/// Get a mutable reference to a top-level table under the given key, creating
/// an empty table if it does not exist yet.
pub fn get_table_mut(&mut self, key: &str) -> &mut toml::value::Table {
let needs_init = !matches!(self.0.get(key), Some(toml::Value::Table(_)));
if needs_init {
self.0.insert(key.to_string(), toml::Value::Table(Default::default()));
}
match self.0.get_mut(key) {
Some(toml::Value::Table(t)) => t,
_ => unreachable!(),
}
}
}
fn merge_tables(base: &mut toml::value::Table, overlay: &toml::value::Table) {
for (k, v) in overlay.iter() {
match (base.get_mut(k), v) {
(Some(toml::Value::Table(bsub)), toml::Value::Table(osub)) => {
merge_tables(bsub, osub);
}
_ => {
base.insert(k.clone(), v.clone());
}
}
}
}
fn read_toml(path: &Path) -> Result<toml::value::Table> {
let s = fs::read_to_string(path).with_context(|| format!("Failed to read config: {}", path.display()))?;
let v: toml::Value = toml::from_str(&s).with_context(|| format!("Invalid TOML in {}", path.display()))?;
Ok(v.as_table().cloned().unwrap_or_default())
}
fn write_toml_atomic(path: &Path, tbl: &toml::value::Table) -> Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).with_context(|| format!("Failed to create config dir: {}", parent.display()))?;
}
let tmp = path.with_extension("tmp");
let mut f = fs::File::create(&tmp).with_context(|| format!("Failed to create temp file: {}", tmp.display()))?;
let s = toml::to_string_pretty(&toml::Value::Table(tbl.clone()))?;
f.write_all(s.as_bytes())?;
if !s.ends_with('\n') { f.write_all(b"\n")?; }
drop(f);
fs::rename(&tmp, path).with_context(|| format!("Failed to atomically replace config: {}", path.display()))?;
Ok(())
}
fn system_config_path() -> PathBuf {
if cfg!(unix) { PathBuf::from("/etc").join("polyscribe").join("config.toml") } else { default_user_config_path() }
}
fn default_user_config_path() -> PathBuf {
if let Some(base) = BaseDirs::new() {
return PathBuf::from(base.config_dir()).join("polyscribe").join("config.toml");
}
PathBuf::from(".polyscribe").join("config.toml")
}
fn workspace_config_path() -> PathBuf {
PathBuf::from(".polyscribe").join("config.toml")
}
/// Service responsible for loading and saving PolyScribe configuration
#[derive(Debug, Default, Clone)]
pub struct ConfigService;
impl ConfigService {
/// Load configuration, merging system < user < workspace < env overrides.
pub fn load(&self) -> Result<Config> {
let mut accum = toml::value::Table::default();
let sys = system_config_path();
if sys.exists() {
merge_tables(&mut accum, &read_toml(&sys)?);
}
let user = default_user_config_path();
if user.exists() {
merge_tables(&mut accum, &read_toml(&user)?);
}
let ws = workspace_config_path();
if ws.exists() {
merge_tables(&mut accum, &read_toml(&ws)?);
}
// Env overrides: POLYSCRIBE__SECTION__KEY=value
let mut env_over = toml::value::Table::default();
for (k, v) in env::vars() {
if let Some(rest) = k.strip_prefix("POLYSCRIBE__") {
let parts: Vec<&str> = rest.split("__").collect();
if parts.is_empty() { continue; }
let val: toml::Value = toml::Value::String(v);
// Build nested tables
let mut current = &mut env_over;
for (i, part) in parts.iter().enumerate() {
if i == parts.len() - 1 {
current.insert(part.to_lowercase(), val.clone());
} else {
current = current.entry(part.to_lowercase()).or_insert_with(|| toml::Value::Table(Default::default()))
.as_table_mut().expect("table");
}
}
pub const ENV_NO_CACHE_MANIFEST: &'static str = "POLYSCRIBE_NO_CACHE_MANIFEST";
pub const ENV_MANIFEST_TTL_SECONDS: &'static str = "POLYSCRIBE_MANIFEST_TTL_SECONDS";
pub const ENV_MODELS_DIR: &'static str = "POLYSCRIBE_MODELS_DIR";
pub const ENV_USER_AGENT: &'static str = "POLYSCRIBE_USER_AGENT";
pub const ENV_HTTP_TIMEOUT_SECS: &'static str = "POLYSCRIBE_HTTP_TIMEOUT_SECS";
pub const ENV_HF_REPO: &'static str = "POLYSCRIBE_HF_REPO";
pub const ENV_CACHE_FILENAME: &'static str = "POLYSCRIBE_MANIFEST_CACHE_FILENAME";
pub const DEFAULT_USER_AGENT: &'static str = "polyscribe/0.1";
pub const DEFAULT_DOWNLOADER_UA: &'static str = "polyscribe-model-downloader/1";
pub const DEFAULT_HF_REPO: &'static str = "ggerganov/whisper.cpp";
pub const DEFAULT_CACHE_FILENAME: &'static str = "hf_manifest_whisper_cpp.json";
pub const DEFAULT_HTTP_TIMEOUT_SECS: u64 = 8;
pub const DEFAULT_MANIFEST_CACHE_TTL_SECONDS: u64 = 24 * 60 * 60;
pub fn project_dirs() -> Option<directories::ProjectDirs> {
directories::ProjectDirs::from("dev", "polyscribe", "polyscribe")
}
pub fn default_models_dir() -> Option<PathBuf> {
Self::project_dirs().map(|d| d.data_dir().join("models"))
}
pub fn default_plugins_dir() -> Option<PathBuf> {
Self::project_dirs().map(|d| d.data_dir().join("plugins"))
}
pub fn manifest_cache_dir() -> Option<PathBuf> {
Self::project_dirs().map(|d| d.cache_dir().join("manifest"))
}
pub fn bypass_manifest_cache() -> bool {
env::var(Self::ENV_NO_CACHE_MANIFEST).is_ok()
}
pub fn manifest_cache_ttl_seconds() -> u64 {
env::var(Self::ENV_MANIFEST_TTL_SECONDS)
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(Self::DEFAULT_MANIFEST_CACHE_TTL_SECONDS)
}
pub fn manifest_cache_filename() -> String {
env::var(Self::ENV_CACHE_FILENAME)
.unwrap_or_else(|_| Self::DEFAULT_CACHE_FILENAME.to_string())
}
pub fn models_dir(cfg: Option<&Config>) -> Option<PathBuf> {
if let Ok(env_dir) = env::var(Self::ENV_MODELS_DIR) {
if !env_dir.is_empty() {
return Some(PathBuf::from(env_dir));
}
}
merge_tables(&mut accum, &env_over);
Ok(Config(accum))
}
/// Ensure user config exists with sensible defaults, return loaded config
pub fn ensure_user_config(&self) -> Result<Config> {
let path = default_user_config_path();
if !path.exists() {
let mut defaults = toml::value::Table::default();
defaults.insert("ui".into(), toml::Value::Table({
let mut t = toml::value::Table::default();
t.insert("theme".into(), toml::Value::String("auto".into()));
t
}));
write_toml_atomic(&path, &defaults)?;
if let Some(c) = cfg {
if let Some(dir) = c.models_dir.clone() {
return Some(dir);
}
}
self.load()
Self::default_models_dir()
}
/// Save to user config atomically, merging over existing user file.
pub fn save_user(&self, new_values: &toml::value::Table) -> Result<()> {
let path = default_user_config_path();
let mut base = if path.exists() { read_toml(&path)? } else { Default::default() };
merge_tables(&mut base, new_values);
write_toml_atomic(&path, &base)
pub fn user_agent() -> String {
env::var(Self::ENV_USER_AGENT).unwrap_or_else(|_| Self::DEFAULT_USER_AGENT.to_string())
}
/// Paths used for debugging/information
pub fn paths(&self) -> (PathBuf, PathBuf, PathBuf) {
(system_config_path(), default_user_config_path(), workspace_config_path())
pub fn downloader_user_agent() -> String {
env::var(Self::ENV_USER_AGENT).unwrap_or_else(|_| Self::DEFAULT_DOWNLOADER_UA.to_string())
}
pub fn http_timeout_secs() -> u64 {
env::var(Self::ENV_HTTP_TIMEOUT_SECS)
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(Self::DEFAULT_HTTP_TIMEOUT_SECS)
}
pub fn hf_repo() -> String {
env::var(Self::ENV_HF_REPO).unwrap_or_else(|_| Self::DEFAULT_HF_REPO.to_string())
}
pub fn hf_api_base_for(repo: &str) -> String {
format!("https://huggingface.co/api/models/{}", repo)
}
pub fn manifest_cache_path() -> Option<PathBuf> {
let dir = Self::manifest_cache_dir()?;
Some(dir.join(Self::manifest_cache_filename()))
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Config {
pub models_dir: Option<PathBuf>,
pub plugins_dir: Option<PathBuf>,
}

View File

@@ -0,0 +1,31 @@
use thiserror::Error;
#[derive(Debug, Error)]
pub enum Error {
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("serde error: {0}")]
Serde(#[from] serde_json::Error),
#[error("toml error: {0}")]
Toml(#[from] toml::de::Error),
#[error("toml ser error: {0}")]
TomlSer(#[from] toml::ser::Error),
#[error("env var error: {0}")]
EnvVar(#[from] std::env::VarError),
#[error("http error: {0}")]
Http(#[from] reqwest::Error),
#[error("other: {0}")]
Other(String),
}
impl From<anyhow::Error> for Error {
fn from(e: anyhow::Error) -> Self {
Error::Other(e.to_string())
}
}

View File

@@ -1,67 +1,59 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
#![forbid(elided_lifetimes_in_paths)]
#![forbid(unused_must_use)]
#![deny(missing_docs)]
#![warn(clippy::all)]
//! PolyScribe library: business logic and core types.
//!
//! This crate exposes the reusable parts of the PolyScribe CLI as a library.
//! The binary entry point (main.rs) remains a thin CLI wrapper.
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
// Global runtime flags
use crate::prelude::*;
use anyhow::{Context, anyhow};
use chrono::Local;
use std::env;
use std::path::{Path, PathBuf};
use std::process::Command;
#[cfg(unix)]
use libc::{O_WRONLY, close, dup, dup2, open};
static QUIET: AtomicBool = AtomicBool::new(false);
static NO_INTERACTION: AtomicBool = AtomicBool::new(false);
static VERBOSE: AtomicU8 = AtomicU8::new(0);
static NO_PROGRESS: AtomicBool = AtomicBool::new(false);
/// Set quiet mode: when true, non-interactive logs should be suppressed.
pub fn set_quiet(enabled: bool) {
QUIET.store(enabled, Ordering::Relaxed);
}
/// Return current quiet mode state.
pub fn is_quiet() -> bool {
QUIET.load(Ordering::Relaxed)
}
/// Set non-interactive mode: when true, interactive prompts must be skipped.
pub fn set_no_interaction(enabled: bool) {
NO_INTERACTION.store(enabled, Ordering::Relaxed);
}
/// Return current non-interactive state.
pub fn is_no_interaction() -> bool {
NO_INTERACTION.load(Ordering::Relaxed)
}
/// Set verbose level (0 = normal, 1 = verbose, 2 = super-verbose)
pub fn set_verbose(level: u8) {
VERBOSE.store(level, Ordering::Relaxed);
}
/// Get current verbose level.
pub fn verbose_level() -> u8 {
VERBOSE.load(Ordering::Relaxed)
}
/// Disable interactive progress indicators (bars/spinners)
pub fn set_no_progress(enabled: bool) {
NO_PROGRESS.store(enabled, Ordering::Relaxed);
}
/// Return current no-progress state
pub fn is_no_progress() -> bool {
NO_PROGRESS.load(Ordering::Relaxed)
}
/// Check whether stdin is connected to a TTY. Used to avoid blocking prompts when not interactive.
pub fn stdin_is_tty() -> bool {
use std::io::IsTerminal as _;
std::io::stdin().is_terminal()
}
/// A guard that temporarily redirects stderr to /dev/null on Unix when quiet mode is active.
/// No-op on non-Unix or when quiet is disabled. Restores stderr on drop.
pub struct StderrSilencer {
#[cfg(unix)]
old_stderr_fd: i32,
@@ -71,7 +63,6 @@ pub struct StderrSilencer {
}
impl StderrSilencer {
/// Activate stderr silencing if quiet is set and on Unix; otherwise returns a no-op guard.
pub fn activate_if_quiet() -> Self {
if !is_quiet() {
return Self {
@@ -85,7 +76,6 @@ impl StderrSilencer {
Self::activate()
}
/// Activate stderr silencing unconditionally (used internally); no-op on non-Unix.
pub fn activate() -> Self {
#[cfg(unix)]
unsafe {
@@ -97,11 +87,10 @@ impl StderrSilencer {
devnull_fd: -1,
};
}
// Open /dev/null for writing
let devnull_cstr = std::ffi::CString::new("/dev/null").unwrap();
let devnull_fd = open(devnull_cstr.as_ptr(), O_WRONLY);
if devnull_fd < 0 {
close(old_fd);
let _ = close(old_fd);
return Self {
active: false,
old_stderr_fd: -1,
@@ -109,8 +98,8 @@ impl StderrSilencer {
};
}
if dup2(devnull_fd, 2) < 0 {
close(devnull_fd);
close(old_fd);
let _ = close(devnull_fd);
let _ = close(old_fd);
return Self {
active: false,
old_stderr_fd: -1,
@@ -120,7 +109,7 @@ impl StderrSilencer {
Self {
active: true,
old_stderr_fd: old_fd,
devnull_fd: devnull_fd,
devnull_fd,
}
}
#[cfg(not(unix))]
@@ -144,7 +133,6 @@ impl Drop for StderrSilencer {
}
}
/// Run the given closure with stderr temporarily silenced (Unix-only). Returns the closure result.
pub fn with_suppressed_stderr<F, T>(f: F) -> T
where
F: FnOnce() -> T,
@@ -155,13 +143,11 @@ where
result
}
/// Log an error line (always printed).
#[macro_export]
macro_rules! elog {
($($arg:tt)*) => {{ $crate::ui::error(format!($($arg)*)); }}
}
/// Log an informational line using the UI helper unless quiet mode is enabled.
#[macro_export]
macro_rules! ilog {
($($arg:tt)*) => {{
@@ -169,7 +155,6 @@ macro_rules! ilog {
}}
}
/// Log a debug/trace line when verbose level is at least the given level (u8).
#[macro_export]
macro_rules! dlog {
($lvl:expr, $($arg:tt)*) => {{
@@ -177,52 +162,28 @@ macro_rules! dlog {
}}
}
/// Backward-compatibility: map old qlog! to ilog!
#[macro_export]
macro_rules! qlog {
($($arg:tt)*) => {{ $crate::ilog!($($arg)*); }}
}
use anyhow::{Context, Result, anyhow};
use chrono::Local;
use std::env;
use std::fs::create_dir_all;
use std::path::{Path, PathBuf};
use std::process::Command;
#[cfg(unix)]
use libc::{O_WRONLY, close, dup, dup2, open};
/// Re-export backend module (GPU/CPU selection and transcription).
pub mod backend;
/// Re-export models module (model listing/downloading/updating).
pub mod models;
/// Configuration service (XDG + atomic writes)
pub mod config;
/// UI helpers
pub mod models;
pub mod error;
pub mod ui;
pub use error::Error;
pub mod prelude;
/// Transcript entry for a single segment.
#[derive(Debug, serde::Serialize, Clone)]
pub struct OutputEntry {
/// Sequential id in output ordering.
pub id: u64,
/// Speaker label associated with the segment.
pub speaker: String,
/// Start time in seconds.
pub start: f64,
/// End time in seconds.
pub end: f64,
/// Text content.
pub text: String,
}
/// Return a YYYY-MM-DD date prefix string for output file naming.
pub fn date_prefix() -> String {
Local::now().format("%Y-%m-%d").to_string()
}
/// Format a floating-point number of seconds as SRT timestamp (HH:MM:SS,mmm).
pub fn format_srt_time(seconds: f64) -> String {
let total_ms = (seconds * 1000.0).round() as i64;
let ms = total_ms % 1000;
@@ -233,7 +194,6 @@ pub fn format_srt_time(seconds: f64) -> String {
format!("{hour:02}:{min:02}:{sec:02},{ms:03}")
}
/// Render a list of transcript entries to SRT format.
pub fn render_srt(entries: &[OutputEntry]) -> String {
let mut srt = String::new();
for (index, entry) in entries.iter().enumerate() {
@@ -254,7 +214,8 @@ pub fn render_srt(entries: &[OutputEntry]) -> String {
srt
}
/// Determine the default models directory, honoring POLYSCRIBE_MODELS_DIR override.
pub mod model_manager;
pub fn models_dir_path() -> PathBuf {
if let Ok(env_val) = env::var("POLYSCRIBE_MODELS_DIR") {
let env_path = PathBuf::from(env_val);
@@ -265,24 +226,23 @@ pub fn models_dir_path() -> PathBuf {
if cfg!(debug_assertions) {
return PathBuf::from("models");
}
if let Ok(xdg) = env::var("XDG_DATA_HOME") {
if !xdg.is_empty() {
return PathBuf::from(xdg).join("polyscribe").join("models");
}
if let Ok(xdg) = env::var("XDG_DATA_HOME")
&& !xdg.is_empty()
{
return PathBuf::from(xdg).join("polyscribe").join("models");
}
if let Ok(home) = env::var("HOME") {
if !home.is_empty() {
return PathBuf::from(home)
.join(".local")
.join("share")
.join("polyscribe")
.join("models");
}
if let Ok(home) = env::var("HOME")
&& !home.is_empty()
{
return PathBuf::from(home)
.join(".local")
.join("share")
.join("polyscribe")
.join("models");
}
PathBuf::from("models")
}
/// Normalize a language identifier to a short ISO code when possible.
pub fn normalize_lang_code(input: &str) -> Option<String> {
let mut lang = input.trim().to_lowercase();
if lang.is_empty() || lang == "auto" || lang == "c" || lang == "posix" {
@@ -354,72 +314,92 @@ pub fn normalize_lang_code(input: &str) -> Option<String> {
Some(code.to_string())
}
/// Find the Whisper model file path to use.
pub fn find_model_file() -> Result<PathBuf> {
if let Ok(path) = env::var("WHISPER_MODEL") {
let p = PathBuf::from(path);
if p.exists() {
return Ok(p);
} else {
if !p.exists() {
return Err(anyhow!(
"WHISPER_MODEL points to non-existing file: {}",
"WHISPER_MODEL points to a non-existing path: {}",
p.display()
));
)
.into());
}
}
let models_dir = models_dir_path();
if !models_dir.exists() {
create_dir_all(&models_dir).with_context(|| {
format!("Failed to create models dir: {}", models_dir.display())
})?;
if !p.is_file() {
return Err(anyhow!(
"WHISPER_MODEL must point to a file, but is not: {}",
p.display()
)
.into());
}
return Ok(p);
}
// Heuristic: prefer larger model files and English-only when language hint is en
let models_dir = models_dir_path();
if models_dir.exists() && !models_dir.is_dir() {
return Err(anyhow!(
"Models path exists but is not a directory: {}",
models_dir.display()
)
.into());
}
std::fs::create_dir_all(&models_dir).with_context(|| {
format!(
"Failed to ensure models dir exists: {}",
models_dir.display()
)
})?;
let mut candidates = Vec::new();
for entry in std::fs::read_dir(&models_dir).with_context(|| format!(
"Failed to read models dir: {}",
models_dir.display()
))? {
for entry in std::fs::read_dir(&models_dir)
.with_context(|| format!("Failed to read models dir: {}", models_dir.display()))?
{
let entry = entry?;
let path = entry.path();
if !path
let is_bin = path
.extension()
.and_then(|s| s.to_str())
.is_some_and(|s| s.eq_ignore_ascii_case("bin"))
{
.is_some_and(|s| s.eq_ignore_ascii_case("bin"));
if !is_bin {
continue;
}
if let Ok(md) = std::fs::metadata(&path) {
candidates.push((md.len(), path));
}
let md = match std::fs::metadata(&path) {
Ok(m) if m.is_file() => m,
_ => continue,
};
candidates.push((md.len(), path));
}
if candidates.is_empty() {
// Try default fallback (tiny.en)
let fallback = models_dir.join("ggml-tiny.en.bin");
if fallback.exists() {
if fallback.is_file() {
return Ok(fallback);
}
return Err(anyhow!(
"No Whisper models found in {}. Please download a model or set WHISPER_MODEL.",
"No Whisper model files (*.bin) found in {}. \
Please download a model or set WHISPER_MODEL.",
models_dir.display()
));
)
.into());
}
candidates.sort_by_key(|(size, _)| *size);
let (_size, path) = candidates.into_iter().last().unwrap();
let (_size, path) = candidates.into_iter().last().expect("non-empty");
Ok(path)
}
/// Decode an audio file into PCM f32 samples using ffmpeg (ffmpeg executable required).
pub fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> {
let in_path = audio_path
.to_str()
.ok_or_else(|| anyhow!("Audio path must be valid UTF-8: {}", audio_path.display()))?;
let tmp_wav = std::env::temp_dir().join("polyscribe_tmp_input.wav");
let tmp_wav_str = tmp_wav
.to_str()
.ok_or_else(|| anyhow!("Temp path not valid UTF-8: {}", tmp_wav.display()))?;
// ffmpeg -i input -f f32le -ac 1 -ar 16000 -y /tmp/tmp.raw
let tmp_raw = std::env::temp_dir().join("polyscribe_tmp_input.f32le");
let tmp_raw_str = tmp_raw
.to_str()
.ok_or_else(|| anyhow!("Temp path not valid UTF-8: {}", tmp_raw.display()))?;
let status = Command::new("ffmpeg")
.arg("-hide_banner")
.arg("-loglevel")
@@ -433,16 +413,25 @@ pub fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> {
.arg("-ar")
.arg("16000")
.arg("-y")
.arg(&tmp_wav_str)
.arg(tmp_raw_str)
.status()
.with_context(|| format!("Failed to invoke ffmpeg to decode: {}", in_path))?;
if !status.success() {
return Err(anyhow!("ffmpeg exited with non-zero status when decoding {}", in_path));
return Err(anyhow!(
"ffmpeg exited with non-zero status when decoding {}",
in_path
)
.into());
}
let raw = std::fs::read(&tmp_wav).with_context(|| format!("Failed to read temp PCM file: {}", tmp_wav.display()))?;
// Interpret raw bytes as f32 little-endian
let raw = std::fs::read(&tmp_raw)
.with_context(|| format!("Failed to read temp PCM file: {}", tmp_raw.display()))?;
let _ = std::fs::remove_file(&tmp_raw);
if raw.len() % 4 != 0 {
return Err(anyhow!("Decoded PCM file length not multiple of 4: {}", raw.len()));
return Err(anyhow!("Decoded PCM file length not multiple of 4: {}", raw.len()).into());
}
let mut samples = Vec::with_capacity(raw.len() / 4);
for chunk in raw.chunks_exact(4) {

View File

@@ -0,0 +1,893 @@
// SPDX-License-Identifier: MIT
use crate::prelude::*;
use crate::ui::BytesProgress;
use anyhow::{anyhow, Context};
use chrono::{DateTime, Utc};
use reqwest::blocking::Client;
use reqwest::header::{
ACCEPT_RANGES, AUTHORIZATION, CONTENT_LENGTH, ETAG, IF_NONE_MATCH, LAST_MODIFIED, RANGE,
};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::cmp::min;
use std::collections::BTreeMap;
use std::fs::{self, File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::sync::{mpsc, Arc, Mutex};
use std::thread;
use std::time::Duration;
const DEFAULT_CHUNK_SIZE: u64 = 8 * 1024 * 1024; // 8 MiB
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ModelRecord {
pub alias: String,
pub repo: String,
pub file: String,
pub revision: Option<String>, // ETag or commit hash
pub sha256: Option<String>,
pub size_bytes: Option<u64>,
pub quant: Option<String>,
pub installed_at: Option<DateTime<Utc>>,
pub last_used: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Manifest {
pub models: BTreeMap<String, ModelRecord>, // key = alias
}
#[derive(Debug, Clone)]
pub struct Settings {
pub concurrency: usize,
pub limit_rate: Option<u64>, // bytes/sec
pub chunk_size: u64,
}
impl Default for Settings {
fn default() -> Self {
Self {
concurrency: 4,
limit_rate: None,
chunk_size: DEFAULT_CHUNK_SIZE,
}
}
}
#[derive(Debug, Clone)]
pub struct Paths {
pub cache_dir: PathBuf, // $XDG_CACHE_HOME/polyscribe/models
pub config_path: PathBuf, // $XDG_CONFIG_HOME/polyscribe/models.json
}
impl Paths {
pub fn resolve() -> Result<Self> {
if let Ok(over) = std::env::var("POLYSCRIBE_CACHE_DIR") {
if !over.is_empty() {
let cache_dir = PathBuf::from(over).join("models");
let config_path = std::env::var("POLYSCRIBE_CONFIG_DIR")
.map(|p| PathBuf::from(p).join("models.json"))
.unwrap_or_else(|_| default_config_path());
return Ok(Self {
cache_dir,
config_path,
});
}
}
let cache_dir = default_cache_dir();
let config_path = default_config_path();
Ok(Self {
cache_dir,
config_path,
})
}
}
fn default_cache_dir() -> PathBuf {
if let Ok(xdg) = std::env::var("XDG_CACHE_HOME") {
if !xdg.is_empty() {
return PathBuf::from(xdg).join("polyscribe").join("models");
}
}
if let Ok(home) = std::env::var("HOME") {
if !home.is_empty() {
return PathBuf::from(home)
.join(".cache")
.join("polyscribe")
.join("models");
}
}
PathBuf::from("models")
}
fn default_config_path() -> PathBuf {
if let Ok(xdg) = std::env::var("XDG_CONFIG_HOME") {
if !xdg.is_empty() {
return PathBuf::from(xdg).join("polyscribe").join("models.json");
}
}
if let Ok(home) = std::env::var("HOME") {
if !home.is_empty() {
return PathBuf::from(home)
.join(".config")
.join("polyscribe")
.join("models.json");
}
}
PathBuf::from("models.json")
}
pub trait HttpClient: Send + Sync {
fn head(&self, url: &str, etag: Option<&str>) -> Result<HeadMeta>;
fn get_range(&self, url: &str, start: u64, end_inclusive: u64) -> Result<Vec<u8>>;
fn get_whole_to(&self, url: &str, writer: &mut dyn Write) -> Result<()>;
fn get_from_to(&self, url: &str, start: u64, writer: &mut dyn Write) -> Result<()>;
}
#[derive(Debug, Clone)]
pub struct ReqwestClient {
client: Client,
token: Option<String>,
}
impl ReqwestClient {
pub fn new() -> Result<Self> {
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
let client = Client::builder()
.user_agent(crate::config::ConfigService::user_agent())
.build()?;
Ok(Self { client, token })
}
fn auth(&self, mut req: reqwest::blocking::RequestBuilder) -> reqwest::blocking::RequestBuilder {
if let Some(t) = &self.token {
req = req.header(AUTHORIZATION, format!("Bearer {}", t));
}
req
}
}
#[derive(Debug, Clone)]
pub struct HeadMeta {
pub len: Option<u64>,
pub etag: Option<String>,
pub last_modified: Option<String>,
pub accept_ranges: bool,
pub not_modified: bool,
pub status: u16,
}
impl HttpClient for ReqwestClient {
fn head(&self, url: &str, etag: Option<&str>) -> Result<HeadMeta> {
let mut req = self.client.head(url);
if let Some(e) = etag {
req = req.header(IF_NONE_MATCH, format!("\"{}\"", e));
}
let resp = self.auth(req).send()?;
let status = resp.status().as_u16();
if status == 304 {
return Ok(HeadMeta {
len: None,
etag: etag.map(|s| s.to_string()),
last_modified: None,
accept_ranges: true,
not_modified: true,
status,
});
}
let len = resp
.headers()
.get(CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
let etag = resp
.headers()
.get(ETAG)
.and_then(|v| v.to_str().ok())
.map(|s| s.trim_matches('"').to_string());
let last_modified = resp
.headers()
.get(LAST_MODIFIED)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let accept_ranges = resp
.headers()
.get(ACCEPT_RANGES)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_ascii_lowercase().contains("bytes"))
.unwrap_or(false);
Ok(HeadMeta {
len,
etag,
last_modified,
accept_ranges,
not_modified: false,
status,
})
}
fn get_range(&self, url: &str, start: u64, end_inclusive: u64) -> Result<Vec<u8>> {
let range_val = format!("bytes={}-{}", start, end_inclusive);
let resp = self
.auth(self.client.get(url))
.header(RANGE, range_val)
.send()?;
if !resp.status().is_success() && resp.status().as_u16() != 206 {
return Err(anyhow!("HTTP {} for ranged GET", resp.status()).into());
}
let mut buf = Vec::new();
let mut r = resp;
r.copy_to(&mut buf)?;
Ok(buf)
}
fn get_whole_to(&self, url: &str, writer: &mut dyn Write) -> Result<()> {
let resp = self.auth(self.client.get(url)).send()?;
if !resp.status().is_success() {
return Err(anyhow!("HTTP {} for GET", resp.status()).into());
}
let mut r = resp;
r.copy_to(writer)?;
Ok(())
}
fn get_from_to(&self, url: &str, start: u64, writer: &mut dyn Write) -> Result<()> {
let mut req = self.auth(self.client.get(url));
if start > 0 {
req = req.header(RANGE, format!("bytes={}-", start));
}
let resp = req.send()?;
if !resp.status().is_success() && resp.status().as_u16() != 206 {
return Err(anyhow!("HTTP {} for ranged GET from {}", resp.status(), start).into());
}
let mut r = resp;
r.copy_to(writer)?;
Ok(())
}
}
pub struct ModelManager<C: HttpClient = ReqwestClient> {
pub paths: Paths,
pub settings: Settings,
client: Arc<C>,
}
impl<C: HttpClient + 'static> ModelManager<C> {
pub fn new_with_client(client: C, settings: Settings) -> Result<Self> {
Ok(Self {
paths: Paths::resolve()?,
settings,
client: Arc::new(client),
})
}
pub fn new(settings: Settings) -> Result<Self>
where
C: Default,
{
Ok(Self {
paths: Paths::resolve()?,
settings,
client: Arc::new(C::default()),
})
}
fn load_manifest(&self) -> Result<Manifest> {
let p = &self.paths.config_path;
if !p.exists() {
return Ok(Manifest::default());
}
let file = File::open(p).with_context(|| format!("open manifest {}", p.display()))?;
let m: Manifest = serde_json::from_reader(file).context("parse manifest")?;
Ok(m)
}
fn save_manifest(&self, m: &Manifest) -> Result<()> {
let p = &self.paths.config_path;
if let Some(dir) = p.parent() {
fs::create_dir_all(dir)
.with_context(|| format!("create config dir {}", dir.display()))?;
}
let tmp = p.with_extension("json.tmp");
let f = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(&tmp)?;
serde_json::to_writer_pretty(f, m).context("serialize manifest")?;
fs::rename(&tmp, p).with_context(|| format!("atomic rename {} -> {}", tmp.display(), p.display()))?;
Ok(())
}
fn model_path(&self, file: &str) -> PathBuf {
self.paths.cache_dir.join(file)
}
fn compute_sha256(path: &Path) -> Result<String> {
let mut f = File::open(path)?;
let mut hasher = Sha256::new();
let mut buf = [0u8; 64 * 1024];
loop {
let n = f.read(&mut buf)?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
Ok(format!("{:x}", hasher.finalize()))
}
pub fn ls(&self) -> Result<Vec<ModelRecord>> {
let m = self.load_manifest()?;
Ok(m.models.values().cloned().collect())
}
pub fn rm(&self, alias: &str) -> Result<bool> {
let mut m = self.load_manifest()?;
if let Some(rec) = m.models.remove(alias) {
let p = self.model_path(&rec.file);
let _ = fs::remove_file(&p);
self.save_manifest(&m)?;
return Ok(true);
}
Ok(false)
}
pub fn verify(&self, alias: &str) -> Result<bool> {
let m = self.load_manifest()?;
let Some(rec) = m.models.get(alias) else { return Ok(false) };
let p = self.model_path(&rec.file);
if !p.exists() { return Ok(false); }
if let Some(expected) = &rec.sha256 {
let actual = Self::compute_sha256(&p)?;
return Ok(&actual == expected);
}
Ok(true)
}
pub fn gc(&self) -> Result<(usize, usize)> {
// Remove files not referenced by manifest; also drop manifest entries whose file is missing
fs::create_dir_all(&self.paths.cache_dir).ok();
let mut m = self.load_manifest()?;
let mut referenced = BTreeMap::new();
for (alias, rec) in &m.models {
referenced.insert(rec.file.clone(), alias.clone());
}
let mut removed_files = 0usize;
if let Ok(rd) = fs::read_dir(&self.paths.cache_dir) {
for ent in rd.flatten() {
let p = ent.path();
if p.is_file() {
let fname = p.file_name().and_then(|s| s.to_str()).unwrap_or("");
if !referenced.contains_key(fname) {
let _ = fs::remove_file(&p);
removed_files += 1;
}
}
}
}
m.models.retain(|_, rec| self.model_path(&rec.file).exists());
let removed_entries = referenced
.keys()
.filter(|f| !self.model_path(f).exists())
.count();
self.save_manifest(&m)?;
Ok((removed_files, removed_entries))
}
pub fn add_or_update(
&self,
alias: &str,
repo: &str,
file: &str,
) -> Result<ModelRecord> {
fs::create_dir_all(&self.paths.cache_dir)
.with_context(|| format!("create cache dir {}", self.paths.cache_dir.display()))?;
let url = format!("https://huggingface.co/{}/resolve/main/{}?download=true", repo, file);
let mut manifest = self.load_manifest()?;
let prev = manifest.models.get(alias).cloned();
let prev_etag = prev.as_ref().and_then(|r| r.revision.clone());
// Fetch remote meta (size/hash) when available to verify the download
let (_api_size, api_sha) = hf_fetch_file_meta(repo, file).unwrap_or((None, None));
let head = self.client.head(&url, prev_etag.as_deref())?;
if head.not_modified {
// up-to-date; ensure record present and touch last_used
let mut rec = prev.ok_or_else(|| anyhow!("not installed yet but server says 304"))?;
rec.last_used = Some(Utc::now());
self.save_touch(&mut manifest, rec.clone())?;
return Ok(rec);
}
// Quick check: if HEAD failed (e.g., 404), report a helpful error before attempting download
if head.status >= 400 {
return Err(anyhow!(
"file not found or inaccessible: repo='{}' file='{}' (HTTP {})\nHint: run `polyscribe models search {} --query {}` to list available files",
repo,
file,
head.status,
repo,
file
).into());
}
let total_len = head.len.ok_or_else(|| anyhow!("missing content-length (HEAD)"))?;
let etag = head.etag.clone();
let dest_tmp = self.model_path(&format!("{}.part", file));
// If a previous cancelled download left a .part file, remove it to avoid clutter/resume.
if dest_tmp.exists() { let _ = fs::remove_file(&dest_tmp); }
// Guard to ensure .part is cleaned up on errors
struct TempGuard { path: PathBuf, armed: bool }
impl TempGuard { fn disarm(&mut self) { self.armed = false; } }
impl Drop for TempGuard {
fn drop(&mut self) {
if self.armed {
let _ = fs::remove_file(&self.path);
}
}
}
let mut _tmp_guard = TempGuard { path: dest_tmp.clone(), armed: true };
let dest_final = self.model_path(file);
// Do not resume after cancellation; start fresh to avoid stale .part files
let start_from = 0u64;
// Open tmp for write
let f = OpenOptions::new().create(true).write(true).read(true).open(&dest_tmp)?;
f.set_len(total_len)?; // pre-allocate for random writes
let f = Arc::new(Mutex::new(f));
// Create progress bar
let mut progress = BytesProgress::start(total_len, &format!("Downloading {}", file), start_from);
// Create work chunks
let chunk_size = self.settings.chunk_size;
let mut chunks = Vec::new();
let mut pos = start_from;
while pos < total_len {
let end = min(total_len - 1, pos + chunk_size - 1);
chunks.push((pos, end));
pos = end + 1;
}
// Attempt concurrent ranged download; on failure, fallback to streaming GET
let mut ranged_failed = false;
if head.accept_ranges && self.settings.concurrency > 1 {
let (work_tx, work_rx) = mpsc::channel::<(u64, u64)>();
let (prog_tx, prog_rx) = mpsc::channel::<u64>();
for ch in chunks {
work_tx.send(ch).unwrap();
}
drop(work_tx);
let rx = Arc::new(Mutex::new(work_rx));
let workers = self.settings.concurrency.max(1);
let mut handles = Vec::new();
for _ in 0..workers {
let rx = rx.clone();
let url = url.clone();
let client = self.client.clone();
let f = f.clone();
let limit = self.settings.limit_rate;
let prog_tx = prog_tx.clone();
let handle = thread::spawn(move || -> Result<()> {
loop {
let next = {
let guard = rx.lock().unwrap();
guard.recv().ok()
};
let Some((start, end)) = next else { break; };
let data = client.get_range(&url, start, end)?;
if let Some(max_bps) = limit {
let dur = Duration::from_secs_f64((data.len() as f64) / (max_bps as f64));
if dur > Duration::from_millis(1) {
thread::sleep(dur);
}
}
let mut guard = f.lock().unwrap();
guard.seek(SeekFrom::Start(start))?;
guard.write_all(&data)?;
let _ = prog_tx.send(data.len() as u64);
}
Ok(())
});
handles.push(handle);
}
drop(prog_tx);
for delta in prog_rx {
progress.inc(delta);
}
let mut ranged_err: Option<anyhow::Error> = None;
for h in handles {
match h.join() {
Ok(Ok(())) => {}
Ok(Err(e)) => { ranged_failed = true; if ranged_err.is_none() { ranged_err = Some(e.into()); } }
Err(_) => { ranged_failed = true; if ranged_err.is_none() { ranged_err = Some(anyhow!("worker panicked")); } }
}
}
} else {
ranged_failed = true;
}
if ranged_failed {
// Restart progress if we are abandoning previous partial state
if start_from > 0 {
progress.stop("retrying as streaming");
progress = BytesProgress::start(total_len, &format!("Downloading {}", file), 0);
}
// Fallback to streaming GET; try URL with and without ?download=true
let mut try_urls = vec![url.clone()];
if let Some((base, _qs)) = url.split_once('?') {
try_urls.push(base.to_string());
} else {
try_urls.push(format!("{}?download=true", url));
}
// Fresh temp file for streaming
let mut wf = OpenOptions::new().create(true).write(true).truncate(true).open(&dest_tmp)?;
let mut ok = false;
let mut last_err: Option<anyhow::Error> = None;
// Counting writer to update progress inline
struct CountingWriter<'a, 'b> {
inner: &'a mut File,
progress: &'b mut BytesProgress,
}
impl<'a, 'b> Write for CountingWriter<'a, 'b> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let n = self.inner.write(buf)?;
self.progress.inc(n as u64);
Ok(n)
}
fn flush(&mut self) -> std::io::Result<()> { self.inner.flush() }
}
let mut cw = CountingWriter { inner: &mut wf, progress: &mut progress };
for u in try_urls {
// For fallback, stream from scratch to ensure integrity
let res = self.client.get_whole_to(&u, &mut cw);
match res {
Ok(()) => { ok = true; break; }
Err(e) => { last_err = Some(e.into()); }
}
}
if !ok {
if let Some(e) = last_err { return Err(e.into()); }
return Err(anyhow!("download failed (ranged and streaming)").into());
}
}
progress.stop("download complete");
// Verify integrity
let sha = Self::compute_sha256(&dest_tmp)?;
if let Some(expected) = api_sha.as_ref() {
if &sha != expected {
return Err(anyhow!(
"sha256 mismatch (expected {}, got {})",
expected,
sha
).into());
}
} else if prev.as_ref().map(|r| r.file.eq(file)).unwrap_or(false) {
if let Some(expected) = prev.as_ref().and_then(|r| r.sha256.as_ref()) {
if &sha != expected {
return Err(anyhow!("sha256 mismatch").into());
}
}
}
// Atomic rename
fs::rename(&dest_tmp, &dest_final).with_context(|| format!("rename {} -> {}", dest_tmp.display(), dest_final.display()))?;
// Disarm guard; .part has been moved or cleaned
_tmp_guard.disarm();
let rec = ModelRecord {
alias: alias.to_string(),
repo: repo.to_string(),
file: file.to_string(),
revision: etag,
sha256: Some(sha.clone()),
size_bytes: Some(total_len),
quant: infer_quant(file),
installed_at: Some(Utc::now()),
last_used: Some(Utc::now()),
};
self.save_touch(&mut manifest, rec.clone())?;
Ok(rec)
}
fn save_touch(&self, manifest: &mut Manifest, rec: ModelRecord) -> Result<()> {
manifest.models.insert(rec.alias.clone(), rec);
self.save_manifest(manifest)
}
}
fn infer_quant(file: &str) -> Option<String> {
// Try to extract a Q* token, e.g. Q5_K_M from filename
let lower = file.to_ascii_uppercase();
if let Some(pos) = lower.find('Q') {
let tail = &lower[pos..];
let token: String = tail
.chars()
.take_while(|c| c.is_ascii_uppercase() || c.is_ascii_digit() || *c == '_' || *c == '-')
.collect();
if token.len() >= 2 {
return Some(token);
}
}
None
}
impl Default for ReqwestClient {
fn default() -> Self {
Self::new().expect("reqwest client")
}
}
// Hugging Face API types for file metadata
#[derive(Debug, Deserialize)]
struct ApiHfLfs {
oid: Option<String>,
size: Option<u64>,
sha256: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ApiHfFile {
rfilename: String,
size: Option<u64>,
sha256: Option<String>,
lfs: Option<ApiHfLfs>,
}
#[derive(Debug, Deserialize)]
struct ApiHfModelInfo {
siblings: Option<Vec<ApiHfFile>>,
files: Option<Vec<ApiHfFile>>,
}
fn pick_sha_from_file(f: &ApiHfFile) -> Option<String> {
if let Some(s) = &f.sha256 { return Some(s.to_string()); }
if let Some(l) = &f.lfs {
if let Some(s) = &l.sha256 { return Some(s.to_string()); }
if let Some(oid) = &l.oid { return oid.strip_prefix("sha256:").map(|s| s.to_string()); }
}
None
}
fn hf_fetch_file_meta(repo: &str, target: &str) -> Result<(Option<u64>, Option<String>)> {
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
let client = Client::builder()
.user_agent(crate::config::ConfigService::user_agent())
.build()?;
let base = crate::config::ConfigService::hf_api_base_for(repo);
let urls = [base.clone(), format!("{}?expand=files", base)];
for url in urls {
let mut req = client.get(&url);
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
let resp = req.send()?;
if !resp.status().is_success() { continue; }
let info: ApiHfModelInfo = resp.json()?;
let list = info.files.or(info.siblings).unwrap_or_default();
for f in list {
let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename);
if name.eq_ignore_ascii_case(target) {
let sz = f.size.or_else(|| f.lfs.as_ref().and_then(|l| l.size));
let sha = pick_sha_from_file(&f);
return Ok((sz, sha));
}
}
}
Err(anyhow!("file not found in HF API").into())
}
/// Fetch remote metadata (size, sha256) for a single file in a HF repo.
pub fn fetch_file_meta(repo: &str, file: &str) -> Result<(Option<u64>, Option<String>)> {
hf_fetch_file_meta(repo, file)
}
/// Search a Hugging Face repo for GGUF/BIN files via API. Returns file names only.
pub fn search_repo(repo: &str, query: Option<&str>) -> Result<Vec<String>> {
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
let client = Client::builder()
.user_agent(crate::config::ConfigService::user_agent())
.build()?;
let base = crate::config::ConfigService::hf_api_base_for(repo);
let mut urls = vec![base.clone(), format!("{}?expand=files", base)];
let mut files = Vec::<String>::new();
for url in urls.drain(..) {
let mut req = client.get(&url);
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
let resp = req.send()?;
if !resp.status().is_success() { continue; }
let info: ApiHfModelInfo = resp.json()?;
let list = info.files.or(info.siblings).unwrap_or_default();
for f in list {
if f.rfilename.ends_with(".gguf") || f.rfilename.ends_with(".bin") {
let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename).to_string();
if !files.contains(&name) { files.push(name); }
}
}
if !files.is_empty() { break; }
}
if let Some(q) = query { let qlc = q.to_ascii_lowercase(); files.retain(|f| f.to_ascii_lowercase().contains(&qlc)); }
files.sort();
Ok(files)
}
/// List repo files with optional size metadata for GGUF/BIN entries.
pub fn list_repo_files_with_meta(repo: &str) -> Result<Vec<(String, Option<u64>)>> {
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
let client = Client::builder()
.user_agent(crate::config::ConfigService::user_agent())
.build()?;
let base = crate::config::ConfigService::hf_api_base_for(repo);
for url in [base.clone(), format!("{}?expand=files", base)] {
let mut req = client.get(&url);
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
let resp = req.send()?;
if !resp.status().is_success() { continue; }
let info: ApiHfModelInfo = resp.json()?;
let list = info.files.or(info.siblings).unwrap_or_default();
let mut out = Vec::new();
for f in list {
if !(f.rfilename.ends_with(".gguf") || f.rfilename.ends_with(".bin")) { continue; }
let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename).to_string();
let size = f.size.or_else(|| f.lfs.as_ref().and_then(|l| l.size));
out.push((name, size));
}
if !out.is_empty() { return Ok(out); }
}
Ok(Vec::new())
}
/// Fallback: HEAD request for a single file to retrieve Content-Length (size).
pub fn head_len_for_file(repo: &str, file: &str) -> Option<u64> {
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
let client = Client::builder()
.user_agent(crate::config::ConfigService::user_agent())
.build().ok()?;
let mut urls = Vec::new();
urls.push(format!("https://huggingface.co/{}/resolve/main/{}?download=true", repo, file));
urls.push(format!("https://huggingface.co/{}/resolve/main/{}", repo, file));
for url in urls {
let mut req = client.head(&url);
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
if let Ok(resp) = req.send() {
if resp.status().is_success() {
if let Some(len) = resp.headers().get(CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
{ return Some(len); }
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
use tempfile::TempDir;
#[derive(Clone)]
struct StubHttp {
data: Arc<Vec<u8>>,
etag: Arc<Option<String>>,
accept_ranges: bool,
}
impl HttpClient for StubHttp {
fn head(&self, _url: &str, etag: Option<&str>) -> Result<HeadMeta> {
let not_modified = etag.is_some() && self.etag.as_ref().as_deref() == etag;
Ok(HeadMeta {
len: Some(self.data.len() as u64),
etag: self.etag.as_ref().clone(),
last_modified: None,
accept_ranges: self.accept_ranges,
not_modified,
status: if not_modified { 304 } else { 200 },
})
}
fn get_range(&self, _url: &str, start: u64, end_inclusive: u64) -> Result<Vec<u8>> {
let s = start as usize;
let e = (end_inclusive as usize) + 1;
Ok(self.data[s..e].to_vec())
}
fn get_whole_to(&self, _url: &str, writer: &mut dyn Write) -> Result<()> {
writer.write_all(&self.data)?;
Ok(())
}
fn get_from_to(&self, _url: &str, start: u64, writer: &mut dyn Write) -> Result<()> {
let s = start as usize;
writer.write_all(&self.data[s..])?;
Ok(())
}
}
fn setup_env(cache: &Path, cfg: &Path) {
unsafe {
env::set_var("POLYSCRIBE_CACHE_DIR", cache.to_string_lossy().to_string());
env::set_var(
"POLYSCRIBE_CONFIG_DIR",
cfg.parent().unwrap().to_string_lossy().to_string(),
);
}
}
#[test]
fn test_manifest_roundtrip() {
let temp = TempDir::new().unwrap();
let cache = temp.path().join("cache");
let cfg = temp.path().join("config").join("models.json");
setup_env(&cache, &cfg);
let client = StubHttp {
data: Arc::new(vec![0u8; 1024]),
etag: Arc::new(Some("etag123".into())),
accept_ranges: true,
};
let mm: ModelManager<StubHttp> = ModelManager::new_with_client(client, Settings::default()).unwrap();
let m = mm.load_manifest().unwrap();
assert!(m.models.is_empty());
let rec = ModelRecord {
alias: "tiny".into(),
repo: "foo/bar".into(),
file: "gguf-tiny.bin".into(),
revision: Some("etag123".into()),
sha256: None,
size_bytes: None,
quant: None,
installed_at: None,
last_used: None,
};
let mut m2 = Manifest::default();
mm.save_touch(&mut m2, rec.clone()).unwrap();
let m3 = mm.load_manifest().unwrap();
assert!(m3.models.contains_key("tiny"));
}
#[test]
fn test_add_verify_update_gc() {
let temp = TempDir::new().unwrap();
let cache = temp.path().join("cache");
let cfg_dir = temp.path().join("config");
let cfg = cfg_dir.join("models.json");
setup_env(&cache, &cfg);
let data = (0..1024 * 1024u32).flat_map(|i| i.to_le_bytes()).collect::<Vec<u8>>();
let etag = Some("abc123".to_string());
let client = StubHttp { data: Arc::new(data), etag: Arc::new(etag), accept_ranges: true };
let mm: ModelManager<StubHttp> = ModelManager::new_with_client(client.clone(), Settings{ concurrency: 3, ..Default::default() }).unwrap();
// add
let rec = mm.add_or_update("tiny", "gguf/models", "gguf-tiny-q4_0.bin").unwrap();
assert_eq!(rec.alias, "tiny");
assert!(mm.verify("tiny").unwrap());
// update (304)
let rec2 = mm.add_or_update("tiny", "gguf/models", "gguf-tiny-q4_0.bin").unwrap();
assert_eq!(rec2.alias, "tiny");
// gc (nothing to remove)
let (files_removed, entries_removed) = mm.gc().unwrap();
assert_eq!(files_removed, 0);
assert_eq!(entries_removed, 0);
// rm
assert!(mm.rm("tiny").unwrap());
assert!(!mm.rm("tiny").unwrap());
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,7 @@
pub use crate::backend::*;
pub use crate::config::*;
pub use crate::error::Error;
pub use crate::models::*;
pub use crate::ui::*;
pub type Result<T, E = Error> = std::result::Result<T, E>;

View File

@@ -1,87 +1,329 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
//! Centralized UI helpers (TTY-aware, quiet/verbose-aware)
pub mod progress;
use std::io;
use std::io::IsTerminal;
use std::io::Write as _;
use std::time::{Duration, Instant};
/// Startup intro/banner (suppressed when quiet).
pub fn intro(msg: impl AsRef<str>) {
let _ = cliclack::intro(msg.as_ref());
pub fn info(msg: impl AsRef<str>) {
let m = msg.as_ref();
let _ = cliclack::log::info(m);
}
pub fn warn(msg: impl AsRef<str>) {
let m = msg.as_ref();
let _ = cliclack::log::warning(m);
}
pub fn error(msg: impl AsRef<str>) {
let m = msg.as_ref();
let _ = cliclack::log::error(m);
}
pub fn success(msg: impl AsRef<str>) {
let m = msg.as_ref();
let _ = cliclack::log::success(m);
}
pub fn note(prompt: impl AsRef<str>, message: impl AsRef<str>) {
let _ = cliclack::note(prompt.as_ref(), message.as_ref());
}
pub fn intro(title: impl AsRef<str>) {
let _ = cliclack::intro(title.as_ref());
}
/// Final outro/summary printed below any progress indicators (suppressed when quiet).
pub fn outro(msg: impl AsRef<str>) {
let _ = cliclack::outro(msg.as_ref());
}
/// Info message (TTY-aware; suppressed by --quiet is handled by outer callers if needed)
pub fn info(msg: impl AsRef<str>) {
let _ = cliclack::log::info(msg.as_ref());
pub fn println_above_bars(line: impl AsRef<str>) {
let _ = cliclack::log::info(line.as_ref());
}
/// Print a warning (always printed).
pub fn warn(msg: impl AsRef<str>) {
// cliclack provides a warning-level log utility
let _ = cliclack::log::warning(msg.as_ref());
}
/// Print an error (always printed).
pub fn error(msg: impl AsRef<str>) {
let _ = cliclack::log::error(msg.as_ref());
}
/// Print a line above any progress bars (maps to cliclack log; synchronized).
pub fn println_above_bars(msg: impl AsRef<str>) {
if crate::is_quiet() { return; }
// cliclack logs are synchronized with its spinners/bars
let _ = cliclack::log::info(msg.as_ref());
}
/// Input prompt with a question: returns Ok(None) if non-interactive or canceled
pub fn prompt_input(question: impl AsRef<str>, default: Option<&str>) -> anyhow::Result<Option<String>> {
pub fn prompt_input(prompt: &str, default: Option<&str>) -> io::Result<String> {
if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Ok(None);
return Ok(default.unwrap_or("").to_string());
}
let mut p = cliclack::input(question.as_ref());
if let Some(d) = default {
// Use default_input when available in 0.3.x
p = p.default_input(d);
}
match p.interact() {
Ok(s) => Ok(Some(s)),
Err(_) => Ok(None),
let mut q = cliclack::input(prompt);
if let Some(def) = default {
q = q.default_input(def);
}
q.interact().map_err(|e| io::Error::other(e.to_string()))
}
/// Confirmation prompt; returns Ok(None) if non-interactive or canceled
pub fn prompt_confirm(question: impl AsRef<str>, default_yes: bool) -> anyhow::Result<Option<bool>> {
pub fn prompt_select(prompt: &str, items: &[&str]) -> io::Result<usize> {
if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Ok(None);
return Err(io::Error::other("interactive prompt disabled"));
}
let res = cliclack::confirm(question.as_ref())
.initial_value(default_yes)
.interact();
match res {
Ok(v) => Ok(Some(v)),
Err(_) => Ok(None),
let mut sel = cliclack::select::<usize>(prompt);
for (idx, label) in items.iter().enumerate() {
sel = sel.item(idx, *label, "");
}
sel.interact().map_err(|e| io::Error::other(e.to_string()))
}
pub fn prompt_multi_select(
prompt: &str,
items: &[&str],
defaults: Option<&[bool]>,
) -> io::Result<Vec<usize>> {
if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Err(io::Error::other("interactive prompt disabled"));
}
let mut ms = cliclack::multiselect::<usize>(prompt);
for (idx, label) in items.iter().enumerate() {
ms = ms.item(idx, *label, "");
}
if let Some(def) = defaults {
let selected: Vec<usize> = def
.iter()
.enumerate()
.filter_map(|(i, &on)| if on { Some(i) } else { None })
.collect();
if !selected.is_empty() {
ms = ms.initial_values(selected);
}
}
ms.interact().map_err(|e| io::Error::other(e.to_string()))
}
pub fn prompt_confirm(prompt: &str, default: bool) -> io::Result<bool> {
if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Ok(default);
}
let mut q = cliclack::confirm(prompt);
q.interact().map_err(|e| io::Error::other(e.to_string()))
}
pub fn prompt_password(prompt: &str) -> io::Result<String> {
if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Err(io::Error::other(
"password prompt disabled in non-interactive mode",
));
}
let mut q = cliclack::password(prompt);
q.interact().map_err(|e| io::Error::other(e.to_string()))
}
pub fn prompt_input_validated<F>(
prompt: &str,
default: Option<&str>,
validate: F,
) -> io::Result<String>
where
F: Fn(&str) -> Result<(), String> + 'static,
{
if crate::is_no_interaction() || !crate::stdin_is_tty() {
if let Some(def) = default {
return Ok(def.to_string());
}
return Err(io::Error::other("interactive prompt disabled"));
}
let mut q = cliclack::input(prompt);
if let Some(def) = default {
q = q.default_input(def);
}
q.validate(move |s: &String| validate(s))
.interact()
.map_err(|e| io::Error::other(e.to_string()))
}
pub struct Spinner(cliclack::ProgressBar);
impl Spinner {
pub fn start(text: impl AsRef<str>) -> Self {
if crate::is_no_progress() || crate::is_no_interaction() || !std::io::stderr().is_terminal()
{
let _ = cliclack::log::info(text.as_ref());
let s = cliclack::spinner();
Self(s)
} else {
let s = cliclack::spinner();
s.start(text.as_ref());
Self(s)
}
}
pub fn stop(self, text: impl AsRef<str>) {
let s = self.0;
if crate::is_no_progress() {
let _ = cliclack::log::info(text.as_ref());
} else {
s.stop(text.as_ref());
}
}
pub fn success(self, text: impl AsRef<str>) {
let s = self.0;
if crate::is_no_progress() {
let _ = cliclack::log::success(text.as_ref());
} else {
s.stop(text.as_ref());
}
}
pub fn error(self, text: impl AsRef<str>) {
let s = self.0;
if crate::is_no_progress() {
let _ = cliclack::log::error(text.as_ref());
} else {
s.error(text.as_ref());
}
}
}
/// Prompt the user (TTY-aware via cliclack) and read a line from stdin. Returns the raw line with trailing newline removed.
pub fn prompt_line(prompt: &str) -> io::Result<String> {
// Route prompt through cliclack to keep consistent styling and avoid direct eprint!/println!
let _ = cliclack::log::info(prompt);
let mut s = String::new();
io::stdin().read_line(&mut s)?;
Ok(s)
pub struct BytesProgress {
enabled: bool,
total: u64,
current: u64,
started: Instant,
last_msg: Instant,
width: usize,
// Sticky ETA to carry through zero-speed stalls
last_eta_secs: Option<f64>,
}
/// TTY-aware progress UI built on `indicatif` for per-file and aggregate progress bars.
///
/// This small helper encapsulates a `MultiProgress` with one aggregate (total) bar and
/// one per-file bar. It is intentionally minimal to keep integration lightweight.
pub mod progress {
// The submodule is defined in a separate file for clarity.
include!("ui/progress.rs");
impl BytesProgress {
pub fn start(total: u64, text: &str, initial: u64) -> Self {
let enabled = !(crate::is_no_progress()
|| crate::is_no_interaction()
|| !std::io::stderr().is_terminal()
|| total == 0);
if !enabled {
let _ = cliclack::log::info(text);
}
let mut me = Self {
enabled,
total,
current: initial.min(total),
started: Instant::now(),
last_msg: Instant::now(),
width: 40,
last_eta_secs: None,
};
me.draw();
me
}
fn human_bytes(n: u64) -> String {
const KB: f64 = 1024.0;
const MB: f64 = 1024.0 * KB;
const GB: f64 = 1024.0 * MB;
let x = n as f64;
if x >= GB {
format!("{:.2} GiB", x / GB)
} else if x >= MB {
format!("{:.2} MiB", x / MB)
} else if x >= KB {
format!("{:.2} KiB", x / KB)
} else {
format!("{} B", n)
}
}
// Elapsed formatting is used for stable, finite durations. For ETA, we guard
// against zero-speed or unstable estimates separately via `format_eta`.
fn refresh_allowed(&mut self) -> (f64, f64) {
let now = Instant::now();
let since_last = now.duration_since(self.last_msg);
if since_last < Duration::from_millis(100) {
// Too soon to refresh; keep previous ETA if any
let eta = self.last_eta_secs.unwrap_or(f64::INFINITY);
return (0.0, eta);
}
self.last_msg = now;
let elapsed = now.duration_since(self.started).as_secs_f64().max(0.001);
let speed = (self.current as f64) / elapsed;
let remaining = self.total.saturating_sub(self.current) as f64;
// If speed is effectively zero, carry ETA forward and add wall time.
const EPS: f64 = 1e-6;
let eta = if speed <= EPS {
let prev = self.last_eta_secs.unwrap_or(f64::INFINITY);
if prev.is_finite() {
prev + since_last.as_secs_f64()
} else {
prev
}
} else {
remaining / speed
};
// Remember only finite ETAs to use during stalls
if eta.is_finite() {
self.last_eta_secs = Some(eta);
}
(speed, eta)
}
fn format_elapsed(seconds: f64) -> String {
let total = seconds.round() as u64;
let h = total / 3600;
let m = (total % 3600) / 60;
let s = total % 60;
if h > 0 { format!("{:02}:{:02}:{:02}", h, m, s) } else { format!("{:02}:{:02}", m, s) }
}
fn format_eta(seconds: f64) -> String {
// If ETA is not finite (e.g., divide-by-zero speed) or unreasonably large,
// show a placeholder rather than overflowing into huge values.
if !seconds.is_finite() {
return "".to_string();
}
// Cap ETA display to 99:59:59 to avoid silly numbers; beyond that, show placeholder.
const CAP_SECS: f64 = 99.0 * 3600.0 + 59.0 * 60.0 + 59.0;
if seconds > CAP_SECS {
return "".to_string();
}
Self::format_elapsed(seconds)
}
fn draw(&mut self) {
if !self.enabled { return; }
let (speed, eta) = self.refresh_allowed();
let elapsed = Instant::now().duration_since(self.started).as_secs_f64();
// Build bar
let width = self.width.max(10);
let filled = ((self.current as f64 / self.total.max(1) as f64) * width as f64).round() as usize;
let filled = filled.min(width);
let mut bar = String::with_capacity(width);
for _ in 0..filled { bar.push('■'); }
for _ in filled..width { bar.push('□'); }
let line = format!(
"[{}] {} [{}] ({}/{} at {}/s)",
Self::format_elapsed(elapsed),
bar,
Self::format_eta(eta),
Self::human_bytes(self.current),
Self::human_bytes(self.total),
Self::human_bytes(speed.max(0.0) as u64),
);
eprint!("\r{}\x1b[K", line);
let _ = io::stderr().flush();
}
pub fn inc(&mut self, delta: u64) {
self.current = self.current.saturating_add(delta).min(self.total);
self.draw();
}
pub fn stop(mut self, text: &str) {
if self.enabled {
self.draw();
eprintln!();
} else {
let _ = cliclack::log::info(text);
}
}
pub fn error(mut self, text: &str) {
if self.enabled {
self.draw();
eprintln!();
let _ = cliclack::log::error(text);
} else {
let _ = cliclack::log::error(text);
}
}
}

View File

@@ -1,81 +1,122 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use std::io::IsTerminal as _;
/// Manages a set of per-file progress bars plus a top aggregate bar.
pub struct ProgressManager {
pub struct FileProgress {
enabled: bool,
mp: Option<MultiProgress>,
per: Vec<ProgressBar>,
total: Option<ProgressBar>,
file_bars: Vec<cliclack::ProgressBar>,
total_bar: Option<cliclack::ProgressBar>,
completed: usize,
total_file_count: usize,
}
impl ProgressManager {
/// Create a new manager with the given enabled flag.
impl FileProgress {
pub fn new(enabled: bool) -> Self {
Self { enabled, mp: None, per: Vec::new(), total: None, completed: 0 }
Self {
enabled,
file_bars: Vec::new(),
total_bar: None,
completed: 0,
total_file_count: 0,
}
}
/// Create a manager that enables bars when `n > 1`, stderr is a TTY, and not quiet.
pub fn default_for_files(n: usize) -> Self {
let enabled = n > 1 && std::io::stderr().is_terminal() && !crate::is_quiet() && !crate::is_no_progress();
pub fn default_for_files(file_count: usize) -> Self {
let enabled = file_count > 1
&& std::io::stderr().is_terminal()
&& !crate::is_quiet()
&& !crate::is_no_progress();
Self::new(enabled)
}
/// Initialize bars for the given file labels. If disabled or single file, no-op.
pub fn init_files(&mut self, labels: &[String]) {
self.total_file_count = labels.len();
if !self.enabled || labels.len() <= 1 {
// No bars in single-file mode or when disabled
self.enabled = false;
return;
}
let mp = MultiProgress::new();
// Aggregate bar at the top
let total = mp.add(ProgressBar::new(labels.len() as u64));
total.set_style(ProgressStyle::with_template("{prefix} [{bar:40.cyan/blue}] {pos}/{len}")
.unwrap()
.progress_chars("=>-"));
total.set_prefix("Total");
self.total = Some(total);
// Per-file bars
let total = cliclack::progress_bar(labels.len() as u64);
total.start("Total");
self.total_bar = Some(total);
for label in labels {
let pb = mp.add(ProgressBar::new(100));
pb.set_style(ProgressStyle::with_template("{prefix} [{bar:40.green/black}] {pos}% {msg}")
.unwrap()
.progress_chars("=>-"));
pb.set_position(0);
pb.set_prefix(label.clone());
self.per.push(pb);
let pb = cliclack::progress_bar(100);
pb.start(label);
self.file_bars.push(pb);
}
self.mp = Some(mp);
}
/// Returns true when bars are enabled (multi-file TTY mode).
pub fn is_enabled(&self) -> bool { self.enabled }
/// Get a clone of the per-file progress bar at index, if enabled.
pub fn per_bar(&self, idx: usize) -> Option<ProgressBar> {
if !self.enabled { return None; }
self.per.get(idx).cloned()
pub fn is_enabled(&self) -> bool {
self.enabled
}
/// Get a clone of the aggregate (total) progress bar, if enabled.
pub fn total_bar(&self) -> Option<ProgressBar> {
if !self.enabled { return None; }
self.total.as_ref().cloned()
pub fn set_file_message(&mut self, idx: usize, message: &str) {
if !self.enabled {
return;
}
if let Some(pb) = self.file_bars.get_mut(idx) {
pb.set_message(message);
}
}
pub fn set_file_percent(&mut self, idx: usize, percent: u64) {
if !self.enabled {
return;
}
if let Some(pb) = self.file_bars.get_mut(idx) {
let p = percent.min(100);
pb.set_message(format!("{p}%"));
}
}
/// Mark a file as finished (set to 100% and update total counter).
pub fn mark_file_done(&mut self, idx: usize) {
if !self.enabled { return; }
if let Some(pb) = self.per.get(idx) {
pb.set_position(100);
pb.finish_with_message("done");
if !self.enabled {
return;
}
if let Some(pb) = self.file_bars.get_mut(idx) {
pb.stop("done");
}
self.completed += 1;
if let Some(total) = &self.total { total.set_position(self.completed as u64); }
if let Some(total) = &mut self.total_bar {
total.inc(1);
if self.completed >= self.total_file_count {
total.stop("all done");
}
}
}
pub fn finish_total(&mut self, message: &str) {
if !self.enabled {
return;
}
if let Some(total) = &mut self.total_bar {
total.stop(message);
}
}
}
#[derive(Debug)]
pub struct ProgressReporter {
non_interactive: bool,
}
impl ProgressReporter {
pub fn new(non_interactive: bool) -> Self {
Self { non_interactive }
}
pub fn step(&mut self, message: &str) {
if self.non_interactive {
let _ = cliclack::log::info(format!("[..] {message}"));
} else {
let _ = cliclack::log::info(format!("{message}"));
}
}
pub fn finish_with_message(&mut self, message: &str) {
if self.non_interactive {
let _ = cliclack::log::info(format!("[ok] {message}"));
} else {
let _ = cliclack::log::info(format!("{message}"));
}
}
}

View File

@@ -1,17 +1,12 @@
[package]
name = "polyscribe-host"
version = "0.1.0"
edition = "2024"
license = "MIT"
version.workspace = true
edition.workspace = true
[dependencies]
anyhow = "1.0.98"
thiserror = "1"
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.142"
tokio = { version = "1", features = ["full"] }
which = "6"
cliclack = "0.3"
directories = "5"
polyscribe = { path = "../polyscribe-core" }
polyscribe-protocol = { path = "../polyscribe-protocol" }
anyhow = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
tokio = { workspace = true, features = ["rt-multi-thread", "process", "io-util"] }
which = { workspace = true }
directories = { workspace = true }

View File

@@ -1,168 +1,119 @@
// SPDX-License-Identifier: MIT
use anyhow::{anyhow, Context, Result};
use cliclack as ui; // reuse for minimal logging
use directories::BaseDirs;
use serde_json::Value;
use std::collections::BTreeMap;
use std::ffi::OsStr;
use std::fs;
use std::io::{BufRead, BufReader, Write};
use std::path::{Path, PathBuf};
use std::process::{Command, Stdio};
use polyscribe_protocol as psp;
use anyhow::{Context, Result};
use std::process::Stdio;
use std::{
env, fs,
os::unix::fs::PermissionsExt,
path::Path,
};
use tokio::{
io::{AsyncBufReadExt, BufReader},
process::{Child as TokioChild, Command},
};
#[derive(Debug, Clone)]
pub struct Plugin {
pub struct PluginInfo {
pub name: String,
pub path: PathBuf,
pub path: String,
}
/// Discover plugins on PATH and in the user's data dir (XDG) under polyscribe/plugins.
pub fn discover() -> Result<Vec<Plugin>> {
let mut found: BTreeMap<String, PathBuf> = BTreeMap::new();
#[derive(Debug, Default)]
pub struct PluginManager;
// Scan PATH directories
if let Some(path_var) = std::env::var_os("PATH") {
for dir in std::env::split_paths(&path_var) {
if dir.as_os_str().is_empty() { continue; }
if let Ok(rd) = fs::read_dir(&dir) {
for ent in rd.flatten() {
let p = ent.path();
if !is_executable(&p) { continue; }
if let Some(fname) = p.file_name().and_then(OsStr::to_str) {
if let Some(name) = fname.strip_prefix("polyscribe-plugin-") {
found.entry(name.to_string()).or_insert(p);
}
}
}
impl PluginManager {
pub fn list(&self) -> Result<Vec<PluginInfo>> {
let mut plugins = Vec::new();
if let Ok(path) = env::var("PATH") {
for dir in env::split_paths(&path) {
scan_dir_for_plugins(&dir, &mut plugins);
}
}
if let Some(dirs) = directories::ProjectDirs::from("dev", "polyscribe", "polyscribe") {
let plugin_dir = dirs.data_dir().join("plugins");
scan_dir_for_plugins(&plugin_dir, &mut plugins);
}
plugins.sort_by(|a, b| a.path.cmp(&b.path));
plugins.dedup_by(|a, b| a.path == b.path);
Ok(plugins)
}
// Scan user data dir
if let Some(base) = BaseDirs::new() {
let user_plugins = PathBuf::from(base.data_dir()).join("polyscribe").join("plugins");
if let Ok(rd) = fs::read_dir(&user_plugins) {
for ent in rd.flatten() {
let p = ent.path();
if !is_executable(&p) { continue; }
if let Some(fname) = p.file_name().and_then(OsStr::to_str) {
let name = fname.strip_prefix("polyscribe-plugin-")
.map(|s| s.to_string())
.or_else(|| Some(fname.to_string()))
.unwrap();
found.entry(name).or_insert(p);
}
pub fn info(&self, name: &str) -> Result<serde_json::Value> {
let bin = self.resolve(name)?;
let out = std::process::Command::new(&bin)
.arg("info")
.stdout(Stdio::piped())
.stderr(Stdio::inherit())
.spawn()
.context("spawning plugin info")?
.wait_with_output()
.context("waiting for plugin info")?;
let val: serde_json::Value =
serde_json::from_slice(&out.stdout).context("parsing plugin info JSON")?;
Ok(val)
}
pub fn spawn(&self, name: &str, command: &str) -> Result<TokioChild> {
let bin = self.resolve(name)?;
let mut cmd = Command::new(&bin);
cmd.arg("run")
.arg(command)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit());
let child = cmd.spawn().context("spawning plugin run")?;
Ok(child)
}
pub async fn forward_stdio(&self, child: &mut TokioChild) -> Result<std::process::ExitStatus> {
if let Some(stdout) = child.stdout.take() {
let mut reader = BufReader::new(stdout).lines();
while let Some(line) = reader.next_line().await? {
println!("{line}");
}
}
Ok(child.wait().await?)
}
Ok(found
.into_iter()
.map(|(name, path)| Plugin { name, path })
.collect())
fn resolve(&self, name: &str) -> Result<String> {
let bin = format!("polyscribe-plugin-{name}");
let path =
which::which(&bin).with_context(|| format!("plugin not found in PATH: {bin}"))?;
Ok(path.to_string_lossy().to_string())
}
}
fn is_executable(p: &Path) -> bool {
if !p.is_file() { return false; }
fn is_executable(path: &Path) -> bool {
if !path.is_file() {
return false;
}
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
if let Ok(md) = fs::metadata(p) {
let mode = md.permissions().mode();
return (mode & 0o111) != 0;
}
false
}
#[cfg(not(unix))]
{
// On Windows, consider .exe, .bat, .cmd
matches!(p.extension().and_then(|s| s.to_str()).map(|s| s.to_lowercase()), Some(ext) if matches!(ext.as_str(), "exe"|"bat"|"cmd"))
}
}
/// Query plugin capabilities by invoking `--capabilities`.
pub fn capabilities(plugin_path: &Path) -> Result<psp::Capabilities> {
let out = Command::new(plugin_path)
.arg("--capabilities")
.stdout(Stdio::piped())
.stderr(Stdio::null())
.output()
.with_context(|| format!("Failed to execute plugin: {}", plugin_path.display()))?;
if !out.status.success() {
return Err(anyhow!("Plugin --capabilities failed: {}", plugin_path.display()));
}
let s = String::from_utf8(out.stdout).context("capabilities stdout not utf-8")?;
let caps: psp::Capabilities = serde_json::from_str(s.trim()).context("invalid capabilities JSON")?;
Ok(caps)
}
/// Run a single method via `--serve`, writing one JSON-RPC request and streaming until result.
pub fn run_method<F>(plugin_path: &Path, method: &str, params: Value, mut on_progress: F) -> Result<Value>
where
F: FnMut(psp::Progress),
{
let mut child = Command::new(plugin_path)
.arg("--serve")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null())
.spawn()
.with_context(|| format!("Failed to spawn plugin: {}", plugin_path.display()))?;
let mut stdin = child.stdin.take().ok_or_else(|| anyhow!("failed to open plugin stdin"))?;
let stdout = child.stdout.take().ok_or_else(|| anyhow!("failed to open plugin stdout"))?;
// Send request line
let req = psp::JsonRpcRequest { jsonrpc: "2.0".into(), id: "1".into(), method: method.to_string(), params: Some(params) };
let line = serde_json::to_string(&req)? + "\n";
stdin.write_all(line.as_bytes())?;
stdin.flush()?;
// Read response lines
let reader = BufReader::new(stdout);
for line in reader.lines() {
let line = line?;
if line.trim().is_empty() { continue; }
// Try parse StreamItem; if that fails, try parse JsonRpcResponse directly
if let Ok(item) = serde_json::from_str::<psp::StreamItem>(&line) {
match item {
psp::StreamItem::Progress(p) => {
on_progress(p);
}
psp::StreamItem::Result(resp) => {
match resp.outcome {
psp::JsonRpcOutcome::Ok { result } => return Ok(result),
psp::JsonRpcOutcome::Err { error } => return Err(anyhow!("{} ({})", error.message, error.code)),
}
}
}
} else if let Ok(resp) = serde_json::from_str::<psp::JsonRpcResponse>(&line) {
match resp.outcome {
psp::JsonRpcOutcome::Ok { result } => return Ok(result),
psp::JsonRpcOutcome::Err { error } => return Err(anyhow!("{} ({})", error.message, error.code)),
}
} else {
let _ = ui::log::warning(format!("Unrecognized plugin output: {}", line));
if let Ok(meta) = fs::metadata(path) {
let mode = meta.permissions().mode();
return mode & 0o111 != 0;
}
}
true
}
// If we exited loop without returning, wait for child
let status = child.wait()?;
if status.success() {
Err(anyhow!("Plugin terminated without sending a result"))
} else {
Err(anyhow!("Plugin exited with status: {:?}", status))
fn scan_dir_for_plugins(dir: &Path, out: &mut Vec<PluginInfo>) {
if let Ok(read_dir) = fs::read_dir(dir) {
for entry in read_dir.flatten() {
let path = entry.path();
if let Some(fname) = path.file_name().and_then(|s| s.to_str())
&& fname.starts_with("polyscribe-plugin-")
&& is_executable(&path)
{
let name = fname.trim_start_matches("polyscribe-plugin-").to_string();
out.push(PluginInfo {
name,
path: path.to_string_lossy().to_string(),
});
}
}
}
}
/// Helper: find a plugin by name using discovery
pub fn find_plugin_by_name(name: &str) -> Result<Plugin> {
let plugins = discover()?;
plugins
.into_iter()
.find(|p| p.name == name)
.ok_or_else(|| anyhow!("Plugin '{}' not found", name))
}

View File

@@ -1,10 +1,8 @@
[package]
name = "polyscribe-protocol"
version = "0.1.0"
edition = "2024"
license = "MIT"
version.workspace = true
edition.workspace = true
[dependencies]
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.142"
thiserror = "1"
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }

View File

@@ -1,90 +1,60 @@
// SPDX-License-Identifier: MIT
// PolyScribe Protocol (PSP/1): JSON-RPC 2.0 over NDJSON on stdio
use serde::{Deserialize, Serialize};
use serde_json::Value;
/// Plugin capabilities as reported by `--capabilities`.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Capabilities {
pub name: String,
pub version: String,
/// Protocol identifier (e.g., "psp/1")
pub protocol: String,
/// Role (e.g., pipeline, tool, generator)
pub role: String,
/// Supported command names
pub commands: Vec<String>,
}
/// Generic JSON-RPC 2.0 request for PSP/1
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcRequest {
pub jsonrpc: String, // "2.0"
#[derive(Debug, Serialize, Deserialize)]
pub struct Request {
pub id: String,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<serde_json::Value>,
pub params: Option<Value>,
}
/// Error object for JSON-RPC 2.0
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcError {
pub code: i64,
#[derive(Debug, Serialize, Deserialize)]
pub struct Response {
pub id: String,
pub result: Option<Value>,
pub error: Option<ErrorObj>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ErrorObj {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
pub data: Option<Value>,
}
/// Generic JSON-RPC 2.0 response for PSP/1
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "lowercase")]
pub enum StreamItem {
/// Progress notification (out-of-band in stream, not a JSON-RPC response)
Progress(Progress),
/// A proper JSON-RPC response with a result
Result(JsonRpcResponse),
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "event", content = "data")]
pub enum ProgressEvent {
Started,
Message(String),
Percent(f32),
Finished,
}
/// JSON-RPC 2.0 Response envelope containing either result or error.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcResponse {
pub jsonrpc: String, // "2.0"
pub id: String,
#[serde(flatten)]
pub outcome: JsonRpcOutcome,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum JsonRpcOutcome {
Ok { result: serde_json::Value },
Err { error: JsonRpcError },
}
/// Progress event structure for PSP/1 streaming
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Progress {
/// 0..=100
pub pct: u8,
/// Short phase name
pub stage: Option<String>,
/// Human-friendly detail
pub message: Option<String>,
}
/// Convenience helpers to build items
impl StreamItem {
pub fn progress(pct: u8, stage: impl Into<Option<String>>, message: impl Into<Option<String>>) -> Self {
StreamItem::Progress(Progress { pct, stage: stage.into(), message: message.into() })
}
pub fn ok(id: impl Into<String>, result: serde_json::Value) -> Self {
StreamItem::Result(JsonRpcResponse { jsonrpc: "2.0".into(), id: id.into(), outcome: JsonRpcOutcome::Ok { result } })
}
pub fn err(id: impl Into<String>, code: i64, message: impl Into<String>, data: Option<serde_json::Value>) -> Self {
StreamItem::Result(JsonRpcResponse {
jsonrpc: "2.0".into(),
impl Response {
pub fn ok(id: impl Into<String>, result: Value) -> Self {
Self {
id: id.into(),
outcome: JsonRpcOutcome::Err { error: JsonRpcError { code, message: message.into(), data } },
})
result: Some(result),
error: None,
}
}
pub fn err(
id: impl Into<String>,
code: i32,
message: impl Into<String>,
data: Option<Value>,
) -> Self {
Self {
id: id.into(),
result: None,
error: Some(ErrorObj {
code,
message: message.into(),
data,
}),
}
}
}

View File

@@ -1,26 +0,0 @@
# CI checklist and job outline
Checklist to keep docs and code healthy in CI
- Build: cargo build --all-targets --locked
- Tests: cargo test --all --locked
- Lints: cargo clippy --all-targets -- -D warnings
- Optional: check README and docs snippets (basic smoke run of examples scripts)
- bash examples/update_models.sh (can be skipped offline)
- bash examples/transcribe_file.sh (use a tiny sample file if available)
Example GitHub Actions job (outline)
- name: Rust
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- name: Build
run: cargo build --all-targets --locked
- name: Test
run: cargo test --all --locked
- name: Clippy
run: cargo clippy --all-targets -- -D warnings
Notes
- For GPU features, set up appropriate runners and add `--features gpu-cuda|gpu-hip|gpu-vulkan` where applicable.
- For docs-only changes, jobs still build/test to ensure doctests and examples compile when enabled.

View File

@@ -32,18 +32,20 @@ Run locally
Models during development
- Interactive downloader:
- cargo run -- --download-models
- cargo run -- models download
- Non-interactive update (checks sizes/hashes, downloads if missing):
- cargo run -- --update-models --no-interaction -q
- cargo run -- models update --no-interaction -q
Tests
- Run all tests:
- cargo test
- The test suite includes CLI-oriented integration tests and unit tests. Some tests simulate GPU detection using env vars (POLYSCRIBE_TEST_FORCE_*). Do not rely on these flags in production code.
Clippy
Clippy & formatting
- Run lint checks and treat warnings as errors:
- cargo clippy --all-targets -- -D warnings
- Check formatting:
- cargo fmt --all -- --check
- Common warnings can often be fixed by simplifying code, removing unused imports, and following idiomatic patterns.
Code layout
@@ -61,10 +63,10 @@ Adding a feature
Running the model downloader
- Interactive:
- cargo run -- --download-models
- cargo run -- models download
- Non-interactive suggestions for CI:
- POLYSCRIBE_MODELS_DIR=$PWD/models \
cargo run -- --update-models --no-interaction -q
cargo run -- models update --no-interaction -q
Env var examples for local testing
- Use a local copy of models and a specific model file:

View File

@@ -30,10 +30,10 @@ CLI reference
- Choose runtime backend. Default is auto (prefers CUDA → HIP → Vulkan → CPU), depending on detection.
- --gpu-layers N
- Number of layers to offload to the GPU when supported.
- --download-models
- models download
- Launch interactive model downloader (lists Hugging Face models; multi-select to download).
- Controls: Use Up/Down to navigate, Space to toggle selections, and Enter to confirm. Models are grouped by base (e.g., tiny, base, small).
- --update-models
- models update
- Verify/update local models by comparing sizes and hashes with the upstream manifest.
- -v, --verbose (repeatable)
- Increase log verbosity; use -vv for very detailed logs.
@@ -42,6 +42,9 @@ CLI reference
- --no-interaction
- Disable all interactive prompts (for CI). Combine with env vars to control behavior.
- Subcommands:
- models download: Launch interactive model downloader.
- models update: Verify/update local models (non-interactive).
- plugins list|info|run: Discover and run plugins.
- completions <shell>: Write shell completion script to stdout.
- man: Write a man page to stdout.

View File

@@ -1,13 +0,0 @@
#!/usr/bin/env bash
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
set -euo pipefail
# Launch the interactive model downloader and select models to install
BIN=${BIN:-./target/release/polyscribe}
MODELS_DIR=${POLYSCRIBE_MODELS_DIR:-$PWD/models}
export POLYSCRIBE_MODELS_DIR="$MODELS_DIR"
mkdir -p "$MODELS_DIR"
"$BIN" --download-models

View File

@@ -1,15 +0,0 @@
#!/usr/bin/env bash
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
set -euo pipefail
# Transcribe an audio/video file to JSON and SRT into ./output
# Requires a model; first run may prompt to download.
BIN=${BIN:-./target/release/polyscribe}
INPUT=${1:-samples/example.mp3}
OUTDIR=${OUTDIR:-output}
mkdir -p "$OUTDIR"
"$BIN" -v -o "$OUTDIR" "$INPUT"
echo "Done. See $OUTDIR for JSON/SRT files."

View File

@@ -1,15 +0,0 @@
#!/usr/bin/env bash
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
set -euo pipefail
# Verify/update local models non-interactively (useful in CI)
BIN=${BIN:-./target/release/polyscribe}
MODELS_DIR=${POLYSCRIBE_MODELS_DIR:-$PWD/models}
export POLYSCRIBE_MODELS_DIR="$MODELS_DIR"
mkdir -p "$MODELS_DIR"
"$BIN" --update-models --no-interaction -q
echo "Models updated in $MODELS_DIR"

View File

@@ -1,5 +1,4 @@
// SPDX-License-Identifier: MIT
// Stub plugin: tubescribe
use anyhow::{Context, Result};
use clap::Parser;
@@ -36,7 +35,6 @@ fn main() -> Result<()> {
serve_once()?;
return Ok(());
}
// Default: show capabilities (friendly behavior if run without flags)
let caps = psp::Capabilities {
name: "tubescribe".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
@@ -49,14 +47,12 @@ fn main() -> Result<()> {
}
fn serve_once() -> Result<()> {
// Read exactly one line (one request)
let stdin = std::io::stdin();
let mut reader = BufReader::new(stdin.lock());
let mut line = String::new();
reader.read_line(&mut line).context("failed to read request line")?;
let req: psp::JsonRpcRequest = serde_json::from_str(line.trim()).context("invalid JSON-RPC request")?;
// Simulate doing some work with progress
emit(&psp::StreamItem::progress(5, Some("start".into()), Some("initializing".into())))?;
std::thread::sleep(std::time::Duration::from_millis(50));
emit(&psp::StreamItem::progress(25, Some("probe".into()), Some("probing sources".into())))?;
@@ -65,7 +61,6 @@ fn serve_once() -> Result<()> {
std::thread::sleep(std::time::Duration::from_millis(50));
emit(&psp::StreamItem::progress(90, Some("finalize".into()), Some("finalizing".into())))?;
// Handle method and produce result
let result = match req.method.as_str() {
"generate_metadata" => {
let title = "Canned title";
@@ -78,7 +73,6 @@ fn serve_once() -> Result<()> {
})
}
other => {
// Unknown method
let err = psp::StreamItem::err(req.id.clone(), -32601, format!("Method not found: {}", other), None);
emit(&err)?;
return Ok(());

6
rust-toolchain.toml Normal file
View File

@@ -0,0 +1,6 @@
# SPDX-License-Identifier: MIT
[toolchain]
channel = "1.89.0"
components = ["clippy", "rustfmt"]
profile = "minimal"

View File

@@ -1,329 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
//! Transcription backend selection and implementations (CPU/GPU) used by PolyScribe.
use crate::OutputEntry;
use crate::{decode_audio_to_pcm_f32_ffmpeg, find_model_file};
use anyhow::{Context, Result, anyhow};
use std::env;
use std::path::Path;
// Re-export a public enum for CLI parsing usage
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
/// Kind of transcription backend to use.
pub enum BackendKind {
/// Automatically detect the best available backend (CUDA > HIP > Vulkan > CPU).
Auto,
/// Pure CPU backend using whisper-rs.
Cpu,
/// NVIDIA CUDA backend (requires CUDA runtime available at load time and proper feature build).
Cuda,
/// AMD ROCm/HIP backend (requires hip/rocBLAS libraries available and proper feature build).
Hip,
/// Vulkan backend (experimental; requires Vulkan loader/SDK and feature build).
Vulkan,
}
/// Abstraction for a transcription backend.
pub trait TranscribeBackend {
/// Backend kind implemented by this type.
fn kind(&self) -> BackendKind;
/// Transcribe the given audio and return transcript entries.
fn transcribe(
&self,
audio_path: &Path,
speaker: &str,
language: Option<&str>,
gpu_layers: Option<u32>,
progress: Option<&(dyn Fn(i32) + Send + Sync)>,
) -> Result<Vec<OutputEntry>>;
}
fn check_lib(_names: &[&str]) -> bool {
#[cfg(test)]
{
// During unit tests, avoid touching system libs to prevent loader crashes in CI.
false
}
#[cfg(not(test))]
{
// Disabled runtime dlopen probing to avoid loader instability; rely on environment overrides.
false
}
}
fn cuda_available() -> bool {
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_CUDA") {
return x == "1";
}
check_lib(&[
"libcudart.so",
"libcudart.so.12",
"libcudart.so.11",
"libcublas.so",
"libcublas.so.12",
])
}
fn hip_available() -> bool {
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_HIP") {
return x == "1";
}
check_lib(&["libhipblas.so", "librocblas.so"])
}
fn vulkan_available() -> bool {
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_VULKAN") {
return x == "1";
}
check_lib(&["libvulkan.so.1", "libvulkan.so"])
}
/// CPU-based transcription backend using whisper-rs.
#[derive(Default)]
pub struct CpuBackend;
/// CUDA-accelerated transcription backend for NVIDIA GPUs.
#[derive(Default)]
pub struct CudaBackend;
/// ROCm/HIP-accelerated transcription backend for AMD GPUs.
#[derive(Default)]
pub struct HipBackend;
/// Vulkan-based transcription backend (experimental/incomplete).
#[derive(Default)]
pub struct VulkanBackend;
macro_rules! impl_whisper_backend {
($ty:ty, $kind:expr) => {
impl TranscribeBackend for $ty {
fn kind(&self) -> BackendKind { $kind }
fn transcribe(
&self,
audio_path: &Path,
speaker: &str,
language: Option<&str>,
_gpu_layers: Option<u32>,
progress: Option<&(dyn Fn(i32) + Send + Sync)>,
) -> Result<Vec<OutputEntry>> {
transcribe_with_whisper_rs(audio_path, speaker, language, progress)
}
}
};
}
impl_whisper_backend!(CpuBackend, BackendKind::Cpu);
impl_whisper_backend!(CudaBackend, BackendKind::Cuda);
impl_whisper_backend!(HipBackend, BackendKind::Hip);
impl TranscribeBackend for VulkanBackend {
fn kind(&self) -> BackendKind {
BackendKind::Vulkan
}
fn transcribe(
&self,
_audio_path: &Path,
_speaker: &str,
_language: Option<&str>,
_gpu_layers: Option<u32>,
_progress: Option<&(dyn Fn(i32) + Send + Sync)>,
) -> Result<Vec<OutputEntry>> {
Err(anyhow!(
"Vulkan backend not yet wired to whisper.cpp FFI. Build with --features gpu-vulkan and ensure Vulkan SDK is installed. How to fix: install Vulkan loader (libvulkan), set VULKAN_SDK, and run cargo build --features gpu-vulkan."
))
}
}
/// Result of choosing a transcription backend.
pub struct SelectionResult {
/// The constructed backend instance to perform transcription with.
pub backend: Box<dyn TranscribeBackend + Send + Sync>,
/// Which backend kind was ultimately selected.
pub chosen: BackendKind,
/// Which backend kinds were detected as available on this system.
pub detected: Vec<BackendKind>,
}
/// Select an appropriate backend based on user request and system detection.
///
/// If `requested` is `BackendKind::Auto`, the function prefers CUDA, then HIP,
/// then Vulkan, falling back to CPU when no GPU backend is detected. When a
/// specific GPU backend is requested but unavailable, an error is returned with
/// guidance on how to enable it.
///
/// Set `verbose` to true to print detection/selection info to stderr.
pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<SelectionResult> {
let mut detected = Vec::new();
if cuda_available() {
detected.push(BackendKind::Cuda);
}
if hip_available() {
detected.push(BackendKind::Hip);
}
if vulkan_available() {
detected.push(BackendKind::Vulkan);
}
let instantiate_backend = |k: BackendKind| -> Box<dyn TranscribeBackend + Send + Sync> {
match k {
BackendKind::Cpu => Box::new(CpuBackend::default()),
BackendKind::Cuda => Box::new(CudaBackend::default()),
BackendKind::Hip => Box::new(HipBackend::default()),
BackendKind::Vulkan => Box::new(VulkanBackend::default()),
BackendKind::Auto => Box::new(CpuBackend::default()), // placeholder for Auto
}
};
let chosen = match requested {
BackendKind::Auto => {
if detected.contains(&BackendKind::Cuda) {
BackendKind::Cuda
} else if detected.contains(&BackendKind::Hip) {
BackendKind::Hip
} else if detected.contains(&BackendKind::Vulkan) {
BackendKind::Vulkan
} else {
BackendKind::Cpu
}
}
BackendKind::Cuda => {
if detected.contains(&BackendKind::Cuda) {
BackendKind::Cuda
} else {
return Err(anyhow!(
"Requested CUDA backend but CUDA libraries/devices not detected. How to fix: install NVIDIA driver + CUDA toolkit, ensure libcudart/libcublas are in loader path, and build with --features gpu-cuda."
));
}
}
BackendKind::Hip => {
if detected.contains(&BackendKind::Hip) {
BackendKind::Hip
} else {
return Err(anyhow!(
"Requested ROCm/HIP backend but libraries/devices not detected. How to fix: install ROCm hipBLAS/rocBLAS, ensure libs are in loader path, and build with --features gpu-hip."
));
}
}
BackendKind::Vulkan => {
if detected.contains(&BackendKind::Vulkan) {
BackendKind::Vulkan
} else {
return Err(anyhow!(
"Requested Vulkan backend but libvulkan not detected. How to fix: install Vulkan loader/SDK and build with --features gpu-vulkan."
));
}
}
BackendKind::Cpu => BackendKind::Cpu,
};
if verbose {
crate::dlog!(1, "Detected backends: {:?}", detected);
crate::dlog!(1, "Selected backend: {:?}", chosen);
}
Ok(SelectionResult {
backend: instantiate_backend(chosen),
chosen,
detected,
})
}
// Internal helper: transcription using whisper-rs with CPU/GPU (depending on build features)
#[allow(clippy::too_many_arguments)]
pub(crate) fn transcribe_with_whisper_rs(
audio_path: &Path,
speaker: &str,
language: Option<&str>,
progress: Option<&(dyn Fn(i32) + Send + Sync)>,
) -> Result<Vec<OutputEntry>> {
let report = |p: i32| {
if let Some(cb) = progress { cb(p); }
};
report(0);
let pcm_samples = decode_audio_to_pcm_f32_ffmpeg(audio_path)?;
report(5);
let model_path = find_model_file()?;
let english_only_model = model_path
.file_name()
.and_then(|s| s.to_str())
.map(|s| s.contains(".en.") || s.ends_with(".en.bin"))
.unwrap_or(false);
if let Some(lang) = language {
if english_only_model && lang != "en" {
return Err(anyhow!(
"Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.",
model_path.display(),
lang
));
}
}
let model_path_str = model_path
.to_str()
.ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model_path.display()))?;
if crate::verbose_level() < 2 {
// Some builds of whisper/ggml expect these env vars; harmless if unknown
unsafe {
std::env::set_var("GGML_LOG_LEVEL", "0");
std::env::set_var("WHISPER_PRINT_PROGRESS", "0");
}
}
let (_context, mut state) = crate::with_suppressed_stderr(|| {
let params = whisper_rs::WhisperContextParameters::default();
let context = whisper_rs::WhisperContext::new_with_params(model_path_str, params)
.with_context(|| format!("Failed to load Whisper model at {}", model_path.display()))?;
let state = context
.create_state()
.map_err(|e| anyhow!("Failed to create Whisper state: {:?}", e))?;
Ok::<_, anyhow::Error>((context, state))
})?;
report(20);
let mut full_params =
whisper_rs::FullParams::new(whisper_rs::SamplingStrategy::Greedy { best_of: 1 });
let threads = std::thread::available_parallelism()
.map(|n| n.get() as i32)
.unwrap_or(1);
full_params.set_n_threads(threads);
full_params.set_translate(false);
if let Some(lang) = language {
full_params.set_language(Some(lang));
}
report(30);
crate::with_suppressed_stderr(|| {
report(40);
state
.full(full_params, &pcm_samples)
.map_err(|e| anyhow!("Whisper full() failed: {:?}", e))
})?;
report(90);
let num_segments = state
.full_n_segments()
.map_err(|e| anyhow!("Failed to get segments: {:?}", e))?;
let mut entries = Vec::new();
for seg_idx in 0..num_segments {
let segment_text = state
.full_get_segment_text(seg_idx)
.map_err(|e| anyhow!("Failed to get segment text: {:?}", e))?;
let t0 = state
.full_get_segment_t0(seg_idx)
.map_err(|e| anyhow!("Failed to get segment t0: {:?}", e))?;
let t1 = state
.full_get_segment_t1(seg_idx)
.map_err(|e| anyhow!("Failed to get segment t1: {:?}", e))?;
let start = (t0 as f64) * 0.01;
let end = (t1 as f64) * 0.01;
entries.push(OutputEntry {
id: 0,
speaker: speaker.to_string(),
start,
end,
text: segment_text.trim().to_string(),
});
}
report(100);
Ok(entries)
}

View File

@@ -1,571 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
#![forbid(elided_lifetimes_in_paths)]
#![forbid(unused_must_use)]
#![deny(missing_docs)]
#![warn(clippy::all)]
//! PolyScribe library: business logic and core types.
//!
//! This crate exposes the reusable parts of the PolyScribe CLI as a library.
//! The binary entry point (main.rs) remains a thin CLI wrapper.
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
// Global runtime flags
static QUIET: AtomicBool = AtomicBool::new(false);
static NO_INTERACTION: AtomicBool = AtomicBool::new(false);
static VERBOSE: AtomicU8 = AtomicU8::new(0);
static NO_PROGRESS: AtomicBool = AtomicBool::new(false);
/// Set quiet mode: when true, non-interactive logs should be suppressed.
pub fn set_quiet(enabled: bool) {
QUIET.store(enabled, Ordering::Relaxed);
}
/// Return current quiet mode state.
pub fn is_quiet() -> bool {
QUIET.load(Ordering::Relaxed)
}
/// Set non-interactive mode: when true, interactive prompts must be skipped.
pub fn set_no_interaction(enabled: bool) {
NO_INTERACTION.store(enabled, Ordering::Relaxed);
}
/// Return current non-interactive state.
pub fn is_no_interaction() -> bool {
NO_INTERACTION.load(Ordering::Relaxed)
}
/// Set verbose level (0 = normal, 1 = verbose, 2 = super-verbose)
pub fn set_verbose(level: u8) {
VERBOSE.store(level, Ordering::Relaxed);
}
/// Get current verbose level.
pub fn verbose_level() -> u8 {
VERBOSE.load(Ordering::Relaxed)
}
/// Disable interactive progress indicators (bars/spinners)
pub fn set_no_progress(enabled: bool) {
NO_PROGRESS.store(enabled, Ordering::Relaxed);
}
/// Return current no-progress state
pub fn is_no_progress() -> bool {
NO_PROGRESS.load(Ordering::Relaxed)
}
/// Check whether stdin is connected to a TTY. Used to avoid blocking prompts when not interactive.
pub fn stdin_is_tty() -> bool {
use std::io::IsTerminal as _;
std::io::stdin().is_terminal()
}
/// A guard that temporarily redirects stderr to /dev/null on Unix when quiet mode is active.
/// No-op on non-Unix or when quiet is disabled. Restores stderr on drop.
pub struct StderrSilencer {
#[cfg(unix)]
old_stderr_fd: i32,
#[cfg(unix)]
devnull_fd: i32,
active: bool,
}
impl StderrSilencer {
/// Activate stderr silencing if quiet is set and on Unix; otherwise returns a no-op guard.
pub fn activate_if_quiet() -> Self {
if !is_quiet() {
return Self {
active: false,
#[cfg(unix)]
old_stderr_fd: -1,
#[cfg(unix)]
devnull_fd: -1,
};
}
Self::activate()
}
/// Activate stderr silencing unconditionally (used internally); no-op on non-Unix.
pub fn activate() -> Self {
#[cfg(unix)]
unsafe {
let old_fd = dup(2);
if old_fd < 0 {
return Self {
active: false,
old_stderr_fd: -1,
devnull_fd: -1,
};
}
// Open /dev/null for writing
let devnull_cstr = std::ffi::CString::new("/dev/null").unwrap();
let devnull_fd = open(devnull_cstr.as_ptr(), O_WRONLY);
if devnull_fd < 0 {
close(old_fd);
return Self {
active: false,
old_stderr_fd: -1,
devnull_fd: -1,
};
}
if dup2(devnull_fd, 2) < 0 {
close(devnull_fd);
close(old_fd);
return Self {
active: false,
old_stderr_fd: -1,
devnull_fd: -1,
};
}
Self {
active: true,
old_stderr_fd: old_fd,
devnull_fd: devnull_fd,
}
}
#[cfg(not(unix))]
{
Self { active: false }
}
}
}
impl Drop for StderrSilencer {
fn drop(&mut self) {
if !self.active {
return;
}
#[cfg(unix)]
unsafe {
let _ = dup2(self.old_stderr_fd, 2);
let _ = close(self.devnull_fd);
let _ = close(self.old_stderr_fd);
}
self.active = false;
}
}
/// Run a closure while temporarily suppressing stderr on Unix when appropriate.
/// On Windows/non-Unix, this is a no-op wrapper.
/// This helper uses RAII + panic catching to ensure restoration before resuming panic.
pub fn with_suppressed_stderr<F, T>(f: F) -> T
where
F: FnOnce() -> T,
{
// Suppress noisy native logs unless super-verbose (-vv) is enabled.
if verbose_level() < 2 {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _guard = StderrSilencer::activate();
f()
}));
match result {
Ok(value) => value,
Err(panic_payload) => std::panic::resume_unwind(panic_payload),
}
} else {
f()
}
}
/// Centralized UI helpers (TTY-aware, quiet/verbose-aware)
pub mod ui;
/// Logging macros and helpers
/// Log an error using the UI helper (always printed). Recommended for user-visible errors.
#[macro_export]
macro_rules! elog {
($($arg:tt)*) => {{
$crate::ui::error(format!($($arg)*));
}}
}
/// Log a warning using the UI helper (printed even in quiet mode).
#[macro_export]
macro_rules! wlog {
($($arg:tt)*) => {{
$crate::ui::warn(format!($($arg)*));
}}
}
/// Log an informational line using the UI helper unless quiet mode is enabled.
#[macro_export]
macro_rules! ilog {
($($arg:tt)*) => {{
if !$crate::is_quiet() { $crate::ui::info(format!($($arg)*)); }
}}
}
/// Log a debug/trace line when verbose level is at least the given level (u8).
#[macro_export]
macro_rules! dlog {
($lvl:expr, $($arg:tt)*) => {{
if !$crate::is_quiet() && $crate::verbose_level() >= $lvl { $crate::ui::info(format!("DEBUG{}: {}", $lvl, format!($($arg)*))); }
}}
}
/// Backward-compatibility: map old qlog! to ilog!
#[macro_export]
macro_rules! qlog {
($($arg:tt)*) => {{ $crate::ilog!($($arg)*); }}
}
use anyhow::{Context, Result, anyhow};
use chrono::Local;
use std::env;
use std::fs::create_dir_all;
use std::path::{Path, PathBuf};
use std::process::Command;
#[cfg(unix)]
use libc::{O_WRONLY, close, dup, dup2, open};
/// Re-export backend module (GPU/CPU selection and transcription).
pub mod backend;
/// Re-export models module (model listing/downloading/updating).
pub mod models;
/// Transcript entry for a single segment.
#[derive(Debug, serde::Serialize, Clone)]
pub struct OutputEntry {
/// Sequential id in output ordering.
pub id: u64,
/// Speaker label associated with the segment.
pub speaker: String,
/// Start time in seconds.
pub start: f64,
/// End time in seconds.
pub end: f64,
/// Text content.
pub text: String,
}
/// Return a YYYY-MM-DD date prefix string for output file naming.
pub fn date_prefix() -> String {
Local::now().format("%Y-%m-%d").to_string()
}
/// Format a floating-point number of seconds as SRT timestamp (HH:MM:SS,mmm).
pub fn format_srt_time(seconds: f64) -> String {
let total_ms = (seconds * 1000.0).round() as i64;
let ms = total_ms % 1000;
let total_secs = total_ms / 1000;
let sec = total_secs % 60;
let min = (total_secs / 60) % 60;
let hour = total_secs / 3600;
format!("{hour:02}:{min:02}:{sec:02},{ms:03}")
}
/// Render a list of transcript entries to SRT format.
pub fn render_srt(entries: &[OutputEntry]) -> String {
let mut srt = String::new();
for (index, entry) in entries.iter().enumerate() {
let srt_index = index + 1;
srt.push_str(&format!("{srt_index}\n"));
srt.push_str(&format!(
"{} --> {}\n",
format_srt_time(entry.start),
format_srt_time(entry.end)
));
if !entry.speaker.is_empty() {
srt.push_str(&format!("{}: {}\n", entry.speaker, entry.text));
} else {
srt.push_str(&format!("{}\n", entry.text));
}
srt.push('\n');
}
srt
}
/// Determine the default models directory, honoring POLYSCRIBE_MODELS_DIR override.
pub fn models_dir_path() -> PathBuf {
if let Ok(env_val) = env::var("POLYSCRIBE_MODELS_DIR") {
let env_path = PathBuf::from(env_val);
if !env_path.as_os_str().is_empty() {
return env_path;
}
}
if cfg!(debug_assertions) {
return PathBuf::from("models");
}
if let Ok(xdg) = env::var("XDG_DATA_HOME") {
if !xdg.is_empty() {
return PathBuf::from(xdg).join("polyscribe").join("models");
}
}
if let Ok(home) = env::var("HOME") {
if !home.is_empty() {
return PathBuf::from(home)
.join(".local")
.join("share")
.join("polyscribe")
.join("models");
}
}
PathBuf::from("models")
}
/// Normalize a language identifier to a short ISO code when possible.
pub fn normalize_lang_code(input: &str) -> Option<String> {
let mut lang = input.trim().to_lowercase();
if lang.is_empty() || lang == "auto" || lang == "c" || lang == "posix" {
return None;
}
if let Some((prefix, _)) = lang.split_once('.') {
lang = prefix.to_string();
}
if let Some((prefix, _)) = lang.split_once('_') {
lang = prefix.to_string();
}
let code = match lang.as_str() {
"en" => "en",
"de" => "de",
"es" => "es",
"fr" => "fr",
"it" => "it",
"pt" => "pt",
"nl" => "nl",
"ru" => "ru",
"pl" => "pl",
"uk" => "uk",
"cs" => "cs",
"sv" => "sv",
"no" => "no",
"da" => "da",
"fi" => "fi",
"hu" => "hu",
"tr" => "tr",
"el" => "el",
"zh" => "zh",
"ja" => "ja",
"ko" => "ko",
"ar" => "ar",
"he" => "he",
"hi" => "hi",
"ro" => "ro",
"bg" => "bg",
"sk" => "sk",
"english" => "en",
"german" => "de",
"spanish" => "es",
"french" => "fr",
"italian" => "it",
"portuguese" => "pt",
"dutch" => "nl",
"russian" => "ru",
"polish" => "pl",
"ukrainian" => "uk",
"czech" => "cs",
"swedish" => "sv",
"norwegian" => "no",
"danish" => "da",
"finnish" => "fi",
"hungarian" => "hu",
"turkish" => "tr",
"greek" => "el",
"chinese" => "zh",
"japanese" => "ja",
"korean" => "ko",
"arabic" => "ar",
"hebrew" => "he",
"hindi" => "hi",
"romanian" => "ro",
"bulgarian" => "bg",
"slovak" => "sk",
_ => return None,
};
Some(code.to_string())
}
/// Locate a Whisper model file, prompting user to download/select when necessary.
pub fn find_model_file() -> Result<PathBuf> {
let models_dir_buf = models_dir_path();
let models_dir = models_dir_buf.as_path();
if !models_dir.exists() {
create_dir_all(models_dir).with_context(|| {
format!(
"Failed to create models directory: {}",
models_dir.display()
)
})?;
}
if let Ok(env_model) = env::var("WHISPER_MODEL") {
let model_path = PathBuf::from(env_model);
if model_path.is_file() {
let _ = std::fs::write(models_dir.join(".last_model"), model_path.display().to_string());
return Ok(model_path);
}
}
// Non-interactive mode: automatic selection and optional download
if crate::is_no_interaction() {
if let Some(local) = crate::models::pick_best_local_model(models_dir) {
let _ = std::fs::write(models_dir.join(".last_model"), local.display().to_string());
return Ok(local);
} else {
ilog!("No local models found; downloading large-v3-turbo-q8_0...");
let path = crate::models::ensure_model_available_noninteractive("large-v3-turbo-q8_0")
.with_context(|| "Failed to download required model 'large-v3-turbo-q8_0'")?;
let _ = std::fs::write(models_dir.join(".last_model"), path.display().to_string());
return Ok(path);
}
}
let mut candidates: Vec<PathBuf> = Vec::new();
let dir_entries = std::fs::read_dir(models_dir)
.with_context(|| format!("Failed to read models directory: {}", models_dir.display()))?;
for entry in dir_entries {
let entry = entry?;
let path = entry.path();
if path.is_file() {
if let Some(ext) = path
.extension()
.and_then(|s| s.to_str())
.map(|s| s.to_lowercase())
{
if ext == "bin" {
candidates.push(path);
}
}
}
}
if candidates.is_empty() {
// No models found: prompt interactively (TTY only)
wlog!(
"{}",
format!(
"No Whisper model files (*.bin) found in {}.",
models_dir.display()
)
);
if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Err(anyhow!(
"No models available and interactive mode is disabled. Please set WHISPER_MODEL or run with --download-models."
));
}
let input = crate::ui::prompt_line("Would you like to download models now? [Y/n]: ").unwrap_or_default();
let answer = input.trim().to_lowercase();
if answer.is_empty() || answer == "y" || answer == "yes" {
if let Err(e) = models::run_interactive_model_downloader() {
elog!("Downloader failed: {:#}", e);
}
candidates.clear();
let dir_entries2 = std::fs::read_dir(models_dir).with_context(|| {
format!("Failed to read models directory: {}", models_dir.display())
})?;
for entry in dir_entries2 {
let entry = entry?;
let path = entry.path();
if path.is_file() {
if let Some(ext) = path
.extension()
.and_then(|s| s.to_str())
.map(|s| s.to_lowercase())
{
if ext == "bin" {
candidates.push(path);
}
}
}
}
}
}
if candidates.is_empty() {
return Err(anyhow!(
"No Whisper model files (*.bin) available in {}",
models_dir.display()
));
}
if candidates.len() == 1 {
let only_model = candidates.remove(0);
let _ = std::fs::write(models_dir.join(".last_model"), only_model.display().to_string());
return Ok(only_model);
}
let last_file = models_dir.join(".last_model");
if let Ok(previous_content) = std::fs::read_to_string(&last_file) {
let previous_content = previous_content.trim();
if !previous_content.is_empty() {
let previous_path = PathBuf::from(previous_content);
if previous_path.is_file() && candidates.iter().any(|c| c == &previous_path) {
return Ok(previous_path);
}
}
}
crate::ui::println_above_bars(format!("Multiple Whisper models found in {}:", models_dir.display()));
for (index, path) in candidates.iter().enumerate() {
crate::ui::println_above_bars(format!(" {}) {}", index + 1, path.display()));
}
let input = crate::ui::prompt_line(&format!("Select model by number [1-{}]: ", candidates.len()))
.map_err(|_| anyhow!("Failed to read selection"))?;
let selection: usize = input
.trim()
.parse()
.map_err(|_| anyhow!("Invalid selection: {}", input.trim()))?;
if selection == 0 || selection > candidates.len() {
return Err(anyhow!("Selection out of range"));
}
let chosen = candidates.swap_remove(selection - 1);
let _ = std::fs::write(models_dir.join(".last_model"), chosen.display().to_string());
Ok(chosen)
}
/// Decode an input media file to 16kHz mono f32 PCM using ffmpeg available on PATH.
pub fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> {
let output = match Command::new("ffmpeg")
.arg("-i")
.arg(audio_path)
.arg("-f")
.arg("f32le")
.arg("-ac")
.arg("1")
.arg("-ar")
.arg("16000")
.arg("pipe:1")
.output()
{
Ok(o) => o,
Err(e) => {
if e.kind() == std::io::ErrorKind::NotFound {
return Err(anyhow!(
"ffmpeg not found on PATH. Please install ffmpeg and ensure it is available."
));
} else {
return Err(anyhow!(
"Failed to execute ffmpeg for {}: {}",
audio_path.display(),
e
));
}
}
};
if !output.status.success() {
let stderr_str = String::from_utf8_lossy(&output.stderr);
return Err(anyhow!(
"Failed to decode audio from {} using ffmpeg. This may indicate the file is not a valid or supported audio/video file, is corrupted, or cannot be opened. ffmpeg stderr: {}",
audio_path.display(),
stderr_str.trim()
));
}
let data = output.stdout;
if data.len() % 4 != 0 {
let truncated = data.len() - (data.len() % 4);
let mut samples = Vec::with_capacity(truncated / 4);
for chunk in data[..truncated].chunks_exact(4) {
let arr = [chunk[0], chunk[1], chunk[2], chunk[3]];
samples.push(f32::from_le_bytes(arr));
}
Ok(samples)
} else {
let mut samples = Vec::with_capacity(data.len() / 4);
for chunk in data.chunks_exact(4) {
let arr = [chunk[0], chunk[1], chunk[2], chunk[3]];
samples.push(f32::from_le_bytes(arr));
}
Ok(samples)
}
}

View File

@@ -1,483 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
use std::fs::{File, create_dir_all};
use std::io::{self, Read, Write};
use std::path::{Path, PathBuf};
use anyhow::{Context, Result, anyhow};
use clap::{Parser, Subcommand, ValueEnum, CommandFactory};
use clap_complete::Shell;
use serde::{Deserialize, Serialize};
use polyscribe::{OutputEntry, date_prefix, normalize_lang_code, render_srt};
#[derive(Subcommand, Debug, Clone)]
enum AuxCommands {
Completions {
#[arg(value_enum)]
shell: Shell,
},
Man,
}
#[derive(ValueEnum, Debug, Clone, Copy)]
#[value(rename_all = "kebab-case")]
enum GpuBackendCli {
Auto,
Cpu,
Cuda,
Hip,
Vulkan,
}
#[derive(Parser, Debug)]
#[command(
name = "PolyScribe",
bin_name = "polyscribe",
version,
about = "Merge JSON transcripts or transcribe audio using native whisper"
)]
struct Args {
/// Increase verbosity (-v, -vv). Repeat to increase.
/// Debug logs appear with -v; very verbose with -vv. Logs go to stderr.
#[arg(short = 'v', long = "verbose", action = clap::ArgAction::Count, global = true)]
verbose: u8,
/// Quiet mode: suppress non-error logging on stderr (overrides -v)
/// Does not suppress interactive prompts or stdout output.
#[arg(short = 'q', long = "quiet", global = true)]
quiet: bool,
/// Non-interactive mode: never prompt; use defaults instead.
#[arg(long = "no-interaction", global = true)]
no_interaction: bool,
/// Disable interactive progress indicators (bars/spinners)
#[arg(long = "no-progress", global = true)]
no_progress: bool,
/// Optional auxiliary subcommands (completions, man)
#[command(subcommand)]
aux: Option<AuxCommands>,
/// Input .json transcript files or audio files to merge/transcribe
inputs: Vec<String>,
/// Output file path base or directory (date prefix added).
/// In merge mode: base path.
/// In separate mode: directory.
/// If omitted: prints JSON to stdout for merge mode; separate mode requires directory for multiple inputs.
#[arg(short, long, value_name = "FILE")]
output: Option<String>,
/// Merge all inputs into a single output; if not set, each input is written as a separate output
#[arg(short = 'm', long = "merge")]
merge: bool,
/// Merge and also write separate outputs per input; requires -o OUTPUT_DIR
#[arg(long = "merge-and-separate")]
merge_and_separate: bool,
/// Prompt for speaker names per input file
#[arg(long = "set-speaker-names")]
set_speaker_names: bool,
/// Language code to use for transcription (e.g., en, de). No auto-detection.
#[arg(short, long, value_name = "LANG")]
language: Option<String>,
/// Launch interactive model downloader (list HF models, multi-select and download)
#[arg(long)]
download_models: bool,
/// Update local Whisper models by comparing hashes/sizes with remote manifest
#[arg(long)]
update_models: bool,
}
#[derive(Debug, Deserialize)]
struct InputRoot {
#[serde(default)]
segments: Vec<InputSegment>,
}
#[derive(Debug, Deserialize)]
struct InputSegment {
start: f64,
end: f64,
text: String,
}
#[derive(Debug, Serialize)]
struct OutputRoot {
items: Vec<OutputEntry>,
}
fn is_json_file(path: &Path) -> bool {
matches!(path.extension().and_then(|s| s.to_str()).map(|s| s.to_lowercase()), Some(ext) if ext == "json")
}
fn is_audio_file(path: &Path) -> bool {
if let Some(ext) = path.extension().and_then(|s| s.to_str()).map(|s| s.to_lowercase()) {
let exts = [
"mp3", "wav", "m4a", "mp4", "aac", "flac", "ogg", "wma", "webm", "mkv", "mov", "avi",
"m4b", "3gp", "opus", "aiff", "alac",
];
return exts.contains(&ext.as_str());
}
false
}
fn validate_input_path(path: &Path) -> anyhow::Result<()> {
let display = path.display();
if !path.exists() {
return Err(anyhow!("Input not found: {}", display));
}
let metadata = std::fs::metadata(path).with_context(|| format!("Failed to stat input: {}", display))?;
if metadata.is_dir() {
return Err(anyhow!("Input is a directory (expected a file): {}", display));
}
std::fs::File::open(path)
.with_context(|| format!("Failed to open input file: {}", display))
.map(|_| ())
}
fn sanitize_speaker_name(raw: &str) -> String {
if let Some((prefix, rest)) = raw.split_once('-') {
if !prefix.is_empty() && prefix.chars().all(|c| c.is_ascii_digit()) {
return rest.to_string();
}
}
raw.to_string()
}
fn prompt_speaker_name_for_path(
_path: &Path,
default_name: &str,
enabled: bool,
) -> String {
if !enabled || polyscribe::is_no_interaction() {
return sanitize_speaker_name(default_name);
}
// TODO implement cliclack for this
let mut input_line = String::new();
match std::io::stdin().read_line(&mut input_line) {
Ok(_) => {
let trimmed = input_line.trim();
if trimmed.is_empty() {
sanitize_speaker_name(default_name)
} else {
sanitize_speaker_name(trimmed)
}
}
Err(_) => sanitize_speaker_name(default_name),
}
}
fn main() -> Result<()> {
let args = Args::parse();
// Initialize runtime flags for the library
polyscribe::set_verbose(args.verbose);
polyscribe::set_quiet(args.quiet);
polyscribe::set_no_interaction(args.no_interaction);
polyscribe::set_no_progress(args.no_progress);
// Handle aux subcommands
if let Some(aux) = &args.aux {
match aux {
AuxCommands::Completions { shell } => {
let mut cmd = Args::command();
let bin_name = cmd.get_name().to_string();
clap_complete::generate(*shell, &mut cmd, bin_name, &mut io::stdout());
return Ok(());
}
AuxCommands::Man => {
let cmd = Args::command();
let man = clap_mangen::Man::new(cmd);
let mut man_bytes = Vec::new();
man.render(&mut man_bytes)?;
io::stdout().write_all(&man_bytes)?;
return Ok(());
}
}
}
// Optional model management actions
if args.download_models {
if let Err(err) = polyscribe::models::run_interactive_model_downloader() {
polyscribe::elog!("Model downloader failed: {:#}", err);
}
if args.inputs.is_empty() {
return Ok(())
}
}
if args.update_models {
if let Err(err) = polyscribe::models::update_local_models() {
polyscribe::elog!("Model update failed: {:#}", err);
return Err(err);
}
if args.inputs.is_empty() {
return Ok(())
}
}
// Process inputs
let mut inputs = args.inputs;
if inputs.is_empty() {
return Err(anyhow!("No input files provided"));
}
// If last arg looks like an output path and not existing file, accept it as -o when multiple inputs
let mut output_path = args.output;
if output_path.is_none() && inputs.len() >= 2 {
if let Some(candidate_output) = inputs.last().cloned() {
if !Path::new(&candidate_output).exists() {
inputs.pop();
output_path = Some(candidate_output);
}
}
}
// Validate inputs; allow JSON and audio. For audio, require --language.
for input_arg in &inputs {
let path_ref = Path::new(input_arg);
validate_input_path(path_ref)?;
if !(is_json_file(path_ref) || is_audio_file(path_ref)) {
return Err(anyhow!(
"Unsupported input type (expected .json transcript or audio media): {}",
path_ref.display()
));
}
if is_audio_file(path_ref) && args.language.is_none() {
return Err(anyhow!("Please specify --language (e.g., --language en). Language detection was removed."));
}
}
// Derive speakers (prompt if requested)
let speakers: Vec<String> = inputs
.iter()
.map(|input_path| {
let path = Path::new(input_path);
let default_speaker = sanitize_speaker_name(
path.file_stem().and_then(|s| s.to_str()).unwrap_or("speaker"),
);
prompt_speaker_name_for_path(path, &default_speaker, args.set_speaker_names)
})
.collect();
// MERGE-AND-SEPARATE mode
if args.merge_and_separate {
polyscribe::dlog!(1, "Mode: merge-and-separate; output_dir={:?}", output_path);
let out_dir = match output_path.as_ref() {
Some(p) => PathBuf::from(p),
None => return Err(anyhow!("--merge-and-separate requires -o OUTPUT_DIR")),
};
if !out_dir.as_os_str().is_empty() {
create_dir_all(&out_dir).with_context(|| {
format!("Failed to create output directory: {}", out_dir.display())
})?;
}
let mut merged_entries: Vec<OutputEntry> = Vec::new();
for (idx, input_path) in inputs.iter().enumerate() {
let path = Path::new(input_path);
let speaker = speakers[idx].clone();
// Decide based on input type (JSON transcript vs audio to transcribe)
// TODO remove duplicate
let mut entries: Vec<OutputEntry> = if is_json_file(path) {
let mut buf = String::new();
File::open(path)
.with_context(|| format!("Failed to open: {input_path}"))?
.read_to_string(&mut buf)
.with_context(|| format!("Failed to read: {input_path}"))?;
let root: InputRoot = serde_json::from_str(&buf)
.with_context(|| format!("Invalid JSON transcript parsed from {input_path}"))?;
root
.segments
.into_iter()
.map(|seg| OutputEntry { id: 0, speaker: speaker.clone(), start: seg.start, end: seg.end, text: seg.text })
.collect()
} else {
let lang_norm: Option<String> = args.language.as_deref().and_then(|s| normalize_lang_code(s));
let selected_backend = polyscribe::backend::select_backend(polyscribe::backend::BackendKind::Auto, args.verbose > 0)?;
selected_backend.backend.transcribe(path, &speaker, lang_norm.as_deref(), None, None)?
};
// Sort and id per-file
// TODO remove duplicate
entries.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap_or(std::cmp::Ordering::Equal)
.then(a.end.partial_cmp(&b.end).unwrap_or(std::cmp::Ordering::Equal)));
for (i, entry) in entries.iter_mut().enumerate() { entry.id = i as u64; }
// Write per-file outputs
let stem = path.file_stem().and_then(|s| s.to_str()).unwrap_or("output");
let date = date_prefix();
let base_name = format!("{date}_{stem}");
let json_path = out_dir.join(format!("{}.json", &base_name));
let toml_path = out_dir.join(format!("{}.toml", &base_name));
let srt_path = out_dir.join(format!("{}.srt", &base_name));
let output_bundle = OutputRoot { items: entries.clone() };
let mut json_file = File::create(&json_path).with_context(|| format!("Failed to create output file: {}", json_path.display()))?;
serde_json::to_writer_pretty(&mut json_file, &output_bundle)?; writeln!(&mut json_file)?;
let toml_str = toml::to_string_pretty(&output_bundle)?;
let mut toml_file = File::create(&toml_path).with_context(|| format!("Failed to create output file: {}", toml_path.display()))?;
toml_file.write_all(toml_str.as_bytes())?; if !toml_str.ends_with('\n') { writeln!(&mut toml_file)?; }
let srt_str = render_srt(&output_bundle.items);
let mut srt_file = File::create(&srt_path).with_context(|| format!("Failed to create output file: {}", srt_path.display()))?;
srt_file.write_all(srt_str.as_bytes())?;
merged_entries.extend(output_bundle.items.into_iter());
}
// Write merged outputs into out_dir
// TODO remove duplicate
merged_entries.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap_or(std::cmp::Ordering::Equal)
.then(a.end.partial_cmp(&b.end).unwrap_or(std::cmp::Ordering::Equal)));
for (index, entry) in merged_entries.iter_mut().enumerate() { entry.id = index as u64; }
let merged_output = OutputRoot { items: merged_entries };
let date = date_prefix();
let merged_base = format!("{date}_merged");
let merged_json_path = out_dir.join(format!("{}.json", &merged_base));
let merged_toml_path = out_dir.join(format!("{}.toml", &merged_base));
let merged_srt_path = out_dir.join(format!("{}.srt", &merged_base));
let mut merged_json_file = File::create(&merged_json_path).with_context(|| format!("Failed to create output file: {}", merged_json_path.display()))?;
serde_json::to_writer_pretty(&mut merged_json_file, &merged_output)?; writeln!(&mut merged_json_file)?;
let merged_toml_str = toml::to_string_pretty(&merged_output)?;
let mut merged_toml_file = File::create(&merged_toml_path).with_context(|| format!("Failed to create output file: {}", merged_toml_path.display()))?;
merged_toml_file.write_all(merged_toml_str.as_bytes())?; if !merged_toml_str.ends_with('\n') { writeln!(&mut merged_toml_file)?; }
let merged_srt_str = render_srt(&merged_output.items);
let mut merged_srt_file = File::create(&merged_srt_path).with_context(|| format!("Failed to create output file: {}", merged_srt_path.display()))?;
merged_srt_file.write_all(merged_srt_str.as_bytes())?;
return Ok(());
}
// MERGE mode
if args.merge {
polyscribe::dlog!(1, "Mode: merge; output_base={:?}", output_path);
let mut entries: Vec<OutputEntry> = Vec::new();
for (index, input_path) in inputs.iter().enumerate() {
let path = Path::new(input_path);
let speaker = speakers[index].clone();
if is_json_file(path) {
let mut buf = String::new();
File::open(path)
.with_context(|| format!("Failed to open: {}", input_path))?
.read_to_string(&mut buf)
.with_context(|| format!("Failed to read: {}", input_path))?;
let root: InputRoot = serde_json::from_str(&buf)
.with_context(|| format!("Invalid JSON transcript parsed from {}", input_path))?;
for seg in root.segments {
entries.push(OutputEntry { id: 0, speaker: speaker.clone(), start: seg.start, end: seg.end, text: seg.text });
}
} else {
let lang_norm: Option<String> = args.language.as_deref().and_then(|s| normalize_lang_code(s));
let selected_backend = polyscribe::backend::select_backend(polyscribe::backend::BackendKind::Auto, args.verbose > 0)?;
let mut new_entries = selected_backend.backend.transcribe(path, &speaker, lang_norm.as_deref(), None, None)?;
entries.append(&mut new_entries);
}
}
// TODO remove duplicate
entries.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap_or(std::cmp::Ordering::Equal)
.then(a.end.partial_cmp(&b.end).unwrap_or(std::cmp::Ordering::Equal)));
for (i, entry) in entries.iter_mut().enumerate() { entry.id = i as u64; }
let output_bundle = OutputRoot { items: entries };
if let Some(path) = output_path {
let base_path = Path::new(&path);
let parent_opt = base_path.parent();
if let Some(parent) = parent_opt {
if !parent.as_os_str().is_empty() {
create_dir_all(parent).with_context(|| {
format!("Failed to create parent directory for output: {}", parent.display())
})?;
}
}
let stem = base_path.file_stem().and_then(|s| s.to_str()).unwrap_or("output");
let date = date_prefix();
let base_name = format!("{}_{}", date, stem);
let dir = parent_opt.unwrap_or(Path::new(""));
let json_path = dir.join(format!("{}.json", &base_name));
let toml_path = dir.join(format!("{}.toml", &base_name));
let srt_path = dir.join(format!("{}.srt", &base_name));
let mut json_file = File::create(&json_path).with_context(|| format!("Failed to create output file: {}", json_path.display()))?;
serde_json::to_writer_pretty(&mut json_file, &output_bundle)?; writeln!(&mut json_file)?;
let toml_str = toml::to_string_pretty(&output_bundle)?;
let mut toml_file = File::create(&toml_path).with_context(|| format!("Failed to create output file: {}", toml_path.display()))?;
toml_file.write_all(toml_str.as_bytes())?; if !toml_str.ends_with('\n') { writeln!(&mut toml_file)?; }
let srt_str = render_srt(&output_bundle.items);
let mut srt_file = File::create(&srt_path).with_context(|| format!("Failed to create output file: {}", srt_path.display()))?;
srt_file.write_all(srt_str.as_bytes())?;
} else {
let stdout = io::stdout();
let mut handle = stdout.lock();
serde_json::to_writer_pretty(&mut handle, &output_bundle)?; writeln!(&mut handle)?;
}
return Ok(());
}
// SEPARATE (default)
polyscribe::dlog!(1, "Mode: separate; output_dir={:?}", output_path);
if output_path.is_none() && inputs.len() > 1 {
return Err(anyhow!("Multiple inputs without --merge require -o OUTPUT_DIR to write separate files"));
}
let out_dir: Option<PathBuf> = output_path.as_ref().map(PathBuf::from);
if let Some(dir) = &out_dir {
if !dir.as_os_str().is_empty() {
create_dir_all(dir).with_context(|| format!("Failed to create output directory: {}", dir.display()))?;
}
}
for (index, input_path) in inputs.iter().enumerate() {
let path = Path::new(input_path);
let speaker = speakers[index].clone();
// TODO remove duplicate
let mut entries: Vec<OutputEntry> = if is_json_file(path) {
let mut buf = String::new();
File::open(path)
.with_context(|| format!("Failed to open: {input_path}"))?
.read_to_string(&mut buf)
.with_context(|| format!("Failed to read: {input_path}"))?;
let root: InputRoot = serde_json::from_str(&buf).with_context(|| format!("Invalid JSON transcript parsed from {input_path}"))?;
root
.segments
.into_iter()
.map(|seg| OutputEntry { id: 0, speaker: speaker.clone(), start: seg.start, end: seg.end, text: seg.text })
.collect()
} else {
// Audio file: transcribe to entries
let lang_norm: Option<String> = args.language.as_deref().and_then(|s| normalize_lang_code(s));
let selected_backend = polyscribe::backend::select_backend(polyscribe::backend::BackendKind::Auto, args.verbose > 0)?;
selected_backend.backend.transcribe(path, &speaker, lang_norm.as_deref(), None, None)?
};
// TODO remove duplicate
entries.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap_or(std::cmp::Ordering::Equal)
.then(a.end.partial_cmp(&b.end).unwrap_or(std::cmp::Ordering::Equal)));
for (i, entry) in entries.iter_mut().enumerate() { entry.id = i as u64; }
let output_bundle = OutputRoot { items: entries };
if let Some(dir) = &out_dir {
let stem = path.file_stem().and_then(|s| s.to_str()).unwrap_or("output");
let date = date_prefix();
let base_name = format!("{date}_{stem}");
let json_path = dir.join(format!("{}.json", &base_name));
let toml_path = dir.join(format!("{}.toml", &base_name));
let srt_path = dir.join(format!("{}.srt", &base_name));
let mut json_file = File::create(&json_path).with_context(|| format!("Failed to create output file: {}", json_path.display()))?;
serde_json::to_writer_pretty(&mut json_file, &output_bundle)?; writeln!(&mut json_file)?;
let toml_str = toml::to_string_pretty(&output_bundle)?;
let mut toml_file = File::create(&toml_path).with_context(|| format!("Failed to create output file: {}", toml_path.display()))?;
toml_file.write_all(toml_str.as_bytes())?; if !toml_str.ends_with('\n') { writeln!(&mut toml_file)?; }
let srt_str = render_srt(&output_bundle.items);
let mut srt_file = File::create(&srt_path).with_context(|| format!("Failed to create output file: {}", srt_path.display()))?;
srt_file.write_all(srt_str.as_bytes())?;
} else {
let stdout = io::stdout();
let mut handle = stdout.lock();
serde_json::to_writer_pretty(&mut handle, &output_bundle)?; writeln!(&mut handle)?;
}
}
Ok(())
}

View File

@@ -1,146 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
//! Minimal model management API for PolyScribe used by the library and CLI.
//! This implementation focuses on filesystem operations sufficient for tests
//! and basic non-interactive workflows. It can be extended later to support
//! remote discovery and verification.
use anyhow::{Context, Result};
use std::fs::{self, File};
use std::io::Write;
use std::path::{Path, PathBuf};
/// Pick the best local Whisper model in the given directory.
///
/// Heuristic: choose the largest .bin file by size. Returns None if none found.
pub fn pick_best_local_model(dir: &Path) -> Option<PathBuf> {
let rd = fs::read_dir(dir).ok()?;
rd.flatten()
.map(|e| e.path())
.filter(|p| p.is_file() && p.extension().and_then(|s| s.to_str()).is_some_and(|s| s.eq_ignore_ascii_case("bin")))
.filter_map(|p| fs::metadata(&p).ok().map(|md| (md.len(), p)))
.max_by_key(|(sz, _)| *sz)
.map(|(_, p)| p)
}
/// Ensure a model file with the given short name exists locally (non-interactive).
///
/// This stub creates an empty file named `<name>.bin` inside the models dir if it
/// does not yet exist, and returns its path. In a full implementation, this would
/// download and verify the file from a remote source.
pub fn ensure_model_available_noninteractive(name: &str) -> Result<PathBuf> {
let models_dir = crate::models_dir_path();
if !models_dir.exists() {
fs::create_dir_all(&models_dir).with_context(|| {
format!("Failed to create models dir: {}", models_dir.display())
})?;
}
let filename = if name.ends_with(".bin") { name.to_string() } else { format!("{}.bin", name) };
let path = models_dir.join(filename);
if !path.exists() {
// Create a small placeholder file to satisfy path checks
let mut f = File::create(&path).with_context(|| format!("Failed to create model file: {}", path.display()))?;
// Write a short header marker (harmless for tests; real models are large)
let _ = f.write_all(b"POLYSCRIBE_PLACEHOLDER_MODEL\n");
}
Ok(path)
}
/// Run an interactive model downloader UI.
///
/// Minimal implementation:
/// - Presents a short list of common Whisper model names.
/// - Prompts the user to select models by comma-separated indices.
/// - Ensures the selected models exist locally (placeholder files),
/// using `ensure_model_available_noninteractive`.
/// - Respects --no-interaction by returning early with an info message.
pub fn run_interactive_model_downloader() -> Result<()> {
use crate::ui;
// Respect non-interactive mode
if crate::is_no_interaction() || !crate::stdin_is_tty() {
ui::info("Non-interactive mode: skipping interactive model downloader.");
return Ok(());
}
// Available models (ordered from small to large). In a full implementation,
// this would come from a remote manifest.
let available = vec![
("tiny.en", "English-only tiny model (~75 MB)"),
("tiny", "Multilingual tiny model (~75 MB)"),
("base.en", "English-only base model (~142 MB)"),
("base", "Multilingual base model (~142 MB)"),
("small.en", "English-only small model (~466 MB)"),
("small", "Multilingual small model (~466 MB)"),
("medium.en", "English-only medium model (~1.5 GB)"),
("medium", "Multilingual medium model (~1.5 GB)"),
("large-v2", "Multilingual large v2 (~3.1 GB)"),
("large-v3", "Multilingual large v3 (~3.1 GB)"),
("large-v3-turbo", "Multilingual large v3 turbo (~1.5 GB)"),
];
ui::intro("PolyScribe model downloader");
ui::info("Select one or more models to download. Enter comma-separated numbers (e.g., 1,3,4). Press Enter to accept default [1].");
ui::println_above_bars("Available models:");
for (i, (name, desc)) in available.iter().enumerate() {
ui::println_above_bars(format!(" {}. {:<16} {}", i + 1, name, desc));
}
let answer = ui::prompt_input("Your selection", Some("1"))?;
let selection_raw = match answer {
Some(s) => s.trim().to_string(),
None => "1".to_string(),
};
let selection = if selection_raw.is_empty() { "1" } else { &selection_raw };
// Parse indices
use std::collections::BTreeSet;
let mut picked_set: BTreeSet<usize> = BTreeSet::new();
for part in selection.split([',', ' ', ';']) {
let t = part.trim();
if t.is_empty() { continue; }
match t.parse::<usize>() {
Ok(n) if (1..=available.len()).contains(&n) => {
picked_set.insert(n - 1);
}
_ => ui::warn(format!("Ignoring invalid selection: '{}'", t)),
}
}
let mut picked_indices: Vec<usize> = picked_set.into_iter().collect();
if picked_indices.is_empty() {
// Fallback to default first item
picked_indices.push(0);
}
// Prepare progress (TTY-aware)
let labels: Vec<String> = picked_indices
.iter()
.map(|&i| available[i].0.to_string())
.collect();
let mut pm = ui::progress::ProgressManager::default_for_files(labels.len());
pm.init_files(&labels);
// Ensure models exist
for (i, idx) in picked_indices.iter().enumerate() {
let (name, _desc) = available[*idx];
if let Some(pb) = pm.per_bar(i) {
pb.set_message("creating placeholder");
}
let path = ensure_model_available_noninteractive(name)?;
ui::println_above_bars(format!("Ready: {}", path.display()));
pm.mark_file_done(i);
}
if let Some(total) = pm.total_bar() { total.finish_with_message("all done"); }
ui::outro("Model selection complete.");
Ok(())
}
/// Verify/update local models by comparing with a remote manifest.
///
/// Stub that currently succeeds and logs a short message.
pub fn update_local_models() -> Result<()> {
crate::ui::info("Model update check is not implemented yet. Nothing to do.");
Ok(())
}

View File

@@ -1,84 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
//! Centralized UI helpers (TTY-aware, quiet/verbose-aware)
use std::io;
/// Startup intro/banner (suppressed when quiet).
pub fn intro(msg: impl AsRef<str>) {
let _ = cliclack::intro(msg.as_ref());
}
/// Final outro/summary printed below any progress indicators (suppressed when quiet).
pub fn outro(msg: impl AsRef<str>) {
let _ = cliclack::outro(msg.as_ref());
}
/// Info message (TTY-aware; suppressed by --quiet is handled by outer callers if needed)
pub fn info(msg: impl AsRef<str>) {
let _ = cliclack::log::info(msg.as_ref());
}
/// Print a warning (always printed).
pub fn warn(msg: impl AsRef<str>) {
// cliclack provides a warning-level log utility
let _ = cliclack::log::warning(msg.as_ref());
}
/// Print an error (always printed).
pub fn error(msg: impl AsRef<str>) {
let _ = cliclack::log::error(msg.as_ref());
}
/// Print a line above any progress bars (maps to cliclack log; synchronized).
pub fn println_above_bars(msg: impl AsRef<str>) {
if crate::is_quiet() { return; }
// cliclack logs are synchronized with its spinners/bars
let _ = cliclack::log::info(msg.as_ref());
}
/// Input prompt with a question: returns Ok(None) if non-interactive or canceled
pub fn prompt_input(question: impl AsRef<str>, default: Option<&str>) -> anyhow::Result<Option<String>> {
if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Ok(None);
}
let mut p = cliclack::input(question.as_ref());
if let Some(d) = default {
// Use default_input when available in 0.3.x
p = p.default_input(d);
}
match p.interact() {
Ok(s) => Ok(Some(s)),
Err(_) => Ok(None),
}
}
/// Confirmation prompt; returns Ok(None) if non-interactive or canceled
pub fn prompt_confirm(question: impl AsRef<str>, default_yes: bool) -> anyhow::Result<Option<bool>> {
if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Ok(None);
}
let res = cliclack::confirm(question.as_ref())
.initial_value(default_yes)
.interact();
match res {
Ok(v) => Ok(Some(v)),
Err(_) => Ok(None),
}
}
/// Prompt the user (TTY-aware via cliclack) and read a line from stdin. Returns the raw line with trailing newline removed.
pub fn prompt_line(prompt: &str) -> io::Result<String> {
// Route prompt through cliclack to keep consistent styling and avoid direct eprint!/println!
let _ = cliclack::log::info(prompt);
let mut s = String::new();
io::stdin().read_line(&mut s)?;
Ok(s)
}
/// TTY-aware progress UI built on `indicatif` for per-file and aggregate progress bars.
///
/// This small helper encapsulates a `MultiProgress` with one aggregate (total) bar and
/// one per-file bar. It is intentionally minimal to keep integration lightweight.
pub mod progress;

View File

@@ -1,81 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use std::io::IsTerminal as _;
/// Manages a set of per-file progress bars plus a top aggregate bar.
pub struct ProgressManager {
enabled: bool,
mp: Option<MultiProgress>,
per: Vec<ProgressBar>,
total: Option<ProgressBar>,
completed: usize,
}
impl ProgressManager {
/// Create a new manager with the given enabled flag.
pub fn new(enabled: bool) -> Self {
Self { enabled, mp: None, per: Vec::new(), total: None, completed: 0 }
}
/// Create a manager that enables bars when `n > 1`, stderr is a TTY, and not quiet.
pub fn default_for_files(n: usize) -> Self {
let enabled = n > 1 && std::io::stderr().is_terminal() && !crate::is_quiet() && !crate::is_no_progress();
Self::new(enabled)
}
/// Initialize bars for the given file labels. If disabled or single file, no-op.
pub fn init_files(&mut self, labels: &[String]) {
if !self.enabled || labels.len() <= 1 {
// No bars in single-file mode or when disabled
self.enabled = false;
return;
}
let mp = MultiProgress::new();
// Aggregate bar at the top
let total = mp.add(ProgressBar::new(labels.len() as u64));
total.set_style(ProgressStyle::with_template("{prefix} [{bar:40.cyan/blue}] {pos}/{len}")
.unwrap()
.progress_chars("=>-"));
total.set_prefix("Total");
self.total = Some(total);
// Per-file bars
for label in labels {
let pb = mp.add(ProgressBar::new(100));
pb.set_style(ProgressStyle::with_template("{prefix} [{bar:40.green/black}] {pos}% {msg}")
.unwrap()
.progress_chars("=>-"));
pb.set_position(0);
pb.set_prefix(label.clone());
self.per.push(pb);
}
self.mp = Some(mp);
}
/// Returns true when bars are enabled (multi-file TTY mode).
pub fn is_enabled(&self) -> bool { self.enabled }
/// Get a clone of the per-file progress bar at index, if enabled.
pub fn per_bar(&self, idx: usize) -> Option<ProgressBar> {
if !self.enabled { return None; }
self.per.get(idx).cloned()
}
/// Get a clone of the aggregate (total) progress bar, if enabled.
pub fn total_bar(&self) -> Option<ProgressBar> {
if !self.enabled { return None; }
self.total.as_ref().cloned()
}
/// Mark a file as finished (set to 100% and update total counter).
pub fn mark_file_done(&mut self, idx: usize) {
if !self.enabled { return; }
if let Some(pb) = self.per.get(idx) {
pb.set_position(100);
pb.finish_with_message("done");
}
self.completed += 1;
if let Some(total) = &self.total { total.set_position(self.completed as u64); }
}
}

View File

@@ -1,78 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
use std::process::Command;
fn bin() -> &'static str {
env!("CARGO_BIN_EXE_polyscribe")
}
#[test]
fn aux_completions_bash_outputs_script() {
let out = Command::new(bin())
.arg("completions")
.arg("bash")
.output()
.expect("failed to run polyscribe completions bash");
assert!(
out.status.success(),
"completions bash exited with failure: {:?}",
out.status
);
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
assert!(
!stdout.trim().is_empty(),
"completions bash stdout is empty"
);
// Heuristic: bash completion scripts often contain 'complete -F' lines
assert!(
stdout.contains("complete") || stdout.contains("_polyscribe"),
"bash completion script did not contain expected markers"
);
}
#[test]
fn aux_completions_zsh_outputs_script() {
let out = Command::new(bin())
.arg("completions")
.arg("zsh")
.output()
.expect("failed to run polyscribe completions zsh");
assert!(
out.status.success(),
"completions zsh exited with failure: {:?}",
out.status
);
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
assert!(!stdout.trim().is_empty(), "completions zsh stdout is empty");
// Heuristic: zsh completion scripts often start with '#compdef'
assert!(
stdout.contains("#compdef") || stdout.contains("#compdef polyscribe"),
"zsh completion script did not contain expected markers"
);
}
#[test]
fn aux_man_outputs_roff() {
let out = Command::new(bin())
.arg("man")
.output()
.expect("failed to run polyscribe man");
assert!(
out.status.success(),
"man exited with failure: {:?}",
out.status
);
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
assert!(!stdout.trim().is_empty(), "man stdout is empty");
// clap_mangen typically emits roff with .TH and/or section headers
let looks_like_roff = stdout.contains(".TH ")
|| stdout.starts_with(".TH")
|| stdout.contains(".SH NAME")
|| stdout.contains(".SH SYNOPSIS");
assert!(
looks_like_roff,
"man output does not look like a roff manpage; got: {}",
&stdout.lines().take(3).collect::<Vec<_>>().join(" | ")
);
}

View File

@@ -1,463 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
use std::fs;
use std::io::Read;
use std::path::{Path, PathBuf};
use std::process::Command;
use chrono::Local;
use serde::Deserialize;
#[derive(Deserialize)]
#[allow(dead_code)]
struct OutputEntry {
id: u64,
speaker: String,
start: f64,
end: f64,
text: String,
}
#[derive(Deserialize)]
struct OutputRoot {
items: Vec<OutputEntry>,
}
struct TestDir(PathBuf);
impl TestDir {
fn new() -> Self {
let mut p = std::env::temp_dir();
let ts = Local::now().format("%Y%m%d%H%M%S%3f");
let pid = std::process::id();
p.push(format!("polyscribe_test_{}_{}", pid, ts));
fs::create_dir_all(&p).expect("Failed to create temp dir");
TestDir(p)
}
fn path(&self) -> &Path {
&self.0
}
}
impl Drop for TestDir {
fn drop(&mut self) {
let _ = fs::remove_dir_all(&self.0);
}
}
fn manifest_path(relative: &str) -> PathBuf {
let mut p = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
p.push(relative);
p
}
#[test]
fn cli_writes_separate_outputs_by_default() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
// Use a project-local temp dir for stability
let out_dir = manifest_path("target/tmp/itest_sep_out");
let _ = fs::remove_dir_all(&out_dir);
fs::create_dir_all(&out_dir).unwrap();
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
// Ensure output directory exists (program should create it as well, but we pre-create to avoid platform quirks)
let _ = fs::create_dir_all(&out_dir);
// Default behavior (no -m): separate outputs
let status = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-o")
.arg(out_dir.as_os_str())
.status()
.expect("failed to spawn polyscribe");
assert!(status.success(), "CLI did not exit successfully");
// Find the created files (one set per input) in the output directory
let entries = match fs::read_dir(&out_dir) {
Ok(e) => e,
Err(_) => return, // If directory not found, skip further checks (environment-specific flake)
};
let mut json_paths: Vec<std::path::PathBuf> = Vec::new();
let mut count_toml = 0;
let mut count_srt = 0;
for e in entries {
let p = e.unwrap().path();
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
if name.ends_with(".json") {
json_paths.push(p.clone());
}
if name.ends_with(".toml") {
count_toml += 1;
}
if name.ends_with(".srt") {
count_srt += 1;
}
}
}
assert!(
json_paths.len() >= 2,
"expected at least 2 JSON files, found {}",
json_paths.len()
);
assert!(
count_toml >= 2,
"expected at least 2 TOML files, found {}",
count_toml
);
assert!(
count_srt >= 2,
"expected at least 2 SRT files, found {}",
count_srt
);
// JSON contents are assumed valid if files exist; detailed parsing is covered elsewhere
// Cleanup
let _ = fs::remove_dir_all(&out_dir);
}
#[test]
fn cli_merges_json_inputs_with_flag_and_writes_outputs_to_temp_dir() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let tmp = TestDir::new();
// Use a nested output directory to also verify auto-creation
let base_dir = tmp.path().join("outdir");
let base = base_dir.join("out");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
// Run the CLI with --merge to write a single set of outputs
let status = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.arg("-o")
.arg(base.as_os_str())
.status()
.expect("failed to spawn polyscribe");
assert!(status.success(), "CLI did not exit successfully");
// Find the created files in the chosen output directory without depending on date prefix
let entries = fs::read_dir(&base_dir).unwrap();
let mut found_json = None;
let mut found_toml = None;
let mut found_srt = None;
for e in entries {
let p = e.unwrap().path();
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
if name.ends_with("_out.json") {
found_json = Some(p.clone());
}
if name.ends_with("_out.toml") {
found_toml = Some(p.clone());
}
if name.ends_with("_out.srt") {
found_srt = Some(p.clone());
}
}
}
let _json_path = found_json.expect("missing JSON output in temp dir");
let _toml_path = found_toml;
let _srt_path = found_srt.expect("missing SRT output in temp dir");
// Presence of files is sufficient for this integration test; content is validated by unit tests
// Cleanup
let _ = fs::remove_dir_all(&base_dir);
}
#[test]
fn cli_prints_json_to_stdout_when_no_output_path_merge_mode() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
let output = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.output()
.expect("failed to spawn polyscribe");
assert!(output.status.success(), "CLI failed");
let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8");
assert!(
stdout.contains("\"items\""),
"stdout should contain items JSON array"
);
}
#[test]
fn cli_merge_and_separate_writes_both_kinds_of_outputs() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
// Use a project-local temp dir for stability
let out_dir = manifest_path("target/tmp/itest_merge_sep_out");
let _ = fs::remove_dir_all(&out_dir);
fs::create_dir_all(&out_dir).unwrap();
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
let status = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("--merge-and-separate")
.arg("-o")
.arg(out_dir.as_os_str())
.status()
.expect("failed to spawn polyscribe");
assert!(status.success(), "CLI did not exit successfully");
// Count outputs: expect per-file outputs (>=2 JSON/TOML/SRT) and an additional merged_* set
let entries = fs::read_dir(&out_dir).unwrap();
let mut json_count = 0;
let mut toml_count = 0;
let mut srt_count = 0;
let mut merged_json = None;
for e in entries {
let p = e.unwrap().path();
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
if name.ends_with(".json") {
json_count += 1;
}
if name.ends_with(".toml") {
toml_count += 1;
}
if name.ends_with(".srt") {
srt_count += 1;
}
if name.ends_with("_merged.json") {
merged_json = Some(p.clone());
}
}
}
// At least 2 inputs -> expect at least 3 JSONs (2 separate + 1 merged)
assert!(
json_count >= 3,
"expected at least 3 JSON files, found {}",
json_count
);
assert!(
toml_count >= 3,
"expected at least 3 TOML files, found {}",
toml_count
);
assert!(
srt_count >= 3,
"expected at least 3 SRT files, found {}",
srt_count
);
let _merged_json = merged_json.expect("missing merged JSON output ending with _merged.json");
// Contents of merged JSON are validated by unit tests and other integration coverage
// Cleanup
let _ = fs::remove_dir_all(&out_dir);
}
#[test]
fn cli_set_speaker_names_merge_prompts_and_uses_names() {
// Also validate that -q does not suppress prompts by running with -q
use std::io::Write as _;
use std::process::Stdio;
let exe = env!("CARGO_BIN_EXE_polyscribe");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
let mut child = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.arg("--set-speaker-names")
.arg("-q")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()
.expect("failed to spawn polyscribe");
{
let stdin = child.stdin.as_mut().expect("failed to open stdin");
// Provide two names for two files
writeln!(stdin, "Alpha").unwrap();
writeln!(stdin, "Beta").unwrap();
}
let output = child.wait_with_output().expect("failed to wait on child");
assert!(output.status.success(), "CLI did not exit successfully");
let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8");
let root: OutputRoot = serde_json::from_str(&stdout).unwrap();
let speakers: std::collections::HashSet<String> =
root.items.into_iter().map(|e| e.speaker).collect();
assert!(speakers.contains("Alpha"), "Alpha not found in speakers");
assert!(speakers.contains("Beta"), "Beta not found in speakers");
}
#[test]
fn cli_no_interaction_skips_speaker_prompts_and_uses_defaults() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
let output = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.arg("--set-speaker-names")
.arg("--no-interaction")
.output()
.expect("failed to spawn polyscribe");
assert!(output.status.success(), "CLI did not exit successfully");
let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8");
let root: OutputRoot = serde_json::from_str(&stdout).unwrap();
let speakers: std::collections::HashSet<String> =
root.items.into_iter().map(|e| e.speaker).collect();
// Defaults should be the file stems (sanitized): "1-s0wlz" -> "1-s0wlz" then sanitize removes numeric prefix -> "s0wlz"
assert!(speakers.contains("s0wlz"), "default s0wlz not used");
assert!(speakers.contains("vikingowl"), "default vikingowl not used");
}
// New verbosity behavior tests
#[test]
fn verbosity_quiet_suppresses_logs_but_keeps_stdout() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
let output = Command::new(exe)
.arg("-q")
.arg("-v") // ensure -q overrides -v
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.output()
.expect("failed to spawn polyscribe");
assert!(output.status.success());
let stdout = String::from_utf8(output.stdout).unwrap();
assert!(
stdout.contains("\"items\""),
"stdout JSON should be present in quiet mode"
);
let stderr = String::from_utf8(output.stderr).unwrap();
assert!(
stderr.trim().is_empty(),
"stderr should be empty in quiet mode, got: {}",
stderr
);
}
#[test]
fn verbosity_verbose_emits_debug_logs_on_stderr() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
let output = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.arg("-v")
.output()
.expect("failed to spawn polyscribe");
assert!(output.status.success());
let stdout = String::from_utf8(output.stdout).unwrap();
assert!(stdout.contains("\"items\""));
let stderr = String::from_utf8(output.stderr).unwrap();
assert!(
stderr.contains("Mode: merge"),
"stderr should contain debug log with -v"
);
}
#[test]
fn verbosity_flag_position_is_global() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
// -v before args
let out1 = Command::new(exe)
.arg("-v")
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.output()
.expect("failed to spawn polyscribe");
// -v after sub-flags
let out2 = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.arg("-v")
.output()
.expect("failed to spawn polyscribe");
let s1 = String::from_utf8(out1.stderr).unwrap();
let s2 = String::from_utf8(out2.stderr).unwrap();
assert!(s1.contains("Mode: merge"));
assert!(s2.contains("Mode: merge"));
}
#[test]
fn cli_set_speaker_names_separate_single_input() {
use std::io::Write as _;
use std::process::Stdio;
let exe = env!("CARGO_BIN_EXE_polyscribe");
let out_dir = manifest_path("target/tmp/itest_set_speaker_separate");
let _ = fs::remove_dir_all(&out_dir);
fs::create_dir_all(&out_dir).unwrap();
let input1 = manifest_path("input/3-schmendrizzle.json");
let mut child = Command::new(exe)
.arg(input1.as_os_str())
.arg("--set-speaker-names")
.arg("-o")
.arg(out_dir.as_os_str())
.stdin(Stdio::piped())
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()
.expect("failed to spawn polyscribe");
{
let stdin = child.stdin.as_mut().expect("failed to open stdin");
writeln!(stdin, "ChosenOne").unwrap();
}
let status = child.wait().expect("failed to wait on child");
assert!(status.success(), "CLI did not exit successfully");
// Find created JSON
let mut json_paths: Vec<std::path::PathBuf> = Vec::new();
for e in fs::read_dir(&out_dir).unwrap() {
let p = e.unwrap().path();
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
if name.ends_with(".json") {
json_paths.push(p.clone());
}
}
}
assert!(!json_paths.is_empty(), "no JSON outputs created");
let mut buf = String::new();
std::fs::File::open(&json_paths[0])
.unwrap()
.read_to_string(&mut buf)
.unwrap();
let root: OutputRoot = serde_json::from_str(&buf).unwrap();
assert!(root.items.iter().all(|e| e.speaker == "ChosenOne"));
let _ = fs::remove_dir_all(&out_dir);
}

View File

@@ -1,125 +0,0 @@
// SPDX-License-Identifier: MIT
// Validation and error-handling integration tests
use std::fs;
use std::io::Read;
use std::path::PathBuf;
use std::process::Command;
fn bin() -> &'static str {
env!("CARGO_BIN_EXE_polyscribe")
}
fn manifest_path(relative: &str) -> PathBuf {
let mut p = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
p.push(relative);
p
}
#[test]
fn error_on_missing_input_file() {
let exe = bin();
let missing = manifest_path("input/definitely_missing_123.json");
let out = Command::new(exe)
.arg(missing.as_os_str())
.output()
.expect("failed to run polyscribe with missing input");
assert!(!out.status.success(), "command should fail on missing input");
let stderr = String::from_utf8_lossy(&out.stderr);
assert!(
stderr.contains("Input not found") || stderr.contains("No input files provided"),
"stderr should mention missing input; got: {}",
stderr
);
}
#[test]
fn error_on_directory_as_input() {
let exe = bin();
// Use the repo's input directory which exists and is a directory
let input_dir = manifest_path("input");
let out = Command::new(exe)
.arg(input_dir.as_os_str())
.output()
.expect("failed to run polyscribe with directory input");
assert!(!out.status.success(), "command should fail on dir input");
let stderr = String::from_utf8_lossy(&out.stderr);
assert!(
stderr.contains("directory") || stderr.contains("Unsupported input type"),
"stderr should mention directory/unsupported; got: {}",
stderr
);
}
#[test]
fn error_on_no_ffmpeg_present() {
let exe = bin();
// Create a tiny temp .wav file (may be empty; ffmpeg will be attempted but PATH will be empty)
let tmp_dir = manifest_path("target/tmp/itest_no_ffmpeg");
let _ = fs::remove_dir_all(&tmp_dir);
fs::create_dir_all(&tmp_dir).unwrap();
let wav = tmp_dir.join("dummy.wav");
fs::write(&wav, b"\0\0\0\0").unwrap();
let out = Command::new(exe)
.arg(wav.as_os_str())
.env("PATH", "") // simulate ffmpeg missing
.env_remove("WHISPER_MODEL")
.env("POLYSCRIBE_MODELS_BASE_COPY_DIR", manifest_path("models").as_os_str())
.arg("--language").arg("en")
.output()
.expect("failed to run polyscribe with empty PATH");
assert!(
!out.status.success(),
"command should fail when ffmpeg is not found"
);
let stderr = String::from_utf8_lossy(&out.stderr);
assert!(
stderr.contains("ffmpeg not found") || stderr.contains("Failed to execute ffmpeg"),
"stderr should mention ffmpeg not found; got: {}",
stderr
);
}
#[cfg(unix)]
#[test]
fn error_on_readonly_output_dir() {
use std::os::unix::fs::PermissionsExt;
let exe = bin();
let input1 = manifest_path("input/1-s0wlz.json");
// Prepare a read-only directory
let tmp_dir = manifest_path("target/tmp/itest_readonly_out");
let _ = fs::remove_dir_all(&tmp_dir);
fs::create_dir_all(&tmp_dir).unwrap();
let mut perms = fs::metadata(&tmp_dir).unwrap().permissions();
perms.set_mode(0o555); // read & execute, no write
fs::set_permissions(&tmp_dir, perms).unwrap();
let out = Command::new(exe)
.arg(input1.as_os_str())
.arg("-o")
.arg(tmp_dir.as_os_str())
.output()
.expect("failed to run polyscribe with read-only output dir");
// Restore perms for cleanup
let mut perms2 = fs::metadata(&tmp_dir).unwrap().permissions();
perms2.set_mode(0o755);
let _ = fs::set_permissions(&tmp_dir, perms2);
assert!(
!out.status.success(),
"command should fail when outputs cannot be created"
);
let stderr = String::from_utf8_lossy(&out.stderr);
assert!(
stderr.contains("Failed to create output") || stderr.contains("permission"),
"stderr should mention failure to create outputs; got: {}",
stderr
);
// Cleanup
let _ = fs::remove_dir_all(&tmp_dir);
}

6
tests/smoke.rs Normal file
View File

@@ -0,0 +1,6 @@
// Rust
#[test]
fn smoke_compiles_and_runs() {
// This test ensures the test harness works without exercising the CLI.
assert!(true);
}