Compare commits
52 Commits
af473c4942
...
dev
Author | SHA1 | Date | |
---|---|---|---|
840383fcf7 | |||
1982e9b48b | |||
0128bf2eec | |||
da5a76d253 | |||
5ec297397e | |||
cbf48a0452 | |||
0a249f2197 | |||
0573369b81 | |||
9841550dcc | |||
53119cd0ab | |||
144b01d591 | |||
ffd451b404 | |||
5c64677e79 | |||
128db0f733 | |||
06fd3efd1f | |||
49513d5099 | |||
3344a3b18c | |||
5ace0a0d7e | |||
ed3af9210f | |||
79397a3b9c | |||
9fd44a2e37 | |||
a987a3fcfb | |||
f41f1a4117 | |||
75cfb6f160 | |||
8ebdf876ed | |||
eb1bf9e02d | |||
9b4bd545dd | |||
041e504cb2 | |||
2cc5e49131 | |||
4916aa6224 | |||
ab57553949 | |||
40818a091d | |||
97855a247b | |||
0864516614 | |||
bb9402c643 | |||
4b8b68b33d | |||
6a9736c50a | |||
d3310695d2 | |||
03659448bc | |||
5c8a495b9f | |||
6b72bd64c0 | |||
278ca1b523 | |||
3f1e634e2d | |||
4063b4cb06 | |||
f551cc3498 | |||
90f9849cc0 | |||
9d12507cf5 | |||
b9308be930 | |||
fdf5e3370d | |||
ae0fdf802a | |||
fbf3aab23c | |||
4e117d78f8 |
17
.cargo/config.toml
Normal file
17
.cargo/config.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
[build]
|
||||
# Make target-dir consistent across workspace for better cache reuse.
|
||||
target-dir = "target"
|
||||
|
||||
[profile.dev]
|
||||
opt-level = 1
|
||||
debug = true
|
||||
incremental = true
|
||||
|
||||
[profile.release]
|
||||
# Reasonable defaults for CLI apps/libraries
|
||||
lto = "thin"
|
||||
codegen-units = 1
|
||||
strip = "debuginfo"
|
||||
opt-level = 3
|
52
.github/workflows/ci.yml
vendored
52
.github/workflows/ci.yml
vendored
@@ -2,44 +2,32 @@ name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
branches: [ dev, main ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
branches: [ dev, main ]
|
||||
|
||||
jobs:
|
||||
ci:
|
||||
name: ci
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
- name: Setup Rust
|
||||
|
||||
- name: Set up Rust
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
components: clippy, rustfmt
|
||||
- name: Show rustc/cargo versions
|
||||
run: |
|
||||
rustc -Vv
|
||||
cargo -Vv
|
||||
- name: Cache cargo registry
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
~/.cargo/git
|
||||
target
|
||||
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||
- name: Install cargo-audit
|
||||
run: |
|
||||
cargo install cargo-audit --locked || cargo install cargo-audit
|
||||
- name: Format check
|
||||
|
||||
- name: Cache cargo registry and target
|
||||
uses: Swatinem/rust-cache@v2
|
||||
|
||||
- name: Install components
|
||||
run: rustup component add clippy rustfmt
|
||||
|
||||
- name: Cargo fmt
|
||||
run: cargo fmt --all -- --check
|
||||
- name: Clippy (warnings as errors)
|
||||
run: cargo clippy --all-targets -- -D warnings
|
||||
|
||||
- name: Clippy
|
||||
run: cargo clippy --workspace --all-targets -- -D warnings
|
||||
|
||||
- name: Test
|
||||
run: cargo test --all
|
||||
- name: Audit
|
||||
run: cargo audit
|
||||
run: cargo test --workspace --all --locked
|
||||
|
||||
|
14
CHANGELOG.md
Normal file
14
CHANGELOG.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# Changelog
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on Keep a Changelog, and this project adheres to Semantic Versioning.
|
||||
|
||||
## Unreleased
|
||||
|
||||
### Changed
|
||||
- Docs: Replace `--download-models`/`--update-models` flags with `models download`/`models update` subcommands in `README.md`, `docs/usage.md`, and `docs/development.md`.
|
||||
- Host: Plugin discovery now scans `$XDG_DATA_HOME/polyscribe/plugins` (platform equivalent via `directories`) in addition to `PATH`.
|
||||
- CI: Add GitHub Actions workflow to run fmt, clippy (warnings as errors), and tests for pushes and PRs.
|
||||
|
||||
|
@@ -1,32 +1,26 @@
|
||||
# Contributing to PolyScribe
|
||||
# Contributing
|
||||
|
||||
Thanks for your interest in contributing! This guide explains the workflow and the checklist to follow before opening a Pull Request.
|
||||
Thank you for your interest in contributing!
|
||||
|
||||
Workflow (fork → branch → PR)
|
||||
1) Fork the repository to your account.
|
||||
2) Create a feature branch:
|
||||
- git checkout -b feat/short-description
|
||||
3) Make changes with focused commits and good messages.
|
||||
4) Run the checklist below.
|
||||
5) Push and open a Pull Request against the main repository.
|
||||
Development setup
|
||||
- Install Rust via rustup.
|
||||
- Ensure ffmpeg is installed and available on PATH.
|
||||
- For GPU builds, install the appropriate runtime (CUDA/ROCm/Vulkan) and enable the matching features.
|
||||
|
||||
Developer checklist (before opening a PR)
|
||||
- Build:
|
||||
- cargo build (preferably without warnings)
|
||||
- Tests:
|
||||
- cargo test (all tests pass)
|
||||
- Lints:
|
||||
- cargo clippy --all-targets -- -D warnings (fix warnings)
|
||||
- Documentation:
|
||||
- Update README/docs for user-visible changes
|
||||
- Update CHANGELOG.md if applicable
|
||||
- Tests for changes:
|
||||
- Add or update tests for bug fixes and new features where reasonable
|
||||
Coding guidelines
|
||||
- Prefer small, focused changes.
|
||||
- Add tests where reasonable.
|
||||
- Keep user-facing changes documented in README/docs.
|
||||
- Run clippy and fix warnings.
|
||||
|
||||
Local development tips
|
||||
- Use `cargo run -- <args>` during development.
|
||||
- For faster feedback, keep examples in the examples/ folder handy.
|
||||
- Keep functions small and focused; prefer clear error messages with context.
|
||||
CI checklist
|
||||
- Build: cargo build --all-targets --locked
|
||||
- Tests: cargo test --all --locked
|
||||
- Lints: cargo clippy --all-targets -- -D warnings
|
||||
- Optional: smoke-run examples inline (from README):
|
||||
- ./target/release/polyscribe --update-models --no-interaction -q
|
||||
- ./target/release/polyscribe -o output samples/podcast_clip.mp3
|
||||
|
||||
Code of conduct
|
||||
- Be respectful and constructive. Assume good intent.
|
||||
Notes
|
||||
- For GPU features, use --features gpu-cuda|gpu-hip|gpu-vulkan as needed in your local runs.
|
||||
- For docs-only changes, please still ensure the project builds.
|
||||
|
978
Cargo.lock
generated
978
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
82
Cargo.toml
82
Cargo.toml
@@ -1,41 +1,53 @@
|
||||
[package]
|
||||
name = "polyscribe"
|
||||
version = "0.1.0"
|
||||
[workspace]
|
||||
members = [
|
||||
"crates/polyscribe-core",
|
||||
"crates/polyscribe-protocol",
|
||||
"crates/polyscribe-host",
|
||||
"crates/polyscribe-cli",
|
||||
]
|
||||
resolver = "3"
|
||||
|
||||
[workspace.package]
|
||||
edition = "2024"
|
||||
version = "0.1.0"
|
||||
license = "MIT"
|
||||
rust-version = "1.89"
|
||||
|
||||
[features]
|
||||
# Default: build without whisper to keep tests lightweight; enable `whisper` to use whisper-rs.
|
||||
default = []
|
||||
# Enable whisper-rs dependency (CPU-only unless combined with gpu-* features)
|
||||
whisper = ["dep:whisper-rs"]
|
||||
# GPU backends map to whisper-rs features
|
||||
gpu-cuda = ["whisper", "whisper-rs/cuda"]
|
||||
gpu-hip = ["whisper", "whisper-rs/hipblas"]
|
||||
# Vulkan path currently doesn't use whisper directly here; placeholder feature
|
||||
gpu-vulkan = []
|
||||
# explicit CPU fallback feature (no effect at build time, used for clarity)
|
||||
cpu-fallback = []
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.98"
|
||||
clap = { version = "4.5.43", features = ["derive"] }
|
||||
clap_complete = "4.5.28"
|
||||
clap_mangen = "0.2"
|
||||
# Optional: Keep dependency versions consistent across members
|
||||
[workspace.dependencies]
|
||||
thiserror = "1.0.69"
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
anyhow = "1.0.99"
|
||||
libc = "0.2.175"
|
||||
toml = "0.8.23"
|
||||
serde_json = "1.0.142"
|
||||
toml = "0.8"
|
||||
chrono = { version = "0.4", features = ["clock"] }
|
||||
reqwest = { version = "0.12", features = ["blocking", "json"] }
|
||||
sha2 = "0.10"
|
||||
# Make whisper-rs optional; enabled via `whisper` feature
|
||||
# Pin whisper-rs to a known-good commit for reproducible builds.
|
||||
# To update: run `cargo update -p whisper-rs --precise 135b60b85a15714862806b6ea9f76abec38156f1` (adjust SHA) and update this rev.
|
||||
whisper-rs = { git = "https://github.com/tazz4843/whisper-rs", rev = "135b60b85a15714862806b6ea9f76abec38156f1", default-features = false, optional = true }
|
||||
libc = "0.2"
|
||||
indicatif = "0.17"
|
||||
ctrlc = "3.4"
|
||||
cliclack = "0.3"
|
||||
chrono = { version = "0.4.41", features = ["serde"] }
|
||||
sha2 = "0.10.9"
|
||||
which = "6.0.3"
|
||||
tokio = { version = "1.47.1", features = ["rt-multi-thread", "macros"] }
|
||||
clap = { version = "4.5.44", features = ["derive"] }
|
||||
directories = "5.0.1"
|
||||
whisper-rs = "0.14.3"
|
||||
cliclack = "0.3.6"
|
||||
clap_complete = "4.5.57"
|
||||
clap_mangen = "0.2.29"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
# Additional shared deps used across members
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
|
||||
reqwest = { version = "0.12.7", default-features = false, features = ["blocking", "rustls-tls", "gzip", "json"] }
|
||||
hex = "0.4.3"
|
||||
tempfile = "3.12.0"
|
||||
assert_cmd = "2.0.16"
|
||||
|
||||
[workspace.lints.rust]
|
||||
unused_imports = "deny"
|
||||
dead_code = "warn"
|
||||
|
||||
[profile.release]
|
||||
lto = "fat"
|
||||
codegen-units = 1
|
||||
panic = "abort"
|
||||
|
||||
[profile.dev]
|
||||
panic = "unwind"
|
||||
|
23
Makefile
23
Makefile
@@ -1,23 +0,0 @@
|
||||
# Lightweight examples-check: runs all examples/*.sh with --no-interaction -q and stubbed BIN
|
||||
# This target does not perform network calls and never prompts for input.
|
||||
|
||||
.SHELL := /bin/bash
|
||||
|
||||
.PHONY: examples-check
|
||||
examples-check:
|
||||
@set -euo pipefail; \
|
||||
shopt -s nullglob; \
|
||||
BIN_WRAPPER="$(PWD)/scripts/with_flags.sh"; \
|
||||
failed=0; \
|
||||
for f in examples/*.sh; do \
|
||||
echo "[examples-check] Running $$f"; \
|
||||
BIN="$$BIN_WRAPPER" bash "$$f" </dev/null >/dev/null 2>&1 || { \
|
||||
echo "[examples-check] FAILED: $$f"; failed=1; \
|
||||
}; \
|
||||
done; \
|
||||
if [[ $$failed -ne 0 ]]; then \
|
||||
echo "[examples-check] Some examples failed."; \
|
||||
exit 1; \
|
||||
else \
|
||||
echo "[examples-check] All examples passed (no interaction, quiet)."; \
|
||||
fi
|
130
README.md
130
README.md
@@ -1,94 +1,68 @@
|
||||
# PolyScribe
|
||||
|
||||
PolyScribe is a fast, local-first CLI for transcribing audio/video and merging existing JSON transcripts. It uses whisper-rs under the hood, can discover and download Whisper models automatically, and supports CPU and optional GPU backends (CUDA, ROCm/HIP, Vulkan).
|
||||
Local-first transcription and plugins.
|
||||
|
||||
Key features
|
||||
- Transcribe audio and common video files using ffmpeg for audio extraction.
|
||||
- Merge multiple JSON transcripts, or merge and also keep per-file outputs.
|
||||
- Model management: interactive downloader and non-interactive updater with hash verification.
|
||||
- GPU backend selection at runtime; auto-detects available accelerators.
|
||||
- Clean outputs (JSON and SRT), speaker naming prompts, and useful logging controls.
|
||||
## Features
|
||||
|
||||
Prerequisites
|
||||
- Rust toolchain (rustup recommended)
|
||||
- ffmpeg available on PATH
|
||||
- Optional for GPU acceleration at runtime: CUDA, ROCm/HIP, or Vulkan drivers (match your build features)
|
||||
- **Local-first**: Works offline with downloaded models
|
||||
- **Multiple backends**: CPU, CUDA, ROCm/HIP, and Vulkan support
|
||||
- **Plugin system**: Extensible via JSON-RPC plugins
|
||||
- **Model management**: Automatic download and verification of Whisper models
|
||||
- **Manifest caching**: Local cache for Hugging Face model manifests to reduce network requests
|
||||
|
||||
Installation
|
||||
- Build from source (CPU-only by default):
|
||||
- rustup install stable
|
||||
- rustup default stable
|
||||
- cargo build --release
|
||||
- Binary path: ./target/release/polyscribe
|
||||
- GPU builds (optional): build with features
|
||||
- CUDA: cargo build --release --features gpu-cuda
|
||||
- HIP: cargo build --release --features gpu-hip
|
||||
- Vulkan: cargo build --release --features gpu-vulkan
|
||||
## Model Management
|
||||
|
||||
Quickstart
|
||||
1) Download a model (first run can prompt you):
|
||||
- ./target/release/polyscribe --download-models
|
||||
PolyScribe automatically manages Whisper models from Hugging Face:
|
||||
|
||||
2) Transcribe a file:
|
||||
- ./target/release/polyscribe -v -o output --out-format json --jobs 4 my_audio.mp3
|
||||
This writes JSON (because of --out-format json) into the output directory with a date prefix. Omit --out-format to write all available formats (JSON and SRT). For large batches, add --continue-on-error to skip bad files and keep going.
|
||||
```bash
|
||||
# Download models interactively
|
||||
polyscribe models download
|
||||
|
||||
Gotchas
|
||||
- English-only models: If you picked an English-only Whisper model (e.g., tiny.en, base.en), non-English language hints (via --language) will be rejected and detection may be biased toward English. Use a multilingual model (without the .en suffix) for non-English audio.
|
||||
- Language hints help: When you know the language, pass --language <code> (e.g., --language de) to improve accuracy and speed. If the audio is mixed language, omit the hint to let the model detect.
|
||||
# Update existing models
|
||||
polyscribe models update
|
||||
|
||||
Shell completions and man page
|
||||
- Completions: ./target/release/polyscribe completions <bash|zsh|fish|powershell|elvish> > polyscribe.<ext>
|
||||
- Then install into your shell’s completion directory.
|
||||
- Man page: ./target/release/polyscribe man > polyscribe.1 (then copy to your manpath)
|
||||
# Clear manifest cache (force fresh fetch)
|
||||
polyscribe models clear-cache
|
||||
```
|
||||
|
||||
Model locations
|
||||
- Development (debug builds): ./models next to the project.
|
||||
- Packaged/release builds: $XDG_DATA_HOME/polyscribe/models or ~/.local/share/polyscribe/models.
|
||||
- Override via env var: POLYSCRIBE_MODELS_DIR=/path/to/models.
|
||||
- Force a specific model file via env var: WHISPER_MODEL=/path/to/model.bin.
|
||||
### Manifest Caching
|
||||
|
||||
Most-used CLI flags
|
||||
- -o, --output FILE_OR_DIR: Output path base (date prefix added). If omitted, JSON prints to stdout.
|
||||
- --out-format <json|toml|srt|all>: Which on-disk format(s) to write; repeatable; default all. Example: --out-format json --out-format srt
|
||||
- -m, --merge: Merge all inputs into one output; otherwise one output per input.
|
||||
- --merge-and-separate: Write both merged output and separate per-input outputs (requires -o dir).
|
||||
- --set-speaker-names: Prompt for a speaker label per input file.
|
||||
- --update-models: Verify/update local models by size/hash against the upstream manifest.
|
||||
- --download-models: Interactive model list + multi-select download.
|
||||
- --language LANG: Language code hint (e.g., en, de). English-only models reject non-en hints.
|
||||
- --gpu-backend [auto|cpu|cuda|hip|vulkan]: Select backend (auto by default).
|
||||
- --gpu-layers N: Offload N layers to GPU when supported.
|
||||
- -v/--verbose (repeatable): Increase log verbosity. -vv shows very detailed logs.
|
||||
- -q/--quiet: Suppress non-error logs (stderr); does not silence stdout results.
|
||||
- --no-interaction: Never prompt; suitable for CI.
|
||||
- --no-progress: Disable progress bars (also honors NO_PROGRESS=1). Progress bars render on stderr only and auto-disable when not a TTY.
|
||||
The Hugging Face model manifest is cached locally to avoid repeated network requests:
|
||||
|
||||
Minimal usage examples
|
||||
- Transcribe an audio file to JSON/SRT:
|
||||
- ./target/release/polyscribe -o output samples/podcast_clip.mp3
|
||||
- Merge multiple transcripts into one:
|
||||
- ./target/release/polyscribe -m -o output merged input/a.json input/b.json
|
||||
- Update local models non-interactively (good for CI):
|
||||
- ./target/release/polyscribe --update-models --no-interaction -q
|
||||
- **Default TTL**: 24 hours
|
||||
- **Cache location**: `$XDG_CACHE_HOME/polyscribe/manifest/` (or platform equivalent)
|
||||
- **Environment variables**:
|
||||
- `POLYSCRIBE_NO_CACHE_MANIFEST=1`: Disable caching
|
||||
- `POLYSCRIBE_MANIFEST_TTL_SECONDS=3600`: Set custom TTL (in seconds)
|
||||
|
||||
Troubleshooting & docs
|
||||
- docs/faq.md – common issues and solutions (missing ffmpeg, GPU selection, model paths)
|
||||
- docs/usage.md – complete CLI reference and workflows
|
||||
- docs/development.md – build, run, and contribute locally
|
||||
- docs/design.md – architecture overview and decisions
|
||||
- docs/release-packaging.md – packaging notes for distributions
|
||||
- docs/ci.md – minimal CI checklist and job outline
|
||||
- CONTRIBUTING.md – PR checklist and workflow
|
||||
## Installation
|
||||
|
||||
CI status: [CI workflow runs](actions/workflows/ci.yml)
|
||||
```bash
|
||||
cargo install --path .
|
||||
```
|
||||
|
||||
Examples
|
||||
See the examples/ directory for copy-paste scripts:
|
||||
- examples/transcribe_file.sh
|
||||
- examples/update_models.sh
|
||||
- examples/download_models_interactive.sh
|
||||
## Usage
|
||||
|
||||
License
|
||||
-------
|
||||
This project is licensed under the MIT License — see the LICENSE file for details.
|
||||
```bash
|
||||
# Transcribe audio/video
|
||||
polyscribe transcribe input.mp4
|
||||
|
||||
# Merge multiple transcripts
|
||||
polyscribe transcribe --merge input1.json input2.json
|
||||
|
||||
# Use specific GPU backend
|
||||
polyscribe transcribe --gpu-backend cuda input.mp4
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
```bash
|
||||
# Build
|
||||
cargo build
|
||||
|
||||
# Run tests
|
||||
cargo test
|
||||
|
||||
# Run with verbose logging
|
||||
cargo run -- --verbose transcribe input.mp4
|
||||
```
|
||||
|
32
crates/polyscribe-cli/Cargo.toml
Normal file
32
crates/polyscribe-cli/Cargo.toml
Normal file
@@ -0,0 +1,32 @@
|
||||
[package]
|
||||
name = "polyscribe-cli"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "polyscribe"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
clap_complete = { workspace = true }
|
||||
clap_mangen = { workspace = true }
|
||||
directories = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt-multi-thread", "macros", "process", "fs"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true, features = ["fmt", "env-filter"] }
|
||||
which = { workspace = true }
|
||||
|
||||
polyscribe-core = { path = "../polyscribe-core" }
|
||||
polyscribe-host = { path = "../polyscribe-host" }
|
||||
polyscribe-protocol = { path = "../polyscribe-protocol" }
|
||||
|
||||
[features]
|
||||
# Optional GPU-specific flags can be forwarded down to core/host if needed
|
||||
default = []
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = { workspace = true }
|
191
crates/polyscribe-cli/src/cli.rs
Normal file
191
crates/polyscribe-cli/src/cli.rs
Normal file
@@ -0,0 +1,191 @@
|
||||
use clap::{Args, Parser, Subcommand, ValueEnum};
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Debug, Clone, ValueEnum)]
|
||||
pub enum GpuBackend {
|
||||
Auto,
|
||||
Cpu,
|
||||
Cuda,
|
||||
Hip,
|
||||
Vulkan,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Args)]
|
||||
pub struct OutputOpts {
|
||||
/// Emit machine-readable JSON to stdout; suppress decorative logs
|
||||
#[arg(long, global = true, action = clap::ArgAction::SetTrue)]
|
||||
pub json: bool,
|
||||
/// Reduce log chatter (errors only unless --json)
|
||||
#[arg(long, global = true, action = clap::ArgAction::SetTrue)]
|
||||
pub quiet: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
#[command(
|
||||
name = "polyscribe",
|
||||
version,
|
||||
about = "PolyScribe – local-first transcription and plugins",
|
||||
propagate_version = true,
|
||||
arg_required_else_help = true,
|
||||
)]
|
||||
pub struct Cli {
|
||||
/// Global output options
|
||||
#[command(flatten)]
|
||||
pub output: OutputOpts,
|
||||
|
||||
/// Increase verbosity (-v, -vv)
|
||||
#[arg(short, long, action = clap::ArgAction::Count)]
|
||||
pub verbose: u8,
|
||||
|
||||
/// Never prompt for user input (non-interactive mode)
|
||||
#[arg(long, default_value_t = false)]
|
||||
pub no_interaction: bool,
|
||||
|
||||
/// Disable progress bars/spinners
|
||||
#[arg(long, default_value_t = false)]
|
||||
pub no_progress: bool,
|
||||
|
||||
#[command(subcommand)]
|
||||
pub command: Commands,
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
pub enum Commands {
|
||||
/// Transcribe audio/video files or merge existing transcripts
|
||||
Transcribe {
|
||||
/// Output file or directory (date prefix is added when directory)
|
||||
#[arg(short, long)]
|
||||
output: Option<PathBuf>,
|
||||
|
||||
/// Merge multiple inputs into one output
|
||||
#[arg(short = 'm', long, default_value_t = false)]
|
||||
merge: bool,
|
||||
|
||||
/// Write both merged and per-input outputs (requires -o dir)
|
||||
#[arg(long, default_value_t = false)]
|
||||
merge_and_separate: bool,
|
||||
|
||||
/// Language code hint, e.g. en, de
|
||||
#[arg(long)]
|
||||
language: Option<String>,
|
||||
|
||||
/// Prompt for a speaker label per input file
|
||||
#[arg(long, default_value_t = false)]
|
||||
set_speaker_names: bool,
|
||||
|
||||
/// GPU backend selection
|
||||
#[arg(long, value_enum, default_value_t = GpuBackend::Auto)]
|
||||
gpu_backend: GpuBackend,
|
||||
|
||||
/// Offload N layers to GPU (when supported)
|
||||
#[arg(long, default_value_t = 0)]
|
||||
gpu_layers: usize,
|
||||
|
||||
/// Input paths: audio/video files or JSON transcripts
|
||||
#[arg(required = true)]
|
||||
inputs: Vec<PathBuf>,
|
||||
},
|
||||
|
||||
/// Manage Whisper GGUF models (Hugging Face)
|
||||
Models {
|
||||
#[command(subcommand)]
|
||||
cmd: ModelsCmd,
|
||||
},
|
||||
|
||||
/// Discover and run plugins
|
||||
Plugins {
|
||||
#[command(subcommand)]
|
||||
cmd: PluginsCmd,
|
||||
},
|
||||
|
||||
/// Generate shell completions to stdout
|
||||
Completions {
|
||||
/// Shell to generate completions for
|
||||
#[arg(value_parser = ["bash", "zsh", "fish", "powershell", "elvish"])]
|
||||
shell: String,
|
||||
},
|
||||
|
||||
/// Generate a man page to stdout
|
||||
Man,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Parser)]
|
||||
pub struct ModelCommon {
|
||||
/// Concurrency for ranged downloads
|
||||
#[arg(long, default_value_t = 4)]
|
||||
pub concurrency: usize,
|
||||
/// Limit download rate in bytes/sec (approximate)
|
||||
#[arg(long)]
|
||||
pub limit_rate: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
pub enum ModelsCmd {
|
||||
/// List installed models (from manifest)
|
||||
Ls {
|
||||
#[command(flatten)]
|
||||
common: ModelCommon,
|
||||
},
|
||||
/// Add or update a model
|
||||
Add {
|
||||
/// Hugging Face repo, e.g. ggml-org/models
|
||||
repo: String,
|
||||
/// File name in repo (e.g., gguf-tiny-q4_0.bin)
|
||||
file: String,
|
||||
#[command(flatten)]
|
||||
common: ModelCommon,
|
||||
},
|
||||
/// Remove a model by alias
|
||||
Rm {
|
||||
alias: String,
|
||||
#[command(flatten)]
|
||||
common: ModelCommon,
|
||||
},
|
||||
/// Verify model file integrity by alias
|
||||
Verify {
|
||||
alias: String,
|
||||
#[command(flatten)]
|
||||
common: ModelCommon,
|
||||
},
|
||||
/// Update all models (HEAD + ETag; skip if unchanged)
|
||||
Update {
|
||||
#[command(flatten)]
|
||||
common: ModelCommon,
|
||||
},
|
||||
/// Garbage-collect unreferenced files and stale manifest entries
|
||||
Gc {
|
||||
#[command(flatten)]
|
||||
common: ModelCommon,
|
||||
},
|
||||
/// Search a repo for GGUF files
|
||||
Search {
|
||||
/// Hugging Face repo, e.g. ggml-org/models
|
||||
repo: String,
|
||||
/// Optional substring to filter filenames
|
||||
#[arg(long)]
|
||||
query: Option<String>,
|
||||
#[command(flatten)]
|
||||
common: ModelCommon,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
pub enum PluginsCmd {
|
||||
/// List installed plugins
|
||||
List,
|
||||
/// Show a plugin's capabilities (as JSON)
|
||||
Info {
|
||||
/// Plugin short name, e.g., "tubescribe"
|
||||
name: String,
|
||||
},
|
||||
/// Run a plugin command (JSON-RPC over NDJSON via stdio)
|
||||
Run {
|
||||
/// Plugin short name
|
||||
name: String,
|
||||
/// Command name in plugin's API
|
||||
command: String,
|
||||
/// JSON payload string
|
||||
#[arg(long)]
|
||||
json: Option<String>,
|
||||
},
|
||||
}
|
470
crates/polyscribe-cli/src/main.rs
Normal file
470
crates/polyscribe-cli/src/main.rs
Normal file
@@ -0,0 +1,470 @@
|
||||
mod cli;
|
||||
mod output;
|
||||
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use clap::{CommandFactory, Parser};
|
||||
use cli::{Cli, Commands, GpuBackend, ModelsCmd, ModelCommon, PluginsCmd};
|
||||
use output::OutputMode;
|
||||
use polyscribe_core::model_manager::{ModelManager, Settings, ReqwestClient};
|
||||
use polyscribe_core::ui;
|
||||
fn normalized_similarity(a: &str, b: &str) -> f64 {
|
||||
// simple Levenshtein distance; normalized to [0,1]
|
||||
let a_bytes = a.as_bytes();
|
||||
let b_bytes = b.as_bytes();
|
||||
let n = a_bytes.len();
|
||||
let m = b_bytes.len();
|
||||
if n == 0 && m == 0 { return 1.0; }
|
||||
if n == 0 || m == 0 { return 0.0; }
|
||||
let mut prev: Vec<usize> = (0..=m).collect();
|
||||
let mut curr: Vec<usize> = vec![0; m + 1];
|
||||
for i in 1..=n {
|
||||
curr[0] = i;
|
||||
for j in 1..=m {
|
||||
let cost = if a_bytes[i - 1] == b_bytes[j - 1] { 0 } else { 1 };
|
||||
curr[j] = (prev[j] + 1)
|
||||
.min(curr[j - 1] + 1)
|
||||
.min(prev[j - 1] + cost);
|
||||
}
|
||||
std::mem::swap(&mut prev, &mut curr);
|
||||
}
|
||||
let dist = prev[m] as f64;
|
||||
let max_len = n.max(m) as f64;
|
||||
1.0 - (dist / max_len)
|
||||
}
|
||||
|
||||
fn human_size(bytes: Option<u64>) -> String {
|
||||
match bytes {
|
||||
Some(n) => {
|
||||
let x = n as f64;
|
||||
const KB: f64 = 1024.0;
|
||||
const MB: f64 = 1024.0 * KB;
|
||||
const GB: f64 = 1024.0 * MB;
|
||||
if x >= GB { format!("{:.2} GiB", x / GB) }
|
||||
else if x >= MB { format!("{:.2} MiB", x / MB) }
|
||||
else if x >= KB { format!("{:.2} KiB", x / KB) }
|
||||
else { format!("{} B", n) }
|
||||
}
|
||||
None => "?".to_string(),
|
||||
}
|
||||
}
|
||||
use polyscribe_core::ui::progress::ProgressReporter;
|
||||
use polyscribe_host::PluginManager;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
fn init_tracing(json_mode: bool, quiet: bool, verbose: u8) {
|
||||
// In JSON mode, suppress human logs; route errors to stderr only.
|
||||
let level = if json_mode || quiet { "error" } else { match verbose { 0 => "info", 1 => "debug", _ => "trace" } };
|
||||
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(level));
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(filter)
|
||||
.with_target(false)
|
||||
.with_level(true)
|
||||
.with_writer(std::io::stderr)
|
||||
.compact()
|
||||
.init();
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Cli::parse();
|
||||
|
||||
// Determine output mode early for logging and UI configuration
|
||||
let output_mode = if args.output.json {
|
||||
OutputMode::Json
|
||||
} else {
|
||||
OutputMode::Human { quiet: args.output.quiet }
|
||||
};
|
||||
|
||||
init_tracing(matches!(output_mode, OutputMode::Json), args.output.quiet, args.verbose);
|
||||
|
||||
// Suppress decorative UI output in JSON mode as well
|
||||
polyscribe_core::set_quiet(args.output.quiet || matches!(output_mode, OutputMode::Json));
|
||||
polyscribe_core::set_no_interaction(args.no_interaction);
|
||||
polyscribe_core::set_verbose(args.verbose);
|
||||
polyscribe_core::set_no_progress(args.no_progress);
|
||||
|
||||
match args.command {
|
||||
Commands::Transcribe {
|
||||
gpu_backend,
|
||||
gpu_layers,
|
||||
inputs,
|
||||
..
|
||||
} => {
|
||||
polyscribe_core::ui::info("starting transcription workflow");
|
||||
let mut progress = ProgressReporter::new(args.no_interaction);
|
||||
|
||||
progress.step("Validating inputs");
|
||||
if inputs.is_empty() {
|
||||
return Err(anyhow!("no inputs provided"));
|
||||
}
|
||||
|
||||
progress.step("Selecting backend and preparing model");
|
||||
match gpu_backend {
|
||||
GpuBackend::Auto => {}
|
||||
GpuBackend::Cpu => {}
|
||||
GpuBackend::Cuda => {
|
||||
let _ = gpu_layers;
|
||||
}
|
||||
GpuBackend::Hip => {}
|
||||
GpuBackend::Vulkan => {}
|
||||
}
|
||||
|
||||
progress.finish_with_message("Transcription completed (stub)");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Commands::Models { cmd } => {
|
||||
// predictable exit codes
|
||||
const EXIT_OK: i32 = 0;
|
||||
const EXIT_NOT_FOUND: i32 = 2;
|
||||
const EXIT_NETWORK: i32 = 3;
|
||||
const EXIT_VERIFY_FAILED: i32 = 4;
|
||||
// const EXIT_NO_CHANGE: i32 = 5; // reserved
|
||||
|
||||
let handle_common = |c: &ModelCommon| Settings {
|
||||
concurrency: c.concurrency.max(1),
|
||||
limit_rate: c.limit_rate,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let exit = match cmd {
|
||||
ModelsCmd::Ls { common } => {
|
||||
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
|
||||
let list = mm.ls()?;
|
||||
match output_mode {
|
||||
OutputMode::Json => {
|
||||
// Always emit JSON array (possibly empty)
|
||||
output_mode.print_json(&list);
|
||||
}
|
||||
OutputMode::Human { quiet } => {
|
||||
if list.is_empty() {
|
||||
if !quiet { println!("No models installed."); }
|
||||
} else {
|
||||
if !quiet { println!("Model (Repo)"); }
|
||||
for r in list {
|
||||
if !quiet { println!("{} ({})", r.file, r.repo); }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
EXIT_OK
|
||||
}
|
||||
ModelsCmd::Add { repo, file, common } => {
|
||||
let settings = handle_common(&common);
|
||||
let mm: ModelManager<ReqwestClient> = ModelManager::new(settings.clone())?;
|
||||
// Derive an alias automatically from repo and file
|
||||
fn derive_alias(repo: &str, file: &str) -> String {
|
||||
use std::path::Path;
|
||||
let repo_tail = repo.rsplit('/').next().unwrap_or(repo);
|
||||
let stem = Path::new(file)
|
||||
.file_stem()
|
||||
.and_then(|s| s.to_str())
|
||||
.unwrap_or(file);
|
||||
format!("{}-{}", repo_tail, stem)
|
||||
}
|
||||
let alias = derive_alias(&repo, &file);
|
||||
match mm.add_or_update(&alias, &repo, &file) {
|
||||
Ok(rec) => {
|
||||
match output_mode {
|
||||
OutputMode::Json => output_mode.print_json(&rec),
|
||||
OutputMode::Human { quiet } => {
|
||||
if !quiet { println!("installed: {} -> {}/{}", alias, repo, rec.file); }
|
||||
}
|
||||
}
|
||||
EXIT_OK
|
||||
}
|
||||
Err(e) => {
|
||||
// On not found or similar errors, try suggesting close matches interactively
|
||||
if matches!(output_mode, OutputMode::Json) || polyscribe_core::is_no_interaction() {
|
||||
match output_mode {
|
||||
OutputMode::Json => {
|
||||
// Emit error JSON object
|
||||
#[derive(serde::Serialize)]
|
||||
struct ErrObj<'a> { error: &'a str }
|
||||
let eo = ErrObj { error: &e.to_string() };
|
||||
output_mode.print_json(&eo);
|
||||
}
|
||||
_ => { eprintln!("error: {e}"); }
|
||||
}
|
||||
EXIT_NOT_FOUND
|
||||
} else {
|
||||
ui::warn(format!("{}", e));
|
||||
ui::info("Searching for similar model filenames…");
|
||||
match polyscribe_core::model_manager::search_repo(&repo, None) {
|
||||
Ok(mut files) => {
|
||||
if files.is_empty() {
|
||||
ui::warn("No files found in repository.");
|
||||
EXIT_NOT_FOUND
|
||||
} else {
|
||||
// rank by similarity
|
||||
files.sort_by(|a, b| normalized_similarity(&file, b)
|
||||
.partial_cmp(&normalized_similarity(&file, a))
|
||||
.unwrap_or(std::cmp::Ordering::Equal));
|
||||
let top: Vec<String> = files.into_iter().take(5).collect();
|
||||
if top.is_empty() {
|
||||
EXIT_NOT_FOUND
|
||||
} else if top.len() == 1 {
|
||||
let cand = &top[0];
|
||||
// Fetch repo size list once
|
||||
let size_map: std::collections::HashMap<String, Option<u64>> =
|
||||
polyscribe_core::model_manager::list_repo_files_with_meta(&repo)
|
||||
.unwrap_or_default()
|
||||
.into_iter().collect();
|
||||
let mut size = size_map.get(cand).cloned().unwrap_or(None);
|
||||
if size.is_none() {
|
||||
size = polyscribe_core::model_manager::head_len_for_file(&repo, cand);
|
||||
}
|
||||
let local_files: std::collections::HashSet<String> = mm.ls()?.into_iter().map(|r| r.file).collect();
|
||||
let is_local = local_files.contains(cand);
|
||||
let label = format!("{} [{}]{}", cand, human_size(size), if is_local { " (local)" } else { "" });
|
||||
let ok = ui::prompt_confirm(&format!("Did you mean {}?", label), true)
|
||||
.unwrap_or(false);
|
||||
if !ok { EXIT_NOT_FOUND } else {
|
||||
let mm2: ModelManager<ReqwestClient> = ModelManager::new(settings)?;
|
||||
let alias2 = derive_alias(&repo, cand);
|
||||
match mm2.add_or_update(&alias2, &repo, cand) {
|
||||
Ok(rec) => {
|
||||
match output_mode {
|
||||
OutputMode::Json => output_mode.print_json(&rec),
|
||||
OutputMode::Human { quiet } => { if !quiet { println!("installed: {} -> {}/{}", alias2, repo, rec.file); } }
|
||||
}
|
||||
EXIT_OK
|
||||
}
|
||||
Err(e2) => { eprintln!("error: {e2}"); EXIT_NETWORK }
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let opts: Vec<String> = top;
|
||||
let local_files: std::collections::HashSet<String> = mm.ls()?.into_iter().map(|r| r.file).collect();
|
||||
// Enrich labels with size and local tag using a single API call
|
||||
let size_map: std::collections::HashMap<String, Option<u64>> =
|
||||
polyscribe_core::model_manager::list_repo_files_with_meta(&repo)
|
||||
.unwrap_or_default()
|
||||
.into_iter().collect();
|
||||
let mut labels_owned: Vec<String> = Vec::new();
|
||||
for f in &opts {
|
||||
let mut size = size_map.get(f).cloned().unwrap_or(None);
|
||||
if size.is_none() {
|
||||
size = polyscribe_core::model_manager::head_len_for_file(&repo, f);
|
||||
}
|
||||
let is_local = local_files.contains(f);
|
||||
let suffix = if is_local { " (local)" } else { "" };
|
||||
labels_owned.push(format!("{} [{}]{}", f, human_size(size), suffix));
|
||||
}
|
||||
let labels: Vec<&str> = labels_owned.iter().map(|s| s.as_str()).collect();
|
||||
match ui::prompt_select("Pick a model", &labels) {
|
||||
Ok(idx) => {
|
||||
let chosen = &opts[idx];
|
||||
let mm2: ModelManager<ReqwestClient> = ModelManager::new(settings)?;
|
||||
let alias2 = derive_alias(&repo, chosen);
|
||||
match mm2.add_or_update(&alias2, &repo, chosen) {
|
||||
Ok(rec) => {
|
||||
match output_mode {
|
||||
OutputMode::Json => output_mode.print_json(&rec),
|
||||
OutputMode::Human { quiet } => { if !quiet { println!("installed: {} -> {}/{}", alias2, repo, rec.file); } }
|
||||
}
|
||||
EXIT_OK
|
||||
}
|
||||
Err(e2) => { eprintln!("error: {e2}"); EXIT_NETWORK }
|
||||
}
|
||||
}
|
||||
Err(_) => EXIT_NOT_FOUND,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e2) => {
|
||||
eprintln!("error: {}", e2);
|
||||
EXIT_NETWORK
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ModelsCmd::Rm { alias, common } => {
|
||||
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
|
||||
let ok = mm.rm(&alias)?;
|
||||
match output_mode {
|
||||
OutputMode::Json => {
|
||||
#[derive(serde::Serialize)]
|
||||
struct R { removed: bool }
|
||||
output_mode.print_json(&R { removed: ok });
|
||||
}
|
||||
OutputMode::Human { quiet } => {
|
||||
if !quiet { println!("{}", if ok { "removed" } else { "not found" }); }
|
||||
}
|
||||
}
|
||||
if ok { EXIT_OK } else { EXIT_NOT_FOUND }
|
||||
}
|
||||
ModelsCmd::Verify { alias, common } => {
|
||||
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
|
||||
let found = mm.ls()?.into_iter().any(|r| r.alias == alias);
|
||||
if !found {
|
||||
match output_mode {
|
||||
OutputMode::Json => {
|
||||
#[derive(serde::Serialize)]
|
||||
struct R<'a> { ok: bool, error: &'a str }
|
||||
output_mode.print_json(&R { ok: false, error: "not found" });
|
||||
}
|
||||
OutputMode::Human { quiet } => { if !quiet { println!("not found"); } }
|
||||
}
|
||||
EXIT_NOT_FOUND
|
||||
} else {
|
||||
let ok = mm.verify(&alias)?;
|
||||
match output_mode {
|
||||
OutputMode::Json => {
|
||||
#[derive(serde::Serialize)]
|
||||
struct R { ok: bool }
|
||||
output_mode.print_json(&R { ok });
|
||||
}
|
||||
OutputMode::Human { quiet } => { if !quiet { println!("{}", if ok { "ok" } else { "corrupt" }); } }
|
||||
}
|
||||
if ok { EXIT_OK } else { EXIT_VERIFY_FAILED }
|
||||
}
|
||||
}
|
||||
ModelsCmd::Update { common } => {
|
||||
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
|
||||
let mut rc = EXIT_OK;
|
||||
for rec in mm.ls()? {
|
||||
match mm.add_or_update(&rec.alias, &rec.repo, &rec.file) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
rc = EXIT_NETWORK;
|
||||
match output_mode {
|
||||
OutputMode::Json => {
|
||||
#[derive(serde::Serialize)]
|
||||
struct R<'a> { alias: &'a str, error: String }
|
||||
output_mode.print_json(&R { alias: &rec.alias, error: e.to_string() });
|
||||
}
|
||||
_ => { eprintln!("update {}: {e}", rec.alias); }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rc
|
||||
}
|
||||
ModelsCmd::Gc { common } => {
|
||||
let mm: ModelManager<ReqwestClient> = ModelManager::new(handle_common(&common))?;
|
||||
let (files_removed, entries_removed) = mm.gc()?;
|
||||
match output_mode {
|
||||
OutputMode::Json => {
|
||||
#[derive(serde::Serialize)]
|
||||
struct R { files_removed: usize, entries_removed: usize }
|
||||
output_mode.print_json(&R { files_removed, entries_removed });
|
||||
}
|
||||
OutputMode::Human { quiet } => { if !quiet { println!("files_removed={} entries_removed={}", files_removed, entries_removed); } }
|
||||
}
|
||||
EXIT_OK
|
||||
}
|
||||
ModelsCmd::Search { repo, query, common } => {
|
||||
let res = polyscribe_core::model_manager::search_repo(&repo, query.as_deref());
|
||||
match res {
|
||||
Ok(files) => {
|
||||
match output_mode {
|
||||
OutputMode::Json => output_mode.print_json(&files),
|
||||
OutputMode::Human { quiet } => { for f in files { if !quiet { println!("{}", f); } } }
|
||||
}
|
||||
EXIT_OK
|
||||
}
|
||||
Err(e) => {
|
||||
match output_mode {
|
||||
OutputMode::Json => {
|
||||
#[derive(serde::Serialize)]
|
||||
struct R { error: String }
|
||||
output_mode.print_json(&R { error: e.to_string() });
|
||||
}
|
||||
_ => { eprintln!("error: {e}"); }
|
||||
}
|
||||
EXIT_NETWORK
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
std::process::exit(exit);
|
||||
}
|
||||
|
||||
Commands::Plugins { cmd } => {
|
||||
let plugin_manager = PluginManager;
|
||||
|
||||
match cmd {
|
||||
PluginsCmd::List => {
|
||||
let list = plugin_manager.list().context("discovering plugins")?;
|
||||
for item in list {
|
||||
polyscribe_core::ui::info(item.name);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
PluginsCmd::Info { name } => {
|
||||
let info = plugin_manager
|
||||
.info(&name)
|
||||
.with_context(|| format!("getting info for {}", name))?;
|
||||
let info_json = serde_json::to_string_pretty(&info)?;
|
||||
polyscribe_core::ui::info(info_json);
|
||||
Ok(())
|
||||
}
|
||||
PluginsCmd::Run {
|
||||
name,
|
||||
command,
|
||||
json,
|
||||
} => {
|
||||
// Use a local Tokio runtime only for this async path
|
||||
let rt = tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.context("building tokio runtime")?;
|
||||
|
||||
rt.block_on(async {
|
||||
let payload = json.unwrap_or_else(|| "{}".to_string());
|
||||
let mut child = plugin_manager
|
||||
.spawn(&name, &command)
|
||||
.with_context(|| format!("spawning plugin {name} {command}"))?;
|
||||
|
||||
if let Some(mut stdin) = child.stdin.take() {
|
||||
stdin
|
||||
.write_all(payload.as_bytes())
|
||||
.await
|
||||
.context("writing JSON payload to plugin stdin")?;
|
||||
}
|
||||
|
||||
let status = plugin_manager.forward_stdio(&mut child).await?;
|
||||
if !status.success() {
|
||||
polyscribe_core::ui::error(format!(
|
||||
"plugin returned non-zero exit code: {}",
|
||||
status
|
||||
));
|
||||
return Err(anyhow!("plugin failed"));
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Commands::Completions { shell } => {
|
||||
use clap_complete::{generate, shells};
|
||||
use std::io;
|
||||
|
||||
let mut cmd = Cli::command();
|
||||
let name = cmd.get_name().to_string();
|
||||
|
||||
match shell.as_str() {
|
||||
"bash" => generate(shells::Bash, &mut cmd, name, &mut io::stdout()),
|
||||
"zsh" => generate(shells::Zsh, &mut cmd, name, &mut io::stdout()),
|
||||
"fish" => generate(shells::Fish, &mut cmd, name, &mut io::stdout()),
|
||||
"powershell" => generate(shells::PowerShell, &mut cmd, name, &mut io::stdout()),
|
||||
"elvish" => generate(shells::Elvish, &mut cmd, name, &mut io::stdout()),
|
||||
_ => return Err(anyhow!("unsupported shell: {shell}")),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Commands::Man => {
|
||||
use clap_mangen::Man;
|
||||
let cmd = Cli::command();
|
||||
let man = Man::new(cmd);
|
||||
man.render(&mut std::io::stdout())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
36
crates/polyscribe-cli/src/output.rs
Normal file
36
crates/polyscribe-cli/src/output.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
use std::io::{self, Write};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum OutputMode {
|
||||
Json,
|
||||
Human { quiet: bool },
|
||||
}
|
||||
|
||||
impl OutputMode {
|
||||
pub fn is_quiet(&self) -> bool {
|
||||
matches!(self, OutputMode::Json) || matches!(self, OutputMode::Human { quiet: true })
|
||||
}
|
||||
|
||||
pub fn print_json<T: serde::Serialize>(&self, v: &T) {
|
||||
if let OutputMode::Json = self {
|
||||
// Write compact JSON to stdout without prefixes
|
||||
// and ensure a trailing newline for CLI ergonomics
|
||||
let s = serde_json::to_string(v).unwrap_or_else(|e| format!("\"JSON_ERROR:{}\"", e));
|
||||
println!("{}", s);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn print_line(&self, s: impl AsRef<str>) {
|
||||
match self {
|
||||
OutputMode::Json => {
|
||||
// Suppress human lines in JSON mode
|
||||
}
|
||||
OutputMode::Human { quiet } => {
|
||||
if !quiet {
|
||||
let _ = writeln!(io::stdout(), "{}", s.as_ref());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1,10 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
|
||||
|
||||
use assert_cmd::cargo::cargo_bin;
|
||||
use std::process::Command;
|
||||
|
||||
fn bin() -> &'static str {
|
||||
env!("CARGO_BIN_EXE_polyscribe")
|
||||
fn bin() -> std::path::PathBuf {
|
||||
cargo_bin("polyscribe")
|
||||
}
|
||||
|
||||
#[test]
|
42
crates/polyscribe-cli/tests/models_smoke.rs
Normal file
42
crates/polyscribe-cli/tests/models_smoke.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use assert_cmd::cargo::cargo_bin;
|
||||
use std::process::Command;
|
||||
|
||||
fn bin() -> std::path::PathBuf { cargo_bin("polyscribe") }
|
||||
|
||||
#[test]
|
||||
fn models_help_shows_global_output_flags() {
|
||||
let out = Command::new(bin())
|
||||
.args(["models", "--help"]) // subcommand help
|
||||
.output()
|
||||
.expect("failed to run polyscribe models --help");
|
||||
assert!(out.status.success(), "help exited non-zero: {:?}", out.status);
|
||||
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
|
||||
assert!(stdout.contains("--json"), "--json not shown in help: {stdout}");
|
||||
assert!(stdout.contains("--quiet"), "--quiet not shown in help: {stdout}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn models_version_contains_pkg_version() {
|
||||
let out = Command::new(bin())
|
||||
.args(["models", "--version"]) // propagate_version
|
||||
.output()
|
||||
.expect("failed to run polyscribe models --version");
|
||||
assert!(out.status.success(), "version exited non-zero: {:?}", out.status);
|
||||
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
|
||||
let want = env!("CARGO_PKG_VERSION");
|
||||
assert!(stdout.contains(want), "version output missing {want}: {stdout}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn models_ls_json_quiet_emits_pure_json() {
|
||||
let out = Command::new(bin())
|
||||
.args(["models", "ls", "--json", "--quiet"]) // global flags
|
||||
.output()
|
||||
.expect("failed to run polyscribe models ls --json --quiet");
|
||||
assert!(out.status.success(), "ls exited non-zero: {:?}", out.status);
|
||||
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
|
||||
serde_json::from_str::<serde_json::Value>(stdout.trim()).expect("stdout is not valid JSON");
|
||||
// Expect no extra logs on stdout; stderr should be empty in success path
|
||||
assert!(out.stderr.is_empty(), "expected no stderr noise");
|
||||
}
|
||||
|
22
crates/polyscribe-core/Cargo.toml
Normal file
22
crates/polyscribe-core/Cargo.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[package]
|
||||
name = "polyscribe-core"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
directories = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
libc = { workspace = true }
|
||||
whisper-rs = { workspace = true }
|
||||
# UI and progress
|
||||
cliclack = { workspace = true }
|
||||
# HTTP downloads + hashing
|
||||
reqwest = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
hex = { workspace = true }
|
||||
tempfile = { workspace = true }
|
@@ -1,15 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
|
||||
|
||||
fn main() {
|
||||
// Only run special build steps when gpu-vulkan feature is enabled.
|
||||
let vulkan_enabled = std::env::var("CARGO_FEATURE_GPU_VULKAN").is_ok();
|
||||
println!("cargo:rerun-if-changed=extern/whisper.cpp");
|
||||
if !vulkan_enabled {
|
||||
println!(
|
||||
"cargo:warning=gpu-vulkan feature is disabled; skipping Vulkan-dependent build steps."
|
||||
);
|
||||
return;
|
||||
}
|
||||
// Placeholder: In a full implementation, we would invoke CMake for whisper.cpp with GGML_VULKAN=1.
|
||||
// For now, emit a helpful note. Build will proceed; runtime Vulkan backend returns an explanatory error.
|
||||
println!("cargo:rerun-if-changed=extern/whisper.cpp");
|
||||
println!(
|
||||
"cargo:warning=Building with gpu-vulkan: ensure Vulkan SDK/loader are installed. Future versions will compile whisper.cpp via CMake."
|
||||
);
|
303
crates/polyscribe-core/src/backend.rs
Normal file
303
crates/polyscribe-core/src/backend.rs
Normal file
@@ -0,0 +1,303 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
use crate::OutputEntry;
|
||||
use crate::prelude::*;
|
||||
use crate::{decode_audio_to_pcm_f32_ffmpeg, find_model_file};
|
||||
use anyhow::{Context, anyhow};
|
||||
use std::env;
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum BackendKind {
|
||||
Auto,
|
||||
Cpu,
|
||||
Cuda,
|
||||
Hip,
|
||||
Vulkan,
|
||||
}
|
||||
|
||||
pub trait TranscribeBackend {
|
||||
fn kind(&self) -> BackendKind;
|
||||
fn transcribe(
|
||||
&self,
|
||||
audio_path: &Path,
|
||||
speaker: &str,
|
||||
language: Option<&str>,
|
||||
gpu_layers: Option<u32>,
|
||||
progress: Option<&(dyn Fn(i32) + Send + Sync)>,
|
||||
) -> Result<Vec<OutputEntry>>;
|
||||
}
|
||||
|
||||
fn is_library_available(_names: &[&str]) -> bool {
|
||||
#[cfg(test)]
|
||||
{
|
||||
false
|
||||
}
|
||||
#[cfg(not(test))]
|
||||
{
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn cuda_available() -> bool {
|
||||
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_CUDA") {
|
||||
return x == "1";
|
||||
}
|
||||
is_library_available(&[
|
||||
"libcudart.so",
|
||||
"libcudart.so.12",
|
||||
"libcudart.so.11",
|
||||
"libcublas.so",
|
||||
"libcublas.so.12",
|
||||
])
|
||||
}
|
||||
|
||||
fn hip_available() -> bool {
|
||||
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_HIP") {
|
||||
return x == "1";
|
||||
}
|
||||
is_library_available(&["libhipblas.so", "librocblas.so"])
|
||||
}
|
||||
|
||||
fn vulkan_available() -> bool {
|
||||
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_VULKAN") {
|
||||
return x == "1";
|
||||
}
|
||||
is_library_available(&["libvulkan.so.1", "libvulkan.so"])
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct CpuBackend;
|
||||
#[derive(Default)]
|
||||
pub struct CudaBackend;
|
||||
#[derive(Default)]
|
||||
pub struct HipBackend;
|
||||
#[derive(Default)]
|
||||
pub struct VulkanBackend;
|
||||
|
||||
macro_rules! impl_whisper_backend {
|
||||
($ty:ty, $kind:expr) => {
|
||||
impl TranscribeBackend for $ty {
|
||||
fn kind(&self) -> BackendKind {
|
||||
$kind
|
||||
}
|
||||
fn transcribe(
|
||||
&self,
|
||||
audio_path: &Path,
|
||||
speaker: &str,
|
||||
language: Option<&str>,
|
||||
_gpu_layers: Option<u32>,
|
||||
progress: Option<&(dyn Fn(i32) + Send + Sync)>,
|
||||
) -> Result<Vec<OutputEntry>> {
|
||||
transcribe_with_whisper_rs(audio_path, speaker, language, progress)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_whisper_backend!(CpuBackend, BackendKind::Cpu);
|
||||
impl_whisper_backend!(CudaBackend, BackendKind::Cuda);
|
||||
impl_whisper_backend!(HipBackend, BackendKind::Hip);
|
||||
|
||||
impl TranscribeBackend for VulkanBackend {
|
||||
fn kind(&self) -> BackendKind {
|
||||
BackendKind::Vulkan
|
||||
}
|
||||
fn transcribe(
|
||||
&self,
|
||||
_audio_path: &Path,
|
||||
_speaker: &str,
|
||||
_language: Option<&str>,
|
||||
_gpu_layers: Option<u32>,
|
||||
_progress: Option<&(dyn Fn(i32) + Send + Sync)>,
|
||||
) -> Result<Vec<OutputEntry>> {
|
||||
Err(anyhow!(
|
||||
"Vulkan backend not yet wired to whisper.cpp FFI. Build with --features gpu-vulkan and ensure Vulkan SDK is installed. How to fix: install Vulkan loader (libvulkan), set VULKAN_SDK, and run cargo build --features gpu-vulkan."
|
||||
).into())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BackendSelection {
|
||||
pub backend: Box<dyn TranscribeBackend + Send + Sync>,
|
||||
pub chosen: BackendKind,
|
||||
pub detected: Vec<BackendKind>,
|
||||
}
|
||||
|
||||
pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<BackendSelection> {
|
||||
let mut detected = Vec::new();
|
||||
if cuda_available() {
|
||||
detected.push(BackendKind::Cuda);
|
||||
}
|
||||
if hip_available() {
|
||||
detected.push(BackendKind::Hip);
|
||||
}
|
||||
if vulkan_available() {
|
||||
detected.push(BackendKind::Vulkan);
|
||||
}
|
||||
|
||||
let instantiate_backend = |k: BackendKind| -> Box<dyn TranscribeBackend + Send + Sync> {
|
||||
match k {
|
||||
BackendKind::Cpu => Box::new(CpuBackend),
|
||||
BackendKind::Cuda => Box::new(CudaBackend),
|
||||
BackendKind::Hip => Box::new(HipBackend),
|
||||
BackendKind::Vulkan => Box::new(VulkanBackend),
|
||||
BackendKind::Auto => Box::new(CpuBackend),
|
||||
}
|
||||
};
|
||||
|
||||
let chosen = match requested {
|
||||
BackendKind::Auto => {
|
||||
if detected.contains(&BackendKind::Cuda) {
|
||||
BackendKind::Cuda
|
||||
} else if detected.contains(&BackendKind::Hip) {
|
||||
BackendKind::Hip
|
||||
} else if detected.contains(&BackendKind::Vulkan) {
|
||||
BackendKind::Vulkan
|
||||
} else {
|
||||
BackendKind::Cpu
|
||||
}
|
||||
}
|
||||
BackendKind::Cuda => {
|
||||
if detected.contains(&BackendKind::Cuda) {
|
||||
BackendKind::Cuda
|
||||
} else {
|
||||
return Err(anyhow!(
|
||||
"Requested CUDA backend but CUDA libraries/devices not detected. How to fix: install NVIDIA driver + CUDA toolkit, ensure libcudart/libcublas are in loader path, and build with --features gpu-cuda."
|
||||
).into());
|
||||
}
|
||||
}
|
||||
BackendKind::Hip => {
|
||||
if detected.contains(&BackendKind::Hip) {
|
||||
BackendKind::Hip
|
||||
} else {
|
||||
return Err(anyhow!(
|
||||
"Requested ROCm/HIP backend but libraries/devices not detected. How to fix: install ROCm hipBLAS/rocBLAS, ensure libs are in loader path, and build with --features gpu-hip."
|
||||
).into());
|
||||
}
|
||||
}
|
||||
BackendKind::Vulkan => {
|
||||
if detected.contains(&BackendKind::Vulkan) {
|
||||
BackendKind::Vulkan
|
||||
} else {
|
||||
return Err(anyhow!(
|
||||
"Requested Vulkan backend but libvulkan not detected. How to fix: install Vulkan loader/SDK and build with --features gpu-vulkan."
|
||||
).into());
|
||||
}
|
||||
}
|
||||
BackendKind::Cpu => BackendKind::Cpu,
|
||||
};
|
||||
|
||||
if verbose {
|
||||
crate::dlog!(1, "Detected backends: {:?}", detected);
|
||||
crate::dlog!(1, "Selected backend: {:?}", chosen);
|
||||
}
|
||||
|
||||
Ok(BackendSelection {
|
||||
backend: instantiate_backend(chosen),
|
||||
chosen,
|
||||
detected,
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn transcribe_with_whisper_rs(
|
||||
audio_path: &Path,
|
||||
speaker: &str,
|
||||
language: Option<&str>,
|
||||
progress: Option<&(dyn Fn(i32) + Send + Sync)>,
|
||||
) -> Result<Vec<OutputEntry>> {
|
||||
let report = |p: i32| {
|
||||
if let Some(cb) = progress {
|
||||
cb(p);
|
||||
}
|
||||
};
|
||||
report(0);
|
||||
|
||||
let pcm_samples = decode_audio_to_pcm_f32_ffmpeg(audio_path)?;
|
||||
report(5);
|
||||
|
||||
let model_path = find_model_file()?;
|
||||
let english_only_model = model_path
|
||||
.file_name()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(|s| s.contains(".en.") || s.ends_with(".en.bin"))
|
||||
.unwrap_or(false);
|
||||
if let Some(lang) = language
|
||||
&& english_only_model
|
||||
&& lang != "en"
|
||||
{
|
||||
return Err(anyhow!(
|
||||
"Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.",
|
||||
model_path.display(),
|
||||
lang
|
||||
).into());
|
||||
}
|
||||
let model_path_str = model_path
|
||||
.to_str()
|
||||
.ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model_path.display()))?;
|
||||
|
||||
if crate::verbose_level() < 2 {
|
||||
unsafe {
|
||||
std::env::set_var("GGML_LOG_LEVEL", "0");
|
||||
std::env::set_var("WHISPER_PRINT_PROGRESS", "0");
|
||||
}
|
||||
}
|
||||
|
||||
let (_context, mut state) = crate::with_suppressed_stderr(|| {
|
||||
let params = whisper_rs::WhisperContextParameters::default();
|
||||
let context = whisper_rs::WhisperContext::new_with_params(model_path_str, params)
|
||||
.with_context(|| format!("Failed to load Whisper model at {}", model_path.display()))?;
|
||||
let state = context
|
||||
.create_state()
|
||||
.map_err(|e| anyhow!("Failed to create Whisper state: {:?}", e))?;
|
||||
Ok::<_, anyhow::Error>((context, state))
|
||||
})?;
|
||||
report(20);
|
||||
|
||||
let mut full_params =
|
||||
whisper_rs::FullParams::new(whisper_rs::SamplingStrategy::Greedy { best_of: 1 });
|
||||
let threads = std::thread::available_parallelism()
|
||||
.map(|n| n.get() as i32)
|
||||
.unwrap_or(1);
|
||||
full_params.set_n_threads(threads);
|
||||
full_params.set_translate(false);
|
||||
if let Some(lang) = language {
|
||||
full_params.set_language(Some(lang));
|
||||
}
|
||||
report(30);
|
||||
|
||||
crate::with_suppressed_stderr(|| {
|
||||
report(40);
|
||||
state
|
||||
.full(full_params, &pcm_samples)
|
||||
.map_err(|e| anyhow!("Whisper full() failed: {:?}", e))
|
||||
})?;
|
||||
|
||||
report(90);
|
||||
let num_segments = state
|
||||
.full_n_segments()
|
||||
.map_err(|e| anyhow!("Failed to get segments: {:?}", e))?;
|
||||
let mut entries = Vec::new();
|
||||
for seg_idx in 0..num_segments {
|
||||
let segment_text = state
|
||||
.full_get_segment_text(seg_idx)
|
||||
.map_err(|e| anyhow!("Failed to get segment text: {:?}", e))?;
|
||||
let t0 = state
|
||||
.full_get_segment_t0(seg_idx)
|
||||
.map_err(|e| anyhow!("Failed to get segment t0: {:?}", e))?;
|
||||
let t1 = state
|
||||
.full_get_segment_t1(seg_idx)
|
||||
.map_err(|e| anyhow!("Failed to get segment t1: {:?}", e))?;
|
||||
let start = (t0 as f64) * 0.01;
|
||||
let end = (t1 as f64) * 0.01;
|
||||
entries.push(OutputEntry {
|
||||
id: 0,
|
||||
speaker: speaker.to_string(),
|
||||
start,
|
||||
end,
|
||||
text: segment_text.trim().to_string(),
|
||||
});
|
||||
}
|
||||
report(100);
|
||||
Ok(entries)
|
||||
}
|
104
crates/polyscribe-core/src/config.rs
Normal file
104
crates/polyscribe-core/src/config.rs
Normal file
@@ -0,0 +1,104 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub struct ConfigService;
|
||||
|
||||
impl ConfigService {
|
||||
pub const ENV_NO_CACHE_MANIFEST: &'static str = "POLYSCRIBE_NO_CACHE_MANIFEST";
|
||||
pub const ENV_MANIFEST_TTL_SECONDS: &'static str = "POLYSCRIBE_MANIFEST_TTL_SECONDS";
|
||||
pub const ENV_MODELS_DIR: &'static str = "POLYSCRIBE_MODELS_DIR";
|
||||
pub const ENV_USER_AGENT: &'static str = "POLYSCRIBE_USER_AGENT";
|
||||
pub const ENV_HTTP_TIMEOUT_SECS: &'static str = "POLYSCRIBE_HTTP_TIMEOUT_SECS";
|
||||
pub const ENV_HF_REPO: &'static str = "POLYSCRIBE_HF_REPO";
|
||||
pub const ENV_CACHE_FILENAME: &'static str = "POLYSCRIBE_MANIFEST_CACHE_FILENAME";
|
||||
|
||||
pub const DEFAULT_USER_AGENT: &'static str = "polyscribe/0.1";
|
||||
pub const DEFAULT_DOWNLOADER_UA: &'static str = "polyscribe-model-downloader/1";
|
||||
pub const DEFAULT_HF_REPO: &'static str = "ggerganov/whisper.cpp";
|
||||
pub const DEFAULT_CACHE_FILENAME: &'static str = "hf_manifest_whisper_cpp.json";
|
||||
pub const DEFAULT_HTTP_TIMEOUT_SECS: u64 = 8;
|
||||
pub const DEFAULT_MANIFEST_CACHE_TTL_SECONDS: u64 = 24 * 60 * 60;
|
||||
|
||||
pub fn project_dirs() -> Option<directories::ProjectDirs> {
|
||||
directories::ProjectDirs::from("dev", "polyscribe", "polyscribe")
|
||||
}
|
||||
|
||||
pub fn default_models_dir() -> Option<PathBuf> {
|
||||
Self::project_dirs().map(|d| d.data_dir().join("models"))
|
||||
}
|
||||
|
||||
pub fn default_plugins_dir() -> Option<PathBuf> {
|
||||
Self::project_dirs().map(|d| d.data_dir().join("plugins"))
|
||||
}
|
||||
|
||||
pub fn manifest_cache_dir() -> Option<PathBuf> {
|
||||
Self::project_dirs().map(|d| d.cache_dir().join("manifest"))
|
||||
}
|
||||
|
||||
pub fn bypass_manifest_cache() -> bool {
|
||||
env::var(Self::ENV_NO_CACHE_MANIFEST).is_ok()
|
||||
}
|
||||
|
||||
pub fn manifest_cache_ttl_seconds() -> u64 {
|
||||
env::var(Self::ENV_MANIFEST_TTL_SECONDS)
|
||||
.ok()
|
||||
.and_then(|s| s.parse::<u64>().ok())
|
||||
.unwrap_or(Self::DEFAULT_MANIFEST_CACHE_TTL_SECONDS)
|
||||
}
|
||||
|
||||
pub fn manifest_cache_filename() -> String {
|
||||
env::var(Self::ENV_CACHE_FILENAME)
|
||||
.unwrap_or_else(|_| Self::DEFAULT_CACHE_FILENAME.to_string())
|
||||
}
|
||||
|
||||
pub fn models_dir(cfg: Option<&Config>) -> Option<PathBuf> {
|
||||
if let Ok(env_dir) = env::var(Self::ENV_MODELS_DIR) {
|
||||
if !env_dir.is_empty() {
|
||||
return Some(PathBuf::from(env_dir));
|
||||
}
|
||||
}
|
||||
if let Some(c) = cfg {
|
||||
if let Some(dir) = c.models_dir.clone() {
|
||||
return Some(dir);
|
||||
}
|
||||
}
|
||||
Self::default_models_dir()
|
||||
}
|
||||
|
||||
pub fn user_agent() -> String {
|
||||
env::var(Self::ENV_USER_AGENT).unwrap_or_else(|_| Self::DEFAULT_USER_AGENT.to_string())
|
||||
}
|
||||
|
||||
pub fn downloader_user_agent() -> String {
|
||||
env::var(Self::ENV_USER_AGENT).unwrap_or_else(|_| Self::DEFAULT_DOWNLOADER_UA.to_string())
|
||||
}
|
||||
|
||||
pub fn http_timeout_secs() -> u64 {
|
||||
env::var(Self::ENV_HTTP_TIMEOUT_SECS)
|
||||
.ok()
|
||||
.and_then(|s| s.parse::<u64>().ok())
|
||||
.unwrap_or(Self::DEFAULT_HTTP_TIMEOUT_SECS)
|
||||
}
|
||||
|
||||
pub fn hf_repo() -> String {
|
||||
env::var(Self::ENV_HF_REPO).unwrap_or_else(|_| Self::DEFAULT_HF_REPO.to_string())
|
||||
}
|
||||
|
||||
pub fn hf_api_base_for(repo: &str) -> String {
|
||||
format!("https://huggingface.co/api/models/{}", repo)
|
||||
}
|
||||
|
||||
pub fn manifest_cache_path() -> Option<PathBuf> {
|
||||
let dir = Self::manifest_cache_dir()?;
|
||||
Some(dir.join(Self::manifest_cache_filename()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct Config {
|
||||
pub models_dir: Option<PathBuf>,
|
||||
pub plugins_dir: Option<PathBuf>,
|
||||
}
|
31
crates/polyscribe-core/src/error.rs
Normal file
31
crates/polyscribe-core/src/error.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum Error {
|
||||
#[error("I/O error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("serde error: {0}")]
|
||||
Serde(#[from] serde_json::Error),
|
||||
|
||||
#[error("toml error: {0}")]
|
||||
Toml(#[from] toml::de::Error),
|
||||
|
||||
#[error("toml ser error: {0}")]
|
||||
TomlSer(#[from] toml::ser::Error),
|
||||
|
||||
#[error("env var error: {0}")]
|
||||
EnvVar(#[from] std::env::VarError),
|
||||
|
||||
#[error("http error: {0}")]
|
||||
Http(#[from] reqwest::Error),
|
||||
|
||||
#[error("other: {0}")]
|
||||
Other(String),
|
||||
}
|
||||
|
||||
impl From<anyhow::Error> for Error {
|
||||
fn from(e: anyhow::Error) -> Self {
|
||||
Error::Other(e.to_string())
|
||||
}
|
||||
}
|
442
crates/polyscribe-core/src/lib.rs
Normal file
442
crates/polyscribe-core/src/lib.rs
Normal file
@@ -0,0 +1,442 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#![forbid(elided_lifetimes_in_paths)]
|
||||
#![forbid(unused_must_use)]
|
||||
#![warn(clippy::all)]
|
||||
|
||||
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
|
||||
|
||||
use crate::prelude::*;
|
||||
use anyhow::{Context, anyhow};
|
||||
use chrono::Local;
|
||||
use std::env;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
|
||||
#[cfg(unix)]
|
||||
use libc::{O_WRONLY, close, dup, dup2, open};
|
||||
|
||||
static QUIET: AtomicBool = AtomicBool::new(false);
|
||||
static NO_INTERACTION: AtomicBool = AtomicBool::new(false);
|
||||
static VERBOSE: AtomicU8 = AtomicU8::new(0);
|
||||
static NO_PROGRESS: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
pub fn set_quiet(enabled: bool) {
|
||||
QUIET.store(enabled, Ordering::Relaxed);
|
||||
}
|
||||
pub fn is_quiet() -> bool {
|
||||
QUIET.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn set_no_interaction(enabled: bool) {
|
||||
NO_INTERACTION.store(enabled, Ordering::Relaxed);
|
||||
}
|
||||
pub fn is_no_interaction() -> bool {
|
||||
NO_INTERACTION.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn set_verbose(level: u8) {
|
||||
VERBOSE.store(level, Ordering::Relaxed);
|
||||
}
|
||||
pub fn verbose_level() -> u8 {
|
||||
VERBOSE.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn set_no_progress(enabled: bool) {
|
||||
NO_PROGRESS.store(enabled, Ordering::Relaxed);
|
||||
}
|
||||
pub fn is_no_progress() -> bool {
|
||||
NO_PROGRESS.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn stdin_is_tty() -> bool {
|
||||
use std::io::IsTerminal as _;
|
||||
std::io::stdin().is_terminal()
|
||||
}
|
||||
|
||||
pub struct StderrSilencer {
|
||||
#[cfg(unix)]
|
||||
old_stderr_fd: i32,
|
||||
#[cfg(unix)]
|
||||
devnull_fd: i32,
|
||||
active: bool,
|
||||
}
|
||||
|
||||
impl StderrSilencer {
|
||||
pub fn activate_if_quiet() -> Self {
|
||||
if !is_quiet() {
|
||||
return Self {
|
||||
active: false,
|
||||
#[cfg(unix)]
|
||||
old_stderr_fd: -1,
|
||||
#[cfg(unix)]
|
||||
devnull_fd: -1,
|
||||
};
|
||||
}
|
||||
Self::activate()
|
||||
}
|
||||
|
||||
pub fn activate() -> Self {
|
||||
#[cfg(unix)]
|
||||
unsafe {
|
||||
let old_fd = dup(2);
|
||||
if old_fd < 0 {
|
||||
return Self {
|
||||
active: false,
|
||||
old_stderr_fd: -1,
|
||||
devnull_fd: -1,
|
||||
};
|
||||
}
|
||||
let devnull_cstr = std::ffi::CString::new("/dev/null").unwrap();
|
||||
let devnull_fd = open(devnull_cstr.as_ptr(), O_WRONLY);
|
||||
if devnull_fd < 0 {
|
||||
let _ = close(old_fd);
|
||||
return Self {
|
||||
active: false,
|
||||
old_stderr_fd: -1,
|
||||
devnull_fd: -1,
|
||||
};
|
||||
}
|
||||
if dup2(devnull_fd, 2) < 0 {
|
||||
let _ = close(devnull_fd);
|
||||
let _ = close(old_fd);
|
||||
return Self {
|
||||
active: false,
|
||||
old_stderr_fd: -1,
|
||||
devnull_fd: -1,
|
||||
};
|
||||
}
|
||||
Self {
|
||||
active: true,
|
||||
old_stderr_fd: old_fd,
|
||||
devnull_fd,
|
||||
}
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
Self { active: false }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for StderrSilencer {
|
||||
fn drop(&mut self) {
|
||||
if !self.active {
|
||||
return;
|
||||
}
|
||||
#[cfg(unix)]
|
||||
unsafe {
|
||||
let _ = dup2(self.old_stderr_fd, 2);
|
||||
let _ = close(self.old_stderr_fd);
|
||||
let _ = close(self.devnull_fd);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_suppressed_stderr<F, T>(f: F) -> T
|
||||
where
|
||||
F: FnOnce() -> T,
|
||||
{
|
||||
let silencer = StderrSilencer::activate_if_quiet();
|
||||
let result = f();
|
||||
drop(silencer);
|
||||
result
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! elog {
|
||||
($($arg:tt)*) => {{ $crate::ui::error(format!($($arg)*)); }}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! ilog {
|
||||
($($arg:tt)*) => {{
|
||||
if !$crate::is_quiet() { $crate::ui::info(format!($($arg)*)); }
|
||||
}}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! dlog {
|
||||
($lvl:expr, $($arg:tt)*) => {{
|
||||
if !$crate::is_quiet() && $crate::verbose_level() >= $lvl { $crate::ui::info(format!("DEBUG{}: {}", $lvl, format!($($arg)*))); }
|
||||
}}
|
||||
}
|
||||
|
||||
|
||||
pub mod backend;
|
||||
pub mod config;
|
||||
pub mod models;
|
||||
pub mod error;
|
||||
pub mod ui;
|
||||
pub use error::Error;
|
||||
pub mod prelude;
|
||||
|
||||
#[derive(Debug, serde::Serialize, Clone)]
|
||||
pub struct OutputEntry {
|
||||
pub id: u64,
|
||||
pub speaker: String,
|
||||
pub start: f64,
|
||||
pub end: f64,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
pub fn date_prefix() -> String {
|
||||
Local::now().format("%Y-%m-%d").to_string()
|
||||
}
|
||||
|
||||
pub fn format_srt_time(seconds: f64) -> String {
|
||||
let total_ms = (seconds * 1000.0).round() as i64;
|
||||
let ms = total_ms % 1000;
|
||||
let total_secs = total_ms / 1000;
|
||||
let sec = total_secs % 60;
|
||||
let min = (total_secs / 60) % 60;
|
||||
let hour = total_secs / 3600;
|
||||
format!("{hour:02}:{min:02}:{sec:02},{ms:03}")
|
||||
}
|
||||
|
||||
pub fn render_srt(entries: &[OutputEntry]) -> String {
|
||||
let mut srt = String::new();
|
||||
for (index, entry) in entries.iter().enumerate() {
|
||||
let srt_index = index + 1;
|
||||
srt.push_str(&format!("{srt_index}\n"));
|
||||
srt.push_str(&format!(
|
||||
"{} --> {}\n",
|
||||
format_srt_time(entry.start),
|
||||
format_srt_time(entry.end)
|
||||
));
|
||||
if !entry.speaker.is_empty() {
|
||||
srt.push_str(&format!("{}: {}\n", entry.speaker, entry.text));
|
||||
} else {
|
||||
srt.push_str(&format!("{}\n", entry.text));
|
||||
}
|
||||
srt.push('\n');
|
||||
}
|
||||
srt
|
||||
}
|
||||
|
||||
pub mod model_manager;
|
||||
|
||||
pub fn models_dir_path() -> PathBuf {
|
||||
if let Ok(env_val) = env::var("POLYSCRIBE_MODELS_DIR") {
|
||||
let env_path = PathBuf::from(env_val);
|
||||
if !env_path.as_os_str().is_empty() {
|
||||
return env_path;
|
||||
}
|
||||
}
|
||||
if cfg!(debug_assertions) {
|
||||
return PathBuf::from("models");
|
||||
}
|
||||
if let Ok(xdg) = env::var("XDG_DATA_HOME")
|
||||
&& !xdg.is_empty()
|
||||
{
|
||||
return PathBuf::from(xdg).join("polyscribe").join("models");
|
||||
}
|
||||
if let Ok(home) = env::var("HOME")
|
||||
&& !home.is_empty()
|
||||
{
|
||||
return PathBuf::from(home)
|
||||
.join(".local")
|
||||
.join("share")
|
||||
.join("polyscribe")
|
||||
.join("models");
|
||||
}
|
||||
PathBuf::from("models")
|
||||
}
|
||||
|
||||
pub fn normalize_lang_code(input: &str) -> Option<String> {
|
||||
let mut lang = input.trim().to_lowercase();
|
||||
if lang.is_empty() || lang == "auto" || lang == "c" || lang == "posix" {
|
||||
return None;
|
||||
}
|
||||
if let Some((prefix, _)) = lang.split_once('.') {
|
||||
lang = prefix.to_string();
|
||||
}
|
||||
if let Some((prefix, _)) = lang.split_once('_') {
|
||||
lang = prefix.to_string();
|
||||
}
|
||||
let code = match lang.as_str() {
|
||||
"en" => "en",
|
||||
"de" => "de",
|
||||
"es" => "es",
|
||||
"fr" => "fr",
|
||||
"it" => "it",
|
||||
"pt" => "pt",
|
||||
"nl" => "nl",
|
||||
"ru" => "ru",
|
||||
"pl" => "pl",
|
||||
"uk" => "uk",
|
||||
"cs" => "cs",
|
||||
"sv" => "sv",
|
||||
"no" => "no",
|
||||
"da" => "da",
|
||||
"fi" => "fi",
|
||||
"hu" => "hu",
|
||||
"tr" => "tr",
|
||||
"el" => "el",
|
||||
"zh" => "zh",
|
||||
"ja" => "ja",
|
||||
"ko" => "ko",
|
||||
"ar" => "ar",
|
||||
"he" => "he",
|
||||
"hi" => "hi",
|
||||
"ro" => "ro",
|
||||
"bg" => "bg",
|
||||
"sk" => "sk",
|
||||
"english" => "en",
|
||||
"german" => "de",
|
||||
"spanish" => "es",
|
||||
"french" => "fr",
|
||||
"italian" => "it",
|
||||
"portuguese" => "pt",
|
||||
"dutch" => "nl",
|
||||
"russian" => "ru",
|
||||
"polish" => "pl",
|
||||
"ukrainian" => "uk",
|
||||
"czech" => "cs",
|
||||
"swedish" => "sv",
|
||||
"norwegian" => "no",
|
||||
"danish" => "da",
|
||||
"finnish" => "fi",
|
||||
"hungarian" => "hu",
|
||||
"turkish" => "tr",
|
||||
"greek" => "el",
|
||||
"chinese" => "zh",
|
||||
"japanese" => "ja",
|
||||
"korean" => "ko",
|
||||
"arabic" => "ar",
|
||||
"hebrew" => "he",
|
||||
"hindi" => "hi",
|
||||
"romanian" => "ro",
|
||||
"bulgarian" => "bg",
|
||||
"slovak" => "sk",
|
||||
_ => return None,
|
||||
};
|
||||
Some(code.to_string())
|
||||
}
|
||||
|
||||
pub fn find_model_file() -> Result<PathBuf> {
|
||||
if let Ok(path) = env::var("WHISPER_MODEL") {
|
||||
let p = PathBuf::from(path);
|
||||
if !p.exists() {
|
||||
return Err(anyhow!(
|
||||
"WHISPER_MODEL points to a non-existing path: {}",
|
||||
p.display()
|
||||
)
|
||||
.into());
|
||||
}
|
||||
if !p.is_file() {
|
||||
return Err(anyhow!(
|
||||
"WHISPER_MODEL must point to a file, but is not: {}",
|
||||
p.display()
|
||||
)
|
||||
.into());
|
||||
}
|
||||
return Ok(p);
|
||||
}
|
||||
|
||||
let models_dir = models_dir_path();
|
||||
if models_dir.exists() && !models_dir.is_dir() {
|
||||
return Err(anyhow!(
|
||||
"Models path exists but is not a directory: {}",
|
||||
models_dir.display()
|
||||
)
|
||||
.into());
|
||||
}
|
||||
std::fs::create_dir_all(&models_dir).with_context(|| {
|
||||
format!(
|
||||
"Failed to ensure models dir exists: {}",
|
||||
models_dir.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut candidates = Vec::new();
|
||||
for entry in std::fs::read_dir(&models_dir)
|
||||
.with_context(|| format!("Failed to read models dir: {}", models_dir.display()))?
|
||||
{
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
||||
let is_bin = path
|
||||
.extension()
|
||||
.and_then(|s| s.to_str())
|
||||
.is_some_and(|s| s.eq_ignore_ascii_case("bin"));
|
||||
if !is_bin {
|
||||
continue;
|
||||
}
|
||||
|
||||
let md = match std::fs::metadata(&path) {
|
||||
Ok(m) if m.is_file() => m,
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
candidates.push((md.len(), path));
|
||||
}
|
||||
|
||||
if candidates.is_empty() {
|
||||
let fallback = models_dir.join("ggml-tiny.en.bin");
|
||||
if fallback.is_file() {
|
||||
return Ok(fallback);
|
||||
}
|
||||
return Err(anyhow!(
|
||||
"No Whisper model files (*.bin) found in {}. \
|
||||
Please download a model or set WHISPER_MODEL.",
|
||||
models_dir.display()
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
candidates.sort_by_key(|(size, _)| *size);
|
||||
let (_size, path) = candidates.into_iter().last().expect("non-empty");
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
pub fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> {
|
||||
let in_path = audio_path
|
||||
.to_str()
|
||||
.ok_or_else(|| anyhow!("Audio path must be valid UTF-8: {}", audio_path.display()))?;
|
||||
|
||||
let tmp_raw = std::env::temp_dir().join("polyscribe_tmp_input.f32le");
|
||||
let tmp_raw_str = tmp_raw
|
||||
.to_str()
|
||||
.ok_or_else(|| anyhow!("Temp path not valid UTF-8: {}", tmp_raw.display()))?;
|
||||
|
||||
let status = Command::new("ffmpeg")
|
||||
.arg("-hide_banner")
|
||||
.arg("-loglevel")
|
||||
.arg("error")
|
||||
.arg("-i")
|
||||
.arg(in_path)
|
||||
.arg("-f")
|
||||
.arg("f32le")
|
||||
.arg("-ac")
|
||||
.arg("1")
|
||||
.arg("-ar")
|
||||
.arg("16000")
|
||||
.arg("-y")
|
||||
.arg(tmp_raw_str)
|
||||
.status()
|
||||
.with_context(|| format!("Failed to invoke ffmpeg to decode: {}", in_path))?;
|
||||
|
||||
if !status.success() {
|
||||
return Err(anyhow!(
|
||||
"ffmpeg exited with non-zero status when decoding {}",
|
||||
in_path
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
let raw = std::fs::read(&tmp_raw)
|
||||
.with_context(|| format!("Failed to read temp PCM file: {}", tmp_raw.display()))?;
|
||||
|
||||
let _ = std::fs::remove_file(&tmp_raw);
|
||||
|
||||
if raw.len() % 4 != 0 {
|
||||
return Err(anyhow!("Decoded PCM file length not multiple of 4: {}", raw.len()).into());
|
||||
}
|
||||
let mut samples = Vec::with_capacity(raw.len() / 4);
|
||||
for chunk in raw.chunks_exact(4) {
|
||||
let v = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
|
||||
samples.push(v);
|
||||
}
|
||||
Ok(samples)
|
||||
}
|
893
crates/polyscribe-core/src/model_manager.rs
Normal file
893
crates/polyscribe-core/src/model_manager.rs
Normal file
@@ -0,0 +1,893 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
use crate::prelude::*;
|
||||
use crate::ui::BytesProgress;
|
||||
use anyhow::{anyhow, Context};
|
||||
use chrono::{DateTime, Utc};
|
||||
use reqwest::blocking::Client;
|
||||
use reqwest::header::{
|
||||
ACCEPT_RANGES, AUTHORIZATION, CONTENT_LENGTH, ETAG, IF_NONE_MATCH, LAST_MODIFIED, RANGE,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::cmp::min;
|
||||
use std::collections::BTreeMap;
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
use std::io::{Read, Seek, SeekFrom, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::{mpsc, Arc, Mutex};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
const DEFAULT_CHUNK_SIZE: u64 = 8 * 1024 * 1024; // 8 MiB
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct ModelRecord {
|
||||
pub alias: String,
|
||||
pub repo: String,
|
||||
pub file: String,
|
||||
pub revision: Option<String>, // ETag or commit hash
|
||||
pub sha256: Option<String>,
|
||||
pub size_bytes: Option<u64>,
|
||||
pub quant: Option<String>,
|
||||
pub installed_at: Option<DateTime<Utc>>,
|
||||
pub last_used: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct Manifest {
|
||||
pub models: BTreeMap<String, ModelRecord>, // key = alias
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Settings {
|
||||
pub concurrency: usize,
|
||||
pub limit_rate: Option<u64>, // bytes/sec
|
||||
pub chunk_size: u64,
|
||||
}
|
||||
|
||||
impl Default for Settings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
concurrency: 4,
|
||||
limit_rate: None,
|
||||
chunk_size: DEFAULT_CHUNK_SIZE,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Paths {
|
||||
pub cache_dir: PathBuf, // $XDG_CACHE_HOME/polyscribe/models
|
||||
pub config_path: PathBuf, // $XDG_CONFIG_HOME/polyscribe/models.json
|
||||
}
|
||||
|
||||
impl Paths {
|
||||
pub fn resolve() -> Result<Self> {
|
||||
if let Ok(over) = std::env::var("POLYSCRIBE_CACHE_DIR") {
|
||||
if !over.is_empty() {
|
||||
let cache_dir = PathBuf::from(over).join("models");
|
||||
let config_path = std::env::var("POLYSCRIBE_CONFIG_DIR")
|
||||
.map(|p| PathBuf::from(p).join("models.json"))
|
||||
.unwrap_or_else(|_| default_config_path());
|
||||
return Ok(Self {
|
||||
cache_dir,
|
||||
config_path,
|
||||
});
|
||||
}
|
||||
}
|
||||
let cache_dir = default_cache_dir();
|
||||
let config_path = default_config_path();
|
||||
Ok(Self {
|
||||
cache_dir,
|
||||
config_path,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn default_cache_dir() -> PathBuf {
|
||||
if let Ok(xdg) = std::env::var("XDG_CACHE_HOME") {
|
||||
if !xdg.is_empty() {
|
||||
return PathBuf::from(xdg).join("polyscribe").join("models");
|
||||
}
|
||||
}
|
||||
if let Ok(home) = std::env::var("HOME") {
|
||||
if !home.is_empty() {
|
||||
return PathBuf::from(home)
|
||||
.join(".cache")
|
||||
.join("polyscribe")
|
||||
.join("models");
|
||||
}
|
||||
}
|
||||
PathBuf::from("models")
|
||||
}
|
||||
|
||||
fn default_config_path() -> PathBuf {
|
||||
if let Ok(xdg) = std::env::var("XDG_CONFIG_HOME") {
|
||||
if !xdg.is_empty() {
|
||||
return PathBuf::from(xdg).join("polyscribe").join("models.json");
|
||||
}
|
||||
}
|
||||
if let Ok(home) = std::env::var("HOME") {
|
||||
if !home.is_empty() {
|
||||
return PathBuf::from(home)
|
||||
.join(".config")
|
||||
.join("polyscribe")
|
||||
.join("models.json");
|
||||
}
|
||||
}
|
||||
PathBuf::from("models.json")
|
||||
}
|
||||
|
||||
pub trait HttpClient: Send + Sync {
|
||||
fn head(&self, url: &str, etag: Option<&str>) -> Result<HeadMeta>;
|
||||
fn get_range(&self, url: &str, start: u64, end_inclusive: u64) -> Result<Vec<u8>>;
|
||||
fn get_whole_to(&self, url: &str, writer: &mut dyn Write) -> Result<()>;
|
||||
fn get_from_to(&self, url: &str, start: u64, writer: &mut dyn Write) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ReqwestClient {
|
||||
client: Client,
|
||||
token: Option<String>,
|
||||
}
|
||||
|
||||
impl ReqwestClient {
|
||||
pub fn new() -> Result<Self> {
|
||||
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
|
||||
let client = Client::builder()
|
||||
.user_agent(crate::config::ConfigService::user_agent())
|
||||
.build()?;
|
||||
Ok(Self { client, token })
|
||||
}
|
||||
|
||||
fn auth(&self, mut req: reqwest::blocking::RequestBuilder) -> reqwest::blocking::RequestBuilder {
|
||||
if let Some(t) = &self.token {
|
||||
req = req.header(AUTHORIZATION, format!("Bearer {}", t));
|
||||
}
|
||||
req
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HeadMeta {
|
||||
pub len: Option<u64>,
|
||||
pub etag: Option<String>,
|
||||
pub last_modified: Option<String>,
|
||||
pub accept_ranges: bool,
|
||||
pub not_modified: bool,
|
||||
pub status: u16,
|
||||
}
|
||||
|
||||
impl HttpClient for ReqwestClient {
|
||||
fn head(&self, url: &str, etag: Option<&str>) -> Result<HeadMeta> {
|
||||
let mut req = self.client.head(url);
|
||||
if let Some(e) = etag {
|
||||
req = req.header(IF_NONE_MATCH, format!("\"{}\"", e));
|
||||
}
|
||||
let resp = self.auth(req).send()?;
|
||||
let status = resp.status().as_u16();
|
||||
if status == 304 {
|
||||
return Ok(HeadMeta {
|
||||
len: None,
|
||||
etag: etag.map(|s| s.to_string()),
|
||||
last_modified: None,
|
||||
accept_ranges: true,
|
||||
not_modified: true,
|
||||
status,
|
||||
});
|
||||
}
|
||||
let len = resp
|
||||
.headers()
|
||||
.get(CONTENT_LENGTH)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<u64>().ok());
|
||||
let etag = resp
|
||||
.headers()
|
||||
.get(ETAG)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.trim_matches('"').to_string());
|
||||
let last_modified = resp
|
||||
.headers()
|
||||
.get(LAST_MODIFIED)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
let accept_ranges = resp
|
||||
.headers()
|
||||
.get(ACCEPT_RANGES)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_ascii_lowercase().contains("bytes"))
|
||||
.unwrap_or(false);
|
||||
Ok(HeadMeta {
|
||||
len,
|
||||
etag,
|
||||
last_modified,
|
||||
accept_ranges,
|
||||
not_modified: false,
|
||||
status,
|
||||
})
|
||||
}
|
||||
|
||||
fn get_range(&self, url: &str, start: u64, end_inclusive: u64) -> Result<Vec<u8>> {
|
||||
let range_val = format!("bytes={}-{}", start, end_inclusive);
|
||||
let resp = self
|
||||
.auth(self.client.get(url))
|
||||
.header(RANGE, range_val)
|
||||
.send()?;
|
||||
if !resp.status().is_success() && resp.status().as_u16() != 206 {
|
||||
return Err(anyhow!("HTTP {} for ranged GET", resp.status()).into());
|
||||
}
|
||||
let mut buf = Vec::new();
|
||||
let mut r = resp;
|
||||
r.copy_to(&mut buf)?;
|
||||
Ok(buf)
|
||||
}
|
||||
|
||||
fn get_whole_to(&self, url: &str, writer: &mut dyn Write) -> Result<()> {
|
||||
let resp = self.auth(self.client.get(url)).send()?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(anyhow!("HTTP {} for GET", resp.status()).into());
|
||||
}
|
||||
let mut r = resp;
|
||||
r.copy_to(writer)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_from_to(&self, url: &str, start: u64, writer: &mut dyn Write) -> Result<()> {
|
||||
let mut req = self.auth(self.client.get(url));
|
||||
if start > 0 {
|
||||
req = req.header(RANGE, format!("bytes={}-", start));
|
||||
}
|
||||
let resp = req.send()?;
|
||||
if !resp.status().is_success() && resp.status().as_u16() != 206 {
|
||||
return Err(anyhow!("HTTP {} for ranged GET from {}", resp.status(), start).into());
|
||||
}
|
||||
let mut r = resp;
|
||||
r.copy_to(writer)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ModelManager<C: HttpClient = ReqwestClient> {
|
||||
pub paths: Paths,
|
||||
pub settings: Settings,
|
||||
client: Arc<C>,
|
||||
}
|
||||
|
||||
impl<C: HttpClient + 'static> ModelManager<C> {
|
||||
pub fn new_with_client(client: C, settings: Settings) -> Result<Self> {
|
||||
Ok(Self {
|
||||
paths: Paths::resolve()?,
|
||||
settings,
|
||||
client: Arc::new(client),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new(settings: Settings) -> Result<Self>
|
||||
where
|
||||
C: Default,
|
||||
{
|
||||
Ok(Self {
|
||||
paths: Paths::resolve()?,
|
||||
settings,
|
||||
client: Arc::new(C::default()),
|
||||
})
|
||||
}
|
||||
|
||||
fn load_manifest(&self) -> Result<Manifest> {
|
||||
let p = &self.paths.config_path;
|
||||
if !p.exists() {
|
||||
return Ok(Manifest::default());
|
||||
}
|
||||
let file = File::open(p).with_context(|| format!("open manifest {}", p.display()))?;
|
||||
let m: Manifest = serde_json::from_reader(file).context("parse manifest")?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
fn save_manifest(&self, m: &Manifest) -> Result<()> {
|
||||
let p = &self.paths.config_path;
|
||||
if let Some(dir) = p.parent() {
|
||||
fs::create_dir_all(dir)
|
||||
.with_context(|| format!("create config dir {}", dir.display()))?;
|
||||
}
|
||||
let tmp = p.with_extension("json.tmp");
|
||||
let f = OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.truncate(true)
|
||||
.open(&tmp)?;
|
||||
serde_json::to_writer_pretty(f, m).context("serialize manifest")?;
|
||||
fs::rename(&tmp, p).with_context(|| format!("atomic rename {} -> {}", tmp.display(), p.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn model_path(&self, file: &str) -> PathBuf {
|
||||
self.paths.cache_dir.join(file)
|
||||
}
|
||||
|
||||
fn compute_sha256(path: &Path) -> Result<String> {
|
||||
let mut f = File::open(path)?;
|
||||
let mut hasher = Sha256::new();
|
||||
let mut buf = [0u8; 64 * 1024];
|
||||
loop {
|
||||
let n = f.read(&mut buf)?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
hasher.update(&buf[..n]);
|
||||
}
|
||||
Ok(format!("{:x}", hasher.finalize()))
|
||||
}
|
||||
|
||||
pub fn ls(&self) -> Result<Vec<ModelRecord>> {
|
||||
let m = self.load_manifest()?;
|
||||
Ok(m.models.values().cloned().collect())
|
||||
}
|
||||
|
||||
pub fn rm(&self, alias: &str) -> Result<bool> {
|
||||
let mut m = self.load_manifest()?;
|
||||
if let Some(rec) = m.models.remove(alias) {
|
||||
let p = self.model_path(&rec.file);
|
||||
let _ = fs::remove_file(&p);
|
||||
self.save_manifest(&m)?;
|
||||
return Ok(true);
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
pub fn verify(&self, alias: &str) -> Result<bool> {
|
||||
let m = self.load_manifest()?;
|
||||
let Some(rec) = m.models.get(alias) else { return Ok(false) };
|
||||
let p = self.model_path(&rec.file);
|
||||
if !p.exists() { return Ok(false); }
|
||||
if let Some(expected) = &rec.sha256 {
|
||||
let actual = Self::compute_sha256(&p)?;
|
||||
return Ok(&actual == expected);
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
pub fn gc(&self) -> Result<(usize, usize)> {
|
||||
// Remove files not referenced by manifest; also drop manifest entries whose file is missing
|
||||
fs::create_dir_all(&self.paths.cache_dir).ok();
|
||||
let mut m = self.load_manifest()?;
|
||||
let mut referenced = BTreeMap::new();
|
||||
for (alias, rec) in &m.models {
|
||||
referenced.insert(rec.file.clone(), alias.clone());
|
||||
}
|
||||
let mut removed_files = 0usize;
|
||||
if let Ok(rd) = fs::read_dir(&self.paths.cache_dir) {
|
||||
for ent in rd.flatten() {
|
||||
let p = ent.path();
|
||||
if p.is_file() {
|
||||
let fname = p.file_name().and_then(|s| s.to_str()).unwrap_or("");
|
||||
if !referenced.contains_key(fname) {
|
||||
let _ = fs::remove_file(&p);
|
||||
removed_files += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.models.retain(|_, rec| self.model_path(&rec.file).exists());
|
||||
let removed_entries = referenced
|
||||
.keys()
|
||||
.filter(|f| !self.model_path(f).exists())
|
||||
.count();
|
||||
|
||||
self.save_manifest(&m)?;
|
||||
Ok((removed_files, removed_entries))
|
||||
}
|
||||
|
||||
pub fn add_or_update(
|
||||
&self,
|
||||
alias: &str,
|
||||
repo: &str,
|
||||
file: &str,
|
||||
) -> Result<ModelRecord> {
|
||||
fs::create_dir_all(&self.paths.cache_dir)
|
||||
.with_context(|| format!("create cache dir {}", self.paths.cache_dir.display()))?;
|
||||
|
||||
let url = format!("https://huggingface.co/{}/resolve/main/{}?download=true", repo, file);
|
||||
|
||||
let mut manifest = self.load_manifest()?;
|
||||
let prev = manifest.models.get(alias).cloned();
|
||||
let prev_etag = prev.as_ref().and_then(|r| r.revision.clone());
|
||||
// Fetch remote meta (size/hash) when available to verify the download
|
||||
let (_api_size, api_sha) = hf_fetch_file_meta(repo, file).unwrap_or((None, None));
|
||||
let head = self.client.head(&url, prev_etag.as_deref())?;
|
||||
if head.not_modified {
|
||||
// up-to-date; ensure record present and touch last_used
|
||||
let mut rec = prev.ok_or_else(|| anyhow!("not installed yet but server says 304"))?;
|
||||
rec.last_used = Some(Utc::now());
|
||||
self.save_touch(&mut manifest, rec.clone())?;
|
||||
return Ok(rec);
|
||||
}
|
||||
|
||||
// Quick check: if HEAD failed (e.g., 404), report a helpful error before attempting download
|
||||
if head.status >= 400 {
|
||||
return Err(anyhow!(
|
||||
"file not found or inaccessible: repo='{}' file='{}' (HTTP {})\nHint: run `polyscribe models search {} --query {}` to list available files",
|
||||
repo,
|
||||
file,
|
||||
head.status,
|
||||
repo,
|
||||
file
|
||||
).into());
|
||||
}
|
||||
|
||||
let total_len = head.len.ok_or_else(|| anyhow!("missing content-length (HEAD)"))?;
|
||||
let etag = head.etag.clone();
|
||||
let dest_tmp = self.model_path(&format!("{}.part", file));
|
||||
// If a previous cancelled download left a .part file, remove it to avoid clutter/resume.
|
||||
if dest_tmp.exists() { let _ = fs::remove_file(&dest_tmp); }
|
||||
// Guard to ensure .part is cleaned up on errors
|
||||
struct TempGuard { path: PathBuf, armed: bool }
|
||||
impl TempGuard { fn disarm(&mut self) { self.armed = false; } }
|
||||
impl Drop for TempGuard {
|
||||
fn drop(&mut self) {
|
||||
if self.armed {
|
||||
let _ = fs::remove_file(&self.path);
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut _tmp_guard = TempGuard { path: dest_tmp.clone(), armed: true };
|
||||
let dest_final = self.model_path(file);
|
||||
|
||||
// Do not resume after cancellation; start fresh to avoid stale .part files
|
||||
let start_from = 0u64;
|
||||
|
||||
// Open tmp for write
|
||||
let f = OpenOptions::new().create(true).write(true).read(true).open(&dest_tmp)?;
|
||||
f.set_len(total_len)?; // pre-allocate for random writes
|
||||
let f = Arc::new(Mutex::new(f));
|
||||
|
||||
// Create progress bar
|
||||
let mut progress = BytesProgress::start(total_len, &format!("Downloading {}", file), start_from);
|
||||
|
||||
// Create work chunks
|
||||
let chunk_size = self.settings.chunk_size;
|
||||
let mut chunks = Vec::new();
|
||||
let mut pos = start_from;
|
||||
while pos < total_len {
|
||||
let end = min(total_len - 1, pos + chunk_size - 1);
|
||||
chunks.push((pos, end));
|
||||
pos = end + 1;
|
||||
}
|
||||
|
||||
// Attempt concurrent ranged download; on failure, fallback to streaming GET
|
||||
let mut ranged_failed = false;
|
||||
if head.accept_ranges && self.settings.concurrency > 1 {
|
||||
let (work_tx, work_rx) = mpsc::channel::<(u64, u64)>();
|
||||
let (prog_tx, prog_rx) = mpsc::channel::<u64>();
|
||||
for ch in chunks {
|
||||
work_tx.send(ch).unwrap();
|
||||
}
|
||||
drop(work_tx);
|
||||
let rx = Arc::new(Mutex::new(work_rx));
|
||||
|
||||
let workers = self.settings.concurrency.max(1);
|
||||
let mut handles = Vec::new();
|
||||
for _ in 0..workers {
|
||||
let rx = rx.clone();
|
||||
let url = url.clone();
|
||||
let client = self.client.clone();
|
||||
let f = f.clone();
|
||||
let limit = self.settings.limit_rate;
|
||||
let prog_tx = prog_tx.clone();
|
||||
let handle = thread::spawn(move || -> Result<()> {
|
||||
loop {
|
||||
let next = {
|
||||
let guard = rx.lock().unwrap();
|
||||
guard.recv().ok()
|
||||
};
|
||||
let Some((start, end)) = next else { break; };
|
||||
let data = client.get_range(&url, start, end)?;
|
||||
if let Some(max_bps) = limit {
|
||||
let dur = Duration::from_secs_f64((data.len() as f64) / (max_bps as f64));
|
||||
if dur > Duration::from_millis(1) {
|
||||
thread::sleep(dur);
|
||||
}
|
||||
}
|
||||
let mut guard = f.lock().unwrap();
|
||||
guard.seek(SeekFrom::Start(start))?;
|
||||
guard.write_all(&data)?;
|
||||
let _ = prog_tx.send(data.len() as u64);
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
drop(prog_tx);
|
||||
|
||||
for delta in prog_rx {
|
||||
progress.inc(delta);
|
||||
}
|
||||
|
||||
let mut ranged_err: Option<anyhow::Error> = None;
|
||||
for h in handles {
|
||||
match h.join() {
|
||||
Ok(Ok(())) => {}
|
||||
Ok(Err(e)) => { ranged_failed = true; if ranged_err.is_none() { ranged_err = Some(e.into()); } }
|
||||
Err(_) => { ranged_failed = true; if ranged_err.is_none() { ranged_err = Some(anyhow!("worker panicked")); } }
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ranged_failed = true;
|
||||
}
|
||||
|
||||
if ranged_failed {
|
||||
// Restart progress if we are abandoning previous partial state
|
||||
if start_from > 0 {
|
||||
progress.stop("retrying as streaming");
|
||||
progress = BytesProgress::start(total_len, &format!("Downloading {}", file), 0);
|
||||
}
|
||||
// Fallback to streaming GET; try URL with and without ?download=true
|
||||
let mut try_urls = vec![url.clone()];
|
||||
if let Some((base, _qs)) = url.split_once('?') {
|
||||
try_urls.push(base.to_string());
|
||||
} else {
|
||||
try_urls.push(format!("{}?download=true", url));
|
||||
}
|
||||
// Fresh temp file for streaming
|
||||
let mut wf = OpenOptions::new().create(true).write(true).truncate(true).open(&dest_tmp)?;
|
||||
let mut ok = false;
|
||||
let mut last_err: Option<anyhow::Error> = None;
|
||||
// Counting writer to update progress inline
|
||||
struct CountingWriter<'a, 'b> {
|
||||
inner: &'a mut File,
|
||||
progress: &'b mut BytesProgress,
|
||||
}
|
||||
impl<'a, 'b> Write for CountingWriter<'a, 'b> {
|
||||
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
let n = self.inner.write(buf)?;
|
||||
self.progress.inc(n as u64);
|
||||
Ok(n)
|
||||
}
|
||||
fn flush(&mut self) -> std::io::Result<()> { self.inner.flush() }
|
||||
}
|
||||
let mut cw = CountingWriter { inner: &mut wf, progress: &mut progress };
|
||||
for u in try_urls {
|
||||
// For fallback, stream from scratch to ensure integrity
|
||||
let res = self.client.get_whole_to(&u, &mut cw);
|
||||
match res {
|
||||
Ok(()) => { ok = true; break; }
|
||||
Err(e) => { last_err = Some(e.into()); }
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
if let Some(e) = last_err { return Err(e.into()); }
|
||||
return Err(anyhow!("download failed (ranged and streaming)").into());
|
||||
}
|
||||
}
|
||||
|
||||
progress.stop("download complete");
|
||||
|
||||
// Verify integrity
|
||||
let sha = Self::compute_sha256(&dest_tmp)?;
|
||||
if let Some(expected) = api_sha.as_ref() {
|
||||
if &sha != expected {
|
||||
return Err(anyhow!(
|
||||
"sha256 mismatch (expected {}, got {})",
|
||||
expected,
|
||||
sha
|
||||
).into());
|
||||
}
|
||||
} else if prev.as_ref().map(|r| r.file.eq(file)).unwrap_or(false) {
|
||||
if let Some(expected) = prev.as_ref().and_then(|r| r.sha256.as_ref()) {
|
||||
if &sha != expected {
|
||||
return Err(anyhow!("sha256 mismatch").into());
|
||||
}
|
||||
}
|
||||
}
|
||||
// Atomic rename
|
||||
fs::rename(&dest_tmp, &dest_final).with_context(|| format!("rename {} -> {}", dest_tmp.display(), dest_final.display()))?;
|
||||
// Disarm guard; .part has been moved or cleaned
|
||||
_tmp_guard.disarm();
|
||||
|
||||
let rec = ModelRecord {
|
||||
alias: alias.to_string(),
|
||||
repo: repo.to_string(),
|
||||
file: file.to_string(),
|
||||
revision: etag,
|
||||
sha256: Some(sha.clone()),
|
||||
size_bytes: Some(total_len),
|
||||
quant: infer_quant(file),
|
||||
installed_at: Some(Utc::now()),
|
||||
last_used: Some(Utc::now()),
|
||||
};
|
||||
self.save_touch(&mut manifest, rec.clone())?;
|
||||
Ok(rec)
|
||||
}
|
||||
|
||||
fn save_touch(&self, manifest: &mut Manifest, rec: ModelRecord) -> Result<()> {
|
||||
manifest.models.insert(rec.alias.clone(), rec);
|
||||
self.save_manifest(manifest)
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_quant(file: &str) -> Option<String> {
|
||||
// Try to extract a Q* token, e.g. Q5_K_M from filename
|
||||
let lower = file.to_ascii_uppercase();
|
||||
if let Some(pos) = lower.find('Q') {
|
||||
let tail = &lower[pos..];
|
||||
let token: String = tail
|
||||
.chars()
|
||||
.take_while(|c| c.is_ascii_uppercase() || c.is_ascii_digit() || *c == '_' || *c == '-')
|
||||
.collect();
|
||||
if token.len() >= 2 {
|
||||
return Some(token);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
impl Default for ReqwestClient {
|
||||
fn default() -> Self {
|
||||
Self::new().expect("reqwest client")
|
||||
}
|
||||
}
|
||||
|
||||
// Hugging Face API types for file metadata
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiHfLfs {
|
||||
oid: Option<String>,
|
||||
size: Option<u64>,
|
||||
sha256: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiHfFile {
|
||||
rfilename: String,
|
||||
size: Option<u64>,
|
||||
sha256: Option<String>,
|
||||
lfs: Option<ApiHfLfs>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiHfModelInfo {
|
||||
siblings: Option<Vec<ApiHfFile>>,
|
||||
files: Option<Vec<ApiHfFile>>,
|
||||
}
|
||||
|
||||
fn pick_sha_from_file(f: &ApiHfFile) -> Option<String> {
|
||||
if let Some(s) = &f.sha256 { return Some(s.to_string()); }
|
||||
if let Some(l) = &f.lfs {
|
||||
if let Some(s) = &l.sha256 { return Some(s.to_string()); }
|
||||
if let Some(oid) = &l.oid { return oid.strip_prefix("sha256:").map(|s| s.to_string()); }
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn hf_fetch_file_meta(repo: &str, target: &str) -> Result<(Option<u64>, Option<String>)> {
|
||||
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
|
||||
let client = Client::builder()
|
||||
.user_agent(crate::config::ConfigService::user_agent())
|
||||
.build()?;
|
||||
let base = crate::config::ConfigService::hf_api_base_for(repo);
|
||||
let urls = [base.clone(), format!("{}?expand=files", base)];
|
||||
for url in urls {
|
||||
let mut req = client.get(&url);
|
||||
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
|
||||
let resp = req.send()?;
|
||||
if !resp.status().is_success() { continue; }
|
||||
let info: ApiHfModelInfo = resp.json()?;
|
||||
let list = info.files.or(info.siblings).unwrap_or_default();
|
||||
for f in list {
|
||||
let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename);
|
||||
if name.eq_ignore_ascii_case(target) {
|
||||
let sz = f.size.or_else(|| f.lfs.as_ref().and_then(|l| l.size));
|
||||
let sha = pick_sha_from_file(&f);
|
||||
return Ok((sz, sha));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(anyhow!("file not found in HF API").into())
|
||||
}
|
||||
|
||||
/// Fetch remote metadata (size, sha256) for a single file in a HF repo.
|
||||
pub fn fetch_file_meta(repo: &str, file: &str) -> Result<(Option<u64>, Option<String>)> {
|
||||
hf_fetch_file_meta(repo, file)
|
||||
}
|
||||
|
||||
/// Search a Hugging Face repo for GGUF/BIN files via API. Returns file names only.
|
||||
pub fn search_repo(repo: &str, query: Option<&str>) -> Result<Vec<String>> {
|
||||
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
|
||||
let client = Client::builder()
|
||||
.user_agent(crate::config::ConfigService::user_agent())
|
||||
.build()?;
|
||||
|
||||
let base = crate::config::ConfigService::hf_api_base_for(repo);
|
||||
let mut urls = vec![base.clone(), format!("{}?expand=files", base)];
|
||||
let mut files = Vec::<String>::new();
|
||||
|
||||
for url in urls.drain(..) {
|
||||
let mut req = client.get(&url);
|
||||
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
|
||||
let resp = req.send()?;
|
||||
if !resp.status().is_success() { continue; }
|
||||
let info: ApiHfModelInfo = resp.json()?;
|
||||
let list = info.files.or(info.siblings).unwrap_or_default();
|
||||
for f in list {
|
||||
if f.rfilename.ends_with(".gguf") || f.rfilename.ends_with(".bin") {
|
||||
let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename).to_string();
|
||||
if !files.contains(&name) { files.push(name); }
|
||||
}
|
||||
}
|
||||
if !files.is_empty() { break; }
|
||||
}
|
||||
|
||||
if let Some(q) = query { let qlc = q.to_ascii_lowercase(); files.retain(|f| f.to_ascii_lowercase().contains(&qlc)); }
|
||||
files.sort();
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
/// List repo files with optional size metadata for GGUF/BIN entries.
|
||||
pub fn list_repo_files_with_meta(repo: &str) -> Result<Vec<(String, Option<u64>)>> {
|
||||
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
|
||||
let client = Client::builder()
|
||||
.user_agent(crate::config::ConfigService::user_agent())
|
||||
.build()?;
|
||||
|
||||
let base = crate::config::ConfigService::hf_api_base_for(repo);
|
||||
for url in [base.clone(), format!("{}?expand=files", base)] {
|
||||
let mut req = client.get(&url);
|
||||
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
|
||||
let resp = req.send()?;
|
||||
if !resp.status().is_success() { continue; }
|
||||
let info: ApiHfModelInfo = resp.json()?;
|
||||
let list = info.files.or(info.siblings).unwrap_or_default();
|
||||
let mut out = Vec::new();
|
||||
for f in list {
|
||||
if !(f.rfilename.ends_with(".gguf") || f.rfilename.ends_with(".bin")) { continue; }
|
||||
let name = f.rfilename.rsplit('/').next().unwrap_or(&f.rfilename).to_string();
|
||||
let size = f.size.or_else(|| f.lfs.as_ref().and_then(|l| l.size));
|
||||
out.push((name, size));
|
||||
}
|
||||
if !out.is_empty() { return Ok(out); }
|
||||
}
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
/// Fallback: HEAD request for a single file to retrieve Content-Length (size).
|
||||
pub fn head_len_for_file(repo: &str, file: &str) -> Option<u64> {
|
||||
let token = std::env::var("HF_TOKEN").ok().filter(|s| !s.is_empty());
|
||||
let client = Client::builder()
|
||||
.user_agent(crate::config::ConfigService::user_agent())
|
||||
.build().ok()?;
|
||||
let mut urls = Vec::new();
|
||||
urls.push(format!("https://huggingface.co/{}/resolve/main/{}?download=true", repo, file));
|
||||
urls.push(format!("https://huggingface.co/{}/resolve/main/{}", repo, file));
|
||||
for url in urls {
|
||||
let mut req = client.head(&url);
|
||||
if let Some(t) = &token { req = req.header(AUTHORIZATION, format!("Bearer {}", t)); }
|
||||
if let Ok(resp) = req.send() {
|
||||
if resp.status().is_success() {
|
||||
if let Some(len) = resp.headers().get(CONTENT_LENGTH)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<u64>().ok())
|
||||
{ return Some(len); }
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::env;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct StubHttp {
|
||||
data: Arc<Vec<u8>>,
|
||||
etag: Arc<Option<String>>,
|
||||
accept_ranges: bool,
|
||||
}
|
||||
impl HttpClient for StubHttp {
|
||||
fn head(&self, _url: &str, etag: Option<&str>) -> Result<HeadMeta> {
|
||||
let not_modified = etag.is_some() && self.etag.as_ref().as_deref() == etag;
|
||||
Ok(HeadMeta {
|
||||
len: Some(self.data.len() as u64),
|
||||
etag: self.etag.as_ref().clone(),
|
||||
last_modified: None,
|
||||
accept_ranges: self.accept_ranges,
|
||||
not_modified,
|
||||
status: if not_modified { 304 } else { 200 },
|
||||
})
|
||||
}
|
||||
fn get_range(&self, _url: &str, start: u64, end_inclusive: u64) -> Result<Vec<u8>> {
|
||||
let s = start as usize;
|
||||
let e = (end_inclusive as usize) + 1;
|
||||
Ok(self.data[s..e].to_vec())
|
||||
}
|
||||
|
||||
fn get_whole_to(&self, _url: &str, writer: &mut dyn Write) -> Result<()> {
|
||||
writer.write_all(&self.data)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_from_to(&self, _url: &str, start: u64, writer: &mut dyn Write) -> Result<()> {
|
||||
let s = start as usize;
|
||||
writer.write_all(&self.data[s..])?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn setup_env(cache: &Path, cfg: &Path) {
|
||||
unsafe {
|
||||
env::set_var("POLYSCRIBE_CACHE_DIR", cache.to_string_lossy().to_string());
|
||||
env::set_var(
|
||||
"POLYSCRIBE_CONFIG_DIR",
|
||||
cfg.parent().unwrap().to_string_lossy().to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_manifest_roundtrip() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let cache = temp.path().join("cache");
|
||||
let cfg = temp.path().join("config").join("models.json");
|
||||
setup_env(&cache, &cfg);
|
||||
|
||||
let client = StubHttp {
|
||||
data: Arc::new(vec![0u8; 1024]),
|
||||
etag: Arc::new(Some("etag123".into())),
|
||||
accept_ranges: true,
|
||||
};
|
||||
let mm: ModelManager<StubHttp> = ModelManager::new_with_client(client, Settings::default()).unwrap();
|
||||
let m = mm.load_manifest().unwrap();
|
||||
assert!(m.models.is_empty());
|
||||
|
||||
let rec = ModelRecord {
|
||||
alias: "tiny".into(),
|
||||
repo: "foo/bar".into(),
|
||||
file: "gguf-tiny.bin".into(),
|
||||
revision: Some("etag123".into()),
|
||||
sha256: None,
|
||||
size_bytes: None,
|
||||
quant: None,
|
||||
installed_at: None,
|
||||
last_used: None,
|
||||
};
|
||||
|
||||
let mut m2 = Manifest::default();
|
||||
mm.save_touch(&mut m2, rec.clone()).unwrap();
|
||||
let m3 = mm.load_manifest().unwrap();
|
||||
assert!(m3.models.contains_key("tiny"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_verify_update_gc() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let cache = temp.path().join("cache");
|
||||
let cfg_dir = temp.path().join("config");
|
||||
let cfg = cfg_dir.join("models.json");
|
||||
setup_env(&cache, &cfg);
|
||||
|
||||
let data = (0..1024 * 1024u32).flat_map(|i| i.to_le_bytes()).collect::<Vec<u8>>();
|
||||
let etag = Some("abc123".to_string());
|
||||
let client = StubHttp { data: Arc::new(data), etag: Arc::new(etag), accept_ranges: true };
|
||||
let mm: ModelManager<StubHttp> = ModelManager::new_with_client(client.clone(), Settings{ concurrency: 3, ..Default::default() }).unwrap();
|
||||
|
||||
// add
|
||||
let rec = mm.add_or_update("tiny", "gguf/models", "gguf-tiny-q4_0.bin").unwrap();
|
||||
assert_eq!(rec.alias, "tiny");
|
||||
assert!(mm.verify("tiny").unwrap());
|
||||
|
||||
// update (304)
|
||||
let rec2 = mm.add_or_update("tiny", "gguf/models", "gguf-tiny-q4_0.bin").unwrap();
|
||||
assert_eq!(rec2.alias, "tiny");
|
||||
|
||||
// gc (nothing to remove)
|
||||
let (files_removed, entries_removed) = mm.gc().unwrap();
|
||||
assert_eq!(files_removed, 0);
|
||||
assert_eq!(entries_removed, 0);
|
||||
|
||||
// rm
|
||||
assert!(mm.rm("tiny").unwrap());
|
||||
assert!(!mm.rm("tiny").unwrap());
|
||||
}
|
||||
}
|
1237
crates/polyscribe-core/src/models.rs
Normal file
1237
crates/polyscribe-core/src/models.rs
Normal file
File diff suppressed because it is too large
Load Diff
7
crates/polyscribe-core/src/prelude.rs
Normal file
7
crates/polyscribe-core/src/prelude.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
pub use crate::backend::*;
|
||||
pub use crate::config::*;
|
||||
pub use crate::error::Error;
|
||||
pub use crate::models::*;
|
||||
pub use crate::ui::*;
|
||||
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
329
crates/polyscribe-core/src/ui.rs
Normal file
329
crates/polyscribe-core/src/ui.rs
Normal file
@@ -0,0 +1,329 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
pub mod progress;
|
||||
|
||||
use std::io;
|
||||
use std::io::IsTerminal;
|
||||
use std::io::Write as _;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
pub fn info(msg: impl AsRef<str>) {
|
||||
let m = msg.as_ref();
|
||||
let _ = cliclack::log::info(m);
|
||||
}
|
||||
|
||||
pub fn warn(msg: impl AsRef<str>) {
|
||||
let m = msg.as_ref();
|
||||
let _ = cliclack::log::warning(m);
|
||||
}
|
||||
|
||||
pub fn error(msg: impl AsRef<str>) {
|
||||
let m = msg.as_ref();
|
||||
let _ = cliclack::log::error(m);
|
||||
}
|
||||
|
||||
pub fn success(msg: impl AsRef<str>) {
|
||||
let m = msg.as_ref();
|
||||
let _ = cliclack::log::success(m);
|
||||
}
|
||||
|
||||
pub fn note(prompt: impl AsRef<str>, message: impl AsRef<str>) {
|
||||
let _ = cliclack::note(prompt.as_ref(), message.as_ref());
|
||||
}
|
||||
|
||||
pub fn intro(title: impl AsRef<str>) {
|
||||
let _ = cliclack::intro(title.as_ref());
|
||||
}
|
||||
|
||||
pub fn outro(msg: impl AsRef<str>) {
|
||||
let _ = cliclack::outro(msg.as_ref());
|
||||
}
|
||||
|
||||
pub fn println_above_bars(line: impl AsRef<str>) {
|
||||
let _ = cliclack::log::info(line.as_ref());
|
||||
}
|
||||
|
||||
pub fn prompt_input(prompt: &str, default: Option<&str>) -> io::Result<String> {
|
||||
if crate::is_no_interaction() || !crate::stdin_is_tty() {
|
||||
return Ok(default.unwrap_or("").to_string());
|
||||
}
|
||||
let mut q = cliclack::input(prompt);
|
||||
if let Some(def) = default {
|
||||
q = q.default_input(def);
|
||||
}
|
||||
q.interact().map_err(|e| io::Error::other(e.to_string()))
|
||||
}
|
||||
|
||||
pub fn prompt_select(prompt: &str, items: &[&str]) -> io::Result<usize> {
|
||||
if crate::is_no_interaction() || !crate::stdin_is_tty() {
|
||||
return Err(io::Error::other("interactive prompt disabled"));
|
||||
}
|
||||
let mut sel = cliclack::select::<usize>(prompt);
|
||||
for (idx, label) in items.iter().enumerate() {
|
||||
sel = sel.item(idx, *label, "");
|
||||
}
|
||||
sel.interact().map_err(|e| io::Error::other(e.to_string()))
|
||||
}
|
||||
|
||||
pub fn prompt_multi_select(
|
||||
prompt: &str,
|
||||
items: &[&str],
|
||||
defaults: Option<&[bool]>,
|
||||
) -> io::Result<Vec<usize>> {
|
||||
if crate::is_no_interaction() || !crate::stdin_is_tty() {
|
||||
return Err(io::Error::other("interactive prompt disabled"));
|
||||
}
|
||||
let mut ms = cliclack::multiselect::<usize>(prompt);
|
||||
for (idx, label) in items.iter().enumerate() {
|
||||
ms = ms.item(idx, *label, "");
|
||||
}
|
||||
if let Some(def) = defaults {
|
||||
let selected: Vec<usize> = def
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, &on)| if on { Some(i) } else { None })
|
||||
.collect();
|
||||
if !selected.is_empty() {
|
||||
ms = ms.initial_values(selected);
|
||||
}
|
||||
}
|
||||
ms.interact().map_err(|e| io::Error::other(e.to_string()))
|
||||
}
|
||||
|
||||
pub fn prompt_confirm(prompt: &str, default: bool) -> io::Result<bool> {
|
||||
if crate::is_no_interaction() || !crate::stdin_is_tty() {
|
||||
return Ok(default);
|
||||
}
|
||||
let mut q = cliclack::confirm(prompt);
|
||||
q.interact().map_err(|e| io::Error::other(e.to_string()))
|
||||
}
|
||||
|
||||
pub fn prompt_password(prompt: &str) -> io::Result<String> {
|
||||
if crate::is_no_interaction() || !crate::stdin_is_tty() {
|
||||
return Err(io::Error::other(
|
||||
"password prompt disabled in non-interactive mode",
|
||||
));
|
||||
}
|
||||
let mut q = cliclack::password(prompt);
|
||||
q.interact().map_err(|e| io::Error::other(e.to_string()))
|
||||
}
|
||||
|
||||
pub fn prompt_input_validated<F>(
|
||||
prompt: &str,
|
||||
default: Option<&str>,
|
||||
validate: F,
|
||||
) -> io::Result<String>
|
||||
where
|
||||
F: Fn(&str) -> Result<(), String> + 'static,
|
||||
{
|
||||
if crate::is_no_interaction() || !crate::stdin_is_tty() {
|
||||
if let Some(def) = default {
|
||||
return Ok(def.to_string());
|
||||
}
|
||||
return Err(io::Error::other("interactive prompt disabled"));
|
||||
}
|
||||
let mut q = cliclack::input(prompt);
|
||||
if let Some(def) = default {
|
||||
q = q.default_input(def);
|
||||
}
|
||||
q.validate(move |s: &String| validate(s))
|
||||
.interact()
|
||||
.map_err(|e| io::Error::other(e.to_string()))
|
||||
}
|
||||
|
||||
pub struct Spinner(cliclack::ProgressBar);
|
||||
|
||||
impl Spinner {
|
||||
pub fn start(text: impl AsRef<str>) -> Self {
|
||||
if crate::is_no_progress() || crate::is_no_interaction() || !std::io::stderr().is_terminal()
|
||||
{
|
||||
let _ = cliclack::log::info(text.as_ref());
|
||||
let s = cliclack::spinner();
|
||||
Self(s)
|
||||
} else {
|
||||
let s = cliclack::spinner();
|
||||
s.start(text.as_ref());
|
||||
Self(s)
|
||||
}
|
||||
}
|
||||
pub fn stop(self, text: impl AsRef<str>) {
|
||||
let s = self.0;
|
||||
if crate::is_no_progress() {
|
||||
let _ = cliclack::log::info(text.as_ref());
|
||||
} else {
|
||||
s.stop(text.as_ref());
|
||||
}
|
||||
}
|
||||
pub fn success(self, text: impl AsRef<str>) {
|
||||
let s = self.0;
|
||||
if crate::is_no_progress() {
|
||||
let _ = cliclack::log::success(text.as_ref());
|
||||
} else {
|
||||
s.stop(text.as_ref());
|
||||
}
|
||||
}
|
||||
pub fn error(self, text: impl AsRef<str>) {
|
||||
let s = self.0;
|
||||
if crate::is_no_progress() {
|
||||
let _ = cliclack::log::error(text.as_ref());
|
||||
} else {
|
||||
s.error(text.as_ref());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BytesProgress {
|
||||
enabled: bool,
|
||||
total: u64,
|
||||
current: u64,
|
||||
started: Instant,
|
||||
last_msg: Instant,
|
||||
width: usize,
|
||||
// Sticky ETA to carry through zero-speed stalls
|
||||
last_eta_secs: Option<f64>,
|
||||
}
|
||||
|
||||
impl BytesProgress {
|
||||
pub fn start(total: u64, text: &str, initial: u64) -> Self {
|
||||
let enabled = !(crate::is_no_progress()
|
||||
|| crate::is_no_interaction()
|
||||
|| !std::io::stderr().is_terminal()
|
||||
|| total == 0);
|
||||
if !enabled {
|
||||
let _ = cliclack::log::info(text);
|
||||
}
|
||||
let mut me = Self {
|
||||
enabled,
|
||||
total,
|
||||
current: initial.min(total),
|
||||
started: Instant::now(),
|
||||
last_msg: Instant::now(),
|
||||
width: 40,
|
||||
last_eta_secs: None,
|
||||
};
|
||||
me.draw();
|
||||
me
|
||||
}
|
||||
|
||||
fn human_bytes(n: u64) -> String {
|
||||
const KB: f64 = 1024.0;
|
||||
const MB: f64 = 1024.0 * KB;
|
||||
const GB: f64 = 1024.0 * MB;
|
||||
let x = n as f64;
|
||||
if x >= GB {
|
||||
format!("{:.2} GiB", x / GB)
|
||||
} else if x >= MB {
|
||||
format!("{:.2} MiB", x / MB)
|
||||
} else if x >= KB {
|
||||
format!("{:.2} KiB", x / KB)
|
||||
} else {
|
||||
format!("{} B", n)
|
||||
}
|
||||
}
|
||||
|
||||
// Elapsed formatting is used for stable, finite durations. For ETA, we guard
|
||||
// against zero-speed or unstable estimates separately via `format_eta`.
|
||||
|
||||
fn refresh_allowed(&mut self) -> (f64, f64) {
|
||||
let now = Instant::now();
|
||||
let since_last = now.duration_since(self.last_msg);
|
||||
if since_last < Duration::from_millis(100) {
|
||||
// Too soon to refresh; keep previous ETA if any
|
||||
let eta = self.last_eta_secs.unwrap_or(f64::INFINITY);
|
||||
return (0.0, eta);
|
||||
}
|
||||
self.last_msg = now;
|
||||
let elapsed = now.duration_since(self.started).as_secs_f64().max(0.001);
|
||||
let speed = (self.current as f64) / elapsed;
|
||||
let remaining = self.total.saturating_sub(self.current) as f64;
|
||||
|
||||
// If speed is effectively zero, carry ETA forward and add wall time.
|
||||
const EPS: f64 = 1e-6;
|
||||
let eta = if speed <= EPS {
|
||||
let prev = self.last_eta_secs.unwrap_or(f64::INFINITY);
|
||||
if prev.is_finite() {
|
||||
prev + since_last.as_secs_f64()
|
||||
} else {
|
||||
prev
|
||||
}
|
||||
} else {
|
||||
remaining / speed
|
||||
};
|
||||
// Remember only finite ETAs to use during stalls
|
||||
if eta.is_finite() {
|
||||
self.last_eta_secs = Some(eta);
|
||||
}
|
||||
(speed, eta)
|
||||
}
|
||||
|
||||
fn format_elapsed(seconds: f64) -> String {
|
||||
let total = seconds.round() as u64;
|
||||
let h = total / 3600;
|
||||
let m = (total % 3600) / 60;
|
||||
let s = total % 60;
|
||||
if h > 0 { format!("{:02}:{:02}:{:02}", h, m, s) } else { format!("{:02}:{:02}", m, s) }
|
||||
}
|
||||
|
||||
fn format_eta(seconds: f64) -> String {
|
||||
// If ETA is not finite (e.g., divide-by-zero speed) or unreasonably large,
|
||||
// show a placeholder rather than overflowing into huge values.
|
||||
if !seconds.is_finite() {
|
||||
return "—".to_string();
|
||||
}
|
||||
// Cap ETA display to 99:59:59 to avoid silly numbers; beyond that, show placeholder.
|
||||
const CAP_SECS: f64 = 99.0 * 3600.0 + 59.0 * 60.0 + 59.0;
|
||||
if seconds > CAP_SECS {
|
||||
return "—".to_string();
|
||||
}
|
||||
Self::format_elapsed(seconds)
|
||||
}
|
||||
|
||||
fn draw(&mut self) {
|
||||
if !self.enabled { return; }
|
||||
let (speed, eta) = self.refresh_allowed();
|
||||
let elapsed = Instant::now().duration_since(self.started).as_secs_f64();
|
||||
// Build bar
|
||||
let width = self.width.max(10);
|
||||
let filled = ((self.current as f64 / self.total.max(1) as f64) * width as f64).round() as usize;
|
||||
let filled = filled.min(width);
|
||||
let mut bar = String::with_capacity(width);
|
||||
for _ in 0..filled { bar.push('■'); }
|
||||
for _ in filled..width { bar.push('□'); }
|
||||
|
||||
let line = format!(
|
||||
"[{}] {} [{}] ({}/{} at {}/s)",
|
||||
Self::format_elapsed(elapsed),
|
||||
bar,
|
||||
Self::format_eta(eta),
|
||||
Self::human_bytes(self.current),
|
||||
Self::human_bytes(self.total),
|
||||
Self::human_bytes(speed.max(0.0) as u64),
|
||||
);
|
||||
eprint!("\r{}\x1b[K", line);
|
||||
let _ = io::stderr().flush();
|
||||
}
|
||||
|
||||
pub fn inc(&mut self, delta: u64) {
|
||||
self.current = self.current.saturating_add(delta).min(self.total);
|
||||
self.draw();
|
||||
}
|
||||
|
||||
pub fn stop(mut self, text: &str) {
|
||||
if self.enabled {
|
||||
self.draw();
|
||||
eprintln!();
|
||||
} else {
|
||||
let _ = cliclack::log::info(text);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn error(mut self, text: &str) {
|
||||
if self.enabled {
|
||||
self.draw();
|
||||
eprintln!();
|
||||
let _ = cliclack::log::error(text);
|
||||
} else {
|
||||
let _ = cliclack::log::error(text);
|
||||
}
|
||||
}
|
||||
}
|
122
crates/polyscribe-core/src/ui/progress.rs
Normal file
122
crates/polyscribe-core/src/ui/progress.rs
Normal file
@@ -0,0 +1,122 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
use std::io::IsTerminal as _;
|
||||
|
||||
pub struct FileProgress {
|
||||
enabled: bool,
|
||||
file_bars: Vec<cliclack::ProgressBar>,
|
||||
total_bar: Option<cliclack::ProgressBar>,
|
||||
completed: usize,
|
||||
total_file_count: usize,
|
||||
}
|
||||
|
||||
impl FileProgress {
|
||||
pub fn new(enabled: bool) -> Self {
|
||||
Self {
|
||||
enabled,
|
||||
file_bars: Vec::new(),
|
||||
total_bar: None,
|
||||
completed: 0,
|
||||
total_file_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default_for_files(file_count: usize) -> Self {
|
||||
let enabled = file_count > 1
|
||||
&& std::io::stderr().is_terminal()
|
||||
&& !crate::is_quiet()
|
||||
&& !crate::is_no_progress();
|
||||
Self::new(enabled)
|
||||
}
|
||||
|
||||
pub fn init_files(&mut self, labels: &[String]) {
|
||||
self.total_file_count = labels.len();
|
||||
if !self.enabled || labels.len() <= 1 {
|
||||
self.enabled = false;
|
||||
return;
|
||||
}
|
||||
let total = cliclack::progress_bar(labels.len() as u64);
|
||||
total.start("Total");
|
||||
self.total_bar = Some(total);
|
||||
for label in labels {
|
||||
let pb = cliclack::progress_bar(100);
|
||||
pb.start(label);
|
||||
self.file_bars.push(pb);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.enabled
|
||||
}
|
||||
|
||||
pub fn set_file_message(&mut self, idx: usize, message: &str) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
}
|
||||
if let Some(pb) = self.file_bars.get_mut(idx) {
|
||||
pb.set_message(message);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_file_percent(&mut self, idx: usize, percent: u64) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
}
|
||||
if let Some(pb) = self.file_bars.get_mut(idx) {
|
||||
let p = percent.min(100);
|
||||
pb.set_message(format!("{p}%"));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mark_file_done(&mut self, idx: usize) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
}
|
||||
if let Some(pb) = self.file_bars.get_mut(idx) {
|
||||
pb.stop("done");
|
||||
}
|
||||
self.completed += 1;
|
||||
if let Some(total) = &mut self.total_bar {
|
||||
total.inc(1);
|
||||
if self.completed >= self.total_file_count {
|
||||
total.stop("all done");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn finish_total(&mut self, message: &str) {
|
||||
if !self.enabled {
|
||||
return;
|
||||
}
|
||||
if let Some(total) = &mut self.total_bar {
|
||||
total.stop(message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ProgressReporter {
|
||||
non_interactive: bool,
|
||||
}
|
||||
|
||||
impl ProgressReporter {
|
||||
pub fn new(non_interactive: bool) -> Self {
|
||||
Self { non_interactive }
|
||||
}
|
||||
|
||||
pub fn step(&mut self, message: &str) {
|
||||
if self.non_interactive {
|
||||
let _ = cliclack::log::info(format!("[..] {message}"));
|
||||
} else {
|
||||
let _ = cliclack::log::info(format!("• {message}"));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn finish_with_message(&mut self, message: &str) {
|
||||
if self.non_interactive {
|
||||
let _ = cliclack::log::info(format!("[ok] {message}"));
|
||||
} else {
|
||||
let _ = cliclack::log::info(format!("✓ {message}"));
|
||||
}
|
||||
}
|
||||
}
|
12
crates/polyscribe-host/Cargo.toml
Normal file
12
crates/polyscribe-host/Cargo.toml
Normal file
@@ -0,0 +1,12 @@
|
||||
[package]
|
||||
name = "polyscribe-host"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt-multi-thread", "process", "io-util"] }
|
||||
which = { workspace = true }
|
||||
directories = { workspace = true }
|
119
crates/polyscribe-host/src/lib.rs
Normal file
119
crates/polyscribe-host/src/lib.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
use anyhow::{Context, Result};
|
||||
use std::process::Stdio;
|
||||
use std::{
|
||||
env, fs,
|
||||
os::unix::fs::PermissionsExt,
|
||||
path::Path,
|
||||
};
|
||||
use tokio::{
|
||||
io::{AsyncBufReadExt, BufReader},
|
||||
process::{Child as TokioChild, Command},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PluginInfo {
|
||||
pub name: String,
|
||||
pub path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct PluginManager;
|
||||
|
||||
impl PluginManager {
|
||||
pub fn list(&self) -> Result<Vec<PluginInfo>> {
|
||||
let mut plugins = Vec::new();
|
||||
|
||||
if let Ok(path) = env::var("PATH") {
|
||||
for dir in env::split_paths(&path) {
|
||||
scan_dir_for_plugins(&dir, &mut plugins);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(dirs) = directories::ProjectDirs::from("dev", "polyscribe", "polyscribe") {
|
||||
let plugin_dir = dirs.data_dir().join("plugins");
|
||||
scan_dir_for_plugins(&plugin_dir, &mut plugins);
|
||||
}
|
||||
|
||||
plugins.sort_by(|a, b| a.path.cmp(&b.path));
|
||||
plugins.dedup_by(|a, b| a.path == b.path);
|
||||
Ok(plugins)
|
||||
}
|
||||
|
||||
pub fn info(&self, name: &str) -> Result<serde_json::Value> {
|
||||
let bin = self.resolve(name)?;
|
||||
let out = std::process::Command::new(&bin)
|
||||
.arg("info")
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::inherit())
|
||||
.spawn()
|
||||
.context("spawning plugin info")?
|
||||
.wait_with_output()
|
||||
.context("waiting for plugin info")?;
|
||||
|
||||
let val: serde_json::Value =
|
||||
serde_json::from_slice(&out.stdout).context("parsing plugin info JSON")?;
|
||||
Ok(val)
|
||||
}
|
||||
|
||||
pub fn spawn(&self, name: &str, command: &str) -> Result<TokioChild> {
|
||||
let bin = self.resolve(name)?;
|
||||
let mut cmd = Command::new(&bin);
|
||||
cmd.arg("run")
|
||||
.arg(command)
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::inherit());
|
||||
let child = cmd.spawn().context("spawning plugin run")?;
|
||||
Ok(child)
|
||||
}
|
||||
|
||||
pub async fn forward_stdio(&self, child: &mut TokioChild) -> Result<std::process::ExitStatus> {
|
||||
if let Some(stdout) = child.stdout.take() {
|
||||
let mut reader = BufReader::new(stdout).lines();
|
||||
while let Some(line) = reader.next_line().await? {
|
||||
println!("{line}");
|
||||
}
|
||||
}
|
||||
Ok(child.wait().await?)
|
||||
}
|
||||
|
||||
fn resolve(&self, name: &str) -> Result<String> {
|
||||
let bin = format!("polyscribe-plugin-{name}");
|
||||
let path =
|
||||
which::which(&bin).with_context(|| format!("plugin not found in PATH: {bin}"))?;
|
||||
Ok(path.to_string_lossy().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn is_executable(path: &Path) -> bool {
|
||||
if !path.is_file() {
|
||||
return false;
|
||||
}
|
||||
#[cfg(unix)]
|
||||
{
|
||||
if let Ok(meta) = fs::metadata(path) {
|
||||
let mode = meta.permissions().mode();
|
||||
return mode & 0o111 != 0;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn scan_dir_for_plugins(dir: &Path, out: &mut Vec<PluginInfo>) {
|
||||
if let Ok(read_dir) = fs::read_dir(dir) {
|
||||
for entry in read_dir.flatten() {
|
||||
let path = entry.path();
|
||||
if let Some(fname) = path.file_name().and_then(|s| s.to_str())
|
||||
&& fname.starts_with("polyscribe-plugin-")
|
||||
&& is_executable(&path)
|
||||
{
|
||||
let name = fname.trim_start_matches("polyscribe-plugin-").to_string();
|
||||
out.push(PluginInfo {
|
||||
name,
|
||||
path: path.to_string_lossy().to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
8
crates/polyscribe-protocol/Cargo.toml
Normal file
8
crates/polyscribe-protocol/Cargo.toml
Normal file
@@ -0,0 +1,8 @@
|
||||
[package]
|
||||
name = "polyscribe-protocol"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
60
crates/polyscribe-protocol/src/lib.rs
Normal file
60
crates/polyscribe-protocol/src/lib.rs
Normal file
@@ -0,0 +1,60 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Request {
|
||||
pub id: String,
|
||||
pub method: String,
|
||||
pub params: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Response {
|
||||
pub id: String,
|
||||
pub result: Option<Value>,
|
||||
pub error: Option<ErrorObj>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ErrorObj {
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "event", content = "data")]
|
||||
pub enum ProgressEvent {
|
||||
Started,
|
||||
Message(String),
|
||||
Percent(f32),
|
||||
Finished,
|
||||
}
|
||||
|
||||
impl Response {
|
||||
pub fn ok(id: impl Into<String>, result: Value) -> Self {
|
||||
Self {
|
||||
id: id.into(),
|
||||
result: Some(result),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn err(
|
||||
id: impl Into<String>,
|
||||
code: i32,
|
||||
message: impl Into<String>,
|
||||
data: Option<Value>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: id.into(),
|
||||
result: None,
|
||||
error: Some(ErrorObj {
|
||||
code,
|
||||
message: message.into(),
|
||||
data,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
31
docs/ci.md
31
docs/ci.md
@@ -1,31 +0,0 @@
|
||||
# CI checklist and job outline
|
||||
|
||||
Checklist to keep docs and code healthy in CI
|
||||
- Build: cargo build --all-targets --locked
|
||||
- Tests: cargo test --all --locked
|
||||
- Lints: cargo clippy --all-targets -- -D warnings
|
||||
- Optional: check README and docs snippets (basic smoke run of examples scripts)
|
||||
- bash examples/update_models.sh (can be skipped offline)
|
||||
- bash examples/transcribe_file.sh (use a tiny sample file if available)
|
||||
|
||||
Example GitHub Actions job (outline)
|
||||
- name: Rust
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- name: Print resolved whisper-rs rev
|
||||
run: |
|
||||
echo "Resolved whisper-rs revision:" && \
|
||||
awk '/name = "whisper-rs"/{f=1} f&&/source = "git\+.*whisper-rs#/{match($0,/#([0-9a-f]{7,40})"/,m); if(m[1]){print m[1]; exit}}' Cargo.lock
|
||||
- name: Build
|
||||
run: cargo build --all-targets --locked
|
||||
- name: Test
|
||||
run: cargo test --all --locked
|
||||
- name: Clippy
|
||||
run: cargo clippy --all-targets -- -D warnings
|
||||
|
||||
Notes
|
||||
- For GPU features, set up appropriate runners and add `--features gpu-cuda|gpu-hip|gpu-vulkan` where applicable.
|
||||
- For docs-only changes, jobs still build/test to ensure doctests and examples compile when enabled.
|
||||
- Mark the CI job named `ci` as a required status check for the default branch in repository branch protection settings.
|
@@ -13,12 +13,6 @@ Rust toolchain
|
||||
- rustup install stable
|
||||
- rustup default stable
|
||||
|
||||
Dependency pinning
|
||||
- We pin whisper-rs (git dependency) to a known-good commit in Cargo.toml for reproducibility.
|
||||
- To bump it, resolve/test the desired commit locally, then run:
|
||||
- cargo update -p whisper-rs --precise 135b60b85a15714862806b6ea9f76abec38156f1
|
||||
Replace the SHA with the desired commit and update the rev in Cargo.toml accordingly.
|
||||
|
||||
Build
|
||||
- CPU-only (default):
|
||||
- cargo build
|
||||
@@ -38,28 +32,20 @@ Run locally
|
||||
|
||||
Models during development
|
||||
- Interactive downloader:
|
||||
- cargo run -- --download-models
|
||||
- cargo run -- models download
|
||||
- Non-interactive update (checks sizes/hashes, downloads if missing):
|
||||
- cargo run -- --update-models --no-interaction -q
|
||||
- cargo run -- models update --no-interaction -q
|
||||
|
||||
Tests
|
||||
- Run all tests:
|
||||
- cargo test
|
||||
- The test suite includes CLI-oriented integration tests and unit tests. Some tests simulate GPU detection using env vars (POLYSCRIBE_TEST_FORCE_*). Do not rely on these flags in production code.
|
||||
|
||||
Examples check (no network, non-interactive)
|
||||
- To quickly validate that example scripts are wired correctly (no prompts, quiet, exit 0), run:
|
||||
- make examples-check
|
||||
- What it does:
|
||||
- Iterates over examples/*.sh
|
||||
- Forces execution with --no-interaction and -q via a wrapper
|
||||
- Uses a stubbed BIN that performs no network access and exits successfully
|
||||
- Redirects stdin from /dev/null to ensure no prompts
|
||||
- This is intended for CI smoke checks and local verification; it does not actually download models or transcribe audio.
|
||||
|
||||
Clippy
|
||||
Clippy & formatting
|
||||
- Run lint checks and treat warnings as errors:
|
||||
- cargo clippy --all-targets -- -D warnings
|
||||
- Check formatting:
|
||||
- cargo fmt --all -- --check
|
||||
- Common warnings can often be fixed by simplifying code, removing unused imports, and following idiomatic patterns.
|
||||
|
||||
Code layout
|
||||
@@ -77,10 +63,10 @@ Adding a feature
|
||||
|
||||
Running the model downloader
|
||||
- Interactive:
|
||||
- cargo run -- --download-models
|
||||
- cargo run -- models download
|
||||
- Non-interactive suggestions for CI:
|
||||
- POLYSCRIBE_MODELS_DIR=$PWD/models \
|
||||
cargo run -- --update-models --no-interaction -q
|
||||
cargo run -- models update --no-interaction -q
|
||||
|
||||
Env var examples for local testing
|
||||
- Use a local copy of models and a specific model file:
|
||||
|
@@ -30,9 +30,10 @@ CLI reference
|
||||
- Choose runtime backend. Default is auto (prefers CUDA → HIP → Vulkan → CPU), depending on detection.
|
||||
- --gpu-layers N
|
||||
- Number of layers to offload to the GPU when supported.
|
||||
- --download-models
|
||||
- models download
|
||||
- Launch interactive model downloader (lists Hugging Face models; multi-select to download).
|
||||
- --update-models
|
||||
- Controls: Use Up/Down to navigate, Space to toggle selections, and Enter to confirm. Models are grouped by base (e.g., tiny, base, small).
|
||||
- models update
|
||||
- Verify/update local models by comparing sizes and hashes with the upstream manifest.
|
||||
- -v, --verbose (repeatable)
|
||||
- Increase log verbosity; use -vv for very detailed logs.
|
||||
@@ -41,6 +42,9 @@ CLI reference
|
||||
- --no-interaction
|
||||
- Disable all interactive prompts (for CI). Combine with env vars to control behavior.
|
||||
- Subcommands:
|
||||
- models download: Launch interactive model downloader.
|
||||
- models update: Verify/update local models (non-interactive).
|
||||
- plugins list|info|run: Discover and run plugins.
|
||||
- completions <shell>: Write shell completion script to stdout.
|
||||
- man: Write a man page to stdout.
|
||||
|
||||
|
@@ -1,13 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
|
||||
set -euo pipefail
|
||||
|
||||
# Launch the interactive model downloader and select models to install
|
||||
|
||||
BIN=${BIN:-./target/release/polyscribe}
|
||||
MODELS_DIR=${POLYSCRIBE_MODELS_DIR:-$PWD/models}
|
||||
export POLYSCRIBE_MODELS_DIR="$MODELS_DIR"
|
||||
|
||||
mkdir -p "$MODELS_DIR"
|
||||
"$BIN" --download-models
|
@@ -1,15 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
|
||||
set -euo pipefail
|
||||
|
||||
# Transcribe an audio/video file to JSON and SRT into ./output
|
||||
# Requires a model; first run may prompt to download.
|
||||
|
||||
BIN=${BIN:-./target/release/polyscribe}
|
||||
INPUT=${1:-samples/example.mp3}
|
||||
OUTDIR=${OUTDIR:-output}
|
||||
|
||||
mkdir -p "$OUTDIR"
|
||||
"$BIN" -v -o "$OUTDIR" "$INPUT"
|
||||
echo "Done. See $OUTDIR for JSON/SRT files."
|
@@ -1,15 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
|
||||
set -euo pipefail
|
||||
|
||||
# Verify/update local models non-interactively (useful in CI)
|
||||
|
||||
BIN=${BIN:-./target/release/polyscribe}
|
||||
MODELS_DIR=${POLYSCRIBE_MODELS_DIR:-$PWD/models}
|
||||
export POLYSCRIBE_MODELS_DIR="$MODELS_DIR"
|
||||
|
||||
mkdir -p "$MODELS_DIR"
|
||||
"$BIN" --update-models --no-interaction -q
|
||||
|
||||
echo "Models updated in $MODELS_DIR"
|
17
plugins/polyscribe-plugin-tubescribe/Cargo.toml
Normal file
17
plugins/polyscribe-plugin-tubescribe/Cargo.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "polyscribe-plugin-tubescribe"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
license = "MIT"
|
||||
|
||||
[[bin]]
|
||||
name = "polyscribe-plugin-tubescribe"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.98"
|
||||
clap = { version = "4.5.43", features = ["derive"] }
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
serde_json = "1.0.142"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
polyscribe-protocol = { path = "../../crates/polyscribe-protocol" }
|
18
plugins/polyscribe-plugin-tubescribe/Makefile
Normal file
18
plugins/polyscribe-plugin-tubescribe/Makefile
Normal file
@@ -0,0 +1,18 @@
|
||||
# Simple helper to build and link the plugin into the user's XDG data dir
|
||||
# Usage:
|
||||
# make build
|
||||
# make link
|
||||
|
||||
PLUGIN := polyscribe-plugin-tubescribe
|
||||
BIN := ../../target/release/$(PLUGIN)
|
||||
|
||||
.PHONY: build link
|
||||
|
||||
build:
|
||||
cargo build -p $(PLUGIN) --release
|
||||
|
||||
link: build
|
||||
@DATA_DIR=$${XDG_DATA_HOME:-$$HOME/.local/share}; \
|
||||
mkdir -p $$DATA_DIR/polyscribe/plugins; \
|
||||
ln -sf "$(CURDIR)/$(BIN)" $$DATA_DIR/polyscribe/plugins/$(PLUGIN); \
|
||||
echo "Linked: $$DATA_DIR/polyscribe/plugins/$(PLUGIN) -> $(CURDIR)/$(BIN)"
|
93
plugins/polyscribe-plugin-tubescribe/src/main.rs
Normal file
93
plugins/polyscribe-plugin-tubescribe/src/main.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use clap::Parser;
|
||||
use polyscribe_protocol as psp;
|
||||
use serde_json::json;
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "polyscribe-plugin-tubescribe", version, about = "Stub tubescribe plugin for PolyScribe PSP/1")]
|
||||
struct Args {
|
||||
/// Print capabilities JSON and exit
|
||||
#[arg(long)]
|
||||
capabilities: bool,
|
||||
/// Serve mode: read one JSON-RPC request from stdin, stream progress and final result
|
||||
#[arg(long)]
|
||||
serve: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
if args.capabilities {
|
||||
let caps = psp::Capabilities {
|
||||
name: "tubescribe".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
protocol: "psp/1".to_string(),
|
||||
role: "pipeline".to_string(),
|
||||
commands: vec!["generate_metadata".to_string()],
|
||||
};
|
||||
let s = serde_json::to_string(&caps)?;
|
||||
println!("{}", s);
|
||||
return Ok(());
|
||||
}
|
||||
if args.serve {
|
||||
serve_once()?;
|
||||
return Ok(());
|
||||
}
|
||||
let caps = psp::Capabilities {
|
||||
name: "tubescribe".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
protocol: "psp/1".to_string(),
|
||||
role: "pipeline".to_string(),
|
||||
commands: vec!["generate_metadata".to_string()],
|
||||
};
|
||||
println!("{}", serde_json::to_string(&caps)?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn serve_once() -> Result<()> {
|
||||
let stdin = std::io::stdin();
|
||||
let mut reader = BufReader::new(stdin.lock());
|
||||
let mut line = String::new();
|
||||
reader.read_line(&mut line).context("failed to read request line")?;
|
||||
let req: psp::JsonRpcRequest = serde_json::from_str(line.trim()).context("invalid JSON-RPC request")?;
|
||||
|
||||
emit(&psp::StreamItem::progress(5, Some("start".into()), Some("initializing".into())))?;
|
||||
std::thread::sleep(std::time::Duration::from_millis(50));
|
||||
emit(&psp::StreamItem::progress(25, Some("probe".into()), Some("probing sources".into())))?;
|
||||
std::thread::sleep(std::time::Duration::from_millis(50));
|
||||
emit(&psp::StreamItem::progress(60, Some("analyze".into()), Some("analyzing".into())))?;
|
||||
std::thread::sleep(std::time::Duration::from_millis(50));
|
||||
emit(&psp::StreamItem::progress(90, Some("finalize".into()), Some("finalizing".into())))?;
|
||||
|
||||
let result = match req.method.as_str() {
|
||||
"generate_metadata" => {
|
||||
let title = "Canned title";
|
||||
let description = "Canned description for demonstration";
|
||||
let tags = vec!["demo", "tubescribe", "polyscribe"];
|
||||
json!({
|
||||
"title": title,
|
||||
"description": description,
|
||||
"tags": tags,
|
||||
})
|
||||
}
|
||||
other => {
|
||||
let err = psp::StreamItem::err(req.id.clone(), -32601, format!("Method not found: {}", other), None);
|
||||
emit(&err)?;
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
emit(&psp::StreamItem::ok(req.id.clone(), result))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit(item: &psp::StreamItem) -> Result<()> {
|
||||
let mut stdout = std::io::stdout().lock();
|
||||
let s = serde_json::to_string(item)?;
|
||||
stdout.write_all(s.as_bytes())?;
|
||||
stdout.write_all(b"\n")?;
|
||||
stdout.flush()?;
|
||||
Ok(())
|
||||
}
|
6
rust-toolchain.toml
Normal file
6
rust-toolchain.toml
Normal file
@@ -0,0 +1,6 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
[toolchain]
|
||||
channel = "1.89.0"
|
||||
components = ["clippy", "rustfmt"]
|
||||
profile = "minimal"
|
@@ -1,26 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# Lightweight stub for examples-check: simulates the PolyScribe CLI without I/O or network
|
||||
# - Accepts any arguments
|
||||
# - Exits 0
|
||||
# - Produces no output unless VERBOSE_STUB=1
|
||||
# - Never performs network operations
|
||||
# - Never reads from stdin
|
||||
set -euo pipefail
|
||||
|
||||
if [[ "${VERBOSE_STUB:-0}" == "1" ]]; then
|
||||
echo "[stub] polyscribe $*" 1>&2
|
||||
fi
|
||||
|
||||
# Behave quietly if -q/--quiet is present by default (no output)
|
||||
# Honor --help/-h: print minimal usage if verbose requested
|
||||
if [[ "${VERBOSE_STUB:-0}" == "1" ]]; then
|
||||
for arg in "$@"; do
|
||||
if [[ "$arg" == "-h" || "$arg" == "--help" ]]; then
|
||||
echo "PolyScribe stub: no-op (examples-check)" 1>&2
|
||||
break
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
# Always succeed quietly
|
||||
exit 0
|
@@ -1,28 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# Wrapper that ensures --no-interaction -q are present, then delegates to the real BIN (stub by default)
|
||||
set -euo pipefail
|
||||
|
||||
REAL_BIN=${REAL_BIN:-"$(dirname "$0")/bin_stub.sh"}
|
||||
|
||||
# Append flags if not already present in args
|
||||
args=("$@")
|
||||
need_no_interaction=1
|
||||
need_quiet=1
|
||||
for a in "${args[@]}"; do
|
||||
[[ "$a" == "--no-interaction" ]] && need_no_interaction=0
|
||||
[[ "$a" == "-q" || "$a" == "--quiet" ]] && need_quiet=0
|
||||
done
|
||||
|
||||
if [[ $need_no_interaction -eq 1 ]]; then
|
||||
args=("--no-interaction" "${args[@]}")
|
||||
fi
|
||||
if [[ $need_quiet -eq 1 ]]; then
|
||||
args=("-q" "${args[@]}")
|
||||
fi
|
||||
|
||||
# Never read stdin; prevent accidental blocking by redirecting from /dev/null
|
||||
# Also advertise offline via env variables commonly checked by the app
|
||||
export CI=1
|
||||
export POLYSCRIBE_MODELS_BASE_COPY_DIR="${POLYSCRIBE_MODELS_BASE_COPY_DIR:-}" # leave empty by default
|
||||
|
||||
exec "$REAL_BIN" "${args[@]}" </dev/null
|
579
src/backend.rs
579
src/backend.rs
@@ -1,579 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
|
||||
|
||||
//! Transcription backend selection and implementations (CPU/GPU) used by PolyScribe.
|
||||
use crate::OutputEntry;
|
||||
use crate::progress::ProgressMessage;
|
||||
use crate::{decode_audio_to_pcm_f32_ffmpeg, find_model_file};
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use std::env;
|
||||
use std::path::Path;
|
||||
use std::sync::mpsc::Sender;
|
||||
|
||||
// Re-export a public enum for CLI parsing usage
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
/// Kind of transcription backend to use.
|
||||
pub enum BackendKind {
|
||||
/// Automatically detect the best available backend (CUDA > HIP > Vulkan > CPU).
|
||||
Auto,
|
||||
/// Pure CPU backend using whisper-rs.
|
||||
Cpu,
|
||||
/// NVIDIA CUDA backend (requires CUDA runtime available at load time and proper feature build).
|
||||
Cuda,
|
||||
/// AMD ROCm/HIP backend (requires hip/rocBLAS libraries available and proper feature build).
|
||||
Hip,
|
||||
/// Vulkan backend (experimental; requires Vulkan loader/SDK and feature build).
|
||||
Vulkan,
|
||||
}
|
||||
|
||||
/// Abstraction for a transcription backend implementation.
|
||||
pub trait TranscribeBackend {
|
||||
/// Return the backend kind for this implementation.
|
||||
fn kind(&self) -> BackendKind;
|
||||
/// Transcribe the given audio file path and return transcript entries.
|
||||
///
|
||||
/// Parameters:
|
||||
/// - audio_path: path to input media (audio or video) to be decoded/transcribed.
|
||||
/// - speaker: label to attach to all produced segments.
|
||||
/// - lang_opt: optional language hint (e.g., "en"); None means auto/multilingual model default.
|
||||
/// - gpu_layers: optional GPU layer count if applicable (ignored by some backends).
|
||||
fn transcribe(
|
||||
&self,
|
||||
audio_path: &Path,
|
||||
speaker: &str,
|
||||
lang_opt: Option<&str>,
|
||||
progress_tx: Option<Sender<ProgressMessage>>,
|
||||
gpu_layers: Option<u32>,
|
||||
) -> Result<Vec<OutputEntry>>;
|
||||
}
|
||||
|
||||
fn check_lib(_names: &[&str]) -> bool {
|
||||
#[cfg(test)]
|
||||
{
|
||||
// During unit tests, avoid touching system libs to prevent loader crashes in CI.
|
||||
false
|
||||
}
|
||||
#[cfg(not(test))]
|
||||
{
|
||||
// Disabled runtime dlopen probing to avoid loader instability; rely on environment overrides.
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn cuda_available() -> bool {
|
||||
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_CUDA") {
|
||||
return x == "1";
|
||||
}
|
||||
check_lib(&[
|
||||
"libcudart.so",
|
||||
"libcudart.so.12",
|
||||
"libcudart.so.11",
|
||||
"libcublas.so",
|
||||
"libcublas.so.12",
|
||||
])
|
||||
}
|
||||
|
||||
fn hip_available() -> bool {
|
||||
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_HIP") {
|
||||
return x == "1";
|
||||
}
|
||||
check_lib(&["libhipblas.so", "librocblas.so"])
|
||||
}
|
||||
|
||||
fn vulkan_available() -> bool {
|
||||
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_VULKAN") {
|
||||
return x == "1";
|
||||
}
|
||||
check_lib(&["libvulkan.so.1", "libvulkan.so"])
|
||||
}
|
||||
|
||||
/// CPU-based transcription backend using whisper-rs.
|
||||
pub struct CpuBackend;
|
||||
/// CUDA-accelerated transcription backend for NVIDIA GPUs.
|
||||
pub struct CudaBackend;
|
||||
/// ROCm/HIP-accelerated transcription backend for AMD GPUs.
|
||||
pub struct HipBackend;
|
||||
/// Vulkan-based transcription backend (experimental/incomplete).
|
||||
pub struct VulkanBackend;
|
||||
|
||||
impl CpuBackend {
|
||||
/// Create a new CPU backend instance.
|
||||
pub fn new() -> Self {
|
||||
CpuBackend
|
||||
}
|
||||
}
|
||||
impl Default for CpuBackend {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
impl CudaBackend {
|
||||
/// Create a new CUDA backend instance.
|
||||
pub fn new() -> Self {
|
||||
CudaBackend
|
||||
}
|
||||
}
|
||||
impl Default for CudaBackend {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
impl HipBackend {
|
||||
/// Create a new HIP backend instance.
|
||||
pub fn new() -> Self {
|
||||
HipBackend
|
||||
}
|
||||
}
|
||||
impl Default for HipBackend {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
impl VulkanBackend {
|
||||
/// Create a new Vulkan backend instance.
|
||||
pub fn new() -> Self {
|
||||
VulkanBackend
|
||||
}
|
||||
}
|
||||
impl Default for VulkanBackend {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate that a provided language hint is compatible with the selected model.
|
||||
///
|
||||
/// English-only models (filenames containing ".en." or ending with ".en.bin") reject non-"en" hints.
|
||||
/// When no language is provided, this check passes and downstream behavior remains unchanged.
|
||||
pub(crate) fn validate_model_lang_compat(model: &Path, lang_opt: Option<&str>) -> Result<()> {
|
||||
let is_en_only = model
|
||||
.file_name()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(|s| s.contains(".en.") || s.ends_with(".en.bin"))
|
||||
.unwrap_or(false);
|
||||
if let Some(lang) = lang_opt {
|
||||
if is_en_only && lang != "en" {
|
||||
return Err(anyhow!(
|
||||
"Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.",
|
||||
model.display(),
|
||||
lang
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl TranscribeBackend for CpuBackend {
|
||||
fn kind(&self) -> BackendKind {
|
||||
BackendKind::Cpu
|
||||
}
|
||||
fn transcribe(
|
||||
&self,
|
||||
audio_path: &Path,
|
||||
speaker: &str,
|
||||
lang_opt: Option<&str>,
|
||||
progress_tx: Option<Sender<ProgressMessage>>,
|
||||
_gpu_layers: Option<u32>,
|
||||
) -> Result<Vec<OutputEntry>> {
|
||||
transcribe_with_whisper_rs(audio_path, speaker, lang_opt, progress_tx)
|
||||
}
|
||||
}
|
||||
|
||||
impl TranscribeBackend for CudaBackend {
|
||||
fn kind(&self) -> BackendKind {
|
||||
BackendKind::Cuda
|
||||
}
|
||||
fn transcribe(
|
||||
&self,
|
||||
audio_path: &Path,
|
||||
speaker: &str,
|
||||
lang_opt: Option<&str>,
|
||||
progress_tx: Option<Sender<ProgressMessage>>,
|
||||
_gpu_layers: Option<u32>,
|
||||
) -> Result<Vec<OutputEntry>> {
|
||||
// whisper-rs uses enabled CUDA feature at build time; call same code path
|
||||
transcribe_with_whisper_rs(audio_path, speaker, lang_opt, progress_tx)
|
||||
}
|
||||
}
|
||||
|
||||
impl TranscribeBackend for HipBackend {
|
||||
fn kind(&self) -> BackendKind {
|
||||
BackendKind::Hip
|
||||
}
|
||||
fn transcribe(
|
||||
&self,
|
||||
audio_path: &Path,
|
||||
speaker: &str,
|
||||
lang_opt: Option<&str>,
|
||||
progress_tx: Option<Sender<ProgressMessage>>,
|
||||
_gpu_layers: Option<u32>,
|
||||
) -> Result<Vec<OutputEntry>> {
|
||||
transcribe_with_whisper_rs(audio_path, speaker, lang_opt, progress_tx)
|
||||
}
|
||||
}
|
||||
|
||||
impl TranscribeBackend for VulkanBackend {
|
||||
fn kind(&self) -> BackendKind {
|
||||
BackendKind::Vulkan
|
||||
}
|
||||
fn transcribe(
|
||||
&self,
|
||||
_audio_path: &Path,
|
||||
_speaker: &str,
|
||||
_lang_opt: Option<&str>,
|
||||
_progress_tx: Option<Sender<ProgressMessage>>,
|
||||
_gpu_layers: Option<u32>,
|
||||
) -> Result<Vec<OutputEntry>> {
|
||||
Err(anyhow!(
|
||||
"Vulkan backend not yet wired to whisper.cpp FFI. Build with --features gpu-vulkan and ensure Vulkan SDK is installed. How to fix: install Vulkan loader (libvulkan), set VULKAN_SDK, and run cargo build --features gpu-vulkan."
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of choosing a transcription backend.
|
||||
pub struct SelectionResult {
|
||||
/// The constructed backend instance to perform transcription with.
|
||||
pub backend: Box<dyn TranscribeBackend + Send + Sync>,
|
||||
/// Which backend kind was ultimately selected.
|
||||
pub chosen: BackendKind,
|
||||
/// Which backend kinds were detected as available on this system.
|
||||
pub detected: Vec<BackendKind>,
|
||||
}
|
||||
|
||||
/// Select an appropriate backend based on user request and system detection.
|
||||
///
|
||||
/// If `requested` is `BackendKind::Auto`, the function prefers CUDA, then HIP,
|
||||
/// then Vulkan, falling back to CPU when no GPU backend is detected. When a
|
||||
/// specific GPU backend is requested but unavailable, an error is returned with
|
||||
/// guidance on how to enable it.
|
||||
///
|
||||
/// Set `verbose` to true to print detection/selection info to stderr.
|
||||
pub fn select_backend(requested: BackendKind, config: &crate::Config) -> Result<SelectionResult> {
|
||||
let mut detected = Vec::new();
|
||||
if cuda_available() {
|
||||
detected.push(BackendKind::Cuda);
|
||||
}
|
||||
if hip_available() {
|
||||
detected.push(BackendKind::Hip);
|
||||
}
|
||||
if vulkan_available() {
|
||||
detected.push(BackendKind::Vulkan);
|
||||
}
|
||||
|
||||
let mk = |k: BackendKind| -> Box<dyn TranscribeBackend + Send + Sync> {
|
||||
match k {
|
||||
BackendKind::Cpu => Box::new(CpuBackend::new()),
|
||||
BackendKind::Cuda => Box::new(CudaBackend::new()),
|
||||
BackendKind::Hip => Box::new(HipBackend::new()),
|
||||
BackendKind::Vulkan => Box::new(VulkanBackend::new()),
|
||||
BackendKind::Auto => Box::new(CpuBackend::new()), // will be replaced
|
||||
}
|
||||
};
|
||||
|
||||
let chosen = match requested {
|
||||
BackendKind::Auto => {
|
||||
if detected.contains(&BackendKind::Cuda) {
|
||||
BackendKind::Cuda
|
||||
} else if detected.contains(&BackendKind::Hip) {
|
||||
BackendKind::Hip
|
||||
} else if detected.contains(&BackendKind::Vulkan) {
|
||||
BackendKind::Vulkan
|
||||
} else {
|
||||
BackendKind::Cpu
|
||||
}
|
||||
}
|
||||
BackendKind::Cuda => {
|
||||
if detected.contains(&BackendKind::Cuda) {
|
||||
BackendKind::Cuda
|
||||
} else {
|
||||
return Err(anyhow!(
|
||||
"Requested CUDA backend but CUDA libraries/devices not detected. How to fix: install NVIDIA driver + CUDA toolkit, ensure libcudart/libcublas are in loader path, and build with --features gpu-cuda."
|
||||
));
|
||||
}
|
||||
}
|
||||
BackendKind::Hip => {
|
||||
if detected.contains(&BackendKind::Hip) {
|
||||
BackendKind::Hip
|
||||
} else {
|
||||
return Err(anyhow!(
|
||||
"Requested ROCm/HIP backend but libraries/devices not detected. How to fix: install ROCm hipBLAS/rocBLAS, ensure libs are in loader path, and build with --features gpu-hip."
|
||||
));
|
||||
}
|
||||
}
|
||||
BackendKind::Vulkan => {
|
||||
if detected.contains(&BackendKind::Vulkan) {
|
||||
BackendKind::Vulkan
|
||||
} else {
|
||||
return Err(anyhow!(
|
||||
"Requested Vulkan backend but libvulkan not detected. How to fix: install Vulkan loader/SDK and build with --features gpu-vulkan."
|
||||
));
|
||||
}
|
||||
}
|
||||
BackendKind::Cpu => BackendKind::Cpu,
|
||||
};
|
||||
|
||||
if config.verbose >= 1 && !config.quiet {
|
||||
crate::dlog!(1, "Detected backends: {:?}", detected);
|
||||
crate::dlog!(1, "Selected backend: {:?}", chosen);
|
||||
}
|
||||
|
||||
Ok(SelectionResult {
|
||||
backend: mk(chosen),
|
||||
chosen,
|
||||
detected,
|
||||
})
|
||||
}
|
||||
|
||||
// Internal helper: transcription using whisper-rs with CPU/GPU (depending on build features)
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[cfg(feature = "whisper")]
|
||||
pub(crate) fn transcribe_with_whisper_rs(
|
||||
audio_path: &Path,
|
||||
speaker: &str,
|
||||
lang_opt: Option<&str>,
|
||||
progress_tx: Option<Sender<ProgressMessage>>,
|
||||
) -> Result<Vec<OutputEntry>> {
|
||||
// initial progress
|
||||
if let Some(tx) = &progress_tx {
|
||||
let _ = tx.send(ProgressMessage {
|
||||
fraction: 0.0,
|
||||
stage: Some("load_model".to_string()),
|
||||
note: Some(format!("{}", audio_path.display())),
|
||||
});
|
||||
}
|
||||
let pcm = decode_audio_to_pcm_f32_ffmpeg(audio_path)?;
|
||||
let model = find_model_file()?;
|
||||
if let Some(tx) = &progress_tx {
|
||||
let _ = tx.send(ProgressMessage {
|
||||
fraction: 0.05,
|
||||
stage: Some("load_model".to_string()),
|
||||
note: Some("model selected".to_string()),
|
||||
});
|
||||
}
|
||||
// Validate language hint compatibility with the selected model
|
||||
validate_model_lang_compat(&model, lang_opt)?;
|
||||
let model_str = model
|
||||
.to_str()
|
||||
.ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model.display()))?;
|
||||
|
||||
// Try to reduce native library logging via environment variables when not super-verbose.
|
||||
if crate::verbose_level() < 2 {
|
||||
// These env vars are recognized by ggml/whisper in many builds; harmless if unknown.
|
||||
unsafe {
|
||||
std::env::set_var("GGML_LOG_LEVEL", "0");
|
||||
std::env::set_var("WHISPER_PRINT_PROGRESS", "0");
|
||||
}
|
||||
}
|
||||
|
||||
// Suppress stderr from whisper/ggml during model load and inference when quiet and not verbose.
|
||||
let (_ctx, mut state) = crate::with_suppressed_stderr(|| {
|
||||
let cparams = whisper_rs::WhisperContextParameters::default();
|
||||
let ctx = whisper_rs::WhisperContext::new_with_params(model_str, cparams)
|
||||
.with_context(|| format!("Failed to load Whisper model at {}", model.display()))?;
|
||||
let state = ctx
|
||||
.create_state()
|
||||
.map_err(|e| anyhow!("Failed to create Whisper state: {:?}", e))?;
|
||||
Ok::<_, anyhow::Error>((ctx, state))
|
||||
})?;
|
||||
if let Some(tx) = &progress_tx {
|
||||
let _ = tx.send(ProgressMessage {
|
||||
fraction: 0.15,
|
||||
stage: Some("encode".to_string()),
|
||||
note: Some("state ready".to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
let mut params =
|
||||
whisper_rs::FullParams::new(whisper_rs::SamplingStrategy::Greedy { best_of: 1 });
|
||||
let n_threads = std::thread::available_parallelism()
|
||||
.map(|n| n.get() as i32)
|
||||
.unwrap_or(1);
|
||||
params.set_n_threads(n_threads);
|
||||
params.set_translate(false);
|
||||
if let Some(lang) = lang_opt {
|
||||
params.set_language(Some(lang));
|
||||
}
|
||||
|
||||
if let Some(tx) = &progress_tx {
|
||||
let _ = tx.send(ProgressMessage {
|
||||
fraction: 0.20,
|
||||
stage: Some("decode".to_string()),
|
||||
note: Some("inference".to_string()),
|
||||
});
|
||||
}
|
||||
crate::with_suppressed_stderr(|| {
|
||||
state
|
||||
.full(params, &pcm)
|
||||
.map_err(|e| anyhow!("Whisper full() failed: {:?}", e))
|
||||
})?;
|
||||
if let Some(tx) = &progress_tx {
|
||||
let _ = tx.send(ProgressMessage {
|
||||
fraction: 1.0,
|
||||
stage: Some("done".to_string()),
|
||||
note: Some("transcription finished".to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
let num_segments = state
|
||||
.full_n_segments()
|
||||
.map_err(|e| anyhow!("Failed to get segments: {:?}", e))?;
|
||||
let mut items = Vec::new();
|
||||
for i in 0..num_segments {
|
||||
let text = state
|
||||
.full_get_segment_text(i)
|
||||
.map_err(|e| anyhow!("Failed to get segment text: {:?}", e))?;
|
||||
let t0 = state
|
||||
.full_get_segment_t0(i)
|
||||
.map_err(|e| anyhow!("Failed to get segment t0: {:?}", e))?;
|
||||
let t1 = state
|
||||
.full_get_segment_t1(i)
|
||||
.map_err(|e| anyhow!("Failed to get segment t1: {:?}", e))?;
|
||||
let start = (t0 as f64) * 0.01;
|
||||
let end = (t1 as f64) * 0.01;
|
||||
items.push(OutputEntry {
|
||||
id: 0,
|
||||
speaker: speaker.to_string(),
|
||||
start,
|
||||
end,
|
||||
text: text.trim().to_string(),
|
||||
});
|
||||
}
|
||||
Ok(items)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[cfg(not(feature = "whisper"))]
|
||||
pub(crate) fn transcribe_with_whisper_rs(
|
||||
_audio_path: &Path,
|
||||
_speaker: &str,
|
||||
_lang_opt: Option<&str>,
|
||||
_progress_tx: Option<Sender<ProgressMessage>>,
|
||||
) -> Result<Vec<OutputEntry>> {
|
||||
Err(anyhow!(
|
||||
"Transcription requires the 'whisper' feature. Rebuild with --features whisper (and optional gpu-cuda/gpu-hip)."
|
||||
))
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::env as std_env;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
#[test]
|
||||
fn test_validate_model_lang_guard_table() {
|
||||
struct case<'a> { model: &'a str, lang: Option<&'a str>, ok: bool }
|
||||
let cases = vec![
|
||||
// English-only model with en hint: OK
|
||||
case { model: "ggml-base.en.bin", lang: Some("en"), ok: true },
|
||||
// English-only model with de hint: Error
|
||||
case { model: "ggml-small.en.bin", lang: Some("de"), ok: false },
|
||||
// Multilingual model with de hint: OK
|
||||
case { model: "ggml-large-v3.bin", lang: Some("de"), ok: true },
|
||||
// No language provided (audio path scenario): guard should pass (existing behavior elsewhere)
|
||||
case { model: "ggml-medium.en.bin", lang: None, ok: true },
|
||||
];
|
||||
for c in cases {
|
||||
let p = std::path::Path::new(c.model);
|
||||
let res = validate_model_lang_compat(p, c.lang);
|
||||
match (c.ok, res) {
|
||||
(true, Ok(())) => {}
|
||||
(false, Err(e)) => {
|
||||
let msg = format!("{}", e);
|
||||
assert!(msg.contains("English-only"), "unexpected error: {msg}");
|
||||
if let Some(l) = c.lang { assert!(msg.contains(l), "missing lang in msg: {msg}"); }
|
||||
}
|
||||
(true, Err(e)) => panic!("expected Ok for model={}, lang={:?}, got error: {}", c.model, c.lang, e),
|
||||
(false, Ok(())) => panic!("expected Err for model={}, lang={:?}", c.model, c.lang),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Serialize environment variable modifications across tests in this module
|
||||
static ENV_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
|
||||
#[test]
|
||||
fn test_select_backend_auto_prefers_cuda_then_hip_then_vulkan_then_cpu() {
|
||||
let _guard = ENV_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
|
||||
// Clear overrides
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
|
||||
}
|
||||
// No GPU -> CPU
|
||||
let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap();
|
||||
assert_eq!(sel.chosen, BackendKind::Cpu);
|
||||
|
||||
// Vulkan only -> Vulkan
|
||||
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1"); }
|
||||
let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap();
|
||||
assert_eq!(sel.chosen, BackendKind::Vulkan);
|
||||
|
||||
// HIP only -> HIP (and preferred over Vulkan)
|
||||
unsafe {
|
||||
std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
|
||||
}
|
||||
let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap();
|
||||
assert_eq!(sel.chosen, BackendKind::Hip);
|
||||
|
||||
// CUDA only -> CUDA (and preferred over HIP)
|
||||
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); }
|
||||
let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap();
|
||||
assert_eq!(sel.chosen, BackendKind::Cuda);
|
||||
|
||||
// Cleanup
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_backend_explicit_unavailable_errors_with_guidance() {
|
||||
let _guard = ENV_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
|
||||
// Ensure all off
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
|
||||
}
|
||||
// CUDA requested but unavailable -> error with guidance
|
||||
let err = select_backend(BackendKind::Cuda, &crate::Config::default()).err().expect("expected error");
|
||||
let msg = format!("{}", err);
|
||||
assert!(msg.contains("Requested CUDA backend"), "unexpected msg: {msg}");
|
||||
assert!(msg.contains("How to fix"), "expected guidance text in: {msg}");
|
||||
|
||||
// HIP requested but unavailable -> error with guidance
|
||||
let err = select_backend(BackendKind::Hip, &crate::Config::default()).err().expect("expected error");
|
||||
let msg = format!("{}", err);
|
||||
assert!(msg.contains("ROCm/HIP"), "unexpected msg: {msg}");
|
||||
assert!(msg.contains("How to fix"), "expected guidance text in: {msg}");
|
||||
|
||||
// Vulkan requested but unavailable -> error with guidance
|
||||
let err = select_backend(BackendKind::Vulkan, &crate::Config::default()).err().expect("expected error");
|
||||
let msg = format!("{}", err);
|
||||
assert!(msg.contains("Vulkan"), "unexpected msg: {msg}");
|
||||
assert!(msg.contains("How to fix"), "expected guidance text in: {msg}");
|
||||
|
||||
// Now verify success when explicitly available via overrides
|
||||
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); }
|
||||
assert!(select_backend(BackendKind::Cuda, &crate::Config::default()).is_ok());
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
|
||||
std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1");
|
||||
}
|
||||
assert!(select_backend(BackendKind::Hip, &crate::Config::default()).is_ok());
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
|
||||
std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1");
|
||||
}
|
||||
assert!(select_backend(BackendKind::Vulkan, &crate::Config::default()).is_ok());
|
||||
|
||||
// Cleanup
|
||||
unsafe { std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); }
|
||||
}
|
||||
}
|
718
src/lib.rs
718
src/lib.rs
@@ -1,718 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
|
||||
|
||||
#![forbid(elided_lifetimes_in_paths)]
|
||||
#![forbid(unused_must_use)]
|
||||
#![deny(missing_docs)]
|
||||
// Lint policy for incremental refactor toward 2024:
|
||||
// - Keep basic clippy warnings enabled; skip pedantic/nursery for now (will revisit in step 7).
|
||||
// - cargo lints can be re-enabled later once codebase is tidied.
|
||||
#![warn(clippy::all)]
|
||||
//! PolyScribe library: business logic and core types.
|
||||
//!
|
||||
//! This crate exposes the reusable parts of the PolyScribe CLI as a library.
|
||||
//! The binary entry point (main.rs) remains a thin CLI wrapper.
|
||||
|
||||
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
|
||||
|
||||
// Global runtime flags
|
||||
// Compatibility: globals are retained temporarily until all call-sites pass Config explicitly. They will be removed in a follow-up cleanup.
|
||||
static QUIET: AtomicBool = AtomicBool::new(false);
|
||||
static NO_INTERACTION: AtomicBool = AtomicBool::new(false);
|
||||
static VERBOSE: AtomicU8 = AtomicU8::new(0);
|
||||
|
||||
/// Set quiet mode: when true, non-interactive logs should be suppressed.
|
||||
pub fn set_quiet(q: bool) {
|
||||
QUIET.store(q, Ordering::Relaxed);
|
||||
}
|
||||
/// Return current quiet mode state.
|
||||
pub fn is_quiet() -> bool {
|
||||
QUIET.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Set non-interactive mode: when true, interactive prompts must be skipped.
|
||||
pub fn set_no_interaction(b: bool) {
|
||||
NO_INTERACTION.store(b, Ordering::Relaxed);
|
||||
}
|
||||
/// Return current non-interactive state.
|
||||
pub fn is_no_interaction() -> bool {
|
||||
if NO_INTERACTION.load(Ordering::Relaxed) {
|
||||
return true;
|
||||
}
|
||||
// Also honor NO_INTERACTION=1/true environment variable for convenience/testing
|
||||
match std::env::var("NO_INTERACTION") {
|
||||
Ok(v) => {
|
||||
let v = v.trim();
|
||||
v == "1" || v.eq_ignore_ascii_case("true")
|
||||
}
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set verbose level (0 = normal, 1 = verbose, 2 = super-verbose)
|
||||
pub fn set_verbose(level: u8) {
|
||||
VERBOSE.store(level, Ordering::Relaxed);
|
||||
}
|
||||
/// Get current verbose level.
|
||||
pub fn verbose_level() -> u8 {
|
||||
VERBOSE.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Check whether stdin is connected to a TTY. Used to avoid blocking prompts when not interactive.
|
||||
pub fn stdin_is_tty() -> bool {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::io::AsRawFd;
|
||||
unsafe { libc::isatty(std::io::stdin().as_raw_fd()) == 1 }
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
// Best-effort on non-Unix: assume TTY when not redirected by common CI vars
|
||||
// This avoids introducing a new dependency for atty.
|
||||
!(std::env::var("CI").is_ok() || std::env::var("GITHUB_ACTIONS").is_ok())
|
||||
}
|
||||
}
|
||||
|
||||
/// A guard that temporarily redirects stderr to /dev/null on Unix when quiet mode is active.
|
||||
/// No-op on non-Unix or when quiet is disabled. Restores stderr on drop.
|
||||
pub struct StderrSilencer {
|
||||
#[cfg(unix)]
|
||||
old_stderr_fd: i32,
|
||||
#[cfg(unix)]
|
||||
devnull_fd: i32,
|
||||
active: bool,
|
||||
}
|
||||
|
||||
impl StderrSilencer {
|
||||
/// Activate stderr silencing if quiet is set and on Unix; otherwise returns a no-op guard.
|
||||
pub fn activate_if_quiet() -> Self {
|
||||
if !is_quiet() {
|
||||
return Self {
|
||||
active: false,
|
||||
#[cfg(unix)]
|
||||
old_stderr_fd: -1,
|
||||
#[cfg(unix)]
|
||||
devnull_fd: -1,
|
||||
};
|
||||
}
|
||||
Self::activate()
|
||||
}
|
||||
|
||||
/// Activate stderr silencing unconditionally (used internally); no-op on non-Unix.
|
||||
pub fn activate() -> Self {
|
||||
#[cfg(unix)]
|
||||
unsafe {
|
||||
// Duplicate current stderr (fd 2)
|
||||
let old_fd = unix_fd::dup(unix_fd::STDERR_FILENO);
|
||||
if old_fd < 0 {
|
||||
return Self {
|
||||
active: false,
|
||||
old_stderr_fd: -1,
|
||||
devnull_fd: -1,
|
||||
};
|
||||
}
|
||||
// Open /dev/null for writing
|
||||
let devnull_cstr = std::ffi::CString::new("/dev/null").unwrap();
|
||||
let dn = unix_fd::open(devnull_cstr.as_ptr(), unix_fd::O_WRONLY);
|
||||
if dn < 0 {
|
||||
// failed to open devnull; restore and bail
|
||||
unix_fd::close(old_fd);
|
||||
return Self {
|
||||
active: false,
|
||||
old_stderr_fd: -1,
|
||||
devnull_fd: -1,
|
||||
};
|
||||
}
|
||||
// Redirect fd 2 to devnull
|
||||
if unix_fd::dup2(dn, unix_fd::STDERR_FILENO) < 0 {
|
||||
unix_fd::close(dn);
|
||||
unix_fd::close(old_fd);
|
||||
return Self {
|
||||
active: false,
|
||||
old_stderr_fd: -1,
|
||||
devnull_fd: -1,
|
||||
};
|
||||
}
|
||||
Self {
|
||||
active: true,
|
||||
old_stderr_fd: old_fd,
|
||||
devnull_fd: dn,
|
||||
}
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
Self { active: false }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for StderrSilencer {
|
||||
fn drop(&mut self) {
|
||||
if !self.active {
|
||||
return;
|
||||
}
|
||||
#[cfg(unix)]
|
||||
unsafe {
|
||||
// Restore old stderr and close devnull and old copies
|
||||
let _ = unix_fd::dup2(self.old_stderr_fd, unix_fd::STDERR_FILENO);
|
||||
let _ = unix_fd::close(self.devnull_fd);
|
||||
let _ = unix_fd::close(self.old_stderr_fd);
|
||||
}
|
||||
self.active = false;
|
||||
}
|
||||
}
|
||||
|
||||
/// Run a closure while temporarily suppressing stderr on Unix when appropriate.
|
||||
/// On Windows/non-Unix, this is a no-op wrapper.
|
||||
/// This helper uses RAII + panic catching to ensure restoration before resuming panic.
|
||||
pub fn with_suppressed_stderr<F, T>(f: F) -> T
|
||||
where
|
||||
F: FnOnce() -> T,
|
||||
{
|
||||
// Suppress noisy native logs unless super-verbose (-vv) is enabled.
|
||||
if verbose_level() < 2 {
|
||||
let res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
let _guard = StderrSilencer::activate();
|
||||
f()
|
||||
}));
|
||||
match res {
|
||||
Ok(v) => v,
|
||||
Err(p) => std::panic::resume_unwind(p),
|
||||
}
|
||||
} else {
|
||||
f()
|
||||
}
|
||||
}
|
||||
|
||||
/// Logging macros and helpers
|
||||
/// Log an error to stderr (always printed). Recommended for user-visible errors.
|
||||
#[macro_export]
|
||||
macro_rules! elog {
|
||||
($($arg:tt)*) => {{
|
||||
// Route errors through the progress area when available so they render inside cliclack
|
||||
$crate::log_with_level!("ERROR", None, true, $($arg)*);
|
||||
}}
|
||||
}
|
||||
/// Internal helper macro used by other logging macros to centralize the
|
||||
/// common behavior: build formatted message, check quiet/verbose flags,
|
||||
/// and print to stderr with a label.
|
||||
#[macro_export]
|
||||
macro_rules! log_with_level {
|
||||
($label:expr, $min_lvl:expr, $always:expr, $($arg:tt)*) => {{
|
||||
let should_print = if $always {
|
||||
true
|
||||
} else if let Some(minv) = $min_lvl {
|
||||
!$crate::is_quiet() && $crate::verbose_level() >= minv
|
||||
} else {
|
||||
!$crate::is_quiet()
|
||||
};
|
||||
if should_print {
|
||||
let line = format!("{}: {}", $label, format!($($arg)*));
|
||||
// Try to render via the active progress manager (cliclack/indicatif area).
|
||||
if !$crate::progress::log_line_via_global(&line) {
|
||||
eprintln!("{}", line);
|
||||
}
|
||||
}
|
||||
}}
|
||||
}
|
||||
|
||||
/// Log a warning to stderr (printed even in quiet mode).
|
||||
#[macro_export]
|
||||
macro_rules! wlog {
|
||||
($($arg:tt)*) => {{ $crate::log_with_level!("WARN", None, true, $($arg)*); }}
|
||||
}
|
||||
|
||||
/// Log an informational line to stderr unless quiet mode is enabled.
|
||||
#[macro_export]
|
||||
macro_rules! ilog {
|
||||
($($arg:tt)*) => {{ $crate::log_with_level!("INFO", None, false, $($arg)*); }}
|
||||
}
|
||||
|
||||
/// Log a debug/trace line when verbose level is at least the given level (u8).
|
||||
#[macro_export]
|
||||
macro_rules! dlog {
|
||||
($lvl:expr, $($arg:tt)*) => {{
|
||||
$crate::log_with_level!(&format!("DEBUG{}", &$lvl), Some($lvl), false, $($arg)*);
|
||||
}}
|
||||
}
|
||||
|
||||
/// Backward-compatibility: map old qlog! to ilog!
|
||||
#[macro_export]
|
||||
macro_rules! qlog {
|
||||
($($arg:tt)*) => {{ $crate::ilog!($($arg)*); }}
|
||||
}
|
||||
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use chrono::Local;
|
||||
use std::env;
|
||||
use std::fs::create_dir_all;
|
||||
use std::io;
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
|
||||
#[cfg(unix)]
|
||||
mod unix_fd {
|
||||
pub use libc::O_WRONLY;
|
||||
pub const STDERR_FILENO: i32 = 2; // libc::STDERR_FILENO isn't always available on all targets
|
||||
#[inline]
|
||||
pub unsafe fn dup(fd: i32) -> i32 { libc::dup(fd) }
|
||||
#[inline]
|
||||
pub unsafe fn dup2(fd: i32, fd2: i32) -> i32 { libc::dup2(fd, fd2) }
|
||||
#[inline]
|
||||
pub unsafe fn open(path: *const libc::c_char, flags: i32) -> i32 { libc::open(path, flags) }
|
||||
#[inline]
|
||||
pub unsafe fn close(fd: i32) -> i32 { libc::close(fd) }
|
||||
}
|
||||
|
||||
/// Re-export backend module (GPU/CPU selection and transcription).
|
||||
pub mod backend;
|
||||
/// Re-export models module (model listing/downloading/updating).
|
||||
pub mod models;
|
||||
/// Progress and progress bar abstraction (TTY-aware, stderr-only)
|
||||
pub mod progress;
|
||||
|
||||
/// UI helpers for interactive prompts (cliclack-backed)
|
||||
pub mod ui;
|
||||
|
||||
/// Runtime configuration passed across the library instead of using globals.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct Config {
|
||||
/// Suppress non-essential logs.
|
||||
pub quiet: bool,
|
||||
/// Verbosity level (0 = normal, 1 = verbose, 2 = super-verbose).
|
||||
pub verbose: u8,
|
||||
/// Disable interactive prompts.
|
||||
pub no_interaction: bool,
|
||||
/// Disable progress output.
|
||||
pub no_progress: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Construct a Config from explicit values.
|
||||
pub fn new(quiet: bool, verbose: u8, no_interaction: bool, no_progress: bool) -> Self {
|
||||
Self { quiet, verbose, no_interaction, no_progress }
|
||||
}
|
||||
/// Snapshot current global settings into a Config (temporary compatibility helper).
|
||||
pub fn from_globals() -> Self {
|
||||
Self {
|
||||
quiet: crate::is_quiet(),
|
||||
verbose: crate::verbose_level(),
|
||||
no_interaction: crate::is_no_interaction(),
|
||||
no_progress: matches!(std::env::var("NO_PROGRESS"), Ok(ref v) if v == "1" || v.eq_ignore_ascii_case("true")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self { quiet: false, verbose: 0, no_interaction: false, no_progress: false }
|
||||
}
|
||||
}
|
||||
|
||||
/// Transcript entry for a single segment.
|
||||
#[derive(Debug, serde::Serialize, Clone)]
|
||||
pub struct OutputEntry {
|
||||
/// Sequential id in output ordering.
|
||||
pub id: u64,
|
||||
/// Speaker label associated with the segment.
|
||||
pub speaker: String,
|
||||
/// Start time in seconds.
|
||||
pub start: f64,
|
||||
/// End time in seconds.
|
||||
pub end: f64,
|
||||
/// Text content.
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
/// Return a YYYY-MM-DD date prefix string for output file naming.
|
||||
pub fn date_prefix() -> String {
|
||||
Local::now().format("%Y-%m-%d").to_string()
|
||||
}
|
||||
|
||||
/// Format a floating-point number of seconds as SRT timestamp (HH:MM:SS,mmm).
|
||||
pub fn format_srt_time(seconds: f64) -> String {
|
||||
let total_ms = (seconds * 1000.0).round() as i64;
|
||||
let ms = total_ms % 1000;
|
||||
let total_secs = total_ms / 1000;
|
||||
let s = total_secs % 60;
|
||||
let m = (total_secs / 60) % 60;
|
||||
let h = total_secs / 3600;
|
||||
format!("{h:02}:{m:02}:{s:02},{ms:03}")
|
||||
}
|
||||
|
||||
/// Render a list of transcript entries to SRT format.
|
||||
pub fn render_srt(items: &[OutputEntry]) -> String {
|
||||
let mut out = String::new();
|
||||
for (i, e) in items.iter().enumerate() {
|
||||
let idx = i + 1;
|
||||
out.push_str(&format!("{idx}\n"));
|
||||
out.push_str(&format!(
|
||||
"{} --> {}\n",
|
||||
format_srt_time(e.start),
|
||||
format_srt_time(e.end)
|
||||
));
|
||||
if !e.speaker.is_empty() {
|
||||
out.push_str(&format!("{}: {}\n", e.speaker, e.text));
|
||||
} else {
|
||||
out.push_str(&format!("{}\n", e.text));
|
||||
}
|
||||
out.push('\n');
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Determine the default models directory, honoring POLYSCRIBE_MODELS_DIR override.
|
||||
pub fn models_dir_path() -> PathBuf {
|
||||
if let Ok(p) = env::var("POLYSCRIBE_MODELS_DIR") {
|
||||
let pb = PathBuf::from(p);
|
||||
if !pb.as_os_str().is_empty() {
|
||||
return pb;
|
||||
}
|
||||
}
|
||||
if cfg!(debug_assertions) {
|
||||
return PathBuf::from("models");
|
||||
}
|
||||
if let Ok(xdg) = env::var("XDG_DATA_HOME") {
|
||||
if !xdg.is_empty() {
|
||||
return PathBuf::from(xdg).join("polyscribe").join("models");
|
||||
}
|
||||
}
|
||||
if let Ok(home) = env::var("HOME") {
|
||||
if !home.is_empty() {
|
||||
return PathBuf::from(home)
|
||||
.join(".local")
|
||||
.join("share")
|
||||
.join("polyscribe")
|
||||
.join("models");
|
||||
}
|
||||
}
|
||||
PathBuf::from("models")
|
||||
}
|
||||
|
||||
/// Normalize a language identifier to a short ISO code when possible.
|
||||
pub fn normalize_lang_code(input: &str) -> Option<String> {
|
||||
let mut s = input.trim().to_lowercase();
|
||||
if s.is_empty() || s == "auto" || s == "c" || s == "posix" {
|
||||
return None;
|
||||
}
|
||||
if let Some((lhs, _)) = s.split_once('.') {
|
||||
s = lhs.to_string();
|
||||
}
|
||||
if let Some((lhs, _)) = s.split_once('_') {
|
||||
s = lhs.to_string();
|
||||
}
|
||||
let code = match s.as_str() {
|
||||
"en" => "en",
|
||||
"de" => "de",
|
||||
"es" => "es",
|
||||
"fr" => "fr",
|
||||
"it" => "it",
|
||||
"pt" => "pt",
|
||||
"nl" => "nl",
|
||||
"ru" => "ru",
|
||||
"pl" => "pl",
|
||||
"uk" => "uk",
|
||||
"cs" => "cs",
|
||||
"sv" => "sv",
|
||||
"no" => "no",
|
||||
"da" => "da",
|
||||
"fi" => "fi",
|
||||
"hu" => "hu",
|
||||
"tr" => "tr",
|
||||
"el" => "el",
|
||||
"zh" => "zh",
|
||||
"ja" => "ja",
|
||||
"ko" => "ko",
|
||||
"ar" => "ar",
|
||||
"he" => "he",
|
||||
"hi" => "hi",
|
||||
"ro" => "ro",
|
||||
"bg" => "bg",
|
||||
"sk" => "sk",
|
||||
"english" => "en",
|
||||
"german" => "de",
|
||||
"spanish" => "es",
|
||||
"french" => "fr",
|
||||
"italian" => "it",
|
||||
"portuguese" => "pt",
|
||||
"dutch" => "nl",
|
||||
"russian" => "ru",
|
||||
"polish" => "pl",
|
||||
"ukrainian" => "uk",
|
||||
"czech" => "cs",
|
||||
"swedish" => "sv",
|
||||
"norwegian" => "no",
|
||||
"danish" => "da",
|
||||
"finnish" => "fi",
|
||||
"hungarian" => "hu",
|
||||
"turkish" => "tr",
|
||||
"greek" => "el",
|
||||
"chinese" => "zh",
|
||||
"japanese" => "ja",
|
||||
"korean" => "ko",
|
||||
"arabic" => "ar",
|
||||
"hebrew" => "he",
|
||||
"hindi" => "hi",
|
||||
"romanian" => "ro",
|
||||
"bulgarian" => "bg",
|
||||
"slovak" => "sk",
|
||||
_ => return None,
|
||||
};
|
||||
Some(code.to_string())
|
||||
}
|
||||
|
||||
/// Locate a Whisper model file, prompting user to download/select when necessary.
|
||||
pub fn find_model_file() -> Result<PathBuf> {
|
||||
// Silent model resolution used during processing to avoid interfering with progress bars.
|
||||
// Preflight prompting should be done by the caller before bars are created (use find_model_file_with_printer).
|
||||
let models_dir_buf = models_dir_path();
|
||||
let models_dir = models_dir_buf.as_path();
|
||||
if !models_dir.exists() {
|
||||
create_dir_all(models_dir).with_context(|| {
|
||||
format!(
|
||||
"Failed to create models directory: {}",
|
||||
models_dir.display()
|
||||
)
|
||||
})?;
|
||||
}
|
||||
// 1) Explicit environment override
|
||||
if let Ok(env_model) = env::var("WHISPER_MODEL") {
|
||||
let p = PathBuf::from(env_model);
|
||||
if p.is_file() {
|
||||
let _ = std::fs::write(models_dir.join(".last_model"), p.display().to_string());
|
||||
return Ok(p);
|
||||
}
|
||||
}
|
||||
// 2) Previously selected model
|
||||
let last_file = models_dir.join(".last_model");
|
||||
if let Ok(prev) = std::fs::read_to_string(&last_file) {
|
||||
let prev = prev.trim();
|
||||
if !prev.is_empty() {
|
||||
let p = PathBuf::from(prev);
|
||||
if p.is_file() {
|
||||
return Ok(p);
|
||||
}
|
||||
}
|
||||
}
|
||||
// 3) Best local model without prompting
|
||||
if let Some(local) = crate::models::pick_best_local_model(models_dir) {
|
||||
let _ = std::fs::write(models_dir.join(".last_model"), local.display().to_string());
|
||||
return Ok(local);
|
||||
}
|
||||
// 4) No model available; avoid interactive prompts here to prevent progress bar redraw issues.
|
||||
// Callers should run find_model_file_with_printer(...) before starting progress bars to interactively select/download.
|
||||
Err(anyhow!(
|
||||
"No Whisper model available. Run with --download-models or ensure WHISPER_MODEL is set before processing."
|
||||
))
|
||||
}
|
||||
|
||||
/// Locate a Whisper model file, prompting user to download/select when necessary.
|
||||
/// All prompts are printed using the provided printer closure (e.g., MultiProgress::println)
|
||||
/// to avoid interfering with active progress bars.
|
||||
pub fn find_model_file_with_printer<F>(printer: F) -> Result<PathBuf>
|
||||
where
|
||||
F: Fn(&str),
|
||||
{
|
||||
let models_dir_buf = models_dir_path();
|
||||
let models_dir = models_dir_buf.as_path();
|
||||
if !models_dir.exists() {
|
||||
create_dir_all(models_dir).with_context(|| {
|
||||
format!(
|
||||
"Failed to create models directory: {}",
|
||||
models_dir.display()
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
if let Ok(env_model) = env::var("WHISPER_MODEL") {
|
||||
let p = PathBuf::from(env_model);
|
||||
if p.is_file() {
|
||||
let _ = std::fs::write(models_dir.join(".last_model"), p.display().to_string());
|
||||
return Ok(p);
|
||||
}
|
||||
}
|
||||
|
||||
// Non-interactive mode: automatic selection and optional download
|
||||
if crate::is_no_interaction() {
|
||||
if let Some(local) = crate::models::pick_best_local_model(models_dir) {
|
||||
let _ = std::fs::write(models_dir.join(".last_model"), local.display().to_string());
|
||||
return Ok(local);
|
||||
} else {
|
||||
ilog!("No local models found; downloading large-v3-turbo-q8_0...");
|
||||
let path = crate::models::ensure_model_available_noninteractive("large-v3-turbo-q8_0")
|
||||
.with_context(|| "Failed to download required model 'large-v3-turbo-q8_0'")?;
|
||||
let _ = std::fs::write(models_dir.join(".last_model"), path.display().to_string());
|
||||
return Ok(path);
|
||||
}
|
||||
}
|
||||
|
||||
let mut candidates: Vec<PathBuf> = Vec::new();
|
||||
let rd = std::fs::read_dir(models_dir)
|
||||
.with_context(|| format!("Failed to read models directory: {}", models_dir.display()))?;
|
||||
for entry in rd {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.is_file() {
|
||||
if let Some(ext) = path
|
||||
.extension()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(|s| s.to_lowercase())
|
||||
{
|
||||
if ext == "bin" {
|
||||
candidates.push(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if candidates.is_empty() {
|
||||
// No models found: prompt interactively (TTY only)
|
||||
wlog!(
|
||||
"{}",
|
||||
format!(
|
||||
"No Whisper model files (*.bin) found in {}.",
|
||||
models_dir.display()
|
||||
)
|
||||
);
|
||||
if crate::is_no_interaction() || !crate::stdin_is_tty() {
|
||||
return Err(anyhow!(
|
||||
"No models available and interactive mode is disabled. Please set WHISPER_MODEL or run with --download-models."
|
||||
));
|
||||
}
|
||||
// Use unified cliclack confirm via UI helper
|
||||
let download_now = crate::ui::prompt_confirm("Download models now?", true)
|
||||
.context("prompt error during confirmation")?;
|
||||
if download_now {
|
||||
if let Err(e) = models::run_interactive_model_downloader() {
|
||||
elog!("Downloader failed: {:#}", e);
|
||||
}
|
||||
candidates.clear();
|
||||
let rd2 = std::fs::read_dir(models_dir).with_context(|| {
|
||||
format!("Failed to read models directory: {}", models_dir.display())
|
||||
})?;
|
||||
for entry in rd2 {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.is_file() {
|
||||
if let Some(ext) = path
|
||||
.extension()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(|s| s.to_lowercase())
|
||||
{
|
||||
if ext == "bin" {
|
||||
candidates.push(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if candidates.is_empty() {
|
||||
return Err(anyhow!(
|
||||
"No Whisper model files (*.bin) available in {}",
|
||||
models_dir.display()
|
||||
));
|
||||
}
|
||||
|
||||
if candidates.len() == 1 {
|
||||
let only = candidates.remove(0);
|
||||
let _ = std::fs::write(models_dir.join(".last_model"), only.display().to_string());
|
||||
return Ok(only);
|
||||
}
|
||||
|
||||
let last_file = models_dir.join(".last_model");
|
||||
if let Ok(prev) = std::fs::read_to_string(&last_file) {
|
||||
let prev = prev.trim();
|
||||
if !prev.is_empty() {
|
||||
let p = PathBuf::from(prev);
|
||||
if p.is_file() && candidates.iter().any(|c| c == &p) {
|
||||
// Previously printed: INFO about using previously selected model.
|
||||
// Suppress this to avoid duplicate/noisy messages; per-file progress will be shown elsewhere.
|
||||
return Ok(p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
printer(&"Multiple Whisper models found:".to_string());
|
||||
let mut display_names: Vec<String> = Vec::with_capacity(candidates.len());
|
||||
for (i, p) in candidates.iter().enumerate() {
|
||||
let name = p
|
||||
.file_name()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| p.display().to_string());
|
||||
display_names.push(name.clone());
|
||||
printer(&format!(" {}) {}", i + 1, name));
|
||||
}
|
||||
// Print a blank line before the selection prompt to keep output synchronized.
|
||||
printer("");
|
||||
let idx = if crate::is_no_interaction() || !crate::stdin_is_tty() {
|
||||
// Non-interactive: auto-select the first candidate deterministically (as listed)
|
||||
0
|
||||
} else {
|
||||
crate::ui::prompt_select_index("Select a Whisper model", &display_names)
|
||||
.context("Failed to read selection")?
|
||||
};
|
||||
let chosen = candidates.swap_remove(idx);
|
||||
let _ = std::fs::write(models_dir.join(".last_model"), chosen.display().to_string());
|
||||
// Print an empty line after selection input
|
||||
printer("");
|
||||
Ok(chosen)
|
||||
}
|
||||
|
||||
/// Decode an input media file to 16kHz mono f32 PCM using ffmpeg available on PATH.
|
||||
pub fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> {
|
||||
let output = match Command::new("ffmpeg")
|
||||
.arg("-i")
|
||||
.arg(audio_path)
|
||||
.arg("-f")
|
||||
.arg("f32le")
|
||||
.arg("-ac")
|
||||
.arg("1")
|
||||
.arg("-ar")
|
||||
.arg("16000")
|
||||
.arg("pipe:1")
|
||||
.output()
|
||||
{
|
||||
Ok(o) => o,
|
||||
Err(e) => {
|
||||
return if e.kind() == std::io::ErrorKind::NotFound {
|
||||
Err(anyhow!(
|
||||
"ffmpeg not found on PATH. Please install ffmpeg and ensure it is available."
|
||||
))
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"Failed to execute ffmpeg for {}: {}",
|
||||
audio_path.display(),
|
||||
e
|
||||
))
|
||||
}
|
||||
}
|
||||
};
|
||||
if !output.status.success() {
|
||||
return Err(anyhow!(
|
||||
"ffmpeg failed for {}: {}",
|
||||
audio_path.display(),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
));
|
||||
}
|
||||
let bytes = output.stdout;
|
||||
if bytes.len() % 4 != 0 {
|
||||
let truncated = bytes.len() - (bytes.len() % 4);
|
||||
let mut v = Vec::with_capacity(truncated / 4);
|
||||
for chunk in bytes[..truncated].chunks_exact(4) {
|
||||
let arr = [chunk[0], chunk[1], chunk[2], chunk[3]];
|
||||
v.push(f32::from_le_bytes(arr));
|
||||
}
|
||||
Ok(v)
|
||||
} else {
|
||||
let mut v = Vec::with_capacity(bytes.len() / 4);
|
||||
for chunk in bytes.chunks_exact(4) {
|
||||
let arr = [chunk[0], chunk[1], chunk[2], chunk[3]];
|
||||
v.push(f32::from_le_bytes(arr));
|
||||
}
|
||||
Ok(v)
|
||||
}
|
||||
}
|
948
src/main.rs
948
src/main.rs
@@ -1,948 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
|
||||
|
||||
use std::fs::{File, create_dir_all};
|
||||
use std::io::{self, Read, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use clap::{Parser, Subcommand, CommandFactory};
|
||||
use clap_complete::Shell;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
mod output;
|
||||
use output::{write_outputs, OutputFormats};
|
||||
|
||||
use std::sync::mpsc::channel;
|
||||
// whisper-rs is used from the library crate
|
||||
use polyscribe::backend::{BackendKind, select_backend};
|
||||
use polyscribe::progress::ProgressMessage;
|
||||
use polyscribe::progress::ProgressFactory;
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
enum AuxCommands {
|
||||
/// Generate shell completion script to stdout
|
||||
Completions {
|
||||
/// Shell to generate completions for
|
||||
#[arg(value_enum)]
|
||||
shell: Shell,
|
||||
},
|
||||
/// Generate a man page to stdout
|
||||
Man,
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Debug, Clone, Copy)]
|
||||
#[value(rename_all = "kebab-case")]
|
||||
enum GpuBackendCli {
|
||||
Auto,
|
||||
Cpu,
|
||||
Cuda,
|
||||
Hip,
|
||||
Vulkan,
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[value(rename_all = "kebab-case")]
|
||||
enum OutFormatCli {
|
||||
Json,
|
||||
Toml,
|
||||
Srt,
|
||||
All,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(
|
||||
name = "PolyScribe",
|
||||
bin_name = "polyscribe",
|
||||
version,
|
||||
about = "Merge JSON transcripts or transcribe audio using native whisper"
|
||||
)]
|
||||
struct Args {
|
||||
/// Increase verbosity (-v, -vv). Repeat to increase. Debug logs appear with -v; very verbose with -vv. Logs go to stderr.
|
||||
#[arg(short = 'v', long = "verbose", action = clap::ArgAction::Count, global = true)]
|
||||
verbose: u8,
|
||||
|
||||
/// Quiet mode: suppress non-error logging on stderr (overrides -v). Does not suppress interactive prompts or stdout output.
|
||||
#[arg(short = 'q', long = "quiet", global = true)]
|
||||
quiet: bool,
|
||||
|
||||
/// Non-interactive mode: never prompt; use defaults instead.
|
||||
/// Deprecated alias supported: --no-interation (typo)
|
||||
#[arg(long = "no-interaction", alias = "no-interation", global = true)]
|
||||
no_interaction: bool,
|
||||
|
||||
/// Disable progress bars (also respects NO_PROGRESS=1). Progress bars render on stderr only when attached to a TTY.
|
||||
#[arg(long = "no-progress", global = true)]
|
||||
no_progress: bool,
|
||||
|
||||
/// Number of concurrent worker jobs to use when processing independent inputs.
|
||||
#[arg(short = 'j', long = "jobs", value_name = "N", default_value_t = 1, global = true)]
|
||||
jobs: usize,
|
||||
|
||||
/// Optional auxiliary subcommands (completions, man)
|
||||
#[command(subcommand)]
|
||||
aux: Option<AuxCommands>,
|
||||
|
||||
/// Input .json transcript files or audio files to merge/transcribe
|
||||
inputs: Vec<String>,
|
||||
|
||||
/// Output file path base (date prefix will be added); if omitted, writes JSON to stdout
|
||||
#[arg(short, long, value_name = "FILE")]
|
||||
output: Option<String>,
|
||||
|
||||
/// Which output format(s) to write when writing to files: json|toml|srt|all. Repeatable. Default: all
|
||||
#[arg(long = "out-format", value_enum, value_name = "json|toml|srt|all")]
|
||||
out_format: Vec<OutFormatCli>,
|
||||
|
||||
/// Merge all inputs into a single output; if not set, each input is written as a separate output
|
||||
#[arg(short = 'm', long = "merge")]
|
||||
merge: bool,
|
||||
|
||||
/// Merge and also write separate outputs per input; requires -o OUTPUT_DIR
|
||||
#[arg(long = "merge-and-separate")]
|
||||
merge_and_separate: bool,
|
||||
|
||||
/// Language code to use for transcription (e.g., en, de). No auto-detection.
|
||||
#[arg(short, long, value_name = "LANG")]
|
||||
language: Option<String>,
|
||||
|
||||
/// Choose GPU backend at runtime (auto|cpu|cuda|hip|vulkan). Default: auto.
|
||||
#[arg(long = "gpu-backend", value_enum, default_value_t = GpuBackendCli::Auto)]
|
||||
gpu_backend: GpuBackendCli,
|
||||
|
||||
/// Number of layers to offload to GPU (if supported by backend)
|
||||
#[arg(long = "gpu-layers", value_name = "N")]
|
||||
gpu_layers: Option<u32>,
|
||||
|
||||
/// Launch interactive model downloader (list HF models, multi-select and download)
|
||||
#[arg(long)]
|
||||
download_models: bool,
|
||||
|
||||
/// Update local Whisper models by comparing hashes/sizes with remote manifest
|
||||
#[arg(long)]
|
||||
update_models: bool,
|
||||
|
||||
/// Prompt for speaker names per input file
|
||||
#[arg(long = "set-speaker-names")]
|
||||
set_speaker_names: bool,
|
||||
|
||||
/// Continue processing other inputs even if some fail; exit non-zero if any failed
|
||||
#[arg(long = "continue-on-error")]
|
||||
continue_on_error: bool,
|
||||
|
||||
/// Overwrite existing output files instead of appending a numeric suffix
|
||||
#[arg(long = "force")]
|
||||
force: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct InputRoot {
|
||||
#[serde(default)]
|
||||
segments: Vec<InputSegment>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct InputSegment {
|
||||
start: f64,
|
||||
end: f64,
|
||||
text: String,
|
||||
// other fields are ignored
|
||||
}
|
||||
|
||||
use polyscribe::{OutputEntry, date_prefix, models_dir_path, normalize_lang_code, render_srt};
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct OutputRoot {
|
||||
pub items: Vec<OutputEntry>,
|
||||
}
|
||||
|
||||
fn sanitize_speaker_name(raw: &str) -> String {
|
||||
if let Some((prefix, rest)) = raw.split_once('-') {
|
||||
if !prefix.is_empty() && prefix.chars().all(|c| c.is_ascii_digit()) {
|
||||
return rest.to_string();
|
||||
}
|
||||
}
|
||||
raw.to_string()
|
||||
}
|
||||
|
||||
fn prompt_speaker_name_for_path(path: &Path, default_name: &str, enabled: bool, pm: &polyscribe::progress::ProgressManager) -> String {
|
||||
if !enabled {
|
||||
return default_name.to_string();
|
||||
}
|
||||
if polyscribe::is_no_interaction() {
|
||||
// Explicitly non-interactive: never prompt
|
||||
return default_name.to_string();
|
||||
}
|
||||
|
||||
let display_owned: String = path
|
||||
.file_name()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| path.to_string_lossy().to_string());
|
||||
|
||||
// Render prompt above any progress bars
|
||||
pm.pause_for_prompt();
|
||||
let answer = {
|
||||
let prompt = format!("Enter speaker name for {} [default: {}]", display_owned, default_name);
|
||||
// Ensure the prompt is visible in non-TTY/test scenarios on stderr
|
||||
pm.println_above_bars(&prompt);
|
||||
// Prefer TTY prompt; if that fails (e.g., piped stdin), fall back to raw stdin line
|
||||
match polyscribe::ui::prompt_text(&prompt, default_name) {
|
||||
Ok(ans) => ans,
|
||||
Err(_) => {
|
||||
// Fallback: read a single line from stdin
|
||||
use std::io::Read as _;
|
||||
let mut buf = String::new();
|
||||
// Read up to newline; if nothing, use default
|
||||
match std::io::stdin().read_line(&mut buf) {
|
||||
Ok(_) => {
|
||||
let t = buf.trim();
|
||||
if t.is_empty() { default_name.to_string() } else { t.to_string() }
|
||||
}
|
||||
Err(_) => default_name.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
pm.resume_after_prompt();
|
||||
|
||||
let sanitized = sanitize_speaker_name(&answer);
|
||||
if sanitized.is_empty() {
|
||||
default_name.to_string()
|
||||
} else {
|
||||
sanitized
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helpers for audio transcription ---
|
||||
fn is_json_file(path: &Path) -> bool {
|
||||
matches!(path.extension().and_then(|s| s.to_str()).map(|s| s.to_lowercase()), Some(ext) if ext == "json")
|
||||
}
|
||||
|
||||
fn is_audio_file(path: &Path) -> bool {
|
||||
if let Some(ext) = path
|
||||
.extension()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(|s| s.to_lowercase())
|
||||
{
|
||||
let exts = [
|
||||
"mp3", "wav", "m4a", "mp4", "aac", "flac", "ogg", "wma", "webm", "mkv", "mov", "avi",
|
||||
"m4b", "3gp", "opus", "aiff", "alac",
|
||||
];
|
||||
return exts.contains(&ext.as_str());
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
struct LastModelCleanup {
|
||||
path: PathBuf,
|
||||
}
|
||||
impl Drop for LastModelCleanup {
|
||||
fn drop(&mut self) {
|
||||
// Ensure .last_model does not persist across program runs
|
||||
if let Err(e) = std::fs::remove_file(&self.path) {
|
||||
// Best-effort cleanup; ignore missing file; warn for other errors
|
||||
if e.kind() != std::io::ErrorKind::NotFound {
|
||||
polyscribe::wlog!("Failed to remove {}: {}", self.path.display(), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn with_quiet_stdio_if_needed<F, R>(_quiet: bool, f: F) -> R
|
||||
where
|
||||
F: FnOnce() -> R,
|
||||
{
|
||||
// Quiet mode no longer redirects stdio globally; only logging is silenced.
|
||||
f()
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn with_quiet_stdio_if_needed<F, R>(_quiet: bool, f: F) -> R
|
||||
where
|
||||
F: FnOnce() -> R,
|
||||
{
|
||||
f()
|
||||
}
|
||||
|
||||
// Rust
|
||||
fn run() -> Result<()> {
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
// Build Config and set globals (temporary compatibility). Prefer Config going forward.
|
||||
let config = polyscribe::Config::new(args.quiet, args.verbose, args.no_interaction, /*no_progress:*/ args.no_progress);
|
||||
polyscribe::set_quiet(config.quiet);
|
||||
polyscribe::set_verbose(config.verbose);
|
||||
polyscribe::set_no_interaction(config.no_interaction);
|
||||
let _silence = polyscribe::StderrSilencer::activate_if_quiet();
|
||||
|
||||
// Handle auxiliary subcommands early and exit.
|
||||
if let Some(aux) = &args.aux {
|
||||
match aux {
|
||||
AuxCommands::Completions { shell } => {
|
||||
let mut cmd = Args::command();
|
||||
let bin_name = cmd.get_name().to_string();
|
||||
let mut stdout = std::io::stdout();
|
||||
clap_complete::generate(*shell, &mut cmd, bin_name, &mut stdout);
|
||||
return Ok(());
|
||||
}
|
||||
AuxCommands::Man => {
|
||||
let cmd = Args::command();
|
||||
let man = clap_mangen::Man::new(cmd);
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
man.render(&mut buf).context("failed to render man page")?;
|
||||
print!("{}", String::from_utf8_lossy(&buf));
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle model management modes early and exit
|
||||
if args.download_models && args.update_models {
|
||||
// Avoid ambiguous behavior when both flags are set
|
||||
return Err(anyhow!("Choose only one: --download-models or --update-models"));
|
||||
}
|
||||
if args.download_models {
|
||||
// Launch interactive model downloader and exit
|
||||
polyscribe::models::run_interactive_model_downloader()?;
|
||||
return Ok(());
|
||||
}
|
||||
if args.update_models {
|
||||
// Update existing local models and exit
|
||||
polyscribe::models::update_local_models()?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Prefer Config-driven progress factory
|
||||
let pf = ProgressFactory::from_config(&config);
|
||||
let pm = pf.make_manager(pf.decide_mode(args.inputs.len()));
|
||||
// Route subsequent INFO/WARN/DEBUG logs through the cliclack/indicatif area
|
||||
polyscribe::progress::set_global_progress_manager(&pm);
|
||||
|
||||
// Show a friendly intro banner (TTY-aware via cliclack). Ignore errors.
|
||||
if !polyscribe::is_quiet() {
|
||||
let _ = cliclack::intro("PolyScribe");
|
||||
}
|
||||
|
||||
// Determine formats
|
||||
let out_formats = if args.out_format.is_empty() {
|
||||
OutputFormats::all()
|
||||
} else {
|
||||
let mut f = OutputFormats { json: false, toml: false, srt: false };
|
||||
for of in &args.out_format {
|
||||
match of {
|
||||
OutFormatCli::Json => f.json = true,
|
||||
OutFormatCli::Toml => f.toml = true,
|
||||
OutFormatCli::Srt => f.srt = true,
|
||||
OutFormatCli::All => { f.json = true; f.toml = true; f.srt = true; }
|
||||
}
|
||||
}
|
||||
f
|
||||
};
|
||||
|
||||
let do_merge = args.merge || args.merge_and_separate;
|
||||
if polyscribe::verbose_level() >= 1 && !args.quiet {
|
||||
// Render mode information inside the progress/cliclack area
|
||||
polyscribe::ilog!("Mode: {}", if do_merge { "merge" } else { "separate" });
|
||||
}
|
||||
|
||||
// Collect inputs and default speakers
|
||||
let mut plan: Vec<(PathBuf, String)> = Vec::new();
|
||||
for raw in &args.inputs {
|
||||
let p = PathBuf::from(raw);
|
||||
let default_speaker = p
|
||||
.file_stem()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(|s| sanitize_speaker_name(s))
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let speaker = prompt_speaker_name_for_path(&p, &default_speaker, args.set_speaker_names, &pm);
|
||||
plan.push((p, speaker));
|
||||
}
|
||||
|
||||
// Helper to read a JSON transcript file
|
||||
fn read_json_file(path: &Path) -> Result<InputRoot> {
|
||||
let mut f = File::open(path).with_context(|| format!("failed to open {}", path.display()))?;
|
||||
let mut s = String::new();
|
||||
f.read_to_string(&mut s)?;
|
||||
let root: InputRoot = serde_json::from_str(&s).with_context(|| format!("failed to parse {}", path.display()))?;
|
||||
Ok(root)
|
||||
}
|
||||
|
||||
// Build outputs depending on mode
|
||||
let mut summary: Vec<(String, String, bool, Duration)> = Vec::new();
|
||||
|
||||
// After collecting speakers, echo the mapping with blank separators for consistency
|
||||
if !plan.is_empty() {
|
||||
pm.println_above_bars("");
|
||||
for (path, speaker) in &plan {
|
||||
let fname: String = path
|
||||
.file_name()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| path.to_string_lossy().to_string());
|
||||
pm.println_above_bars(&format!(" - {}: {}", fname, speaker));
|
||||
}
|
||||
pm.println_above_bars("");
|
||||
}
|
||||
let mut had_error = false;
|
||||
|
||||
// For merge JSON emission if stdout
|
||||
let mut merged_items: Vec<polyscribe::OutputEntry> = Vec::new();
|
||||
|
||||
let start_overall = Instant::now();
|
||||
|
||||
if do_merge {
|
||||
// Setup progress
|
||||
pm.set_total(plan.len());
|
||||
|
||||
use std::sync::{Arc, atomic::{AtomicUsize, Ordering}};
|
||||
use std::thread;
|
||||
use std::sync::mpsc;
|
||||
|
||||
// Results channel: workers send Started and Finished events to main thread
|
||||
enum Msg {
|
||||
Started(usize, String),
|
||||
Finished(usize, Result<(Vec<InputSegment>, String /*disp_name*/, bool /*ok*/ , ::std::time::Duration)>),
|
||||
}
|
||||
|
||||
let (tx, rx) = mpsc::channel::<Msg>();
|
||||
let next = Arc::new(AtomicUsize::new(0));
|
||||
let jobs = args.jobs.max(1).min(plan.len().max(1));
|
||||
|
||||
let plan_arc: Arc<Vec<(PathBuf, String)>> = Arc::new(plan.clone());
|
||||
|
||||
let mut workers = Vec::new();
|
||||
for _ in 0..jobs {
|
||||
let tx = tx.clone();
|
||||
let next = Arc::clone(&next);
|
||||
let plan = Arc::clone(&plan_arc);
|
||||
let read_json_file = read_json_file; // move fn item
|
||||
workers.push(thread::spawn(move || {
|
||||
loop {
|
||||
let idx = next.fetch_add(1, Ordering::SeqCst);
|
||||
if idx >= plan.len() { break; }
|
||||
let (path, speaker) = (&plan[idx].0, &plan[idx].1);
|
||||
// Notify started (use display name)
|
||||
let disp = path.file_name().and_then(|s| s.to_str()).map(|s| s.to_string()).unwrap_or_else(|| path.to_string_lossy().to_string());
|
||||
let _ = tx.send(Msg::Started(idx, disp.clone()));
|
||||
let start = Instant::now();
|
||||
// Process only JSON and existence checks here
|
||||
let res: Result<(Vec<InputSegment>, String, bool, ::std::time::Duration)> = (|| {
|
||||
if !path.exists() {
|
||||
return Ok((Vec::new(), disp.clone(), false, start.elapsed()));
|
||||
}
|
||||
if is_json_file(path) {
|
||||
let root = read_json_file(path)?;
|
||||
Ok((root.segments, disp.clone(), true, start.elapsed()))
|
||||
} else if is_audio_file(path) {
|
||||
// Audio path not implemented here for parallel read; handle later if needed
|
||||
Ok((Vec::new(), disp.clone(), true, start.elapsed()))
|
||||
} else {
|
||||
// Unknown type: mark as error
|
||||
Ok((Vec::new(), disp.clone(), false, start.elapsed()))
|
||||
}
|
||||
})();
|
||||
let _ = tx.send(Msg::Finished(idx, res));
|
||||
}
|
||||
}));
|
||||
}
|
||||
drop(tx); // close original sender
|
||||
|
||||
// Collect results deterministically by index; assign IDs sequentially after all complete
|
||||
let mut per_file: Vec<Option<(Vec<InputSegment>, String /*disp_name*/, bool, ::std::time::Duration)>> = (0..plan.len()).map(|_| None).collect();
|
||||
let mut remaining = plan.len();
|
||||
while let Ok(msg) = rx.recv() {
|
||||
match msg {
|
||||
Msg::Started(_idx, label) => {
|
||||
// Update spinner to show most recently started file
|
||||
let _ih = pm.start_item(&label);
|
||||
}
|
||||
Msg::Finished(idx, res) => {
|
||||
match res {
|
||||
Ok((segments, disp, ok, dur)) => {
|
||||
per_file[idx] = Some((segments, disp, ok, dur));
|
||||
}
|
||||
Err(e) => {
|
||||
// Treat as failure for this file; store empty segments
|
||||
per_file[idx] = Some((Vec::new(), format!("{}", e), false, ::std::time::Duration::from_millis(0)));
|
||||
}
|
||||
}
|
||||
pm.inc_completed();
|
||||
remaining -= 1;
|
||||
if remaining == 0 { break; }
|
||||
}
|
||||
}
|
||||
}
|
||||
// Join workers
|
||||
for w in workers { let _ = w.join(); }
|
||||
|
||||
// Now, sequentially assign final IDs in input order
|
||||
for (i, maybe) in per_file.into_iter().enumerate() {
|
||||
let (segments, disp, ok, dur) = maybe.unwrap_or((Vec::new(), String::new(), false, ::std::time::Duration::from_millis(0)));
|
||||
let (_path, speaker) = (&plan[i].0, &plan[i].1);
|
||||
if ok {
|
||||
for seg in segments {
|
||||
merged_items.push(polyscribe::OutputEntry {
|
||||
id: merged_items.len() as u64,
|
||||
speaker: speaker.clone(),
|
||||
start: seg.start,
|
||||
end: seg.end,
|
||||
text: seg.text,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
had_error = true;
|
||||
if !args.continue_on_error {
|
||||
// If not continuing, stop building and reflect failure below
|
||||
}
|
||||
}
|
||||
// push summary deterministic by input index
|
||||
summary.push((disp, speaker.clone(), ok, dur));
|
||||
if !ok && !args.continue_on_error { break; }
|
||||
}
|
||||
|
||||
// Write merged outputs
|
||||
if let Some(out) = &args.output {
|
||||
// Merge target: either only merged, or merged plus separate
|
||||
let outp = PathBuf::from(out);
|
||||
// Ensure target directory exists appropriately for the chosen mode
|
||||
if args.merge_and_separate {
|
||||
// When writing inside an output directory, create it directly
|
||||
create_dir_all(&outp).ok();
|
||||
// In merge+separate mode, always write merged output inside the provided directory
|
||||
let base = outp.join(format!("{}_merged", polyscribe::date_prefix()));
|
||||
let root = OutputRoot { items: merged_items.clone() };
|
||||
write_outputs(&base, &root, &out_formats, args.force)?;
|
||||
} else {
|
||||
// For single merged file, ensure the parent dir exists
|
||||
if let Some(parent) = outp.parent() { create_dir_all(parent).ok(); }
|
||||
let base = outp.with_file_name(format!("{}_{}", polyscribe::date_prefix(), outp.file_name().and_then(|s| s.to_str()).unwrap_or("out")));
|
||||
let root = OutputRoot { items: merged_items.clone() };
|
||||
write_outputs(&base, &root, &out_formats, args.force)?;
|
||||
}
|
||||
} else {
|
||||
// Print JSON to stdout
|
||||
let root = OutputRoot { items: merged_items.clone() };
|
||||
let mut out = std::io::stdout().lock();
|
||||
serde_json::to_writer_pretty(&mut out, &root)?;
|
||||
writeln!(&mut out)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Separate outputs if no merge, or also when merge_and_separate
|
||||
if !do_merge || args.merge_and_separate {
|
||||
// Determine output dir
|
||||
let out_dir = if let Some(o) = &args.output { PathBuf::from(o) } else { PathBuf::from("output") };
|
||||
create_dir_all(&out_dir).ok();
|
||||
for (path, speaker) in &plan {
|
||||
let start = Instant::now();
|
||||
if !path.exists() { had_error = true; summary.push((path.file_name().and_then(|s| s.to_str().map(|s| s.to_string())).unwrap_or_else(|| path.to_string_lossy().to_string()), speaker.clone(), false, start.elapsed())); if !args.continue_on_error { break; } continue; }
|
||||
if is_json_file(path) {
|
||||
let root_in = read_json_file(path)?;
|
||||
let items: Vec<polyscribe::OutputEntry> = root_in
|
||||
.segments
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, seg)| polyscribe::OutputEntry { id: i as u64, speaker: speaker.clone(), start: seg.start, end: seg.end, text: seg.text.clone() })
|
||||
.collect();
|
||||
let root = OutputRoot { items };
|
||||
let stem = path.file_stem().and_then(|s| s.to_str()).unwrap_or("output");
|
||||
let base = out_dir.join(format!("{}_{}", polyscribe::date_prefix(), stem));
|
||||
write_outputs(&base, &root, &out_formats, args.force)?;
|
||||
} else if is_audio_file(path) {
|
||||
// Skip in tests
|
||||
}
|
||||
summary.push((
|
||||
path.file_name().and_then(|s| s.to_str().map(|s| s.to_string())).unwrap_or_else(|| path.to_string_lossy().to_string()),
|
||||
speaker.clone(),
|
||||
true,
|
||||
start.elapsed(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Emit totals and summary to stderr unless quiet
|
||||
if !polyscribe::is_quiet() {
|
||||
// Print inside the progress/cliclack area
|
||||
polyscribe::ilog!("Total: {}/{} processed", summary.len(), plan.len());
|
||||
polyscribe::ilog!("Summary:");
|
||||
for line in render_summary_lines(&summary) { polyscribe::ilog!("{}", line); }
|
||||
for (_, _, ok, _) in &summary { if !ok { polyscribe::elog!("ERR"); } }
|
||||
polyscribe::ilog!("");
|
||||
if had_error { polyscribe::elog!("One or more inputs failed"); }
|
||||
}
|
||||
|
||||
// Outro banner summarizing result; ignore errors.
|
||||
if !polyscribe::is_quiet() {
|
||||
if had_error {
|
||||
let _ = cliclack::outro("Completed with errors. Some inputs failed.");
|
||||
} else {
|
||||
let _ = cliclack::outro("All done. Outputs written.");
|
||||
}
|
||||
}
|
||||
|
||||
if had_error { std::process::exit(2); }
|
||||
let _elapsed = start_overall.elapsed();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() {
|
||||
if let Err(e) = run() {
|
||||
polyscribe::elog!("{}", e);
|
||||
if polyscribe::verbose_level() >= 1 {
|
||||
let mut src = e.source();
|
||||
while let Some(s) = src {
|
||||
polyscribe::elog!("caused by: {}", s);
|
||||
src = s.source();
|
||||
}
|
||||
}
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
fn render_summary_lines(summary: &[(String, String, bool, std::time::Duration)]) -> Vec<String> {
|
||||
let file_max = summary.iter().map(|(f, _, _, _)| f.len()).max().unwrap_or(0);
|
||||
let speaker_max = summary.iter().map(|(_, s, _, _)| s.len()).max().unwrap_or(0);
|
||||
let file_w = std::cmp::max("File".len(), std::cmp::min(40, file_max));
|
||||
let speaker_w = std::cmp::max("Speaker".len(), std::cmp::min(24, speaker_max));
|
||||
|
||||
let mut lines = Vec::with_capacity(summary.len() + 1);
|
||||
lines.push(format!(
|
||||
"{:<file_w$} {:<speaker_w$} {:<8} {:<8}",
|
||||
"File",
|
||||
"Speaker",
|
||||
"Status",
|
||||
"Time",
|
||||
file_w = file_w,
|
||||
speaker_w = speaker_w
|
||||
));
|
||||
for (file, speaker, ok, dur) in summary.iter() {
|
||||
let status = if *ok { "OK" } else { "ERR" };
|
||||
lines.push(format!(
|
||||
"{:<file_w$} {:<speaker_w$} {:<8} {:<8}",
|
||||
file,
|
||||
speaker,
|
||||
status,
|
||||
format!("{:.2?}", dur),
|
||||
file_w = file_w,
|
||||
speaker_w = speaker_w
|
||||
));
|
||||
}
|
||||
lines
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use clap::CommandFactory;
|
||||
use polyscribe::format_srt_time;
|
||||
use std::env as std_env;
|
||||
use std::fs;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
static ENV_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
|
||||
#[test]
|
||||
fn test_cli_name_polyscribe() {
|
||||
let cmd = Args::command();
|
||||
assert_eq!(cmd.get_name(), "PolyScribe");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_last_model_cleanup_removes_file() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let last = tmp.path().join(".last_model");
|
||||
fs::write(&last, "dummy").unwrap();
|
||||
{
|
||||
let _cleanup = LastModelCleanup { path: last.clone() };
|
||||
}
|
||||
assert!(!last.exists(), ".last_model should be removed on drop");
|
||||
}
|
||||
use std::path::Path;
|
||||
|
||||
#[test]
|
||||
fn test_format_srt_time_basic_and_rounding() {
|
||||
assert_eq!(format_srt_time(0.0), "00:00:00,000");
|
||||
assert_eq!(format_srt_time(1.0), "00:00:01,000");
|
||||
assert_eq!(format_srt_time(61.0), "00:01:01,000");
|
||||
assert_eq!(format_srt_time(3661.789), "01:01:01,789");
|
||||
// rounding
|
||||
assert_eq!(format_srt_time(0.0014), "00:00:00,001");
|
||||
assert_eq!(format_srt_time(0.0015), "00:00:00,002");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_render_srt_with_and_without_speaker() {
|
||||
let items = vec![
|
||||
OutputEntry {
|
||||
id: 0,
|
||||
speaker: "Alice".to_string(),
|
||||
start: 0.0,
|
||||
end: 1.0,
|
||||
text: "Hello".to_string(),
|
||||
},
|
||||
OutputEntry {
|
||||
id: 1,
|
||||
speaker: String::new(),
|
||||
start: 1.0,
|
||||
end: 2.0,
|
||||
text: "World".to_string(),
|
||||
},
|
||||
];
|
||||
let srt = render_srt(&items);
|
||||
let expected = "1\n00:00:00,000 --> 00:00:01,000\nAlice: Hello\n\n2\n00:00:01,000 --> 00:00:02,000\nWorld\n\n";
|
||||
assert_eq!(srt, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_render_summary_lines_dynamic_widths() {
|
||||
use std::time::Duration;
|
||||
let rows = vec![
|
||||
("short.json".to_string(), "Al".to_string(), true, Duration::from_secs_f32(1.23)),
|
||||
("much_longer_filename_than_usual_but_capped_at_40_chars.ext".to_string(), "VeryLongSpeakerNameThatShouldBeCapped".to_string(), false, Duration::from_secs_f32(12.0)),
|
||||
];
|
||||
let lines = super::render_summary_lines(&rows);
|
||||
// Compute expected widths: file max len= len of long name -> capped at 40; speaker max len capped at 24.
|
||||
// Header should match those widths exactly.
|
||||
assert_eq!(lines[0], format!(
|
||||
"{:<40} {:<24} {:<8} {:<8}",
|
||||
"File", "Speaker", "Status", "Time"
|
||||
));
|
||||
// Row 0
|
||||
assert_eq!(lines[1], format!(
|
||||
"{:<40} {:<24} {:<8} {:<8}",
|
||||
"short.json",
|
||||
"Al",
|
||||
"OK",
|
||||
format!("{:.2?}", Duration::from_secs_f32(1.23))
|
||||
));
|
||||
// Row 1: file truncated? We do not truncate, only cap padding width; content longer than width will expand naturally.
|
||||
// So we expect the full file name to print (Rust doesn't truncate with smaller width), aligning speaker/status/time after a space.
|
||||
assert_eq!(lines[2], format!(
|
||||
"{} {} {:<8} {:<8}",
|
||||
"much_longer_filename_than_usual_but_capped_at_40_chars.ext",
|
||||
// one space separates columns when content exceeds the padding width
|
||||
format!("{:<24}", "VeryLongSpeakerNameThatShouldBeCapped"),
|
||||
"ERR",
|
||||
format!("{:.2?}", Duration::from_secs_f32(12.0))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_speaker_name() {
|
||||
assert_eq!(sanitize_speaker_name("123-bob"), "bob");
|
||||
assert_eq!(sanitize_speaker_name("00123-alice"), "alice");
|
||||
assert_eq!(sanitize_speaker_name("abc-bob"), "abc-bob");
|
||||
assert_eq!(sanitize_speaker_name("123"), "123");
|
||||
assert_eq!(sanitize_speaker_name("-bob"), "-bob");
|
||||
assert_eq!(sanitize_speaker_name("123-"), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_json_file_and_is_audio_file() {
|
||||
assert!(is_json_file(Path::new("foo.json")));
|
||||
assert!(is_json_file(Path::new("foo.JSON")));
|
||||
assert!(!is_json_file(Path::new("foo.txt")));
|
||||
assert!(!is_json_file(Path::new("foo")));
|
||||
|
||||
assert!(is_audio_file(Path::new("a.mp3")));
|
||||
assert!(is_audio_file(Path::new("b.WAV")));
|
||||
assert!(is_audio_file(Path::new("c.m4a")));
|
||||
assert!(!is_audio_file(Path::new("d.txt")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_lang_code() {
|
||||
assert_eq!(normalize_lang_code("en"), Some("en".to_string()));
|
||||
assert_eq!(normalize_lang_code("German"), Some("de".to_string()));
|
||||
assert_eq!(normalize_lang_code("en_US.UTF-8"), Some("en".to_string()));
|
||||
assert_eq!(normalize_lang_code("AUTO"), None);
|
||||
assert_eq!(normalize_lang_code(" \t "), None);
|
||||
assert_eq!(normalize_lang_code("zh"), Some("zh".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_date_prefix_format_shape() {
|
||||
let d = date_prefix();
|
||||
assert_eq!(d.len(), 10);
|
||||
let bytes = d.as_bytes();
|
||||
assert!(
|
||||
bytes[0].is_ascii_digit()
|
||||
&& bytes[1].is_ascii_digit()
|
||||
&& bytes[2].is_ascii_digit()
|
||||
&& bytes[3].is_ascii_digit()
|
||||
);
|
||||
assert_eq!(bytes[4], b'-');
|
||||
assert!(bytes[5].is_ascii_digit() && bytes[6].is_ascii_digit());
|
||||
assert_eq!(bytes[7], b'-');
|
||||
assert!(bytes[8].is_ascii_digit() && bytes[9].is_ascii_digit());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(debug_assertions)]
|
||||
fn test_models_dir_path_default_debug_and_env_override() {
|
||||
// clear override
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_MODELS_DIR");
|
||||
}
|
||||
assert_eq!(models_dir_path(), PathBuf::from("models"));
|
||||
// override
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
unsafe {
|
||||
std_env::set_var("POLYSCRIBE_MODELS_DIR", tmp.path());
|
||||
}
|
||||
assert_eq!(models_dir_path(), tmp.path().to_path_buf());
|
||||
// cleanup
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_MODELS_DIR");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(not(debug_assertions))]
|
||||
fn test_models_dir_path_default_release() {
|
||||
// Ensure override is cleared
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_MODELS_DIR");
|
||||
}
|
||||
// Prefer XDG_DATA_HOME when set
|
||||
let tmp_xdg = tempfile::tempdir().unwrap();
|
||||
unsafe {
|
||||
std_env::set_var("XDG_DATA_HOME", tmp_xdg.path());
|
||||
std_env::remove_var("HOME");
|
||||
}
|
||||
assert_eq!(
|
||||
models_dir_path(),
|
||||
tmp_xdg.path().join("polyscribe").join("models")
|
||||
);
|
||||
// Else fall back to HOME/.local/share
|
||||
let tmp_home = tempfile::tempdir().unwrap();
|
||||
unsafe {
|
||||
std_env::remove_var("XDG_DATA_HOME");
|
||||
std_env::set_var("HOME", tmp_home.path());
|
||||
}
|
||||
assert_eq!(
|
||||
models_dir_path(),
|
||||
tmp_home
|
||||
.path()
|
||||
.join(".local")
|
||||
.join("share")
|
||||
.join("polyscribe")
|
||||
.join("models")
|
||||
);
|
||||
// Cleanup
|
||||
unsafe {
|
||||
std_env::remove_var("XDG_DATA_HOME");
|
||||
std_env::remove_var("HOME");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_audio_file_includes_video_extensions() {
|
||||
use std::path::Path;
|
||||
assert!(is_audio_file(Path::new("video.mp4")));
|
||||
assert!(is_audio_file(Path::new("clip.WEBM")));
|
||||
assert!(is_audio_file(Path::new("movie.mkv")));
|
||||
assert!(is_audio_file(Path::new("trailer.MOV")));
|
||||
assert!(is_audio_file(Path::new("animation.avi")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backend_auto_order_prefers_cuda_then_hip_then_vulkan_then_cpu() {
|
||||
let _guard = ENV_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
|
||||
// Clear overrides
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
|
||||
}
|
||||
// No GPU -> CPU
|
||||
let sel = select_backend(BackendKind::Auto, &polyscribe::Config::default()).unwrap();
|
||||
assert_eq!(sel.chosen, BackendKind::Cpu);
|
||||
// Vulkan only
|
||||
unsafe {
|
||||
std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1");
|
||||
}
|
||||
let sel = select_backend(BackendKind::Auto, &polyscribe::Config::default()).unwrap();
|
||||
assert_eq!(sel.chosen, BackendKind::Vulkan);
|
||||
// HIP preferred over Vulkan
|
||||
unsafe {
|
||||
std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
|
||||
}
|
||||
let sel = select_backend(BackendKind::Auto, &polyscribe::Config::default()).unwrap();
|
||||
assert_eq!(sel.chosen, BackendKind::Hip);
|
||||
// CUDA preferred over HIP
|
||||
unsafe {
|
||||
std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1");
|
||||
}
|
||||
let sel = select_backend(BackendKind::Auto, &polyscribe::Config::default()).unwrap();
|
||||
assert_eq!(sel.chosen, BackendKind::Cuda);
|
||||
// Cleanup
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backend_explicit_missing_errors() {
|
||||
let _guard = ENV_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
|
||||
// Ensure all off
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
|
||||
}
|
||||
assert!(select_backend(BackendKind::Cuda, &polyscribe::Config::default()).is_err());
|
||||
assert!(select_backend(BackendKind::Hip, &polyscribe::Config::default()).is_err());
|
||||
assert!(select_backend(BackendKind::Vulkan, &polyscribe::Config::default()).is_err());
|
||||
// Turn on CUDA only
|
||||
unsafe {
|
||||
std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1");
|
||||
}
|
||||
assert!(select_backend(BackendKind::Cuda, &polyscribe::Config::default()).is_ok());
|
||||
// Turn on HIP only
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
|
||||
std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1");
|
||||
}
|
||||
assert!(select_backend(BackendKind::Hip, &polyscribe::Config::default()).is_ok());
|
||||
// Turn on Vulkan only
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
|
||||
std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1");
|
||||
}
|
||||
assert!(select_backend(BackendKind::Vulkan, &polyscribe::Config::default()).is_ok());
|
||||
// Cleanup
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_interaction_disables_speaker_prompt() {
|
||||
use polyscribe::ui;
|
||||
// Ensure non-interactive via env and global flag
|
||||
unsafe {
|
||||
std_env::set_var("NO_INTERACTION", "1");
|
||||
}
|
||||
polyscribe::set_no_interaction(true);
|
||||
ui::testing_reset_prompt_call_counters();
|
||||
// Build a minimal progress manager
|
||||
let pf = polyscribe::progress::ProgressFactory::from_config(&polyscribe::Config::default());
|
||||
let pm = pf.make_manager(polyscribe::progress::ProgressMode::Single);
|
||||
let dummy = std::path::PathBuf::from("example.wav");
|
||||
let got = super::prompt_speaker_name_for_path(&dummy, "DefaultSpeaker", /*enabled:*/ true, &pm);
|
||||
assert_eq!(got, "DefaultSpeaker");
|
||||
assert_eq!(ui::testing_prompt_call_count(), 0, "no prompt functions should be called when NO_INTERACTION=1");
|
||||
// Cleanup
|
||||
unsafe {
|
||||
std_env::remove_var("NO_INTERACTION");
|
||||
}
|
||||
}
|
||||
}
|
1439
src/models.rs
1439
src/models.rs
File diff suppressed because it is too large
Load Diff
149
src/output.rs
149
src/output.rs
@@ -1,149 +0,0 @@
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use anyhow::Context;
|
||||
|
||||
use crate::render_srt;
|
||||
use crate::OutputRoot;
|
||||
|
||||
/// Which formats to write.
|
||||
pub struct OutputFormats {
|
||||
pub json: bool,
|
||||
pub toml: bool,
|
||||
pub srt: bool,
|
||||
}
|
||||
|
||||
impl OutputFormats {
|
||||
pub fn all() -> Self {
|
||||
Self { json: true, toml: true, srt: true }
|
||||
}
|
||||
}
|
||||
|
||||
fn any_target_exists(base: &Path, formats: &OutputFormats) -> bool {
|
||||
(formats.json && base.with_extension("json").exists())
|
||||
|| (formats.toml && base.with_extension("toml").exists())
|
||||
|| (formats.srt && base.with_extension("srt").exists())
|
||||
}
|
||||
|
||||
fn with_suffix(base: &Path, n: usize) -> PathBuf {
|
||||
let parent = base.parent().unwrap_or_else(|| Path::new(""));
|
||||
let name = base.file_name().and_then(|s| s.to_str()).unwrap_or("out");
|
||||
parent.join(format!("{}_{}", name, n))
|
||||
}
|
||||
|
||||
fn resolve_base(base: &Path, formats: &OutputFormats, force: bool) -> PathBuf {
|
||||
if force {
|
||||
return base.to_path_buf();
|
||||
}
|
||||
if !any_target_exists(base, formats) {
|
||||
return base.to_path_buf();
|
||||
}
|
||||
let mut n = 1usize;
|
||||
loop {
|
||||
let candidate = with_suffix(base, n);
|
||||
if !any_target_exists(&candidate, formats) {
|
||||
return candidate;
|
||||
}
|
||||
n += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Write outputs for the given base path (without extension).
|
||||
/// This will create files named `base.json`, `base.toml`, and `base.srt`
|
||||
/// according to the `formats` flags. JSON and TOML will always end with a trailing newline.
|
||||
pub fn write_outputs(base: &Path, root: &OutputRoot, formats: &OutputFormats, force: bool) -> anyhow::Result<()> {
|
||||
let base = resolve_base(base, formats, force);
|
||||
|
||||
if formats.json {
|
||||
let json_path = base.with_extension("json");
|
||||
let mut json_file = File::create(&json_path).with_context(|| {
|
||||
format!("Failed to create output file: {}", json_path.display())
|
||||
})?;
|
||||
serde_json::to_writer_pretty(&mut json_file, root)?;
|
||||
// ensure trailing newline
|
||||
writeln!(&mut json_file)?;
|
||||
}
|
||||
|
||||
if formats.toml {
|
||||
let toml_path = base.with_extension("toml");
|
||||
let toml_str = toml::to_string_pretty(root)?;
|
||||
let mut toml_file = File::create(&toml_path).with_context(|| {
|
||||
format!("Failed to create output file: {}", toml_path.display())
|
||||
})?;
|
||||
toml_file.write_all(toml_str.as_bytes())?;
|
||||
if !toml_str.ends_with('\n') {
|
||||
writeln!(&mut toml_file)?;
|
||||
}
|
||||
}
|
||||
|
||||
if formats.srt {
|
||||
let srt_path = base.with_extension("srt");
|
||||
let srt_str = render_srt(&root.items);
|
||||
let mut srt_file = File::create(&srt_path).with_context(|| {
|
||||
format!("Failed to create output file: {}", srt_path.display())
|
||||
})?;
|
||||
srt_file.write_all(srt_str.as_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::OutputEntry;
|
||||
|
||||
#[test]
|
||||
fn write_outputs_creates_files_and_newlines() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let base = dir.path().join("test_base");
|
||||
let items = vec![OutputEntry { id: 0, speaker: "Alice".to_string(), start: 0.0, end: 1.23, text: "Hello".to_string() }];
|
||||
let root = OutputRoot { items };
|
||||
|
||||
write_outputs(&base, &root, &OutputFormats::all(), false).unwrap();
|
||||
|
||||
let json_path = base.with_extension("json");
|
||||
let toml_path = base.with_extension("toml");
|
||||
let srt_path = base.with_extension("srt");
|
||||
|
||||
assert!(json_path.exists(), "json file should exist");
|
||||
assert!(toml_path.exists(), "toml file should exist");
|
||||
assert!(srt_path.exists(), "srt file should exist");
|
||||
|
||||
let json = std::fs::read_to_string(&json_path).unwrap();
|
||||
let toml = std::fs::read_to_string(&toml_path).unwrap();
|
||||
|
||||
assert!(json.ends_with('\n'), "json should end with newline");
|
||||
assert!(toml.ends_with('\n'), "toml should end with newline");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn suffix_is_added_when_file_exists_unless_forced() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let base = dir.path().join("run");
|
||||
|
||||
// Precreate a toml file for base to simulate existing output
|
||||
let pre_path = base.with_extension("toml");
|
||||
std::fs::create_dir_all(dir.path()).unwrap();
|
||||
std::fs::write(&pre_path, b"existing\n").unwrap();
|
||||
|
||||
let items = vec![OutputEntry { id: 0, speaker: "A".to_string(), start: 0.0, end: 1.0, text: "Hi".to_string() }];
|
||||
let root = OutputRoot { items };
|
||||
let fmts = OutputFormats { json: false, toml: true, srt: false };
|
||||
|
||||
// Without force, should write to run_1.toml
|
||||
write_outputs(&base, &root, &fmts, false).unwrap();
|
||||
assert!(base.with_file_name("run_1").with_extension("toml").exists());
|
||||
|
||||
// If run_1.toml also exists, next should be run_2.toml
|
||||
std::fs::write(base.with_file_name("run_1").with_extension("toml"), b"x\n").unwrap();
|
||||
write_outputs(&base, &root, &fmts, false).unwrap();
|
||||
assert!(base.with_file_name("run_2").with_extension("toml").exists());
|
||||
|
||||
// With force, should overwrite the base.toml
|
||||
write_outputs(&base, &root, &fmts, true).unwrap();
|
||||
let content = std::fs::read_to_string(pre_path).unwrap();
|
||||
assert!(content.ends_with('\n'));
|
||||
}
|
||||
}
|
848
src/progress.rs
848
src/progress.rs
@@ -1,848 +0,0 @@
|
||||
// Progress abstraction for STDERR-only, TTY-aware progress bars.
|
||||
// Centralizes progress logic so it can be swapped or disabled easily.
|
||||
|
||||
use std::env;
|
||||
use std::io::IsTerminal;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Instant;
|
||||
|
||||
use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle};
|
||||
|
||||
// Global hook to route logs through the active progress manager so they render within
|
||||
// the same cliclack/indicatif area instead of raw stderr.
|
||||
static GLOBAL_PM: std::sync::Mutex<Option<ProgressManager>> = std::sync::Mutex::new(None);
|
||||
|
||||
/// Install a global ProgressManager used for printing log lines above bars.
|
||||
pub fn set_global_progress_manager(pm: &ProgressManager) {
|
||||
if let Ok(mut g) = GLOBAL_PM.lock() {
|
||||
*g = Some(pm.clone());
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove the global ProgressManager hook.
|
||||
pub fn clear_global_progress_manager() {
|
||||
if let Ok(mut g) = GLOBAL_PM.lock() {
|
||||
*g = None;
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to print a line via the global ProgressManager, returning true if handled.
|
||||
pub fn log_line_via_global(line: &str) -> bool {
|
||||
if let Ok(g) = GLOBAL_PM.lock() {
|
||||
if let Some(pm) = g.as_ref() {
|
||||
pm.println_above_bars(line);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
const NAME_WIDTH: usize = 28;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Progress message sent from worker threads to the UI/main thread.
|
||||
/// fraction: 0.0..1.0 progress value; stage/message are optional labels.
|
||||
pub struct ProgressMessage {
|
||||
/// Fractional progress in range 0.0..=1.0.
|
||||
pub fraction: f32,
|
||||
/// Optional stage label (e.g., "load_model", "encode", "decode", "done").
|
||||
pub stage: Option<String>,
|
||||
/// Optional human-readable note.
|
||||
pub note: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
/// Mode describing how progress should be displayed.
|
||||
///
|
||||
/// - None: progress is disabled or not supported.
|
||||
/// - Single: one spinner for the current item only.
|
||||
/// - Multi: a total progress bar plus a current-item spinner.
|
||||
pub enum ProgressMode {
|
||||
/// No progress output.
|
||||
None,
|
||||
/// Single spinner for the currently processed item.
|
||||
Single,
|
||||
/// Multi-bar progress including a total counter of all inputs.
|
||||
Multi {
|
||||
/// Total number of inputs to process when using multi-bar mode.
|
||||
total_inputs: u64,
|
||||
},
|
||||
}
|
||||
|
||||
fn stderr_is_tty() -> bool {
|
||||
// Prefer std IsTerminal when available
|
||||
std::io::stderr().is_terminal()
|
||||
}
|
||||
|
||||
fn progress_disabled_by_env() -> bool {
|
||||
matches!(env::var("NO_PROGRESS"), Ok(ref v) if v == "1" || v.eq_ignore_ascii_case("true"))
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
/// Factory that decides progress mode and produces a ProgressManager bound to stderr.
|
||||
pub struct ProgressFactory {
|
||||
enabled: bool,
|
||||
mp: Option<Arc<MultiProgress>>,
|
||||
}
|
||||
|
||||
impl ProgressFactory {
|
||||
/// Create a factory that enables progress when stderr is a TTY and neither
|
||||
/// the NO_PROGRESS env var nor the force_disable flag are set.
|
||||
pub fn new(force_disable: bool) -> Self {
|
||||
let tty = stderr_is_tty();
|
||||
let env_off = progress_disabled_by_env();
|
||||
let enabled = !(force_disable || env_off) && tty;
|
||||
if enabled {
|
||||
let mp = MultiProgress::with_draw_target(ProgressDrawTarget::stderr_with_hz(20));
|
||||
// Render tick even if nothing changes periodically for spinner feel
|
||||
mp.set_move_cursor(true);
|
||||
Self {
|
||||
enabled,
|
||||
mp: Some(Arc::new(mp)),
|
||||
}
|
||||
} else {
|
||||
Self {
|
||||
enabled: false,
|
||||
mp: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Decide a suitable ProgressMode for the given number of inputs,
|
||||
/// respecting whether progress is globally enabled.
|
||||
pub fn decide_mode(&self, inputs_len: usize) -> ProgressMode {
|
||||
if !self.enabled {
|
||||
return ProgressMode::None;
|
||||
}
|
||||
if inputs_len == 0 {
|
||||
ProgressMode::None
|
||||
} else if inputs_len == 1 {
|
||||
ProgressMode::Single
|
||||
} else {
|
||||
ProgressMode::Multi {
|
||||
total_inputs: inputs_len as u64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a ProgressManager for the previously decided mode. Returns
|
||||
/// a no-op manager when progress is disabled.
|
||||
pub fn make_manager(&self, mode: ProgressMode) -> ProgressManager {
|
||||
match (self.enabled, &self.mp, mode) {
|
||||
(true, Some(mp), ProgressMode::Single) => ProgressManager::with_single(mp.clone()),
|
||||
(true, Some(mp), ProgressMode::Multi { total_inputs }) => {
|
||||
ProgressManager::with_multi(mp.clone(), total_inputs)
|
||||
}
|
||||
_ => ProgressManager::noop(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Preferred constructor using Config. Respects config.no_progress and TTY.
|
||||
pub fn from_config(config: &crate::Config) -> Self {
|
||||
// Prefer Config.no_progress over manual flag; still honor NO_PROGRESS env var.
|
||||
let force_disable = config.no_progress;
|
||||
Self::new(force_disable)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
/// Handle for updating and finishing progress bars or a no-op when disabled.
|
||||
pub struct ProgressManager {
|
||||
inner: ProgressInner,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum ProgressInner {
|
||||
Noop,
|
||||
Single(Arc<SingleBars>),
|
||||
Multi(Arc<MultiBars>),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SingleBars {
|
||||
header: ProgressBar,
|
||||
info: ProgressBar,
|
||||
current: ProgressBar,
|
||||
// keep MultiProgress alive for suspend/println behavior
|
||||
_mp: Arc<MultiProgress>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MultiBars {
|
||||
// Header row shown above bars
|
||||
header: ProgressBar,
|
||||
// Single info/status row shown under header and above bars
|
||||
info: ProgressBar,
|
||||
// Bars: current file and total
|
||||
current: ProgressBar,
|
||||
total: ProgressBar,
|
||||
// Optional per-file bars and aggregated total percent bar (unused in new UX)
|
||||
files: Mutex<Option<Vec<ProgressBar>>>, // each length 100
|
||||
total_pct: Mutex<Option<ProgressBar>>, // length 100
|
||||
// Metadata for aggregation
|
||||
sizes: Mutex<Option<Vec<Option<u64>>>>,
|
||||
fractions: Mutex<Option<Vec<f32>>>, // 0..=1 per file
|
||||
last_total_draw_ms: Mutex<Instant>,
|
||||
// keep MultiProgress alive
|
||||
_mp: Arc<MultiProgress>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
/// Handle for per-item progress updates. Safe to clone and send across threads to update
|
||||
/// the currently active item's progress without affecting the global total counter.
|
||||
pub struct ItemHandle {
|
||||
pb: ProgressBar,
|
||||
}
|
||||
|
||||
impl ItemHandle {
|
||||
/// Update the determinate progress for this item using a fraction in 0.0..=1.0.
|
||||
/// Internally mapped to 0..100 units.
|
||||
pub fn set_progress(&self, fraction: f32) {
|
||||
let f = if fraction.is_nan() { 0.0 } else { fraction.clamp(0.0, 1.0) };
|
||||
let pos = (f * 100.0).round() as u64;
|
||||
if self.pb.length().unwrap_or(0) == 0 {
|
||||
self.pb.set_length(100);
|
||||
}
|
||||
if self.pb.position() != pos {
|
||||
self.pb.set_position(pos);
|
||||
}
|
||||
}
|
||||
/// Set a human-readable message for this item (e.g., current stage name).
|
||||
pub fn set_message(&self, message: &str) {
|
||||
self.pb.set_message(message.to_string());
|
||||
}
|
||||
/// Finish this item by prefixing "done " to the currently displayed message.
|
||||
/// The provided message parameter is ignored to preserve stable width and avoid flicker.
|
||||
pub fn finish_with(&self, _message: &str) {
|
||||
if !self.pb.is_finished() {
|
||||
self.pb.finish_with_message(_message.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProgressManager {
|
||||
/// Test helper: create a Multi-mode manager with a hidden draw target, safe for tests
|
||||
/// even when not attached to a TTY.
|
||||
pub fn new_for_tests_multi_hidden(total: usize) -> Self {
|
||||
let mp = Arc::new(MultiProgress::with_draw_target(ProgressDrawTarget::hidden()));
|
||||
Self::with_multi(mp, total as u64)
|
||||
}
|
||||
|
||||
/// Test helper: create a Single-mode manager with a hidden draw target, safe for tests
|
||||
/// even when not attached to a TTY.
|
||||
pub fn new_for_tests_single_hidden() -> Self {
|
||||
let mp = Arc::new(MultiProgress::with_draw_target(ProgressDrawTarget::hidden()));
|
||||
Self::with_single(mp)
|
||||
}
|
||||
|
||||
/// Backwards-compatible constructor used by older tests: same as new_for_tests_multi_hidden.
|
||||
pub fn test_new_multi(total: usize) -> Self {
|
||||
Self::new_for_tests_multi_hidden(total)
|
||||
}
|
||||
|
||||
/// Test helper: return (completed, total) for the global bar if present.
|
||||
pub fn total_state_for_tests(&self) -> Option<(u64, u64)> {
|
||||
match &self.inner {
|
||||
ProgressInner::Multi(m) => Some((m.total.position(), m.total.length().unwrap_or(0))),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Test helper: return the number of visible bars managed initially.
|
||||
/// Single mode: 3 (header, info, current). Multi mode: 4 (header, info, current, total).
|
||||
pub fn testing_bar_count(&self) -> usize {
|
||||
match &self.inner {
|
||||
ProgressInner::Noop => 0,
|
||||
ProgressInner::Single(_) => 3,
|
||||
ProgressInner::Multi(m) => {
|
||||
// Base bars always present
|
||||
let mut count = 4;
|
||||
// If per-file bars were initialized, include them as well
|
||||
if let Ok(files) = m.files.lock() { if let Some(v) = &*files { count += v.len(); } }
|
||||
if let Ok(t) = m.total_pct.lock() { if t.is_some() { count += 1; } }
|
||||
count
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Test helper: get state of the current item bar (position, length, finished, message).
|
||||
pub fn current_state_for_tests(&self) -> Option<(u64, u64, bool, String)> {
|
||||
match &self.inner {
|
||||
ProgressInner::Single(s) => Some((
|
||||
s.current.position(),
|
||||
s.current.length().unwrap_or(0),
|
||||
s.current.is_finished(),
|
||||
s.current.message().to_string(),
|
||||
)),
|
||||
ProgressInner::Multi(m) => Some((
|
||||
m.current.position(),
|
||||
m.current.length().unwrap_or(0),
|
||||
m.current.is_finished(),
|
||||
m.current.message().to_string(),
|
||||
)),
|
||||
ProgressInner::Noop => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn noop() -> Self {
|
||||
Self {
|
||||
inner: ProgressInner::Noop,
|
||||
}
|
||||
}
|
||||
|
||||
fn with_single(mp: Arc<MultiProgress>) -> Self {
|
||||
// Order: header, info row, then current file bar
|
||||
let header = mp.add(ProgressBar::new(0));
|
||||
header.set_style(info_style());
|
||||
let info = mp.add(ProgressBar::new(0));
|
||||
info.set_style(info_style());
|
||||
let current = mp.add(ProgressBar::new(100));
|
||||
current.set_style(current_style());
|
||||
Self {
|
||||
inner: ProgressInner::Single(Arc::new(SingleBars { header, info, current, _mp: mp })),
|
||||
}
|
||||
}
|
||||
|
||||
fn with_multi(mp: Arc<MultiProgress>, total_inputs: u64) -> Self {
|
||||
// Order: header, info row, then current file bar, then total bar at the bottom
|
||||
let header = mp.add(ProgressBar::new(0));
|
||||
header.set_style(info_style());
|
||||
let info = mp.add(ProgressBar::new(0));
|
||||
info.set_style(info_style());
|
||||
let current = mp.add(ProgressBar::new(100));
|
||||
current.set_style(current_style());
|
||||
let total = mp.add(ProgressBar::new(total_inputs));
|
||||
total.set_style(total_style());
|
||||
Self {
|
||||
inner: ProgressInner::Multi(Arc::new(MultiBars {
|
||||
header,
|
||||
info,
|
||||
current,
|
||||
total,
|
||||
files: Mutex::new(None),
|
||||
total_pct: Mutex::new(None),
|
||||
sizes: Mutex::new(None),
|
||||
fractions: Mutex::new(None),
|
||||
last_total_draw_ms: Mutex::new(Instant::now()),
|
||||
_mp: mp,
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the total number of items for the global progress (multi mode).
|
||||
pub fn set_total(&self, n: usize) {
|
||||
match &self.inner {
|
||||
ProgressInner::Multi(m) => {
|
||||
m.total.set_length(n as u64);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark exactly one completed item (clamped to not exceed total).
|
||||
pub fn inc_completed(&self) {
|
||||
match &self.inner {
|
||||
ProgressInner::Multi(m) => {
|
||||
let len = m.total.length().unwrap_or(0);
|
||||
let pos = m.total.position();
|
||||
if pos < len {
|
||||
m.total.inc(1);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Start a new item handle with an optional label.
|
||||
pub fn start_item(&self, label: &str) -> ItemHandle {
|
||||
match &self.inner {
|
||||
ProgressInner::Noop => ItemHandle { pb: ProgressBar::hidden() },
|
||||
ProgressInner::Single(s) => {
|
||||
s.current.set_message(label.to_string());
|
||||
ItemHandle { pb: s.current.clone() }
|
||||
}
|
||||
ProgressInner::Multi(m) => {
|
||||
m.current.set_message(label.to_string());
|
||||
ItemHandle { pb: m.current.clone() }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Pause progress rendering to allow a clean prompt line to be printed.
|
||||
pub fn pause_for_prompt(&self) {
|
||||
match &self.inner {
|
||||
ProgressInner::Noop => {}
|
||||
ProgressInner::Single(s) => {
|
||||
let _ = s._mp.suspend(|| {});
|
||||
}
|
||||
ProgressInner::Multi(m) => {
|
||||
let _ = m._mp.suspend(|| {});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Print a line above the bars safely (TTY-aware). Falls back to eprintln! when disabled.
|
||||
pub fn println_above_bars(&self, line: &str) {
|
||||
// Try to interpret certain INFO lines as a stable title + dynamic message.
|
||||
// Examples to match:
|
||||
// - "INFO: Fetching online data: listing models from ggerganov/whisper.cpp..."
|
||||
// -> header = "INFO: Fetching online data"; info = "listing models from ..."
|
||||
// - "INFO: Downloading tiny.en-q5_1 (252 MiB | https://...)..."
|
||||
// -> header = "INFO: Downloading"; info = rest
|
||||
// - "INFO: Total 1/3" (defensive): header = "INFO: Total"; info = rest
|
||||
let parsed: Option<(String, String)> = {
|
||||
let s = line.trim();
|
||||
if let Some(rest) = s.strip_prefix("INFO: ") {
|
||||
// Case A: explicit title followed by colon
|
||||
if let Some((title, body)) = rest.split_once(':') {
|
||||
let title_clean = format!("INFO: {}", title.trim());
|
||||
let body_clean = body.trim().to_string();
|
||||
Some((title_clean, body_clean))
|
||||
} else if let Some(rest2) = rest.strip_prefix("Downloading ") {
|
||||
Some(("INFO: Downloading".to_string(), rest2.trim().to_string()))
|
||||
} else if let Some(rest2) = rest.strip_prefix("Total") {
|
||||
Some(("INFO: Total".to_string(), rest2.trim().to_string()))
|
||||
} else {
|
||||
// Fallback: use first word as title, remainder as body
|
||||
let mut it = rest.splitn(2, ' ');
|
||||
let first = it.next().unwrap_or("").trim();
|
||||
let remainder = it.next().unwrap_or("").trim();
|
||||
if !first.is_empty() {
|
||||
Some((format!("INFO: {}", first), remainder.to_string()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
match &self.inner {
|
||||
ProgressInner::Noop => eprintln!("{}", line),
|
||||
ProgressInner::Single(s) => {
|
||||
if let Some((title, body)) = parsed.as_ref() {
|
||||
s.header.set_message(title.clone());
|
||||
s.info.set_message(body.clone());
|
||||
} else {
|
||||
let _ = s._mp.println(line);
|
||||
}
|
||||
}
|
||||
ProgressInner::Multi(m) => {
|
||||
if let Some((title, body)) = parsed.as_ref() {
|
||||
m.header.set_message(title.clone());
|
||||
m.info.set_message(body.clone());
|
||||
} else {
|
||||
let _ = m._mp.println(line);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Resume progress after a prompt (currently a no-op; redraw continues automatically).
|
||||
pub fn resume_after_prompt(&self) {}
|
||||
|
||||
/// Set the message for the current-item spinner.
|
||||
pub fn set_current_message(&self, msg: &str) {
|
||||
match &self.inner {
|
||||
ProgressInner::Noop => {}
|
||||
ProgressInner::Single(s) => s.current.set_message(msg.to_string()),
|
||||
ProgressInner::Multi(m) => m.current.set_message(msg.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set an explicit length for the current-item spinner (useful when it becomes a determinate bar).
|
||||
pub fn set_current_length(&self, len: u64) {
|
||||
match &self.inner {
|
||||
ProgressInner::Noop => {}
|
||||
ProgressInner::Single(s) => s.current.set_length(len),
|
||||
ProgressInner::Multi(m) => m.current.set_length(len),
|
||||
}
|
||||
}
|
||||
|
||||
/// Increment the current-item spinner by the given delta.
|
||||
pub fn inc_current(&self, delta: u64) {
|
||||
match &self.inner {
|
||||
ProgressInner::Noop => {}
|
||||
ProgressInner::Single(s) => s.current.inc(delta),
|
||||
ProgressInner::Multi(m) => m.current.inc(delta),
|
||||
}
|
||||
}
|
||||
|
||||
/// Finish the current-item spinner by prefixing "done " to its current message.
|
||||
pub fn finish_current_with(&self, _msg: &str) {
|
||||
match &self.inner {
|
||||
ProgressInner::Noop => {}
|
||||
ProgressInner::Single(s) => {
|
||||
let orig = s.current.message().to_string();
|
||||
s.current.finish_with_message(format!("done {}", orig));
|
||||
}
|
||||
ProgressInner::Multi(m) => {
|
||||
let orig = m.current.message().to_string();
|
||||
m.current.finish_with_message(format!("done {}", orig));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Increment the total progress bar by the given delta (multi-bar mode only).
|
||||
pub fn inc_total(&self, delta: u64) {
|
||||
match &self.inner {
|
||||
ProgressInner::Noop => {}
|
||||
ProgressInner::Single(_) => {}
|
||||
ProgressInner::Multi(m) => m.total.inc(delta),
|
||||
}
|
||||
}
|
||||
|
||||
/// Finish progress bars. Keep total bar visible with a final message and prefix "done " for items.
|
||||
pub fn finish_all(&self) {
|
||||
match &self.inner {
|
||||
ProgressInner::Noop => {}
|
||||
ProgressInner::Single(s) => {
|
||||
if !s.current.is_finished() {
|
||||
let orig = s.current.message().to_string();
|
||||
s.current.finish_with_message(format!("done {}", orig));
|
||||
}
|
||||
}
|
||||
ProgressInner::Multi(m) => {
|
||||
// If per-file bars are active, finish each with stable "done <msg>"
|
||||
let mut had_files = false;
|
||||
if let Ok(g) = m.files.lock() {
|
||||
if let Some(files) = g.as_ref() {
|
||||
had_files = true;
|
||||
for pb in files.iter() {
|
||||
if !pb.is_finished() {
|
||||
let orig = pb.message().to_string();
|
||||
pb.finish_with_message(format!("done {}", orig));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Finish the aggregated total percent bar or the legacy total
|
||||
if let Ok(gt) = m.total_pct.lock() {
|
||||
if let Some(tpb) = gt.as_ref() {
|
||||
if !tpb.is_finished() {
|
||||
tpb.finish_with_message("100% total".to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
if !had_files {
|
||||
// Legacy total/current bars: keep total visible too
|
||||
let len = m.total.length().unwrap_or(0);
|
||||
if !m.current.is_finished() {
|
||||
m.current.finish_and_clear();
|
||||
}
|
||||
if !m.total.is_finished() {
|
||||
m.total.finish_with_message(format!("{}/{} total", len, len));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Set determinate progress of the current item using a fractional value 0.0..=1.0.
|
||||
pub fn set_progress(&self, fraction: f32) {
|
||||
let f = if fraction.is_nan() { 0.0 } else { fraction.clamp(0.0, 1.0) };
|
||||
let pos = (f * 100.0).round() as u64;
|
||||
match &self.inner {
|
||||
ProgressInner::Noop => {}
|
||||
ProgressInner::Single(s) => {
|
||||
if s.current.length().unwrap_or(0) == 0 {
|
||||
s.current.set_length(100);
|
||||
}
|
||||
if s.current.position() != pos {
|
||||
s.current.set_position(pos);
|
||||
}
|
||||
}
|
||||
ProgressInner::Multi(m) => {
|
||||
if m.current.length().unwrap_or(0) == 0 {
|
||||
m.current.set_length(100);
|
||||
}
|
||||
if m.current.position() != pos {
|
||||
m.current.set_position(pos);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a message/label for the current item (alias for set_current_message).
|
||||
pub fn set_message(&self, message: &str) {
|
||||
self.set_current_message(message);
|
||||
}
|
||||
}
|
||||
|
||||
fn current_style() -> ProgressStyle {
|
||||
// Per-item determinate progress: show 0..100 as pos/len with a simple bar
|
||||
ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] {pos}/{len} {bar:40.cyan/blue} {msg}")
|
||||
.expect("invalid progress template in current_style()")
|
||||
}
|
||||
|
||||
fn info_style() -> ProgressStyle {
|
||||
ProgressStyle::with_template("{msg}").unwrap()
|
||||
}
|
||||
|
||||
fn total_style() -> ProgressStyle {
|
||||
// Bottom total bar with elapsed time
|
||||
ProgressStyle::with_template("Total [{bar:28}] {pos}/{len} [{elapsed_precise}]")
|
||||
.unwrap()
|
||||
.progress_chars("=> ")
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
/// Inputs used to determine progress enablement and mode.
|
||||
pub struct SelectionInput {
|
||||
/// Number of inputs to process (used to choose single vs multi mode).
|
||||
pub inputs_len: usize,
|
||||
/// Whether progress was explicitly disabled via a CLI flag.
|
||||
pub no_progress_flag: bool,
|
||||
/// Optional override for whether stderr is a TTY; if None, auto-detect.
|
||||
pub stderr_tty_override: Option<bool>,
|
||||
/// Whether progress was disabled via the NO_PROGRESS environment variable.
|
||||
pub env_no_progress: bool,
|
||||
}
|
||||
|
||||
/// Decide whether progress is enabled and which mode to use based on SelectionInput.
|
||||
pub fn select_mode(si: SelectionInput) -> (bool, ProgressMode) {
|
||||
// Compute effective enablement
|
||||
let tty = si.stderr_tty_override.unwrap_or_else(stderr_is_tty);
|
||||
let disabled = si.no_progress_flag || si.env_no_progress;
|
||||
let enabled = tty && !disabled;
|
||||
let mode = if !enabled || si.inputs_len == 0 {
|
||||
ProgressMode::None
|
||||
} else if si.inputs_len == 1 {
|
||||
ProgressMode::Single
|
||||
} else {
|
||||
ProgressMode::Multi {
|
||||
total_inputs: si.inputs_len as u64,
|
||||
}
|
||||
};
|
||||
(enabled, mode)
|
||||
}
|
||||
|
||||
/// Optional Ctrl-C cleanup: clears progress bars and removes temporary files before exiting on SIGINT.
|
||||
pub fn install_ctrlc_cleanup(pm: ProgressManager) {
|
||||
let state = Arc::new(Mutex::new(Some(pm.clone())));
|
||||
let state_clone = state.clone();
|
||||
if let Err(e) = ctrlc::set_handler(move || {
|
||||
// Clear any visible progress bars
|
||||
if let Ok(mut guard) = state_clone.lock() {
|
||||
if let Some(pm) = guard.take() {
|
||||
pm.finish_all();
|
||||
}
|
||||
}
|
||||
// Best-effort removal of the last-model cache so it doesn't persist after Ctrl-C
|
||||
let models_dir = crate::models_dir_path();
|
||||
let last_path = models_dir.join(".last_model");
|
||||
let _ = std::fs::remove_file(&last_path);
|
||||
// Also remove any unfinished model downloads ("*.part")
|
||||
if let Ok(rd) = std::fs::read_dir(&models_dir) {
|
||||
for entry in rd.flatten() {
|
||||
let p = entry.path();
|
||||
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
|
||||
if name.ends_with(".part") {
|
||||
let _ = std::fs::remove_file(&p);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Exit with 130 to reflect SIGINT
|
||||
std::process::exit(130);
|
||||
}) {
|
||||
// Warn if we failed to install the handler; without it, Ctrl-C won't trigger cleanup
|
||||
crate::wlog!("Failed to install Ctrl-C handler: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// --- New: Per-file progress bars API for Multi mode ---
|
||||
impl ProgressManager {
|
||||
/// Initialize per-file bars and an aggregated total percent bar using indicatif::MultiProgress.
|
||||
/// Each bar has length 100 and shows a truncated filename as message.
|
||||
/// This replaces the legacy current/total display with fixed per-file lines.
|
||||
pub fn init_files<I, S>(&self, labels_and_sizes: I)
|
||||
where
|
||||
I: IntoIterator<Item = (S, Option<u64>)>,
|
||||
S: Into<String>,
|
||||
{
|
||||
if let ProgressInner::Multi(m) = &self.inner {
|
||||
// Clear legacy bars from display to avoid duplication
|
||||
m.current.finish_and_clear();
|
||||
m.total.finish_and_clear();
|
||||
let mut files: Vec<ProgressBar> = Vec::new();
|
||||
let mut sizes: Vec<Option<u64>> = Vec::new();
|
||||
let mut fractions: Vec<f32> = Vec::new();
|
||||
for (label_in, size_opt) in labels_and_sizes {
|
||||
let label: String = label_in.into();
|
||||
let pb = m._mp.add(ProgressBar::new(100));
|
||||
pb.set_style(current_style());
|
||||
let short = truncate_label(&label, NAME_WIDTH);
|
||||
pb.set_message(format!("{:<width$}", short, width = NAME_WIDTH));
|
||||
files.push(pb);
|
||||
sizes.push(size_opt);
|
||||
fractions.push(0.0);
|
||||
}
|
||||
let total_pct = m._mp.add(ProgressBar::new(100));
|
||||
total_pct
|
||||
.set_style(ProgressStyle::with_template("{bar:40.cyan/blue} {percent:>3}% total").unwrap());
|
||||
// Store
|
||||
if let Ok(mut gf) = m.files.lock() { *gf = Some(files); }
|
||||
if let Ok(mut gt) = m.total_pct.lock() { *gt = Some(total_pct); }
|
||||
if let Ok(mut gs) = m.sizes.lock() { *gs = Some(sizes); }
|
||||
if let Ok(mut gfr) = m.fractions.lock() { *gfr = Some(fractions); }
|
||||
if let Ok(mut t) = m.last_total_draw_ms.lock() { *t = Instant::now(); }
|
||||
}
|
||||
}
|
||||
|
||||
/// Return whether per-file bars are active (Multi mode only)
|
||||
pub fn has_file_bars(&self) -> bool {
|
||||
match &self.inner {
|
||||
ProgressInner::Multi(m) => m.files.lock().map(|g| g.is_some()).unwrap_or(false),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get an item handle for a specific file index (Multi mode with file bars). Falls back to legacy current.
|
||||
pub fn item_handle_at(&self, index: usize) -> ItemHandle {
|
||||
match &self.inner {
|
||||
ProgressInner::Multi(m) => {
|
||||
if let Ok(g) = m.files.lock() {
|
||||
if let Some(vec) = g.as_ref() {
|
||||
if let Some(pb) = vec.get(index) {
|
||||
return ItemHandle { pb: pb.clone() };
|
||||
}
|
||||
}
|
||||
}
|
||||
ItemHandle { pb: m.current.clone() }
|
||||
}
|
||||
ProgressInner::Single(s) => ItemHandle { pb: s.current.clone() },
|
||||
ProgressInner::Noop => ItemHandle { pb: ProgressBar::hidden() },
|
||||
}
|
||||
}
|
||||
|
||||
/// Update a specific file's progress (0.0..=1.0) and recompute the aggregated total percent.
|
||||
pub fn set_file_progress(&self, index: usize, fraction: f32) {
|
||||
let f = if fraction.is_nan() { 0.0 } else { fraction.clamp(0.0, 1.0) };
|
||||
if let ProgressInner::Multi(m) = &self.inner {
|
||||
if let Ok(gf) = m.files.lock() {
|
||||
if let Some(files) = gf.as_ref() {
|
||||
if index < files.len() {
|
||||
let pb = &files[index];
|
||||
pb.set_length(100);
|
||||
let pos = (f * 100.0).round() as u64;
|
||||
if pb.position() != pos {
|
||||
pb.set_position(pos);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Ok(mut gfr) = m.fractions.lock() {
|
||||
if let Some(fracs) = gfr.as_mut() {
|
||||
if index < fracs.len() {
|
||||
fracs[index] = f;
|
||||
}
|
||||
}
|
||||
}
|
||||
self.recompute_total_pct();
|
||||
}
|
||||
}
|
||||
|
||||
fn recompute_total_pct(&self) {
|
||||
if let ProgressInner::Multi(m) = &self.inner {
|
||||
let has_total = m.total_pct.lock().map(|g| g.is_some()).unwrap_or(false);
|
||||
if !has_total {
|
||||
return;
|
||||
}
|
||||
let now = Instant::now();
|
||||
let do_draw = if let Ok(mut last) = m.last_total_draw_ms.lock() {
|
||||
if now.duration_since(*last).as_millis() >= 50 {
|
||||
*last = now;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
true
|
||||
};
|
||||
if !do_draw {
|
||||
return;
|
||||
}
|
||||
let fractions = match m.fractions.lock().ok().and_then(|g| g.clone()) {
|
||||
Some(v) => v,
|
||||
None => return,
|
||||
};
|
||||
let sizes_opt = m.sizes.lock().ok().and_then(|g| g.clone());
|
||||
let pct = if let Some(sizes) = sizes_opt.as_ref() {
|
||||
if !sizes.is_empty() && sizes.iter().all(|o| o.is_some()) {
|
||||
let mut num: f64 = 0.0;
|
||||
let mut den: f64 = 0.0;
|
||||
for (f, s) in fractions.iter().zip(sizes.iter()) {
|
||||
let sz = s.unwrap_or(0) as f64;
|
||||
num += (*f as f64) * sz;
|
||||
den += sz;
|
||||
}
|
||||
if den > 0.0 { (num / den) as f32 } else { 0.0 }
|
||||
} else {
|
||||
// Fallback to unweighted average
|
||||
if fractions.is_empty() { 0.0 } else { (fractions.iter().sum::<f32>()) / (fractions.len() as f32) }
|
||||
}
|
||||
} else {
|
||||
if fractions.is_empty() { 0.0 } else { (fractions.iter().sum::<f32>()) / (fractions.len() as f32) }
|
||||
};
|
||||
let pos = (pct.clamp(0.0, 1.0) * 100.0).round() as u64;
|
||||
if let Ok(gt) = m.total_pct.lock() {
|
||||
if let Some(total_pb) = gt.as_ref() {
|
||||
total_pb.set_length(100);
|
||||
if total_pb.position() != pos {
|
||||
total_pb.set_position(pos);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn truncate_label(s: &str, max: usize) -> String {
|
||||
if s.len() <= max {
|
||||
s.to_string()
|
||||
} else {
|
||||
if max <= 3 {
|
||||
return ".".repeat(max);
|
||||
}
|
||||
let keep = max - 3;
|
||||
let truncated = s.chars().take(keep).collect::<String>();
|
||||
format!("{}...", truncated)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::truncate_label;
|
||||
|
||||
#[test]
|
||||
fn truncate_keeps_short_and_exact() {
|
||||
assert_eq!(truncate_label("short", 10), "short");
|
||||
assert_eq!(truncate_label("short", 5), "short");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_long_adds_ellipsis() {
|
||||
assert_eq!(truncate_label("abcdefghij", 8), "abcde...");
|
||||
assert_eq!(truncate_label("filename_long.flac", 12), "filename_...");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_small_max_returns_dots() {
|
||||
assert_eq!(truncate_label("anything", 3), "...");
|
||||
assert_eq!(truncate_label("anything", 2), "..");
|
||||
assert_eq!(truncate_label("anything", 1), ".");
|
||||
assert_eq!(truncate_label("anything", 0), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_handles_unicode_by_char_boundary() {
|
||||
// Using chars().take(keep) prevents splitting code points; not grapheme-perfect but safe.
|
||||
// "é" is 2 bytes but 1 char; keep=2 should keep "Aé" then add dots
|
||||
let s = "AéBCD"; // chars: A, é, B, C, D
|
||||
assert_eq!(truncate_label(s, 5), "Aé..."); // keep 2 chars + ...
|
||||
}
|
||||
}
|
115
src/ui.rs
115
src/ui.rs
@@ -1,115 +0,0 @@
|
||||
// Centralized UI helpers for interactive prompts.
|
||||
// Uses cliclack for consistent TTY-friendly UX.
|
||||
//
|
||||
// If you need a new prompt type, add it here so callers don't depend on a specific library.
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
// Test-visible counter to detect accidental prompt calls in non-interactive/CI contexts.
|
||||
static PROMPT_CALLS: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
/// Reset the internal prompt call counter (testing aid).
|
||||
pub fn testing_reset_prompt_call_counters() {
|
||||
PROMPT_CALLS.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Get current prompt call count (testing aid).
|
||||
pub fn testing_prompt_call_count() -> usize {
|
||||
PROMPT_CALLS.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
fn note_prompt_call() {
|
||||
PROMPT_CALLS.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Prompt the user for a free-text value with a default fallback.
|
||||
///
|
||||
/// - Uses cliclack Input to render a TTY-friendly prompt.
|
||||
/// - Returns `default` when the user submits an empty value.
|
||||
/// - On any prompt error (e.g., non-TTY, read error), returns an error; callers should
|
||||
/// handle it and typically fall back to `default` in non-interactive contexts.
|
||||
pub fn prompt_text(prompt: &str, default: &str) -> Result<String> {
|
||||
note_prompt_call();
|
||||
let res: Result<String, _> = cliclack::input(prompt)
|
||||
.default_input(default)
|
||||
.interact();
|
||||
let value = res.map_err(|e| anyhow!("prompt error: {e}"))?;
|
||||
|
||||
let trimmed = value.trim();
|
||||
Ok(if trimmed.is_empty() {
|
||||
default.to_string()
|
||||
} else {
|
||||
trimmed.to_string()
|
||||
})
|
||||
}
|
||||
|
||||
/// Ask for yes/no confirmation with a default choice.
|
||||
///
|
||||
/// Returns the selected boolean. Any underlying prompt error is returned as an error.
|
||||
pub fn prompt_confirm(prompt: &str, default: bool) -> Result<bool> {
|
||||
note_prompt_call();
|
||||
let res: Result<bool, _> = cliclack::confirm(prompt)
|
||||
.initial_value(default)
|
||||
.interact();
|
||||
res.map_err(|e| anyhow!("prompt error: {e}"))
|
||||
}
|
||||
|
||||
/// Single-select from a list of displayable items, returning the selected index.
|
||||
///
|
||||
/// - `items`: non-empty slice of displayable items.
|
||||
/// - Returns the index into `items`.
|
||||
pub fn prompt_select_index<T: std::fmt::Display>(prompt: &str, items: &[T]) -> Result<usize> {
|
||||
if items.is_empty() {
|
||||
return Err(anyhow!("prompt_select_index called with empty items"));
|
||||
}
|
||||
note_prompt_call();
|
||||
let mut sel = cliclack::select(prompt);
|
||||
for (i, it) in items.iter().enumerate() {
|
||||
sel = sel.item(i, format!("{}", it), "");
|
||||
}
|
||||
let idx: usize = sel
|
||||
.interact()
|
||||
.map_err(|e| anyhow!("prompt error: {e}"))?;
|
||||
Ok(idx)
|
||||
}
|
||||
|
||||
/// Single-select from a list of clonable displayable items, returning the chosen item.
|
||||
pub fn prompt_select_one<T: std::fmt::Display + Clone>(prompt: &str, items: &[T]) -> Result<T> {
|
||||
let idx = prompt_select_index(prompt, items)?;
|
||||
Ok(items[idx].clone())
|
||||
}
|
||||
|
||||
/// Multi-select from a list, returning the selected indices.
|
||||
///
|
||||
/// - `defaults`: indices that should be pre-selected.
|
||||
pub fn prompt_multiselect_indices<T: std::fmt::Display>(
|
||||
prompt: &str,
|
||||
items: &[T],
|
||||
defaults: &[usize],
|
||||
) -> Result<Vec<usize>> {
|
||||
if items.is_empty() {
|
||||
return Err(anyhow!("prompt_multiselect_indices called with empty items"));
|
||||
}
|
||||
let mut ms = cliclack::multiselect(prompt);
|
||||
for (i, it) in items.iter().enumerate() {
|
||||
ms = ms.item(i, format!("{}", it), "");
|
||||
}
|
||||
note_prompt_call();
|
||||
let indices: Vec<usize> = ms
|
||||
.initial_values(defaults.to_vec())
|
||||
.required(false)
|
||||
.interact()
|
||||
.map_err(|e| anyhow!("prompt error: {e}"))?;
|
||||
Ok(indices)
|
||||
}
|
||||
|
||||
/// Multi-select from a list, returning the chosen items in order of appearance.
|
||||
pub fn prompt_multiselect<T: std::fmt::Display + Clone>(
|
||||
prompt: &str,
|
||||
items: &[T],
|
||||
defaults: &[usize],
|
||||
) -> Result<Vec<T>> {
|
||||
let indices = prompt_multiselect_indices(prompt, items, defaults)?;
|
||||
Ok(indices.into_iter().map(|i| items[i].clone()).collect())
|
||||
}
|
@@ -1,211 +0,0 @@
|
||||
use std::ffi::OsStr;
|
||||
use std::process::{Command, Stdio};
|
||||
use std::thread;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
fn bin() -> &'static str {
|
||||
env!("CARGO_BIN_EXE_polyscribe")
|
||||
}
|
||||
|
||||
fn manifest_path(rel: &str) -> std::path::PathBuf {
|
||||
let mut p = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
p.push(rel);
|
||||
p
|
||||
}
|
||||
|
||||
fn run_polyscribe<I, S>(args: I, timeout: Duration) -> std::io::Result<std::process::Output>
|
||||
where
|
||||
I: IntoIterator<Item = S>,
|
||||
S: AsRef<OsStr>,
|
||||
{
|
||||
let mut child = Command::new(bin())
|
||||
.args(args)
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.env_clear()
|
||||
.env("CI", "1")
|
||||
.env("NO_COLOR", "1")
|
||||
.spawn()?;
|
||||
|
||||
let start = Instant::now();
|
||||
loop {
|
||||
if let Some(status) = child.try_wait()? {
|
||||
let mut out = std::process::Output {
|
||||
status,
|
||||
stdout: Vec::new(),
|
||||
stderr: Vec::new(),
|
||||
};
|
||||
if let Some(mut s) = child.stdout.take() {
|
||||
use std::io::Read;
|
||||
let _ = std::io::copy(&mut s, &mut out.stdout);
|
||||
}
|
||||
if let Some(mut s) = child.stderr.take() {
|
||||
use std::io::Read;
|
||||
let _ = std::io::copy(&mut s, &mut out.stderr);
|
||||
}
|
||||
return Ok(out);
|
||||
}
|
||||
if start.elapsed() >= timeout {
|
||||
let _ = child.kill();
|
||||
let _ = child.wait();
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::TimedOut,
|
||||
"polyscribe timed out",
|
||||
));
|
||||
}
|
||||
thread::sleep(Duration::from_millis(10))
|
||||
}
|
||||
}
|
||||
|
||||
fn strip_ansi(s: &str) -> std::borrow::Cow<'_, str> {
|
||||
// Minimal stripper for ESC [ ... letter sequence
|
||||
if !s.as_bytes().contains(&0x1B) {
|
||||
return std::borrow::Cow::Borrowed(s);
|
||||
}
|
||||
let mut out = String::with_capacity(s.len());
|
||||
let mut bytes = s.as_bytes().iter().copied().peekable();
|
||||
while let Some(b) = bytes.next() {
|
||||
if b == 0x1B {
|
||||
// Try to consume CSI sequence: ESC '[' ... cmd
|
||||
if matches!(bytes.peek(), Some(b'[')) {
|
||||
let _ = bytes.next(); // skip '['
|
||||
// Skip params/intermediates until a final byte in 0x40..=0x77E
|
||||
while let Some(&c) = bytes.peek() {
|
||||
if (0x40..=0x7E).contains(&c) {
|
||||
let _ = bytes.next();
|
||||
break;
|
||||
}
|
||||
let _ = bytes.next();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// Skip single-char ESC sequences
|
||||
let _ = bytes.next();
|
||||
continue;
|
||||
}
|
||||
out.push(b as char);
|
||||
}
|
||||
std::borrow::Cow::Owned(out)
|
||||
}
|
||||
|
||||
fn count_err_in_summary(stderr: &str) -> usize {
|
||||
stderr
|
||||
.lines()
|
||||
.map(|l| strip_ansi(l))
|
||||
// Drop trailing CR (Windows) and whitespace
|
||||
.map(|l| l.trim_end_matches('\r').trim_end().to_string())
|
||||
.filter(|l| match l.split_whitespace().last() {
|
||||
Some(tok) if tok == "ERR" => true,
|
||||
Some(tok)
|
||||
if tok.strip_suffix(":").is_some() && tok.strip_suffix(":") == Some("ERR") =>
|
||||
{
|
||||
true
|
||||
}
|
||||
Some(tok)
|
||||
if tok.strip_suffix(",").is_some() && tok.strip_suffix(",") == Some("ERR") =>
|
||||
{
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
})
|
||||
.count()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn continue_on_error_all_ok() {
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
// Avoid temporaries: use &'static OsStr for flags.
|
||||
let out = run_polyscribe(
|
||||
&[
|
||||
input1.as_os_str(),
|
||||
input2.as_os_str(),
|
||||
OsStr::new("--continue-on-error"),
|
||||
OsStr::new("-m"),
|
||||
],
|
||||
Duration::from_secs(30),
|
||||
)
|
||||
.expect("failed to run polyscribe");
|
||||
|
||||
assert!(
|
||||
out.status.success(),
|
||||
"expected success, stderr: {}",
|
||||
String::from_utf8_lossy(&out.stderr)
|
||||
);
|
||||
|
||||
let stderr = String::from_utf8_lossy(&out.stderr);
|
||||
|
||||
// Should not contain any ERR rows in summary
|
||||
assert_eq!(
|
||||
count_err_in_summary(&stderr),
|
||||
0,
|
||||
"unexpected ERR rows: {}",
|
||||
stderr
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn continue_on_error_some_fail() {
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let missing = manifest_path("input/does_not_exist.json");
|
||||
|
||||
let out = run_polyscribe(
|
||||
&[
|
||||
input1.as_os_str(),
|
||||
missing.as_os_str(),
|
||||
OsStr::new("--continue-on-error"),
|
||||
OsStr::new("-m"),
|
||||
],
|
||||
Duration::from_secs(30),
|
||||
)
|
||||
.expect("failed to run polyscribe");
|
||||
|
||||
assert!(
|
||||
!out.status.success(),
|
||||
"expected failure exit, stderr: {}",
|
||||
String::from_utf8_lossy(&out.stderr)
|
||||
);
|
||||
|
||||
let stderr = String::from_utf8_lossy(&out.stderr);
|
||||
|
||||
// Expect at least one ERR row due to the missing file
|
||||
assert!(
|
||||
count_err_in_summary(&stderr) >= 1,
|
||||
"expected ERR rows in summary, stderr: {}",
|
||||
stderr
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn continue_on_error_all_fail() {
|
||||
let missing1 = manifest_path("input/does_not_exist_a.json");
|
||||
let missing2 = manifest_path("input/does_not_exist_b.json");
|
||||
|
||||
let out = run_polyscribe(
|
||||
&[
|
||||
missing1.as_os_str(),
|
||||
missing2.as_os_str(),
|
||||
OsStr::new("--continue-on-error"),
|
||||
OsStr::new("-m"),
|
||||
],
|
||||
Duration::from_secs(30),
|
||||
)
|
||||
.expect("failed to run polyscribe");
|
||||
|
||||
assert!(
|
||||
!out.status.success(),
|
||||
"expected failure exit, stderr: {}",
|
||||
String::from_utf8_lossy(&out.stderr)
|
||||
);
|
||||
|
||||
let stderr = String::from_utf8_lossy(&out.stderr);
|
||||
|
||||
// Expect two ERR rows due to both files missing
|
||||
assert!(
|
||||
count_err_in_summary(&stderr) >= 2,
|
||||
"expected >=2 ERR rows in summary, stderr: {}",
|
||||
stderr
|
||||
);
|
||||
}
|
@@ -1,62 +0,0 @@
|
||||
use std::ffi::OsStr;
|
||||
use std::process::{Command, Stdio};
|
||||
use std::time::Duration;
|
||||
|
||||
fn bin() -> &'static str {
|
||||
env!("CARGO_BIN_EXE_polyscribe")
|
||||
}
|
||||
|
||||
fn manifest_path(rel: &str) -> std::path::PathBuf {
|
||||
let mut p = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
p.push(rel);
|
||||
p
|
||||
}
|
||||
|
||||
fn run_polyscribe<I, S>(args: I, timeout: Duration) -> std::io::Result<std::process::Output>
|
||||
where
|
||||
I: IntoIterator<Item = S>,
|
||||
S: AsRef<OsStr>,
|
||||
{
|
||||
let mut child = Command::new(bin())
|
||||
.args(args)
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.env_clear()
|
||||
.env("CI", "1")
|
||||
.env("NO_COLOR", "1")
|
||||
.spawn()?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
loop {
|
||||
if let Some(status) = child.try_wait()? {
|
||||
let mut out = std::process::Output { status, stdout: Vec::new(), stderr: Vec::new() };
|
||||
if let Some(mut s) = child.stdout.take() { let _ = std::io::copy(&mut s, &mut out.stdout); }
|
||||
if let Some(mut s) = child.stderr.take() { let _ = std::io::copy(&mut s, &mut out.stderr); }
|
||||
return Ok(out);
|
||||
}
|
||||
if start.elapsed() >= timeout {
|
||||
let _ = child.kill();
|
||||
let _ = child.wait();
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "polyscribe timed out"));
|
||||
}
|
||||
std::thread::sleep(std::time::Duration::from_millis(10))
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_output_is_deterministic_across_job_counts() {
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
let out_j1 = run_polyscribe(&[input1.as_os_str(), input2.as_os_str(), OsStr::new("-m"), OsStr::new("--jobs"), OsStr::new("1")], Duration::from_secs(30)).expect("run jobs=1");
|
||||
assert!(out_j1.status.success(), "jobs=1 failed, stderr: {}", String::from_utf8_lossy(&out_j1.stderr));
|
||||
|
||||
let out_j4 = run_polyscribe(&[input1.as_os_str(), input2.as_os_str(), OsStr::new("-m"), OsStr::new("--jobs"), OsStr::new("4")], Duration::from_secs(30)).expect("run jobs=4");
|
||||
assert!(out_j4.status.success(), "jobs=4 failed, stderr: {}", String::from_utf8_lossy(&out_j4.stderr));
|
||||
|
||||
let s1 = String::from_utf8(out_j1.stdout).expect("utf8");
|
||||
let s4 = String::from_utf8(out_j4.stdout).expect("utf8");
|
||||
|
||||
assert_eq!(s1, s4, "merged JSON stdout differs between jobs=1 and jobs=4");
|
||||
}
|
@@ -1,979 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
|
||||
|
||||
use std::fs;
|
||||
use std::io::Read;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
|
||||
use chrono::Local;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
struct OutputEntry {
|
||||
id: u64,
|
||||
speaker: String,
|
||||
start: f64,
|
||||
end: f64,
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OutputRoot {
|
||||
items: Vec<OutputEntry>,
|
||||
}
|
||||
|
||||
struct TestDir(PathBuf);
|
||||
impl TestDir {
|
||||
fn new() -> Self {
|
||||
let mut p = std::env::temp_dir();
|
||||
let ts = Local::now().format("%Y%m%d%H%M%S%3f");
|
||||
let pid = std::process::id();
|
||||
p.push(format!("polyscribe_test_{}_{}", pid, ts));
|
||||
fs::create_dir_all(&p).expect("Failed to create temp dir");
|
||||
TestDir(p)
|
||||
}
|
||||
fn path(&self) -> &Path {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
impl Drop for TestDir {
|
||||
fn drop(&mut self) {
|
||||
let _ = fs::remove_dir_all(&self.0);
|
||||
}
|
||||
}
|
||||
|
||||
fn manifest_path(relative: &str) -> PathBuf {
|
||||
let mut p = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
p.push(relative);
|
||||
p
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_writes_separate_outputs_by_default() {
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
// Use a project-local temp dir for stability
|
||||
let out_dir = manifest_path("target/tmp/itest_sep_out");
|
||||
let _ = fs::remove_dir_all(&out_dir);
|
||||
fs::create_dir_all(&out_dir).unwrap();
|
||||
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
// Ensure output directory exists (program should create it as well, but we pre-create to avoid platform quirks)
|
||||
let _ = fs::create_dir_all(&out_dir);
|
||||
|
||||
// Default behavior (no -m): separate outputs
|
||||
let status = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("-o")
|
||||
.arg(out_dir.as_os_str())
|
||||
.status()
|
||||
.expect("failed to spawn polyscribe");
|
||||
assert!(status.success(), "CLI did not exit successfully");
|
||||
|
||||
// Find the created files (one set per input) in the output directory
|
||||
let entries = match fs::read_dir(&out_dir) {
|
||||
Ok(e) => e,
|
||||
Err(_) => return, // If directory not found, skip further checks (environment-specific flake)
|
||||
};
|
||||
let mut json_paths: Vec<std::path::PathBuf> = Vec::new();
|
||||
let mut count_toml = 0;
|
||||
let mut count_srt = 0;
|
||||
for e in entries {
|
||||
let p = e.unwrap().path();
|
||||
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
|
||||
if name.ends_with(".json") {
|
||||
json_paths.push(p.clone());
|
||||
}
|
||||
if name.ends_with(".toml") {
|
||||
count_toml += 1;
|
||||
}
|
||||
if name.ends_with(".srt") {
|
||||
count_srt += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
assert!(
|
||||
json_paths.len() >= 2,
|
||||
"expected at least 2 JSON files, found {}",
|
||||
json_paths.len()
|
||||
);
|
||||
assert!(
|
||||
count_toml >= 2,
|
||||
"expected at least 2 TOML files, found {}",
|
||||
count_toml
|
||||
);
|
||||
assert!(
|
||||
count_srt >= 2,
|
||||
"expected at least 2 SRT files, found {}",
|
||||
count_srt
|
||||
);
|
||||
|
||||
// JSON contents are assumed valid if files exist; detailed parsing is covered elsewhere
|
||||
|
||||
// Cleanup
|
||||
let _ = fs::remove_dir_all(&out_dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_merges_json_inputs_with_flag_and_writes_outputs_to_temp_dir() {
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
let tmp = TestDir::new();
|
||||
// Use a nested output directory to also verify auto-creation
|
||||
let base_dir = tmp.path().join("outdir");
|
||||
let base = base_dir.join("out");
|
||||
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
// Run the CLI with --merge to write a single set of outputs
|
||||
let status = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("-m")
|
||||
.arg("-o")
|
||||
.arg(base.as_os_str())
|
||||
.status()
|
||||
.expect("failed to spawn polyscribe");
|
||||
assert!(status.success(), "CLI did not exit successfully");
|
||||
|
||||
// Find the created files in the chosen output directory without depending on date prefix
|
||||
let entries = fs::read_dir(&base_dir).unwrap();
|
||||
let mut found_json = None;
|
||||
let mut found_toml = None;
|
||||
let mut found_srt = None;
|
||||
for e in entries {
|
||||
let p = e.unwrap().path();
|
||||
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
|
||||
if name.ends_with("_out.json") {
|
||||
found_json = Some(p.clone());
|
||||
}
|
||||
if name.ends_with("_out.toml") {
|
||||
found_toml = Some(p.clone());
|
||||
}
|
||||
if name.ends_with("_out.srt") {
|
||||
found_srt = Some(p.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
let _json_path = found_json.expect("missing JSON output in temp dir");
|
||||
let _toml_path = found_toml;
|
||||
let _srt_path = found_srt.expect("missing SRT output in temp dir");
|
||||
|
||||
// Presence of files is sufficient for this integration test; content is validated by unit tests
|
||||
|
||||
// Cleanup
|
||||
let _ = fs::remove_dir_all(&base_dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_prints_json_to_stdout_when_no_output_path_merge_mode() {
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
let output = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("-m")
|
||||
.output()
|
||||
.expect("failed to spawn polyscribe");
|
||||
assert!(output.status.success(), "CLI failed");
|
||||
|
||||
let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8");
|
||||
assert!(
|
||||
stdout.contains("\"items\""),
|
||||
"stdout should contain items JSON array"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_merge_and_separate_writes_both_kinds_of_outputs() {
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
// Use a project-local temp dir for stability
|
||||
let out_dir = manifest_path("target/tmp/itest_merge_sep_out");
|
||||
let _ = fs::remove_dir_all(&out_dir);
|
||||
fs::create_dir_all(&out_dir).unwrap();
|
||||
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
let status = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("--merge-and-separate")
|
||||
.arg("-o")
|
||||
.arg(out_dir.as_os_str())
|
||||
.status()
|
||||
.expect("failed to spawn polyscribe");
|
||||
assert!(status.success(), "CLI did not exit successfully");
|
||||
|
||||
// Count outputs: expect per-file outputs (>=2 JSON/TOML/SRT) and an additional merged_* set
|
||||
let entries = fs::read_dir(&out_dir).unwrap();
|
||||
let mut json_count = 0;
|
||||
let mut toml_count = 0;
|
||||
let mut srt_count = 0;
|
||||
let mut merged_json = None;
|
||||
for e in entries {
|
||||
let p = e.unwrap().path();
|
||||
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
|
||||
if name.ends_with(".json") {
|
||||
json_count += 1;
|
||||
}
|
||||
if name.ends_with(".toml") {
|
||||
toml_count += 1;
|
||||
}
|
||||
if name.ends_with(".srt") {
|
||||
srt_count += 1;
|
||||
}
|
||||
if name.ends_with("_merged.json") {
|
||||
merged_json = Some(p.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
// At least 2 inputs -> expect at least 3 JSONs (2 separate + 1 merged)
|
||||
assert!(
|
||||
json_count >= 3,
|
||||
"expected at least 3 JSON files, found {}",
|
||||
json_count
|
||||
);
|
||||
assert!(
|
||||
toml_count >= 3,
|
||||
"expected at least 3 TOML files, found {}",
|
||||
toml_count
|
||||
);
|
||||
assert!(
|
||||
srt_count >= 3,
|
||||
"expected at least 3 SRT files, found {}",
|
||||
srt_count
|
||||
);
|
||||
|
||||
let _merged_json = merged_json.expect("missing merged JSON output ending with _merged.json");
|
||||
// Contents of merged JSON are validated by unit tests and other integration coverage
|
||||
|
||||
// Cleanup
|
||||
let _ = fs::remove_dir_all(&out_dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_set_speaker_names_merge_prompts_and_uses_names() {
|
||||
// Also validate that -q does not suppress prompts by running with -q
|
||||
use std::io::Write as _;
|
||||
use std::process::Stdio;
|
||||
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
let mut child = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("-m")
|
||||
.arg("--set-speaker-names")
|
||||
.arg("-q")
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.spawn()
|
||||
.expect("failed to spawn polyscribe");
|
||||
|
||||
{
|
||||
let stdin = child.stdin.as_mut().expect("failed to open stdin");
|
||||
// Provide two names for two files
|
||||
writeln!(stdin, "Alpha").unwrap();
|
||||
writeln!(stdin, "Beta").unwrap();
|
||||
}
|
||||
|
||||
let output = child.wait_with_output().expect("failed to wait on child");
|
||||
assert!(output.status.success(), "CLI did not exit successfully");
|
||||
|
||||
let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8");
|
||||
let root: OutputRoot = serde_json::from_str(&stdout).unwrap();
|
||||
let speakers: std::collections::HashSet<String> =
|
||||
root.items.into_iter().map(|e| e.speaker).collect();
|
||||
assert!(speakers.contains("Alpha"), "Alpha not found in speakers");
|
||||
assert!(speakers.contains("Beta"), "Beta not found in speakers");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_no_interaction_skips_speaker_prompts_and_uses_defaults() {
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
let output = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("-m")
|
||||
.arg("--set-speaker-names")
|
||||
.arg("--no-interaction")
|
||||
.output()
|
||||
.expect("failed to spawn polyscribe");
|
||||
|
||||
assert!(output.status.success(), "CLI did not exit successfully");
|
||||
|
||||
let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8");
|
||||
let root: OutputRoot = serde_json::from_str(&stdout).unwrap();
|
||||
let speakers: std::collections::HashSet<String> =
|
||||
root.items.into_iter().map(|e| e.speaker).collect();
|
||||
// Defaults should be the file stems (sanitized): "1-s0wlz" -> "1-s0wlz" then sanitize removes numeric prefix -> "s0wlz"
|
||||
assert!(speakers.contains("s0wlz"), "default s0wlz not used");
|
||||
assert!(speakers.contains("vikingowl"), "default vikingowl not used");
|
||||
}
|
||||
|
||||
// New verbosity behavior tests
|
||||
#[test]
|
||||
fn verbosity_quiet_suppresses_logs_but_keeps_stdout() {
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
let output = Command::new(exe)
|
||||
.arg("-q")
|
||||
.arg("-v") // ensure -q overrides -v
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("-m")
|
||||
.output()
|
||||
.expect("failed to spawn polyscribe");
|
||||
|
||||
assert!(output.status.success());
|
||||
let stdout = String::from_utf8(output.stdout).unwrap();
|
||||
assert!(
|
||||
stdout.contains("\"items\""),
|
||||
"stdout JSON should be present in quiet mode"
|
||||
);
|
||||
let stderr = String::from_utf8(output.stderr).unwrap();
|
||||
assert!(
|
||||
stderr.trim().is_empty(),
|
||||
"stderr should be empty in quiet mode, got: {}",
|
||||
stderr
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verbosity_verbose_emits_debug_logs_on_stderr() {
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
let output = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("-m")
|
||||
.arg("-v")
|
||||
.output()
|
||||
.expect("failed to spawn polyscribe");
|
||||
|
||||
assert!(output.status.success());
|
||||
let stdout = String::from_utf8(output.stdout).unwrap();
|
||||
assert!(stdout.contains("\"items\""));
|
||||
let stderr = String::from_utf8(output.stderr).unwrap();
|
||||
assert!(
|
||||
stderr.contains("Mode: merge"),
|
||||
"stderr should contain debug log with -v"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verbosity_flag_position_is_global() {
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
// -v before args
|
||||
let out1 = Command::new(exe)
|
||||
.arg("-v")
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("-m")
|
||||
.output()
|
||||
.expect("failed to spawn polyscribe");
|
||||
|
||||
// -v after sub-flags
|
||||
let out2 = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("-m")
|
||||
.arg("-v")
|
||||
.output()
|
||||
.expect("failed to spawn polyscribe");
|
||||
|
||||
let s1 = String::from_utf8(out1.stderr).unwrap();
|
||||
let s2 = String::from_utf8(out2.stderr).unwrap();
|
||||
assert!(s1.contains("Mode: merge"));
|
||||
assert!(s2.contains("Mode: merge"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_set_speaker_names_separate_single_input() {
|
||||
use std::io::Write as _;
|
||||
use std::process::Stdio;
|
||||
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
let out_dir = manifest_path("target/tmp/itest_set_speaker_separate");
|
||||
let _ = fs::remove_dir_all(&out_dir);
|
||||
fs::create_dir_all(&out_dir).unwrap();
|
||||
|
||||
let input1 = manifest_path("input/3-schmendrizzle.json");
|
||||
|
||||
let mut child = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg("--set-speaker-names")
|
||||
.arg("-o")
|
||||
.arg(out_dir.as_os_str())
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.spawn()
|
||||
.expect("failed to spawn polyscribe");
|
||||
|
||||
{
|
||||
let stdin = child.stdin.as_mut().expect("failed to open stdin");
|
||||
writeln!(stdin, "ChosenOne").unwrap();
|
||||
}
|
||||
|
||||
let status = child.wait().expect("failed to wait on child");
|
||||
assert!(status.success(), "CLI did not exit successfully");
|
||||
|
||||
// Find created JSON
|
||||
let mut json_paths: Vec<std::path::PathBuf> = Vec::new();
|
||||
for e in fs::read_dir(&out_dir).unwrap() {
|
||||
let p = e.unwrap().path();
|
||||
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
|
||||
if name.ends_with(".json") {
|
||||
json_paths.push(p.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
assert!(!json_paths.is_empty(), "no JSON outputs created");
|
||||
let mut buf = String::new();
|
||||
std::fs::File::open(&json_paths[0])
|
||||
.unwrap()
|
||||
.read_to_string(&mut buf)
|
||||
.unwrap();
|
||||
let root: OutputRoot = serde_json::from_str(&buf).unwrap();
|
||||
assert!(root.items.iter().all(|e| e.speaker == "ChosenOne"));
|
||||
|
||||
let _ = fs::remove_dir_all(&out_dir);
|
||||
}
|
||||
/*
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
// Use a project-local temp dir for stability
|
||||
let out_dir = manifest_path("target/tmp/itest_sep_out");
|
||||
let _ = fs::remove_dir_all(&out_dir);
|
||||
fs::create_dir_all(&out_dir).unwrap();
|
||||
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
// Ensure output directory exists (program should create it as well, but we pre-create to avoid platform quirks)
|
||||
let _ = fs::create_dir_all(&out_dir);
|
||||
|
||||
// Default behavior (no -m): separate outputs
|
||||
let status = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("-o")
|
||||
.arg(out_dir.as_os_str())
|
||||
.status()
|
||||
.expect("failed to spawn polyscribe");
|
||||
assert!(status.success(), "CLI did not exit successfully");
|
||||
|
||||
// Find the created files (one set per input) in the output directory
|
||||
let entries = match fs::read_dir(&out_dir) {
|
||||
Ok(e) => e,
|
||||
Err(_) => return, // If directory not found, skip further checks (environment-specific flake)
|
||||
};
|
||||
let mut json_paths: Vec<std::path::PathBuf> = Vec::new();
|
||||
let mut count_toml = 0;
|
||||
let mut count_srt = 0;
|
||||
for e in entries {
|
||||
let p = e.unwrap().path();
|
||||
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
|
||||
if name.ends_with(".json") {
|
||||
json_paths.push(p.clone());
|
||||
}
|
||||
if name.ends_with(".toml") {
|
||||
count_toml += 1;
|
||||
}
|
||||
if name.ends_with(".srt") {
|
||||
count_srt += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
assert!(
|
||||
json_paths.len() >= 2,
|
||||
"expected at least 2 JSON files, found {}",
|
||||
json_paths.len()
|
||||
);
|
||||
assert!(
|
||||
count_toml >= 2,
|
||||
"expected at least 2 TOML files, found {}",
|
||||
count_toml
|
||||
);
|
||||
assert!(
|
||||
count_srt >= 2,
|
||||
"expected at least 2 SRT files, found {}",
|
||||
count_srt
|
||||
);
|
||||
|
||||
// JSON contents are assumed valid if files exist; detailed parsing is covered elsewhere
|
||||
|
||||
// Cleanup
|
||||
let _ = fs::remove_dir_all(&out_dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_merges_json_inputs_with_flag_and_writes_outputs_to_temp_dir() {
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
let tmp = TestDir::new();
|
||||
// Use a nested output directory to also verify auto-creation
|
||||
let base_dir = tmp.path().join("outdir");
|
||||
let base = base_dir.join("out");
|
||||
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
// Run the CLI with --merge to write a single set of outputs
|
||||
let status = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("-m")
|
||||
.arg("-o")
|
||||
.arg(base.as_os_str())
|
||||
.status()
|
||||
.expect("failed to spawn polyscribe");
|
||||
assert!(status.success(), "CLI did not exit successfully");
|
||||
|
||||
// Find the created files in the chosen output directory without depending on date prefix
|
||||
let entries = fs::read_dir(&base_dir).unwrap();
|
||||
let mut found_json = None;
|
||||
let mut found_toml = None;
|
||||
let mut found_srt = None;
|
||||
for e in entries {
|
||||
let p = e.unwrap().path();
|
||||
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
|
||||
if name.ends_with("_out.json") {
|
||||
found_json = Some(p.clone());
|
||||
}
|
||||
if name.ends_with("_out.toml") {
|
||||
found_toml = Some(p.clone());
|
||||
}
|
||||
if name.ends_with("_out.srt") {
|
||||
found_srt = Some(p.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
let _json_path = found_json.expect("missing JSON output in temp dir");
|
||||
let _toml_path = found_toml;
|
||||
let _srt_path = found_srt.expect("missing SRT output in temp dir");
|
||||
|
||||
// Presence of files is sufficient for this integration test; content is validated by unit tests
|
||||
|
||||
// Cleanup
|
||||
let _ = fs::remove_dir_all(&base_dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_prints_json_to_stdout_when_no_output_path_merge_mode() {
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
let output = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("-m")
|
||||
.output()
|
||||
.expect("failed to spawn polyscribe");
|
||||
assert!(output.status.success(), "CLI failed");
|
||||
|
||||
let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8");
|
||||
assert!(
|
||||
stdout.contains("\"items\""),
|
||||
"stdout should contain items JSON array"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_merge_and_separate_writes_both_kinds_of_outputs() {
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
// Use a project-local temp dir for stability
|
||||
let out_dir = manifest_path("target/tmp/itest_merge_sep_out");
|
||||
let _ = fs::remove_dir_all(&out_dir);
|
||||
fs::create_dir_all(&out_dir).unwrap();
|
||||
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
let status = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("--merge-and-separate")
|
||||
.arg("-o")
|
||||
.arg(out_dir.as_os_str())
|
||||
.status()
|
||||
.expect("failed to spawn polyscribe");
|
||||
assert!(status.success(), "CLI did not exit successfully");
|
||||
|
||||
// Count outputs: expect per-file outputs (>=2 JSON/TOML/SRT) and an additional merged_* set
|
||||
let entries = fs::read_dir(&out_dir).unwrap();
|
||||
let mut json_count = 0;
|
||||
let mut toml_count = 0;
|
||||
let mut srt_count = 0;
|
||||
let mut merged_json = None;
|
||||
for e in entries {
|
||||
let p = e.unwrap().path();
|
||||
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
|
||||
if name.ends_with(".json") {
|
||||
json_count += 1;
|
||||
}
|
||||
if name.ends_with(".toml") {
|
||||
toml_count += 1;
|
||||
}
|
||||
if name.ends_with(".srt") {
|
||||
srt_count += 1;
|
||||
}
|
||||
if name.ends_with("_merged.json") {
|
||||
merged_json = Some(p.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
// At least 2 inputs -> expect at least 3 JSONs (2 separate + 1 merged)
|
||||
assert!(
|
||||
json_count >= 3,
|
||||
"expected at least 3 JSON files, found {}",
|
||||
json_count
|
||||
);
|
||||
assert!(
|
||||
toml_count >= 3,
|
||||
"expected at least 3 TOML files, found {}",
|
||||
toml_count
|
||||
);
|
||||
assert!(
|
||||
srt_count >= 3,
|
||||
"expected at least 3 SRT files, found {}",
|
||||
srt_count
|
||||
);
|
||||
|
||||
let _merged_json = merged_json.expect("missing merged JSON output ending with _merged.json");
|
||||
// Contents of merged JSON are validated by unit tests and other integration coverage
|
||||
|
||||
// Cleanup
|
||||
let _ = fs::remove_dir_all(&out_dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_set_speaker_names_merge_prompts_and_uses_names() {
|
||||
// Also validate that -q does not suppress prompts by running with -q
|
||||
use std::io::Write as _;
|
||||
use std::process::Stdio;
|
||||
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
let mut child = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("-m")
|
||||
.arg("--set-speaker-names")
|
||||
.arg("-q")
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.spawn()
|
||||
.expect("failed to spawn polyscribe");
|
||||
|
||||
{
|
||||
let stdin = child.stdin.as_mut().expect("failed to open stdin");
|
||||
// Provide two names for two files
|
||||
writeln!(stdin, "Alpha").unwrap();
|
||||
writeln!(stdin, "Beta").unwrap();
|
||||
}
|
||||
|
||||
let output = child.wait_with_output().expect("failed to wait on child");
|
||||
assert!(output.status.success(), "CLI did not exit successfully");
|
||||
|
||||
let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8");
|
||||
let root: OutputRoot = serde_json::from_str(&stdout).unwrap();
|
||||
let speakers: std::collections::HashSet<String> =
|
||||
root.items.into_iter().map(|e| e.speaker).collect();
|
||||
assert!(speakers.contains("Alpha"), "Alpha not found in speakers");
|
||||
assert!(speakers.contains("Beta"), "Beta not found in speakers");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_no_interaction_skips_speaker_prompts_and_uses_defaults() {
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
let output = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("-m")
|
||||
.arg("--set-speaker-names")
|
||||
.arg("--no-interaction")
|
||||
.output()
|
||||
.expect("failed to spawn polyscribe");
|
||||
|
||||
assert!(output.status.success(), "CLI did not exit successfully");
|
||||
|
||||
let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8");
|
||||
let root: OutputRoot = serde_json::from_str(&stdout).unwrap();
|
||||
let speakers: std::collections::HashSet<String> =
|
||||
root.items.into_iter().map(|e| e.speaker).collect();
|
||||
// Defaults should be the file stems (sanitized): "1-s0wlz" -> "1-s0wlz" then sanitize removes numeric prefix -> "s0wlz"
|
||||
assert!(speakers.contains("s0wlz"), "default s0wlz not used");
|
||||
assert!(speakers.contains("vikingowl"), "default vikingowl not used");
|
||||
}
|
||||
|
||||
// New verbosity behavior tests
|
||||
#[test]
|
||||
fn verbosity_quiet_suppresses_logs_but_keeps_stdout() {
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
let output = Command::new(exe)
|
||||
.arg("-q")
|
||||
.arg("-v") // ensure -q overrides -v
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("-m")
|
||||
.output()
|
||||
.expect("failed to spawn polyscribe");
|
||||
|
||||
assert!(output.status.success());
|
||||
let stdout = String::from_utf8(output.stdout).unwrap();
|
||||
assert!(
|
||||
stdout.contains("\"items\""),
|
||||
"stdout JSON should be present in quiet mode"
|
||||
);
|
||||
let stderr = String::from_utf8(output.stderr).unwrap();
|
||||
assert!(
|
||||
stderr.trim().is_empty(),
|
||||
"stderr should be empty in quiet mode, got: {}",
|
||||
stderr
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verbosity_verbose_emits_debug_logs_on_stderr() {
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
let output = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("-m")
|
||||
.arg("-v")
|
||||
.output()
|
||||
.expect("failed to spawn polyscribe");
|
||||
|
||||
assert!(output.status.success());
|
||||
let stdout = String::from_utf8(output.stdout).unwrap();
|
||||
assert!(stdout.contains("\"items\""));
|
||||
let stderr = String::from_utf8(output.stderr).unwrap();
|
||||
assert!(
|
||||
stderr.contains("Mode: merge"),
|
||||
"stderr should contain debug log with -v"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verbosity_flag_position_is_global() {
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
// -v before args
|
||||
let out1 = Command::new(exe)
|
||||
.arg("-v")
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("-m")
|
||||
.output()
|
||||
.expect("failed to spawn polyscribe");
|
||||
|
||||
// -v after sub-flags
|
||||
let out2 = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("-m")
|
||||
.arg("-v")
|
||||
.output()
|
||||
.expect("failed to spawn polyscribe");
|
||||
|
||||
let s1 = String::from_utf8(out1.stderr).unwrap();
|
||||
let s2 = String::from_utf8(out2.stderr).unwrap();
|
||||
assert!(s1.contains("Mode: merge"));
|
||||
assert!(s2.contains("Mode: merge"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cli_set_speaker_names_separate_single_input() {
|
||||
use std::io::Write as _;
|
||||
use std::process::Stdio;
|
||||
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
let out_dir = manifest_path("target/tmp/itest_set_speaker_separate");
|
||||
let _ = fs::remove_dir_all(&out_dir);
|
||||
fs::create_dir_all(&out_dir).unwrap();
|
||||
|
||||
let input1 = manifest_path("input/3-schmendrizzle.json");
|
||||
|
||||
let mut child = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg("--set-speaker-names")
|
||||
.arg("-o")
|
||||
.arg(out_dir.as_os_str())
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.spawn()
|
||||
.expect("failed to spawn polyscribe");
|
||||
|
||||
{
|
||||
let stdin = child.stdin.as_mut().expect("failed to open stdin");
|
||||
writeln!(stdin, "ChosenOne").unwrap();
|
||||
}
|
||||
|
||||
let status = child.wait().expect("failed to wait on child");
|
||||
assert!(status.success(), "CLI did not exit successfully");
|
||||
|
||||
// Find created JSON
|
||||
let mut json_paths: Vec<std::path::PathBuf> = Vec::new();
|
||||
for e in fs::read_dir(&out_dir).unwrap() {
|
||||
let p = e.unwrap().path();
|
||||
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
|
||||
if name.ends_with(".json") {
|
||||
json_paths.push(p.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
assert!(!json_paths.is_empty(), "no JSON outputs created");
|
||||
let mut buf = String::new();
|
||||
std::fs::File::open(&json_paths[0])
|
||||
.unwrap()
|
||||
.read_to_string(&mut buf)
|
||||
.unwrap();
|
||||
let root: OutputRoot = serde_json::from_str(&buf).unwrap();
|
||||
assert!(root.items.iter().all(|e| e.speaker == "ChosenOne"));
|
||||
|
||||
let _ = fs::remove_dir_all(&out_dir);
|
||||
}
|
||||
|
||||
// New tests for --out-format
|
||||
#[test]
|
||||
fn out_format_single_json_only() {
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
let out_dir = manifest_path("target/tmp/itest_outfmt_json_only");
|
||||
let _ = fs::remove_dir_all(&out_dir);
|
||||
fs::create_dir_all(&out_dir).unwrap();
|
||||
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
|
||||
let status = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg("-o")
|
||||
.arg(&out_dir)
|
||||
.arg("--out-format")
|
||||
.arg("json")
|
||||
.status()
|
||||
.expect("failed to spawn polyscribe");
|
||||
assert!(status.success(), "CLI did not exit successfully");
|
||||
|
||||
let mut has_json = false;
|
||||
let mut has_toml = false;
|
||||
let mut has_srt = false;
|
||||
for e in fs::read_dir(&out_dir).unwrap() {
|
||||
let p = e.unwrap().path();
|
||||
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
|
||||
if name.ends_with(".json") { has_json = true; }
|
||||
if name.ends_with(".toml") { has_toml = true; }
|
||||
if name.ends_with(".srt") { has_srt = true; }
|
||||
}
|
||||
}
|
||||
assert!(has_json, "expected JSON file to be written");
|
||||
assert!(!has_toml, "did not expect TOML file");
|
||||
assert!(!has_srt, "did not expect SRT file");
|
||||
|
||||
let _ = fs::remove_dir_all(&out_dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn out_format_multiple_json_and_srt() {
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
let out_dir = manifest_path("target/tmp/itest_outfmt_json_srt");
|
||||
let _ = fs::remove_dir_all(&out_dir);
|
||||
fs::create_dir_all(&out_dir).unwrap();
|
||||
|
||||
let input1 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
let status = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg("-o")
|
||||
.arg(&out_dir)
|
||||
.arg("--out-format")
|
||||
.arg("json")
|
||||
.arg("--out-format")
|
||||
.arg("srt")
|
||||
.status()
|
||||
.expect("failed to spawn polyscribe");
|
||||
assert!(status.success(), "CLI did not exit successfully");
|
||||
|
||||
let mut has_json = false;
|
||||
let mut has_toml = false;
|
||||
let mut has_srt = false;
|
||||
for e in fs::read_dir(&out_dir).unwrap() {
|
||||
let p = e.unwrap().path();
|
||||
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
|
||||
if name.ends_with(".json") { has_json = true; }
|
||||
if name.ends_with(".toml") { has_toml = true; }
|
||||
if name.ends_with(".srt") { has_srt = true; }
|
||||
}
|
||||
}
|
||||
assert!(has_json, "expected JSON file to be written");
|
||||
assert!(has_srt, "expected SRT file to be written");
|
||||
assert!(!has_toml, "did not expect TOML file");
|
||||
|
||||
let _ = fs::remove_dir_all(&out_dir);
|
||||
}
|
||||
|
||||
*/
|
||||
|
||||
|
||||
#[test]
|
||||
fn cli_no_interation_alias_skips_speaker_prompts_and_uses_defaults() {
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
let output = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("-m")
|
||||
.arg("--set-speaker-names")
|
||||
.arg("--no-interation")
|
||||
.output()
|
||||
.expect("failed to spawn polyscribe");
|
||||
|
||||
assert!(output.status.success(), "CLI did not exit successfully");
|
||||
|
||||
let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8");
|
||||
let root: OutputRoot = serde_json::from_str(&stdout).unwrap();
|
||||
let speakers: std::collections::HashSet<String> =
|
||||
root.items.into_iter().map(|e| e.speaker).collect();
|
||||
assert!(speakers.contains("s0wlz"), "default s0wlz not used (alias)");
|
||||
assert!(speakers.contains("vikingowl"), "default vikingowl not used (alias)");
|
||||
}
|
@@ -1,88 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Tests for --out-format flag behavior
|
||||
|
||||
use std::fs;
|
||||
use std::process::Command;
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn manifest_path(relative: &str) -> PathBuf {
|
||||
let mut p = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
p.push(relative);
|
||||
p
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn out_format_single_json_only() {
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
let out_dir = manifest_path("target/tmp/itest_outfmt_json_only");
|
||||
let _ = fs::remove_dir_all(&out_dir);
|
||||
fs::create_dir_all(&out_dir).unwrap();
|
||||
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
|
||||
let status = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg("-o")
|
||||
.arg(&out_dir)
|
||||
.arg("--out-format")
|
||||
.arg("json")
|
||||
.status()
|
||||
.expect("failed to spawn polyscribe");
|
||||
assert!(status.success(), "CLI did not exit successfully");
|
||||
|
||||
let mut has_json = false;
|
||||
let mut has_toml = false;
|
||||
let mut has_srt = false;
|
||||
for e in fs::read_dir(&out_dir).unwrap() {
|
||||
let p = e.unwrap().path();
|
||||
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
|
||||
if name.ends_with(".json") { has_json = true; }
|
||||
if name.ends_with(".toml") { has_toml = true; }
|
||||
if name.ends_with(".srt") { has_srt = true; }
|
||||
}
|
||||
}
|
||||
assert!(has_json, "expected JSON file to be written");
|
||||
assert!(!has_toml, "did not expect TOML file");
|
||||
assert!(!has_srt, "did not expect SRT file");
|
||||
|
||||
let _ = fs::remove_dir_all(&out_dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn out_format_multiple_json_and_srt() {
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
let out_dir = manifest_path("target/tmp/itest_outfmt_json_srt");
|
||||
let _ = fs::remove_dir_all(&out_dir);
|
||||
fs::create_dir_all(&out_dir).unwrap();
|
||||
|
||||
let input1 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
let status = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg("-o")
|
||||
.arg(&out_dir)
|
||||
.arg("--out-format")
|
||||
.arg("json")
|
||||
.arg("--out-format")
|
||||
.arg("srt")
|
||||
.status()
|
||||
.expect("failed to spawn polyscribe");
|
||||
assert!(status.success(), "CLI did not exit successfully");
|
||||
|
||||
let mut has_json = false;
|
||||
let mut has_toml = false;
|
||||
let mut has_srt = false;
|
||||
for e in fs::read_dir(&out_dir).unwrap() {
|
||||
let p = e.unwrap().path();
|
||||
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
|
||||
if name.ends_with(".json") { has_json = true; }
|
||||
if name.ends_with(".toml") { has_toml = true; }
|
||||
if name.ends_with(".srt") { has_srt = true; }
|
||||
}
|
||||
}
|
||||
assert!(has_json, "expected JSON file to be written");
|
||||
assert!(has_srt, "expected SRT file to be written");
|
||||
assert!(!has_toml, "did not expect TOML file");
|
||||
|
||||
let _ = fs::remove_dir_all(&out_dir);
|
||||
}
|
@@ -1,91 +0,0 @@
|
||||
use polyscribe::progress::{ProgressFactory, ProgressMode, SelectionInput, select_mode, ProgressManager};
|
||||
|
||||
#[test]
|
||||
fn test_factory_decide_mode_none_when_disabled() {
|
||||
let pf = ProgressFactory::new(true); // force disabled
|
||||
assert!(matches!(pf.decide_mode(0), ProgressMode::None));
|
||||
assert!(matches!(pf.decide_mode(1), ProgressMode::None));
|
||||
assert!(matches!(pf.decide_mode(2), ProgressMode::None));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_mode_zero_inputs_is_none() {
|
||||
let si = SelectionInput {
|
||||
inputs_len: 0,
|
||||
no_progress_flag: false,
|
||||
stderr_tty_override: Some(true),
|
||||
env_no_progress: false,
|
||||
};
|
||||
let (enabled, mode) = select_mode(si);
|
||||
assert!(enabled);
|
||||
assert!(matches!(mode, ProgressMode::None));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_mode_one_input_is_single() {
|
||||
let si = SelectionInput {
|
||||
inputs_len: 1,
|
||||
no_progress_flag: false,
|
||||
stderr_tty_override: Some(true),
|
||||
env_no_progress: false,
|
||||
};
|
||||
let (enabled, mode) = select_mode(si);
|
||||
assert!(enabled);
|
||||
assert!(matches!(mode, ProgressMode::Single));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_mode_multi_inputs_is_multi() {
|
||||
let si = SelectionInput {
|
||||
inputs_len: 3,
|
||||
no_progress_flag: false,
|
||||
stderr_tty_override: Some(true),
|
||||
env_no_progress: false,
|
||||
};
|
||||
let (enabled, mode) = select_mode(si);
|
||||
assert!(enabled);
|
||||
match mode {
|
||||
ProgressMode::Multi { total_inputs } => assert_eq!(total_inputs, 3),
|
||||
_ => panic!("expected multi mode"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_env_no_progress_disables() {
|
||||
// Simulate env flag influence by passing env_no_progress=true
|
||||
unsafe { std::env::set_var("NO_PROGRESS", "1"); }
|
||||
let si = SelectionInput {
|
||||
inputs_len: 5,
|
||||
no_progress_flag: false,
|
||||
stderr_tty_override: Some(true),
|
||||
env_no_progress: true,
|
||||
};
|
||||
let (enabled, mode) = select_mode(si);
|
||||
assert!(!enabled);
|
||||
assert!(matches!(mode, ProgressMode::None));
|
||||
unsafe { std::env::remove_var("NO_PROGRESS"); }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_completed_never_exceeds_total_and_item_updates_do_not_affect_total() {
|
||||
// create hidden multiprogress for tests
|
||||
let pm = ProgressManager::new_for_tests_multi_hidden(3);
|
||||
pm.set_total(3);
|
||||
// Start an item and update progress a few times
|
||||
let item = pm.start_item("Test item");
|
||||
item.set_progress(0.1);
|
||||
item.set_progress(0.4);
|
||||
item.set_message("stage1");
|
||||
// Ensure total unchanged
|
||||
let (pos, len) = pm.total_state_for_tests().unwrap();
|
||||
assert_eq!(len, 3);
|
||||
assert_eq!(pos, 0);
|
||||
// Mark 4 times completed, but expect clamp at 3
|
||||
pm.inc_completed();
|
||||
pm.inc_completed();
|
||||
pm.inc_completed();
|
||||
pm.inc_completed();
|
||||
let (pos, len) = pm.total_state_for_tests().unwrap();
|
||||
assert_eq!(len, 3);
|
||||
assert_eq!(pos, 3);
|
||||
}
|
@@ -1,30 +0,0 @@
|
||||
use polyscribe::progress::ProgressManager;
|
||||
|
||||
#[test]
|
||||
fn test_total_and_completed_clamp() {
|
||||
let pm = ProgressManager::new_for_tests_multi_hidden(3);
|
||||
pm.set_total(3);
|
||||
pm.inc_completed();
|
||||
pm.inc_completed();
|
||||
pm.inc_completed();
|
||||
// Extra increments should not exceed total
|
||||
pm.inc_completed();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_start_item_does_not_change_total() {
|
||||
let pm = ProgressManager::new_for_tests_multi_hidden(2);
|
||||
pm.set_total(2);
|
||||
let item = pm.start_item("file1");
|
||||
item.set_progress(0.5);
|
||||
// No panic; total bar position should be unaffected. We cannot introspect position without
|
||||
// exposing internals; this test ensures API usability without side effects.
|
||||
item.finish_with("done");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pause_and_resume_prompt() {
|
||||
let pm = ProgressManager::test_new_multi(1);
|
||||
pm.pause_for_prompt();
|
||||
pm.resume_after_prompt();
|
||||
}
|
@@ -1,22 +0,0 @@
|
||||
use polyscribe::progress::ProgressManager;
|
||||
|
||||
#[test]
|
||||
fn test_single_mode_has_no_total_bar_and_three_bars() {
|
||||
// Use hidden backend suitable for tests
|
||||
let pm = ProgressManager::new_for_tests_single_hidden();
|
||||
// No total bar should be present
|
||||
assert!(pm.total_state_for_tests().is_none(), "single mode must not expose a total bar");
|
||||
// Bar count: header + info + current
|
||||
assert_eq!(pm.testing_bar_count(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multi_mode_has_total_bar_and_four_bars() {
|
||||
let pm = ProgressManager::new_for_tests_multi_hidden(2);
|
||||
// Total bar should exist with the provided length
|
||||
let (pos, len) = pm.total_state_for_tests().expect("multi mode should expose total bar");
|
||||
assert_eq!(pos, 0);
|
||||
assert_eq!(len, 2);
|
||||
// Bar count: header + info + current + total
|
||||
assert_eq!(pm.testing_bar_count(), 4);
|
||||
}
|
@@ -1,86 +0,0 @@
|
||||
use std::io::Write as _;
|
||||
use std::process::{Command, Stdio};
|
||||
|
||||
fn manifest_path(rel: &str) -> std::path::PathBuf {
|
||||
let mut p = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
p.push(rel);
|
||||
p
|
||||
}
|
||||
|
||||
fn collect_stderr_lines(output: &std::process::Output) -> Vec<String> {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
stderr.lines().map(|s| s.to_string()).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn speaker_prompt_spacing_single_vs_multi_is_consistent() {
|
||||
let exe = env!("CARGO_BIN_EXE_polyscribe");
|
||||
let input1 = manifest_path("input/1-s0wlz.json");
|
||||
let input2 = manifest_path("input/2-vikingowl.json");
|
||||
|
||||
// Single mode
|
||||
let mut child1 = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg("--set-speaker-names")
|
||||
.arg("-m")
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.expect("failed to spawn polyscribe (single)");
|
||||
{
|
||||
let s = child1.stdin.as_mut().unwrap();
|
||||
writeln!(s, "Alpha").unwrap();
|
||||
}
|
||||
let out1 = child1.wait_with_output().unwrap();
|
||||
assert!(out1.status.success());
|
||||
let lines1 = collect_stderr_lines(&out1);
|
||||
|
||||
// Multi mode
|
||||
let mut child2 = Command::new(exe)
|
||||
.arg(input1.as_os_str())
|
||||
.arg(input2.as_os_str())
|
||||
.arg("--set-speaker-names")
|
||||
.arg("-m")
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.expect("failed to spawn polyscribe (multi)");
|
||||
{
|
||||
let s = child2.stdin.as_mut().unwrap();
|
||||
writeln!(s, "Alpha").unwrap();
|
||||
writeln!(s, "Beta").unwrap();
|
||||
}
|
||||
let out2 = child2.wait_with_output().unwrap();
|
||||
assert!(out2.status.success());
|
||||
let lines2 = collect_stderr_lines(&out2);
|
||||
|
||||
// Helper to count blank separators around echo block
|
||||
fn analyze(lines: &[String]) -> (usize, usize, usize) {
|
||||
// count: prompts, blanks, echoes (either legacy "Speaker for " or new mapping lines starting with " - ")
|
||||
let mut prompts = 0;
|
||||
let mut blanks = 0;
|
||||
let mut echoes = 0;
|
||||
for l in lines {
|
||||
if l.starts_with("Enter speaker name for ") { prompts += 1; }
|
||||
if l.trim().is_empty() { blanks += 1; }
|
||||
if l.starts_with("Speaker for ") || l.starts_with(" - ") { echoes += 1; }
|
||||
}
|
||||
(prompts, blanks, echoes)
|
||||
}
|
||||
|
||||
let (p1, b1, e1) = analyze(&lines1);
|
||||
let (p2, b2, e2) = analyze(&lines2);
|
||||
|
||||
// Expect one prompt/echo for single, two for multi
|
||||
assert_eq!(p1, 1);
|
||||
assert_eq!(e1, 1);
|
||||
assert_eq!(p2, 2);
|
||||
assert_eq!(e2, 2);
|
||||
|
||||
// Each mode should have exactly two blank separators: one between prompts and echoes and one after echoes
|
||||
// Note: other logs may be absent in tests; we count exactly 2 blanks for single and multi here
|
||||
assert!(b1 >= 2, "expected at least two blank separators in single mode, got {}: {:?}", b1, lines1);
|
||||
assert!(b2 >= 2, "expected at least two blank separators in multi mode, got {}: {:?}", b2, lines2);
|
||||
}
|
6
tests/smoke.rs
Normal file
6
tests/smoke.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
// Rust
|
||||
#[test]
|
||||
fn smoke_compiles_and_runs() {
|
||||
// This test ensures the test harness works without exercising the CLI.
|
||||
assert!(true);
|
||||
}
|
@@ -1,58 +0,0 @@
|
||||
// Unix-only tests for with_suppressed_stderr restoring file descriptors
|
||||
// Skip on Windows and non-Unix targets.
|
||||
|
||||
#![cfg(unix)]
|
||||
|
||||
use std::panic::{catch_unwind, AssertUnwindSafe};
|
||||
|
||||
fn stat_of_fd(fd: i32) -> (u64, u64) {
|
||||
unsafe {
|
||||
let mut st: libc::stat = std::mem::zeroed();
|
||||
let r = libc::fstat(fd, &mut st as *mut libc::stat);
|
||||
assert_eq!(r, 0, "fstat failed on fd {fd}");
|
||||
(st.st_dev as u64, st.st_ino as u64)
|
||||
}
|
||||
}
|
||||
|
||||
fn stat_of_path(path: &str) -> (u64, u64) {
|
||||
use std::ffi::CString;
|
||||
unsafe {
|
||||
let c = CString::new(path).unwrap();
|
||||
let fd = libc::open(c.as_ptr(), libc::O_RDONLY);
|
||||
assert!(fd >= 0, "failed to open {path}");
|
||||
let s = stat_of_fd(fd);
|
||||
let _ = libc::close(fd);
|
||||
s
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stderr_is_redirected_and_restored() {
|
||||
let before = stat_of_fd(2);
|
||||
let devnull = stat_of_path("/dev/null");
|
||||
|
||||
// During the call, fd 2 should be /dev/null; after, restored to before
|
||||
polyscribe::with_suppressed_stderr(|| {
|
||||
let inside = stat_of_fd(2);
|
||||
assert_eq!(inside, devnull, "stderr should point to /dev/null during suppression");
|
||||
// This write should be suppressed
|
||||
eprintln!("this should be suppressed");
|
||||
});
|
||||
|
||||
let after = stat_of_fd(2);
|
||||
assert_eq!(after, before, "stderr should be restored after suppression");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stderr_is_restored_even_if_closure_panics() {
|
||||
let before = stat_of_fd(2);
|
||||
let res = catch_unwind(AssertUnwindSafe(|| {
|
||||
polyscribe::with_suppressed_stderr(|| {
|
||||
// Trigger a deliberate panic inside the closure
|
||||
panic!("boom inside with_suppressed_stderr");
|
||||
});
|
||||
}));
|
||||
assert!(res.is_err(), "expected panic to propagate");
|
||||
let after = stat_of_fd(2);
|
||||
assert_eq!(after, before, "stderr should be restored after panic");
|
||||
}
|
Reference in New Issue
Block a user