Compare commits

...

19 Commits

Author SHA1 Message Date
af473c4942 [docs] update README with new CLI options, usage tips, and guidance for language models
Some checks failed
CI / ci (push) Has been cancelled
2025-08-12 05:17:41 +02:00
a26eade80b [test] add examples-check target with stubbed BIN and no-network validation for example scripts 2025-08-12 05:15:41 +02:00
94c816acdf [test] add CI workflow with Rust checks, cache setup, and auditing; update docs and README with CI details 2025-08-12 05:09:24 +02:00
3dc1237938 [test] add tests for progress manager modes; verify bar counts and total bar visibility in single and multi modes 2025-08-12 05:07:09 +02:00
6994d20f5e [test] add tests for --force flag and numeric suffix handling to ensure proper output file resolution and overwriting behavior 2025-08-12 05:02:33 +02:00
b7f0ddda37 [test] add tests for --no-interaction and its alias to ensure non-interactive mode skips prompts and uses defaults 2025-08-12 04:58:39 +02:00
98491a8701 [test] add test for deterministic merge output across job counts; enhance --jobs support with parallel processing logic 2025-08-12 04:43:39 +02:00
abe81b643b [test] add unit tests for validate_model_lang_compat ensuring model-language compatibility validation 2025-08-12 04:39:33 +02:00
f143e66e80 [test] add comprehensive tests for select_backend ensuring proper backend priority and error guidance 2025-08-12 04:24:51 +02:00
7832545033 [test] add Unix-only tests for with_suppressed_stderr ensuring stderr redirection and restoration, including panic handling 2025-08-12 04:22:55 +02:00
152fde36ae [build] pin whisper-rs dependency to a specific commit for reproducible builds; update documentation accordingly 2025-08-12 04:16:09 +02:00
df6faf6436 [refactor] remove dialoguer dependency; migrate selection prompts to cliclack 2025-08-12 04:12:53 +02:00
e954902aa9 [feat] enhance progress logging, introduce TTY-aware banners, and implement hardened SHA-256 verification for model downloads 2025-08-12 04:10:24 +02:00
37c43161da [feat] integrate global progress manager for unified log handling; enhance model download workflow with progress tracking and SHA-256 verification 2025-08-12 03:46:39 +02:00
9120e8fb26 [feat] introduce Config for centralized runtime settings; refactor progress management and backend selection to leverage config 2025-08-12 02:57:42 +02:00
ee67b56d6b [feat] enhance error handling, CLI options, and progress display; add --continue-on-error flag and improve maintainability 2025-08-12 02:43:20 +02:00
d531ac0b96 [refactor] extract summary table rendering logic into render_summary_lines for improved readability and reusability; add associated tests 2025-08-11 10:05:37 +02:00
66f0062ffb [feat] add --out-format CLI option for customizable output formats; update tests and README 2025-08-11 10:01:29 +02:00
d46b23a4f5 [refactor] extract and centralize output writing logic into write_outputs function in output.rs for improved code reuse and maintainability 2025-08-11 09:35:29 +02:00
25 changed files with 2457 additions and 848 deletions

45
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,45 @@
name: CI
on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]
permissions:
contents: read
jobs:
ci:
name: ci
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Rust
uses: dtolnay/rust-toolchain@stable
with:
components: clippy, rustfmt
- name: Show rustc/cargo versions
run: |
rustc -Vv
cargo -Vv
- name: Cache cargo registry
uses: actions/cache@v4
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
- name: Install cargo-audit
run: |
cargo install cargo-audit --locked || cargo install cargo-audit
- name: Format check
run: cargo fmt --all -- --check
- name: Clippy (warnings as errors)
run: cargo clippy --all-targets -- -D warnings
- name: Test
run: cargo test --all
- name: Audit
run: cargo audit

44
Cargo.lock generated
View File

@@ -378,19 +378,6 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "dialoguer"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de"
dependencies = [
"console",
"shell-words",
"tempfile",
"thiserror",
"zeroize",
]
[[package]]
name = "digest"
version = "0.10.7"
@@ -1173,7 +1160,6 @@ dependencies = [
"clap_mangen",
"cliclack",
"ctrlc",
"dialoguer",
"indicatif",
"libc",
"reqwest",
@@ -1491,12 +1477,6 @@ dependencies = [
"digest",
]
[[package]]
name = "shell-words"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde"
[[package]]
name = "shlex"
version = "1.3.0"
@@ -1625,26 +1605,6 @@ dependencies = [
"unicode-width",
]
[[package]]
name = "thiserror"
version = "1.0.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tinystr"
version = "0.8.1"
@@ -1999,7 +1959,7 @@ dependencies = [
[[package]]
name = "whisper-rs"
version = "0.14.3"
source = "git+https://github.com/tazz4843/whisper-rs#135b60b85a15714862806b6ea9f76abec38156f1"
source = "git+https://github.com/tazz4843/whisper-rs?rev=135b60b85a15714862806b6ea9f76abec38156f1#135b60b85a15714862806b6ea9f76abec38156f1"
dependencies = [
"whisper-rs-sys",
]
@@ -2007,7 +1967,7 @@ dependencies = [
[[package]]
name = "whisper-rs-sys"
version = "0.13.0"
source = "git+https://github.com/tazz4843/whisper-rs#135b60b85a15714862806b6ea9f76abec38156f1"
source = "git+https://github.com/tazz4843/whisper-rs?rev=135b60b85a15714862806b6ea9f76abec38156f1#135b60b85a15714862806b6ea9f76abec38156f1"
dependencies = [
"bindgen",
"cfg-if",

View File

@@ -5,11 +5,14 @@ edition = "2024"
license = "MIT"
[features]
# Default: CPU only; no GPU features enabled
# Default: build without whisper to keep tests lightweight; enable `whisper` to use whisper-rs.
default = []
# GPU backends map to whisper-rs features or FFI stub for Vulkan
gpu-cuda = ["whisper-rs/cuda"]
gpu-hip = ["whisper-rs/hipblas"]
# Enable whisper-rs dependency (CPU-only unless combined with gpu-* features)
whisper = ["dep:whisper-rs"]
# GPU backends map to whisper-rs features
gpu-cuda = ["whisper", "whisper-rs/cuda"]
gpu-hip = ["whisper", "whisper-rs/hipblas"]
# Vulkan path currently doesn't use whisper directly here; placeholder feature
gpu-vulkan = []
# explicit CPU fallback feature (no effect at build time, used for clarity)
cpu-fallback = []
@@ -25,12 +28,13 @@ toml = "0.8"
chrono = { version = "0.4", features = ["clock"] }
reqwest = { version = "0.12", features = ["blocking", "json"] }
sha2 = "0.10"
# whisper-rs is always used (CPU-only by default); GPU features map onto it
whisper-rs = { git = "https://github.com/tazz4843/whisper-rs", default-features = false }
# Make whisper-rs optional; enabled via `whisper` feature
# Pin whisper-rs to a known-good commit for reproducible builds.
# To update: run `cargo update -p whisper-rs --precise 135b60b85a15714862806b6ea9f76abec38156f1` (adjust SHA) and update this rev.
whisper-rs = { git = "https://github.com/tazz4843/whisper-rs", rev = "135b60b85a15714862806b6ea9f76abec38156f1", default-features = false, optional = true }
libc = "0.2"
indicatif = "0.17"
ctrlc = "3.4"
dialoguer = "0.11"
cliclack = "0.3"
[dev-dependencies]

23
Makefile Normal file
View File

@@ -0,0 +1,23 @@
# Lightweight examples-check: runs all examples/*.sh with --no-interaction -q and stubbed BIN
# This target does not perform network calls and never prompts for input.
.SHELL := /bin/bash
.PHONY: examples-check
examples-check:
@set -euo pipefail; \
shopt -s nullglob; \
BIN_WRAPPER="$(PWD)/scripts/with_flags.sh"; \
failed=0; \
for f in examples/*.sh; do \
echo "[examples-check] Running $$f"; \
BIN="$$BIN_WRAPPER" bash "$$f" </dev/null >/dev/null 2>&1 || { \
echo "[examples-check] FAILED: $$f"; failed=1; \
}; \
done; \
if [[ $$failed -ne 0 ]]; then \
echo "[examples-check] Some examples failed."; \
exit 1; \
else \
echo "[examples-check] All examples passed (no interaction, quiet)."; \
fi

View File

@@ -30,8 +30,12 @@ Quickstart
- ./target/release/polyscribe --download-models
2) Transcribe a file:
- ./target/release/polyscribe -v -o output my_audio.mp3
This writes JSON and SRT into the output directory with a date prefix.
- ./target/release/polyscribe -v -o output --out-format json --jobs 4 my_audio.mp3
This writes JSON (because of --out-format json) into the output directory with a date prefix. Omit --out-format to write all available formats (JSON and SRT). For large batches, add --continue-on-error to skip bad files and keep going.
Gotchas
- English-only models: If you picked an English-only Whisper model (e.g., tiny.en, base.en), non-English language hints (via --language) will be rejected and detection may be biased toward English. Use a multilingual model (without the .en suffix) for non-English audio.
- Language hints help: When you know the language, pass --language <code> (e.g., --language de) to improve accuracy and speed. If the audio is mixed language, omit the hint to let the model detect.
Shell completions and man page
- Completions: ./target/release/polyscribe completions <bash|zsh|fish|powershell|elvish> > polyscribe.<ext>
@@ -46,6 +50,7 @@ Model locations
Most-used CLI flags
- -o, --output FILE_OR_DIR: Output path base (date prefix added). If omitted, JSON prints to stdout.
- --out-format <json|toml|srt|all>: Which on-disk format(s) to write; repeatable; default all. Example: --out-format json --out-format srt
- -m, --merge: Merge all inputs into one output; otherwise one output per input.
- --merge-and-separate: Write both merged output and separate per-input outputs (requires -o dir).
- --set-speaker-names: Prompt for a speaker label per input file.
@@ -76,7 +81,7 @@ Troubleshooting & docs
- docs/ci.md minimal CI checklist and job outline
- CONTRIBUTING.md PR checklist and workflow
CI status: [CI badge placeholder]
CI status: [CI workflow runs](actions/workflows/ci.yml)
Examples
See the examples/ directory for copy-paste scripts:

View File

@@ -14,6 +14,10 @@ Example GitHub Actions job (outline)
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- name: Print resolved whisper-rs rev
run: |
echo "Resolved whisper-rs revision:" && \
awk '/name = "whisper-rs"/{f=1} f&&/source = "git\+.*whisper-rs#/{match($0,/#([0-9a-f]{7,40})"/,m); if(m[1]){print m[1]; exit}}' Cargo.lock
- name: Build
run: cargo build --all-targets --locked
- name: Test
@@ -24,3 +28,4 @@ Example GitHub Actions job (outline)
Notes
- For GPU features, set up appropriate runners and add `--features gpu-cuda|gpu-hip|gpu-vulkan` where applicable.
- For docs-only changes, jobs still build/test to ensure doctests and examples compile when enabled.
- Mark the CI job named `ci` as a required status check for the default branch in repository branch protection settings.

View File

@@ -13,6 +13,12 @@ Rust toolchain
- rustup install stable
- rustup default stable
Dependency pinning
- We pin whisper-rs (git dependency) to a known-good commit in Cargo.toml for reproducibility.
- To bump it, resolve/test the desired commit locally, then run:
- cargo update -p whisper-rs --precise 135b60b85a15714862806b6ea9f76abec38156f1
Replace the SHA with the desired commit and update the rev in Cargo.toml accordingly.
Build
- CPU-only (default):
- cargo build
@@ -41,6 +47,16 @@ Tests
- cargo test
- The test suite includes CLI-oriented integration tests and unit tests. Some tests simulate GPU detection using env vars (POLYSCRIBE_TEST_FORCE_*). Do not rely on these flags in production code.
Examples check (no network, non-interactive)
- To quickly validate that example scripts are wired correctly (no prompts, quiet, exit 0), run:
- make examples-check
- What it does:
- Iterates over examples/*.sh
- Forces execution with --no-interaction and -q via a wrapper
- Uses a stubbed BIN that performs no network access and exits successfully
- Redirects stdin from /dev/null to ensure no prompts
- This is intended for CI smoke checks and local verification; it does not actually download models or transcribe audio.
Clippy
- Run lint checks and treat warnings as errors:
- cargo clippy --all-targets -- -D warnings

0
examples/download_models_interactive.sh Normal file → Executable file
View File

0
examples/transcribe_file.sh Normal file → Executable file
View File

0
examples/update_models.sh Normal file → Executable file
View File

26
scripts/bin_stub.sh Executable file
View File

@@ -0,0 +1,26 @@
#!/usr/bin/env bash
# Lightweight stub for examples-check: simulates the PolyScribe CLI without I/O or network
# - Accepts any arguments
# - Exits 0
# - Produces no output unless VERBOSE_STUB=1
# - Never performs network operations
# - Never reads from stdin
set -euo pipefail
if [[ "${VERBOSE_STUB:-0}" == "1" ]]; then
echo "[stub] polyscribe $*" 1>&2
fi
# Behave quietly if -q/--quiet is present by default (no output)
# Honor --help/-h: print minimal usage if verbose requested
if [[ "${VERBOSE_STUB:-0}" == "1" ]]; then
for arg in "$@"; do
if [[ "$arg" == "-h" || "$arg" == "--help" ]]; then
echo "PolyScribe stub: no-op (examples-check)" 1>&2
break
fi
done
fi
# Always succeed quietly
exit 0

28
scripts/with_flags.sh Executable file
View File

@@ -0,0 +1,28 @@
#!/usr/bin/env bash
# Wrapper that ensures --no-interaction -q are present, then delegates to the real BIN (stub by default)
set -euo pipefail
REAL_BIN=${REAL_BIN:-"$(dirname "$0")/bin_stub.sh"}
# Append flags if not already present in args
args=("$@")
need_no_interaction=1
need_quiet=1
for a in "${args[@]}"; do
[[ "$a" == "--no-interaction" ]] && need_no_interaction=0
[[ "$a" == "-q" || "$a" == "--quiet" ]] && need_quiet=0
done
if [[ $need_no_interaction -eq 1 ]]; then
args=("--no-interaction" "${args[@]}")
fi
if [[ $need_quiet -eq 1 ]]; then
args=("-q" "${args[@]}")
fi
# Never read stdin; prevent accidental blocking by redirecting from /dev/null
# Also advertise offline via env variables commonly checked by the app
export CI=1
export POLYSCRIBE_MODELS_BASE_COPY_DIR="${POLYSCRIBE_MODELS_BASE_COPY_DIR:-}" # leave empty by default
exec "$REAL_BIN" "${args[@]}" </dev/null

View File

@@ -141,6 +141,28 @@ impl Default for VulkanBackend {
}
}
/// Validate that a provided language hint is compatible with the selected model.
///
/// English-only models (filenames containing ".en." or ending with ".en.bin") reject non-"en" hints.
/// When no language is provided, this check passes and downstream behavior remains unchanged.
pub(crate) fn validate_model_lang_compat(model: &Path, lang_opt: Option<&str>) -> Result<()> {
let is_en_only = model
.file_name()
.and_then(|s| s.to_str())
.map(|s| s.contains(".en.") || s.ends_with(".en.bin"))
.unwrap_or(false);
if let Some(lang) = lang_opt {
if is_en_only && lang != "en" {
return Err(anyhow!(
"Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.",
model.display(),
lang
));
}
}
Ok(())
}
impl TranscribeBackend for CpuBackend {
fn kind(&self) -> BackendKind {
BackendKind::Cpu
@@ -226,7 +248,7 @@ pub struct SelectionResult {
/// guidance on how to enable it.
///
/// Set `verbose` to true to print detection/selection info to stderr.
pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<SelectionResult> {
pub fn select_backend(requested: BackendKind, config: &crate::Config) -> Result<SelectionResult> {
let mut detected = Vec::new();
if cuda_available() {
detected.push(BackendKind::Cuda);
@@ -290,7 +312,7 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
BackendKind::Cpu => BackendKind::Cpu,
};
if verbose {
if config.verbose >= 1 && !config.quiet {
crate::dlog!(1, "Detected backends: {:?}", detected);
crate::dlog!(1, "Selected backend: {:?}", chosen);
}
@@ -304,6 +326,7 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
// Internal helper: transcription using whisper-rs with CPU/GPU (depending on build features)
#[allow(clippy::too_many_arguments)]
#[cfg(feature = "whisper")]
pub(crate) fn transcribe_with_whisper_rs(
audio_path: &Path,
speaker: &str,
@@ -327,20 +350,8 @@ pub(crate) fn transcribe_with_whisper_rs(
note: Some("model selected".to_string()),
});
}
let is_en_only = model
.file_name()
.and_then(|s| s.to_str())
.map(|s| s.contains(".en.") || s.ends_with(".en.bin"))
.unwrap_or(false);
if let Some(lang) = lang_opt {
if is_en_only && lang != "en" {
return Err(anyhow!(
"Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.",
model.display(),
lang
));
}
}
// Validate language hint compatibility with the selected model
validate_model_lang_compat(&model, lang_opt)?;
let model_str = model
.to_str()
.ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model.display()))?;
@@ -429,3 +440,140 @@ pub(crate) fn transcribe_with_whisper_rs(
}
Ok(items)
}
#[allow(clippy::too_many_arguments)]
#[cfg(not(feature = "whisper"))]
pub(crate) fn transcribe_with_whisper_rs(
_audio_path: &Path,
_speaker: &str,
_lang_opt: Option<&str>,
_progress_tx: Option<Sender<ProgressMessage>>,
) -> Result<Vec<OutputEntry>> {
Err(anyhow!(
"Transcription requires the 'whisper' feature. Rebuild with --features whisper (and optional gpu-cuda/gpu-hip)."
))
}
#[cfg(test)]
mod tests {
use super::*;
use std::env as std_env;
use std::sync::{Mutex, OnceLock};
#[test]
fn test_validate_model_lang_guard_table() {
struct case<'a> { model: &'a str, lang: Option<&'a str>, ok: bool }
let cases = vec![
// English-only model with en hint: OK
case { model: "ggml-base.en.bin", lang: Some("en"), ok: true },
// English-only model with de hint: Error
case { model: "ggml-small.en.bin", lang: Some("de"), ok: false },
// Multilingual model with de hint: OK
case { model: "ggml-large-v3.bin", lang: Some("de"), ok: true },
// No language provided (audio path scenario): guard should pass (existing behavior elsewhere)
case { model: "ggml-medium.en.bin", lang: None, ok: true },
];
for c in cases {
let p = std::path::Path::new(c.model);
let res = validate_model_lang_compat(p, c.lang);
match (c.ok, res) {
(true, Ok(())) => {}
(false, Err(e)) => {
let msg = format!("{}", e);
assert!(msg.contains("English-only"), "unexpected error: {msg}");
if let Some(l) = c.lang { assert!(msg.contains(l), "missing lang in msg: {msg}"); }
}
(true, Err(e)) => panic!("expected Ok for model={}, lang={:?}, got error: {}", c.model, c.lang, e),
(false, Ok(())) => panic!("expected Err for model={}, lang={:?}", c.model, c.lang),
}
}
}
// Serialize environment variable modifications across tests in this module
static ENV_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
#[test]
fn test_select_backend_auto_prefers_cuda_then_hip_then_vulkan_then_cpu() {
let _guard = ENV_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
// Clear overrides
unsafe {
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
}
// No GPU -> CPU
let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap();
assert_eq!(sel.chosen, BackendKind::Cpu);
// Vulkan only -> Vulkan
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1"); }
let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap();
assert_eq!(sel.chosen, BackendKind::Vulkan);
// HIP only -> HIP (and preferred over Vulkan)
unsafe {
std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1");
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
}
let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap();
assert_eq!(sel.chosen, BackendKind::Hip);
// CUDA only -> CUDA (and preferred over HIP)
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); }
let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap();
assert_eq!(sel.chosen, BackendKind::Cuda);
// Cleanup
unsafe {
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
}
}
#[test]
fn test_select_backend_explicit_unavailable_errors_with_guidance() {
let _guard = ENV_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
// Ensure all off
unsafe {
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
}
// CUDA requested but unavailable -> error with guidance
let err = select_backend(BackendKind::Cuda, &crate::Config::default()).err().expect("expected error");
let msg = format!("{}", err);
assert!(msg.contains("Requested CUDA backend"), "unexpected msg: {msg}");
assert!(msg.contains("How to fix"), "expected guidance text in: {msg}");
// HIP requested but unavailable -> error with guidance
let err = select_backend(BackendKind::Hip, &crate::Config::default()).err().expect("expected error");
let msg = format!("{}", err);
assert!(msg.contains("ROCm/HIP"), "unexpected msg: {msg}");
assert!(msg.contains("How to fix"), "expected guidance text in: {msg}");
// Vulkan requested but unavailable -> error with guidance
let err = select_backend(BackendKind::Vulkan, &crate::Config::default()).err().expect("expected error");
let msg = format!("{}", err);
assert!(msg.contains("Vulkan"), "unexpected msg: {msg}");
assert!(msg.contains("How to fix"), "expected guidance text in: {msg}");
// Now verify success when explicitly available via overrides
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); }
assert!(select_backend(BackendKind::Cuda, &crate::Config::default()).is_ok());
unsafe {
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1");
}
assert!(select_backend(BackendKind::Hip, &crate::Config::default()).is_ok());
unsafe {
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1");
}
assert!(select_backend(BackendKind::Vulkan, &crate::Config::default()).is_ok());
// Cleanup
unsafe { std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); }
}
}

View File

@@ -16,6 +16,7 @@
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
// Global runtime flags
// Compatibility: globals are retained temporarily until all call-sites pass Config explicitly. They will be removed in a follow-up cleanup.
static QUIET: AtomicBool = AtomicBool::new(false);
static NO_INTERACTION: AtomicBool = AtomicBool::new(false);
static VERBOSE: AtomicU8 = AtomicU8::new(0);
@@ -35,7 +36,17 @@ pub fn set_no_interaction(b: bool) {
}
/// Return current non-interactive state.
pub fn is_no_interaction() -> bool {
NO_INTERACTION.load(Ordering::Relaxed)
if NO_INTERACTION.load(Ordering::Relaxed) {
return true;
}
// Also honor NO_INTERACTION=1/true environment variable for convenience/testing
match std::env::var("NO_INTERACTION") {
Ok(v) => {
let v = v.trim();
v == "1" || v.eq_ignore_ascii_case("true")
}
Err(_) => false,
}
}
/// Set verbose level (0 = normal, 1 = verbose, 2 = super-verbose)
@@ -92,7 +103,7 @@ impl StderrSilencer {
#[cfg(unix)]
unsafe {
// Duplicate current stderr (fd 2)
let old_fd = dup(2);
let old_fd = unix_fd::dup(unix_fd::STDERR_FILENO);
if old_fd < 0 {
return Self {
active: false,
@@ -102,10 +113,10 @@ impl StderrSilencer {
}
// Open /dev/null for writing
let devnull_cstr = std::ffi::CString::new("/dev/null").unwrap();
let dn = open(devnull_cstr.as_ptr(), O_WRONLY);
let dn = unix_fd::open(devnull_cstr.as_ptr(), unix_fd::O_WRONLY);
if dn < 0 {
// failed to open devnull; restore and bail
close(old_fd);
unix_fd::close(old_fd);
return Self {
active: false,
old_stderr_fd: -1,
@@ -113,9 +124,9 @@ impl StderrSilencer {
};
}
// Redirect fd 2 to devnull
if dup2(dn, 2) < 0 {
close(dn);
close(old_fd);
if unix_fd::dup2(dn, unix_fd::STDERR_FILENO) < 0 {
unix_fd::close(dn);
unix_fd::close(old_fd);
return Self {
active: false,
old_stderr_fd: -1,
@@ -143,9 +154,9 @@ impl Drop for StderrSilencer {
#[cfg(unix)]
unsafe {
// Restore old stderr and close devnull and old copies
let _ = dup2(self.old_stderr_fd, 2);
let _ = close(self.devnull_fd);
let _ = close(self.old_stderr_fd);
let _ = unix_fd::dup2(self.old_stderr_fd, unix_fd::STDERR_FILENO);
let _ = unix_fd::close(self.devnull_fd);
let _ = unix_fd::close(self.old_stderr_fd);
}
self.active = false;
}
@@ -178,7 +189,8 @@ where
#[macro_export]
macro_rules! elog {
($($arg:tt)*) => {{
eprintln!("ERROR: {}", format!($($arg)*));
// Route errors through the progress area when available so they render inside cliclack
$crate::log_with_level!("ERROR", None, true, $($arg)*);
}}
}
/// Internal helper macro used by other logging macros to centralize the
@@ -195,7 +207,11 @@ macro_rules! log_with_level {
!$crate::is_quiet()
};
if should_print {
eprintln!("{}: {}", $label, format!($($arg)*));
let line = format!("{}: {}", $label, format!($($arg)*));
// Try to render via the active progress manager (cliclack/indicatif area).
if !$crate::progress::log_line_via_global(&line) {
eprintln!("{}", line);
}
}
}}
}
@@ -236,7 +252,18 @@ use std::path::{Path, PathBuf};
use std::process::Command;
#[cfg(unix)]
use libc::{O_WRONLY, close, dup, dup2, open};
mod unix_fd {
pub use libc::O_WRONLY;
pub const STDERR_FILENO: i32 = 2; // libc::STDERR_FILENO isn't always available on all targets
#[inline]
pub unsafe fn dup(fd: i32) -> i32 { libc::dup(fd) }
#[inline]
pub unsafe fn dup2(fd: i32, fd2: i32) -> i32 { libc::dup2(fd, fd2) }
#[inline]
pub unsafe fn open(path: *const libc::c_char, flags: i32) -> i32 { libc::open(path, flags) }
#[inline]
pub unsafe fn close(fd: i32) -> i32 { libc::close(fd) }
}
/// Re-export backend module (GPU/CPU selection and transcription).
pub mod backend;
@@ -248,6 +275,41 @@ pub mod progress;
/// UI helpers for interactive prompts (cliclack-backed)
pub mod ui;
/// Runtime configuration passed across the library instead of using globals.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Config {
/// Suppress non-essential logs.
pub quiet: bool,
/// Verbosity level (0 = normal, 1 = verbose, 2 = super-verbose).
pub verbose: u8,
/// Disable interactive prompts.
pub no_interaction: bool,
/// Disable progress output.
pub no_progress: bool,
}
impl Config {
/// Construct a Config from explicit values.
pub fn new(quiet: bool, verbose: u8, no_interaction: bool, no_progress: bool) -> Self {
Self { quiet, verbose, no_interaction, no_progress }
}
/// Snapshot current global settings into a Config (temporary compatibility helper).
pub fn from_globals() -> Self {
Self {
quiet: crate::is_quiet(),
verbose: crate::verbose_level(),
no_interaction: crate::is_no_interaction(),
no_progress: matches!(std::env::var("NO_PROGRESS"), Ok(ref v) if v == "1" || v.eq_ignore_ascii_case("true")),
}
}
}
impl Default for Config {
fn default() -> Self {
Self { quiet: false, verbose: 0, no_interaction: false, no_progress: false }
}
}
/// Transcript entry for a single segment.
#[derive(Debug, serde::Serialize, Clone)]
pub struct OutputEntry {
@@ -574,27 +636,26 @@ where
}
printer(&"Multiple Whisper models found:".to_string());
let mut display_names: Vec<String> = Vec::with_capacity(candidates.len());
for (i, p) in candidates.iter().enumerate() {
let name = p
.file_name()
.and_then(|s| s.to_str())
.map(|s| s.to_string())
.unwrap_or_else(|| p.display().to_string());
display_names.push(name.clone());
printer(&format!(" {}) {}", i + 1, name));
}
// Print a blank line and the selection prompt using the provided printer to
// keep output synchronized with any active progress rendering.
// Print a blank line before the selection prompt to keep output synchronized.
printer("");
let prompt = format!("Select model [1-{}]:", candidates.len());
// TODO(ui): migrate to cliclack::Select for model picking to standardize UI.
let sel: usize = dialoguer::Input::new()
.with_prompt(prompt)
.interact_text()
.context("Failed to read selection")?;
if sel == 0 || sel > candidates.len() {
return Err(anyhow!("Selection out of range"));
}
let chosen = candidates.swap_remove(sel - 1);
let idx = if crate::is_no_interaction() || !crate::stdin_is_tty() {
// Non-interactive: auto-select the first candidate deterministically (as listed)
0
} else {
crate::ui::prompt_select_index("Select a Whisper model", &display_names)
.context("Failed to read selection")?
};
let chosen = candidates.swap_remove(idx);
let _ = std::fs::write(models_dir.join(".last_model"), chosen.display().to_string());
// Print an empty line after selection input
printer("");

File diff suppressed because it is too large Load Diff

View File

@@ -440,11 +440,26 @@ fn prompt_select_models_two_stage(models: &[ModelEntry]) -> Result<Vec<ModelEntr
&[],
)?;
// Map labels back to entries in stable order
// If no variants were explicitly selected, ask for confirmation to download all.
// This avoids surprising behavior while still allowing a quick "download all" path.
let mut picked: Vec<ModelEntry> = Vec::new();
for (i, label) in labels.iter().enumerate() {
if selected_labels.iter().any(|s| s == label) {
picked.push(variants[i].clone().clone());
if selected_labels.is_empty() {
// Confirm with the user; default to "No" to prevent accidental bulk downloads.
if crate::ui::prompt_confirm(&format!("No variants selected. Download ALL {base} variants?"), false).unwrap_or(false) {
crate::qlog!("Downloading all {base} variants as requested.");
for v in &variants {
picked.push((*v).clone());
}
} else {
// User declined; return empty selection so caller can abort gracefully.
return Ok(Vec::new());
}
} else {
// Map labels back to entries in stable order
for (i, label) in labels.iter().enumerate() {
if selected_labels.iter().any(|s| s == label) {
picked.push(variants[i].clone());
}
}
}
@@ -480,6 +495,11 @@ pub fn run_interactive_model_downloader() -> Result<()> {
.build()
.context("Failed to build HTTP client")?;
// Set up a temporary progress manager so INFO/WARN logs render within the UI.
let pf0 = crate::progress::ProgressFactory::from_config(&crate::Config::from_globals());
let pm0 = pf0.make_manager(crate::progress::ProgressMode::Single);
crate::progress::set_global_progress_manager(&pm0);
ilog!(
"Fetching online data: contacting Hugging Face to retrieve available models (this may take a moment)..."
);
@@ -493,11 +513,212 @@ pub fn run_interactive_model_downloader() -> Result<()> {
qlog!("No selection. Aborting download.");
return Ok(());
}
// Set up progress bars for downloads
let pf = crate::progress::ProgressFactory::from_config(&crate::Config::from_globals());
let pm = pf.make_manager(crate::progress::ProgressMode::Multi { total_inputs: selected.len() as u64 });
crate::progress::set_global_progress_manager(&pm);
// Install Ctrl-C cleanup to ensure partial downloads (*.part) are removed on cancel
crate::progress::install_ctrlc_cleanup(pm.clone());
pm.set_total(selected.len());
for m in selected {
if let Err(e) = download_one_model(&client, models_dir, &m) {
let label = format!("{} ({} total)", m.name, human_size(m.size));
let item = pm.start_item(&label);
// Initialize message
if m.size > 0 { update_item_progress(&item, 0, m.size); }
if let Err(e) = download_one_model_with_progress(&client, models_dir, &m, &item) {
item.finish_with("done");
elog!("Error: {:#}", e);
}
pm.inc_completed();
}
pm.finish_all();
Ok(())
}
/// Internal helper: update a per-item progress handle with bytes progress.
fn update_item_progress(item: &crate::progress::ItemHandle, done_bytes: u64, total_bytes: u64) {
let total_mib = (total_bytes as f64) / (1024.0 * 1024.0);
let done_mib = (done_bytes as f64) / (1024.0 * 1024.0);
let pct = if total_bytes > 0 { ((done_bytes as f64) * 100.0 / (total_bytes as f64)).round() } else { 0.0 };
item.set_message(&format!("{:.2}/{:.2} MiB ({:.0}%)", done_mib, total_mib, pct));
if total_bytes > 0 {
item.set_progress((done_bytes as f32) / (total_bytes as f32));
}
}
/// Internal streaming helper used by both network and tests.
fn stream_with_progress<R: Read, W: Write>(mut reader: R, mut writer: W, total: u64, item: &crate::progress::ItemHandle) -> Result<(u64, String)> {
let mut hasher = Sha256::new();
let mut buf = [0u8; 1024 * 128];
let mut done: u64 = 0;
if total > 0 {
// initialize bar to determinate length 100
item.set_progress(0.0);
}
loop {
let n = reader.read(&mut buf).context("Network/read error")?;
if n == 0 { break; }
hasher.update(&buf[..n]);
writer.write_all(&buf[..n]).context("Write error")?;
done += n as u64;
update_item_progress(item, done, total);
}
writer.flush().ok();
let got = to_hex_lower(&hasher.finalize());
Ok((done, got))
}
/// Download a single model entry into the given models directory, verifying SHA-256 when available, with visible progress.
fn download_one_model_with_progress(client: &Client, models_dir: &Path, entry: &ModelEntry, item: &crate::progress::ItemHandle) -> Result<()> {
let final_path = models_dir.join(format!("ggml-{}.bin", entry.name));
// Same pre-checks as the non-progress version (up-to-date checks)
if final_path.exists() {
if let Some(expected) = &entry.sha256 {
match compute_file_sha256_hex(&final_path) {
Ok(local_hash) => {
if local_hash.eq_ignore_ascii_case(expected) {
item.set_message(&format!("{} up-to-date", entry.name));
item.set_progress(1.0);
item.finish_with("done");
return Ok(());
}
}
Err(_) => { /* proceed to download */ }
}
} else if entry.size > 0 {
if let Ok(md) = std::fs::metadata(&final_path) {
if md.len() == entry.size {
item.set_message(&format!("{} up-to-date", entry.name));
item.set_progress(1.0);
item.finish_with("done");
return Ok(());
}
}
}
}
// Offline/local copy mode for tests (same behavior, but reflect via item)
if let Ok(base_dir) = env::var("POLYSCRIBE_MODELS_BASE_COPY_DIR") {
let src_path = std::path::Path::new(&base_dir).join(format!("ggml-{}.bin", entry.name));
if src_path.exists() {
let tmp_path = models_dir.join(format!("ggml-{}.bin.part", entry.name));
if tmp_path.exists() { let _ = std::fs::remove_file(&tmp_path); }
std::fs::copy(&src_path, &tmp_path).with_context(|| {
format!("Failed to copy from {} to {}", src_path.display(), tmp_path.display())
})?;
if let Some(expected) = &entry.sha256 {
let got = compute_file_sha256_hex(&tmp_path)?;
if !got.eq_ignore_ascii_case(expected) {
let _ = std::fs::remove_file(&tmp_path);
return Err(anyhow!("SHA-256 mismatch for {} (copied): expected {}, got {}", entry.name, expected, got));
}
}
if final_path.exists() { let _ = std::fs::remove_file(&final_path); }
std::fs::rename(&tmp_path, &final_path).with_context(|| format!("Failed to move into place: {}", final_path.display()))?;
// Hardened verification after save
if let Some(expected) = &entry.sha256 {
match compute_file_sha256_hex(&final_path) {
Ok(rehash) => {
if !rehash.eq_ignore_ascii_case(expected) {
let _ = std::fs::remove_file(&final_path);
return Err(anyhow!(
"Downloaded file failed SHA-256 verification after save for {}: expected {}, got {}. The file has been removed. Please try downloading again. If the problem persists, check your network connection and disk space, or report this issue.",
entry.name,
expected,
rehash
));
}
}
Err(e) => {
let _ = std::fs::remove_file(&final_path);
return Err(anyhow!(
"Failed to verify downloaded file {}: {}. The file has been removed. Please try again.",
final_path.display(),
e
));
}
}
}
item.set_progress(1.0);
item.finish_with("done");
return Ok(());
}
}
let url = format!(
"https://huggingface.co/{}/resolve/main/ggml-{}.bin",
entry.repo, entry.name
);
let mut resp = client
.get(url)
.send()
.and_then(|r| r.error_for_status())
.context("Failed to download model")?;
let tmp_path = models_dir.join(format!("ggml-{}.bin.part", entry.name));
if tmp_path.exists() { let _ = std::fs::remove_file(&tmp_path); }
let mut file = std::io::BufWriter::new(
File::create(&tmp_path).with_context(|| format!("Failed to create {}", tmp_path.display()))?,
);
// Determine total bytes (prefer metadata/HEAD-derived entry.size)
let total = if entry.size > 0 { entry.size } else { resp.content_length().unwrap_or(0) };
// Stream with progress
let (_bytes, hash_hex) = stream_with_progress(&mut resp, &mut file, total, item)?;
// Verify
item.set_message("sha256 verifying…");
if let Some(expected) = &entry.sha256 {
if hash_hex.to_lowercase() != expected.to_lowercase() {
let _ = std::fs::remove_file(&tmp_path);
return Err(anyhow!(
"SHA-256 mismatch for {}: expected {}, got {}",
entry.name,
expected,
hash_hex
));
}
} else {
qlog!(
"Warning: no SHA-256 available for {}. Skipping verification.",
entry.name
);
}
// Replace existing file safely
if final_path.exists() { let _ = std::fs::remove_file(&final_path); }
std::fs::rename(&tmp_path, &final_path)
.with_context(|| format!("Failed to move into place: {}", final_path.display()))?;
// Hardened verification: recompute SHA-256 from the saved file and compare to expected.
if let Some(expected) = &entry.sha256 {
match compute_file_sha256_hex(&final_path) {
Ok(rehash) => {
if !rehash.eq_ignore_ascii_case(expected) {
let _ = std::fs::remove_file(&final_path);
return Err(anyhow!(
"Downloaded file failed SHA-256 verification after save for {}: expected {}, got {}. The file has been removed. Please try downloading again. If the problem persists, check your network connection and disk space, or report this issue.",
entry.name,
expected,
rehash
));
}
}
Err(e) => {
let _ = std::fs::remove_file(&final_path);
return Err(anyhow!(
"Failed to verify downloaded file {}: {}. The file has been removed. Please try again.",
final_path.display(),
e
));
}
}
}
item.finish_with("done");
Ok(())
}
@@ -601,6 +822,30 @@ fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) ->
}
std::fs::rename(&tmp_path, &final_path)
.with_context(|| format!("Failed to move into place: {}", final_path.display()))?;
// Hardened verification after save
if let Some(expected) = &entry.sha256 {
match compute_file_sha256_hex(&final_path) {
Ok(rehash) => {
if !rehash.eq_ignore_ascii_case(expected) {
let _ = std::fs::remove_file(&final_path);
return Err(anyhow!(
"Downloaded file failed SHA-256 verification after save for {}: expected {}, got {}. The file has been removed. Please try downloading again. If the problem persists, check your network connection and disk space, or report this issue.",
entry.name,
expected,
rehash
));
}
}
Err(e) => {
let _ = std::fs::remove_file(&final_path);
return Err(anyhow!(
"Failed to verify downloaded file {}: {}. The file has been removed. Please try again.",
final_path.display(),
e
));
}
}
}
qlog!("Saved: {}", final_path.display());
return Ok(());
}
@@ -666,6 +911,32 @@ fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) ->
}
std::fs::rename(&tmp_path, &final_path)
.with_context(|| format!("Failed to move into place: {}", final_path.display()))?;
// Hardened verification: recompute SHA-256 from the saved file and compare to expected.
if let Some(expected) = &entry.sha256 {
match compute_file_sha256_hex(&final_path) {
Ok(rehash) => {
if !rehash.eq_ignore_ascii_case(expected) {
let _ = std::fs::remove_file(&final_path);
return Err(anyhow!(
"Downloaded file failed SHA-256 verification after save for {}: expected {}, got {}. The file has been removed. Please try downloading again. If the problem persists, check your network connection and disk space, or report this issue.",
entry.name,
expected,
rehash
));
}
}
Err(e) => {
let _ = std::fs::remove_file(&final_path);
return Err(anyhow!(
"Failed to verify downloaded file {}: {}. The file has been removed. Please try again.",
final_path.display(),
e
));
}
}
}
qlog!("Saved: {}", final_path.display());
Ok(())
}
@@ -701,6 +972,11 @@ pub fn update_local_models() -> Result<()> {
.build()
.context("Failed to build HTTP client")?;
// Ensure logs go through cliclack area during update as well
let pf_up = crate::progress::ProgressFactory::from_config(&crate::Config::from_globals());
let pm_up = pf_up.make_manager(crate::progress::ProgressMode::Single);
crate::progress::set_global_progress_manager(&pm_up);
// Obtain manifest: env override or online fetch
let models: Vec<ModelEntry> = if let Ok(manifest_path) = env::var("POLYSCRIBE_MODELS_MANIFEST")
{
@@ -859,6 +1135,7 @@ mod tests {
#[test]
fn test_format_model_list_spacing_and_structure() {
use std::env as std_env;
let models = vec![
ModelEntry {
name: "tiny.en-q5_1".to_string(),
@@ -1071,4 +1348,92 @@ mod tests {
std::env::remove_var("HOME");
}
}
#[test]
fn test_download_progress_bar_reaches_done() {
use std::io::Cursor;
// Prepare small fake stream of 300 KiB
let data = vec![42u8; 300 * 1024];
let total = data.len() as u64;
let cursor = Cursor::new(data);
let mut sink: Vec<u8> = Vec::new();
let pm = crate::progress::ProgressManager::new_for_tests_multi_hidden(1);
let item = pm.start_item("test-download");
// Stream into sink while updating progress
let (_bytes, _hash) = super::stream_with_progress(cursor, &mut sink, total, &item).unwrap();
// Transition to verifying and finish
item.set_message("sha256 verifying…");
item.finish_with("done");
// Inspect current bar state
if let Some((pos, len, finished, msg)) = pm.current_state_for_tests() {
// Ensure determinate length is 100 and we reached 100
assert_eq!(len, 100);
assert_eq!(pos, 100);
assert!(finished);
assert!(msg.contains("done"));
} else {
panic!("progress manager did not expose current state");
}
}
#[test]
fn test_no_interaction_models_downloader_skips_prompts() {
// Force non-interactive; verify that no UI prompt functions are invoked
unsafe { std::env::set_var("NO_INTERACTION", "1"); }
crate::set_no_interaction(true);
crate::ui::testing_reset_prompt_call_counters();
let models = vec![
ModelEntry { name: "tiny.en-q5_1".to_string(), base: "tiny".to_string(), subtype: "en-q5_1".to_string(), size: 1024, sha256: None, repo: "ggerganov/whisper.cpp".to_string() },
ModelEntry { name: "tiny-q5_1".to_string(), base: "tiny".to_string(), subtype: "q5_1".to_string(), size: 2048, sha256: None, repo: "ggerganov/whisper.cpp".to_string() },
];
let picked = super::prompt_select_models_two_stage(&models).unwrap();
assert!(picked.is_empty(), "non-interactive should not select any models by default");
assert_eq!(crate::ui::testing_prompt_call_count(), 0, "no prompt functions should be called in non-interactive mode");
unsafe { std::env::remove_var("NO_INTERACTION"); }
}
#[test]
fn test_wrong_hash_deletes_temp_and_errors() {
use std::sync::{Mutex, OnceLock};
static ENV_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
let _guard = ENV_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
let tmp_models = tempdir().unwrap();
let tmp_base = tempdir().unwrap();
// Prepare source model file content and a pre-existing local file to trigger update
let model_name = "tiny.en-q5_1";
let src_path = tmp_base.path().join(format!("ggml-{}.bin", model_name));
let content = b"model data";
fs::write(&src_path, content).unwrap();
let wrong_sha = "0000000000000000000000000000000000000000000000000000000000000000".to_string();
let local_path = tmp_models.path().join(format!("ggml-{}.bin", model_name));
let original = b"old local";
fs::write(&local_path, original).unwrap();
unsafe {
std::env::set_var("POLYSCRIBE_MODELS_BASE_COPY_DIR", tmp_base.path());
}
// Construct a ModelEntry with wrong expected sha and call the downloader directly
let client = Client::builder().build().unwrap();
let entry = ModelEntry {
name: model_name.to_string(),
base: "tiny".to_string(),
subtype: "en-q5_1".to_string(),
size: content.len() as u64,
sha256: Some(wrong_sha),
repo: "ggerganov/whisper.cpp".to_string(),
};
let res = super::download_one_model(&client, tmp_models.path(), &entry);
assert!(res.is_err(), "expected error due to wrong hash");
let final_path = tmp_models.path().join(format!("ggml-{}.bin", model_name));
let tmp_path = tmp_models.path().join(format!("ggml-{}.bin.part", model_name));
assert!(final_path.exists(), "existing local file should remain when new download fails");
let preserved = fs::read(&final_path).unwrap();
assert_eq!(preserved, original, "existing local file must be preserved");
assert!(!tmp_path.exists(), ".part file should be deleted on hash mismatch");
}
}

149
src/output.rs Normal file
View File

@@ -0,0 +1,149 @@
use std::fs::File;
use std::io::Write;
use std::path::{Path, PathBuf};
use anyhow::Context;
use crate::render_srt;
use crate::OutputRoot;
/// Which formats to write.
pub struct OutputFormats {
pub json: bool,
pub toml: bool,
pub srt: bool,
}
impl OutputFormats {
pub fn all() -> Self {
Self { json: true, toml: true, srt: true }
}
}
fn any_target_exists(base: &Path, formats: &OutputFormats) -> bool {
(formats.json && base.with_extension("json").exists())
|| (formats.toml && base.with_extension("toml").exists())
|| (formats.srt && base.with_extension("srt").exists())
}
fn with_suffix(base: &Path, n: usize) -> PathBuf {
let parent = base.parent().unwrap_or_else(|| Path::new(""));
let name = base.file_name().and_then(|s| s.to_str()).unwrap_or("out");
parent.join(format!("{}_{}", name, n))
}
fn resolve_base(base: &Path, formats: &OutputFormats, force: bool) -> PathBuf {
if force {
return base.to_path_buf();
}
if !any_target_exists(base, formats) {
return base.to_path_buf();
}
let mut n = 1usize;
loop {
let candidate = with_suffix(base, n);
if !any_target_exists(&candidate, formats) {
return candidate;
}
n += 1;
}
}
/// Write outputs for the given base path (without extension).
/// This will create files named `base.json`, `base.toml`, and `base.srt`
/// according to the `formats` flags. JSON and TOML will always end with a trailing newline.
pub fn write_outputs(base: &Path, root: &OutputRoot, formats: &OutputFormats, force: bool) -> anyhow::Result<()> {
let base = resolve_base(base, formats, force);
if formats.json {
let json_path = base.with_extension("json");
let mut json_file = File::create(&json_path).with_context(|| {
format!("Failed to create output file: {}", json_path.display())
})?;
serde_json::to_writer_pretty(&mut json_file, root)?;
// ensure trailing newline
writeln!(&mut json_file)?;
}
if formats.toml {
let toml_path = base.with_extension("toml");
let toml_str = toml::to_string_pretty(root)?;
let mut toml_file = File::create(&toml_path).with_context(|| {
format!("Failed to create output file: {}", toml_path.display())
})?;
toml_file.write_all(toml_str.as_bytes())?;
if !toml_str.ends_with('\n') {
writeln!(&mut toml_file)?;
}
}
if formats.srt {
let srt_path = base.with_extension("srt");
let srt_str = render_srt(&root.items);
let mut srt_file = File::create(&srt_path).with_context(|| {
format!("Failed to create output file: {}", srt_path.display())
})?;
srt_file.write_all(srt_str.as_bytes())?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::OutputEntry;
#[test]
fn write_outputs_creates_files_and_newlines() {
let dir = tempfile::tempdir().unwrap();
let base = dir.path().join("test_base");
let items = vec![OutputEntry { id: 0, speaker: "Alice".to_string(), start: 0.0, end: 1.23, text: "Hello".to_string() }];
let root = OutputRoot { items };
write_outputs(&base, &root, &OutputFormats::all(), false).unwrap();
let json_path = base.with_extension("json");
let toml_path = base.with_extension("toml");
let srt_path = base.with_extension("srt");
assert!(json_path.exists(), "json file should exist");
assert!(toml_path.exists(), "toml file should exist");
assert!(srt_path.exists(), "srt file should exist");
let json = std::fs::read_to_string(&json_path).unwrap();
let toml = std::fs::read_to_string(&toml_path).unwrap();
assert!(json.ends_with('\n'), "json should end with newline");
assert!(toml.ends_with('\n'), "toml should end with newline");
}
#[test]
fn suffix_is_added_when_file_exists_unless_forced() {
let dir = tempfile::tempdir().unwrap();
let base = dir.path().join("run");
// Precreate a toml file for base to simulate existing output
let pre_path = base.with_extension("toml");
std::fs::create_dir_all(dir.path()).unwrap();
std::fs::write(&pre_path, b"existing\n").unwrap();
let items = vec![OutputEntry { id: 0, speaker: "A".to_string(), start: 0.0, end: 1.0, text: "Hi".to_string() }];
let root = OutputRoot { items };
let fmts = OutputFormats { json: false, toml: true, srt: false };
// Without force, should write to run_1.toml
write_outputs(&base, &root, &fmts, false).unwrap();
assert!(base.with_file_name("run_1").with_extension("toml").exists());
// If run_1.toml also exists, next should be run_2.toml
std::fs::write(base.with_file_name("run_1").with_extension("toml"), b"x\n").unwrap();
write_outputs(&base, &root, &fmts, false).unwrap();
assert!(base.with_file_name("run_2").with_extension("toml").exists());
// With force, should overwrite the base.toml
write_outputs(&base, &root, &fmts, true).unwrap();
let content = std::fs::read_to_string(pre_path).unwrap();
assert!(content.ends_with('\n'));
}
}

View File

@@ -8,6 +8,35 @@ use std::time::Instant;
use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle};
// Global hook to route logs through the active progress manager so they render within
// the same cliclack/indicatif area instead of raw stderr.
static GLOBAL_PM: std::sync::Mutex<Option<ProgressManager>> = std::sync::Mutex::new(None);
/// Install a global ProgressManager used for printing log lines above bars.
pub fn set_global_progress_manager(pm: &ProgressManager) {
if let Ok(mut g) = GLOBAL_PM.lock() {
*g = Some(pm.clone());
}
}
/// Remove the global ProgressManager hook.
pub fn clear_global_progress_manager() {
if let Ok(mut g) = GLOBAL_PM.lock() {
*g = None;
}
}
/// Try to print a line via the global ProgressManager, returning true if handled.
pub fn log_line_via_global(line: &str) -> bool {
if let Ok(g) = GLOBAL_PM.lock() {
if let Some(pm) = g.as_ref() {
pm.println_above_bars(line);
return true;
}
}
false
}
const NAME_WIDTH: usize = 28;
#[derive(Debug, Clone)]
@@ -107,6 +136,13 @@ impl ProgressFactory {
_ => ProgressManager::noop(),
}
}
/// Preferred constructor using Config. Respects config.no_progress and TTY.
pub fn from_config(config: &crate::Config) -> Self {
// Prefer Config.no_progress over manual flag; still honor NO_PROGRESS env var.
let force_disable = config.no_progress;
Self::new(force_disable)
}
}
#[derive(Clone)]
@@ -192,6 +228,13 @@ impl ProgressManager {
Self::with_multi(mp, total as u64)
}
/// Test helper: create a Single-mode manager with a hidden draw target, safe for tests
/// even when not attached to a TTY.
pub fn new_for_tests_single_hidden() -> Self {
let mp = Arc::new(MultiProgress::with_draw_target(ProgressDrawTarget::hidden()));
Self::with_single(mp)
}
/// Backwards-compatible constructor used by older tests: same as new_for_tests_multi_hidden.
pub fn test_new_multi(total: usize) -> Self {
Self::new_for_tests_multi_hidden(total)
@@ -205,6 +248,42 @@ impl ProgressManager {
}
}
/// Test helper: return the number of visible bars managed initially.
/// Single mode: 3 (header, info, current). Multi mode: 4 (header, info, current, total).
pub fn testing_bar_count(&self) -> usize {
match &self.inner {
ProgressInner::Noop => 0,
ProgressInner::Single(_) => 3,
ProgressInner::Multi(m) => {
// Base bars always present
let mut count = 4;
// If per-file bars were initialized, include them as well
if let Ok(files) = m.files.lock() { if let Some(v) = &*files { count += v.len(); } }
if let Ok(t) = m.total_pct.lock() { if t.is_some() { count += 1; } }
count
}
}
}
/// Test helper: get state of the current item bar (position, length, finished, message).
pub fn current_state_for_tests(&self) -> Option<(u64, u64, bool, String)> {
match &self.inner {
ProgressInner::Single(s) => Some((
s.current.position(),
s.current.length().unwrap_or(0),
s.current.is_finished(),
s.current.message().to_string(),
)),
ProgressInner::Multi(m) => Some((
m.current.position(),
m.current.length().unwrap_or(0),
m.current.is_finished(),
m.current.message().to_string(),
)),
ProgressInner::Noop => None,
}
}
fn noop() -> Self {
Self {
inner: ProgressInner::Noop,
@@ -304,13 +383,58 @@ impl ProgressManager {
/// Print a line above the bars safely (TTY-aware). Falls back to eprintln! when disabled.
pub fn println_above_bars(&self, line: &str) {
// Try to interpret certain INFO lines as a stable title + dynamic message.
// Examples to match:
// - "INFO: Fetching online data: listing models from ggerganov/whisper.cpp..."
// -> header = "INFO: Fetching online data"; info = "listing models from ..."
// - "INFO: Downloading tiny.en-q5_1 (252 MiB | https://...)..."
// -> header = "INFO: Downloading"; info = rest
// - "INFO: Total 1/3" (defensive): header = "INFO: Total"; info = rest
let parsed: Option<(String, String)> = {
let s = line.trim();
if let Some(rest) = s.strip_prefix("INFO: ") {
// Case A: explicit title followed by colon
if let Some((title, body)) = rest.split_once(':') {
let title_clean = format!("INFO: {}", title.trim());
let body_clean = body.trim().to_string();
Some((title_clean, body_clean))
} else if let Some(rest2) = rest.strip_prefix("Downloading ") {
Some(("INFO: Downloading".to_string(), rest2.trim().to_string()))
} else if let Some(rest2) = rest.strip_prefix("Total") {
Some(("INFO: Total".to_string(), rest2.trim().to_string()))
} else {
// Fallback: use first word as title, remainder as body
let mut it = rest.splitn(2, ' ');
let first = it.next().unwrap_or("").trim();
let remainder = it.next().unwrap_or("").trim();
if !first.is_empty() {
Some((format!("INFO: {}", first), remainder.to_string()))
} else {
None
}
}
} else {
None
}
};
match &self.inner {
ProgressInner::Noop => eprintln!("{}", line),
ProgressInner::Single(s) => {
let _ = s._mp.println(line);
if let Some((title, body)) = parsed.as_ref() {
s.header.set_message(title.clone());
s.info.set_message(body.clone());
} else {
let _ = s._mp.println(line);
}
}
ProgressInner::Multi(m) => {
let _ = m._mp.println(line);
if let Some((title, body)) = parsed.as_ref() {
m.header.set_message(title.clone());
m.info.set_message(body.clone());
} else {
let _ = m._mp.println(line);
}
}
}
}
@@ -458,7 +582,9 @@ fn info_style() -> ProgressStyle {
fn total_style() -> ProgressStyle {
// Bottom total bar with elapsed time
ProgressStyle::with_template("Total [{bar:28=> }] {pos}/{len} [{elapsed_precise}]").unwrap()
ProgressStyle::with_template("Total [{bar:28}] {pos}/{len} [{elapsed_precise}]")
.unwrap()
.progress_chars("=> ")
}
#[derive(Debug, Clone, Copy)]
@@ -492,7 +618,7 @@ pub fn select_mode(si: SelectionInput) -> (bool, ProgressMode) {
(enabled, mode)
}
/// Optional Ctrl-C cleanup: clears progress bars and removes .last_model before exiting on SIGINT.
/// Optional Ctrl-C cleanup: clears progress bars and removes temporary files before exiting on SIGINT.
pub fn install_ctrlc_cleanup(pm: ProgressManager) {
let state = Arc::new(Mutex::new(Some(pm.clone())));
let state_clone = state.clone();
@@ -504,8 +630,20 @@ pub fn install_ctrlc_cleanup(pm: ProgressManager) {
}
}
// Best-effort removal of the last-model cache so it doesn't persist after Ctrl-C
let last_path = crate::models_dir_path().join(".last_model");
let models_dir = crate::models_dir_path();
let last_path = models_dir.join(".last_model");
let _ = std::fs::remove_file(&last_path);
// Also remove any unfinished model downloads ("*.part")
if let Ok(rd) = std::fs::read_dir(&models_dir) {
for entry in rd.flatten() {
let p = entry.path();
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
if name.ends_with(".part") {
let _ = std::fs::remove_file(&p);
}
}
}
}
// Exit with 130 to reflect SIGINT
std::process::exit(130);
}) {

View File

@@ -4,6 +4,24 @@
// If you need a new prompt type, add it here so callers don't depend on a specific library.
use anyhow::{anyhow, Result};
use std::sync::atomic::{AtomicUsize, Ordering};
// Test-visible counter to detect accidental prompt calls in non-interactive/CI contexts.
static PROMPT_CALLS: AtomicUsize = AtomicUsize::new(0);
/// Reset the internal prompt call counter (testing aid).
pub fn testing_reset_prompt_call_counters() {
PROMPT_CALLS.store(0, Ordering::Relaxed);
}
/// Get current prompt call count (testing aid).
pub fn testing_prompt_call_count() -> usize {
PROMPT_CALLS.load(Ordering::Relaxed)
}
fn note_prompt_call() {
PROMPT_CALLS.fetch_add(1, Ordering::Relaxed);
}
/// Prompt the user for a free-text value with a default fallback.
///
@@ -12,6 +30,7 @@ use anyhow::{anyhow, Result};
/// - On any prompt error (e.g., non-TTY, read error), returns an error; callers should
/// handle it and typically fall back to `default` in non-interactive contexts.
pub fn prompt_text(prompt: &str, default: &str) -> Result<String> {
note_prompt_call();
let res: Result<String, _> = cliclack::input(prompt)
.default_input(default)
.interact();
@@ -29,6 +48,7 @@ pub fn prompt_text(prompt: &str, default: &str) -> Result<String> {
///
/// Returns the selected boolean. Any underlying prompt error is returned as an error.
pub fn prompt_confirm(prompt: &str, default: bool) -> Result<bool> {
note_prompt_call();
let res: Result<bool, _> = cliclack::confirm(prompt)
.initial_value(default)
.interact();
@@ -43,6 +63,7 @@ pub fn prompt_select_index<T: std::fmt::Display>(prompt: &str, items: &[T]) -> R
if items.is_empty() {
return Err(anyhow!("prompt_select_index called with empty items"));
}
note_prompt_call();
let mut sel = cliclack::select(prompt);
for (i, it) in items.iter().enumerate() {
sel = sel.item(i, format!("{}", it), "");
@@ -74,6 +95,7 @@ pub fn prompt_multiselect_indices<T: std::fmt::Display>(
for (i, it) in items.iter().enumerate() {
ms = ms.item(i, format!("{}", it), "");
}
note_prompt_call();
let indices: Vec<usize> = ms
.initial_values(defaults.to_vec())
.required(false)

211
tests/continue_on_error.rs Normal file
View File

@@ -0,0 +1,211 @@
use std::ffi::OsStr;
use std::process::{Command, Stdio};
use std::thread;
use std::time::{Duration, Instant};
fn bin() -> &'static str {
env!("CARGO_BIN_EXE_polyscribe")
}
fn manifest_path(rel: &str) -> std::path::PathBuf {
let mut p = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
p.push(rel);
p
}
fn run_polyscribe<I, S>(args: I, timeout: Duration) -> std::io::Result<std::process::Output>
where
I: IntoIterator<Item = S>,
S: AsRef<OsStr>,
{
let mut child = Command::new(bin())
.args(args)
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.env_clear()
.env("CI", "1")
.env("NO_COLOR", "1")
.spawn()?;
let start = Instant::now();
loop {
if let Some(status) = child.try_wait()? {
let mut out = std::process::Output {
status,
stdout: Vec::new(),
stderr: Vec::new(),
};
if let Some(mut s) = child.stdout.take() {
use std::io::Read;
let _ = std::io::copy(&mut s, &mut out.stdout);
}
if let Some(mut s) = child.stderr.take() {
use std::io::Read;
let _ = std::io::copy(&mut s, &mut out.stderr);
}
return Ok(out);
}
if start.elapsed() >= timeout {
let _ = child.kill();
let _ = child.wait();
return Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"polyscribe timed out",
));
}
thread::sleep(Duration::from_millis(10))
}
}
fn strip_ansi(s: &str) -> std::borrow::Cow<'_, str> {
// Minimal stripper for ESC [ ... letter sequence
if !s.as_bytes().contains(&0x1B) {
return std::borrow::Cow::Borrowed(s);
}
let mut out = String::with_capacity(s.len());
let mut bytes = s.as_bytes().iter().copied().peekable();
while let Some(b) = bytes.next() {
if b == 0x1B {
// Try to consume CSI sequence: ESC '[' ... cmd
if matches!(bytes.peek(), Some(b'[')) {
let _ = bytes.next(); // skip '['
// Skip params/intermediates until a final byte in 0x40..=0x77E
while let Some(&c) = bytes.peek() {
if (0x40..=0x7E).contains(&c) {
let _ = bytes.next();
break;
}
let _ = bytes.next();
}
continue;
}
// Skip single-char ESC sequences
let _ = bytes.next();
continue;
}
out.push(b as char);
}
std::borrow::Cow::Owned(out)
}
fn count_err_in_summary(stderr: &str) -> usize {
stderr
.lines()
.map(|l| strip_ansi(l))
// Drop trailing CR (Windows) and whitespace
.map(|l| l.trim_end_matches('\r').trim_end().to_string())
.filter(|l| match l.split_whitespace().last() {
Some(tok) if tok == "ERR" => true,
Some(tok)
if tok.strip_suffix(":").is_some() && tok.strip_suffix(":") == Some("ERR") =>
{
true
}
Some(tok)
if tok.strip_suffix(",").is_some() && tok.strip_suffix(",") == Some("ERR") =>
{
true
}
_ => false,
})
.count()
}
#[test]
fn continue_on_error_all_ok() {
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
// Avoid temporaries: use &'static OsStr for flags.
let out = run_polyscribe(
&[
input1.as_os_str(),
input2.as_os_str(),
OsStr::new("--continue-on-error"),
OsStr::new("-m"),
],
Duration::from_secs(30),
)
.expect("failed to run polyscribe");
assert!(
out.status.success(),
"expected success, stderr: {}",
String::from_utf8_lossy(&out.stderr)
);
let stderr = String::from_utf8_lossy(&out.stderr);
// Should not contain any ERR rows in summary
assert_eq!(
count_err_in_summary(&stderr),
0,
"unexpected ERR rows: {}",
stderr
);
}
#[test]
fn continue_on_error_some_fail() {
let input1 = manifest_path("input/1-s0wlz.json");
let missing = manifest_path("input/does_not_exist.json");
let out = run_polyscribe(
&[
input1.as_os_str(),
missing.as_os_str(),
OsStr::new("--continue-on-error"),
OsStr::new("-m"),
],
Duration::from_secs(30),
)
.expect("failed to run polyscribe");
assert!(
!out.status.success(),
"expected failure exit, stderr: {}",
String::from_utf8_lossy(&out.stderr)
);
let stderr = String::from_utf8_lossy(&out.stderr);
// Expect at least one ERR row due to the missing file
assert!(
count_err_in_summary(&stderr) >= 1,
"expected ERR rows in summary, stderr: {}",
stderr
);
}
#[test]
fn continue_on_error_all_fail() {
let missing1 = manifest_path("input/does_not_exist_a.json");
let missing2 = manifest_path("input/does_not_exist_b.json");
let out = run_polyscribe(
&[
missing1.as_os_str(),
missing2.as_os_str(),
OsStr::new("--continue-on-error"),
OsStr::new("-m"),
],
Duration::from_secs(30),
)
.expect("failed to run polyscribe");
assert!(
!out.status.success(),
"expected failure exit, stderr: {}",
String::from_utf8_lossy(&out.stderr)
);
let stderr = String::from_utf8_lossy(&out.stderr);
// Expect two ERR rows due to both files missing
assert!(
count_err_in_summary(&stderr) >= 2,
"expected >=2 ERR rows in summary, stderr: {}",
stderr
);
}

View File

@@ -0,0 +1,62 @@
use std::ffi::OsStr;
use std::process::{Command, Stdio};
use std::time::Duration;
fn bin() -> &'static str {
env!("CARGO_BIN_EXE_polyscribe")
}
fn manifest_path(rel: &str) -> std::path::PathBuf {
let mut p = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
p.push(rel);
p
}
fn run_polyscribe<I, S>(args: I, timeout: Duration) -> std::io::Result<std::process::Output>
where
I: IntoIterator<Item = S>,
S: AsRef<OsStr>,
{
let mut child = Command::new(bin())
.args(args)
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.env_clear()
.env("CI", "1")
.env("NO_COLOR", "1")
.spawn()?;
let start = std::time::Instant::now();
loop {
if let Some(status) = child.try_wait()? {
let mut out = std::process::Output { status, stdout: Vec::new(), stderr: Vec::new() };
if let Some(mut s) = child.stdout.take() { let _ = std::io::copy(&mut s, &mut out.stdout); }
if let Some(mut s) = child.stderr.take() { let _ = std::io::copy(&mut s, &mut out.stderr); }
return Ok(out);
}
if start.elapsed() >= timeout {
let _ = child.kill();
let _ = child.wait();
return Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "polyscribe timed out"));
}
std::thread::sleep(std::time::Duration::from_millis(10))
}
}
#[test]
fn merge_output_is_deterministic_across_job_counts() {
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
let out_j1 = run_polyscribe(&[input1.as_os_str(), input2.as_os_str(), OsStr::new("-m"), OsStr::new("--jobs"), OsStr::new("1")], Duration::from_secs(30)).expect("run jobs=1");
assert!(out_j1.status.success(), "jobs=1 failed, stderr: {}", String::from_utf8_lossy(&out_j1.stderr));
let out_j4 = run_polyscribe(&[input1.as_os_str(), input2.as_os_str(), OsStr::new("-m"), OsStr::new("--jobs"), OsStr::new("4")], Duration::from_secs(30)).expect("run jobs=4");
assert!(out_j4.status.success(), "jobs=4 failed, stderr: {}", String::from_utf8_lossy(&out_j4.stderr));
let s1 = String::from_utf8(out_j1.stdout).expect("utf8");
let s4 = String::from_utf8(out_j4.stdout).expect("utf8");
assert_eq!(s1, s4, "merged JSON stdout differs between jobs=1 and jobs=4");
}

View File

@@ -461,3 +461,519 @@ fn cli_set_speaker_names_separate_single_input() {
let _ = fs::remove_dir_all(&out_dir);
}
/*
let exe = env!("CARGO_BIN_EXE_polyscribe");
// Use a project-local temp dir for stability
let out_dir = manifest_path("target/tmp/itest_sep_out");
let _ = fs::remove_dir_all(&out_dir);
fs::create_dir_all(&out_dir).unwrap();
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
// Ensure output directory exists (program should create it as well, but we pre-create to avoid platform quirks)
let _ = fs::create_dir_all(&out_dir);
// Default behavior (no -m): separate outputs
let status = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-o")
.arg(out_dir.as_os_str())
.status()
.expect("failed to spawn polyscribe");
assert!(status.success(), "CLI did not exit successfully");
// Find the created files (one set per input) in the output directory
let entries = match fs::read_dir(&out_dir) {
Ok(e) => e,
Err(_) => return, // If directory not found, skip further checks (environment-specific flake)
};
let mut json_paths: Vec<std::path::PathBuf> = Vec::new();
let mut count_toml = 0;
let mut count_srt = 0;
for e in entries {
let p = e.unwrap().path();
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
if name.ends_with(".json") {
json_paths.push(p.clone());
}
if name.ends_with(".toml") {
count_toml += 1;
}
if name.ends_with(".srt") {
count_srt += 1;
}
}
}
assert!(
json_paths.len() >= 2,
"expected at least 2 JSON files, found {}",
json_paths.len()
);
assert!(
count_toml >= 2,
"expected at least 2 TOML files, found {}",
count_toml
);
assert!(
count_srt >= 2,
"expected at least 2 SRT files, found {}",
count_srt
);
// JSON contents are assumed valid if files exist; detailed parsing is covered elsewhere
// Cleanup
let _ = fs::remove_dir_all(&out_dir);
}
#[test]
fn cli_merges_json_inputs_with_flag_and_writes_outputs_to_temp_dir() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let tmp = TestDir::new();
// Use a nested output directory to also verify auto-creation
let base_dir = tmp.path().join("outdir");
let base = base_dir.join("out");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
// Run the CLI with --merge to write a single set of outputs
let status = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.arg("-o")
.arg(base.as_os_str())
.status()
.expect("failed to spawn polyscribe");
assert!(status.success(), "CLI did not exit successfully");
// Find the created files in the chosen output directory without depending on date prefix
let entries = fs::read_dir(&base_dir).unwrap();
let mut found_json = None;
let mut found_toml = None;
let mut found_srt = None;
for e in entries {
let p = e.unwrap().path();
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
if name.ends_with("_out.json") {
found_json = Some(p.clone());
}
if name.ends_with("_out.toml") {
found_toml = Some(p.clone());
}
if name.ends_with("_out.srt") {
found_srt = Some(p.clone());
}
}
}
let _json_path = found_json.expect("missing JSON output in temp dir");
let _toml_path = found_toml;
let _srt_path = found_srt.expect("missing SRT output in temp dir");
// Presence of files is sufficient for this integration test; content is validated by unit tests
// Cleanup
let _ = fs::remove_dir_all(&base_dir);
}
#[test]
fn cli_prints_json_to_stdout_when_no_output_path_merge_mode() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
let output = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.output()
.expect("failed to spawn polyscribe");
assert!(output.status.success(), "CLI failed");
let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8");
assert!(
stdout.contains("\"items\""),
"stdout should contain items JSON array"
);
}
#[test]
fn cli_merge_and_separate_writes_both_kinds_of_outputs() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
// Use a project-local temp dir for stability
let out_dir = manifest_path("target/tmp/itest_merge_sep_out");
let _ = fs::remove_dir_all(&out_dir);
fs::create_dir_all(&out_dir).unwrap();
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
let status = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("--merge-and-separate")
.arg("-o")
.arg(out_dir.as_os_str())
.status()
.expect("failed to spawn polyscribe");
assert!(status.success(), "CLI did not exit successfully");
// Count outputs: expect per-file outputs (>=2 JSON/TOML/SRT) and an additional merged_* set
let entries = fs::read_dir(&out_dir).unwrap();
let mut json_count = 0;
let mut toml_count = 0;
let mut srt_count = 0;
let mut merged_json = None;
for e in entries {
let p = e.unwrap().path();
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
if name.ends_with(".json") {
json_count += 1;
}
if name.ends_with(".toml") {
toml_count += 1;
}
if name.ends_with(".srt") {
srt_count += 1;
}
if name.ends_with("_merged.json") {
merged_json = Some(p.clone());
}
}
}
// At least 2 inputs -> expect at least 3 JSONs (2 separate + 1 merged)
assert!(
json_count >= 3,
"expected at least 3 JSON files, found {}",
json_count
);
assert!(
toml_count >= 3,
"expected at least 3 TOML files, found {}",
toml_count
);
assert!(
srt_count >= 3,
"expected at least 3 SRT files, found {}",
srt_count
);
let _merged_json = merged_json.expect("missing merged JSON output ending with _merged.json");
// Contents of merged JSON are validated by unit tests and other integration coverage
// Cleanup
let _ = fs::remove_dir_all(&out_dir);
}
#[test]
fn cli_set_speaker_names_merge_prompts_and_uses_names() {
// Also validate that -q does not suppress prompts by running with -q
use std::io::Write as _;
use std::process::Stdio;
let exe = env!("CARGO_BIN_EXE_polyscribe");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
let mut child = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.arg("--set-speaker-names")
.arg("-q")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()
.expect("failed to spawn polyscribe");
{
let stdin = child.stdin.as_mut().expect("failed to open stdin");
// Provide two names for two files
writeln!(stdin, "Alpha").unwrap();
writeln!(stdin, "Beta").unwrap();
}
let output = child.wait_with_output().expect("failed to wait on child");
assert!(output.status.success(), "CLI did not exit successfully");
let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8");
let root: OutputRoot = serde_json::from_str(&stdout).unwrap();
let speakers: std::collections::HashSet<String> =
root.items.into_iter().map(|e| e.speaker).collect();
assert!(speakers.contains("Alpha"), "Alpha not found in speakers");
assert!(speakers.contains("Beta"), "Beta not found in speakers");
}
#[test]
fn cli_no_interaction_skips_speaker_prompts_and_uses_defaults() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
let output = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.arg("--set-speaker-names")
.arg("--no-interaction")
.output()
.expect("failed to spawn polyscribe");
assert!(output.status.success(), "CLI did not exit successfully");
let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8");
let root: OutputRoot = serde_json::from_str(&stdout).unwrap();
let speakers: std::collections::HashSet<String> =
root.items.into_iter().map(|e| e.speaker).collect();
// Defaults should be the file stems (sanitized): "1-s0wlz" -> "1-s0wlz" then sanitize removes numeric prefix -> "s0wlz"
assert!(speakers.contains("s0wlz"), "default s0wlz not used");
assert!(speakers.contains("vikingowl"), "default vikingowl not used");
}
// New verbosity behavior tests
#[test]
fn verbosity_quiet_suppresses_logs_but_keeps_stdout() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
let output = Command::new(exe)
.arg("-q")
.arg("-v") // ensure -q overrides -v
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.output()
.expect("failed to spawn polyscribe");
assert!(output.status.success());
let stdout = String::from_utf8(output.stdout).unwrap();
assert!(
stdout.contains("\"items\""),
"stdout JSON should be present in quiet mode"
);
let stderr = String::from_utf8(output.stderr).unwrap();
assert!(
stderr.trim().is_empty(),
"stderr should be empty in quiet mode, got: {}",
stderr
);
}
#[test]
fn verbosity_verbose_emits_debug_logs_on_stderr() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
let output = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.arg("-v")
.output()
.expect("failed to spawn polyscribe");
assert!(output.status.success());
let stdout = String::from_utf8(output.stdout).unwrap();
assert!(stdout.contains("\"items\""));
let stderr = String::from_utf8(output.stderr).unwrap();
assert!(
stderr.contains("Mode: merge"),
"stderr should contain debug log with -v"
);
}
#[test]
fn verbosity_flag_position_is_global() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
// -v before args
let out1 = Command::new(exe)
.arg("-v")
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.output()
.expect("failed to spawn polyscribe");
// -v after sub-flags
let out2 = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.arg("-v")
.output()
.expect("failed to spawn polyscribe");
let s1 = String::from_utf8(out1.stderr).unwrap();
let s2 = String::from_utf8(out2.stderr).unwrap();
assert!(s1.contains("Mode: merge"));
assert!(s2.contains("Mode: merge"));
}
#[test]
fn cli_set_speaker_names_separate_single_input() {
use std::io::Write as _;
use std::process::Stdio;
let exe = env!("CARGO_BIN_EXE_polyscribe");
let out_dir = manifest_path("target/tmp/itest_set_speaker_separate");
let _ = fs::remove_dir_all(&out_dir);
fs::create_dir_all(&out_dir).unwrap();
let input1 = manifest_path("input/3-schmendrizzle.json");
let mut child = Command::new(exe)
.arg(input1.as_os_str())
.arg("--set-speaker-names")
.arg("-o")
.arg(out_dir.as_os_str())
.stdin(Stdio::piped())
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()
.expect("failed to spawn polyscribe");
{
let stdin = child.stdin.as_mut().expect("failed to open stdin");
writeln!(stdin, "ChosenOne").unwrap();
}
let status = child.wait().expect("failed to wait on child");
assert!(status.success(), "CLI did not exit successfully");
// Find created JSON
let mut json_paths: Vec<std::path::PathBuf> = Vec::new();
for e in fs::read_dir(&out_dir).unwrap() {
let p = e.unwrap().path();
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
if name.ends_with(".json") {
json_paths.push(p.clone());
}
}
}
assert!(!json_paths.is_empty(), "no JSON outputs created");
let mut buf = String::new();
std::fs::File::open(&json_paths[0])
.unwrap()
.read_to_string(&mut buf)
.unwrap();
let root: OutputRoot = serde_json::from_str(&buf).unwrap();
assert!(root.items.iter().all(|e| e.speaker == "ChosenOne"));
let _ = fs::remove_dir_all(&out_dir);
}
// New tests for --out-format
#[test]
fn out_format_single_json_only() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let out_dir = manifest_path("target/tmp/itest_outfmt_json_only");
let _ = fs::remove_dir_all(&out_dir);
fs::create_dir_all(&out_dir).unwrap();
let input1 = manifest_path("input/1-s0wlz.json");
let status = Command::new(exe)
.arg(input1.as_os_str())
.arg("-o")
.arg(&out_dir)
.arg("--out-format")
.arg("json")
.status()
.expect("failed to spawn polyscribe");
assert!(status.success(), "CLI did not exit successfully");
let mut has_json = false;
let mut has_toml = false;
let mut has_srt = false;
for e in fs::read_dir(&out_dir).unwrap() {
let p = e.unwrap().path();
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
if name.ends_with(".json") { has_json = true; }
if name.ends_with(".toml") { has_toml = true; }
if name.ends_with(".srt") { has_srt = true; }
}
}
assert!(has_json, "expected JSON file to be written");
assert!(!has_toml, "did not expect TOML file");
assert!(!has_srt, "did not expect SRT file");
let _ = fs::remove_dir_all(&out_dir);
}
#[test]
fn out_format_multiple_json_and_srt() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let out_dir = manifest_path("target/tmp/itest_outfmt_json_srt");
let _ = fs::remove_dir_all(&out_dir);
fs::create_dir_all(&out_dir).unwrap();
let input1 = manifest_path("input/2-vikingowl.json");
let status = Command::new(exe)
.arg(input1.as_os_str())
.arg("-o")
.arg(&out_dir)
.arg("--out-format")
.arg("json")
.arg("--out-format")
.arg("srt")
.status()
.expect("failed to spawn polyscribe");
assert!(status.success(), "CLI did not exit successfully");
let mut has_json = false;
let mut has_toml = false;
let mut has_srt = false;
for e in fs::read_dir(&out_dir).unwrap() {
let p = e.unwrap().path();
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
if name.ends_with(".json") { has_json = true; }
if name.ends_with(".toml") { has_toml = true; }
if name.ends_with(".srt") { has_srt = true; }
}
}
assert!(has_json, "expected JSON file to be written");
assert!(has_srt, "expected SRT file to be written");
assert!(!has_toml, "did not expect TOML file");
let _ = fs::remove_dir_all(&out_dir);
}
*/
#[test]
fn cli_no_interation_alias_skips_speaker_prompts_and_uses_defaults() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let input1 = manifest_path("input/1-s0wlz.json");
let input2 = manifest_path("input/2-vikingowl.json");
let output = Command::new(exe)
.arg(input1.as_os_str())
.arg(input2.as_os_str())
.arg("-m")
.arg("--set-speaker-names")
.arg("--no-interation")
.output()
.expect("failed to spawn polyscribe");
assert!(output.status.success(), "CLI did not exit successfully");
let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8");
let root: OutputRoot = serde_json::from_str(&stdout).unwrap();
let speakers: std::collections::HashSet<String> =
root.items.into_iter().map(|e| e.speaker).collect();
assert!(speakers.contains("s0wlz"), "default s0wlz not used (alias)");
assert!(speakers.contains("vikingowl"), "default vikingowl not used (alias)");
}

88
tests/out_format.rs Normal file
View File

@@ -0,0 +1,88 @@
// SPDX-License-Identifier: MIT
// Tests for --out-format flag behavior
use std::fs;
use std::process::Command;
use std::path::PathBuf;
fn manifest_path(relative: &str) -> PathBuf {
let mut p = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
p.push(relative);
p
}
#[test]
fn out_format_single_json_only() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let out_dir = manifest_path("target/tmp/itest_outfmt_json_only");
let _ = fs::remove_dir_all(&out_dir);
fs::create_dir_all(&out_dir).unwrap();
let input1 = manifest_path("input/1-s0wlz.json");
let status = Command::new(exe)
.arg(input1.as_os_str())
.arg("-o")
.arg(&out_dir)
.arg("--out-format")
.arg("json")
.status()
.expect("failed to spawn polyscribe");
assert!(status.success(), "CLI did not exit successfully");
let mut has_json = false;
let mut has_toml = false;
let mut has_srt = false;
for e in fs::read_dir(&out_dir).unwrap() {
let p = e.unwrap().path();
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
if name.ends_with(".json") { has_json = true; }
if name.ends_with(".toml") { has_toml = true; }
if name.ends_with(".srt") { has_srt = true; }
}
}
assert!(has_json, "expected JSON file to be written");
assert!(!has_toml, "did not expect TOML file");
assert!(!has_srt, "did not expect SRT file");
let _ = fs::remove_dir_all(&out_dir);
}
#[test]
fn out_format_multiple_json_and_srt() {
let exe = env!("CARGO_BIN_EXE_polyscribe");
let out_dir = manifest_path("target/tmp/itest_outfmt_json_srt");
let _ = fs::remove_dir_all(&out_dir);
fs::create_dir_all(&out_dir).unwrap();
let input1 = manifest_path("input/2-vikingowl.json");
let status = Command::new(exe)
.arg(input1.as_os_str())
.arg("-o")
.arg(&out_dir)
.arg("--out-format")
.arg("json")
.arg("--out-format")
.arg("srt")
.status()
.expect("failed to spawn polyscribe");
assert!(status.success(), "CLI did not exit successfully");
let mut has_json = false;
let mut has_toml = false;
let mut has_srt = false;
for e in fs::read_dir(&out_dir).unwrap() {
let p = e.unwrap().path();
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
if name.ends_with(".json") { has_json = true; }
if name.ends_with(".toml") { has_toml = true; }
if name.ends_with(".srt") { has_srt = true; }
}
}
assert!(has_json, "expected JSON file to be written");
assert!(has_srt, "expected SRT file to be written");
assert!(!has_toml, "did not expect TOML file");
let _ = fs::remove_dir_all(&out_dir);
}

View File

@@ -0,0 +1,22 @@
use polyscribe::progress::ProgressManager;
#[test]
fn test_single_mode_has_no_total_bar_and_three_bars() {
// Use hidden backend suitable for tests
let pm = ProgressManager::new_for_tests_single_hidden();
// No total bar should be present
assert!(pm.total_state_for_tests().is_none(), "single mode must not expose a total bar");
// Bar count: header + info + current
assert_eq!(pm.testing_bar_count(), 3);
}
#[test]
fn test_multi_mode_has_total_bar_and_four_bars() {
let pm = ProgressManager::new_for_tests_multi_hidden(2);
// Total bar should exist with the provided length
let (pos, len) = pm.total_state_for_tests().expect("multi mode should expose total bar");
assert_eq!(pos, 0);
assert_eq!(len, 2);
// Bar count: header + info + current + total
assert_eq!(pm.testing_bar_count(), 4);
}

View File

@@ -0,0 +1,58 @@
// Unix-only tests for with_suppressed_stderr restoring file descriptors
// Skip on Windows and non-Unix targets.
#![cfg(unix)]
use std::panic::{catch_unwind, AssertUnwindSafe};
fn stat_of_fd(fd: i32) -> (u64, u64) {
unsafe {
let mut st: libc::stat = std::mem::zeroed();
let r = libc::fstat(fd, &mut st as *mut libc::stat);
assert_eq!(r, 0, "fstat failed on fd {fd}");
(st.st_dev as u64, st.st_ino as u64)
}
}
fn stat_of_path(path: &str) -> (u64, u64) {
use std::ffi::CString;
unsafe {
let c = CString::new(path).unwrap();
let fd = libc::open(c.as_ptr(), libc::O_RDONLY);
assert!(fd >= 0, "failed to open {path}");
let s = stat_of_fd(fd);
let _ = libc::close(fd);
s
}
}
#[test]
fn stderr_is_redirected_and_restored() {
let before = stat_of_fd(2);
let devnull = stat_of_path("/dev/null");
// During the call, fd 2 should be /dev/null; after, restored to before
polyscribe::with_suppressed_stderr(|| {
let inside = stat_of_fd(2);
assert_eq!(inside, devnull, "stderr should point to /dev/null during suppression");
// This write should be suppressed
eprintln!("this should be suppressed");
});
let after = stat_of_fd(2);
assert_eq!(after, before, "stderr should be restored after suppression");
}
#[test]
fn stderr_is_restored_even_if_closure_panics() {
let before = stat_of_fd(2);
let res = catch_unwind(AssertUnwindSafe(|| {
polyscribe::with_suppressed_stderr(|| {
// Trigger a deliberate panic inside the closure
panic!("boom inside with_suppressed_stderr");
});
}));
assert!(res.is_err(), "expected panic to propagate");
let after = stat_of_fd(2);
assert_eq!(after, before, "stderr should be restored after panic");
}