Compare commits

...

14 Commits

Author SHA1 Message Date
840383fcf7 [feat] add JSON and quiet output modes for models subcommands, update UI suppression logic, and enhance CLI test coverage
Some checks failed
CI / build (push) Has been cancelled
2025-08-27 23:58:57 +02:00
1982e9b48b [feat] add empty state message for models ls command output 2025-08-27 23:51:21 +02:00
0128bf2eec [feat] add ModelManager with caching, manifest management, and Hugging Face API integration 2025-08-27 20:56:05 +02:00
da5a76d253 [refactor] Refactor project into proper rust workspace. 2025-08-27 18:28:37 +02:00
5ec297397e [refactor] rename and simplify ProgressManager to FileProgress, enhance caching logic, update Hugging Face API integration, and clean up unused comments
Some checks failed
CI / build (push) Has been cancelled
2025-08-15 11:24:50 +02:00
cbf48a0452 docs: align CLI docs to models subcommands; host: scan XDG plugin dir; ci: add GitHub Actions; chore: add CHANGELOG 2025-08-14 11:16:50 +02:00
0a249f2197 [refactor] improve code readability, streamline initialization, update error handling, and format multi-line statements for consistency 2025-08-14 11:06:37 +02:00
0573369b81 [refactor] propagate no-progress and no-interaction flags, enhance prompt handling, and update progress bar logic with cliclack 2025-08-14 10:34:52 +02:00
9841550dcc [refactor] replace indicatif with cliclack for progress and logging, updating affected modules and dependencies 2025-08-14 03:31:00 +02:00
53119cd0ab [refactor] enhance model management with metadata enrichment, new API integration, and manifest resolution 2025-08-13 22:44:51 +02:00
144b01d591 [refactor] update Cargo.lock with new dependency additions and version bumps 2025-08-13 14:45:43 +02:00
ffd451b404 [refactor] remove unused test suites, examples, CI docs, and PR description file 2025-08-13 14:26:18 +02:00
5c64677e79 [refactor] streamline crate structure, update dependencies, and integrate CLI functionalities 2025-08-13 14:05:13 +02:00
128db0f733 [refactor] remove backend and library modules, consolidating features into main crate 2025-08-13 13:35:53 +02:00
48 changed files with 5102 additions and 4082 deletions

17
.cargo/config.toml Normal file
View File

@@ -0,0 +1,17 @@
# SPDX-License-Identifier: MIT
[build]
# Make target-dir consistent across workspace for better cache reuse.
target-dir = "target"
[profile.dev]
opt-level = 1
debug = true
incremental = true
[profile.release]
# Reasonable defaults for CLI apps/libraries
lto = "thin"
codegen-units = 1
strip = "debuginfo"
opt-level = 3

33
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,33 @@
name: CI
on:
push:
branches: [ dev, main ]
pull_request:
branches: [ dev, main ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Rust
uses: dtolnay/rust-toolchain@stable
- name: Cache cargo registry and target
uses: Swatinem/rust-cache@v2
- name: Install components
run: rustup component add clippy rustfmt
- name: Cargo fmt
run: cargo fmt --all -- --check
- name: Clippy
run: cargo clippy --workspace --all-targets -- -D warnings
- name: Test
run: cargo test --workspace --all --locked

14
CHANGELOG.md Normal file
View File

@@ -0,0 +1,14 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on Keep a Changelog, and this project adheres to Semantic Versioning.
## Unreleased
### Changed
- Docs: Replace `--download-models`/`--update-models` flags with `models download`/`models update` subcommands in `README.md`, `docs/usage.md`, and `docs/development.md`.
- Host: Plugin discovery now scans `$XDG_DATA_HOME/polyscribe/plugins` (platform equivalent via `directories`) in addition to `PATH`.
- CI: Add GitHub Actions workflow to run fmt, clippy (warnings as errors), and tests for pushes and PRs.

View File

@@ -1,32 +1,26 @@
# Contributing to PolyScribe # Contributing
Thanks for your interest in contributing! This guide explains the workflow and the checklist to follow before opening a Pull Request. Thank you for your interest in contributing!
Workflow (fork → branch → PR) Development setup
1) Fork the repository to your account. - Install Rust via rustup.
2) Create a feature branch: - Ensure ffmpeg is installed and available on PATH.
- git checkout -b feat/short-description - For GPU builds, install the appropriate runtime (CUDA/ROCm/Vulkan) and enable the matching features.
3) Make changes with focused commits and good messages.
4) Run the checklist below.
5) Push and open a Pull Request against the main repository.
Developer checklist (before opening a PR) Coding guidelines
- Build: - Prefer small, focused changes.
- cargo build (preferably without warnings) - Add tests where reasonable.
- Tests: - Keep user-facing changes documented in README/docs.
- cargo test (all tests pass) - Run clippy and fix warnings.
- Lints:
- cargo clippy --all-targets -- -D warnings (fix warnings)
- Documentation:
- Update README/docs for user-visible changes
- Update CHANGELOG.md if applicable
- Tests for changes:
- Add or update tests for bug fixes and new features where reasonable
Local development tips CI checklist
- Use `cargo run -- <args>` during development. - Build: cargo build --all-targets --locked
- For faster feedback, keep examples in the examples/ folder handy. - Tests: cargo test --all --locked
- Keep functions small and focused; prefer clear error messages with context. - Lints: cargo clippy --all-targets -- -D warnings
- Optional: smoke-run examples inline (from README):
- ./target/release/polyscribe --update-models --no-interaction -q
- ./target/release/polyscribe -o output samples/podcast_clip.mp3
Code of conduct Notes
- Be respectful and constructive. Assume good intent. - For GPU features, use --features gpu-cuda|gpu-hip|gpu-vulkan as needed in your local runs.
- For docs-only changes, please still ensure the project builds.

1341
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,50 @@ members = [
"crates/polyscribe-protocol", "crates/polyscribe-protocol",
"crates/polyscribe-host", "crates/polyscribe-host",
"crates/polyscribe-cli", "crates/polyscribe-cli",
"plugins/polyscribe-plugin-tubescribe",
] ]
resolver = "2" resolver = "3"
[workspace.package]
edition = "2024"
version = "0.1.0"
license = "MIT"
rust-version = "1.89"
# Optional: Keep dependency versions consistent across members
[workspace.dependencies]
thiserror = "1.0.69"
serde = { version = "1.0.219", features = ["derive"] }
anyhow = "1.0.99"
libc = "0.2.175"
toml = "0.8.23"
serde_json = "1.0.142"
chrono = { version = "0.4.41", features = ["serde"] }
sha2 = "0.10.9"
which = "6.0.3"
tokio = { version = "1.47.1", features = ["rt-multi-thread", "macros"] }
clap = { version = "4.5.44", features = ["derive"] }
directories = "5.0.1"
whisper-rs = "0.14.3"
cliclack = "0.3.6"
clap_complete = "4.5.57"
clap_mangen = "0.2.29"
# Additional shared deps used across members
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
reqwest = { version = "0.12.7", default-features = false, features = ["blocking", "rustls-tls", "gzip", "json"] }
hex = "0.4.3"
tempfile = "3.12.0"
assert_cmd = "2.0.16"
[workspace.lints.rust]
unused_imports = "deny"
dead_code = "warn"
[profile.release]
lto = "fat"
codegen-units = 1
panic = "abort"
[profile.dev]
panic = "unwind"

View File

@@ -1,99 +0,0 @@
# Pull Request: PolyScribe workspace + plugin system
This PR refactors the repository into a multi-crate Cargo workspace and adds a minimal, working plugin system scaffold over NDJSON/stdio, while preserving existing CLI behavior. It also introduces a stub plugin `polyscribe-plugin-tubescribe` and documentation updates.
Differences & Adaptations
- The repository already contained most of the workspace and plugin scaffolding; this PR focuses on completing and verifying the setup, fixing a symlink path issue in the plugin Makefile, and adding documentation and minor cleanup.
- Existing CLI commands and flags are preserved; a new `plugins` command group is added (list/info/run) without breaking existing outputs.
## Commits
### 1) chore(workspace): scaffold workspace + move crates
Rationale
- Ensure workspace members and resolver are properly defined. The repository already contained these crates; this commit documents the layout and confirms no absolute paths are used.
Updated files (representative snapshots)
- Cargo.toml (workspace):
```
[workspace]
members = [
"crates/polyscribe-core",
"crates/polyscribe-protocol",
"crates/polyscribe-host",
"crates/polyscribe-cli",
"plugins/polyscribe-plugin-tubescribe",
]
resolver = "2"
```
Repository tree after this commit (abridged)
```
.
├── Cargo.toml
├── crates
│ ├── polyscribe-cli
│ ├── polyscribe-core
│ ├── polyscribe-host
│ └── polyscribe-protocol
└── plugins
└── polyscribe-plugin-tubescribe
```
### 2) feat(plugins): host/stdio runner + CLI plugin commands
Rationale
- Provide plugin discovery and stdio NDJSON JSON-RPC runner in host crate; add `plugins` subcommands to CLI. These were already implemented; this commit verifies and documents behavior.
Updated files (representative snapshots)
- crates/polyscribe-host/src/lib.rs: discover(), capabilities(), run_method().
- crates/polyscribe-cli/src/main.rs: `plugins list|info|run` wired to host, forwarding progress.
Repository tree after this commit: unchanged from above.
### 3) feat(plugin): add stub polyscribe-plugin-tubescribe + docs
Rationale (risky change explained)
- Fixed a symlink path issue in the Makefile by switching from $(PWD) to $(CURDIR) to avoid brittle relative paths. This ensures discovery finds the plugin consistently on all shells.
- Removed an unused import to keep clippy clean.
- Added README docs covering workspace layout and verification commands.
Updated files (full contents included in repo):
- plugins/polyscribe-plugin-tubescribe/Makefile
- plugins/polyscribe-plugin-tubescribe/src/main.rs
- README.md (appended Workspace & Plugins section)
Repository tree after this commit (abridged)
```
.
├── Cargo.toml
├── README.md
├── crates
│ ├── polyscribe-cli
│ ├── polyscribe-core
│ ├── polyscribe-host
│ └── polyscribe-protocol
└── plugins
└── polyscribe-plugin-tubescribe
├── Cargo.toml
├── Makefile
└── src/main.rs
```
## Verification commands
- Build the workspace:
- cargo build --workspace --all-targets
- Show CLI help and plugin subcommands:
- cargo run -p polyscribe-cli -- --help
- Discover plugins (before linking, likely empty):
- cargo run -p polyscribe-cli -- plugins list
- Build and link the stub plugin:
- make -C plugins/polyscribe-plugin-tubescribe link
- Discover again:
- cargo run -p polyscribe-cli -- plugins list
- Show plugin capabilities:
- cargo run -p polyscribe-cli -- plugins info tubescribe
- Run a plugin command and observe progress + JSON result:
- cargo run -p polyscribe-cli -- plugins run tubescribe generate_metadata --json '{"input":{"kind":"text","summary":"hello world"}}'
All acceptance checks pass locally.

151
README.md
View File

@@ -1,127 +1,68 @@
# PolyScribe # PolyScribe
PolyScribe is a fast, local-first CLI for transcribing audio/video and merging existing JSON transcripts. It uses whisper-rs under the hood, can discover and download Whisper models automatically, and supports CPU and optional GPU backends (CUDA, ROCm/HIP, Vulkan). Local-first transcription and plugins.
Key features ## Features
- Transcribe audio and common video files using ffmpeg for audio extraction.
- Merge multiple JSON transcripts, or merge and also keep per-file outputs.
- Model management: interactive downloader and non-interactive updater with hash verification.
- GPU backend selection at runtime; auto-detects available accelerators.
- Clean outputs (JSON and SRT), speaker naming prompts, and useful logging controls.
Prerequisites - **Local-first**: Works offline with downloaded models
- Rust toolchain (rustup recommended) - **Multiple backends**: CPU, CUDA, ROCm/HIP, and Vulkan support
- ffmpeg available on PATH - **Plugin system**: Extensible via JSON-RPC plugins
- Optional for GPU acceleration at runtime: CUDA, ROCm/HIP, or Vulkan drivers (match your build features) - **Model management**: Automatic download and verification of Whisper models
- **Manifest caching**: Local cache for Hugging Face model manifests to reduce network requests
Installation ## Model Management
- Build from source (CPU-only by default):
- rustup install stable
- rustup default stable
- cargo build --release
- Binary path: ./target/release/polyscribe
- GPU builds (optional): build with features
- CUDA: cargo build --release --features gpu-cuda
- HIP: cargo build --release --features gpu-hip
- Vulkan: cargo build --release --features gpu-vulkan
Quickstart PolyScribe automatically manages Whisper models from Hugging Face:
1) Download a model (first run can prompt you):
- ./target/release/polyscribe --download-models
- In the interactive picker, use Up/Down to navigate, Space to toggle selections, and Enter to confirm. Models are grouped by base (e.g., tiny, base, small).
2) Transcribe a file: ```bash
- ./target/release/polyscribe -v -o output my_audio.mp3 # Download models interactively
This writes JSON and SRT into the output directory with a date prefix. polyscribe models download
Shell completions and man page # Update existing models
- Completions: ./target/release/polyscribe completions <bash|zsh|fish|powershell|elvish> > polyscribe.<ext> polyscribe models update
- Then install into your shells completion directory.
- Man page: ./target/release/polyscribe man > polyscribe.1 (then copy to your manpath)
Model locations # Clear manifest cache (force fresh fetch)
- Development (debug builds): ./models next to the project. polyscribe models clear-cache
- Packaged/release builds: $XDG_DATA_HOME/polyscribe/models or ~/.local/share/polyscribe/models. ```
- Override via env var: POLYSCRIBE_MODELS_DIR=/path/to/models.
- Force a specific model file via env var: WHISPER_MODEL=/path/to/model.bin.
Most-used CLI flags ### Manifest Caching
- -o, --output FILE_OR_DIR: Output path base (date prefix added). If omitted, JSON prints to stdout.
- -m, --merge: Merge all inputs into one output; otherwise one output per input.
- --merge-and-separate: Write both merged output and separate per-input outputs (requires -o dir).
- --set-speaker-names: Prompt for a speaker label per input file.
- --update-models: Verify/update local models by size/hash against the upstream manifest.
- --download-models: Interactive model list + multi-select download.
- --language LANG: Language code hint (e.g., en, de). English-only models reject non-en hints.
- --gpu-backend [auto|cpu|cuda|hip|vulkan]: Select backend (auto by default).
- --gpu-layers N: Offload N layers to GPU when supported.
- -v/--verbose (repeatable): Increase log verbosity. -vv shows very detailed logs.
- -q/--quiet: Suppress non-error logs (stderr); does not silence stdout results.
- --no-interaction: Never prompt; suitable for CI.
Minimal usage examples The Hugging Face model manifest is cached locally to avoid repeated network requests:
- Transcribe an audio file to JSON/SRT:
- ./target/release/polyscribe -o output samples/podcast_clip.mp3
- Merge multiple transcripts into one:
- ./target/release/polyscribe -m -o output merged input/a.json input/b.json
- Update local models non-interactively (good for CI):
- ./target/release/polyscribe --update-models --no-interaction -q
Troubleshooting & docs - **Default TTL**: 24 hours
- docs/faq.md common issues and solutions (missing ffmpeg, GPU selection, model paths) - **Cache location**: `$XDG_CACHE_HOME/polyscribe/manifest/` (or platform equivalent)
- docs/usage.md complete CLI reference and workflows - **Environment variables**:
- docs/development.md build, run, and contribute locally - `POLYSCRIBE_NO_CACHE_MANIFEST=1`: Disable caching
- docs/design.md architecture overview and decisions - `POLYSCRIBE_MANIFEST_TTL_SECONDS=3600`: Set custom TTL (in seconds)
- docs/release-packaging.md packaging notes for distributions
- docs/ci.md minimal CI checklist and job outline
- CONTRIBUTING.md PR checklist and workflow
CI status: [CI badge placeholder] ## Installation
Examples ```bash
See the examples/ directory for copy-paste scripts: cargo install --path .
- examples/transcribe_file.sh ```
- examples/update_models.sh
- examples/download_models_interactive.sh
License ## Usage
-------
This project is licensed under the MIT License — see the LICENSE file for details.
```bash
# Transcribe audio/video
polyscribe transcribe input.mp4
--- # Merge multiple transcripts
polyscribe transcribe --merge input1.json input2.json
Workspace layout # Use specific GPU backend
- This repo is a Cargo workspace using resolver = "2". polyscribe transcribe --gpu-backend cuda input.mp4
- Members: ```
- crates/polyscribe-core — types, errors, config service, core helpers.
- crates/polyscribe-protocol — PSP/1 serde types for NDJSON over stdio.
- crates/polyscribe-host — plugin discovery/runner, progress forwarding.
- crates/polyscribe-cli — the CLI, using host + core.
- plugins/polyscribe-plugin-tubescribe — stub plugin used for verification.
Build and run ## Development
- Build all: cargo build --workspace --all-targets
- CLI help: cargo run -p polyscribe-cli -- --help
Plugins ```bash
- Build and link the example plugin into your XDG data plugin dir: # Build
- make -C plugins/polyscribe-plugin-tubescribe link cargo build
- This creates a symlink at: $XDG_DATA_HOME/polyscribe/plugins/polyscribe-plugin-tubescribe (defaults to ~/.local/share on Linux).
- Discover installed plugins:
- cargo run -p polyscribe-cli -- plugins list
- Show a plugin's capabilities:
- cargo run -p polyscribe-cli -- plugins info tubescribe
- Run a plugin command (JSON-RPC over NDJSON via stdio):
- cargo run -p polyscribe-cli -- plugins run tubescribe generate_metadata --json '{"input":{"kind":"text","summary":"hello world"}}'
Verification commands # Run tests
- The above commands are used for acceptance; expected behavior: cargo test
- plugins list shows "tubescribe" once linked.
- plugins info tubescribe prints JSON capabilities.
- plugins run ... prints progress events and a JSON result.
Notes # Run with verbose logging
- No absolute paths are hardcoded; config and plugin dirs respect XDG on Linux and platform equivalents via directories. cargo run -- --verbose transcribe input.mp4
- Plugins must be non-interactive (no TTY prompts). All interaction stays in the host/CLI. ```
- Config files are written atomically and support env overrides: POLYSCRIBE__SECTION__KEY=value.

View File

@@ -1,16 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
fn main() {
// Only run special build steps when gpu-vulkan feature is enabled.
let vulkan_enabled = std::env::var("CARGO_FEATURE_GPU_VULKAN").is_ok();
if !vulkan_enabled {
return;
}
// Placeholder: In a full implementation, we would invoke CMake for whisper.cpp with GGML_VULKAN=1.
// For now, emit a helpful note. Build will proceed; runtime Vulkan backend returns an explanatory error.
println!("cargo:rerun-if-changed=extern/whisper.cpp");
println!(
"cargo:warning=Building with gpu-vulkan: ensure Vulkan SDK/loader are installed. Future versions will compile whisper.cpp via CMake."
);
}

View File

@@ -1,24 +1,32 @@
[package] [package]
name = "polyscribe-cli" name = "polyscribe-cli"
version = "0.1.0" version.workspace = true
edition = "2024" edition.workspace = true
license = "MIT"
[[bin]] [[bin]]
name = "polyscribe" name = "polyscribe"
path = "src/main.rs" path = "src/main.rs"
[dependencies] [dependencies]
anyhow = "1.0.98" anyhow = { workspace = true }
clap = { version = "4.5.43", features = ["derive"] } clap = { workspace = true, features = ["derive"] }
clap_complete = "4.5.28" clap_complete = { workspace = true }
clap_mangen = "0.2" clap_mangen = { workspace = true }
serde = { version = "1.0.219", features = ["derive"] } directories = { workspace = true }
serde_json = "1.0.142" serde = { workspace = true, features = ["derive"] }
toml = "0.8" serde_json = { workspace = true }
chrono = { version = "0.4", features = ["clock"] } tokio = { workspace = true, features = ["rt-multi-thread", "macros", "process", "fs"] }
cliclack = "0.3" tracing = { workspace = true }
indicatif = "0.17" tracing-subscriber = { workspace = true, features = ["fmt", "env-filter"] }
polyscribe = { path = "../polyscribe-core" } which = { workspace = true }
polyscribe-core = { path = "../polyscribe-core" }
polyscribe-host = { path = "../polyscribe-host" } polyscribe-host = { path = "../polyscribe-host" }
polyscribe-protocol = { path = "../polyscribe-protocol" } polyscribe-protocol = { path = "../polyscribe-protocol" }
[features]
# Optional GPU-specific flags can be forwarded down to core/host if needed
default = []
[dev-dependencies]
assert_cmd = { workspace = true }

View File

@@ -0,0 +1,191 @@
use clap::{Args, Parser, Subcommand, ValueEnum};
use std::path::PathBuf;
#[derive(Debug, Clone, ValueEnum)]
pub enum GpuBackend {
Auto,
Cpu,
Cuda,
Hip,
Vulkan,
}
#[derive(Debug, Clone, Args)]
pub struct OutputOpts {
/// Emit machine-readable JSON to stdout; suppress decorative logs
#[arg(long, global = true, action = clap::ArgAction::SetTrue)]
pub json: bool,
/// Reduce log chatter (errors only unless --json)
#[arg(long, global = true, action = clap::ArgAction::SetTrue)]
pub quiet: bool,
}
#[derive(Debug, Parser)]
#[command(
name = "polyscribe",
version,
about = "PolyScribe local-first transcription and plugins",
propagate_version = true,
arg_required_else_help = true,
)]
pub struct Cli {
/// Global output options
#[command(flatten)]
pub output: OutputOpts,
/// Increase verbosity (-v, -vv)
#[arg(short, long, action = clap::ArgAction::Count)]
pub verbose: u8,
/// Never prompt for user input (non-interactive mode)
#[arg(long, default_value_t = false)]
pub no_interaction: bool,
/// Disable progress bars/spinners
#[arg(long, default_value_t = false)]
pub no_progress: bool,
#[command(subcommand)]
pub command: Commands,
}
#[derive(Debug, Subcommand)]
pub enum Commands {
/// Transcribe audio/video files or merge existing transcripts
Transcribe {
/// Output file or directory (date prefix is added when directory)
#[arg(short, long)]
output: Option<PathBuf>,
/// Merge multiple inputs into one output
#[arg(short = 'm', long, default_value_t = false)]
merge: bool,
/// Write both merged and per-input outputs (requires -o dir)
#[arg(long, default_value_t = false)]
merge_and_separate: bool,
/// Language code hint, e.g. en, de
#[arg(long)]
language: Option<String>,
/// Prompt for a speaker label per input file
#[arg(long, default_value_t = false)]
set_speaker_names: bool,
/// GPU backend selection
#[arg(long, value_enum, default_value_t = GpuBackend::Auto)]
gpu_backend: GpuBackend,
/// Offload N layers to GPU (when supported)
#[arg(long, default_value_t = 0)]
gpu_layers: usize,
/// Input paths: audio/video files or JSON transcripts
#[arg(required = true)]
inputs: Vec<PathBuf>,
},
/// Manage Whisper GGUF models (Hugging Face)
Models {
#[command(subcommand)]
cmd: ModelsCmd,
},
/// Discover and run plugins
Plugins {
#[command(subcommand)]
cmd: PluginsCmd,
},
/// Generate shell completions to stdout
Completions {
/// Shell to generate completions for
#[arg(value_parser = ["bash", "zsh", "fish", "powershell", "elvish"])]
shell: String,
},
/// Generate a man page to stdout
Man,
}
#[derive(Debug, Clone, Parser)]
pub struct ModelCommon {
/// Concurrency for ranged downloads
#[arg(long, default_value_t = 4)]
pub concurrency: usize,
/// Limit download rate in bytes/sec (approximate)
#[arg(long)]
pub limit_rate: Option<u64>,
}
#[derive(Debug, Subcommand)]
pub enum ModelsCmd {
/// List installed models (from manifest)
Ls {
#[command(flatten)]
common: ModelCommon,
},
/// Add or update a model
Add {
/// Hugging Face repo, e.g. ggml-org/models
repo: String,
/// File name in repo (e.g., gguf-tiny-q4_0.bin)
file: String,
#[command(flatten)]
common: ModelCommon,
},
/// Remove a model by alias
Rm {
alias: String,
#[command(flatten)]
common: ModelCommon,
},
/// Verify model file integrity by alias
Verify {
alias: String,
#[command(flatten)]
common: ModelCommon,
},
/// Update all models (HEAD + ETag; skip if unchanged)
Update {
#[command(flatten)]
common: ModelCommon,
},
/// Garbage-collect unreferenced files and stale manifest entries
Gc {
#[command(flatten)]
common: ModelCommon,
},
/// Search a repo for GGUF files
Search {
/// Hugging Face repo, e.g. ggml-org/models
repo: String,
/// Optional substring to filter filenames
#[arg(long)]
query: Option<String>,
#[command(flatten)]
common: ModelCommon,
},
}
#[derive(Debug, Subcommand)]
pub enum PluginsCmd {
/// List installed plugins
List,
/// Show a plugin's capabilities (as JSON)
Info {
/// Plugin short name, e.g., "tubescribe"
name: String,
},
/// Run a plugin command (JSON-RPC over NDJSON via stdio)
Run {
/// Plugin short name
name: String,
/// Command name in plugin's API
command: String,
/// JSON payload string
#[arg(long)]
json: Option<String>,
},
}

View File

@@ -1,536 +1,470 @@
// SPDX-License-Identifier: MIT mod cli;
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved. mod output;
use std::fs::{File, create_dir_all};
use std::io::{self, Read, Write};
use std::path::{Path, PathBuf};
use anyhow::{Context, Result, anyhow}; use anyhow::{Context, Result, anyhow};
use clap::{Parser, Subcommand, ValueEnum, CommandFactory}; use clap::{CommandFactory, Parser};
use clap_complete::Shell; use cli::{Cli, Commands, GpuBackend, ModelsCmd, ModelCommon, PluginsCmd};
use serde::{Deserialize, Serialize}; use output::OutputMode;
use polyscribe_core::model_manager::{ModelManager, Settings, ReqwestClient};
use polyscribe::{OutputEntry, date_prefix, normalize_lang_code, render_srt}; use polyscribe_core::ui;
use polyscribe_host as host; fn normalized_similarity(a: &str, b: &str) -> f64 {
// simple Levenshtein distance; normalized to [0,1]
#[derive(Subcommand, Debug, Clone)] let a_bytes = a.as_bytes();
enum PluginsCmd { let b_bytes = b.as_bytes();
/// List available plugins let n = a_bytes.len();
List, let m = b_bytes.len();
/// Show plugin capabilities if n == 0 && m == 0 { return 1.0; }
Info { name: String }, if n == 0 || m == 0 { return 0.0; }
/// Run a plugin command with a JSON payload let mut prev: Vec<usize> = (0..=m).collect();
Run { let mut curr: Vec<usize> = vec![0; m + 1];
name: String, for i in 1..=n {
command: String, curr[0] = i;
/// JSON payload string passed to the plugin as request.params for j in 1..=m {
#[arg(long = "json")] let cost = if a_bytes[i - 1] == b_bytes[j - 1] { 0 } else { 1 };
json: String, curr[j] = (prev[j] + 1)
}, .min(curr[j - 1] + 1)
} .min(prev[j - 1] + cost);
#[derive(Subcommand, Debug, Clone)]
enum Command {
Completions { #[arg(value_enum)] shell: Shell },
Man,
Plugins { #[command(subcommand)] cmd: PluginsCmd },
}
#[derive(ValueEnum, Debug, Clone, Copy)]
#[value(rename_all = "kebab-case")]
enum GpuBackendCli {
Auto,
Cpu,
Cuda,
Hip,
Vulkan,
}
#[derive(Parser, Debug)]
#[command(
name = "PolyScribe",
bin_name = "polyscribe",
version,
about = "Merge JSON transcripts or transcribe audio using native whisper"
)]
struct Args {
/// Increase verbosity (-v, -vv). Repeat to increase.
/// Debug logs appear with -v; very verbose with -vv. Logs go to stderr.
#[arg(short = 'v', long = "verbose", action = clap::ArgAction::Count, global = true)]
verbose: u8,
/// Quiet mode: suppress non-error logging on stderr (overrides -v)
/// Does not suppress interactive prompts or stdout output.
#[arg(short = 'q', long = "quiet", global = true)]
quiet: bool,
/// Non-interactive mode: never prompt; use defaults instead.
#[arg(long = "no-interaction", global = true)]
no_interaction: bool,
/// Disable interactive progress indicators (bars/spinners)
#[arg(long = "no-progress", global = true)]
no_progress: bool,
/// Optional subcommands (completions, man, plugins)
#[command(subcommand)]
cmd: Option<Command>,
/// Input .json transcript files or audio files to merge/transcribe
inputs: Vec<String>,
/// Output file path base or directory (date prefix added).
/// In merge mode: base path.
/// In separate mode: directory.
/// If omitted: prints JSON to stdout for merge mode; separate mode requires directory for multiple inputs.
#[arg(short, long, value_name = "FILE")]
output: Option<String>,
/// Merge all inputs into a single output; if not set, each input is written as a separate output
#[arg(short = 'm', long = "merge")]
merge: bool,
/// Merge and also write separate outputs per input; requires -o OUTPUT_DIR
#[arg(long = "merge-and-separate")]
merge_and_separate: bool,
/// Prompt for speaker names per input file
#[arg(long = "set-speaker-names")]
set_speaker_names: bool,
/// Language code to use for transcription (e.g., en, de). No auto-detection.
#[arg(short, long, value_name = "LANG")]
language: Option<String>,
/// Launch interactive model downloader (list HF models, multi-select and download)
#[arg(long)]
download_models: bool,
/// Update local Whisper models by comparing hashes/sizes with remote manifest
#[arg(long)]
update_models: bool,
}
#[derive(Debug, Deserialize)]
struct InputRoot {
#[serde(default)]
segments: Vec<InputSegment>,
}
#[derive(Debug, Deserialize)]
struct InputSegment {
start: f64,
end: f64,
text: String,
}
#[derive(Debug, Serialize)]
struct OutputRoot {
items: Vec<OutputEntry>,
}
fn is_json_file(path: &Path) -> bool {
matches!(path.extension().and_then(|s| s.to_str()).map(|s| s.to_lowercase()), Some(ext) if ext == "json")
}
fn is_audio_file(path: &Path) -> bool {
if let Some(ext) = path.extension().and_then(|s| s.to_str()).map(|s| s.to_lowercase()) {
let exts = [
"mp3", "wav", "m4a", "mp4", "aac", "flac", "ogg", "wma", "webm", "mkv", "mov", "avi",
"m4b", "3gp", "opus", "aiff", "alac",
];
return exts.contains(&ext.as_str());
}
false
}
fn validate_input_path(path: &Path) -> anyhow::Result<()> {
let display = path.display();
if !path.exists() {
return Err(anyhow!("Input not found: {}", display));
}
let metadata = std::fs::metadata(path).with_context(|| format!("Failed to stat input: {}", display))?;
if metadata.is_dir() {
return Err(anyhow!("Input is a directory (expected a file): {}", display));
}
std::fs::File::open(path)
.with_context(|| format!("Failed to open input file: {}", display))
.map(|_| ())
}
fn sanitize_speaker_name(raw: &str) -> String {
if let Some((prefix, rest)) = raw.split_once('-') {
if !prefix.is_empty() && prefix.chars().all(|c| c.is_ascii_digit()) {
return rest.to_string();
} }
std::mem::swap(&mut prev, &mut curr);
} }
raw.to_string() let dist = prev[m] as f64;
let max_len = n.max(m) as f64;
1.0 - (dist / max_len)
} }
fn prompt_speaker_name_for_path( fn human_size(bytes: Option<u64>) -> String {
_path: &Path, match bytes {
default_name: &str, Some(n) => {
enabled: bool, let x = n as f64;
) -> String { const KB: f64 = 1024.0;
if !enabled || polyscribe::is_no_interaction() { const MB: f64 = 1024.0 * KB;
return sanitize_speaker_name(default_name); const GB: f64 = 1024.0 * MB;
} if x >= GB { format!("{:.2} GiB", x / GB) }
// TODO implement cliclack for this else if x >= MB { format!("{:.2} MiB", x / MB) }
let mut input_line = String::new(); else if x >= KB { format!("{:.2} KiB", x / KB) }
match std::io::stdin().read_line(&mut input_line) { else { format!("{} B", n) }
Ok(_) => {
let trimmed = input_line.trim();
if trimmed.is_empty() {
sanitize_speaker_name(default_name)
} else {
sanitize_speaker_name(trimmed)
}
} }
Err(_) => sanitize_speaker_name(default_name), None => "?".to_string(),
} }
} }
use polyscribe_core::ui::progress::ProgressReporter;
use polyscribe_host::PluginManager;
use tokio::io::AsyncWriteExt;
use tracing_subscriber::EnvFilter;
fn handle_plugins(cmd: PluginsCmd) -> Result<()> { fn init_tracing(json_mode: bool, quiet: bool, verbose: u8) {
match cmd { // In JSON mode, suppress human logs; route errors to stderr only.
PluginsCmd::List => { let level = if json_mode || quiet { "error" } else { match verbose { 0 => "info", 1 => "debug", _ => "trace" } };
let list = host::discover()?; let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(level));
for p in list { tracing_subscriber::fmt()
println!("{}\t{}", p.name, p.path.display()); .with_env_filter(filter)
} .with_target(false)
Ok(()) .with_level(true)
} .with_writer(std::io::stderr)
PluginsCmd::Info { name } => { .compact()
let p = host::find_plugin_by_name(&name)?; .init();
let caps = host::capabilities(&p.path)?;
println!("{}", serde_json::to_string_pretty(&caps)?);
Ok(())
}
PluginsCmd::Run { name, command, json } => {
let p = host::find_plugin_by_name(&name)?;
let params: serde_json::Value = serde_json::from_str(&json).context("--json payload must be valid JSON")?;
let mut last_pct = 0u8;
let result = host::run_method(&p.path, &command, params, |prog| {
// Render minimal progress
let stage = prog.stage.as_deref().unwrap_or("");
let msg = prog.message.as_deref().unwrap_or("");
if prog.pct != last_pct {
let _ = cliclack::log::info(format!("[{}%] {} {}", prog.pct, stage, msg).trim());
last_pct = prog.pct;
}
})?;
println!("{}", serde_json::to_string_pretty(&result)?);
Ok(())
}
}
} }
fn main() -> Result<()> { fn main() -> Result<()> {
let args = Args::parse(); let args = Cli::parse();
// Initialize runtime flags for the library // Determine output mode early for logging and UI configuration
polyscribe::set_verbose(args.verbose); let output_mode = if args.output.json {
polyscribe::set_quiet(args.quiet); OutputMode::Json
polyscribe::set_no_interaction(args.no_interaction); } else {
polyscribe::set_no_progress(args.no_progress); OutputMode::Human { quiet: args.output.quiet }
};
// Handle subcommands init_tracing(matches!(output_mode, OutputMode::Json), args.output.quiet, args.verbose);
if let Some(cmd) = &args.cmd {
match cmd.clone() { // Suppress decorative UI output in JSON mode as well
Command::Completions { shell } => { polyscribe_core::set_quiet(args.output.quiet || matches!(output_mode, OutputMode::Json));
let mut cmd = Args::command(); polyscribe_core::set_no_interaction(args.no_interaction);
let bin_name = cmd.get_name().to_string(); polyscribe_core::set_verbose(args.verbose);
clap_complete::generate(shell, &mut cmd, bin_name, &mut io::stdout()); polyscribe_core::set_no_progress(args.no_progress);
return Ok(());
match args.command {
Commands::Transcribe {
gpu_backend,
gpu_layers,
inputs,
..
} => {
polyscribe_core::ui::info("starting transcription workflow");
let mut progress = ProgressReporter::new(args.no_interaction);
progress.step("Validating inputs");
if inputs.is_empty() {
return Err(anyhow!("no inputs provided"));
} }
Command::Man => {
let cmd = Args::command(); progress.step("Selecting backend and preparing model");
let man = clap_mangen::Man::new(cmd); match gpu_backend {
let mut man_bytes = Vec::new(); GpuBackend::Auto => {}
man.render(&mut man_bytes)?; GpuBackend::Cpu => {}
io::stdout().write_all(&man_bytes)?; GpuBackend::Cuda => {
return Ok(()); let _ = gpu_layers;
}
GpuBackend::Hip => {}
GpuBackend::Vulkan => {}
} }
Command::Plugins { cmd } => {
return handle_plugins(cmd);
}
}
}
// Optional model management actions progress.finish_with_message("Transcription completed (stub)");
if args.download_models { Ok(())
if let Err(err) = polyscribe::models::run_interactive_model_downloader() {
polyscribe::elog!("Model downloader failed: {:#}", err);
}
if args.inputs.is_empty() {
return Ok(())
}
}
if args.update_models {
if let Err(err) = polyscribe::models::update_local_models() {
polyscribe::elog!("Model update failed: {:#}", err);
return Err(err);
}
if args.inputs.is_empty() {
return Ok(())
}
}
// Process inputs
let mut inputs = args.inputs;
if inputs.is_empty() {
return Err(anyhow!("No input files provided"));
}
// If last arg looks like an output path and not existing file, accept it as -o when multiple inputs
let mut output_path = args.output;
if output_path.is_none() && inputs.len() >= 2 {
if let Some(candidate_output) = inputs.last().cloned() {
if !Path::new(&candidate_output).exists() {
inputs.pop();
output_path = Some(candidate_output);
}
}
}
// Validate inputs; allow JSON and audio. For audio, require --language.
for input_arg in &inputs {
let path_ref = Path::new(input_arg);
validate_input_path(path_ref)?;
if !(is_json_file(path_ref) || is_audio_file(path_ref)) {
return Err(anyhow!(
"Unsupported input type (expected .json transcript or audio media): {}",
path_ref.display()
));
}
if is_audio_file(path_ref) && args.language.is_none() {
return Err(anyhow!("Please specify --language (e.g., --language en). Language detection was removed."));
}
}
// Derive speakers (prompt if requested)
let speakers: Vec<String> = inputs
.iter()
.map(|input_path| {
let path = Path::new(input_path);
let default_speaker = sanitize_speaker_name(
path.file_stem().and_then(|s| s.to_str()).unwrap_or("speaker"),
);
prompt_speaker_name_for_path(path, &default_speaker, args.set_speaker_names)
})
.collect();
// MERGE-AND-SEPARATE mode
if args.merge_and_separate {
polyscribe::dlog!(1, "Mode: merge-and-separate; output_dir={:?}", output_path);
let out_dir = match output_path.as_ref() {
Some(p) => PathBuf::from(p),
None => return Err(anyhow!("--merge-and-separate requires -o OUTPUT_DIR")),
};
if !out_dir.as_os_str().is_empty() {
create_dir_all(&out_dir).with_context(|| {
format!("Failed to create output directory: {}", out_dir.display())
})?;
} }
let mut merged_entries: Vec<OutputEntry> = Vec::new(); Commands::Models { cmd } => {
for (idx, input_path) in inputs.iter().enumerate() { // predictable exit codes
let path = Path::new(input_path); const EXIT_OK: i32 = 0;
let speaker = speakers[idx].clone(); const EXIT_NOT_FOUND: i32 = 2;
// Decide based on input type (JSON transcript vs audio to transcribe) const EXIT_NETWORK: i32 = 3;
// TODO remove duplicate const EXIT_VERIFY_FAILED: i32 = 4;
let mut entries: Vec<OutputEntry> = if is_json_file(path) { // const EXIT_NO_CHANGE: i32 = 5; // reserved
let mut buf = String::new();
File::open(path) let handle_common = |c: &ModelCommon| Settings {
.with_context(|| format!("Failed to open: {input_path}"))? concurrency: c.concurrency.max(1),
.read_to_string(&mut buf) limit_rate: c.limit_rate,
.with_context(|| format!("Failed to read: {input_path}"))?; ..Default::default()
let root: InputRoot = serde_json::from_str(&buf)
.with_context(|| format!("Invalid JSON transcript parsed from {input_path}"))?;
root
.segments
.into_iter()
.map(|seg| OutputEntry { id: 0, speaker: speaker.clone(), start: seg.start, end: seg.end, text: seg.text })
.collect()
} else {
let lang_norm: Option<String> = args.language.as_deref().and_then(|s| normalize_lang_code(s));
let selected_backend = polyscribe::backend::select_backend(polyscribe::backend::BackendKind::Auto, args.verbose > 0)?;
selected_backend.backend.transcribe(path, &speaker, lang_norm.as_deref(), None, None)?
}; };
// Sort and id per-file
// TODO remove duplicate
entries.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap_or(std::cmp::Ordering::Equal)
.then(a.end.partial_cmp(&b.end).unwrap_or(std::cmp::Ordering::Equal)));
for (i, entry) in entries.iter_mut().enumerate() { entry.id = i as u64; }
// Write per-file outputs
let stem = path.file_stem().and_then(|s| s.to_str()).unwrap_or("output");
let date = date_prefix();
let base_name = format!("{date}_{stem}");
let json_path = out_dir.join(format!("{}.json", &base_name));
let toml_path = out_dir.join(format!("{}.toml", &base_name));
let srt_path = out_dir.join(format!("{}.srt", &base_name));
let output_bundle = OutputRoot { items: entries.clone() }; let exit = match cmd {
let mut json_file = File::create(&json_path).with_context(|| format!("Failed to create output file: {}", json_path.display()))?; ModelsCmd::Ls { common } => {
serde_json::to_writer_pretty(&mut json_file, &output_bundle)?; writeln!(&mut json_file)?; let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
let toml_str = toml::to_string_pretty(&output_bundle)?; let list = mm.ls()?;
let mut toml_file = File::create(&toml_path).with_context(|| format!("Failed to create output file: {}", toml_path.display()))?; match output_mode {
toml_file.write_all(toml_str.as_bytes())?; if !toml_str.ends_with('\n') { writeln!(&mut toml_file)?; } OutputMode::Json => {
let srt_str = render_srt(&output_bundle.items); // Always emit JSON array (possibly empty)
let mut srt_file = File::create(&srt_path).with_context(|| format!("Failed to create output file: {}", srt_path.display()))?; output_mode.print_json(&list);
srt_file.write_all(srt_str.as_bytes())?; }
OutputMode::Human { quiet } => {
merged_entries.extend(output_bundle.items.into_iter()); if list.is_empty() {
} if !quiet { println!("No models installed."); }
// Write merged outputs into out_dir } else {
// TODO remove duplicate if !quiet { println!("Model (Repo)"); }
merged_entries.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap_or(std::cmp::Ordering::Equal) for r in list {
.then(a.end.partial_cmp(&b.end).unwrap_or(std::cmp::Ordering::Equal))); if !quiet { println!("{} ({})", r.file, r.repo); }
for (index, entry) in merged_entries.iter_mut().enumerate() { entry.id = index as u64; } }
let merged_output = OutputRoot { items: merged_entries }; }
let date = date_prefix(); }
let merged_base = format!("{date}_merged"); }
let merged_json_path = out_dir.join(format!("{}.json", &merged_base)); EXIT_OK
let merged_toml_path = out_dir.join(format!("{}.toml", &merged_base));
let merged_srt_path = out_dir.join(format!("{}.srt", &merged_base));
let mut merged_json_file = File::create(&merged_json_path).with_context(|| format!("Failed to create output file: {}", merged_json_path.display()))?;
serde_json::to_writer_pretty(&mut merged_json_file, &merged_output)?; writeln!(&mut merged_json_file)?;
let merged_toml_str = toml::to_string_pretty(&merged_output)?;
let mut merged_toml_file = File::create(&merged_toml_path).with_context(|| format!("Failed to create output file: {}", merged_toml_path.display()))?;
merged_toml_file.write_all(merged_toml_str.as_bytes())?; if !merged_toml_str.ends_with('\n') { writeln!(&mut merged_toml_file)?; }
let merged_srt_str = render_srt(&merged_output.items);
let mut merged_srt_file = File::create(&merged_srt_path).with_context(|| format!("Failed to create output file: {}", merged_srt_path.display()))?;
merged_srt_file.write_all(merged_srt_str.as_bytes())?;
return Ok(());
}
// MERGE mode
if args.merge {
polyscribe::dlog!(1, "Mode: merge; output_base={:?}", output_path);
let mut entries: Vec<OutputEntry> = Vec::new();
for (index, input_path) in inputs.iter().enumerate() {
let path = Path::new(input_path);
let speaker = speakers[index].clone();
if is_json_file(path) {
let mut buf = String::new();
File::open(path)
.with_context(|| format!("Failed to open: {}", input_path))?
.read_to_string(&mut buf)
.with_context(|| format!("Failed to read: {}", input_path))?;
let root: InputRoot = serde_json::from_str(&buf)
.with_context(|| format!("Invalid JSON transcript parsed from {}", input_path))?;
for seg in root.segments {
entries.push(OutputEntry { id: 0, speaker: speaker.clone(), start: seg.start, end: seg.end, text: seg.text });
} }
} else { ModelsCmd::Add { repo, file, common } => {
let lang_norm: Option<String> = args.language.as_deref().and_then(|s| normalize_lang_code(s)); let settings = handle_common(&common);
let selected_backend = polyscribe::backend::select_backend(polyscribe::backend::BackendKind::Auto, args.verbose > 0)?; let mm: ModelManager<ReqwestClient> = ModelManager::new(settings.clone())?;
let mut new_entries = selected_backend.backend.transcribe(path, &speaker, lang_norm.as_deref(), None, None)?; // Derive an alias automatically from repo and file
entries.append(&mut new_entries); fn derive_alias(repo: &str, file: &str) -> String {
} use std::path::Path;
let repo_tail = repo.rsplit('/').next().unwrap_or(repo);
let stem = Path::new(file)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or(file);
format!("{}-{}", repo_tail, stem)
}
let alias = derive_alias(&repo, &file);
match mm.add_or_update(&alias, &repo, &file) {
Ok(rec) => {
match output_mode {
OutputMode::Json => output_mode.print_json(&rec),
OutputMode::Human { quiet } => {
if !quiet { println!("installed: {} -> {}/{}", alias, repo, rec.file); }
}
}
EXIT_OK
}
Err(e) => {
// On not found or similar errors, try suggesting close matches interactively
if matches!(output_mode, OutputMode::Json) || polyscribe_core::is_no_interaction() {
match output_mode {
OutputMode::Json => {
// Emit error JSON object
#[derive(serde::Serialize)]
struct ErrObj<'a> { error: &'a str }
let eo = ErrObj { error: &e.to_string() };
output_mode.print_json(&eo);
}
_ => { eprintln!("error: {e}"); }
}
EXIT_NOT_FOUND
} else {
ui::warn(format!("{}", e));
ui::info("Searching for similar model filenames…");
match polyscribe_core::model_manager::search_repo(&repo, None) {
Ok(mut files) => {
if files.is_empty() {
ui::warn("No files found in repository.");
EXIT_NOT_FOUND
} else {
// rank by similarity
files.sort_by(|a, b| normalized_similarity(&file, b)
.partial_cmp(&normalized_similarity(&file, a))
.unwrap_or(std::cmp::Ordering::Equal));
let top: Vec<String> = files.into_iter().take(5).collect();
if top.is_empty() {
EXIT_NOT_FOUND
} else if top.len() == 1 {
let cand = &top[0];
// Fetch repo size list once
let size_map: std::collections::HashMap<String, Option<u64>> =
polyscribe_core::model_manager::list_repo_files_with_meta(&repo)
.unwrap_or_default()
.into_iter().collect();
let mut size = size_map.get(cand).cloned().unwrap_or(None);
if size.is_none() {
size = polyscribe_core::model_manager::head_len_for_file(&repo, cand);
}
let local_files: std::collections::HashSet<String> = mm.ls()?.into_iter().map(|r| r.file).collect();
let is_local = local_files.contains(cand);
let label = format!("{} [{}]{}", cand, human_size(size), if is_local { " (local)" } else { "" });
let ok = ui::prompt_confirm(&format!("Did you mean {}?", label), true)
.unwrap_or(false);
if !ok { EXIT_NOT_FOUND } else {
let mm2: ModelManager<ReqwestClient> = ModelManager::new(settings)?;
let alias2 = derive_alias(&repo, cand);
match mm2.add_or_update(&alias2, &repo, cand) {
Ok(rec) => {
match output_mode {
OutputMode::Json => output_mode.print_json(&rec),
OutputMode::Human { quiet } => { if !quiet { println!("installed: {} -> {}/{}", alias2, repo, rec.file); } }
}
EXIT_OK
}
Err(e2) => { eprintln!("error: {e2}"); EXIT_NETWORK }
}
}
} else {
let opts: Vec<String> = top;
let local_files: std::collections::HashSet<String> = mm.ls()?.into_iter().map(|r| r.file).collect();
// Enrich labels with size and local tag using a single API call
let size_map: std::collections::HashMap<String, Option<u64>> =
polyscribe_core::model_manager::list_repo_files_with_meta(&repo)
.unwrap_or_default()
.into_iter().collect();
let mut labels_owned: Vec<String> = Vec::new();
for f in &opts {
let mut size = size_map.get(f).cloned().unwrap_or(None);
if size.is_none() {
size = polyscribe_core::model_manager::head_len_for_file(&repo, f);
}
let is_local = local_files.contains(f);
let suffix = if is_local { " (local)" } else { "" };
labels_owned.push(format!("{} [{}]{}", f, human_size(size), suffix));
}
let labels: Vec<&str> = labels_owned.iter().map(|s| s.as_str()).collect();
match ui::prompt_select("Pick a model", &labels) {
Ok(idx) => {
let chosen = &opts[idx];
let mm2: ModelManager<ReqwestClient> = ModelManager::new(settings)?;
let alias2 = derive_alias(&repo, chosen);
match mm2.add_or_update(&alias2, &repo, chosen) {
Ok(rec) => {
match output_mode {
OutputMode::Json => output_mode.print_json(&rec),
OutputMode::Human { quiet } => { if !quiet { println!("installed: {} -> {}/{}", alias2, repo, rec.file); } }
}
EXIT_OK
}
Err(e2) => { eprintln!("error: {e2}"); EXIT_NETWORK }
}
}
Err(_) => EXIT_NOT_FOUND,
}
}
}
}
Err(e2) => {
eprintln!("error: {}", e2);
EXIT_NETWORK
}
}
}
}
}
}
ModelsCmd::Rm { alias, common } => {
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
let ok = mm.rm(&alias)?;
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R { removed: bool }
output_mode.print_json(&R { removed: ok });
}
OutputMode::Human { quiet } => {
if !quiet { println!("{}", if ok { "removed" } else { "not found" }); }
}
}
if ok { EXIT_OK } else { EXIT_NOT_FOUND }
}
ModelsCmd::Verify { alias, common } => {
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
let found = mm.ls()?.into_iter().any(|r| r.alias == alias);
if !found {
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R<'a> { ok: bool, error: &'a str }
output_mode.print_json(&R { ok: false, error: "not found" });
}
OutputMode::Human { quiet } => { if !quiet { println!("not found"); } }
}
EXIT_NOT_FOUND
} else {
let ok = mm.verify(&alias)?;
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R { ok: bool }
output_mode.print_json(&R { ok });
}
OutputMode::Human { quiet } => { if !quiet { println!("{}", if ok { "ok" } else { "corrupt" }); } }
}
if ok { EXIT_OK } else { EXIT_VERIFY_FAILED }
}
}
ModelsCmd::Update { common } => {
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
let mut rc = EXIT_OK;
for rec in mm.ls()? {
match mm.add_or_update(&rec.alias, &rec.repo, &rec.file) {
Ok(_) => {}
Err(e) => {
rc = EXIT_NETWORK;
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R<'a> { alias: &'a str, error: String }
output_mode.print_json(&R { alias: &rec.alias, error: e.to_string() });
}
_ => { eprintln!("update {}: {e}", rec.alias); }
}
}
}
}
rc
}
ModelsCmd::Gc { common } => {
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
let (files_removed, entries_removed) = mm.gc()?;
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R { files_removed: usize, entries_removed: usize }
output_mode.print_json(&R { files_removed, entries_removed });
}
OutputMode::Human { quiet } => { if !quiet { println!("files_removed={} entries_removed={}", files_removed, entries_removed); } }
}
EXIT_OK
}
ModelsCmd::Search { repo, query, common } => {
let res = polyscribe_core::model_manager::search_repo(&repo, query.as_deref());
match res {
Ok(files) => {
match output_mode {
OutputMode::Json => output_mode.print_json(&files),
OutputMode::Human { quiet } => { for f in files { if !quiet { println!("{}", f); } } }
}
EXIT_OK
}
Err(e) => {
match output_mode {
OutputMode::Json => {
#[derive(serde::Serialize)]
struct R { error: String }
output_mode.print_json(&R { error: e.to_string() });
}
_ => { eprintln!("error: {e}"); }
}
EXIT_NETWORK
}
}
}
};
std::process::exit(exit);
} }
// TODO remove duplicate
entries.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap_or(std::cmp::Ordering::Equal)
.then(a.end.partial_cmp(&b.end).unwrap_or(std::cmp::Ordering::Equal)));
for (i, entry) in entries.iter_mut().enumerate() { entry.id = i as u64; }
let output_bundle = OutputRoot { items: entries };
if let Some(path) = output_path { Commands::Plugins { cmd } => {
let base_path = Path::new(&path); let plugin_manager = PluginManager;
let parent_opt = base_path.parent();
if let Some(parent) = parent_opt { match cmd {
if !parent.as_os_str().is_empty() { PluginsCmd::List => {
create_dir_all(parent).with_context(|| { let list = plugin_manager.list().context("discovering plugins")?;
format!("Failed to create parent directory for output: {}", parent.display()) for item in list {
})?; polyscribe_core::ui::info(item.name);
}
Ok(())
}
PluginsCmd::Info { name } => {
let info = plugin_manager
.info(&name)
.with_context(|| format!("getting info for {}", name))?;
let info_json = serde_json::to_string_pretty(&info)?;
polyscribe_core::ui::info(info_json);
Ok(())
}
PluginsCmd::Run {
name,
command,
json,
} => {
// Use a local Tokio runtime only for this async path
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.context("building tokio runtime")?;
rt.block_on(async {
let payload = json.unwrap_or_else(|| "{}".to_string());
let mut child = plugin_manager
.spawn(&name, &command)
.with_context(|| format!("spawning plugin {name} {command}"))?;
if let Some(mut stdin) = child.stdin.take() {
stdin
.write_all(payload.as_bytes())
.await
.context("writing JSON payload to plugin stdin")?;
}
let status = plugin_manager.forward_stdio(&mut child).await?;
if !status.success() {
polyscribe_core::ui::error(format!(
"plugin returned non-zero exit code: {}",
status
));
return Err(anyhow!("plugin failed"));
}
Ok(())
})
} }
} }
let stem = base_path.file_stem().and_then(|s| s.to_str()).unwrap_or("output");
let date = date_prefix();
let base_name = format!("{}_{}", date, stem);
let dir = parent_opt.unwrap_or(Path::new(""));
let json_path = dir.join(format!("{}.json", &base_name));
let toml_path = dir.join(format!("{}.toml", &base_name));
let srt_path = dir.join(format!("{}.srt", &base_name));
let mut json_file = File::create(&json_path).with_context(|| format!("Failed to create output file: {}", json_path.display()))?;
serde_json::to_writer_pretty(&mut json_file, &output_bundle)?; writeln!(&mut json_file)?;
let toml_str = toml::to_string_pretty(&output_bundle)?;
let mut toml_file = File::create(&toml_path).with_context(|| format!("Failed to create output file: {}", toml_path.display()))?;
toml_file.write_all(toml_str.as_bytes())?; if !toml_str.ends_with('\n') { writeln!(&mut toml_file)?; }
let srt_str = render_srt(&output_bundle.items);
let mut srt_file = File::create(&srt_path).with_context(|| format!("Failed to create output file: {}", srt_path.display()))?;
srt_file.write_all(srt_str.as_bytes())?;
} else {
let stdout = io::stdout();
let mut handle = stdout.lock();
serde_json::to_writer_pretty(&mut handle, &output_bundle)?; writeln!(&mut handle)?;
} }
return Ok(());
}
// SEPARATE (default) Commands::Completions { shell } => {
polyscribe::dlog!(1, "Mode: separate; output_dir={:?}", output_path); use clap_complete::{generate, shells};
if output_path.is_none() && inputs.len() > 1 { use std::io;
return Err(anyhow!("Multiple inputs without --merge require -o OUTPUT_DIR to write separate files"));
} let mut cmd = Cli::command();
let out_dir: Option<PathBuf> = output_path.as_ref().map(PathBuf::from); let name = cmd.get_name().to_string();
if let Some(dir) = &out_dir {
if !dir.as_os_str().is_empty() { match shell.as_str() {
create_dir_all(dir).with_context(|| format!("Failed to create output directory: {}", dir.display()))?; "bash" => generate(shells::Bash, &mut cmd, name, &mut io::stdout()),
"zsh" => generate(shells::Zsh, &mut cmd, name, &mut io::stdout()),
"fish" => generate(shells::Fish, &mut cmd, name, &mut io::stdout()),
"powershell" => generate(shells::PowerShell, &mut cmd, name, &mut io::stdout()),
"elvish" => generate(shells::Elvish, &mut cmd, name, &mut io::stdout()),
_ => return Err(anyhow!("unsupported shell: {shell}")),
}
Ok(())
}
Commands::Man => {
use clap_mangen::Man;
let cmd = Cli::command();
let man = Man::new(cmd);
man.render(&mut std::io::stdout())?;
Ok(())
} }
} }
for (index, input_path) in inputs.iter().enumerate() {
let path = Path::new(input_path);
let speaker = speakers[index].clone();
// TODO remove duplicate
let mut entries: Vec<OutputEntry> = if is_json_file(path) {
let mut buf = String::new();
File::open(path)
.with_context(|| format!("Failed to open: {input_path}"))?
.read_to_string(&mut buf)
.with_context(|| format!("Failed to read: {input_path}"))?;
let root: InputRoot = serde_json::from_str(&buf)
.with_context(|| format!("Invalid JSON transcript parsed from {input_path}"))?;
root
.segments
.into_iter()
.map(|seg| OutputEntry { id: 0, speaker: speaker.clone(), start: seg.start, end: seg.end, text: seg.text })
.collect()
} else {
let lang_norm: Option<String> = args.language.as_deref().and_then(|s| normalize_lang_code(s));
let selected_backend = polyscribe::backend::select_backend(polyscribe::backend::BackendKind::Auto, args.verbose > 0)?;
selected_backend.backend.transcribe(path, &speaker, lang_norm.as_deref(), None, None)?
};
// TODO remove duplicate
entries.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap_or(std::cmp::Ordering::Equal)
.then(a.end.partial_cmp(&b.end).unwrap_or(std::cmp::Ordering::Equal)));
for (i, entry) in entries.iter_mut().enumerate() { entry.id = i as u64; }
let stem = path.file_stem().and_then(|s| s.to_str()).unwrap_or("output");
let date = date_prefix();
let base_name = format!("{date}_{stem}");
if let Some(dir) = &out_dir {
let json_path = dir.join(format!("{}.json", &base_name));
let toml_path = dir.join(format!("{}.toml", &base_name));
let srt_path = dir.join(format!("{}.srt", &base_name));
let output_bundle = OutputRoot { items: entries };
let mut json_file = File::create(&json_path).with_context(|| format!("Failed to create output file: {}", json_path.display()))?;
serde_json::to_writer_pretty(&mut json_file, &output_bundle)?; writeln!(&mut json_file)?;
let toml_str = toml::to_string_pretty(&output_bundle)?;
let mut toml_file = File::create(&toml_path).with_context(|| format!("Failed to create output file: {}", toml_path.display()))?;
toml_file.write_all(toml_str.as_bytes())?; if !toml_str.ends_with('\n') { writeln!(&mut toml_file)?; }
let srt_str = render_srt(&output_bundle.items);
let mut srt_file = File::create(&srt_path).with_context(|| format!("Failed to create output file: {}", srt_path.display()))?;
srt_file.write_all(srt_str.as_bytes())?;
} else {
// In separate mode with single input and no output dir, print JSON to stdout
let stdout = io::stdout();
let mut handle = stdout.lock();
let output_bundle = OutputRoot { items: entries };
serde_json::to_writer_pretty(&mut handle, &output_bundle)?; writeln!(&mut handle)?;
}
}
Ok(())
} }

View File

@@ -0,0 +1,36 @@
use std::io::{self, Write};
#[derive(Clone, Debug)]
pub enum OutputMode {
Json,
Human { quiet: bool },
}
impl OutputMode {
pub fn is_quiet(&self) -> bool {
matches!(self, OutputMode::Json) || matches!(self, OutputMode::Human { quiet: true })
}
pub fn print_json<T: serde::Serialize>(&self, v: &T) {
if let OutputMode::Json = self {
// Write compact JSON to stdout without prefixes
// and ensure a trailing newline for CLI ergonomics
let s = serde_json::to_string(v).unwrap_or_else(|e| format!("\"JSON_ERROR:{}\"", e));
println!("{}", s);
}
}
pub fn print_line(&self, s: impl AsRef<str>) {
match self {
OutputMode::Json => {
// Suppress human lines in JSON mode
}
OutputMode::Human { quiet } => {
if !quiet {
let _ = writeln!(io::stdout(), "{}", s.as_ref());
}
}
}
}
}

View File

@@ -1,10 +1,11 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved. // Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
use assert_cmd::cargo::cargo_bin;
use std::process::Command; use std::process::Command;
fn bin() -> &'static str { fn bin() -> std::path::PathBuf {
env!("CARGO_BIN_EXE_polyscribe") cargo_bin("polyscribe")
} }
#[test] #[test]

View File

@@ -0,0 +1,42 @@
use assert_cmd::cargo::cargo_bin;
use std::process::Command;
fn bin() -> std::path::PathBuf { cargo_bin("polyscribe") }
#[test]
fn models_help_shows_global_output_flags() {
let out = Command::new(bin())
.args(["models", "--help"]) // subcommand help
.output()
.expect("failed to run polyscribe models --help");
assert!(out.status.success(), "help exited non-zero: {:?}", out.status);
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
assert!(stdout.contains("--json"), "--json not shown in help: {stdout}");
assert!(stdout.contains("--quiet"), "--quiet not shown in help: {stdout}");
}
#[test]
fn models_version_contains_pkg_version() {
let out = Command::new(bin())
.args(["models", "--version"]) // propagate_version
.output()
.expect("failed to run polyscribe models --version");
assert!(out.status.success(), "version exited non-zero: {:?}", out.status);
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
let want = env!("CARGO_PKG_VERSION");
assert!(stdout.contains(want), "version output missing {want}: {stdout}");
}
#[test]
fn models_ls_json_quiet_emits_pure_json() {
let out = Command::new(bin())
.args(["models", "ls", "--json", "--quiet"]) // global flags
.output()
.expect("failed to run polyscribe models ls --json --quiet");
assert!(out.status.success(), "ls exited non-zero: {:?}", out.status);
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
serde_json::from_str::<serde_json::Value>(stdout.trim()).expect("stdout is not valid JSON");
// Expect no extra logs on stdout; stderr should be empty in success path
assert!(out.stderr.is_empty(), "expected no stderr noise");
}

View File

@@ -1,32 +1,22 @@
[package] [package]
name = "polyscribe" name = "polyscribe-core"
version = "0.1.0" version.workspace = true
edition = "2024" edition.workspace = true
license = "MIT"
[features]
# Default: CPU only; no GPU features enabled
default = []
# GPU backends map to whisper-rs features or FFI stub for Vulkan
gpu-cuda = ["whisper-rs/cuda"]
gpu-hip = ["whisper-rs/hipblas"]
gpu-vulkan = []
# explicit CPU fallback feature (no effect at build time, used for clarity)
cpu-fallback = []
[dependencies] [dependencies]
anyhow = "1.0.98" anyhow = { workspace = true }
serde = { version = "1.0.219", features = ["derive"] } thiserror = { workspace = true }
serde_json = "1.0.142" serde = { workspace = true, features = ["derive"] }
toml = "0.8" serde_json = { workspace = true }
chrono = { version = "0.4", features = ["clock"] } toml = { workspace = true }
sha2 = "0.10" directories = { workspace = true }
whisper-rs = { git = "https://github.com/tazz4843/whisper-rs" } chrono = { workspace = true }
libc = "0.2" libc = { workspace = true }
cliclack = "0.3" whisper-rs = { workspace = true }
indicatif = "0.17" # UI and progress
thiserror = "1" cliclack = { workspace = true }
directories = "5" # HTTP downloads + hashing
reqwest = { workspace = true }
[build-dependencies] sha2 = { workspace = true }
# no special build deps hex = { workspace = true }
tempfile = { workspace = true }

View File

@@ -1,12 +1,14 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Move original build.rs behavior into core crate
fn main() { fn main() {
// Only run special build steps when gpu-vulkan feature is enabled.
let vulkan_enabled = std::env::var("CARGO_FEATURE_GPU_VULKAN").is_ok(); let vulkan_enabled = std::env::var("CARGO_FEATURE_GPU_VULKAN").is_ok();
println!("cargo:rerun-if-changed=extern/whisper.cpp");
if !vulkan_enabled { if !vulkan_enabled {
println!(
"cargo:warning=gpu-vulkan feature is disabled; skipping Vulkan-dependent build steps."
);
return; return;
} }
println!("cargo:rerun-if-changed=extern/whisper.cpp");
println!( println!(
"cargo:warning=Building with gpu-vulkan: ensure Vulkan SDK/loader are installed. Future versions will compile whisper.cpp via CMake." "cargo:warning=Building with gpu-vulkan: ensure Vulkan SDK/loader are installed. Future versions will compile whisper.cpp via CMake."
); );

View File

@@ -1,34 +1,23 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
//! Transcription backend selection and implementations (CPU/GPU) used by PolyScribe.
use crate::OutputEntry; use crate::OutputEntry;
use crate::prelude::*;
use crate::{decode_audio_to_pcm_f32_ffmpeg, find_model_file}; use crate::{decode_audio_to_pcm_f32_ffmpeg, find_model_file};
use anyhow::{Context, Result, anyhow}; use anyhow::{Context, anyhow};
use std::env; use std::env;
use std::path::Path; use std::path::Path;
// Re-export a public enum for CLI parsing usage
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
/// Kind of transcription backend to use.
pub enum BackendKind { pub enum BackendKind {
/// Automatically detect the best available backend (CUDA > HIP > Vulkan > CPU).
Auto, Auto,
/// Pure CPU backend using whisper-rs.
Cpu, Cpu,
/// NVIDIA CUDA backend (requires CUDA runtime available at load time and proper feature build).
Cuda, Cuda,
/// AMD ROCm/HIP backend (requires hip/rocBLAS libraries available and proper feature build).
Hip, Hip,
/// Vulkan backend (experimental; requires Vulkan loader/SDK and feature build).
Vulkan, Vulkan,
} }
/// Abstraction for a transcription backend.
pub trait TranscribeBackend { pub trait TranscribeBackend {
/// Backend kind implemented by this type.
fn kind(&self) -> BackendKind; fn kind(&self) -> BackendKind;
/// Transcribe the given audio and return transcript entries.
fn transcribe( fn transcribe(
&self, &self,
audio_path: &Path, audio_path: &Path,
@@ -39,15 +28,13 @@ pub trait TranscribeBackend {
) -> Result<Vec<OutputEntry>>; ) -> Result<Vec<OutputEntry>>;
} }
fn check_lib(_names: &[&str]) -> bool { fn is_library_available(_names: &[&str]) -> bool {
#[cfg(test)] #[cfg(test)]
{ {
// During unit tests, avoid touching system libs to prevent loader crashes in CI.
false false
} }
#[cfg(not(test))] #[cfg(not(test))]
{ {
// Disabled runtime dlopen probing to avoid loader instability; rely on environment overrides.
false false
} }
} }
@@ -56,7 +43,7 @@ fn cuda_available() -> bool {
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_CUDA") { if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_CUDA") {
return x == "1"; return x == "1";
} }
check_lib(&[ is_library_available(&[
"libcudart.so", "libcudart.so",
"libcudart.so.12", "libcudart.so.12",
"libcudart.so.11", "libcudart.so.11",
@@ -69,33 +56,31 @@ fn hip_available() -> bool {
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_HIP") { if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_HIP") {
return x == "1"; return x == "1";
} }
check_lib(&["libhipblas.so", "librocblas.so"]) is_library_available(&["libhipblas.so", "librocblas.so"])
} }
fn vulkan_available() -> bool { fn vulkan_available() -> bool {
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_VULKAN") { if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_VULKAN") {
return x == "1"; return x == "1";
} }
check_lib(&["libvulkan.so.1", "libvulkan.so"]) is_library_available(&["libvulkan.so.1", "libvulkan.so"])
} }
/// CPU-based transcription backend using whisper-rs.
#[derive(Default)] #[derive(Default)]
pub struct CpuBackend; pub struct CpuBackend;
/// CUDA-accelerated transcription backend for NVIDIA GPUs.
#[derive(Default)] #[derive(Default)]
pub struct CudaBackend; pub struct CudaBackend;
/// ROCm/HIP-accelerated transcription backend for AMD GPUs.
#[derive(Default)] #[derive(Default)]
pub struct HipBackend; pub struct HipBackend;
/// Vulkan-based transcription backend (experimental/incomplete).
#[derive(Default)] #[derive(Default)]
pub struct VulkanBackend; pub struct VulkanBackend;
macro_rules! impl_whisper_backend { macro_rules! impl_whisper_backend {
($ty:ty, $kind:expr) => { ($ty:ty, $kind:expr) => {
impl TranscribeBackend for $ty { impl TranscribeBackend for $ty {
fn kind(&self) -> BackendKind { $kind } fn kind(&self) -> BackendKind {
$kind
}
fn transcribe( fn transcribe(
&self, &self,
audio_path: &Path, audio_path: &Path,
@@ -128,29 +113,17 @@ impl TranscribeBackend for VulkanBackend {
) -> Result<Vec<OutputEntry>> { ) -> Result<Vec<OutputEntry>> {
Err(anyhow!( Err(anyhow!(
"Vulkan backend not yet wired to whisper.cpp FFI. Build with --features gpu-vulkan and ensure Vulkan SDK is installed. How to fix: install Vulkan loader (libvulkan), set VULKAN_SDK, and run cargo build --features gpu-vulkan." "Vulkan backend not yet wired to whisper.cpp FFI. Build with --features gpu-vulkan and ensure Vulkan SDK is installed. How to fix: install Vulkan loader (libvulkan), set VULKAN_SDK, and run cargo build --features gpu-vulkan."
)) ).into())
} }
} }
/// Result of choosing a transcription backend. pub struct BackendSelection {
pub struct SelectionResult {
/// The constructed backend instance to perform transcription with.
pub backend: Box<dyn TranscribeBackend + Send + Sync>, pub backend: Box<dyn TranscribeBackend + Send + Sync>,
/// Which backend kind was ultimately selected.
pub chosen: BackendKind, pub chosen: BackendKind,
/// Which backend kinds were detected as available on this system.
pub detected: Vec<BackendKind>, pub detected: Vec<BackendKind>,
} }
/// Select an appropriate backend based on user request and system detection. pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<BackendSelection> {
///
/// If `requested` is `BackendKind::Auto`, the function prefers CUDA, then HIP,
/// then Vulkan, falling back to CPU when no GPU backend is detected. When a
/// specific GPU backend is requested but unavailable, an error is returned with
/// guidance on how to enable it.
///
/// Set `verbose` to true to print detection/selection info to stderr.
pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<SelectionResult> {
let mut detected = Vec::new(); let mut detected = Vec::new();
if cuda_available() { if cuda_available() {
detected.push(BackendKind::Cuda); detected.push(BackendKind::Cuda);
@@ -164,11 +137,11 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
let instantiate_backend = |k: BackendKind| -> Box<dyn TranscribeBackend + Send + Sync> { let instantiate_backend = |k: BackendKind| -> Box<dyn TranscribeBackend + Send + Sync> {
match k { match k {
BackendKind::Cpu => Box::new(CpuBackend::default()), BackendKind::Cpu => Box::new(CpuBackend),
BackendKind::Cuda => Box::new(CudaBackend::default()), BackendKind::Cuda => Box::new(CudaBackend),
BackendKind::Hip => Box::new(HipBackend::default()), BackendKind::Hip => Box::new(HipBackend),
BackendKind::Vulkan => Box::new(VulkanBackend::default()), BackendKind::Vulkan => Box::new(VulkanBackend),
BackendKind::Auto => Box::new(CpuBackend::default()), // placeholder for Auto BackendKind::Auto => Box::new(CpuBackend),
} }
}; };
@@ -190,7 +163,7 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
} else { } else {
return Err(anyhow!( return Err(anyhow!(
"Requested CUDA backend but CUDA libraries/devices not detected. How to fix: install NVIDIA driver + CUDA toolkit, ensure libcudart/libcublas are in loader path, and build with --features gpu-cuda." "Requested CUDA backend but CUDA libraries/devices not detected. How to fix: install NVIDIA driver + CUDA toolkit, ensure libcudart/libcublas are in loader path, and build with --features gpu-cuda."
)); ).into());
} }
} }
BackendKind::Hip => { BackendKind::Hip => {
@@ -199,7 +172,7 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
} else { } else {
return Err(anyhow!( return Err(anyhow!(
"Requested ROCm/HIP backend but libraries/devices not detected. How to fix: install ROCm hipBLAS/rocBLAS, ensure libs are in loader path, and build with --features gpu-hip." "Requested ROCm/HIP backend but libraries/devices not detected. How to fix: install ROCm hipBLAS/rocBLAS, ensure libs are in loader path, and build with --features gpu-hip."
)); ).into());
} }
} }
BackendKind::Vulkan => { BackendKind::Vulkan => {
@@ -208,7 +181,7 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
} else { } else {
return Err(anyhow!( return Err(anyhow!(
"Requested Vulkan backend but libvulkan not detected. How to fix: install Vulkan loader/SDK and build with --features gpu-vulkan." "Requested Vulkan backend but libvulkan not detected. How to fix: install Vulkan loader/SDK and build with --features gpu-vulkan."
)); ).into());
} }
} }
BackendKind::Cpu => BackendKind::Cpu, BackendKind::Cpu => BackendKind::Cpu,
@@ -219,14 +192,13 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
crate::dlog!(1, "Selected backend: {:?}", chosen); crate::dlog!(1, "Selected backend: {:?}", chosen);
} }
Ok(SelectionResult { Ok(BackendSelection {
backend: instantiate_backend(chosen), backend: instantiate_backend(chosen),
chosen, chosen,
detected, detected,
}) })
} }
// Internal helper: transcription using whisper-rs with CPU/GPU (depending on build features)
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub(crate) fn transcribe_with_whisper_rs( pub(crate) fn transcribe_with_whisper_rs(
audio_path: &Path, audio_path: &Path,
@@ -235,7 +207,9 @@ pub(crate) fn transcribe_with_whisper_rs(
progress: Option<&(dyn Fn(i32) + Send + Sync)>, progress: Option<&(dyn Fn(i32) + Send + Sync)>,
) -> Result<Vec<OutputEntry>> { ) -> Result<Vec<OutputEntry>> {
let report = |p: i32| { let report = |p: i32| {
if let Some(cb) = progress { cb(p); } if let Some(cb) = progress {
cb(p);
}
}; };
report(0); report(0);
@@ -248,21 +222,21 @@ pub(crate) fn transcribe_with_whisper_rs(
.and_then(|s| s.to_str()) .and_then(|s| s.to_str())
.map(|s| s.contains(".en.") || s.ends_with(".en.bin")) .map(|s| s.contains(".en.") || s.ends_with(".en.bin"))
.unwrap_or(false); .unwrap_or(false);
if let Some(lang) = language { if let Some(lang) = language
if english_only_model && lang != "en" { && english_only_model
return Err(anyhow!( && lang != "en"
"Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.", {
model_path.display(), return Err(anyhow!(
lang "Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.",
)); model_path.display(),
} lang
).into());
} }
let model_path_str = model_path let model_path_str = model_path
.to_str() .to_str()
.ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model_path.display()))?; .ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model_path.display()))?;
if crate::verbose_level() < 2 { if crate::verbose_level() < 2 {
// Some builds of whisper/ggml expect these env vars; harmless if unknown
unsafe { unsafe {
std::env::set_var("GGML_LOG_LEVEL", "0"); std::env::set_var("GGML_LOG_LEVEL", "0");
std::env::set_var("WHISPER_PRINT_PROGRESS", "0"); std::env::set_var("WHISPER_PRINT_PROGRESS", "0");

View File

@@ -1,149 +1,104 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Simple ConfigService with XDG/system/workspace merge and atomic writes
use anyhow::{Context, Result};
use directories::BaseDirs;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::env; use std::env;
use std::fs; use std::path::PathBuf;
use std::io::Write;
use std::path::{Path, PathBuf};
/// Generic configuration represented as TOML table
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Config(pub toml::value::Table);
impl Config {
/// Get a mutable reference to a top-level table under the given key, creating
/// an empty table if it does not exist yet.
pub fn get_table_mut(&mut self, key: &str) -> &mut toml::value::Table {
let needs_init = !matches!(self.0.get(key), Some(toml::Value::Table(_)));
if needs_init {
self.0.insert(key.to_string(), toml::Value::Table(Default::default()));
}
match self.0.get_mut(key) {
Some(toml::Value::Table(t)) => t,
_ => unreachable!(),
}
}
}
fn merge_tables(base: &mut toml::value::Table, overlay: &toml::value::Table) {
for (k, v) in overlay.iter() {
match (base.get_mut(k), v) {
(Some(toml::Value::Table(bsub)), toml::Value::Table(osub)) => {
merge_tables(bsub, osub);
}
_ => {
base.insert(k.clone(), v.clone());
}
}
}
}
fn read_toml(path: &Path) -> Result<toml::value::Table> {
let s = fs::read_to_string(path).with_context(|| format!("Failed to read config: {}", path.display()))?;
let v: toml::Value = toml::from_str(&s).with_context(|| format!("Invalid TOML in {}", path.display()))?;
Ok(v.as_table().cloned().unwrap_or_default())
}
fn write_toml_atomic(path: &Path, tbl: &toml::value::Table) -> Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).with_context(|| format!("Failed to create config dir: {}", parent.display()))?;
}
let tmp = path.with_extension("tmp");
let mut f = fs::File::create(&tmp).with_context(|| format!("Failed to create temp file: {}", tmp.display()))?;
let s = toml::to_string_pretty(&toml::Value::Table(tbl.clone()))?;
f.write_all(s.as_bytes())?;
if !s.ends_with('\n') { f.write_all(b"\n")?; }
drop(f);
fs::rename(&tmp, path).with_context(|| format!("Failed to atomically replace config: {}", path.display()))?;
Ok(())
}
fn system_config_path() -> PathBuf {
if cfg!(unix) { PathBuf::from("/etc").join("polyscribe").join("config.toml") } else { default_user_config_path() }
}
fn default_user_config_path() -> PathBuf {
if let Some(base) = BaseDirs::new() {
return PathBuf::from(base.config_dir()).join("polyscribe").join("config.toml");
}
PathBuf::from(".polyscribe").join("config.toml")
}
fn workspace_config_path() -> PathBuf {
PathBuf::from(".polyscribe").join("config.toml")
}
/// Service responsible for loading and saving PolyScribe configuration
#[derive(Debug, Default, Clone)]
pub struct ConfigService; pub struct ConfigService;
impl ConfigService { impl ConfigService {
/// Load configuration, merging system < user < workspace < env overrides. pub const ENV_NO_CACHE_MANIFEST: &'static str = "POLYSCRIBE_NO_CACHE_MANIFEST";
pub fn load(&self) -> Result<Config> { pub const ENV_MANIFEST_TTL_SECONDS: &'static str = "POLYSCRIBE_MANIFEST_TTL_SECONDS";
let mut accum = toml::value::Table::default(); pub const ENV_MODELS_DIR: &'static str = "POLYSCRIBE_MODELS_DIR";
let sys = system_config_path(); pub const ENV_USER_AGENT: &'static str = "POLYSCRIBE_USER_AGENT";
if sys.exists() { pub const ENV_HTTP_TIMEOUT_SECS: &'static str = "POLYSCRIBE_HTTP_TIMEOUT_SECS";
merge_tables(&mut accum, &read_toml(&sys)?); pub const ENV_HF_REPO: &'static str = "POLYSCRIBE_HF_REPO";
} pub const ENV_CACHE_FILENAME: &'static str = "POLYSCRIBE_MANIFEST_CACHE_FILENAME";
let user = default_user_config_path();
if user.exists() { pub const DEFAULT_USER_AGENT: &'static str = "polyscribe/0.1";
merge_tables(&mut accum, &read_toml(&user)?); pub const DEFAULT_DOWNLOADER_UA: &'static str = "polyscribe-model-downloader/1";
} pub const DEFAULT_HF_REPO: &'static str = "ggerganov/whisper.cpp";
let ws = workspace_config_path(); pub const DEFAULT_CACHE_FILENAME: &'static str = "hf_manifest_whisper_cpp.json";
if ws.exists() { pub const DEFAULT_HTTP_TIMEOUT_SECS: u64 = 8;
merge_tables(&mut accum, &read_toml(&ws)?); pub const DEFAULT_MANIFEST_CACHE_TTL_SECONDS: u64 = 24 * 60 * 60;
}
// Env overrides: POLYSCRIBE__SECTION__KEY=value pub fn project_dirs() -> Option<directories::ProjectDirs> {
let mut env_over = toml::value::Table::default(); directories::ProjectDirs::from("dev", "polyscribe", "polyscribe")
for (k, v) in env::vars() { }
if let Some(rest) = k.strip_prefix("POLYSCRIBE__") {
let parts: Vec<&str> = rest.split("__").collect(); pub fn default_models_dir() -> Option<PathBuf> {
if parts.is_empty() { continue; } Self::project_dirs().map(|d| d.data_dir().join("models"))
let val: toml::Value = toml::Value::String(v); }
// Build nested tables
let mut current = &mut env_over; pub fn default_plugins_dir() -> Option<PathBuf> {
for (i, part) in parts.iter().enumerate() { Self::project_dirs().map(|d| d.data_dir().join("plugins"))
if i == parts.len() - 1 { }
current.insert(part.to_lowercase(), val.clone());
} else { pub fn manifest_cache_dir() -> Option<PathBuf> {
current = current.entry(part.to_lowercase()).or_insert_with(|| toml::Value::Table(Default::default())) Self::project_dirs().map(|d| d.cache_dir().join("manifest"))
.as_table_mut().expect("table"); }
}
} pub fn bypass_manifest_cache() -> bool {
env::var(Self::ENV_NO_CACHE_MANIFEST).is_ok()
}
pub fn manifest_cache_ttl_seconds() -> u64 {
env::var(Self::ENV_MANIFEST_TTL_SECONDS)
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(Self::DEFAULT_MANIFEST_CACHE_TTL_SECONDS)
}
pub fn manifest_cache_filename() -> String {
env::var(Self::ENV_CACHE_FILENAME)
.unwrap_or_else(|_| Self::DEFAULT_CACHE_FILENAME.to_string())
}
pub fn models_dir(cfg: Option<&Config>) -> Option<PathBuf> {
if let Ok(env_dir) = env::var(Self::ENV_MODELS_DIR) {
if !env_dir.is_empty() {
return Some(PathBuf::from(env_dir));
} }
} }
merge_tables(&mut accum, &env_over); if let Some(c) = cfg {
Ok(Config(accum)) if let Some(dir) = c.models_dir.clone() {
} return Some(dir);
}
/// Ensure user config exists with sensible defaults, return loaded config
pub fn ensure_user_config(&self) -> Result<Config> {
let path = default_user_config_path();
if !path.exists() {
let mut defaults = toml::value::Table::default();
defaults.insert("ui".into(), toml::Value::Table({
let mut t = toml::value::Table::default();
t.insert("theme".into(), toml::Value::String("auto".into()));
t
}));
write_toml_atomic(&path, &defaults)?;
} }
self.load() Self::default_models_dir()
} }
/// Save to user config atomically, merging over existing user file. pub fn user_agent() -> String {
pub fn save_user(&self, new_values: &toml::value::Table) -> Result<()> { env::var(Self::ENV_USER_AGENT).unwrap_or_else(|_| Self::DEFAULT_USER_AGENT.to_string())
let path = default_user_config_path();
let mut base = if path.exists() { read_toml(&path)? } else { Default::default() };
merge_tables(&mut base, new_values);
write_toml_atomic(&path, &base)
} }
/// Paths used for debugging/information pub fn downloader_user_agent() -> String {
pub fn paths(&self) -> (PathBuf, PathBuf, PathBuf) { env::var(Self::ENV_USER_AGENT).unwrap_or_else(|_| Self::DEFAULT_DOWNLOADER_UA.to_string())
(system_config_path(), default_user_config_path(), workspace_config_path()) }
pub fn http_timeout_secs() -> u64 {
env::var(Self::ENV_HTTP_TIMEOUT_SECS)
.ok()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(Self::DEFAULT_HTTP_TIMEOUT_SECS)
}
pub fn hf_repo() -> String {
env::var(Self::ENV_HF_REPO).unwrap_or_else(|_| Self::DEFAULT_HF_REPO.to_string())
}
pub fn hf_api_base_for(repo: &str) -> String {
format!("https://huggingface.co/api/models/{}", repo)
}
pub fn manifest_cache_path() -> Option<PathBuf> {
let dir = Self::manifest_cache_dir()?;
Some(dir.join(Self::manifest_cache_filename()))
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Config {
pub models_dir: Option<PathBuf>,
pub plugins_dir: Option<PathBuf>,
}

View File

@@ -0,0 +1,31 @@
use thiserror::Error;
#[derive(Debug, Error)]
pub enum Error {
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("serde error: {0}")]
Serde(#[from] serde_json::Error),
#[error("toml error: {0}")]
Toml(#[from] toml::de::Error),
#[error("toml ser error: {0}")]
TomlSer(#[from] toml::ser::Error),
#[error("env var error: {0}")]
EnvVar(#[from] std::env::VarError),
#[error("http error: {0}")]
Http(#[from] reqwest::Error),
#[error("other: {0}")]
Other(String),
}
impl From<anyhow::Error> for Error {
fn from(e: anyhow::Error) -> Self {
Error::Other(e.to_string())
}
}

View File

@@ -1,67 +1,59 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
#![forbid(elided_lifetimes_in_paths)] #![forbid(elided_lifetimes_in_paths)]
#![forbid(unused_must_use)] #![forbid(unused_must_use)]
#![deny(missing_docs)]
#![warn(clippy::all)] #![warn(clippy::all)]
//! PolyScribe library: business logic and core types.
//!
//! This crate exposes the reusable parts of the PolyScribe CLI as a library.
//! The binary entry point (main.rs) remains a thin CLI wrapper.
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering}; use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
// Global runtime flags use crate::prelude::*;
use anyhow::{Context, anyhow};
use chrono::Local;
use std::env;
use std::path::{Path, PathBuf};
use std::process::Command;
#[cfg(unix)]
use libc::{O_WRONLY, close, dup, dup2, open};
static QUIET: AtomicBool = AtomicBool::new(false); static QUIET: AtomicBool = AtomicBool::new(false);
static NO_INTERACTION: AtomicBool = AtomicBool::new(false); static NO_INTERACTION: AtomicBool = AtomicBool::new(false);
static VERBOSE: AtomicU8 = AtomicU8::new(0); static VERBOSE: AtomicU8 = AtomicU8::new(0);
static NO_PROGRESS: AtomicBool = AtomicBool::new(false); static NO_PROGRESS: AtomicBool = AtomicBool::new(false);
/// Set quiet mode: when true, non-interactive logs should be suppressed.
pub fn set_quiet(enabled: bool) { pub fn set_quiet(enabled: bool) {
QUIET.store(enabled, Ordering::Relaxed); QUIET.store(enabled, Ordering::Relaxed);
} }
/// Return current quiet mode state.
pub fn is_quiet() -> bool { pub fn is_quiet() -> bool {
QUIET.load(Ordering::Relaxed) QUIET.load(Ordering::Relaxed)
} }
/// Set non-interactive mode: when true, interactive prompts must be skipped.
pub fn set_no_interaction(enabled: bool) { pub fn set_no_interaction(enabled: bool) {
NO_INTERACTION.store(enabled, Ordering::Relaxed); NO_INTERACTION.store(enabled, Ordering::Relaxed);
} }
/// Return current non-interactive state.
pub fn is_no_interaction() -> bool { pub fn is_no_interaction() -> bool {
NO_INTERACTION.load(Ordering::Relaxed) NO_INTERACTION.load(Ordering::Relaxed)
} }
/// Set verbose level (0 = normal, 1 = verbose, 2 = super-verbose)
pub fn set_verbose(level: u8) { pub fn set_verbose(level: u8) {
VERBOSE.store(level, Ordering::Relaxed); VERBOSE.store(level, Ordering::Relaxed);
} }
/// Get current verbose level.
pub fn verbose_level() -> u8 { pub fn verbose_level() -> u8 {
VERBOSE.load(Ordering::Relaxed) VERBOSE.load(Ordering::Relaxed)
} }
/// Disable interactive progress indicators (bars/spinners)
pub fn set_no_progress(enabled: bool) { pub fn set_no_progress(enabled: bool) {
NO_PROGRESS.store(enabled, Ordering::Relaxed); NO_PROGRESS.store(enabled, Ordering::Relaxed);
} }
/// Return current no-progress state
pub fn is_no_progress() -> bool { pub fn is_no_progress() -> bool {
NO_PROGRESS.load(Ordering::Relaxed) NO_PROGRESS.load(Ordering::Relaxed)
} }
/// Check whether stdin is connected to a TTY. Used to avoid blocking prompts when not interactive.
pub fn stdin_is_tty() -> bool { pub fn stdin_is_tty() -> bool {
use std::io::IsTerminal as _; use std::io::IsTerminal as _;
std::io::stdin().is_terminal() std::io::stdin().is_terminal()
} }
/// A guard that temporarily redirects stderr to /dev/null on Unix when quiet mode is active.
/// No-op on non-Unix or when quiet is disabled. Restores stderr on drop.
pub struct StderrSilencer { pub struct StderrSilencer {
#[cfg(unix)] #[cfg(unix)]
old_stderr_fd: i32, old_stderr_fd: i32,
@@ -71,7 +63,6 @@ pub struct StderrSilencer {
} }
impl StderrSilencer { impl StderrSilencer {
/// Activate stderr silencing if quiet is set and on Unix; otherwise returns a no-op guard.
pub fn activate_if_quiet() -> Self { pub fn activate_if_quiet() -> Self {
if !is_quiet() { if !is_quiet() {
return Self { return Self {
@@ -85,7 +76,6 @@ impl StderrSilencer {
Self::activate() Self::activate()
} }
/// Activate stderr silencing unconditionally (used internally); no-op on non-Unix.
pub fn activate() -> Self { pub fn activate() -> Self {
#[cfg(unix)] #[cfg(unix)]
unsafe { unsafe {
@@ -97,11 +87,10 @@ impl StderrSilencer {
devnull_fd: -1, devnull_fd: -1,
}; };
} }
// Open /dev/null for writing
let devnull_cstr = std::ffi::CString::new("/dev/null").unwrap(); let devnull_cstr = std::ffi::CString::new("/dev/null").unwrap();
let devnull_fd = open(devnull_cstr.as_ptr(), O_WRONLY); let devnull_fd = open(devnull_cstr.as_ptr(), O_WRONLY);
if devnull_fd < 0 { if devnull_fd < 0 {
close(old_fd); let _ = close(old_fd);
return Self { return Self {
active: false, active: false,
old_stderr_fd: -1, old_stderr_fd: -1,
@@ -109,8 +98,8 @@ impl StderrSilencer {
}; };
} }
if dup2(devnull_fd, 2) < 0 { if dup2(devnull_fd, 2) < 0 {
close(devnull_fd); let _ = close(devnull_fd);
close(old_fd); let _ = close(old_fd);
return Self { return Self {
active: false, active: false,
old_stderr_fd: -1, old_stderr_fd: -1,
@@ -120,7 +109,7 @@ impl StderrSilencer {
Self { Self {
active: true, active: true,
old_stderr_fd: old_fd, old_stderr_fd: old_fd,
devnull_fd: devnull_fd, devnull_fd,
} }
} }
#[cfg(not(unix))] #[cfg(not(unix))]
@@ -144,7 +133,6 @@ impl Drop for StderrSilencer {
} }
} }
/// Run the given closure with stderr temporarily silenced (Unix-only). Returns the closure result.
pub fn with_suppressed_stderr<F, T>(f: F) -> T pub fn with_suppressed_stderr<F, T>(f: F) -> T
where where
F: FnOnce() -> T, F: FnOnce() -> T,
@@ -155,13 +143,11 @@ where
result result
} }
/// Log an error line (always printed).
#[macro_export] #[macro_export]
macro_rules! elog { macro_rules! elog {
($($arg:tt)*) => {{ $crate::ui::error(format!($($arg)*)); }} ($($arg:tt)*) => {{ $crate::ui::error(format!($($arg)*)); }}
} }
/// Log an informational line using the UI helper unless quiet mode is enabled.
#[macro_export] #[macro_export]
macro_rules! ilog { macro_rules! ilog {
($($arg:tt)*) => {{ ($($arg:tt)*) => {{
@@ -169,7 +155,6 @@ macro_rules! ilog {
}} }}
} }
/// Log a debug/trace line when verbose level is at least the given level (u8).
#[macro_export] #[macro_export]
macro_rules! dlog { macro_rules! dlog {
($lvl:expr, $($arg:tt)*) => {{ ($lvl:expr, $($arg:tt)*) => {{
@@ -177,52 +162,28 @@ macro_rules! dlog {
}} }}
} }
/// Backward-compatibility: map old qlog! to ilog!
#[macro_export]
macro_rules! qlog {
($($arg:tt)*) => {{ $crate::ilog!($($arg)*); }}
}
use anyhow::{Context, Result, anyhow};
use chrono::Local;
use std::env;
use std::fs::create_dir_all;
use std::path::{Path, PathBuf};
use std::process::Command;
#[cfg(unix)]
use libc::{O_WRONLY, close, dup, dup2, open};
/// Re-export backend module (GPU/CPU selection and transcription).
pub mod backend; pub mod backend;
/// Re-export models module (model listing/downloading/updating).
pub mod models;
/// Configuration service (XDG + atomic writes)
pub mod config; pub mod config;
/// UI helpers pub mod models;
pub mod error;
pub mod ui; pub mod ui;
pub use error::Error;
pub mod prelude;
/// Transcript entry for a single segment.
#[derive(Debug, serde::Serialize, Clone)] #[derive(Debug, serde::Serialize, Clone)]
pub struct OutputEntry { pub struct OutputEntry {
/// Sequential id in output ordering.
pub id: u64, pub id: u64,
/// Speaker label associated with the segment.
pub speaker: String, pub speaker: String,
/// Start time in seconds.
pub start: f64, pub start: f64,
/// End time in seconds.
pub end: f64, pub end: f64,
/// Text content.
pub text: String, pub text: String,
} }
/// Return a YYYY-MM-DD date prefix string for output file naming.
pub fn date_prefix() -> String { pub fn date_prefix() -> String {
Local::now().format("%Y-%m-%d").to_string() Local::now().format("%Y-%m-%d").to_string()
} }
/// Format a floating-point number of seconds as SRT timestamp (HH:MM:SS,mmm).
pub fn format_srt_time(seconds: f64) -> String { pub fn format_srt_time(seconds: f64) -> String {
let total_ms = (seconds * 1000.0).round() as i64; let total_ms = (seconds * 1000.0).round() as i64;
let ms = total_ms % 1000; let ms = total_ms % 1000;
@@ -233,7 +194,6 @@ pub fn format_srt_time(seconds: f64) -> String {
format!("{hour:02}:{min:02}:{sec:02},{ms:03}") format!("{hour:02}:{min:02}:{sec:02},{ms:03}")
} }
/// Render a list of transcript entries to SRT format.
pub fn render_srt(entries: &[OutputEntry]) -> String { pub fn render_srt(entries: &[OutputEntry]) -> String {
let mut srt = String::new(); let mut srt = String::new();
for (index, entry) in entries.iter().enumerate() { for (index, entry) in entries.iter().enumerate() {
@@ -254,7 +214,8 @@ pub fn render_srt(entries: &[OutputEntry]) -> String {
srt srt
} }
/// Determine the default models directory, honoring POLYSCRIBE_MODELS_DIR override. pub mod model_manager;
pub fn models_dir_path() -> PathBuf { pub fn models_dir_path() -> PathBuf {
if let Ok(env_val) = env::var("POLYSCRIBE_MODELS_DIR") { if let Ok(env_val) = env::var("POLYSCRIBE_MODELS_DIR") {
let env_path = PathBuf::from(env_val); let env_path = PathBuf::from(env_val);
@@ -265,24 +226,23 @@ pub fn models_dir_path() -> PathBuf {
if cfg!(debug_assertions) { if cfg!(debug_assertions) {
return PathBuf::from("models"); return PathBuf::from("models");
} }
if let Ok(xdg) = env::var("XDG_DATA_HOME") { if let Ok(xdg) = env::var("XDG_DATA_HOME")
if !xdg.is_empty() { && !xdg.is_empty()
return PathBuf::from(xdg).join("polyscribe").join("models"); {
} return PathBuf::from(xdg).join("polyscribe").join("models");
} }
if let Ok(home) = env::var("HOME") { if let Ok(home) = env::var("HOME")
if !home.is_empty() { && !home.is_empty()
return PathBuf::from(home) {
.join(".local") return PathBuf::from(home)
.join("share") .join(".local")
.join("polyscribe") .join("share")
.join("models"); .join("polyscribe")
} .join("models");
} }
PathBuf::from("models") PathBuf::from("models")
} }
/// Normalize a language identifier to a short ISO code when possible.
pub fn normalize_lang_code(input: &str) -> Option<String> { pub fn normalize_lang_code(input: &str) -> Option<String> {
let mut lang = input.trim().to_lowercase(); let mut lang = input.trim().to_lowercase();
if lang.is_empty() || lang == "auto" || lang == "c" || lang == "posix" { if lang.is_empty() || lang == "auto" || lang == "c" || lang == "posix" {
@@ -354,72 +314,92 @@ pub fn normalize_lang_code(input: &str) -> Option<String> {
Some(code.to_string()) Some(code.to_string())
} }
/// Find the Whisper model file path to use.
pub fn find_model_file() -> Result<PathBuf> { pub fn find_model_file() -> Result<PathBuf> {
if let Ok(path) = env::var("WHISPER_MODEL") { if let Ok(path) = env::var("WHISPER_MODEL") {
let p = PathBuf::from(path); let p = PathBuf::from(path);
if p.exists() { if !p.exists() {
return Ok(p);
} else {
return Err(anyhow!( return Err(anyhow!(
"WHISPER_MODEL points to non-existing file: {}", "WHISPER_MODEL points to a non-existing path: {}",
p.display() p.display()
)); )
.into());
} }
} if !p.is_file() {
let models_dir = models_dir_path(); return Err(anyhow!(
if !models_dir.exists() { "WHISPER_MODEL must point to a file, but is not: {}",
create_dir_all(&models_dir).with_context(|| { p.display()
format!("Failed to create models dir: {}", models_dir.display()) )
})?; .into());
}
return Ok(p);
} }
// Heuristic: prefer larger model files and English-only when language hint is en let models_dir = models_dir_path();
if models_dir.exists() && !models_dir.is_dir() {
return Err(anyhow!(
"Models path exists but is not a directory: {}",
models_dir.display()
)
.into());
}
std::fs::create_dir_all(&models_dir).with_context(|| {
format!(
"Failed to ensure models dir exists: {}",
models_dir.display()
)
})?;
let mut candidates = Vec::new(); let mut candidates = Vec::new();
for entry in std::fs::read_dir(&models_dir).with_context(|| format!( for entry in std::fs::read_dir(&models_dir)
"Failed to read models dir: {}", .with_context(|| format!("Failed to read models dir: {}", models_dir.display()))?
models_dir.display() {
))? {
let entry = entry?; let entry = entry?;
let path = entry.path(); let path = entry.path();
if !path
let is_bin = path
.extension() .extension()
.and_then(|s| s.to_str()) .and_then(|s| s.to_str())
.is_some_and(|s| s.eq_ignore_ascii_case("bin")) .is_some_and(|s| s.eq_ignore_ascii_case("bin"));
{ if !is_bin {
continue; continue;
} }
if let Ok(md) = std::fs::metadata(&path) {
candidates.push((md.len(), path)); let md = match std::fs::metadata(&path) {
} Ok(m) if m.is_file() => m,
_ => continue,
};
candidates.push((md.len(), path));
} }
if candidates.is_empty() { if candidates.is_empty() {
// Try default fallback (tiny.en)
let fallback = models_dir.join("ggml-tiny.en.bin"); let fallback = models_dir.join("ggml-tiny.en.bin");
if fallback.exists() { if fallback.is_file() {
return Ok(fallback); return Ok(fallback);
} }
return Err(anyhow!( return Err(anyhow!(
"No Whisper models found in {}. Please download a model or set WHISPER_MODEL.", "No Whisper model files (*.bin) found in {}. \
Please download a model or set WHISPER_MODEL.",
models_dir.display() models_dir.display()
)); )
.into());
} }
candidates.sort_by_key(|(size, _)| *size); candidates.sort_by_key(|(size, _)| *size);
let (_size, path) = candidates.into_iter().last().unwrap(); let (_size, path) = candidates.into_iter().last().expect("non-empty");
Ok(path) Ok(path)
} }
/// Decode an audio file into PCM f32 samples using ffmpeg (ffmpeg executable required).
pub fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> { pub fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> {
let in_path = audio_path let in_path = audio_path
.to_str() .to_str()
.ok_or_else(|| anyhow!("Audio path must be valid UTF-8: {}", audio_path.display()))?; .ok_or_else(|| anyhow!("Audio path must be valid UTF-8: {}", audio_path.display()))?;
let tmp_wav = std::env::temp_dir().join("polyscribe_tmp_input.wav");
let tmp_wav_str = tmp_wav
.to_str()
.ok_or_else(|| anyhow!("Temp path not valid UTF-8: {}", tmp_wav.display()))?;
// ffmpeg -i input -f f32le -ac 1 -ar 16000 -y /tmp/tmp.raw let tmp_raw = std::env::temp_dir().join("polyscribe_tmp_input.f32le");
let tmp_raw_str = tmp_raw
.to_str()
.ok_or_else(|| anyhow!("Temp path not valid UTF-8: {}", tmp_raw.display()))?;
let status = Command::new("ffmpeg") let status = Command::new("ffmpeg")
.arg("-hide_banner") .arg("-hide_banner")
.arg("-loglevel") .arg("-loglevel")
@@ -433,16 +413,25 @@ pub fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> {
.arg("-ar") .arg("-ar")
.arg("16000") .arg("16000")
.arg("-y") .arg("-y")
.arg(&tmp_wav_str) .arg(tmp_raw_str)
.status() .status()
.with_context(|| format!("Failed to invoke ffmpeg to decode: {}", in_path))?; .with_context(|| format!("Failed to invoke ffmpeg to decode: {}", in_path))?;
if !status.success() { if !status.success() {
return Err(anyhow!("ffmpeg exited with non-zero status when decoding {}", in_path)); return Err(anyhow!(
"ffmpeg exited with non-zero status when decoding {}",
in_path
)
.into());
} }
let raw = std::fs::read(&tmp_wav).with_context(|| format!("Failed to read temp PCM file: {}", tmp_wav.display()))?;
// Interpret raw bytes as f32 little-endian let raw = std::fs::read(&tmp_raw)
.with_context(|| format!("Failed to read temp PCM file: {}", tmp_raw.display()))?;
let _ = std::fs::remove_file(&tmp_raw);
if raw.len() % 4 != 0 { if raw.len() % 4 != 0 {
return Err(anyhow!("Decoded PCM file length not multiple of 4: {}", raw.len())); return Err(anyhow!("Decoded PCM file length not multiple of 4: {}", raw.len()).into());
} }
let mut samples = Vec::with_capacity(raw.len() / 4); let mut samples = Vec::with_capacity(raw.len() / 4);
for chunk in raw.chunks_exact(4) { for chunk in raw.chunks_exact(4) {

View File

@@ -0,0 +1,893 @@
// SPDX-License-Identifier: MIT
use crate::prelude::*;
use crate::ui::BytesProgress;
use anyhow::{anyhow, Context};
use chrono::{DateTime, Utc};
use reqwest::blocking::Client;
use reqwest::header::{
ACCEPT_RANGES, AUTHORIZATION, CONTENT_LENGTH, ETAG, IF_NONE_MATCH, LAST_MODIFIED, RANGE,
};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::cmp::min;
use std::collections::BTreeMap;
use std::fs::{self, File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::sync::{mpsc, Arc, Mutex};
use std::thread;
use std::time::Duration;
const DEFAULT_CHUNK_SIZE: u64 = 8 * 1024 * 1024; // 8 MiB
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ModelRecord {
pub alias: String,
pub repo: String,
pub file: String,
pub revision: Option<String>, // ETag or commit hash
pub sha256: Option<String>,
pub size_bytes: Option<u64>,
pub quant: Option<String>,
pub installed_at: Option<DateTime<Utc>>,
pub last_used: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Manifest {
pub models: BTreeMap<String, ModelRecord>, // key = alias
}
#[derive(Debug, Clone)]
pub struct Settings {
pub concurrency: usize,
pub limit_rate: Option<u64>, // bytes/sec
pub chunk_size: u64,
}
impl Default for Settings {
fn default() -> Self {
Self {
concurrency: 4,
limit_rate: None,
chunk_size: DEFAULT_CHUNK_SIZE,
}
}
}
#[derive(Debug, Clone)]
pub struct Paths {
pub cache_dir: PathBuf, // $XDG_CACHE_HOME/polyscribe/models
pub config_path: PathBuf, // $XDG_CONFIG_HOME/polyscribe/models.json
}
impl Paths {
pub fn resolve() -> Result<Self> {
if let Ok(over) = std::env::var("POLYSCRIBE_CACHE_DIR") {
if !over.is_empty() {
let cache_dir = PathBuf::from(over).join("models");
let config_path = std::env::var("POLYSCRIBE_CONFIG_DIR")
.map(|p| PathBuf::from(p).join("models.json"))
.unwrap_or_else(|_| default_config_path());
return Ok(Self {
cache_dir,
config_path,
});
}
}
let cache_dir = default_cache_dir();
let config_path = default_config_path();
Ok(Self {
cache_dir,
config_path,
})
}
}
fn default_cache_dir() -> PathBuf {
if let Ok(xdg) = std::env::var("XDG_CACHE_HOME") {
if !xdg.is_empty() {
return PathBuf::from(xdg).join("polyscribe").join("models");
}
}
if let Ok(home) = std::env::var("HOME") {
if !home.is_empty() {
return PathBuf::from(home)
.join(".cache")
.join("polyscribe")
.join("models");
}
}
PathBuf::from("models")
}
fn default_config_path() -> PathBuf {
if let Ok(xdg) = std::env::var("XDG_CONFIG_HOME") {
if !xdg.is_empty() {
return PathBuf::from(xdg).join("polyscribe").join("models.json");
}
}
if let Ok(home) = std::env::var("HOME") {
if !home.is_empty() {
return PathBuf::from(home)
.join(".config")
.join("polyscribe")
.join("models.json");
}
}
PathBuf::from("models.json")
}
pub trait HttpClient: Send + Sync {
fn head(&self, url: &str, etag: Option<&str>) -> Result<HeadMeta>;
fn get_range(&self, url: &str, start: u64, end_inclusive: u64) -> Result<Vec<u8>>;
fn get_whole_to(&self, url: &str, writer: &mut dyn Write) -> Result<()>;
fn get_from_to(&self, url: &str, start: u64, writer: &mut dyn Write) -> Result<()>;
}
#[derive(Debug, Clone)]
pub struct ReqwestClient {
client: Client,
token: Option<String>,
}
impl ReqwestClient {
pub fn new() -> Result<Self> {
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
let client = Client::builder()
.user_agent(crate::config::ConfigService::user_agent())
.build()?;
Ok(Self { client, token })
}
fn auth(&self, mut req: reqwest::blocking::RequestBuilder) -> reqwest::blocking::RequestBuilder {
if let Some(t) = &self.token {
req = req.header(AUTHORIZATION, format!("Bearer {}", t));
}
req
}
}
#[derive(Debug, Clone)]
pub struct HeadMeta {
pub len: Option<u64>,
pub etag: Option<String>,
pub last_modified: Option<String>,
pub accept_ranges: bool,
pub not_modified: bool,
pub status: u16,
}
impl HttpClient for ReqwestClient {
fn head(&self, url: &str, etag: Option<&str>) -> Result<HeadMeta> {
let mut req = self.client.head(url);
if let Some(e) = etag {
req = req.header(IF_NONE_MATCH, format!("\"{}\"", e));
}
let resp = self.auth(req).send()?;
let status = resp.status().as_u16();
if status == 304 {
return Ok(HeadMeta {
len: None,
etag: etag.map(|s| s.to_string()),
last_modified: None,
accept_ranges: true,
not_modified: true,
status,
});
}
let len = resp
.headers()
.get(CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
let etag = resp
.headers()
.get(ETAG)
.and_then(|v| v.to_str().ok())
.map(|s| s.trim_matches('"').to_string());
let last_modified = resp
.headers()
.get(LAST_MODIFIED)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let accept_ranges = resp
.headers()
.get(ACCEPT_RANGES)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_ascii_lowercase().contains("bytes"))
.unwrap_or(false);
Ok(HeadMeta {
len,
etag,
last_modified,
accept_ranges,
not_modified: false,
status,
})
}
fn get_range(&self, url: &str, start: u64, end_inclusive: u64) -> Result<Vec<u8>> {
let range_val = format!("bytes={}-{}", start, end_inclusive);
let resp = self
.auth(self.client.get(url))
.header(RANGE, range_val)
.send()?;
if !resp.status().is_success() && resp.status().as_u16() != 206 {
return Err(anyhow!("HTTP {} for ranged GET", resp.status()).into());
}
let mut buf = Vec::new();
let mut r = resp;
r.copy_to(&mut buf)?;
Ok(buf)
}
fn get_whole_to(&self, url: &str, writer: &mut dyn Write) -> Result<()> {
let resp = self.auth(self.client.get(url)).send()?;
if !resp.status().is_success() {
return Err(anyhow!("HTTP {} for GET", resp.status()).into());
}
let mut r = resp;
r.copy_to(writer)?;
Ok(())
}
fn get_from_to(&self, url: &str, start: u64, writer: &mut dyn Write) -> Result<()> {
let mut req = self.auth(self.client.get(url));
if start > 0 {
req = req.header(RANGE, format!("bytes={}-", start));
}
let resp = req.send()?;
if !resp.status().is_success() && resp.status().as_u16() != 206 {
return Err(anyhow!("HTTP {} for ranged GET from {}", resp.status(), start).into());
}
let mut r = resp;
r.copy_to(writer)?;
Ok(())
}
}
pub struct ModelManager<C: HttpClient = ReqwestClient> {
pub paths: Paths,
pub settings: Settings,
client: Arc<C>,
}
impl<C: HttpClient + 'static> ModelManager<C> {
pub fn new_with_client(client: C, settings: Settings) -> Result<Self> {
Ok(Self {
paths: Paths::resolve()?,
settings,
client: Arc::new(client),
})
}
pub fn new(settings: Settings) -> Result<Self>
where
C: Default,
{
Ok(Self {
paths: Paths::resolve()?,
settings,
client: Arc::new(C::default()),
})
}
fn load_manifest(&self) -> Result<Manifest> {
let p = &self.paths.config_path;
if !p.exists() {
return Ok(Manifest::default());
}
let file = File::open(p).with_context(|| format!("open manifest {}", p.display()))?;
let m: Manifest = serde_json::from_reader(file).context("parse manifest")?;
Ok(m)
}
fn save_manifest(&self, m: &Manifest) -> Result<()> {
let p = &self.paths.config_path;
if let Some(dir) = p.parent() {
fs::create_dir_all(dir)
.with_context(|| format!("create config dir {}", dir.display()))?;
}
let tmp = p.with_extension("json.tmp");
let f = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(&tmp)?;
serde_json::to_writer_pretty(f, m).context("serialize manifest")?;
fs::rename(&tmp, p).with_context(|| format!("atomic rename {} -> {}", tmp.display(), p.display()))?;
Ok(())
}
fn model_path(&self, file: &str) -> PathBuf {
self.paths.cache_dir.join(file)
}
fn compute_sha256(path: &Path) -> Result<String> {
let mut f = File::open(path)?;
let mut hasher = Sha256::new();
let mut buf = [0u8; 64 * 1024];
loop {
let n = f.read(&mut buf)?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
Ok(format!("{:x}", hasher.finalize()))
}
pub fn ls(&self) -> Result<Vec<ModelRecord>> {
let m = self.load_manifest()?;
Ok(m.models.values().cloned().collect())
}
pub fn rm(&self, alias: &str) -> Result<bool> {
let mut m = self.load_manifest()?;
if let Some(rec) = m.models.remove(alias) {
let p = self.model_path(&rec.file);
let _ = fs::remove_file(&p);
self.save_manifest(&m)?;
return Ok(true);
}
Ok(false)
}
pub fn verify(&self, alias: &str) -> Result<bool> {
let m = self.load_manifest()?;
let Some(rec) = m.models.get(alias) else { return Ok(false) };
let p = self.model_path(&rec.file);
if !p.exists() { return Ok(false); }
if let Some(expected) = &rec.sha256 {
let actual = Self::compute_sha256(&p)?;
return Ok(&actual == expected);
}
Ok(true)
}
pub fn gc(&self) -> Result<(usize, usize)> {
// Remove files not referenced by manifest; also drop manifest entries whose file is missing
fs::create_dir_all(&self.paths.cache_dir).ok();
let mut m = self.load_manifest()?;
let mut referenced = BTreeMap::new();
for (alias, rec) in &m.models {
referenced.insert(rec.file.clone(), alias.clone());
}
let mut removed_files = 0usize;
if let Ok(rd) = fs::read_dir(&self.paths.cache_dir) {
for ent in rd.flatten() {
let p = ent.path();
if p.is_file() {
let fname = p.file_name().and_then(|s| s.to_str()).unwrap_or("");
if !referenced.contains_key(fname) {
let _ = fs::remove_file(&p);
removed_files += 1;
}
}
}
}
m.models.retain(|_, rec| self.model_path(&rec.file).exists());
let removed_entries = referenced
.keys()
.filter(|f| !self.model_path(f).exists())
.count();
self.save_manifest(&m)?;
Ok((removed_files, removed_entries))
}
pub fn add_or_update(
&self,
alias: &str,
repo: &str,
file: &str,
) -> Result<ModelRecord> {
fs::create_dir_all(&self.paths.cache_dir)
.with_context(|| format!("create cache dir {}", self.paths.cache_dir.display()))?;
let url = format!("https://huggingface.co/{}/resolve/main/{}?download=true", repo, file);
let mut manifest = self.load_manifest()?;
let prev = manifest.models.get(alias).cloned();
let prev_etag = prev.as_ref().and_then(|r| r.revision.clone());
// Fetch remote meta (size/hash) when available to verify the download
let (_api_size, api_sha) = hf_fetch_file_meta(repo, file).unwrap_or((None, None));
let head = self.client.head(&url, prev_etag.as_deref())?;
if head.not_modified {
// up-to-date; ensure record present and touch last_used
let mut rec = prev.ok_or_else(|| anyhow!("not installed yet but server says 304"))?;
rec.last_used = Some(Utc::now());
self.save_touch(&mut manifest, rec.clone())?;
return Ok(rec);
}
// Quick check: if HEAD failed (e.g., 404), report a helpful error before attempting download
if head.status >= 400 {
return Err(anyhow!(
"file not found or inaccessible: repo='{}' file='{}' (HTTP {})\nHint: run `polyscribe models search {} --query {}` to list available files",
repo,
file,
head.status,
repo,
file
).into());
}
let total_len = head.len.ok_or_else(|| anyhow!("missing content-length (HEAD)"))?;
let etag = head.etag.clone();
let dest_tmp = self.model_path(&format!("{}.part", file));
// If a previous cancelled download left a .part file, remove it to avoid clutter/resume.
if dest_tmp.exists() { let _ = fs::remove_file(&dest_tmp); }
// Guard to ensure .part is cleaned up on errors
struct TempGuard { path: PathBuf, armed: bool }
impl TempGuard { fn disarm(&mut self) { self.armed = false; } }
impl Drop for TempGuard {
fn drop(&mut self) {
if self.armed {
let _ = fs::remove_file(&self.path);
}
}
}
let mut _tmp_guard = TempGuard { path: dest_tmp.clone(), armed: true };
let dest_final = self.model_path(file);
// Do not resume after cancellation; start fresh to avoid stale .part files
let start_from = 0u64;
// Open tmp for write
let f = OpenOptions::new().create(true).write(true).read(true).open(&dest_tmp)?;
f.set_len(total_len)?; // pre-allocate for random writes
let f = Arc::new(Mutex::new(f));
// Create progress bar
let mut progress = BytesProgress::start(total_len, &format!("Downloading {}", file), start_from);
// Create work chunks
let chunk_size = self.settings.chunk_size;
let mut chunks = Vec::new();
let mut pos = start_from;
while pos < total_len {
let end = min(total_len - 1, pos + chunk_size - 1);
chunks.push((pos, end));
pos = end + 1;
}
// Attempt concurrent ranged download; on failure, fallback to streaming GET
let mut ranged_failed = false;
if head.accept_ranges && self.settings.concurrency > 1 {
let (work_tx, work_rx) = mpsc::channel::<(u64, u64)>();
let (prog_tx, prog_rx) = mpsc::channel::<u64>();
for ch in chunks {
work_tx.send(ch).unwrap();
}
drop(work_tx);
let rx = Arc::new(Mutex::new(work_rx));
let workers = self.settings.concurrency.max(1);
let mut handles = Vec::new();
for _ in 0..workers {
let rx = rx.clone();
let url = url.clone();
let client = self.client.clone();
let f = f.clone();
let limit = self.settings.limit_rate;
let prog_tx = prog_tx.clone();
let handle = thread::spawn(move || -> Result<()> {
loop {
let next = {
let guard = rx.lock().unwrap();
guard.recv().ok()
};
let Some((start, end)) = next else { break; };
let data = client.get_range(&url, start, end)?;
if let Some(max_bps) = limit {
let dur = Duration::from_secs_f64((data.len() as f64) / (max_bps as f64));
if dur > Duration::from_millis(1) {
thread::sleep(dur);
}
}
let mut guard = f.lock().unwrap();
guard.seek(SeekFrom::Start(start))?;
guard.write_all(&data)?;
let _ = prog_tx.send(data.len() as u64);
}
Ok(())
});
handles.push(handle);
}
drop(prog_tx);
for delta in prog_rx {
progress.inc(delta);
}
let mut ranged_err: Option<anyhow::Error> = None;
for h in handles {
match h.join() {
Ok(Ok(())) => {}
Ok(Err(e)) => { ranged_failed = true; if ranged_err.is_none() { ranged_err = Some(e.into()); } }
Err(_) => { ranged_failed = true; if ranged_err.is_none() { ranged_err = Some(anyhow!("worker panicked")); } }
}
}
} else {
ranged_failed = true;
}
if ranged_failed {
// Restart progress if we are abandoning previous partial state
if start_from > 0 {
progress.stop("retrying as streaming");
progress = BytesProgress::start(total_len, &format!("Downloading {}", file), 0);
}
// Fallback to streaming GET; try URL with and without ?download=true
let mut try_urls = vec![url.clone()];
if let Some((base, _qs)) = url.split_once('?') {
try_urls.push(base.to_string());
} else {
try_urls.push(format!("{}?download=true", url));
}
// Fresh temp file for streaming
let mut wf = OpenOptions::new().create(true).write(true).truncate(true).open(&dest_tmp)?;
let mut ok = false;
let mut last_err: Option<anyhow::Error> = None;
// Counting writer to update progress inline
struct CountingWriter<'a, 'b> {
inner: &'a mut File,
progress: &'b mut BytesProgress,
}
impl<'a, 'b> Write for CountingWriter<'a, 'b> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let n = self.inner.write(buf)?;
self.progress.inc(n as u64);
Ok(n)
}
fn flush(&mut self) -> std::io::Result<()> { self.inner.flush() }
}
let mut cw = CountingWriter { inner: &mut wf, progress: &mut progress };
for u in try_urls {
// For fallback, stream from scratch to ensure integrity
let res = self.client.get_whole_to(&u, &mut cw);
match res {
Ok(()) => { ok = true; break; }
Err(e) => { last_err = Some(e.into()); }
}
}
if !ok {
if let Some(e) = last_err { return Err(e.into()); }
return Err(anyhow!("download failed (ranged and streaming)").into());
}
}
progress.stop("download complete");
// Verify integrity
let sha = Self::compute_sha256(&dest_tmp)?;
if let Some(expected) = api_sha.as_ref() {
if &sha != expected {
return Err(anyhow!(
"sha256 mismatch (expected {}, got {})",
expected,
sha
).into());
}
} else if prev.as_ref().map(|r| r.file.eq(file)).unwrap_or(false) {
if let Some(expected) = prev.as_ref().and_then(|r| r.sha256.as_ref()) {
if &sha != expected {
return Err(anyhow!("sha256 mismatch").into());
}
}
}
// Atomic rename
fs::rename(&dest_tmp, &dest_final).with_context(|| format!("rename {} -> {}", dest_tmp.display(), dest_final.display()))?;
// Disarm guard; .part has been moved or cleaned
_tmp_guard.disarm();
let rec = ModelRecord {
alias: alias.to_string(),
repo: repo.to_string(),
file: file.to_string(),
revision: etag,
sha256: Some(sha.clone()),
size_bytes: Some(total_len),
quant: infer_quant(file),
installed_at: Some(Utc::now()),
last_used: Some(Utc::now()),
};
self.save_touch(&mut manifest, rec.clone())?;
Ok(rec)
}
fn save_touch(&self, manifest: &mut Manifest, rec: ModelRecord) -> Result<()> {
manifest.models.insert(rec.alias.clone(), rec);
self.save_manifest(manifest)
}
}
fn infer_quant(file: &str) -> Option<String> {
// Try to extract a Q* token, e.g. Q5_K_M from filename
let lower = file.to_ascii_uppercase();
if let Some(pos) = lower.find('Q') {
let tail = &lower[pos..];
let token: String = tail
.chars()
.take_while(|c| c.is_ascii_uppercase() || c.is_ascii_digit() || *c == '_' || *c == '-')
.collect();
if token.len() >= 2 {
return Some(token);
}
}
None
}
impl Default for ReqwestClient {
fn default() -> Self {
Self::new().expect("reqwest client")
}
}
// Hugging Face API types for file metadata
#[derive(Debug, Deserialize)]
struct ApiHfLfs {
oid: Option<String>,
size: Option<u64>,
sha256: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ApiHfFile {
rfilename: String,
size: Option<u64>,
sha256: Option<String>,
lfs: Option<ApiHfLfs>,
}
#[derive(Debug, Deserialize)]
struct ApiHfModelInfo {
siblings: Option<Vec<ApiHfFile>>,
files: Option<Vec<ApiHfFile>>,
}
fn pick_sha_from_file(f: &ApiHfFile) -> Option<String> {
if let Some(s) = &f.sha256 { return Some(s.to_string()); }
if let Some(l) = &f.lfs {
if let Some(s) = &l.sha256 { return Some(s.to_string()); }
if let Some(oid) = &l.oid { return oid.strip_prefix("sha256:").map(|s| s.to_string()); }
}
None
}
fn hf_fetch_file_meta(repo: &str, target: &str) -> Result<(Option<u64>, Option<String>)> {
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
let client = Client::builder()
.user_agent(crate::config::ConfigService::user_agent())
.build()?;
let base = crate::config::ConfigService::hf_api_base_for(repo);
let urls = [base.clone(), format!("{}?expand=files", base)];
for url in urls {
let mut req = client.get(&url);
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
let resp = req.send()?;
if !resp.status().is_success() { continue; }
let info: ApiHfModelInfo = resp.json()?;
let list = info.files.or(info.siblings).unwrap_or_default();
for f in list {
let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename);
if name.eq_ignore_ascii_case(target) {
let sz = f.size.or_else(|| f.lfs.as_ref().and_then(|l| l.size));
let sha = pick_sha_from_file(&f);
return Ok((sz, sha));
}
}
}
Err(anyhow!("file not found in HF API").into())
}
/// Fetch remote metadata (size, sha256) for a single file in a HF repo.
pub fn fetch_file_meta(repo: &str, file: &str) -> Result<(Option<u64>, Option<String>)> {
hf_fetch_file_meta(repo, file)
}
/// Search a Hugging Face repo for GGUF/BIN files via API. Returns file names only.
pub fn search_repo(repo: &str, query: Option<&str>) -> Result<Vec<String>> {
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
let client = Client::builder()
.user_agent(crate::config::ConfigService::user_agent())
.build()?;
let base = crate::config::ConfigService::hf_api_base_for(repo);
let mut urls = vec![base.clone(), format!("{}?expand=files", base)];
let mut files = Vec::<String>::new();
for url in urls.drain(..) {
let mut req = client.get(&url);
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
let resp = req.send()?;
if !resp.status().is_success() { continue; }
let info: ApiHfModelInfo = resp.json()?;
let list = info.files.or(info.siblings).unwrap_or_default();
for f in list {
if f.rfilename.ends_with(".gguf") || f.rfilename.ends_with(".bin") {
let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename).to_string();
if !files.contains(&name) { files.push(name); }
}
}
if !files.is_empty() { break; }
}
if let Some(q) = query { let qlc = q.to_ascii_lowercase(); files.retain(|f| f.to_ascii_lowercase().contains(&qlc)); }
files.sort();
Ok(files)
}
/// List repo files with optional size metadata for GGUF/BIN entries.
pub fn list_repo_files_with_meta(repo: &str) -> Result<Vec<(String, Option<u64>)>> {
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
let client = Client::builder()
.user_agent(crate::config::ConfigService::user_agent())
.build()?;
let base = crate::config::ConfigService::hf_api_base_for(repo);
for url in [base.clone(), format!("{}?expand=files", base)] {
let mut req = client.get(&url);
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
let resp = req.send()?;
if !resp.status().is_success() { continue; }
let info: ApiHfModelInfo = resp.json()?;
let list = info.files.or(info.siblings).unwrap_or_default();
let mut out = Vec::new();
for f in list {
if !(f.rfilename.ends_with(".gguf") || f.rfilename.ends_with(".bin")) { continue; }
let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename).to_string();
let size = f.size.or_else(|| f.lfs.as_ref().and_then(|l| l.size));
out.push((name, size));
}
if !out.is_empty() { return Ok(out); }
}
Ok(Vec::new())
}
/// Fallback: HEAD request for a single file to retrieve Content-Length (size).
pub fn head_len_for_file(repo: &str, file: &str) -> Option<u64> {
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
let client = Client::builder()
.user_agent(crate::config::ConfigService::user_agent())
.build().ok()?;
let mut urls = Vec::new();
urls.push(format!("https://huggingface.co/{}/resolve/main/{}?download=true", repo, file));
urls.push(format!("https://huggingface.co/{}/resolve/main/{}", repo, file));
for url in urls {
let mut req = client.head(&url);
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
if let Ok(resp) = req.send() {
if resp.status().is_success() {
if let Some(len) = resp.headers().get(CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
{ return Some(len); }
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
use tempfile::TempDir;
#[derive(Clone)]
struct StubHttp {
data: Arc<Vec<u8>>,
etag: Arc<Option<String>>,
accept_ranges: bool,
}
impl HttpClient for StubHttp {
fn head(&self, _url: &str, etag: Option<&str>) -> Result<HeadMeta> {
let not_modified = etag.is_some() && self.etag.as_ref().as_deref() == etag;
Ok(HeadMeta {
len: Some(self.data.len() as u64),
etag: self.etag.as_ref().clone(),
last_modified: None,
accept_ranges: self.accept_ranges,
not_modified,
status: if not_modified { 304 } else { 200 },
})
}
fn get_range(&self, _url: &str, start: u64, end_inclusive: u64) -> Result<Vec<u8>> {
let s = start as usize;
let e = (end_inclusive as usize) + 1;
Ok(self.data[s..e].to_vec())
}
fn get_whole_to(&self, _url: &str, writer: &mut dyn Write) -> Result<()> {
writer.write_all(&self.data)?;
Ok(())
}
fn get_from_to(&self, _url: &str, start: u64, writer: &mut dyn Write) -> Result<()> {
let s = start as usize;
writer.write_all(&self.data[s..])?;
Ok(())
}
}
fn setup_env(cache: &Path, cfg: &Path) {
unsafe {
env::set_var("POLYSCRIBE_CACHE_DIR", cache.to_string_lossy().to_string());
env::set_var(
"POLYSCRIBE_CONFIG_DIR",
cfg.parent().unwrap().to_string_lossy().to_string(),
);
}
}
#[test]
fn test_manifest_roundtrip() {
let temp = TempDir::new().unwrap();
let cache = temp.path().join("cache");
let cfg = temp.path().join("config").join("models.json");
setup_env(&cache, &cfg);
let client = StubHttp {
data: Arc::new(vec![0u8; 1024]),
etag: Arc::new(Some("etag123".into())),
accept_ranges: true,
};
let mm: ModelManager<StubHttp> = ModelManager::new_with_client(client, Settings::default()).unwrap();
let m = mm.load_manifest().unwrap();
assert!(m.models.is_empty());
let rec = ModelRecord {
alias: "tiny".into(),
repo: "foo/bar".into(),
file: "gguf-tiny.bin".into(),
revision: Some("etag123".into()),
sha256: None,
size_bytes: None,
quant: None,
installed_at: None,
last_used: None,
};
let mut m2 = Manifest::default();
mm.save_touch(&mut m2, rec.clone()).unwrap();
let m3 = mm.load_manifest().unwrap();
assert!(m3.models.contains_key("tiny"));
}
#[test]
fn test_add_verify_update_gc() {
let temp = TempDir::new().unwrap();
let cache = temp.path().join("cache");
let cfg_dir = temp.path().join("config");
let cfg = cfg_dir.join("models.json");
setup_env(&cache, &cfg);
let data = (0..1024 * 1024u32).flat_map(|i| i.to_le_bytes()).collect::<Vec<u8>>();
let etag = Some("abc123".to_string());
let client = StubHttp { data: Arc::new(data), etag: Arc::new(etag), accept_ranges: true };
let mm: ModelManager<StubHttp> = ModelManager::new_with_client(client.clone(), Settings{ concurrency: 3, ..Default::default() }).unwrap();
// add
let rec = mm.add_or_update("tiny", "gguf/models", "gguf-tiny-q4_0.bin").unwrap();
assert_eq!(rec.alias, "tiny");
assert!(mm.verify("tiny").unwrap());
// update (304)
let rec2 = mm.add_or_update("tiny", "gguf/models", "gguf-tiny-q4_0.bin").unwrap();
assert_eq!(rec2.alias, "tiny");
// gc (nothing to remove)
let (files_removed, entries_removed) = mm.gc().unwrap();
assert_eq!(files_removed, 0);
assert_eq!(entries_removed, 0);
// rm
assert!(mm.rm("tiny").unwrap());
assert!(!mm.rm("tiny").unwrap());
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,7 @@
pub use crate::backend::*;
pub use crate::config::*;
pub use crate::error::Error;
pub use crate::models::*;
pub use crate::ui::*;
pub type Result<T, E = Error> = std::result::Result<T, E>;

View File

@@ -1,87 +1,329 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
//! Centralized UI helpers (TTY-aware, quiet/verbose-aware) pub mod progress;
use std::io; use std::io;
use std::io::IsTerminal;
use std::io::Write as _;
use std::time::{Duration, Instant};
/// Startup intro/banner (suppressed when quiet). pub fn info(msg: impl AsRef<str>) {
pub fn intro(msg: impl AsRef<str>) { let m = msg.as_ref();
let _ = cliclack::intro(msg.as_ref()); let _ = cliclack::log::info(m);
}
pub fn warn(msg: impl AsRef<str>) {
let m = msg.as_ref();
let _ = cliclack::log::warning(m);
}
pub fn error(msg: impl AsRef<str>) {
let m = msg.as_ref();
let _ = cliclack::log::error(m);
}
pub fn success(msg: impl AsRef<str>) {
let m = msg.as_ref();
let _ = cliclack::log::success(m);
}
pub fn note(prompt: impl AsRef<str>, message: impl AsRef<str>) {
let _ = cliclack::note(prompt.as_ref(), message.as_ref());
}
pub fn intro(title: impl AsRef<str>) {
let _ = cliclack::intro(title.as_ref());
} }
/// Final outro/summary printed below any progress indicators (suppressed when quiet).
pub fn outro(msg: impl AsRef<str>) { pub fn outro(msg: impl AsRef<str>) {
let _ = cliclack::outro(msg.as_ref()); let _ = cliclack::outro(msg.as_ref());
} }
/// Info message (TTY-aware; suppressed by --quiet is handled by outer callers if needed) pub fn println_above_bars(line: impl AsRef<str>) {
pub fn info(msg: impl AsRef<str>) { let _ = cliclack::log::info(line.as_ref());
let _ = cliclack::log::info(msg.as_ref());
} }
/// Print a warning (always printed). pub fn prompt_input(prompt: &str, default: Option<&str>) -> io::Result<String> {
pub fn warn(msg: impl AsRef<str>) {
// cliclack provides a warning-level log utility
let _ = cliclack::log::warning(msg.as_ref());
}
/// Print an error (always printed).
pub fn error(msg: impl AsRef<str>) {
let _ = cliclack::log::error(msg.as_ref());
}
/// Print a line above any progress bars (maps to cliclack log; synchronized).
pub fn println_above_bars(msg: impl AsRef<str>) {
if crate::is_quiet() { return; }
// cliclack logs are synchronized with its spinners/bars
let _ = cliclack::log::info(msg.as_ref());
}
/// Input prompt with a question: returns Ok(None) if non-interactive or canceled
pub fn prompt_input(question: impl AsRef<str>, default: Option<&str>) -> anyhow::Result<Option<String>> {
if crate::is_no_interaction() || !crate::stdin_is_tty() { if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Ok(None); return Ok(default.unwrap_or("").to_string());
} }
let mut p = cliclack::input(question.as_ref()); let mut q = cliclack::input(prompt);
if let Some(d) = default { if let Some(def) = default {
// Use default_input when available in 0.3.x q = q.default_input(def);
p = p.default_input(d);
}
match p.interact() {
Ok(s) => Ok(Some(s)),
Err(_) => Ok(None),
} }
q.interact().map_err(|e| io::Error::other(e.to_string()))
} }
/// Confirmation prompt; returns Ok(None) if non-interactive or canceled pub fn prompt_select(prompt: &str, items: &[&str]) -> io::Result<usize> {
pub fn prompt_confirm(question: impl AsRef<str>, default_yes: bool) -> anyhow::Result<Option<bool>> {
if crate::is_no_interaction() || !crate::stdin_is_tty() { if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Ok(None); return Err(io::Error::other("interactive prompt disabled"));
} }
let res = cliclack::confirm(question.as_ref()) let mut sel = cliclack::select::<usize>(prompt);
.initial_value(default_yes) for (idx, label) in items.iter().enumerate() {
.interact(); sel = sel.item(idx, *label, "");
match res { }
Ok(v) => Ok(Some(v)), sel.interact().map_err(|e| io::Error::other(e.to_string()))
Err(_) => Ok(None), }
pub fn prompt_multi_select(
prompt: &str,
items: &[&str],
defaults: Option<&[bool]>,
) -> io::Result<Vec<usize>> {
if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Err(io::Error::other("interactive prompt disabled"));
}
let mut ms = cliclack::multiselect::<usize>(prompt);
for (idx, label) in items.iter().enumerate() {
ms = ms.item(idx, *label, "");
}
if let Some(def) = defaults {
let selected: Vec<usize> = def
.iter()
.enumerate()
.filter_map(|(i, &on)| if on { Some(i) } else { None })
.collect();
if !selected.is_empty() {
ms = ms.initial_values(selected);
}
}
ms.interact().map_err(|e| io::Error::other(e.to_string()))
}
pub fn prompt_confirm(prompt: &str, default: bool) -> io::Result<bool> {
if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Ok(default);
}
let mut q = cliclack::confirm(prompt);
q.interact().map_err(|e| io::Error::other(e.to_string()))
}
pub fn prompt_password(prompt: &str) -> io::Result<String> {
if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Err(io::Error::other(
"password prompt disabled in non-interactive mode",
));
}
let mut q = cliclack::password(prompt);
q.interact().map_err(|e| io::Error::other(e.to_string()))
}
pub fn prompt_input_validated<F>(
prompt: &str,
default: Option<&str>,
validate: F,
) -> io::Result<String>
where
F: Fn(&str) -> Result<(), String> + 'static,
{
if crate::is_no_interaction() || !crate::stdin_is_tty() {
if let Some(def) = default {
return Ok(def.to_string());
}
return Err(io::Error::other("interactive prompt disabled"));
}
let mut q = cliclack::input(prompt);
if let Some(def) = default {
q = q.default_input(def);
}
q.validate(move |s: &String| validate(s))
.interact()
.map_err(|e| io::Error::other(e.to_string()))
}
pub struct Spinner(cliclack::ProgressBar);
impl Spinner {
pub fn start(text: impl AsRef<str>) -> Self {
if crate::is_no_progress() || crate::is_no_interaction() || !std::io::stderr().is_terminal()
{
let _ = cliclack::log::info(text.as_ref());
let s = cliclack::spinner();
Self(s)
} else {
let s = cliclack::spinner();
s.start(text.as_ref());
Self(s)
}
}
pub fn stop(self, text: impl AsRef<str>) {
let s = self.0;
if crate::is_no_progress() {
let _ = cliclack::log::info(text.as_ref());
} else {
s.stop(text.as_ref());
}
}
pub fn success(self, text: impl AsRef<str>) {
let s = self.0;
if crate::is_no_progress() {
let _ = cliclack::log::success(text.as_ref());
} else {
s.stop(text.as_ref());
}
}
pub fn error(self, text: impl AsRef<str>) {
let s = self.0;
if crate::is_no_progress() {
let _ = cliclack::log::error(text.as_ref());
} else {
s.error(text.as_ref());
}
} }
} }
/// Prompt the user (TTY-aware via cliclack) and read a line from stdin. Returns the raw line with trailing newline removed. pub struct BytesProgress {
pub fn prompt_line(prompt: &str) -> io::Result<String> { enabled: bool,
// Route prompt through cliclack to keep consistent styling and avoid direct eprint!/println! total: u64,
let _ = cliclack::log::info(prompt); current: u64,
let mut s = String::new(); started: Instant,
io::stdin().read_line(&mut s)?; last_msg: Instant,
Ok(s) width: usize,
// Sticky ETA to carry through zero-speed stalls
last_eta_secs: Option<f64>,
} }
/// TTY-aware progress UI built on `indicatif` for per-file and aggregate progress bars. impl BytesProgress {
/// pub fn start(total: u64, text: &str, initial: u64) -> Self {
/// This small helper encapsulates a `MultiProgress` with one aggregate (total) bar and let enabled = !(crate::is_no_progress()
/// one per-file bar. It is intentionally minimal to keep integration lightweight. || crate::is_no_interaction()
pub mod progress { || !std::io::stderr().is_terminal()
// The submodule is defined in a separate file for clarity. || total == 0);
include!("ui/progress.rs"); if !enabled {
let _ = cliclack::log::info(text);
}
let mut me = Self {
enabled,
total,
current: initial.min(total),
started: Instant::now(),
last_msg: Instant::now(),
width: 40,
last_eta_secs: None,
};
me.draw();
me
}
fn human_bytes(n: u64) -> String {
const KB: f64 = 1024.0;
const MB: f64 = 1024.0 * KB;
const GB: f64 = 1024.0 * MB;
let x = n as f64;
if x >= GB {
format!("{:.2} GiB", x / GB)
} else if x >= MB {
format!("{:.2} MiB", x / MB)
} else if x >= KB {
format!("{:.2} KiB", x / KB)
} else {
format!("{} B", n)
}
}
// Elapsed formatting is used for stable, finite durations. For ETA, we guard
// against zero-speed or unstable estimates separately via `format_eta`.
fn refresh_allowed(&mut self) -> (f64, f64) {
let now = Instant::now();
let since_last = now.duration_since(self.last_msg);
if since_last < Duration::from_millis(100) {
// Too soon to refresh; keep previous ETA if any
let eta = self.last_eta_secs.unwrap_or(f64::INFINITY);
return (0.0, eta);
}
self.last_msg = now;
let elapsed = now.duration_since(self.started).as_secs_f64().max(0.001);
let speed = (self.current as f64) / elapsed;
let remaining = self.total.saturating_sub(self.current) as f64;
// If speed is effectively zero, carry ETA forward and add wall time.
const EPS: f64 = 1e-6;
let eta = if speed <= EPS {
let prev = self.last_eta_secs.unwrap_or(f64::INFINITY);
if prev.is_finite() {
prev + since_last.as_secs_f64()
} else {
prev
}
} else {
remaining / speed
};
// Remember only finite ETAs to use during stalls
if eta.is_finite() {
self.last_eta_secs = Some(eta);
}
(speed, eta)
}
fn format_elapsed(seconds: f64) -> String {
let total = seconds.round() as u64;
let h = total / 3600;
let m = (total % 3600) / 60;
let s = total % 60;
if h > 0 { format!("{:02}:{:02}:{:02}", h, m, s) } else { format!("{:02}:{:02}", m, s) }
}
fn format_eta(seconds: f64) -> String {
// If ETA is not finite (e.g., divide-by-zero speed) or unreasonably large,
// show a placeholder rather than overflowing into huge values.
if !seconds.is_finite() {
return "".to_string();
}
// Cap ETA display to 99:59:59 to avoid silly numbers; beyond that, show placeholder.
const CAP_SECS: f64 = 99.0 * 3600.0 + 59.0 * 60.0 + 59.0;
if seconds > CAP_SECS {
return "".to_string();
}
Self::format_elapsed(seconds)
}
fn draw(&mut self) {
if !self.enabled { return; }
let (speed, eta) = self.refresh_allowed();
let elapsed = Instant::now().duration_since(self.started).as_secs_f64();
// Build bar
let width = self.width.max(10);
let filled = ((self.current as f64 / self.total.max(1) as f64) * width as f64).round() as usize;
let filled = filled.min(width);
let mut bar = String::with_capacity(width);
for _ in 0..filled { bar.push('■'); }
for _ in filled..width { bar.push('□'); }
let line = format!(
"[{}] {} [{}] ({}/{} at {}/s)",
Self::format_elapsed(elapsed),
bar,
Self::format_eta(eta),
Self::human_bytes(self.current),
Self::human_bytes(self.total),
Self::human_bytes(speed.max(0.0) as u64),
);
eprint!("\r{}\x1b[K", line);
let _ = io::stderr().flush();
}
pub fn inc(&mut self, delta: u64) {
self.current = self.current.saturating_add(delta).min(self.total);
self.draw();
}
pub fn stop(mut self, text: &str) {
if self.enabled {
self.draw();
eprintln!();
} else {
let _ = cliclack::log::info(text);
}
}
pub fn error(mut self, text: &str) {
if self.enabled {
self.draw();
eprintln!();
let _ = cliclack::log::error(text);
} else {
let _ = cliclack::log::error(text);
}
}
} }

View File

@@ -1,81 +1,122 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use std::io::IsTerminal as _; use std::io::IsTerminal as _;
/// Manages a set of per-file progress bars plus a top aggregate bar. pub struct FileProgress {
pub struct ProgressManager {
enabled: bool, enabled: bool,
mp: Option<MultiProgress>, file_bars: Vec<cliclack::ProgressBar>,
per: Vec<ProgressBar>, total_bar: Option<cliclack::ProgressBar>,
total: Option<ProgressBar>,
completed: usize, completed: usize,
total_file_count: usize,
} }
impl ProgressManager { impl FileProgress {
/// Create a new manager with the given enabled flag.
pub fn new(enabled: bool) -> Self { pub fn new(enabled: bool) -> Self {
Self { enabled, mp: None, per: Vec::new(), total: None, completed: 0 } Self {
enabled,
file_bars: Vec::new(),
total_bar: None,
completed: 0,
total_file_count: 0,
}
} }
/// Create a manager that enables bars when `n > 1`, stderr is a TTY, and not quiet. pub fn default_for_files(file_count: usize) -> Self {
pub fn default_for_files(n: usize) -> Self { let enabled = file_count > 1
let enabled = n > 1 && std::io::stderr().is_terminal() && !crate::is_quiet() && !crate::is_no_progress(); && std::io::stderr().is_terminal()
&& !crate::is_quiet()
&& !crate::is_no_progress();
Self::new(enabled) Self::new(enabled)
} }
/// Initialize bars for the given file labels. If disabled or single file, no-op.
pub fn init_files(&mut self, labels: &[String]) { pub fn init_files(&mut self, labels: &[String]) {
self.total_file_count = labels.len();
if !self.enabled || labels.len() <= 1 { if !self.enabled || labels.len() <= 1 {
// No bars in single-file mode or when disabled
self.enabled = false; self.enabled = false;
return; return;
} }
let mp = MultiProgress::new(); let total = cliclack::progress_bar(labels.len() as u64);
// Aggregate bar at the top total.start("Total");
let total = mp.add(ProgressBar::new(labels.len() as u64)); self.total_bar = Some(total);
total.set_style(ProgressStyle::with_template("{prefix} [{bar:40.cyan/blue}] {pos}/{len}")
.unwrap()
.progress_chars("=>-"));
total.set_prefix("Total");
self.total = Some(total);
// Per-file bars
for label in labels { for label in labels {
let pb = mp.add(ProgressBar::new(100)); let pb = cliclack::progress_bar(100);
pb.set_style(ProgressStyle::with_template("{prefix} [{bar:40.green/black}] {pos}% {msg}") pb.start(label);
.unwrap() self.file_bars.push(pb);
.progress_chars("=>-"));
pb.set_position(0);
pb.set_prefix(label.clone());
self.per.push(pb);
} }
self.mp = Some(mp);
} }
/// Returns true when bars are enabled (multi-file TTY mode). pub fn is_enabled(&self) -> bool {
pub fn is_enabled(&self) -> bool { self.enabled } self.enabled
/// Get a clone of the per-file progress bar at index, if enabled.
pub fn per_bar(&self, idx: usize) -> Option<ProgressBar> {
if !self.enabled { return None; }
self.per.get(idx).cloned()
} }
/// Get a clone of the aggregate (total) progress bar, if enabled. pub fn set_file_message(&mut self, idx: usize, message: &str) {
pub fn total_bar(&self) -> Option<ProgressBar> { if !self.enabled {
if !self.enabled { return None; } return;
self.total.as_ref().cloned() }
if let Some(pb) = self.file_bars.get_mut(idx) {
pb.set_message(message);
}
}
pub fn set_file_percent(&mut self, idx: usize, percent: u64) {
if !self.enabled {
return;
}
if let Some(pb) = self.file_bars.get_mut(idx) {
let p = percent.min(100);
pb.set_message(format!("{p}%"));
}
} }
/// Mark a file as finished (set to 100% and update total counter).
pub fn mark_file_done(&mut self, idx: usize) { pub fn mark_file_done(&mut self, idx: usize) {
if !self.enabled { return; } if !self.enabled {
if let Some(pb) = self.per.get(idx) { return;
pb.set_position(100); }
pb.finish_with_message("done"); if let Some(pb) = self.file_bars.get_mut(idx) {
pb.stop("done");
} }
self.completed += 1; self.completed += 1;
if let Some(total) = &self.total { total.set_position(self.completed as u64); } if let Some(total) = &mut self.total_bar {
total.inc(1);
if self.completed >= self.total_file_count {
total.stop("all done");
}
}
}
pub fn finish_total(&mut self, message: &str) {
if !self.enabled {
return;
}
if let Some(total) = &mut self.total_bar {
total.stop(message);
}
}
}
#[derive(Debug)]
pub struct ProgressReporter {
non_interactive: bool,
}
impl ProgressReporter {
pub fn new(non_interactive: bool) -> Self {
Self { non_interactive }
}
pub fn step(&mut self, message: &str) {
if self.non_interactive {
let _ = cliclack::log::info(format!("[..] {message}"));
} else {
let _ = cliclack::log::info(format!("{message}"));
}
}
pub fn finish_with_message(&mut self, message: &str) {
if self.non_interactive {
let _ = cliclack::log::info(format!("[ok] {message}"));
} else {
let _ = cliclack::log::info(format!("{message}"));
}
} }
} }

View File

@@ -1,17 +1,12 @@
[package] [package]
name = "polyscribe-host" name = "polyscribe-host"
version = "0.1.0" version.workspace = true
edition = "2024" edition.workspace = true
license = "MIT"
[dependencies] [dependencies]
anyhow = "1.0.98" anyhow = { workspace = true }
thiserror = "1" serde = { workspace = true, features = ["derive"] }
serde = { version = "1.0.219", features = ["derive"] } serde_json = { workspace = true }
serde_json = "1.0.142" tokio = { workspace = true, features = ["rt-multi-thread", "process", "io-util"] }
tokio = { version = "1", features = ["full"] } which = { workspace = true }
which = "6" directories = { workspace = true }
cliclack = "0.3"
directories = "5"
polyscribe = { path = "../polyscribe-core" }
polyscribe-protocol = { path = "../polyscribe-protocol" }

View File

@@ -1,168 +1,119 @@
// SPDX-License-Identifier: MIT use anyhow::{Context, Result};
use std::process::Stdio;
use anyhow::{anyhow, Context, Result}; use std::{
use cliclack as ui; // reuse for minimal logging env, fs,
use directories::BaseDirs; os::unix::fs::PermissionsExt,
use serde_json::Value; path::Path,
use std::collections::BTreeMap; };
use std::ffi::OsStr; use tokio::{
use std::fs; io::{AsyncBufReadExt, BufReader},
use std::io::{BufRead, BufReader, Write}; process::{Child as TokioChild, Command},
use std::path::{Path, PathBuf}; };
use std::process::{Command, Stdio};
use polyscribe_protocol as psp;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Plugin { pub struct PluginInfo {
pub name: String, pub name: String,
pub path: PathBuf, pub path: String,
} }
/// Discover plugins on PATH and in the user's data dir (XDG) under polyscribe/plugins. #[derive(Debug, Default)]
pub fn discover() -> Result<Vec<Plugin>> { pub struct PluginManager;
let mut found: BTreeMap<String, PathBuf> = BTreeMap::new();
// Scan PATH directories impl PluginManager {
if let Some(path_var) = std::env::var_os("PATH") { pub fn list(&self) -> Result<Vec<PluginInfo>> {
for dir in std::env::split_paths(&path_var) { let mut plugins = Vec::new();
if dir.as_os_str().is_empty() { continue; }
if let Ok(rd) = fs::read_dir(&dir) { if let Ok(path) = env::var("PATH") {
for ent in rd.flatten() { for dir in env::split_paths(&path) {
let p = ent.path(); scan_dir_for_plugins(&dir, &mut plugins);
if !is_executable(&p) { continue; }
if let Some(fname) = p.file_name().and_then(OsStr::to_str) {
if let Some(name) = fname.strip_prefix("polyscribe-plugin-") {
found.entry(name.to_string()).or_insert(p);
}
}
}
} }
} }
if let Some(dirs) = directories::ProjectDirs::from("dev", "polyscribe", "polyscribe") {
let plugin_dir = dirs.data_dir().join("plugins");
scan_dir_for_plugins(&plugin_dir, &mut plugins);
}
plugins.sort_by(|a, b| a.path.cmp(&b.path));
plugins.dedup_by(|a, b| a.path == b.path);
Ok(plugins)
} }
// Scan user data dir pub fn info(&self, name: &str) -> Result<serde_json::Value> {
if let Some(base) = BaseDirs::new() { let bin = self.resolve(name)?;
let user_plugins = PathBuf::from(base.data_dir()).join("polyscribe").join("plugins"); let out = std::process::Command::new(&bin)
if let Ok(rd) = fs::read_dir(&user_plugins) { .arg("info")
for ent in rd.flatten() { .stdout(Stdio::piped())
let p = ent.path(); .stderr(Stdio::inherit())
if !is_executable(&p) { continue; } .spawn()
if let Some(fname) = p.file_name().and_then(OsStr::to_str) { .context("spawning plugin info")?
let name = fname.strip_prefix("polyscribe-plugin-") .wait_with_output()
.map(|s| s.to_string()) .context("waiting for plugin info")?;
.or_else(|| Some(fname.to_string()))
.unwrap(); let val: serde_json::Value =
found.entry(name).or_insert(p); serde_json::from_slice(&out.stdout).context("parsing plugin info JSON")?;
} Ok(val)
}
pub fn spawn(&self, name: &str, command: &str) -> Result<TokioChild> {
let bin = self.resolve(name)?;
let mut cmd = Command::new(&bin);
cmd.arg("run")
.arg(command)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit());
let child = cmd.spawn().context("spawning plugin run")?;
Ok(child)
}
pub async fn forward_stdio(&self, child: &mut TokioChild) -> Result<std::process::ExitStatus> {
if let Some(stdout) = child.stdout.take() {
let mut reader = BufReader::new(stdout).lines();
while let Some(line) = reader.next_line().await? {
println!("{line}");
} }
} }
Ok(child.wait().await?)
} }
Ok(found fn resolve(&self, name: &str) -> Result<String> {
.into_iter() let bin = format!("polyscribe-plugin-{name}");
.map(|(name, path)| Plugin { name, path }) let path =
.collect()) which::which(&bin).with_context(|| format!("plugin not found in PATH: {bin}"))?;
Ok(path.to_string_lossy().to_string())
}
} }
fn is_executable(p: &Path) -> bool { fn is_executable(path: &Path) -> bool {
if !p.is_file() { return false; } if !path.is_file() {
return false;
}
#[cfg(unix)] #[cfg(unix)]
{ {
use std::os::unix::fs::PermissionsExt; if let Ok(meta) = fs::metadata(path) {
if let Ok(md) = fs::metadata(p) { let mode = meta.permissions().mode();
let mode = md.permissions().mode(); return mode & 0o111 != 0;
return (mode & 0o111) != 0;
}
false
}
#[cfg(not(unix))]
{
// On Windows, consider .exe, .bat, .cmd
matches!(p.extension().and_then(|s| s.to_str()).map(|s| s.to_lowercase()), Some(ext) if matches!(ext.as_str(), "exe"|"bat"|"cmd"))
}
}
/// Query plugin capabilities by invoking `--capabilities`.
pub fn capabilities(plugin_path: &Path) -> Result<psp::Capabilities> {
let out = Command::new(plugin_path)
.arg("--capabilities")
.stdout(Stdio::piped())
.stderr(Stdio::null())
.output()
.with_context(|| format!("Failed to execute plugin: {}", plugin_path.display()))?;
if !out.status.success() {
return Err(anyhow!("Plugin --capabilities failed: {}", plugin_path.display()));
}
let s = String::from_utf8(out.stdout).context("capabilities stdout not utf-8")?;
let caps: psp::Capabilities = serde_json::from_str(s.trim()).context("invalid capabilities JSON")?;
Ok(caps)
}
/// Run a single method via `--serve`, writing one JSON-RPC request and streaming until result.
pub fn run_method<F>(plugin_path: &Path, method: &str, params: Value, mut on_progress: F) -> Result<Value>
where
F: FnMut(psp::Progress),
{
let mut child = Command::new(plugin_path)
.arg("--serve")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null())
.spawn()
.with_context(|| format!("Failed to spawn plugin: {}", plugin_path.display()))?;
let mut stdin = child.stdin.take().ok_or_else(|| anyhow!("failed to open plugin stdin"))?;
let stdout = child.stdout.take().ok_or_else(|| anyhow!("failed to open plugin stdout"))?;
// Send request line
let req = psp::JsonRpcRequest { jsonrpc: "2.0".into(), id: "1".into(), method: method.to_string(), params: Some(params) };
let line = serde_json::to_string(&req)? + "\n";
stdin.write_all(line.as_bytes())?;
stdin.flush()?;
// Read response lines
let reader = BufReader::new(stdout);
for line in reader.lines() {
let line = line?;
if line.trim().is_empty() { continue; }
// Try parse StreamItem; if that fails, try parse JsonRpcResponse directly
if let Ok(item) = serde_json::from_str::<psp::StreamItem>(&line) {
match item {
psp::StreamItem::Progress(p) => {
on_progress(p);
}
psp::StreamItem::Result(resp) => {
match resp.outcome {
psp::JsonRpcOutcome::Ok { result } => return Ok(result),
psp::JsonRpcOutcome::Err { error } => return Err(anyhow!("{} ({})", error.message, error.code)),
}
}
}
} else if let Ok(resp) = serde_json::from_str::<psp::JsonRpcResponse>(&line) {
match resp.outcome {
psp::JsonRpcOutcome::Ok { result } => return Ok(result),
psp::JsonRpcOutcome::Err { error } => return Err(anyhow!("{} ({})", error.message, error.code)),
}
} else {
let _ = ui::log::warning(format!("Unrecognized plugin output: {}", line));
} }
} }
true
}
// If we exited loop without returning, wait for child fn scan_dir_for_plugins(dir: &Path, out: &mut Vec<PluginInfo>) {
let status = child.wait()?; if let Ok(read_dir) = fs::read_dir(dir) {
if status.success() { for entry in read_dir.flatten() {
Err(anyhow!("Plugin terminated without sending a result")) let path = entry.path();
} else { if let Some(fname) = path.file_name().and_then(|s| s.to_str())
Err(anyhow!("Plugin exited with status: {:?}", status)) && fname.starts_with("polyscribe-plugin-")
&& is_executable(&path)
{
let name = fname.trim_start_matches("polyscribe-plugin-").to_string();
out.push(PluginInfo {
name,
path: path.to_string_lossy().to_string(),
});
}
}
} }
} }
/// Helper: find a plugin by name using discovery
pub fn find_plugin_by_name(name: &str) -> Result<Plugin> {
let plugins = discover()?;
plugins
.into_iter()
.find(|p| p.name == name)
.ok_or_else(|| anyhow!("Plugin '{}' not found", name))
}

View File

@@ -1,10 +1,8 @@
[package] [package]
name = "polyscribe-protocol" name = "polyscribe-protocol"
version = "0.1.0" version.workspace = true
edition = "2024" edition.workspace = true
license = "MIT"
[dependencies] [dependencies]
serde = { version = "1.0.219", features = ["derive"] } serde = { workspace = true, features = ["derive"] }
serde_json = "1.0.142" serde_json = { workspace = true }
thiserror = "1"

View File

@@ -1,90 +1,60 @@
// SPDX-License-Identifier: MIT
// PolyScribe Protocol (PSP/1): JSON-RPC 2.0 over NDJSON on stdio
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value;
/// Plugin capabilities as reported by `--capabilities`. #[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Request {
pub struct Capabilities {
pub name: String,
pub version: String,
/// Protocol identifier (e.g., "psp/1")
pub protocol: String,
/// Role (e.g., pipeline, tool, generator)
pub role: String,
/// Supported command names
pub commands: Vec<String>,
}
/// Generic JSON-RPC 2.0 request for PSP/1
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcRequest {
pub jsonrpc: String, // "2.0"
pub id: String, pub id: String,
pub method: String, pub method: String,
#[serde(skip_serializing_if = "Option::is_none")] pub params: Option<Value>,
pub params: Option<serde_json::Value>,
} }
/// Error object for JSON-RPC 2.0 #[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Response {
pub struct JsonRpcError { pub id: String,
pub code: i64, pub result: Option<Value>,
pub error: Option<ErrorObj>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ErrorObj {
pub code: i32,
pub message: String, pub message: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>, pub data: Option<Value>,
} }
/// Generic JSON-RPC 2.0 response for PSP/1 #[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "event", content = "data")]
#[serde(tag = "kind", rename_all = "lowercase")] pub enum ProgressEvent {
pub enum StreamItem { Started,
/// Progress notification (out-of-band in stream, not a JSON-RPC response) Message(String),
Progress(Progress), Percent(f32),
/// A proper JSON-RPC response with a result Finished,
Result(JsonRpcResponse),
} }
/// JSON-RPC 2.0 Response envelope containing either result or error. impl Response {
#[derive(Debug, Clone, Serialize, Deserialize)] pub fn ok(id: impl Into<String>, result: Value) -> Self {
pub struct JsonRpcResponse { Self {
pub jsonrpc: String, // "2.0"
pub id: String,
#[serde(flatten)]
pub outcome: JsonRpcOutcome,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum JsonRpcOutcome {
Ok { result: serde_json::Value },
Err { error: JsonRpcError },
}
/// Progress event structure for PSP/1 streaming
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Progress {
/// 0..=100
pub pct: u8,
/// Short phase name
pub stage: Option<String>,
/// Human-friendly detail
pub message: Option<String>,
}
/// Convenience helpers to build items
impl StreamItem {
pub fn progress(pct: u8, stage: impl Into<Option<String>>, message: impl Into<Option<String>>) -> Self {
StreamItem::Progress(Progress { pct, stage: stage.into(), message: message.into() })
}
pub fn ok(id: impl Into<String>, result: serde_json::Value) -> Self {
StreamItem::Result(JsonRpcResponse { jsonrpc: "2.0".into(), id: id.into(), outcome: JsonRpcOutcome::Ok { result } })
}
pub fn err(id: impl Into<String>, code: i64, message: impl Into<String>, data: Option<serde_json::Value>) -> Self {
StreamItem::Result(JsonRpcResponse {
jsonrpc: "2.0".into(),
id: id.into(), id: id.into(),
outcome: JsonRpcOutcome::Err { error: JsonRpcError { code, message: message.into(), data } }, result: Some(result),
}) error: None,
}
}
pub fn err(
id: impl Into<String>,
code: i32,
message: impl Into<String>,
data: Option<Value>,
) -> Self {
Self {
id: id.into(),
result: None,
error: Some(ErrorObj {
code,
message: message.into(),
data,
}),
}
} }
} }

View File

@@ -1,26 +0,0 @@
# CI checklist and job outline
Checklist to keep docs and code healthy in CI
- Build: cargo build --all-targets --locked
- Tests: cargo test --all --locked
- Lints: cargo clippy --all-targets -- -D warnings
- Optional: check README and docs snippets (basic smoke run of examples scripts)
- bash examples/update_models.sh (can be skipped offline)
- bash examples/transcribe_file.sh (use a tiny sample file if available)
Example GitHub Actions job (outline)
- name: Rust
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- name: Build
run: cargo build --all-targets --locked
- name: Test
run: cargo test --all --locked
- name: Clippy
run: cargo clippy --all-targets -- -D warnings
Notes
- For GPU features, set up appropriate runners and add `--features gpu-cuda|gpu-hip|gpu-vulkan` where applicable.
- For docs-only changes, jobs still build/test to ensure doctests and examples compile when enabled.

View File

@@ -32,18 +32,20 @@ Run locally
Models during development Models during development
- Interactive downloader: - Interactive downloader:
- cargo run -- --download-models - cargo run -- models download
- Non-interactive update (checks sizes/hashes, downloads if missing): - Non-interactive update (checks sizes/hashes, downloads if missing):
- cargo run -- --update-models --no-interaction -q - cargo run -- models update --no-interaction -q
Tests Tests
- Run all tests: - Run all tests:
- cargo test - cargo test
- The test suite includes CLI-oriented integration tests and unit tests. Some tests simulate GPU detection using env vars (POLYSCRIBE_TEST_FORCE_*). Do not rely on these flags in production code. - The test suite includes CLI-oriented integration tests and unit tests. Some tests simulate GPU detection using env vars (POLYSCRIBE_TEST_FORCE_*). Do not rely on these flags in production code.
Clippy Clippy & formatting
- Run lint checks and treat warnings as errors: - Run lint checks and treat warnings as errors:
- cargo clippy --all-targets -- -D warnings - cargo clippy --all-targets -- -D warnings
- Check formatting:
- cargo fmt --all -- --check
- Common warnings can often be fixed by simplifying code, removing unused imports, and following idiomatic patterns. - Common warnings can often be fixed by simplifying code, removing unused imports, and following idiomatic patterns.
Code layout Code layout
@@ -61,10 +63,10 @@ Adding a feature
Running the model downloader Running the model downloader
- Interactive: - Interactive:
- cargo run -- --download-models - cargo run -- models download
- Non-interactive suggestions for CI: - Non-interactive suggestions for CI:
- POLYSCRIBE_MODELS_DIR=$PWD/models \ - POLYSCRIBE_MODELS_DIR=$PWD/models \
cargo run -- --update-models --no-interaction -q cargo run -- models update --no-interaction -q
Env var examples for local testing Env var examples for local testing
- Use a local copy of models and a specific model file: - Use a local copy of models and a specific model file:

View File

@@ -30,10 +30,10 @@ CLI reference
- Choose runtime backend. Default is auto (prefers CUDA → HIP → Vulkan → CPU), depending on detection. - Choose runtime backend. Default is auto (prefers CUDA → HIP → Vulkan → CPU), depending on detection.
- --gpu-layers N - --gpu-layers N
- Number of layers to offload to the GPU when supported. - Number of layers to offload to the GPU when supported.
- --download-models - models download
- Launch interactive model downloader (lists Hugging Face models; multi-select to download). - Launch interactive model downloader (lists Hugging Face models; multi-select to download).
- Controls: Use Up/Down to navigate, Space to toggle selections, and Enter to confirm. Models are grouped by base (e.g., tiny, base, small). - Controls: Use Up/Down to navigate, Space to toggle selections, and Enter to confirm. Models are grouped by base (e.g., tiny, base, small).
- --update-models - models update
- Verify/update local models by comparing sizes and hashes with the upstream manifest. - Verify/update local models by comparing sizes and hashes with the upstream manifest.
- -v, --verbose (repeatable) - -v, --verbose (repeatable)
- Increase log verbosity; use -vv for very detailed logs. - Increase log verbosity; use -vv for very detailed logs.
@@ -42,6 +42,9 @@ CLI reference
- --no-interaction - --no-interaction
- Disable all interactive prompts (for CI). Combine with env vars to control behavior. - Disable all interactive prompts (for CI). Combine with env vars to control behavior.
- Subcommands: - Subcommands:
- models download: Launch interactive model downloader.
- models update: Verify/update local models (non-interactive).
- plugins list|info|run: Discover and run plugins.
- completions <shell>: Write shell completion script to stdout. - completions <shell>: Write shell completion script to stdout.
- man: Write a man page to stdout. - man: Write a man page to stdout.

View File

@@ -1,13 +0,0 @@
#!/usr/bin/env bash
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
set -euo pipefail
# Launch the interactive model downloader and select models to install
BIN=${BIN:-./target/release/polyscribe}
MODELS_DIR=${POLYSCRIBE_MODELS_DIR:-$PWD/models}
export POLYSCRIBE_MODELS_DIR="$MODELS_DIR"
mkdir -p "$MODELS_DIR"
"$BIN" --download-models

View File

@@ -1,15 +0,0 @@
#!/usr/bin/env bash
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
set -euo pipefail
# Transcribe an audio/video file to JSON and SRT into ./output
# Requires a model; first run may prompt to download.
BIN=${BIN:-./target/release/polyscribe}
INPUT=${1:-samples/example.mp3}
OUTDIR=${OUTDIR:-output}
mkdir -p "$OUTDIR"
"$BIN" -v -o "$OUTDIR" "$INPUT"
echo "Done. See $OUTDIR for JSON/SRT files."

View File

@@ -1,15 +0,0 @@
#!/usr/bin/env bash
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
set -euo pipefail
# Verify/update local models non-interactively (useful in CI)
BIN=${BIN:-./target/release/polyscribe}
MODELS_DIR=${POLYSCRIBE_MODELS_DIR:-$PWD/models}
export POLYSCRIBE_MODELS_DIR="$MODELS_DIR"
mkdir -p "$MODELS_DIR"
"$BIN" --update-models --no-interaction -q
echo "Models updated in $MODELS_DIR"

View File

@@ -1,5 +1,4 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Stub plugin: tubescribe
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use clap::Parser; use clap::Parser;
@@ -36,7 +35,6 @@ fn main() -> Result<()> {
serve_once()?; serve_once()?;
return Ok(()); return Ok(());
} }
// Default: show capabilities (friendly behavior if run without flags)
let caps = psp::Capabilities { let caps = psp::Capabilities {
name: "tubescribe".to_string(), name: "tubescribe".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(), version: env!("CARGO_PKG_VERSION").to_string(),
@@ -49,14 +47,12 @@ fn main() -> Result<()> {
} }
fn serve_once() -> Result<()> { fn serve_once() -> Result<()> {
// Read exactly one line (one request)
let stdin = std::io::stdin(); let stdin = std::io::stdin();
let mut reader = BufReader::new(stdin.lock()); let mut reader = BufReader::new(stdin.lock());
let mut line = String::new(); let mut line = String::new();
reader.read_line(&mut line).context("failed to read request line")?; reader.read_line(&mut line).context("failed to read request line")?;
let req: psp::JsonRpcRequest = serde_json::from_str(line.trim()).context("invalid JSON-RPC request")?; let req: psp::JsonRpcRequest = serde_json::from_str(line.trim()).context("invalid JSON-RPC request")?;
// Simulate doing some work with progress
emit(&psp::StreamItem::progress(5, Some("start".into()), Some("initializing".into())))?; emit(&psp::StreamItem::progress(5, Some("start".into()), Some("initializing".into())))?;
std::thread::sleep(std::time::Duration::from_millis(50)); std::thread::sleep(std::time::Duration::from_millis(50));
emit(&psp::StreamItem::progress(25, Some("probe".into()), Some("probing sources".into())))?; emit(&psp::StreamItem::progress(25, Some("probe".into()), Some("probing sources".into())))?;
@@ -65,7 +61,6 @@ fn serve_once() -> Result<()> {
std::thread::sleep(std::time::Duration::from_millis(50)); std::thread::sleep(std::time::Duration::from_millis(50));
emit(&psp::StreamItem::progress(90, Some("finalize".into()), Some("finalizing".into())))?; emit(&psp::StreamItem::progress(90, Some("finalize".into()), Some("finalizing".into())))?;
// Handle method and produce result
let result = match req.method.as_str() { let result = match req.method.as_str() {
"generate_metadata" => { "generate_metadata" => {
let title = "Canned title"; let title = "Canned title";
@@ -78,7 +73,6 @@ fn serve_once() -> Result<()> {
}) })
} }
other => { other => {
// Unknown method
let err = psp::StreamItem::err(req.id.clone(), -32601, format!("Method not found: {}", other), None); let err = psp::StreamItem::err(req.id.clone(), -32601, format!("Method not found: {}", other), None);
emit(&err)?; emit(&err)?;
return Ok(()); return Ok(());

6
rust-toolchain.toml Normal file
View File

@@ -0,0 +1,6 @@
# SPDX-License-Identifier: MIT
[toolchain]
channel = "1.89.0"
components = ["clippy", "rustfmt"]
profile = "minimal"

View File

@@ -1,329 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
//! Transcription backend selection and implementations (CPU/GPU) used by PolyScribe.
use crate::OutputEntry;
use crate::{decode_audio_to_pcm_f32_ffmpeg, find_model_file};
use anyhow::{Context, Result, anyhow};
use std::env;
use std::path::Path;
// Re-export a public enum for CLI parsing usage
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
/// Kind of transcription backend to use.
pub enum BackendKind {
/// Automatically detect the best available backend (CUDA > HIP > Vulkan > CPU).
Auto,
/// Pure CPU backend using whisper-rs.
Cpu,
/// NVIDIA CUDA backend (requires CUDA runtime available at load time and proper feature build).
Cuda,
/// AMD ROCm/HIP backend (requires hip/rocBLAS libraries available and proper feature build).
Hip,
/// Vulkan backend (experimental; requires Vulkan loader/SDK and feature build).
Vulkan,
}
/// Abstraction for a transcription backend.
pub trait TranscribeBackend {
/// Backend kind implemented by this type.
fn kind(&self) -> BackendKind;
/// Transcribe the given audio and return transcript entries.
fn transcribe(
&self,
audio_path: &Path,
speaker: &str,
language: Option<&str>,
gpu_layers: Option<u32>,
progress: Option<&(dyn Fn(i32) + Send + Sync)>,
) -> Result<Vec<OutputEntry>>;
}
fn check_lib(_names: &[&str]) -> bool {
#[cfg(test)]
{
// During unit tests, avoid touching system libs to prevent loader crashes in CI.
false
}
#[cfg(not(test))]
{
// Disabled runtime dlopen probing to avoid loader instability; rely on environment overrides.
false
}
}
fn cuda_available() -> bool {
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_CUDA") {
return x == "1";
}
check_lib(&[
"libcudart.so",
"libcudart.so.12",
"libcudart.so.11",
"libcublas.so",
"libcublas.so.12",
])
}
fn hip_available() -> bool {
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_HIP") {
return x == "1";
}
check_lib(&["libhipblas.so", "librocblas.so"])
}
fn vulkan_available() -> bool {
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_VULKAN") {
return x == "1";
}
check_lib(&["libvulkan.so.1", "libvulkan.so"])
}
/// CPU-based transcription backend using whisper-rs.
#[derive(Default)]
pub struct CpuBackend;
/// CUDA-accelerated transcription backend for NVIDIA GPUs.
#[derive(Default)]
pub struct CudaBackend;
/// ROCm/HIP-accelerated transcription backend for AMD GPUs.
#[derive(Default)]
pub struct HipBackend;
/// Vulkan-based transcription backend (experimental/incomplete).
#[derive(Default)]
pub struct VulkanBackend;
macro_rules! impl_whisper_backend {
($ty:ty, $kind:expr) => {
impl TranscribeBackend for $ty {
fn kind(&self) -> BackendKind { $kind }
fn transcribe(
&self,
audio_path: &Path,
speaker: &str,
language: Option<&str>,
_gpu_layers: Option<u32>,
progress: Option<&(dyn Fn(i32) + Send + Sync)>,
) -> Result<Vec<OutputEntry>> {
transcribe_with_whisper_rs(audio_path, speaker, language, progress)
}
}
};
}
impl_whisper_backend!(CpuBackend, BackendKind::Cpu);
impl_whisper_backend!(CudaBackend, BackendKind::Cuda);
impl_whisper_backend!(HipBackend, BackendKind::Hip);
impl TranscribeBackend for VulkanBackend {
fn kind(&self) -> BackendKind {
BackendKind::Vulkan
}
fn transcribe(
&self,
_audio_path: &Path,
_speaker: &str,
_language: Option<&str>,
_gpu_layers: Option<u32>,
_progress: Option<&(dyn Fn(i32) + Send + Sync)>,
) -> Result<Vec<OutputEntry>> {
Err(anyhow!(
"Vulkan backend not yet wired to whisper.cpp FFI. Build with --features gpu-vulkan and ensure Vulkan SDK is installed. How to fix: install Vulkan loader (libvulkan), set VULKAN_SDK, and run cargo build --features gpu-vulkan."
))
}
}
/// Result of choosing a transcription backend.
pub struct SelectionResult {
/// The constructed backend instance to perform transcription with.
pub backend: Box<dyn TranscribeBackend + Send + Sync>,
/// Which backend kind was ultimately selected.
pub chosen: BackendKind,
/// Which backend kinds were detected as available on this system.
pub detected: Vec<BackendKind>,
}
/// Select an appropriate backend based on user request and system detection.
///
/// If `requested` is `BackendKind::Auto`, the function prefers CUDA, then HIP,
/// then Vulkan, falling back to CPU when no GPU backend is detected. When a
/// specific GPU backend is requested but unavailable, an error is returned with
/// guidance on how to enable it.
///
/// Set `verbose` to true to print detection/selection info to stderr.
pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<SelectionResult> {
let mut detected = Vec::new();
if cuda_available() {
detected.push(BackendKind::Cuda);
}
if hip_available() {
detected.push(BackendKind::Hip);
}
if vulkan_available() {
detected.push(BackendKind::Vulkan);
}
let instantiate_backend = |k: BackendKind| -> Box<dyn TranscribeBackend + Send + Sync> {
match k {
BackendKind::Cpu => Box::new(CpuBackend::default()),
BackendKind::Cuda => Box::new(CudaBackend::default()),
BackendKind::Hip => Box::new(HipBackend::default()),
BackendKind::Vulkan => Box::new(VulkanBackend::default()),
BackendKind::Auto => Box::new(CpuBackend::default()), // placeholder for Auto
}
};
let chosen = match requested {
BackendKind::Auto => {
if detected.contains(&BackendKind::Cuda) {
BackendKind::Cuda
} else if detected.contains(&BackendKind::Hip) {
BackendKind::Hip
} else if detected.contains(&BackendKind::Vulkan) {
BackendKind::Vulkan
} else {
BackendKind::Cpu
}
}
BackendKind::Cuda => {
if detected.contains(&BackendKind::Cuda) {
BackendKind::Cuda
} else {
return Err(anyhow!(
"Requested CUDA backend but CUDA libraries/devices not detected. How to fix: install NVIDIA driver + CUDA toolkit, ensure libcudart/libcublas are in loader path, and build with --features gpu-cuda."
));
}
}
BackendKind::Hip => {
if detected.contains(&BackendKind::Hip) {
BackendKind::Hip
} else {
return Err(anyhow!(
"Requested ROCm/HIP backend but libraries/devices not detected. How to fix: install ROCm hipBLAS/rocBLAS, ensure libs are in loader path, and build with --features gpu-hip."
));
}
}
BackendKind::Vulkan => {
if detected.contains(&BackendKind::Vulkan) {
BackendKind::Vulkan
} else {
return Err(anyhow!(
"Requested Vulkan backend but libvulkan not detected. How to fix: install Vulkan loader/SDK and build with --features gpu-vulkan."
));
}
}
BackendKind::Cpu => BackendKind::Cpu,
};
if verbose {
crate::dlog!(1, "Detected backends: {:?}", detected);
crate::dlog!(1, "Selected backend: {:?}", chosen);
}
Ok(SelectionResult {
backend: instantiate_backend(chosen),
chosen,
detected,
})
}
// Internal helper: transcription using whisper-rs with CPU/GPU (depending on build features)
#[allow(clippy::too_many_arguments)]
pub(crate) fn transcribe_with_whisper_rs(
audio_path: &Path,
speaker: &str,
language: Option<&str>,
progress: Option<&(dyn Fn(i32) + Send + Sync)>,
) -> Result<Vec<OutputEntry>> {
let report = |p: i32| {
if let Some(cb) = progress { cb(p); }
};
report(0);
let pcm_samples = decode_audio_to_pcm_f32_ffmpeg(audio_path)?;
report(5);
let model_path = find_model_file()?;
let english_only_model = model_path
.file_name()
.and_then(|s| s.to_str())
.map(|s| s.contains(".en.") || s.ends_with(".en.bin"))
.unwrap_or(false);
if let Some(lang) = language {
if english_only_model && lang != "en" {
return Err(anyhow!(
"Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.",
model_path.display(),
lang
));
}
}
let model_path_str = model_path
.to_str()
.ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model_path.display()))?;
if crate::verbose_level() < 2 {
// Some builds of whisper/ggml expect these env vars; harmless if unknown
unsafe {
std::env::set_var("GGML_LOG_LEVEL", "0");
std::env::set_var("WHISPER_PRINT_PROGRESS", "0");
}
}
let (_context, mut state) = crate::with_suppressed_stderr(|| {
let params = whisper_rs::WhisperContextParameters::default();
let context = whisper_rs::WhisperContext::new_with_params(model_path_str, params)
.with_context(|| format!("Failed to load Whisper model at {}", model_path.display()))?;
let state = context
.create_state()
.map_err(|e| anyhow!("Failed to create Whisper state: {:?}", e))?;
Ok::<_, anyhow::Error>((context, state))
})?;
report(20);
let mut full_params =
whisper_rs::FullParams::new(whisper_rs::SamplingStrategy::Greedy { best_of: 1 });
let threads = std::thread::available_parallelism()
.map(|n| n.get() as i32)
.unwrap_or(1);
full_params.set_n_threads(threads);
full_params.set_translate(false);
if let Some(lang) = language {
full_params.set_language(Some(lang));
}
report(30);
crate::with_suppressed_stderr(|| {
report(40);
state
.full(full_params, &pcm_samples)
.map_err(|e| anyhow!("Whisper full() failed: {:?}", e))
})?;
report(90);
let num_segments = state
.full_n_segments()
.map_err(|e| anyhow!("Failed to get segments: {:?}", e))?;
let mut entries = Vec::new();
for seg_idx in 0..num_segments {
let segment_text = state
.full_get_segment_text(seg_idx)
.map_err(|e| anyhow!("Failed to get segment text: {:?}", e))?;
let t0 = state
.full_get_segment_t0(seg_idx)
.map_err(|e| anyhow!("Failed to get segment t0: {:?}", e))?;
let t1 = state
.full_get_segment_t1(seg_idx)
.map_err(|e| anyhow!("Failed to get segment t1: {:?}", e))?;
let start = (t0 as f64) * 0.01;
let end = (t1 as f64) * 0.01;
entries.push(OutputEntry {
id: 0,
speaker: speaker.to_string(),
start,
end,
text: segment_text.trim().to_string(),
});
}
report(100);
Ok(entries)
}

View File

@@ -1,571 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
#![forbid(elided_lifetimes_in_paths)]
#![forbid(unused_must_use)]
#![deny(missing_docs)]
#![warn(clippy::all)]
//! PolyScribe library: business logic and core types.
//!
//! This crate exposes the reusable parts of the PolyScribe CLI as a library.
//! The binary entry point (main.rs) remains a thin CLI wrapper.
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
// Global runtime flags
static QUIET: AtomicBool = AtomicBool::new(false);
static NO_INTERACTION: AtomicBool = AtomicBool::new(false);
static VERBOSE: AtomicU8 = AtomicU8::new(0);
static NO_PROGRESS: AtomicBool = AtomicBool::new(false);
/// Set quiet mode: when true, non-interactive logs should be suppressed.
pub fn set_quiet(enabled: bool) {
QUIET.store(enabled, Ordering::Relaxed);
}
/// Return current quiet mode state.
pub fn is_quiet() -> bool {
QUIET.load(Ordering::Relaxed)
}
/// Set non-interactive mode: when true, interactive prompts must be skipped.
pub fn set_no_interaction(enabled: bool) {
NO_INTERACTION.store(enabled, Ordering::Relaxed);
}
/// Return current non-interactive state.
pub fn is_no_interaction() -> bool {
NO_INTERACTION.load(Ordering::Relaxed)
}
/// Set verbose level (0 = normal, 1 = verbose, 2 = super-verbose)
pub fn set_verbose(level: u8) {
VERBOSE.store(level, Ordering::Relaxed);
}
/// Get current verbose level.
pub fn verbose_level() -> u8 {
VERBOSE.load(Ordering::Relaxed)
}
/// Disable interactive progress indicators (bars/spinners)
pub fn set_no_progress(enabled: bool) {
NO_PROGRESS.store(enabled, Ordering::Relaxed);
}
/// Return current no-progress state
pub fn is_no_progress() -> bool {
NO_PROGRESS.load(Ordering::Relaxed)
}
/// Check whether stdin is connected to a TTY. Used to avoid blocking prompts when not interactive.
pub fn stdin_is_tty() -> bool {
use std::io::IsTerminal as _;
std::io::stdin().is_terminal()
}
/// A guard that temporarily redirects stderr to /dev/null on Unix when quiet mode is active.
/// No-op on non-Unix or when quiet is disabled. Restores stderr on drop.
pub struct StderrSilencer {
#[cfg(unix)]
old_stderr_fd: i32,
#[cfg(unix)]
devnull_fd: i32,
active: bool,
}
impl StderrSilencer {
/// Activate stderr silencing if quiet is set and on Unix; otherwise returns a no-op guard.
pub fn activate_if_quiet() -> Self {
if !is_quiet() {
return Self {
active: false,
#[cfg(unix)]
old_stderr_fd: -1,
#[cfg(unix)]
devnull_fd: -1,
};
}
Self::activate()
}
/// Activate stderr silencing unconditionally (used internally); no-op on non-Unix.
pub fn activate() -> Self {
#[cfg(unix)]
unsafe {
let old_fd = dup(2);
if old_fd < 0 {
return Self {
active: false,
old_stderr_fd: -1,
devnull_fd: -1,
};
}
// Open /dev/null for writing
let devnull_cstr = std::ffi::CString::new("/dev/null").unwrap();
let devnull_fd = open(devnull_cstr.as_ptr(), O_WRONLY);
if devnull_fd < 0 {
close(old_fd);
return Self {
active: false,
old_stderr_fd: -1,
devnull_fd: -1,
};
}
if dup2(devnull_fd, 2) < 0 {
close(devnull_fd);
close(old_fd);
return Self {
active: false,
old_stderr_fd: -1,
devnull_fd: -1,
};
}
Self {
active: true,
old_stderr_fd: old_fd,
devnull_fd: devnull_fd,
}
}
#[cfg(not(unix))]
{
Self { active: false }
}
}
}
impl Drop for StderrSilencer {
fn drop(&mut self) {
if !self.active {
return;
}
#[cfg(unix)]
unsafe {
let _ = dup2(self.old_stderr_fd, 2);
let _ = close(self.devnull_fd);
let _ = close(self.old_stderr_fd);
}
self.active = false;
}
}
/// Run a closure while temporarily suppressing stderr on Unix when appropriate.
/// On Windows/non-Unix, this is a no-op wrapper.
/// This helper uses RAII + panic catching to ensure restoration before resuming panic.
pub fn with_suppressed_stderr<F, T>(f: F) -> T
where
F: FnOnce() -> T,
{
// Suppress noisy native logs unless super-verbose (-vv) is enabled.
if verbose_level() < 2 {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _guard = StderrSilencer::activate();
f()
}));
match result {
Ok(value) => value,
Err(panic_payload) => std::panic::resume_unwind(panic_payload),
}
} else {
f()
}
}
/// Centralized UI helpers (TTY-aware, quiet/verbose-aware)
pub mod ui;
/// Logging macros and helpers
/// Log an error using the UI helper (always printed). Recommended for user-visible errors.
#[macro_export]
macro_rules! elog {
($($arg:tt)*) => {{
$crate::ui::error(format!($($arg)*));
}}
}
/// Log a warning using the UI helper (printed even in quiet mode).
#[macro_export]
macro_rules! wlog {
($($arg:tt)*) => {{
$crate::ui::warn(format!($($arg)*));
}}
}
/// Log an informational line using the UI helper unless quiet mode is enabled.
#[macro_export]
macro_rules! ilog {
($($arg:tt)*) => {{
if !$crate::is_quiet() { $crate::ui::info(format!($($arg)*)); }
}}
}
/// Log a debug/trace line when verbose level is at least the given level (u8).
#[macro_export]
macro_rules! dlog {
($lvl:expr, $($arg:tt)*) => {{
if !$crate::is_quiet() && $crate::verbose_level() >= $lvl { $crate::ui::info(format!("DEBUG{}: {}", $lvl, format!($($arg)*))); }
}}
}
/// Backward-compatibility: map old qlog! to ilog!
#[macro_export]
macro_rules! qlog {
($($arg:tt)*) => {{ $crate::ilog!($($arg)*); }}
}
use anyhow::{Context, Result, anyhow};
use chrono::Local;
use std::env;
use std::fs::create_dir_all;
use std::path::{Path, PathBuf};
use std::process::Command;
#[cfg(unix)]
use libc::{O_WRONLY, close, dup, dup2, open};
/// Re-export backend module (GPU/CPU selection and transcription).
pub mod backend;
/// Re-export models module (model listing/downloading/updating).
pub mod models;
/// Transcript entry for a single segment.
#[derive(Debug, serde::Serialize, Clone)]
pub struct OutputEntry {
/// Sequential id in output ordering.
pub id: u64,
/// Speaker label associated with the segment.
pub speaker: String,
/// Start time in seconds.
pub start: f64,
/// End time in seconds.
pub end: f64,
/// Text content.
pub text: String,
}
/// Return a YYYY-MM-DD date prefix string for output file naming.
pub fn date_prefix() -> String {
Local::now().format("%Y-%m-%d").to_string()
}
/// Format a floating-point number of seconds as SRT timestamp (HH:MM:SS,mmm).
pub fn format_srt_time(seconds: f64) -> String {
let total_ms = (seconds * 1000.0).round() as i64;
let ms = total_ms % 1000;
let total_secs = total_ms / 1000;
let sec = total_secs % 60;
let min = (total_secs / 60) % 60;
let hour = total_secs / 3600;
format!("{hour:02}:{min:02}:{sec:02},{ms:03}")
}
/// Render a list of transcript entries to SRT format.
pub fn render_srt(entries: &[OutputEntry]) -> String {
let mut srt = String::new();
for (index, entry) in entries.iter().enumerate() {
let srt_index = index + 1;
srt.push_str(&format!("{srt_index}\n"));
srt.push_str(&format!(
"{} --> {}\n",
format_srt_time(entry.start),
format_srt_time(entry.end)
));
if !entry.speaker.is_empty() {
srt.push_str(&format!("{}: {}\n", entry.speaker, entry.text));
} else {
srt.push_str(&format!("{}\n", entry.text));
}
srt.push('\n');
}
srt
}
/// Determine the default models directory, honoring POLYSCRIBE_MODELS_DIR override.
pub fn models_dir_path() -> PathBuf {
if let Ok(env_val) = env::var("POLYSCRIBE_MODELS_DIR") {
let env_path = PathBuf::from(env_val);
if !env_path.as_os_str().is_empty() {
return env_path;
}
}
if cfg!(debug_assertions) {
return PathBuf::from("models");
}
if let Ok(xdg) = env::var("XDG_DATA_HOME") {
if !xdg.is_empty() {
return PathBuf::from(xdg).join("polyscribe").join("models");
}
}
if let Ok(home) = env::var("HOME") {
if !home.is_empty() {
return PathBuf::from(home)
.join(".local")
.join("share")
.join("polyscribe")
.join("models");
}
}
PathBuf::from("models")
}
/// Normalize a language identifier to a short ISO code when possible.
pub fn normalize_lang_code(input: &str) -> Option<String> {
let mut lang = input.trim().to_lowercase();
if lang.is_empty() || lang == "auto" || lang == "c" || lang == "posix" {
return None;
}
if let Some((prefix, _)) = lang.split_once('.') {
lang = prefix.to_string();
}
if let Some((prefix, _)) = lang.split_once('_') {
lang = prefix.to_string();
}
let code = match lang.as_str() {
"en" => "en",
"de" => "de",
"es" => "es",
"fr" => "fr",
"it" => "it",
"pt" => "pt",
"nl" => "nl",
"ru" => "ru",
"pl" => "pl",
"uk" => "uk",
"cs" => "cs",
"sv" => "sv",
"no" => "no",
"da" => "da",
"fi" => "fi",
"hu" => "hu",
"tr" => "tr",
"el" => "el",
"zh" => "zh",
"ja" => "ja",
"ko" => "ko",
"ar" => "ar",
"he" => "he",
"hi" => "hi",
"ro" => "ro",
"bg" => "bg",
"sk" => "sk",
"english" => "en",
"german" => "de",
"spanish" => "es",
"french" => "fr",
"italian" => "it",
"portuguese" => "pt",
"dutch" => "nl",
"russian" => "ru",
"polish" => "pl",
"ukrainian" => "uk",
"czech" => "cs",
"swedish" => "sv",
"norwegian" => "no",
"danish" => "da",
"finnish" => "fi",
"hungarian" => "hu",
"turkish" => "tr",
"greek" => "el",
"chinese" => "zh",
"japanese" => "ja",
"korean" => "ko",
"arabic" => "ar",
"hebrew" => "he",
"hindi" => "hi",
"romanian" => "ro",
"bulgarian" => "bg",
"slovak" => "sk",
_ => return None,
};
Some(code.to_string())
}
/// Locate a Whisper model file, prompting user to download/select when necessary.
pub fn find_model_file() -> Result<PathBuf> {
let models_dir_buf = models_dir_path();
let models_dir = models_dir_buf.as_path();
if !models_dir.exists() {
create_dir_all(models_dir).with_context(|| {
format!(
"Failed to create models directory: {}",
models_dir.display()
)
})?;
}
if let Ok(env_model) = env::var("WHISPER_MODEL") {
let model_path = PathBuf::from(env_model);
if model_path.is_file() {
let _ = std::fs::write(models_dir.join(".last_model"), model_path.display().to_string());
return Ok(model_path);
}
}
// Non-interactive mode: automatic selection and optional download
if crate::is_no_interaction() {
if let Some(local) = crate::models::pick_best_local_model(models_dir) {
let _ = std::fs::write(models_dir.join(".last_model"), local.display().to_string());
return Ok(local);
} else {
ilog!("No local models found; downloading large-v3-turbo-q8_0...");
let path = crate::models::ensure_model_available_noninteractive("large-v3-turbo-q8_0")
.with_context(|| "Failed to download required model 'large-v3-turbo-q8_0'")?;
let _ = std::fs::write(models_dir.join(".last_model"), path.display().to_string());
return Ok(path);
}
}
let mut candidates: Vec<PathBuf> = Vec::new();
let dir_entries = std::fs::read_dir(models_dir)
.with_context(|| format!("Failed to read models directory: {}", models_dir.display()))?;
for entry in dir_entries {
let entry = entry?;
let path = entry.path();
if path.is_file() {
if let Some(ext) = path
.extension()
.and_then(|s| s.to_str())
.map(|s| s.to_lowercase())
{
if ext == "bin" {
candidates.push(path);
}
}
}
}
if candidates.is_empty() {
// No models found: prompt interactively (TTY only)
wlog!(
"{}",
format!(
"No Whisper model files (*.bin) found in {}.",
models_dir.display()
)
);
if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Err(anyhow!(
"No models available and interactive mode is disabled. Please set WHISPER_MODEL or run with --download-models."
));
}
let input = crate::ui::prompt_line("Would you like to download models now? [Y/n]: ").unwrap_or_default();
let answer = input.trim().to_lowercase();
if answer.is_empty() || answer == "y" || answer == "yes" {
if let Err(e) = models::run_interactive_model_downloader() {
elog!("Downloader failed: {:#}", e);
}
candidates.clear();
let dir_entries2 = std::fs::read_dir(models_dir).with_context(|| {
format!("Failed to read models directory: {}", models_dir.display())
})?;
for entry in dir_entries2 {
let entry = entry?;
let path = entry.path();
if path.is_file() {
if let Some(ext) = path
.extension()
.and_then(|s| s.to_str())
.map(|s| s.to_lowercase())
{
if ext == "bin" {
candidates.push(path);
}
}
}
}
}
}
if candidates.is_empty() {
return Err(anyhow!(
"No Whisper model files (*.bin) available in {}",
models_dir.display()
));
}
if candidates.len() == 1 {
let only_model = candidates.remove(0);
let _ = std::fs::write(models_dir.join(".last_model"), only_model.display().to_string());
return Ok(only_model);
}
let last_file = models_dir.join(".last_model");
if let Ok(previous_content) = std::fs::read_to_string(&last_file) {
let previous_content = previous_content.trim();
if !previous_content.is_empty() {
let previous_path = PathBuf::from(previous_content);
if previous_path.is_file() && candidates.iter().any(|c| c == &previous_path) {
return Ok(previous_path);
}
}
}
crate::ui::println_above_bars(format!("Multiple Whisper models found in {}:", models_dir.display()));
for (index, path) in candidates.iter().enumerate() {
crate::ui::println_above_bars(format!(" {}) {}", index + 1, path.display()));
}
let input = crate::ui::prompt_line(&format!("Select model by number [1-{}]: ", candidates.len()))
.map_err(|_| anyhow!("Failed to read selection"))?;
let selection: usize = input
.trim()
.parse()
.map_err(|_| anyhow!("Invalid selection: {}", input.trim()))?;
if selection == 0 || selection > candidates.len() {
return Err(anyhow!("Selection out of range"));
}
let chosen = candidates.swap_remove(selection - 1);
let _ = std::fs::write(models_dir.join(".last_model"), chosen.display().to_string());
Ok(chosen)
}
/// Decode an input media file to 16kHz mono f32 PCM using ffmpeg available on PATH.
pub fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> {
let output = match Command::new("ffmpeg")
.arg("-i")
.arg(audio_path)
.arg("-f")
.arg("f32le")
.arg("-ac")
.arg("1")
.arg("-ar")
.arg("16000")
.arg("pipe:1")
.output()
{
Ok(o) => o,
Err(e) => {
if e.kind() == std::io::ErrorKind::NotFound {
return Err(anyhow!(
"ffmpeg not found on PATH. Please install ffmpeg and ensure it is available."
));
} else {
return Err(anyhow!(
"Failed to execute ffmpeg for {}: {}",
audio_path.display(),
e
));
}
}
};
if !output.status.success() {
let stderr_str = String::from_utf8_lossy(&output.stderr);
return Err(anyhow!(
"Failed to decode audio from {} using ffmpeg. This may indicate the file is not a valid or supported audio/video file, is corrupted, or cannot be opened. ffmpeg stderr: {}",
audio_path.display(),
stderr_str.trim()
));
}
let data = output.stdout;
if data.len() % 4 != 0 {
let truncated = data.len() - (data.len() % 4);
let mut samples = Vec::with_capacity(truncated / 4);
for chunk in data[..truncated].chunks_exact(4) {
let arr = [chunk[0], chunk[1], chunk[2], chunk[3]];
samples.push(f32::from_le_bytes(arr));
}
Ok(samples)
} else {
let mut samples = Vec::with_capacity(data.len() / 4);
for chunk in data.chunks_exact(4) {
let arr = [chunk[0], chunk[1], chunk[2], chunk[3]];
samples.push(f32::from_le_bytes(arr));
}
Ok(samples)
}
}

View File

@@ -1,483 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
use std::fs::{File, create_dir_all};
use std::io::{self, Read, Write};
use std::path::{Path, PathBuf};
use anyhow::{Context, Result, anyhow};
use clap::{Parser, Subcommand, ValueEnum, CommandFactory};
use clap_complete::Shell;
use serde::{Deserialize, Serialize};
use polyscribe::{OutputEntry, date_prefix, normalize_lang_code, render_srt};
#[derive(Subcommand, Debug, Clone)]
enum AuxCommands {
Completions {
#[arg(value_enum)]
shell: Shell,
},
Man,
}
#[derive(ValueEnum, Debug, Clone, Copy)]
#[value(rename_all = "kebab-case")]
enum GpuBackendCli {
Auto,
Cpu,
Cuda,
Hip,
Vulkan,
}
#[derive(Parser, Debug)]
#[command(
name = "PolyScribe",
bin_name = "polyscribe",
version,
about = "Merge JSON transcripts or transcribe audio using native whisper"
)]
struct Args {
/// Increase verbosity (-v, -vv). Repeat to increase.
/// Debug logs appear with -v; very verbose with -vv. Logs go to stderr.
#[arg(short = 'v', long = "verbose", action = clap::ArgAction::Count, global = true)]
verbose: u8,
/// Quiet mode: suppress non-error logging on stderr (overrides -v)
/// Does not suppress interactive prompts or stdout output.
#[arg(short = 'q', long = "quiet", global = true)]
quiet: bool,
/// Non-interactive mode: never prompt; use defaults instead.
#[arg(long = "no-interaction", global = true)]
no_interaction: bool,
/// Disable interactive progress indicators (bars/spinners)
#[arg(long = "no-progress", global = true)]
no_progress: bool,
/// Optional auxiliary subcommands (completions, man)
#[command(subcommand)]
aux: Option<AuxCommands>,
/// Input .json transcript files or audio files to merge/transcribe
inputs: Vec<String>,
/// Output file path base or directory (date prefix added).
/// In merge mode: base path.
/// In separate mode: directory.
/// If omitted: prints JSON to stdout for merge mode; separate mode requires directory for multiple inputs.
#[arg(short, long, value_name = "FILE")]
output: Option<String>,
/// Merge all inputs into a single output; if not set, each input is written as a separate output
#[arg(short = 'm', long = "merge")]
merge: bool,
/// Merge and also write separate outputs per input; requires -o OUTPUT_DIR
#[arg(long = "merge-and-separate")]
merge_and_separate: bool,
/// Prompt for speaker names per input file
#[arg(long = "set-speaker-names")]
set_speaker_names: bool,
/// Language code to use for transcription (e.g., en, de). No auto-detection.
#[arg(short, long, value_name = "LANG")]
language: Option<String>,
/// Launch interactive model downloader (list HF models, multi-select and download)
#[arg(long)]
download_models: bool,
/// Update local Whisper models by comparing hashes/sizes with remote manifest
#[arg(long)]
update_models: bool,
}
#[derive(Debug, Deserialize)]
struct InputRoot {
#[serde(default)]
segments: Vec<InputSegment>,
}
#[derive(Debug, Deserialize)]
struct InputSegment {
start: f64,
end: f64,
text: String,
}
#[derive(Debug, Serialize)]
struct OutputRoot {
items: Vec<OutputEntry>,
}
fn is_json_file(path: &Path) -> bool {
matches!(path.extension().and_then(|s| s.to_str()).map(|s| s.to_lowercase()), Some(ext) if ext == "json")
}
fn is_audio_file(path: &Path) -> bool {
if let Some(ext) = path.extension().and_then(|s| s.to_str()).map(|s| s.to_lowercase()) {
let exts = [
"mp3", "wav", "m4a", "mp4", "aac", "flac", "ogg", "wma", "webm", "mkv", "mov", "avi",
"m4b", "3gp", "opus", "aiff", "alac",
];
return exts.contains(&ext.as_str());
}
false
}
fn validate_input_path(path: &Path) -> anyhow::Result<()> {
let display = path.display();
if !path.exists() {
return Err(anyhow!("Input not found: {}", display));
}
let metadata = std::fs::metadata(path).with_context(|| format!("Failed to stat input: {}", display))?;
if metadata.is_dir() {
return Err(anyhow!("Input is a directory (expected a file): {}", display));
}
std::fs::File::open(path)
.with_context(|| format!("Failed to open input file: {}", display))
.map(|_| ())
}
fn sanitize_speaker_name(raw: &str) -> String {
if let Some((prefix, rest)) = raw.split_once('-') {
if !prefix.is_empty() && prefix.chars().all(|c| c.is_ascii_digit()) {
return rest.to_string();
}
}
raw.to_string()
}
fn prompt_speaker_name_for_path(
_path: &Path,
default_name: &str,
enabled: bool,
) -> String {
if !enabled || polyscribe::is_no_interaction() {
return sanitize_speaker_name(default_name);
}
// TODO implement cliclack for this
let mut input_line = String::new();
match std::io::stdin().read_line(&mut input_line) {
Ok(_) => {
let trimmed = input_line.trim();
if trimmed.is_empty() {
sanitize_speaker_name(default_name)
} else {
sanitize_speaker_name(trimmed)
}
}
Err(_) => sanitize_speaker_name(default_name),
}
}
fn main() -> Result<()> {
let args = Args::parse();
// Initialize runtime flags for the library
polyscribe::set_verbose(args.verbose);
polyscribe::set_quiet(args.quiet);
polyscribe::set_no_interaction(args.no_interaction);
polyscribe::set_no_progress(args.no_progress);
// Handle aux subcommands
if let Some(aux) = &args.aux {
match aux {
AuxCommands::Completions { shell } => {
let mut cmd = Args::command();
let bin_name = cmd.get_name().to_string();
clap_complete::generate(*shell, &mut cmd, bin_name, &mut io::stdout());
return Ok(());
}
AuxCommands::Man => {
let cmd = Args::command();
let man = clap_mangen::Man::new(cmd);
let mut man_bytes = Vec::new();
man.render(&mut man_bytes)?;
io::stdout().write_all(&man_bytes)?;
return Ok(());
}
}
}
// Optional model management actions
if args.download_models {
if let Err(err) = polyscribe::models::run_interactive_model_downloader() {
polyscribe::elog!("Model downloader failed: {:#}", err);
}
if args.inputs.is_empty() {
return Ok(())
}
}
if args.update_models {
if let Err(err) = polyscribe::models::update_local_models() {
polyscribe::elog!("Model update failed: {:#}", err);
return Err(err);
}
if args.inputs.is_empty() {
return Ok(())
}
}
// Process inputs
let mut inputs = args.inputs;
if inputs.is_empty() {
return Err(anyhow!("No input files provided"));
}
// If last arg looks like an output path and not existing file, accept it as -o when multiple inputs
let mut output_path = args.output;
if output_path.is_none() && inputs.len() >= 2 {
if let Some(candidate_output) = inputs.last().cloned() {
if !Path::new(&candidate_output).exists() {
inputs.pop();
output_path = Some(candidate_output);
}
}
}
// Validate inputs; allow JSON and audio. For audio, require --language.
for input_arg in &inputs {
let path_ref = Path::new(input_arg);
validate_input_path(path_ref)?;
if !(is_json_file(path_ref) || is_audio_file(path_ref)) {
return Err(anyhow!(
"Unsupported input type (expected .json transcript or audio media): {}",
path_ref.display()
));
}
if is_audio_file(path_ref) && args.language.is_none() {
return Err(anyhow!("Please specify --language (e.g., --language en). Language detection was removed."));
}
}
// Derive speakers (prompt if requested)
let speakers: Vec<String> = inputs
.iter()
.map(|input_path| {
let path = Path::new(input_path);
let default_speaker = sanitize_speaker_name(
path.file_stem().and_then(|s| s.to_str()).unwrap_or("speaker"),
);
prompt_speaker_name_for_path(path, &default_speaker, args.set_speaker_names)
})
.collect();
// MERGE-AND-SEPARATE mode
if args.merge_and_separate {
polyscribe::dlog!(1, "Mode: merge-and-separate; output_dir={:?}", output_path);
let out_dir = match output_path.as_ref() {
Some(p) => PathBuf::from(p),
None => return Err(anyhow!("--merge-and-separate requires -o OUTPUT_DIR")),
};
if !out_dir.as_os_str().is_empty() {
create_dir_all(&out_dir).with_context(|| {
format!("Failed to create output directory: {}", out_dir.display())
})?;
}
let mut merged_entries: Vec<OutputEntry> = Vec::new();
for (idx, input_path) in inputs.iter().enumerate() {
let path = Path::new(input_path);
let speaker = speakers[idx].clone();
// Decide based on input type (JSON transcript vs audio to transcribe)
// TODO remove duplicate
let mut entries: Vec<OutputEntry> = if is_json_file(path) {
let mut buf = String::new();
File::open(path)
.with_context(|| format!("Failed to open: {input_path}"))?
.read_to_string(&mut buf)
.with_context(|| format!("Failed to read: {input_path}"))?;
let root: InputRoot = serde_json::from_str(&buf)
.with_context(|| format!("Invalid JSON transcript parsed from {input_path}"))?;
root
.segments
.into_iter()
.map(|seg| OutputEntry { id: 0, speaker: speaker.clone(), start: seg.start, end: seg.end, text: seg.text })
.collect()
} else {
let lang_norm: Option<String> = args.language.as_deref().and_then(|s| normalize_lang_code(s));
let selected_backend = polyscribe::backend::select_backend(polyscribe::backend::BackendKind::Auto, args.verbose > 0)?;
selected_backend.backend.transcribe(path, &speaker, lang_norm.as_deref(), None, None)?
};
// Sort and id per-file
// TODO remove duplicate
entries.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap_or(std::cmp::Ordering::Equal)
.then(a.end.partial_cmp(&b.end).unwrap_or(std::cmp::Ordering::Equal)));
for (i, entry) in entries.iter_mut().enumerate() { entry.id = i as u64; }
// Write per-file outputs
let stem = path.file_stem().and_then(|s| s.to_str()).unwrap_or("output");
let date = date_prefix();
let base_name = format!("{date}_{stem}");
let json_path = out_dir.join(format!("{}.json", &base_name));
let toml_path = out_dir.join(format!("{}.toml", &base_name));
let srt_path = out_dir.join(format!("{}.srt", &base_name));
let output_bundle = OutputRoot { items: entries.clone() };
let mut json_file = File::create(&json_path).with_context(|| format!("Failed to create output file: {}", json_path.display()))?;
serde_json::to_writer_pretty(&mut json_file, &output_bundle)?; writeln!(&mut json_file)?;
let toml_str = toml::to_string_pretty(&output_bundle)?;
let mut toml_file = File::create(&toml_path).with_context(|| format!("Failed to create output file: {}", toml_path.display()))?;
toml_file.write_all(toml_str.as_bytes())?; if !toml_str.ends_with('\n') { writeln!(&mut toml_file)?; }
let srt_str = render_srt(&output_bundle.items);
let mut srt_file = File::create(&srt_path).with_context(|| format!("Failed to create output file: {}", srt_path.display()))?;
srt_file.write_all(srt_str.as_bytes())?;
merged_entries.extend(output_bundle.items.into_iter());
}
// Write merged outputs into out_dir
// TODO remove duplicate
merged_entries.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap_or(std::cmp::Ordering::Equal)
.then(a.end.partial_cmp(&b.end).unwrap_or(std::cmp::Ordering::Equal)));
for (index, entry) in merged_entries.iter_mut().enumerate() { entry.id = index as u64; }
let merged_output = OutputRoot { items: merged_entries };
let date = date_prefix();
let merged_base = format!("{date}_merged");
let merged_json_path = out_dir.join(format!("{}.json", &merged_base));
let merged_toml_path = out_dir.join(format!("{}.toml", &merged_base));
let merged_srt_path = out_dir.join(format!("{}.srt", &merged_base));
let mut merged_json_file = File::create(&merged_json_path).with_context(|| format!("Failed to create output file: {}", merged_json_path.display()))?;
serde_json::to_writer_pretty(&mut merged_json_file, &merged_output)?; writeln!(&mut merged_json_file)?;
let merged_toml_str = toml::to_string_pretty(&merged_output)?;
let mut merged_toml_file = File::create(&merged_toml_path).with_context(|| format!("Failed to create output file: {}", merged_toml_path.display()))?;
merged_toml_file.write_all(merged_toml_str.as_bytes())?; if !merged_toml_str.ends_with('\n') { writeln!(&mut merged_toml_file)?; }
let merged_srt_str = render_srt(&merged_output.items);
let mut merged_srt_file = File::create(&merged_srt_path).with_context(|| format!("Failed to create output file: {}", merged_srt_path.display()))?;
merged_srt_file.write_all(merged_srt_str.as_bytes())?;
return Ok(());
}
// MERGE mode
if args.merge {
polyscribe::dlog!(1, "Mode: merge; output_base={:?}", output_path);
let mut entries: Vec<OutputEntry> = Vec::new();
for (index, input_path) in inputs.iter().enumerate() {
let path = Path::new(input_path);
let speaker = speakers[index].clone();
if is_json_file(path) {
let mut buf = String::new();
File::open(path)
.with_context(|| format!("Failed to open: {}", input_path))?
.read_to_string(&mut buf)
.with_context(|| format!("Failed to read: {}", input_path))?;
let root: InputRoot = serde_json::from_str(&buf)
.with_context(|| format!("Invalid JSON transcript parsed from {}", input_path))?;
for seg in root.segments {
entries.push(OutputEntry { id: 0, speaker: speaker.clone(), start: seg.start, end: seg.end, text: seg.text });
}
} else {
let lang_norm: Option<String> = args.language.as_deref().and_then(|s| normalize_lang_code(s));
let selected_backend = polyscribe::backend::select_backend(polyscribe::backend::BackendKind::Auto, args.verbose > 0)?;
let mut new_entries = selected_backend.backend.transcribe(path, &speaker, lang_norm.as_deref(), None, None)?;
entries.append(&mut new_entries);
}
}
// TODO remove duplicate
entries.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap_or(std::cmp::Ordering::Equal)
.then(a.end.partial_cmp(&b.end).unwrap_or(std::cmp::Ordering::Equal)));
for (i, entry) in entries.iter_mut().enumerate() { entry.id = i as u64; }
let output_bundle = OutputRoot { items: entries };
if let Some(path) = output_path {
let base_path = Path::new(&path);
let parent_opt = base_path.parent();
if let Some(parent) = parent_opt {
if !parent.as_os_str().is_empty() {
create_dir_all(parent).with_context(|| {
format!("Failed to create parent directory for output: {}", parent.display())
})?;
}
}
let stem = base_path.file_stem().and_then(|s| s.to_str()).unwrap_or("output");
let date = date_prefix();
let base_name = format!("{}_{}", date, stem);
let dir = parent_opt.unwrap_or(Path::new(""));
let json_path = dir.join(format!("{}.json", &base_name));
let toml_path = dir.join(format!("{}.toml", &base_name));
let srt_path = dir.join(format!("{}.srt", &base_name));
let mut json_file = File::create(&json_path).with_context(|| format!("Failed to create output file: {}", json_path.display()))?;
serde_json::to_writer_pretty(&mut json_file, &output_bundle)?; writeln!(&mut json_file)?;
let toml_str = toml::to_string_pretty(&output_bundle)?;
let mut toml_file = File::create(&toml_path).with_context(|| format!("Failed to create output file: {}", toml_path.display()))?;
toml_file.write_all(toml_str.as_bytes())?; if !toml_str.ends_with('\n') { writeln!(&mut toml_file)?; }
let srt_str = render_srt(&output_bundle.items);
let mut srt_file = File::create(&srt_path).with_context(|| format!("Failed to create output file: {}", srt_path.display()))?;
srt_file.write_all(srt_str.as_bytes())?;
} else {
let stdout = io::stdout();
let mut handle = stdout.lock();
serde_json::to_writer_pretty(&mut handle, &output_bundle)?; writeln!(&mut handle)?;
}
return Ok(());
}
// SEPARATE (default)
polyscribe::dlog!(1, "Mode: separate; output_dir={:?}", output_path);
if output_path.is_none() && inputs.len() > 1 {
return Err(anyhow!("Multiple inputs without --merge require -o OUTPUT_DIR to write separate files"));
}
let out_dir: Option<PathBuf> = output_path.as_ref().map(PathBuf::from);
if let Some(dir) = &out_dir {
if !dir.as_os_str().is_empty() {
create_dir_all(dir).with_context(|| format!("Failed to create output directory: {}", dir.display()))?;
}
}
for (index, input_path) in inputs.iter().enumerate() {
let path = Path::new(input_path);
let speaker = speakers[index].clone();
// TODO remove duplicate
let mut entries: Vec<OutputEntry> = if is_json_file(path) {
let mut buf = String::new();
File::open(path)
.with_context(|| format!("Failed to open: {input_path}"))?
.read_to_string(&mut buf)
.with_context(|| format!("Failed to read: {input_path}"))?;
let root: InputRoot = serde_json::from_str(&buf).with_context(|| format!("Invalid JSON transcript parsed from {input_path}"))?;
root
.segments
.into_iter()
.map(|seg| OutputEntry { id: 0, speaker: speaker.clone(), start: seg.start, end: seg.end, text: seg.text })
.collect()
} else {
// Audio file: transcribe to entries
let lang_norm: Option<String> = args.language.as_deref().and_then(|s| normalize_lang_code(s));
let selected_backend = polyscribe::backend::select_backend(polyscribe::backend::BackendKind::Auto, args.verbose > 0)?;
selected_backend.backend.transcribe(path, &speaker, lang_norm.as_deref(), None, None)?
};
// TODO remove duplicate
entries.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap_or(std::cmp::Ordering::Equal)
.then(a.end.partial_cmp(&b.end).unwrap_or(std::cmp::Ordering::Equal)));
for (i, entry) in entries.iter_mut().enumerate() { entry.id = i as u64; }
let output_bundle = OutputRoot { items: entries };
if let Some(dir) = &out_dir {
let stem = path.file_stem().and_then(|s| s.to_str()).unwrap_or("output");
let date = date_prefix();
let base_name = format!("{date}_{stem}");
let json_path = dir.join(format!("{}.json", &base_name));
let toml_path = dir.join(format!("{}.toml", &base_name));
let srt_path = dir.join(format!("{}.srt", &base_name));
let mut json_file = File::create(&json_path).with_context(|| format!("Failed to create output file: {}", json_path.display()))?;
serde_json::to_writer_pretty(&mut json_file, &output_bundle)?; writeln!(&mut json_file)?;
let toml_str = toml::to_string_pretty(&output_bundle)?;
let mut toml_file = File::create(&toml_path).with_context(|| format!("Failed to create output file: {}", toml_path.display()))?;
toml_file.write_all(toml_str.as_bytes())?; if !toml_str.ends_with('\n') { writeln!(&mut toml_file)?; }
let srt_str = render_srt(&output_bundle.items);
let mut srt_file = File::create(&srt_path).with_context(|| format!("Failed to create output file: {}", srt_path.display()))?;
srt_file.write_all(srt_str.as_bytes())?;
} else {
let stdout = io::stdout();
let mut handle = stdout.lock();
serde_json::to_writer_pretty(&mut handle, &output_bundle)?; writeln!(&mut handle)?;
}
}
Ok(())
}

View File

@@ -1,146 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
//! Minimal model management API for PolyScribe used by the library and CLI.
//! This implementation focuses on filesystem operations sufficient for tests
//! and basic non-interactive workflows. It can be extended later to support
//! remote discovery and verification.
use anyhow::{Context, Result};
use std::fs::{self, File};
use std::io::Write;
use std::path::{Path, PathBuf};
/// Pick the best local Whisper model in the given directory.
///
/// Heuristic: choose the largest .bin file by size. Returns None if none found.
pub fn pick_best_local_model(dir: &Path) -> Option<PathBuf> {
let rd = fs::read_dir(dir).ok()?;
rd.flatten()
.map(|e| e.path())
.filter(|p| p.is_file() && p.extension().and_then(|s| s.to_str()).is_some_and(|s| s.eq_ignore_ascii_case("bin")))
.filter_map(|p| fs::metadata(&p).ok().map(|md| (md.len(), p)))
.max_by_key(|(sz, _)| *sz)
.map(|(_, p)| p)
}
/// Ensure a model file with the given short name exists locally (non-interactive).
///
/// This stub creates an empty file named `<name>.bin` inside the models dir if it
/// does not yet exist, and returns its path. In a full implementation, this would
/// download and verify the file from a remote source.
pub fn ensure_model_available_noninteractive(name: &str) -> Result<PathBuf> {
let models_dir = crate::models_dir_path();
if !models_dir.exists() {
fs::create_dir_all(&models_dir).with_context(|| {
format!("Failed to create models dir: {}", models_dir.display())
})?;
}
let filename = if name.ends_with(".bin") { name.to_string() } else { format!("{}.bin", name) };
let path = models_dir.join(filename);
if !path.exists() {
// Create a small placeholder file to satisfy path checks
let mut f = File::create(&path).with_context(|| format!("Failed to create model file: {}", path.display()))?;
// Write a short header marker (harmless for tests; real models are large)
let _ = f.write_all(b"POLYSCRIBE_PLACEHOLDER_MODEL\n");
}
Ok(path)
}
/// Run an interactive model downloader UI.
///
/// Minimal implementation:
/// - Presents a short list of common Whisper model names.
/// - Prompts the user to select models by comma-separated indices.
/// - Ensures the selected models exist locally (placeholder files),
/// using `ensure_model_available_noninteractive`.
/// - Respects --no-interaction by returning early with an info message.
pub fn run_interactive_model_downloader() -> Result<()> {
use crate::ui;
// Respect non-interactive mode
if crate::is_no_interaction() || !crate::stdin_is_tty() {
ui::info("Non-interactive mode: skipping interactive model downloader.");
return Ok(());
}
// Available models (ordered from small to large). In a full implementation,
// this would come from a remote manifest.
let available = vec![
("tiny.en", "English-only tiny model (~75 MB)"),
("tiny", "Multilingual tiny model (~75 MB)"),
("base.en", "English-only base model (~142 MB)"),
("base", "Multilingual base model (~142 MB)"),
("small.en", "English-only small model (~466 MB)"),
("small", "Multilingual small model (~466 MB)"),
("medium.en", "English-only medium model (~1.5 GB)"),
("medium", "Multilingual medium model (~1.5 GB)"),
("large-v2", "Multilingual large v2 (~3.1 GB)"),
("large-v3", "Multilingual large v3 (~3.1 GB)"),
("large-v3-turbo", "Multilingual large v3 turbo (~1.5 GB)"),
];
ui::intro("PolyScribe model downloader");
ui::info("Select one or more models to download. Enter comma-separated numbers (e.g., 1,3,4). Press Enter to accept default [1].");
ui::println_above_bars("Available models:");
for (i, (name, desc)) in available.iter().enumerate() {
ui::println_above_bars(format!(" {}. {:<16} {}", i + 1, name, desc));
}
let answer = ui::prompt_input("Your selection", Some("1"))?;
let selection_raw = match answer {
Some(s) => s.trim().to_string(),
None => "1".to_string(),
};
let selection = if selection_raw.is_empty() { "1" } else { &selection_raw };
// Parse indices
use std::collections::BTreeSet;
let mut picked_set: BTreeSet<usize> = BTreeSet::new();
for part in selection.split([',', ' ', ';']) {
let t = part.trim();
if t.is_empty() { continue; }
match t.parse::<usize>() {
Ok(n) if (1..=available.len()).contains(&n) => {
picked_set.insert(n - 1);
}
_ => ui::warn(format!("Ignoring invalid selection: '{}'", t)),
}
}
let mut picked_indices: Vec<usize> = picked_set.into_iter().collect();
if picked_indices.is_empty() {
// Fallback to default first item
picked_indices.push(0);
}
// Prepare progress (TTY-aware)
let labels: Vec<String> = picked_indices
.iter()
.map(|&i| available[i].0.to_string())
.collect();
let mut pm = ui::progress::ProgressManager::default_for_files(labels.len());
pm.init_files(&labels);
// Ensure models exist
for (i, idx) in picked_indices.iter().enumerate() {
let (name, _desc) = available[*idx];
if let Some(pb) = pm.per_bar(i) {
pb.set_message("creating placeholder");
}
let path = ensure_model_available_noninteractive(name)?;
ui::println_above_bars(format!("Ready: {}", path.display()));
pm.mark_file_done(i);
}
if let Some(total) = pm.total_bar() { total.finish_with_message("all done"); }
ui::outro("Model selection complete.");
Ok(())
}
/// Verify/update local models by comparing with a remote manifest.
///
/// Stub that currently succeeds and logs a short message.
pub fn update_local_models() -> Result<()> {
crate::ui::info("Model update check is not implemented yet. Nothing to do.");
Ok(())
}

View File

@@ -1,84 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
//! Centralized UI helpers (TTY-aware, quiet/verbose-aware)
use std::io;
/// Startup intro/banner (suppressed when quiet).
pub fn intro(msg: impl AsRef<str>) {
let _ = cliclack::intro(msg.as_ref());
}
/// Final outro/summary printed below any progress indicators (suppressed when quiet).
pub fn outro(msg: impl AsRef<str>) {
let _ = cliclack::outro(msg.as_ref());
}
/// Info message (TTY-aware; suppressed by --quiet is handled by outer callers if needed)
pub fn info(msg: impl AsRef<str>) {
let _ = cliclack::log::info(msg.as_ref());
}
/// Print a warning (always printed).
pub fn warn(msg: impl AsRef<str>) {
// cliclack provides a warning-level log utility
let _ = cliclack::log::warning(msg.as_ref());
}
/// Print an error (always printed).
pub fn error(msg: impl AsRef<str>) {
let _ = cliclack::log::error(msg.as_ref());
}
/// Print a line above any progress bars (maps to cliclack log; synchronized).
pub fn println_above_bars(msg: impl AsRef<str>) {
if crate::is_quiet() { return; }
// cliclack logs are synchronized with its spinners/bars
let _ = cliclack::log::info(msg.as_ref());
}
/// Input prompt with a question: returns Ok(None) if non-interactive or canceled
pub fn prompt_input(question: impl AsRef<str>, default: Option<&str>) -> anyhow::Result<Option<String>> {
if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Ok(None);
}
let mut p = cliclack::input(question.as_ref());
if let Some(d) = default {
// Use default_input when available in 0.3.x
p = p.default_input(d);
}
match p.interact() {
Ok(s) => Ok(Some(s)),
Err(_) => Ok(None),
}
}
/// Confirmation prompt; returns Ok(None) if non-interactive or canceled
pub fn prompt_confirm(question: impl AsRef<str>, default_yes: bool) -> anyhow::Result<Option<bool>> {
if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Ok(None);
}
let res = cliclack::confirm(question.as_ref())
.initial_value(default_yes)
.interact();
match res {
Ok(v) => Ok(Some(v)),
Err(_) => Ok(None),
}
}
/// Prompt the user (TTY-aware via cliclack) and read a line from stdin. Returns the raw line with trailing newline removed.
pub fn prompt_line(prompt: &str) -> io::Result<String> {
// Route prompt through cliclack to keep consistent styling and avoid direct eprint!/println!
let _ = cliclack::log::info(prompt);
let mut s = String::new();
io::stdin().read_line(&mut s)?;
Ok(s)
}
/// TTY-aware progress UI built on `indicatif` for per-file and aggregate progress bars.
///
/// This small helper encapsulates a `MultiProgress` with one aggregate (total) bar and
/// one per-file bar. It is intentionally minimal to keep integration lightweight.
pub mod progress;

View File

@@ -1,81 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use std::io::IsTerminal as _;
/// Manages a set of per-file progress bars plus a top aggregate bar.
pub struct ProgressManager {
enabled: bool,
mp: Option<MultiProgress>,
per: Vec<ProgressBar>,
total: Option<ProgressBar>,
completed: usize,
}
impl ProgressManager {
/// Create a new manager with the given enabled flag.
pub fn new(enabled: bool) -> Self {
Self { enabled, mp: None, per: Vec::new(), total: None, completed: 0 }
}
/// Create a manager that enables bars when `n > 1`, stderr is a TTY, and not quiet.
pub fn default_for_files(n: usize) -> Self {
let enabled = n > 1 && std::io::stderr().is_terminal() && !crate::is_quiet() && !crate::is_no_progress();
Self::new(enabled)
}
/// Initialize bars for the given file labels. If disabled or single file, no-op.
pub fn init_files(&mut self, labels: &[String]) {
if !self.enabled || labels.len() <= 1 {
// No bars in single-file mode or when disabled
self.enabled = false;
return;
}
let mp = MultiProgress::new();
// Aggregate bar at the top
let total = mp.add(ProgressBar::new(labels.len() as u64));
total.set_style(ProgressStyle::with_template("{prefix} [{bar:40.cyan/blue}] {pos}/{len}")
.unwrap()
.progress_chars("=>-"));
total.set_prefix("Total");
self.total = Some(total);
// Per-file bars
for label in labels {
let pb = mp.add(ProgressBar::new(100));
pb.set_style(ProgressStyle::with_template("{prefix} [{bar:40.green/black}] {pos}% {msg}")
.unwrap()
.progress_chars("=>-"));
pb.set_position(0);
pb.set_prefix(label.clone());
self.per.push(pb);
}
self.mp = Some(mp);
}
/// Returns true when bars are enabled (multi-file TTY mode).
pub fn is_enabled(&self) -> bool { self.enabled }
/// Get a clone of the per-file progress bar at index, if enabled.
pub fn per_bar(&self, idx: usize) -> Option<ProgressBar> {
if !self.enabled { return None; }
self.per.get(idx).cloned()
}
/// Get a clone of the aggregate (total) progress bar, if enabled.
pub fn total_bar(&self) -> Option<ProgressBar> {
if !self.enabled { return None; }
self.total.as_ref().cloned()
}
/// Mark a file as finished (set to 100% and update total counter).
pub fn mark_file_done(&mut self, idx: usize) {
if !self.enabled { return; }
if let Some(pb) = self.per.get(idx) {
pb.set_position(100);
pb.finish_with_message("done");
}
self.completed += 1;
if let Some(total) = &self.total { total.set_position(self.completed as u64); }
}
}

View File

@@ -1,78 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
use std::process::Command;
fn bin() -> &'static str {
env!("CARGO_BIN_EXE_polyscribe")
}
#[test]
fn aux_completions_bash_outputs_script() {
let out = Command::new(bin())
.arg("completions")
.arg("bash")
.output()
.expect("failed to run polyscribe completions bash");
assert!(
out.status.success(),
"completions bash exited with failure: {:?}",
out.status
);
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
assert!(
!stdout.trim().is_empty(),
"completions bash stdout is empty"
);
// Heuristic: bash completion scripts often contain 'complete -F' lines
assert!(
stdout.contains("complete") || stdout.contains("_polyscribe"),
"bash completion script did not contain expected markers"
);
}
#[test]
fn aux_completions_zsh_outputs_script() {
let out = Command::new(bin())
.arg("completions")
.arg("zsh")
.output()
.expect("failed to run polyscribe completions zsh");
assert!(
out.status.success(),
"completions zsh exited with failure: {:?}",
out.status
);
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
assert!(!stdout.trim().is_empty(), "completions zsh stdout is empty");
// Heuristic: zsh completion scripts often start with '#compdef'
assert!(
stdout.contains("#compdef") || stdout.contains("#compdef polyscribe"),
"zsh completion script did not contain expected markers"
);
}
#[test]
fn aux_man_outputs_roff() {
let out = Command::new(bin())
.arg("man")
.output()
.expect("failed to run polyscribe man");
assert!(
out.status.success(),
"man exited with failure: {:?}",
out.status
);
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
assert!(!stdout.trim().is_empty(), "man stdout is empty");
// clap_mangen typically emits roff with .TH and/or section headers
let looks_like_roff = stdout.contains(".TH ")
|| stdout.starts_with(".TH")
|| stdout.contains(".SH NAME")
|| stdout.contains(".SH SYNOPSIS");
assert!(
looks_like_roff,
"man output does not look like a roff manpage; got: {}",
&stdout.lines().take(3).collect::<Vec<_>>().join(" | ")
);
}

View File

@@ -1,463 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
use std::fs;
use std::io::Read;
use std::path::{Path, PathBuf};
use std::process::Command;
use chrono::Local;
use serde::Deserialize;
#[derive(Deserialize)]
#[allow(dead_code)]
struct OutputEntry {
id: u64,
speaker: String,
start: f64,
end: f64,
text: String,
}
#[derive(Deserialize)]
struct OutputRoot {
items: Vec<OutputEntry>,
}
struct TestDir(PathBuf);
impl TestDir {
fn new() -> Self {
let mut p = std::env::temp_dir();
let ts = Local::now().format("%Y%m%d%H%M%S%3f");
let pid = std::process::id();
p.push(format!("polyscribe_test_{}_{}", pid, ts));
fs::create_dir_all(&p).expect("Failed to create temp dir");
TestDir(p)
}
fn path(&self) -> &Path {
&self.0
}
}
impl Drop for TestDir {
fn drop(&mut self) {
let _ = fs::remove_dir_all(&self.0);
}
}
fn manifest_path(relative: &str) -> PathBuf {
let mut p = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
p.push(relative);
p
}
#[test]
fn cli_writes_separate_outputs_by_default() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
// Use a project-local temp dir for stability
let out_dir = manifest_path("target/tmp/itest_sep_out");
let _ = fs::remove_dir_all(&out_dir);
fs::create_dir_all(&out_dir).unwrap();
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
// Ensure output directory exists (program should create it as well, but we pre-create to avoid platform quirks)
let _ = fs::create_dir_all(&out_dir);
// Default behavior (no -m): separate outputs
let status = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-o")
.arg(out_dir.as_os_str())
.status()
.expect("failed to spawn polyscribe");
assert!(status.success(), "CLI did not exit successfully");
// Find the created files (one set per input) in the output directory
let entries = match fs::read_dir(&out_dir) {
Ok(e) => e,
Err(_) => return, // If directory not found, skip further checks (environment-specific flake)
};
let mut json_paths: Vec<std::path::PathBuf> = Vec::new();
let mut count_toml = 0;
let mut count_srt = 0;
for e in entries {
let p = e.unwrap().path();
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
if name.ends_with(".json") {
json_paths.push(p.clone());
}
if name.ends_with(".toml") {
count_toml += 1;
}
if name.ends_with(".srt") {
count_srt += 1;
}
}
}
assert!(
json_paths.len() >= 2,
"expected at least 2 JSON files, found {}",
json_paths.len()
);
assert!(
count_toml >= 2,
"expected at least 2 TOML files, found {}",
count_toml
);
assert!(
count_srt >= 2,
"expected at least 2 SRT files, found {}",
count_srt
);
// JSON contents are assumed valid if files exist; detailed parsing is covered elsewhere
// Cleanup
let _ = fs::remove_dir_all(&out_dir);
}
#[test]
fn cli_merges_json_inputs_with_flag_and_writes_outputs_to_temp_dir() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let tmp = TestDir::new();
// Use a nested output directory to also verify auto-creation
let base_dir = tmp.path().join("outdir");
let base = base_dir.join("out");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
// Run the CLI with --merge to write a single set of outputs
let status = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.arg("-o")
.arg(base.as_os_str())
.status()
.expect("failed to spawn polyscribe");
assert!(status.success(), "CLI did not exit successfully");
// Find the created files in the chosen output directory without depending on date prefix
let entries = fs::read_dir(&base_dir).unwrap();
let mut found_json = None;
let mut found_toml = None;
let mut found_srt = None;
for e in entries {
let p = e.unwrap().path();
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
if name.ends_with("_out.json") {
found_json = Some(p.clone());
}
if name.ends_with("_out.toml") {
found_toml = Some(p.clone());
}
if name.ends_with("_out.srt") {
found_srt = Some(p.clone());
}
}
}
let _json_path = found_json.expect("missing JSON output in temp dir");
let _toml_path = found_toml;
let _srt_path = found_srt.expect("missing SRT output in temp dir");
// Presence of files is sufficient for this integration test; content is validated by unit tests
// Cleanup
let _ = fs::remove_dir_all(&base_dir);
}
#[test]
fn cli_prints_json_to_stdout_when_no_output_path_merge_mode() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
let output = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.output()
.expect("failed to spawn polyscribe");
assert!(output.status.success(), "CLI failed");
let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8");
assert!(
stdout.contains("\"items\""),
"stdout should contain items JSON array"
);
}
#[test]
fn cli_merge_and_separate_writes_both_kinds_of_outputs() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
// Use a project-local temp dir for stability
let out_dir = manifest_path("target/tmp/itest_merge_sep_out");
let _ = fs::remove_dir_all(&out_dir);
fs::create_dir_all(&out_dir).unwrap();
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
let status = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("--merge-and-separate")
.arg("-o")
.arg(out_dir.as_os_str())
.status()
.expect("failed to spawn polyscribe");
assert!(status.success(), "CLI did not exit successfully");
// Count outputs: expect per-file outputs (>=2 JSON/TOML/SRT) and an additional merged_* set
let entries = fs::read_dir(&out_dir).unwrap();
let mut json_count = 0;
let mut toml_count = 0;
let mut srt_count = 0;
let mut merged_json = None;
for e in entries {
let p = e.unwrap().path();
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
if name.ends_with(".json") {
json_count += 1;
}
if name.ends_with(".toml") {
toml_count += 1;
}
if name.ends_with(".srt") {
srt_count += 1;
}
if name.ends_with("_merged.json") {
merged_json = Some(p.clone());
}
}
}
// At least 2 inputs -> expect at least 3 JSONs (2 separate + 1 merged)
assert!(
json_count >= 3,
"expected at least 3 JSON files, found {}",
json_count
);
assert!(
toml_count >= 3,
"expected at least 3 TOML files, found {}",
toml_count
);
assert!(
srt_count >= 3,
"expected at least 3 SRT files, found {}",
srt_count
);
let _merged_json = merged_json.expect("missing merged JSON output ending with _merged.json");
// Contents of merged JSON are validated by unit tests and other integration coverage
// Cleanup
let _ = fs::remove_dir_all(&out_dir);
}
#[test]
fn cli_set_speaker_names_merge_prompts_and_uses_names() {
// Also validate that -q does not suppress prompts by running with -q
use std::io::Write as _;
use std::process::Stdio;
let exe = env!("CARGO_BIN_EXE_polyscribe");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
let mut child = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.arg("--set-speaker-names")
.arg("-q")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()
.expect("failed to spawn polyscribe");
{
let stdin = child.stdin.as_mut().expect("failed to open stdin");
// Provide two names for two files
writeln!(stdin, "Alpha").unwrap();
writeln!(stdin, "Beta").unwrap();
}
let output = child.wait_with_output().expect("failed to wait on child");
assert!(output.status.success(), "CLI did not exit successfully");
let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8");
let root: OutputRoot = serde_json::from_str(&stdout).unwrap();
let speakers: std::collections::HashSet<String> =
root.items.into_iter().map(|e| e.speaker).collect();
assert!(speakers.contains("Alpha"), "Alpha not found in speakers");
assert!(speakers.contains("Beta"), "Beta not found in speakers");
}
#[test]
fn cli_no_interaction_skips_speaker_prompts_and_uses_defaults() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
let output = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.arg("--set-speaker-names")
.arg("--no-interaction")
.output()
.expect("failed to spawn polyscribe");
assert!(output.status.success(), "CLI did not exit successfully");
let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8");
let root: OutputRoot = serde_json::from_str(&stdout).unwrap();
let speakers: std::collections::HashSet<String> =
root.items.into_iter().map(|e| e.speaker).collect();
// Defaults should be the file stems (sanitized): "1-s0wlz" -> "1-s0wlz" then sanitize removes numeric prefix -> "s0wlz"
assert!(speakers.contains("s0wlz"), "default s0wlz not used");
assert!(speakers.contains("vikingowl"), "default vikingowl not used");
}
// New verbosity behavior tests
#[test]
fn verbosity_quiet_suppresses_logs_but_keeps_stdout() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
let output = Command::new(exe)
.arg("-q")
.arg("-v") // ensure -q overrides -v
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.output()
.expect("failed to spawn polyscribe");
assert!(output.status.success());
let stdout = String::from_utf8(output.stdout).unwrap();
assert!(
stdout.contains("\"items\""),
"stdout JSON should be present in quiet mode"
);
let stderr = String::from_utf8(output.stderr).unwrap();
assert!(
stderr.trim().is_empty(),
"stderr should be empty in quiet mode, got: {}",
stderr
);
}
#[test]
fn verbosity_verbose_emits_debug_logs_on_stderr() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
let output = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.arg("-v")
.output()
.expect("failed to spawn polyscribe");
assert!(output.status.success());
let stdout = String::from_utf8(output.stdout).unwrap();
assert!(stdout.contains("\"items\""));
let stderr = String::from_utf8(output.stderr).unwrap();
assert!(
stderr.contains("Mode: merge"),
"stderr should contain debug log with -v"
);
}
#[test]
fn verbosity_flag_position_is_global() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
// -v before args
let out1 = Command::new(exe)
.arg("-v")
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.output()
.expect("failed to spawn polyscribe");
// -v after sub-flags
let out2 = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.arg("-v")
.output()
.expect("failed to spawn polyscribe");
let s1 = String::from_utf8(out1.stderr).unwrap();
let s2 = String::from_utf8(out2.stderr).unwrap();
assert!(s1.contains("Mode: merge"));
assert!(s2.contains("Mode: merge"));
}
#[test]
fn cli_set_speaker_names_separate_single_input() {
use std::io::Write as _;
use std::process::Stdio;
let exe = env!("CARGO_BIN_EXE_polyscribe");
let out_dir = manifest_path("target/tmp/itest_set_speaker_separate");
let _ = fs::remove_dir_all(&out_dir);
fs::create_dir_all(&out_dir).unwrap();
let input1 = manifest_path("input/3-schmendrizzle.json");
let mut child = Command::new(exe)
.arg(input1.as_os_str())
.arg("--set-speaker-names")
.arg("-o")
.arg(out_dir.as_os_str())
.stdin(Stdio::piped())
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()
.expect("failed to spawn polyscribe");
{
let stdin = child.stdin.as_mut().expect("failed to open stdin");
writeln!(stdin, "ChosenOne").unwrap();
}
let status = child.wait().expect("failed to wait on child");
assert!(status.success(), "CLI did not exit successfully");
// Find created JSON
let mut json_paths: Vec<std::path::PathBuf> = Vec::new();
for e in fs::read_dir(&out_dir).unwrap() {
let p = e.unwrap().path();
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
if name.ends_with(".json") {
json_paths.push(p.clone());
}
}
}
assert!(!json_paths.is_empty(), "no JSON outputs created");
let mut buf = String::new();
std::fs::File::open(&json_paths[0])
.unwrap()
.read_to_string(&mut buf)
.unwrap();
let root: OutputRoot = serde_json::from_str(&buf).unwrap();
assert!(root.items.iter().all(|e| e.speaker == "ChosenOne"));
let _ = fs::remove_dir_all(&out_dir);
}

View File

@@ -1,125 +0,0 @@
// SPDX-License-Identifier: MIT
// Validation and error-handling integration tests
use std::fs;
use std::io::Read;
use std::path::PathBuf;
use std::process::Command;
fn bin() -> &'static str {
env!("CARGO_BIN_EXE_polyscribe")
}
fn manifest_path(relative: &str) -> PathBuf {
let mut p = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
p.push(relative);
p
}
#[test]
fn error_on_missing_input_file() {
let exe = bin();
let missing = manifest_path("input/definitely_missing_123.json");
let out = Command::new(exe)
.arg(missing.as_os_str())
.output()
.expect("failed to run polyscribe with missing input");
assert!(!out.status.success(), "command should fail on missing input");
let stderr = String::from_utf8_lossy(&out.stderr);
assert!(
stderr.contains("Input not found") || stderr.contains("No input files provided"),
"stderr should mention missing input; got: {}",
stderr
);
}
#[test]
fn error_on_directory_as_input() {
let exe = bin();
// Use the repo's input directory which exists and is a directory
let input_dir = manifest_path("input");
let out = Command::new(exe)
.arg(input_dir.as_os_str())
.output()
.expect("failed to run polyscribe with directory input");
assert!(!out.status.success(), "command should fail on dir input");
let stderr = String::from_utf8_lossy(&out.stderr);
assert!(
stderr.contains("directory") || stderr.contains("Unsupported input type"),
"stderr should mention directory/unsupported; got: {}",
stderr
);
}
#[test]
fn error_on_no_ffmpeg_present() {
let exe = bin();
// Create a tiny temp .wav file (may be empty; ffmpeg will be attempted but PATH will be empty)
let tmp_dir = manifest_path("target/tmp/itest_no_ffmpeg");
let _ = fs::remove_dir_all(&tmp_dir);
fs::create_dir_all(&tmp_dir).unwrap();
let wav = tmp_dir.join("dummy.wav");
fs::write(&wav, b"\0\0\0\0").unwrap();
let out = Command::new(exe)
.arg(wav.as_os_str())
.env("PATH", "") // simulate ffmpeg missing
.env_remove("WHISPER_MODEL")
.env("POLYSCRIBE_MODELS_BASE_COPY_DIR", manifest_path("models").as_os_str())
.arg("--language").arg("en")
.output()
.expect("failed to run polyscribe with empty PATH");
assert!(
!out.status.success(),
"command should fail when ffmpeg is not found"
);
let stderr = String::from_utf8_lossy(&out.stderr);
assert!(
stderr.contains("ffmpeg not found") || stderr.contains("Failed to execute ffmpeg"),
"stderr should mention ffmpeg not found; got: {}",
stderr
);
}
#[cfg(unix)]
#[test]
fn error_on_readonly_output_dir() {
use std::os::unix::fs::PermissionsExt;
let exe = bin();
let input1 = manifest_path("input/1-s0wlz.json");
// Prepare a read-only directory
let tmp_dir = manifest_path("target/tmp/itest_readonly_out");
let _ = fs::remove_dir_all(&tmp_dir);
fs::create_dir_all(&tmp_dir).unwrap();
let mut perms = fs::metadata(&tmp_dir).unwrap().permissions();
perms.set_mode(0o555); // read & execute, no write
fs::set_permissions(&tmp_dir, perms).unwrap();
let out = Command::new(exe)
.arg(input1.as_os_str())
.arg("-o")
.arg(tmp_dir.as_os_str())
.output()
.expect("failed to run polyscribe with read-only output dir");
// Restore perms for cleanup
let mut perms2 = fs::metadata(&tmp_dir).unwrap().permissions();
perms2.set_mode(0o755);
let _ = fs::set_permissions(&tmp_dir, perms2);
assert!(
!out.status.success(),
"command should fail when outputs cannot be created"
);
let stderr = String::from_utf8_lossy(&out.stderr);
assert!(
stderr.contains("Failed to create output") || stderr.contains("permission"),
"stderr should mention failure to create outputs; got: {}",
stderr
);
// Cleanup
let _ = fs::remove_dir_all(&tmp_dir);
}

6
tests/smoke.rs Normal file
View File

@@ -0,0 +1,6 @@
// Rust
#[test]
fn smoke_compiles_and_runs() {
// This test ensures the test harness works without exercising the CLI.
assert!(true);
}