Compare commits

...

8 Commits

9 changed files with 842 additions and 267 deletions

196
Cargo.lock generated
View File

@@ -93,9 +93,9 @@ dependencies = [
[[package]]
name = "anyhow"
version = "1.0.98"
version = "1.0.99"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487"
checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100"
[[package]]
name = "atomic-waker"
@@ -103,6 +103,17 @@ version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
[[package]]
name = "atty"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
dependencies = [
"hermit-abi",
"libc",
"winapi",
]
[[package]]
name = "autocfg"
version = "1.5.0"
@@ -179,9 +190,9 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a"
[[package]]
name = "cc"
version = "1.2.31"
version = "1.2.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3a42d84bb6b69d3a8b3eaacf0d88f179e1929695e1ad012b6cf64d9caaa5fd2"
checksum = "2352e5597e9c544d5e6d9c95190d5d27738ade584fa8db0a16e130e5c2b5296e"
dependencies = [
"shlex",
]
@@ -228,9 +239,9 @@ dependencies = [
[[package]]
name = "clap"
version = "4.5.43"
version = "4.5.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50fd97c9dc2399518aa331917ac6f274280ec5eb34e555dd291899745c48ec6f"
checksum = "1c1f056bae57e3e54c3375c41ff79619ddd13460a17d7438712bd0d83fda4ff8"
dependencies = [
"clap_builder",
"clap_derive",
@@ -238,9 +249,9 @@ dependencies = [
[[package]]
name = "clap_builder"
version = "4.5.43"
version = "4.5.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c35b5830294e1fa0462034af85cc95225a4cb07092c088c55bda3147cfcd8f65"
checksum = "b3e7f4214277f3c7aa526a59dd3fbe306a370daee1f8b7b8c987069cd8e888a8"
dependencies = [
"anstream",
"anstyle",
@@ -250,9 +261,9 @@ dependencies = [
[[package]]
name = "clap_complete"
version = "4.5.56"
version = "4.5.57"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67e4efcbb5da11a92e8a609233aa1e8a7d91e38de0be865f016d14700d45a7fd"
checksum = "4d9501bd3f5f09f7bbee01da9a511073ed30a80cd7a509f1214bb74eadea71ad"
dependencies = [
"clap",
]
@@ -285,6 +296,20 @@ dependencies = [
"roff",
]
[[package]]
name = "cliclack"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c420bdc04c123a2df04d9c5a07289195f00007af6e45ab18f55e56dc7e04b8"
dependencies = [
"console",
"indicatif",
"once_cell",
"strsim",
"textwrap",
"zeroize",
]
[[package]]
name = "cmake"
version = "0.1.54"
@@ -300,6 +325,19 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75"
[[package]]
name = "console"
version = "0.15.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8"
dependencies = [
"encode_unicode",
"libc",
"once_cell",
"unicode-width",
"windows-sys 0.59.0",
]
[[package]]
name = "core-foundation"
version = "0.9.4"
@@ -362,6 +400,12 @@ version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
[[package]]
name = "encode_unicode"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0"
[[package]]
name = "encoding_rs"
version = "0.8.35"
@@ -520,9 +564,9 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]]
name = "glob"
version = "0.3.2"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280"
[[package]]
name = "h2"
@@ -555,6 +599,15 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermit-abi"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
dependencies = [
"libc",
]
[[package]]
name = "http"
version = "1.3.1"
@@ -814,6 +867,19 @@ dependencies = [
"hashbrown",
]
[[package]]
name = "indicatif"
version = "0.17.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235"
dependencies = [
"console",
"number_prefix",
"portable-atomic",
"unicode-width",
"web-time",
]
[[package]]
name = "io-uring"
version = "0.7.9"
@@ -874,9 +940,9 @@ dependencies = [
[[package]]
name = "libc"
version = "0.2.174"
version = "0.2.175"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776"
checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543"
[[package]]
name = "libloading"
@@ -980,6 +1046,12 @@ dependencies = [
"autocfg",
]
[[package]]
name = "number_prefix"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
[[package]]
name = "object"
version = "0.36.7"
@@ -1074,10 +1146,13 @@ name = "polyscribe"
version = "0.1.0"
dependencies = [
"anyhow",
"atty",
"chrono",
"clap",
"clap_complete",
"clap_mangen",
"cliclack",
"indicatif",
"libc",
"reqwest",
"serde",
@@ -1088,6 +1163,12 @@ dependencies = [
"whisper-rs",
]
[[package]]
name = "portable-atomic"
version = "1.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483"
[[package]]
name = "potential_utf"
version = "0.1.2"
@@ -1109,9 +1190,9 @@ dependencies = [
[[package]]
name = "proc-macro2"
version = "1.0.95"
version = "1.0.97"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778"
checksum = "d61789d7719defeb74ea5fe81f2fdfdbd28a803847077cecce2ff14e1472f6f1"
dependencies = [
"unicode-ident",
]
@@ -1282,9 +1363,9 @@ dependencies = [
[[package]]
name = "rustversion"
version = "1.0.21"
version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d"
checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
[[package]]
name = "ryu"
@@ -1396,9 +1477,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "slab"
version = "0.4.10"
version = "0.4.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04dc19736151f35336d325007ac991178d504a119863a2fcb3758cdb5e52c50d"
checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589"
[[package]]
name = "smallvec"
@@ -1406,6 +1487,12 @@ version = "1.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03"
[[package]]
name = "smawk"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7c388c1b5e93756d0c740965c41e8822f866621d41acbdf6336a6a168f8840c"
[[package]]
name = "socket2"
version = "0.6.0"
@@ -1499,6 +1586,17 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "textwrap"
version = "0.16.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c13547615a44dc9c452a8a534638acdf07120d4b6847c8178705da06306a3057"
dependencies = [
"smawk",
"unicode-linebreak",
"unicode-width",
]
[[package]]
name = "tinystr"
version = "0.8.1"
@@ -1682,6 +1780,18 @@ version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
[[package]]
name = "unicode-linebreak"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b09c83c3c29d37506a3e260c08c03743a6bb66a9cd432c6934ab501a190571f"
[[package]]
name = "unicode-width"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a1a07cc7db3810833284e8d372ccdc6da29741639ecc70c9ec107df0fa6154c"
[[package]]
name = "untrusted"
version = "0.9.0"
@@ -1828,6 +1938,16 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "web-time"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "whisper-rs"
version = "0.14.3"
@@ -1847,6 +1967,28 @@ dependencies = [
"fs_extra",
]
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-core"
version = "0.61.2"
@@ -2147,6 +2289,20 @@ name = "zeroize"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
dependencies = [
"zeroize_derive",
]
[[package]]
name = "zeroize_derive"
version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "zerotrie"

View File

@@ -29,6 +29,9 @@ sha2 = "0.10"
# whisper-rs is always used (CPU-only by default); GPU features map onto it
whisper-rs = { git = "https://github.com/tazz4843/whisper-rs" }
libc = "0.2"
cliclack = "0.3"
indicatif = "0.17"
atty = "0.2"
[dev-dependencies]
tempfile = "3"

View File

@@ -28,6 +28,7 @@ Installation
Quickstart
1) Download a model (first run can prompt you):
- ./target/release/polyscribe --download-models
- In the interactive picker, use Up/Down to navigate, Space to toggle selections, and Enter to confirm. Models are grouped by base (e.g., tiny, base, small).
2) Transcribe a file:
- ./target/release/polyscribe -v -o output my_audio.mp3

View File

@@ -32,6 +32,7 @@ CLI reference
- Number of layers to offload to the GPU when supported.
- --download-models
- Launch interactive model downloader (lists Hugging Face models; multi-select to download).
- Controls: Use Up/Down to navigate, Space to toggle selections, and Enter to confirm. Models are grouped by base (e.g., tiny, base, small).
- --update-models
- Verify/update local models by comparing sizes and hashes with the upstream manifest.
- -v, --verbose (repeatable)

View File

@@ -35,12 +35,14 @@ pub trait TranscribeBackend {
/// - speaker: label to attach to all produced segments.
/// - lang_opt: optional language hint (e.g., "en"); None means auto/multilingual model default.
/// - gpu_layers: optional GPU layer count if applicable (ignored by some backends).
/// - progress_cb: optional callback receiving percentage [0..=100] updates.
fn transcribe(
&self,
audio_path: &Path,
speaker: &str,
lang_opt: Option<&str>,
gpu_layers: Option<u32>,
progress_cb: Option<&(dyn Fn(i32) + Send + Sync)>,
) -> Result<Vec<OutputEntry>>;
}
@@ -148,8 +150,9 @@ impl TranscribeBackend for CpuBackend {
speaker: &str,
lang_opt: Option<&str>,
_gpu_layers: Option<u32>,
progress_cb: Option<&(dyn Fn(i32) + Send + Sync)>,
) -> Result<Vec<OutputEntry>> {
transcribe_with_whisper_rs(audio_path, speaker, lang_opt)
transcribe_with_whisper_rs(audio_path, speaker, lang_opt, progress_cb)
}
}
@@ -163,9 +166,10 @@ impl TranscribeBackend for CudaBackend {
speaker: &str,
lang_opt: Option<&str>,
_gpu_layers: Option<u32>,
progress_cb: Option<&(dyn Fn(i32) + Send + Sync)>,
) -> Result<Vec<OutputEntry>> {
// whisper-rs uses enabled CUDA feature at build time; call same code path
transcribe_with_whisper_rs(audio_path, speaker, lang_opt)
transcribe_with_whisper_rs(audio_path, speaker, lang_opt, progress_cb)
}
}
@@ -179,8 +183,9 @@ impl TranscribeBackend for HipBackend {
speaker: &str,
lang_opt: Option<&str>,
_gpu_layers: Option<u32>,
progress_cb: Option<&(dyn Fn(i32) + Send + Sync)>,
) -> Result<Vec<OutputEntry>> {
transcribe_with_whisper_rs(audio_path, speaker, lang_opt)
transcribe_with_whisper_rs(audio_path, speaker, lang_opt, progress_cb)
}
}
@@ -194,6 +199,7 @@ impl TranscribeBackend for VulkanBackend {
_speaker: &str,
_lang_opt: Option<&str>,
_gpu_layers: Option<u32>,
_progress_cb: Option<&(dyn Fn(i32) + Send + Sync)>,
) -> Result<Vec<OutputEntry>> {
Err(anyhow!(
"Vulkan backend not yet wired to whisper.cpp FFI. Build with --features gpu-vulkan and ensure Vulkan SDK is installed. How to fix: install Vulkan loader (libvulkan), set VULKAN_SDK, and run cargo build --features gpu-vulkan."
@@ -301,8 +307,13 @@ pub(crate) fn transcribe_with_whisper_rs(
audio_path: &Path,
speaker: &str,
lang_opt: Option<&str>,
progress_cb: Option<&(dyn Fn(i32) + Send + Sync)>,
) -> Result<Vec<OutputEntry>> {
if let Some(cb) = progress_cb { cb(0); }
let pcm = decode_audio_to_pcm_f32_ffmpeg(audio_path)?;
if let Some(cb) = progress_cb { cb(5); }
let model = find_model_file()?;
let is_en_only = model
.file_name()
@@ -341,6 +352,7 @@ pub(crate) fn transcribe_with_whisper_rs(
.map_err(|e| anyhow!("Failed to create Whisper state: {:?}", e))?;
Ok::<_, anyhow::Error>((ctx, state))
})?;
if let Some(cb) = progress_cb { cb(20); }
let mut params =
whisper_rs::FullParams::new(whisper_rs::SamplingStrategy::Greedy { best_of: 1 });
@@ -352,13 +364,16 @@ pub(crate) fn transcribe_with_whisper_rs(
if let Some(lang) = lang_opt {
params.set_language(Some(lang));
}
if let Some(cb) = progress_cb { cb(30); }
crate::with_suppressed_stderr(|| {
if let Some(cb) = progress_cb { cb(40); }
state
.full(params, &pcm)
.map_err(|e| anyhow!("Whisper full() failed: {:?}", e))
})?;
if let Some(cb) = progress_cb { cb(90); }
let num_segments = state
.full_n_segments()
.map_err(|e| anyhow!("Failed to get segments: {:?}", e))?;
@@ -383,5 +398,6 @@ pub(crate) fn transcribe_with_whisper_rs(
text: text.trim().to_string(),
});
}
if let Some(cb) = progress_cb { cb(100); }
Ok(items)
}

View File

@@ -19,6 +19,7 @@ use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
static QUIET: AtomicBool = AtomicBool::new(false);
static NO_INTERACTION: AtomicBool = AtomicBool::new(false);
static VERBOSE: AtomicU8 = AtomicU8::new(0);
static NO_PROGRESS: AtomicBool = AtomicBool::new(false);
/// Set quiet mode: when true, non-interactive logs should be suppressed.
pub fn set_quiet(q: bool) {
@@ -47,6 +48,15 @@ pub fn verbose_level() -> u8 {
VERBOSE.load(Ordering::Relaxed)
}
/// Disable interactive progress indicators (bars/spinners)
pub fn set_no_progress(b: bool) {
NO_PROGRESS.store(b, Ordering::Relaxed);
}
/// Return current no-progress state
pub fn is_no_progress() -> bool {
NO_PROGRESS.load(Ordering::Relaxed)
}
/// Check whether stdin is connected to a TTY. Used to avoid blocking prompts when not interactive.
pub fn stdin_is_tty() -> bool {
#[cfg(unix)]
@@ -173,50 +183,171 @@ where
}
}
/// Centralized UI helpers (TTY-aware, quiet/verbose-aware)
pub mod ui {
use std::io;
// Prefer cliclack for all user-visible messages to ensure consistent, TTY-aware output.
// Falls back to stderr printing if needed.
/// Startup intro/banner (suppressed when quiet).
pub fn intro(msg: impl AsRef<str>) {
if crate::is_quiet() { return; }
// Use cliclack intro to render a nice banner when TTY
let _ = cliclack::intro(msg.as_ref());
}
/// Print an informational line (suppressed when quiet).
pub fn info(msg: impl AsRef<str>) {
if crate::is_quiet() { return; }
let _ = cliclack::log::info(msg.as_ref());
}
/// Print a warning (always printed).
pub fn warn(msg: impl AsRef<str>) {
// cliclack provides a warning-level log utility
let _ = cliclack::log::warning(msg.as_ref());
}
/// Print an error (always printed).
pub fn error(msg: impl AsRef<str>) {
let _ = cliclack::log::error(msg.as_ref());
}
/// Print a line above any progress bars (maps to cliclack log; synchronized).
pub fn println_above_bars(msg: impl AsRef<str>) {
if crate::is_quiet() { return; }
// cliclack logs are synchronized with its spinners/bars
let _ = cliclack::log::info(msg.as_ref());
}
/// Final outro/summary printed below any progress indicators (suppressed when quiet).
pub fn outro(msg: impl AsRef<str>) {
if crate::is_quiet() { return; }
let _ = cliclack::outro(msg.as_ref());
}
/// Prompt the user (TTY-aware via cliclack) and read a line from stdin. Returns the raw line with trailing newline removed.
pub fn prompt_line(prompt: &str) -> io::Result<String> {
// Route prompt through cliclack to keep consistent styling and avoid direct eprint!/println!
let _ = cliclack::log::info(prompt);
let mut s = String::new();
io::stdin().read_line(&mut s)?;
Ok(s)
}
// Progress manager built on indicatif MultiProgress for per-file and aggregate bars
/// TTY-aware progress UI built on `indicatif` for per-file and aggregate progress bars.
///
/// This small helper encapsulates a `MultiProgress` with one aggregate (total) bar and
/// one per-file bar. It is intentionally minimal to keep integration lightweight.
pub mod progress {
use atty::Stream;
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
/// Manages a set of per-file progress bars plus a top aggregate bar.
pub struct ProgressManager {
enabled: bool,
mp: Option<MultiProgress>,
per: Vec<ProgressBar>,
total: Option<ProgressBar>,
total_n: usize,
completed: usize,
done: Vec<bool>,
}
impl ProgressManager {
/// Create a new manager with the given enabled flag.
pub fn new(enabled: bool) -> Self {
Self { enabled, mp: None, per: Vec::new(), total: None, total_n: 0, completed: 0, done: Vec::new() }
}
/// Create a manager that enables bars when `n > 1`, stderr is a TTY, and not quiet.
pub fn default_for_files(n: usize) -> Self {
let enabled = n > 1 && atty::is(Stream::Stderr) && !crate::is_quiet() && !crate::is_no_progress();
Self::new(enabled)
}
/// Initialize bars for the given file labels. If disabled or single file, no-op.
pub fn init_files(&mut self, labels: &[String]) {
self.total_n = labels.len();
if !self.enabled || self.total_n <= 1 {
// No bars in single-file mode or when disabled
self.enabled = false;
return;
}
let mp = MultiProgress::new();
// Aggregate bar at the top
let total = mp.add(ProgressBar::new(labels.len() as u64));
total.set_style(ProgressStyle::with_template("{prefix} [{bar:40.cyan/blue}] {pos}/{len}")
.unwrap()
.progress_chars("=>-"));
total.set_prefix("Total");
self.total = Some(total);
// Per-file bars
for label in labels {
let pb = mp.add(ProgressBar::new(100));
pb.set_style(ProgressStyle::with_template("{prefix} [{bar:40.green/black}] {pos}% {msg}")
.unwrap()
.progress_chars("=>-"));
pb.set_position(0);
pb.set_prefix(label.clone());
self.per.push(pb);
}
self.mp = Some(mp);
}
/// Returns true when bars are enabled (multi-file TTY mode).
pub fn is_enabled(&self) -> bool { self.enabled }
/// Get a clone of the per-file progress bar at index, if enabled.
pub fn per_bar(&self, idx: usize) -> Option<ProgressBar> {
if !self.enabled { return None; }
self.per.get(idx).cloned()
}
/// Get a clone of the aggregate (total) progress bar, if enabled.
pub fn total_bar(&self) -> Option<ProgressBar> {
if !self.enabled { return None; }
self.total.as_ref().cloned()
}
/// Mark a file as finished (set to 100% and update total counter).
pub fn mark_file_done(&mut self, idx: usize) {
if !self.enabled { return; }
if let Some(pb) = self.per.get(idx) {
pb.set_position(100);
pb.finish_with_message("done");
}
self.completed += 1;
if let Some(total) = &self.total { total.set_position(self.completed as u64); }
}
}
}
}
/// Logging macros and helpers
/// Log an error to stderr (always printed). Recommended for user-visible errors.
/// Log an error using the UI helper (always printed). Recommended for user-visible errors.
#[macro_export]
macro_rules! elog {
($($arg:tt)*) => {{
eprintln!("ERROR: {}", format!($($arg)*));
}}
}
/// Internal helper macro used by other logging macros to centralize the
/// common behavior: build formatted message, check quiet/verbose flags,
/// and print to stderr with a label.
#[macro_export]
macro_rules! log_with_level {
($label:expr, $min_lvl:expr, $always:expr, $($arg:tt)*) => {{
let should_print = if $always {
true
} else if let Some(minv) = $min_lvl {
!$crate::is_quiet() && $crate::verbose_level() >= minv
} else {
!$crate::is_quiet()
};
if should_print {
eprintln!("{}: {}", $label, format!($($arg)*));
}
$crate::ui::error(format!($($arg)*));
}}
}
/// Log a warning to stderr (printed even in quiet mode).
/// Log a warning using the UI helper (printed even in quiet mode).
#[macro_export]
macro_rules! wlog {
($($arg:tt)*) => {{ $crate::log_with_level!("WARN", None, true, $($arg)*); }}
($($arg:tt)*) => {{
$crate::ui::warn(format!($($arg)*));
}}
}
/// Log an informational line to stderr unless quiet mode is enabled.
/// Log an informational line using the UI helper unless quiet mode is enabled.
#[macro_export]
macro_rules! ilog {
($($arg:tt)*) => {{ $crate::log_with_level!("INFO", None, false, $($arg)*); }}
($($arg:tt)*) => {{
if !$crate::is_quiet() { $crate::ui::info(format!($($arg)*)); }
}}
}
/// Log a debug/trace line when verbose level is at least the given level (u8).
#[macro_export]
macro_rules! dlog {
($lvl:expr, $($arg:tt)*) => {{
$crate::log_with_level!(&format!("DEBUG{}", &$lvl), Some($lvl), false, $($arg)*);
if !$crate::is_quiet() && $crate::verbose_level() >= $lvl { $crate::ui::info(format!("DEBUG{}: {}", $lvl, format!($($arg)*))); }
}}
}
@@ -230,7 +361,6 @@ use anyhow::{Context, Result, anyhow};
use chrono::Local;
use std::env;
use std::fs::create_dir_all;
use std::io::{self, Write};
use std::path::{Path, PathBuf};
use std::process::Command;
@@ -462,10 +592,7 @@ pub fn find_model_file() -> Result<PathBuf> {
"No models available and interactive mode is disabled. Please set WHISPER_MODEL or run with --download-models."
));
}
eprint!("Would you like to download models now? [Y/n]: ");
io::stderr().flush().ok();
let mut input = String::new();
io::stdin().read_line(&mut input).ok();
let input = crate::ui::prompt_line("Would you like to download models now? [Y/n]: ").unwrap_or_default();
let ans = input.trim().to_lowercase();
if ans.is_empty() || ans == "y" || ans == "yes" {
if let Err(e) = models::run_interactive_model_downloader() {
@@ -519,16 +646,12 @@ pub fn find_model_file() -> Result<PathBuf> {
}
}
eprintln!("Multiple Whisper models found in {}:", models_dir.display());
crate::ui::println_above_bars(format!("Multiple Whisper models found in {}:", models_dir.display()));
for (i, p) in candidates.iter().enumerate() {
eprintln!(" {}) {}", i + 1, p.display());
crate::ui::println_above_bars(format!(" {}) {}", i + 1, p.display()));
}
eprint!("Select model by number [1-{}]: ", candidates.len());
io::stderr().flush().ok();
let mut input = String::new();
io::stdin()
.read_line(&mut input)
.context("Failed to read selection")?;
let input = crate::ui::prompt_line(&format!("Select model by number [1-{}]: ", candidates.len()))
.map_err(|_| anyhow!("Failed to read selection"))?;
let sel: usize = input
.trim()
.parse()
@@ -571,10 +694,11 @@ pub fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> {
}
};
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(anyhow!(
"ffmpeg failed for {}: {}",
"Failed to decode audio from {} using ffmpeg. This may indicate the file is not a valid or supported audio/video file, is corrupted, or cannot be opened. ffmpeg stderr: {}",
audio_path.display(),
String::from_utf8_lossy(&output.stderr)
stderr.trim()
));
}
let bytes = output.stdout;

View File

@@ -55,6 +55,10 @@ struct Args {
#[arg(long = "no-interaction", global = true)]
no_interaction: bool,
/// Disable interactive progress indicators (bars/spinners)
#[arg(long = "no-progress", global = true)]
no_progress: bool,
/// Optional auxiliary subcommands (completions, man)
#[command(subcommand)]
aux: Option<AuxCommands>,
@@ -129,7 +133,12 @@ fn sanitize_speaker_name(raw: &str) -> String {
raw.to_string()
}
fn prompt_speaker_name_for_path(path: &Path, default_name: &str, enabled: bool) -> String {
fn prompt_speaker_name_for_path(
path: &Path,
default_name: &str,
enabled: bool,
_pm: Option<&polyscribe::ui::progress::ProgressManager>,
) -> String {
if !enabled {
return default_name.to_string();
}
@@ -142,25 +151,18 @@ fn prompt_speaker_name_for_path(path: &Path, default_name: &str, enabled: bool)
.and_then(|s| s.to_str())
.map(|s| s.to_string())
.unwrap_or_else(|| path.to_string_lossy().to_string());
eprint!(
let buf = polyscribe::ui::prompt_line(&format!(
"Enter speaker name for {display_owned} [default: {default_name}]: "
);
io::stderr().flush().ok();
let mut buf = String::new();
match io::stdin().read_line(&mut buf) {
Ok(_) => {
let raw = buf.trim();
if raw.is_empty() {
return default_name.to_string();
}
let sanitized = sanitize_speaker_name(raw);
if sanitized.is_empty() {
default_name.to_string()
} else {
sanitized
}
}
Err(_) => default_name.to_string(),
)).unwrap_or_default();
let raw = buf.trim();
if raw.is_empty() {
return default_name.to_string();
}
let sanitized = sanitize_speaker_name(raw);
if sanitized.is_empty() {
default_name.to_string()
} else {
sanitized
}
}
@@ -184,6 +186,22 @@ fn is_audio_file(path: &Path) -> bool {
false
}
fn validate_input_path(path: &Path) -> anyhow::Result<()> {
use anyhow::{anyhow, Context};
let display = path.display();
if !path.exists() {
return Err(anyhow!("Input not found: {}", display));
}
let md = std::fs::metadata(path).with_context(|| format!("Failed to stat input: {}", display))?;
if md.is_dir() {
return Err(anyhow!("Input is a directory (expected a file): {}", display));
}
// Attempt to open to catch permission errors early
std::fs::File::open(path)
.with_context(|| format!("Failed to open input file: {}", display))
.map(|_| ())
}
struct LastModelCleanup {
path: PathBuf,
}
@@ -217,6 +235,7 @@ where
}
fn run() -> Result<()> {
let _t0 = std::time::Instant::now();
// Parse CLI
let args = Args::parse();
@@ -224,6 +243,10 @@ fn run() -> Result<()> {
polyscribe::set_verbose(args.verbose);
polyscribe::set_quiet(args.quiet);
polyscribe::set_no_interaction(args.no_interaction);
polyscribe::set_no_progress(args.no_progress);
// Startup banner via UI (TTY-aware through cliclack), suppressed when quiet
polyscribe::ui::intro(format!("PolyScribe v{}", env!("CARGO_PKG_VERSION")));
// Handle auxiliary subcommands that write to stdout and exit early
if let Some(aux) = &args.aux {
@@ -266,6 +289,10 @@ fn run() -> Result<()> {
polyscribe::dlog!(1, "Using backend: {:?}", sel.chosen);
// If requested, run the interactive model downloader first. If no inputs were provided, exit after downloading.
let mut summary_inputs_total: usize = 0;
let mut summary_audio_count: usize = 0;
let mut summary_json_count: usize = 0;
let mut summary_segments_total: usize = 0;
if args.download_models {
if let Err(e) = polyscribe::models::run_interactive_model_downloader() {
polyscribe::elog!("Model downloader failed: {:#}", e);
@@ -290,6 +317,7 @@ fn run() -> Result<()> {
// Determine inputs and optional output path
polyscribe::dlog!(1, "Parsed {} input(s)", args.inputs.len());
let mut inputs = args.inputs;
summary_inputs_total = inputs.len();
let mut output_path = args.output;
if output_path.is_none() && inputs.len() >= 2 {
if let Some(last) = inputs.last().cloned() {
@@ -303,6 +331,18 @@ fn run() -> Result<()> {
return Err(anyhow!("No input files provided"));
}
// Preflight: validate each input path and type
for inp in &inputs {
let p = Path::new(inp);
validate_input_path(p)?;
if !(is_audio_file(p) || is_json_file(p)) {
return Err(anyhow!(
"Unsupported input type (expected .json transcript or common audio/video): {}",
p.display()
));
}
}
// Language must be provided via CLI when transcribing audio (no detection from JSON/env)
let lang_hint: Option<String> = if let Some(ref l) = args.language {
normalize_lang_code(l).or_else(|| Some(l.trim().to_lowercase()))
@@ -316,6 +356,31 @@ fn run() -> Result<()> {
));
}
// Initialize progress manager early to coordinate prompts
let mut pm = polyscribe::ui::progress::ProgressManager::default_for_files(inputs.len());
// Initialize progress manager early to coordinate prompts
let mut pm = polyscribe::ui::progress::ProgressManager::default_for_files(inputs.len());
// Collect all speaker names up front (one per input), before any file reading/transcription
let speakers: Vec<String> = inputs
.iter()
.map(|input_path| {
let path = Path::new(input_path);
let default_speaker = sanitize_speaker_name(
path.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("speaker"),
);
prompt_speaker_name_for_path(path, &default_speaker, args.set_speaker_names, Some(&pm))
})
.collect();
// Initialize multi-file progress bars (TTY-aware); suppressed for single-file/non-TTY/quiet
let mut pm = polyscribe::ui::progress::ProgressManager::default_for_files(speakers.len());
// Use speaker names (derived from file names or prompted) as labels
pm.init_files(&speakers);
if args.merge_and_separate {
polyscribe::dlog!(1, "Mode: merge-and-separate; output_dir={:?}", output_path);
// Combined mode: write separate outputs per input and also a merged output set
@@ -332,38 +397,55 @@ fn run() -> Result<()> {
let mut merged_entries: Vec<OutputEntry> = Vec::new();
for input_path in &inputs {
for (idx, input_path) in inputs.iter().enumerate() {
let path = Path::new(input_path);
let default_speaker = sanitize_speaker_name(
path.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("speaker"),
);
let speaker =
prompt_speaker_name_for_path(path, &default_speaker, args.set_speaker_names);
let speaker = speakers[idx].clone();
// Collect entries per file and extend merged
let mut entries: Vec<OutputEntry> = Vec::new();
if is_audio_file(path) {
// Progress log to stderr (suppressed by -q); avoid partial lines
polyscribe::ilog!("Processing file: {} ...", path.display());
summary_audio_count += 1;
// Progress log only when multi-bars are not enabled
if !pm.is_enabled() {
polyscribe::ilog!("Processing file: {} ...", path.display());
}
// Prepare per-file progress callback if multi-bars enabled
let mut cb_holder: Option<Box<dyn Fn(i32) + Send + Sync>> = None;
if let Some(pb) = pm.per_bar(idx) {
let pb = pb.clone();
cb_holder = Some(Box::new(move |p: i32| {
let p = p.clamp(0, 100) as u64;
pb.set_position(p);
}));
}
let res = with_quiet_stdio_if_needed(args.quiet, || {
let cb_ref = cb_holder.as_ref().map(|b| &**b as &(dyn Fn(i32) + Send + Sync));
sel.backend
.transcribe(path, &speaker, lang_hint.as_deref(), args.gpu_layers)
.transcribe(path, &speaker, lang_hint.as_deref(), args.gpu_layers, cb_ref)
});
match res {
Ok(items) => {
polyscribe::ilog!("done");
if pm.is_enabled() {
pm.mark_file_done(idx);
} else {
polyscribe::ilog!("done");
}
entries.extend(items.into_iter());
}
Err(e) => {
if !polyscribe::is_no_interaction() && polyscribe::stdin_is_tty() {
polyscribe::elog!("{:#}", e);
if let Some(pb) = pm.per_bar(idx) {
pb.finish_with_message("error");
}
if !pm.is_enabled() {
if !polyscribe::is_no_interaction() && polyscribe::stdin_is_tty() {
polyscribe::elog!("{:#}", e);
}
}
return Err(e);
}
}
} else if is_json_file(path) {
summary_json_count += 1;
let mut buf = String::new();
File::open(path)
.with_context(|| format!("Failed to open: {input_path}"))?
@@ -401,6 +483,7 @@ fn run() -> Result<()> {
for (i, e) in entries.iter_mut().enumerate() {
e.id = i as u64;
}
summary_segments_total += entries.len();
// Write separate outputs to out_dir
let out = OutputRoot {
@@ -484,40 +567,57 @@ fn run() -> Result<()> {
polyscribe::dlog!(1, "Mode: merge; output_base={:?}", output_path);
// MERGED MODE (previous default)
let mut entries: Vec<OutputEntry> = Vec::new();
for input_path in &inputs {
for (idx, input_path) in inputs.iter().enumerate() {
let path = Path::new(input_path);
let default_speaker = sanitize_speaker_name(
path.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("speaker"),
);
let speaker =
prompt_speaker_name_for_path(path, &default_speaker, args.set_speaker_names);
let speaker = speakers[idx].clone();
let mut buf = String::new();
if is_audio_file(path) {
// Progress log to stderr (suppressed by -q)
polyscribe::ilog!("Processing file: {} ...", path.display());
summary_audio_count += 1;
// Progress log only when multi-bars are not enabled
if !pm.is_enabled() {
polyscribe::ilog!("Processing file: {} ...", path.display());
}
// Prepare per-file progress callback if multi-bars enabled
let mut cb_holder: Option<Box<dyn Fn(i32) + Send + Sync>> = None;
if let Some(pb) = pm.per_bar(idx) {
let pb = pb.clone();
cb_holder = Some(Box::new(move |p: i32| {
let p = p.clamp(0, 100) as u64;
pb.set_position(p);
}));
}
let res = with_quiet_stdio_if_needed(args.quiet, || {
let cb_ref = cb_holder.as_ref().map(|b| &**b as &(dyn Fn(i32) + Send + Sync));
sel.backend
.transcribe(path, &speaker, lang_hint.as_deref(), args.gpu_layers)
.transcribe(path, &speaker, lang_hint.as_deref(), args.gpu_layers, cb_ref)
});
match res {
Ok(items) => {
polyscribe::ilog!("done");
if pm.is_enabled() {
pm.mark_file_done(idx);
} else {
polyscribe::ilog!("done");
}
for e in items {
entries.push(e);
}
continue;
}
Err(e) => {
if !(polyscribe::is_no_interaction() || !polyscribe::stdin_is_tty()) {
polyscribe::elog!("{:#}", e);
if let Some(pb) = pm.per_bar(idx) {
pb.finish_with_message("error");
}
if !pm.is_enabled() {
if !(polyscribe::is_no_interaction() || !polyscribe::stdin_is_tty()) {
polyscribe::elog!("{:#}", e);
}
}
return Err(e);
}
}
} else if is_json_file(path) {
summary_json_count += 1;
File::open(path)
.with_context(|| format!("Failed to open: {}", input_path))?
.read_to_string(&mut buf)
@@ -557,6 +657,7 @@ fn run() -> Result<()> {
e.id = i as u64;
}
let out = OutputRoot { items: entries };
summary_segments_total = out.items.len();
if let Some(path) = output_path {
let base_path = Path::new(&path);
@@ -627,38 +728,55 @@ fn run() -> Result<()> {
}
}
for input_path in &inputs {
for (idx, input_path) in inputs.iter().enumerate() {
let path = Path::new(input_path);
let default_speaker = sanitize_speaker_name(
path.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("speaker"),
);
let speaker =
prompt_speaker_name_for_path(path, &default_speaker, args.set_speaker_names);
let speaker = speakers[idx].clone();
// Collect entries per file
let mut entries: Vec<OutputEntry> = Vec::new();
if is_audio_file(path) {
// Progress log to stderr (suppressed by -q)
polyscribe::ilog!("Processing file: {} ...", path.display());
summary_audio_count += 1;
// Progress log only when multi-bars are not enabled
if !pm.is_enabled() {
polyscribe::ilog!("Processing file: {} ...", path.display());
}
// Prepare per-file progress callback if multi-bars enabled
let mut cb_holder: Option<Box<dyn Fn(i32) + Send + Sync>> = None;
if let Some(pb) = pm.per_bar(idx) {
let pb = pb.clone();
cb_holder = Some(Box::new(move |p: i32| {
let p = p.clamp(0, 100) as u64;
pb.set_position(p);
}));
}
let res = with_quiet_stdio_if_needed(args.quiet, || {
let cb_ref = cb_holder.as_ref().map(|b| &**b as &(dyn Fn(i32) + Send + Sync));
sel.backend
.transcribe(path, &speaker, lang_hint.as_deref(), args.gpu_layers)
.transcribe(path, &speaker, lang_hint.as_deref(), args.gpu_layers, cb_ref)
});
match res {
Ok(items) => {
polyscribe::ilog!("done");
if pm.is_enabled() {
pm.mark_file_done(idx);
} else {
polyscribe::ilog!("done");
}
entries.extend(items);
}
Err(e) => {
if !polyscribe::is_no_interaction() && polyscribe::stdin_is_tty() {
polyscribe::elog!("{:#}", e);
if let Some(pb) = pm.per_bar(idx) {
pb.finish_with_message("error");
}
if !pm.is_enabled() {
if !polyscribe::is_no_interaction() && polyscribe::stdin_is_tty() {
polyscribe::elog!("{:#}", e);
}
}
return Err(e);
}
}
} else if is_json_file(path) {
summary_json_count += 1;
let mut buf = String::new();
File::open(path)
.with_context(|| format!("Failed to open: {input_path}"))?
@@ -696,6 +814,7 @@ fn run() -> Result<()> {
for (i, e) in entries.iter_mut().enumerate() {
e.id = i as u64;
}
summary_segments_total += entries.len();
let out = OutputRoot { items: entries };
if let Some(dir) = &out_dir {
@@ -740,6 +859,20 @@ fn run() -> Result<()> {
}
}
// Final summary (TTY-aware via UI), only when not quiet
if !polyscribe::is_quiet() {
let elapsed = _t0.elapsed();
let secs = elapsed.as_secs_f32();
let mut out = String::new();
out.push_str("Summary:\n");
out.push_str(&format!("{:<12} {:>8}\n", "Files:", summary_inputs_total));
out.push_str(&format!("{:<12} {:>8}\n", "Audio:", summary_audio_count));
out.push_str(&format!("{:<12} {:>8}\n", "JSON:", summary_json_count));
out.push_str(&format!("{:<12} {:>8}\n", "Segments:", summary_segments_total));
out.push_str(&format!("{:<12} {:>8.2}s\n", "Time:", secs));
polyscribe::ui::outro(out);
}
Ok(())
}

View File

@@ -5,7 +5,7 @@
use std::collections::BTreeMap;
use std::env;
use std::fs::{File, create_dir_all};
use std::io::{self, Read, Write};
use std::io::{Read, Write};
use std::path::Path;
use std::time::Duration;
@@ -14,6 +14,8 @@ use reqwest::blocking::Client;
use reqwest::redirect::Policy;
use serde::Deserialize;
use sha2::{Digest, Sha256};
use indicatif::{ProgressBar, ProgressStyle, MultiProgress};
use atty::Stream;
// --- Model downloader: list & download ggml models from Hugging Face ---
@@ -326,8 +328,8 @@ fn fetch_all_models(client: &Client) -> Result<Vec<ModelEntry>> {
match hf_fetch_repo_models(client, "akashmjn/tinydiarize-whisper.cpp") {
Ok(v) => v,
Err(e) => {
ilog!(
"Warning: failed to fetch optional repo akashmjn/tinydiarize-whisper.cpp: {:#}",
wlog!(
"Failed to fetch optional repo akashmjn/tinydiarize-whisper.cpp: {:#}",
e
);
Vec::new()
@@ -393,128 +395,61 @@ fn format_model_list(models: &[ModelEntry]) -> String {
}
fn prompt_select_models_two_stage(models: &[ModelEntry]) -> Result<Vec<ModelEntry>> {
// Replaced by cliclack-based multiselect; keep function to preserve signature but delegate.
prompt_select_models_cliclack(models)
}
fn prompt_select_models_cliclack(models: &[ModelEntry]) -> Result<Vec<ModelEntry>> {
if crate::is_no_interaction() || !crate::stdin_is_tty() {
// Non-interactive: do not prompt, return empty selection to skip
return Ok(Vec::new());
}
// 1) Choose base (tiny, small, medium, etc.)
let mut bases: Vec<String> = Vec::new();
let mut last = String::new();
for m in models.iter() {
if m.base != last {
// models are sorted by base; avoid duplicates while preserving order
if !bases.last().map(|b| b == &m.base).unwrap_or(false) {
bases.push(m.base.clone());
// Build grouped, aligned labels for selection items (include base prefix for grouping).
let mut item_labels: Vec<String> = Vec::new();
let mut item_model_indices: Vec<usize> = Vec::new();
// Compute widths for alignment
let name_width = models.iter().map(|m| m.name.len()).max().unwrap_or(0);
let base_width = models.iter().map(|m| m.base.len()).max().unwrap_or(0);
for (i, m) in models.iter().enumerate() {
let label = format!(
"{base:<bw$}: {name:<nw$} [{repo} | {size}]",
base = m.base,
bw = base_width,
name = m.name,
nw = name_width,
repo = m.repo,
size = human_size(m.size)
);
item_labels.push(label);
item_model_indices.push(i);
}
// Use cliclack multiselect builder with (value, label, help) tuples.
let prompt = "Select Whisper model(s) to download (↑/↓ move, space toggle, enter confirm)";
let mut items: Vec<(usize, String, String)> = Vec::with_capacity(item_labels.len());
for (idx, label) in item_labels.iter().cloned().enumerate() {
items.push((item_model_indices[idx], label, String::new()));
}
match cliclack::multiselect::<usize>(prompt)
.items(&items)
.interact()
{
Ok(selected_indices) => {
let mut chosen: Vec<ModelEntry> = Vec::new();
for mi in selected_indices {
if let Some(m) = models.get(mi) {
chosen.push(m.clone());
}
}
last = m.base.clone();
Ok(chosen)
}
}
if bases.is_empty() {
return Ok(Vec::new());
}
// Print base selection on stderr
eprintln!("Available base model families:");
for (i, b) in bases.iter().enumerate() {
eprintln!(" {}) {}", i + 1, b);
}
loop {
eprint!("Select base (number or name, 'q' to cancel): ");
io::stderr().flush().ok();
let mut line = String::new();
io::stdin()
.read_line(&mut line)
.context("Failed to read base selection")?;
let s = line.trim();
if s.eq_ignore_ascii_case("q")
|| s.eq_ignore_ascii_case("quit")
|| s.eq_ignore_ascii_case("exit")
{
return Ok(Vec::new());
}
let chosen_base = if let Ok(i) = s.parse::<usize>() {
if i >= 1 && i <= bases.len() {
Some(bases[i - 1].clone())
} else {
None
}
} else if !s.is_empty() {
// accept exact name match (case-insensitive)
bases.iter().find(|b| b.eq_ignore_ascii_case(s)).cloned()
} else {
None
};
if let Some(base) = chosen_base {
// 2) Choose sub-type(s) within that base
let filtered: Vec<ModelEntry> =
models.iter().filter(|m| m.base == base).cloned().collect();
if filtered.is_empty() {
eprintln!("No models found for base '{base}'.");
continue;
}
// Reuse the formatter but only for the chosen base list
let listing = format_model_list(&filtered);
eprint!("{listing}");
// Build index map for filtered list
let mut index_map: Vec<usize> = Vec::with_capacity(filtered.len());
let mut idx = 1usize;
for (pos, _m) in filtered.iter().enumerate() {
index_map.push(pos);
idx += 1;
}
// Second prompt: sub-type selection
loop {
eprint!("Selection: ");
io::stderr().flush().ok();
let mut line2 = String::new();
io::stdin()
.read_line(&mut line2)
.context("Failed to read selection")?;
let s2 = line2.trim().to_lowercase();
if s2 == "q" || s2 == "quit" || s2 == "exit" {
return Ok(Vec::new());
}
let mut selected: Vec<usize> = Vec::new();
if s2 == "all" || s2 == "*" {
selected = (1..idx).collect();
} else if !s2.is_empty() {
for part in s2.split([',', ' ', ';']) {
let part = part.trim();
if part.is_empty() {
continue;
}
if let Some((a, b)) = part.split_once('-') {
if let (Ok(ia), Ok(ib)) = (a.parse::<usize>(), b.parse::<usize>()) {
if ia >= 1 && ib < idx && ia <= ib {
selected.extend(ia..=ib);
}
}
} else if let Ok(i) = part.parse::<usize>() {
if i >= 1 && i < idx {
selected.push(i);
}
}
}
}
selected.sort_unstable();
selected.dedup();
if selected.is_empty() {
eprintln!("No valid selection. Please try again or 'q' to cancel.");
continue;
}
let chosen: Vec<ModelEntry> = selected
.into_iter()
.map(|i| filtered[index_map[i - 1]].clone())
.collect();
return Ok(chosen);
}
} else {
eprintln!(
"Invalid base selection. Please enter a number from 1-{} or a base name.",
bases.len()
);
Err(e) => {
// If interaction fails (e.g., not a TTY), return empty to gracefully skip
wlog!("Selection canceled or failed: {}", e);
Ok(Vec::new())
}
}
}
@@ -561,16 +496,55 @@ pub fn run_interactive_model_downloader() -> Result<()> {
qlog!("No selection. Aborting download.");
return Ok(());
}
for m in selected {
if let Err(e) = download_one_model(&client, models_dir, &m) {
elog!("Error: {:#}", e);
// Parallel downloads with bounded concurrency. Default 3; override via POLYSCRIBE_MAX_PARALLEL_DOWNLOADS (1..=6).
let max_jobs = std::env::var("POLYSCRIBE_MAX_PARALLEL_DOWNLOADS")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.map(|n| n.clamp(1, 6))
.unwrap_or(3);
// Use a MultiProgress to render per-model bars concurrently when interactive.
let mp_opt = if !crate::is_quiet() && !crate::is_no_progress() && atty::is(Stream::Stderr) {
Some(MultiProgress::new())
} else {
None
};
let mut i = 0;
while i < selected.len() {
let end = std::cmp::min(i + max_jobs, selected.len());
let mut handles = Vec::new();
for m in selected[i..end].iter().cloned() {
let client2 = client.clone();
let models_dir2 = models_dir.to_path_buf();
let pb_opt = if let Some(mp) = &mp_opt {
let pb = mp.add(ProgressBar::new(m.size));
let style = ProgressStyle::with_template("Downloading {prefix} ({total_bytes}) [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%)")
.unwrap()
.progress_chars("=>-");
pb.set_style(style);
pb.set_prefix(format!("{}", m.name));
Some(pb)
} else { None };
handles.push(std::thread::spawn(move || {
if let Err(e) = download_one_model(&client2, &models_dir2, &m, pb_opt) {
crate::elog!("Error: {:#}", e);
}
}));
}
for h in handles { let _ = h.join(); }
i = end;
}
// Drop MultiProgress after threads are joined; bars finish naturally.
drop(mp_opt);
Ok(())
}
/// Download a single model entry into the given models directory, verifying SHA-256 when available.
fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> {
fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry, pb: Option<indicatif::ProgressBar>) -> Result<()> {
let final_path = models_dir.join(format!("ggml-{}.bin", entry.name));
// If the model already exists, verify against online metadata
@@ -591,8 +565,8 @@ fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) ->
}
}
Err(e) => {
qlog!(
"Warning: failed to hash existing {}: {}. Will re-download to ensure correctness.",
wlog!(
"Failed to hash existing {}: {}. Will re-download to ensure correctness.",
final_path.display(),
e
);
@@ -618,8 +592,8 @@ fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) ->
}
}
Err(e) => {
qlog!(
"Warning: failed to stat existing {}: {}. Will re-download to ensure correctness.",
wlog!(
"Failed to stat existing {}: {}. Will re-download to ensure correctness.",
final_path.display(),
e
);
@@ -685,10 +659,13 @@ fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) ->
url
);
let mut resp = client
.get(url)
.get(url.clone())
.send()
.and_then(|r| r.error_for_status())
.context("Failed to download model")?;
.with_context(|| format!(
"Failed to download model {} from {}. If your terminal has display/TTY issues, try running with --no-progress.",
entry.name, url
))?;
let tmp_path = models_dir.join(format!("ggml-{}.bin.part", entry.name));
if tmp_path.exists() {
@@ -699,21 +676,49 @@ fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) ->
.with_context(|| format!("Failed to create {}", tmp_path.display()))?,
);
// Set up progress bar: use provided one if present; otherwise create if interactive and we know size
let show_progress = !crate::is_quiet() && !crate::is_no_progress() && atty::is(Stream::Stderr) && entry.size > 0;
let pb_opt = if let Some(p) = pb {
Some(p)
} else if show_progress {
let pb = ProgressBar::new(entry.size);
let style = ProgressStyle::with_template("Downloading {prefix} ({total_bytes}) [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%)")
.unwrap()
.progress_chars("=>-");
pb.set_style(style);
pb.set_prefix(format!("{}", entry.name));
Some(pb)
} else { None };
let mut hasher = Sha256::new();
let mut downloaded: u64 = 0;
let mut buf = [0u8; 1024 * 128];
let mut read_err: Option<anyhow::Error> = None;
loop {
let n = resp.read(&mut buf).context("Network read error")?;
if n == 0 {
break;
let nres = resp.read(&mut buf);
match nres {
Ok(n) => {
if n == 0 { break; }
hasher.update(&buf[..n]);
if let Err(e) = file.write_all(&buf[..n]) { read_err = Some(anyhow!(e)); break; }
downloaded += n as u64;
if let Some(pb) = &pb_opt { pb.set_position(downloaded.min(entry.size)); }
}
Err(e) => { read_err = Some(anyhow!("Network read error: {}", e)); break; }
}
hasher.update(&buf[..n]);
file.write_all(&buf[..n]).context("Write error")?;
}
file.flush().ok();
if let Some(err) = read_err {
if let Some(pb) = &pb_opt { pb.abandon_with_message("download failed"); }
let _ = std::fs::remove_file(&tmp_path);
return Err(err);
}
let got = to_hex_lower(&hasher.finalize());
if let Some(expected) = &entry.sha256 {
if got != expected.to_lowercase() {
if let Some(pb) = &pb_opt { pb.abandon_with_message("hash mismatch"); }
let _ = std::fs::remove_file(&tmp_path);
return Err(anyhow!(
"SHA-256 mismatch for {}: expected {}, got {}",
@@ -723,8 +728,8 @@ fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) ->
));
}
} else {
qlog!(
"Warning: no SHA-256 available for {}. Skipping verification.",
wlog!(
"No SHA-256 available for {}. Skipping verification.",
entry.name
);
}
@@ -734,6 +739,7 @@ fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) ->
}
std::fs::rename(&tmp_path, &final_path)
.with_context(|| format!("Failed to move into place: {}", final_path.display()))?;
if let Some(pb) = &pb_opt { pb.finish_with_message("saved"); }
qlog!("Saved: {}", final_path.display());
Ok(())
}
@@ -811,7 +817,17 @@ pub fn update_local_models() -> Result<()> {
if let Some(remote) = map.get(&model_name) {
// If SHA256 available, verify and update if mismatch
if let Some(expected) = &remote.sha256 {
match compute_file_sha256_hex(&path) {
// Show a small spinner while verifying hash (TTY, not quiet, not no-progress)
let show_spin = !crate::is_quiet() && !crate::is_no_progress() && atty::is(Stream::Stderr);
let spinner = if show_spin {
let pb = ProgressBar::new_spinner();
pb.enable_steady_tick(std::time::Duration::from_millis(100));
pb.set_message(format!("Verifying {}", fname));
Some(pb)
} else { None };
let verify_res = compute_file_sha256_hex(&path);
if let Some(pb) = &spinner { pb.finish_and_clear(); }
match verify_res {
Ok(local_hash) => {
if local_hash.eq_ignore_ascii_case(expected) {
qlog!("{} is up-to-date.", fname);
@@ -826,21 +842,21 @@ pub fn update_local_models() -> Result<()> {
}
}
Err(e) => {
qlog!("Warning: failed hashing {}: {}. Re-downloading.", fname, e);
wlog!("Failed hashing {}: {}. Re-downloading.", fname, e);
}
}
download_one_model(&client, models_dir, remote)?;
download_one_model(&client, models_dir, remote, None)?;
} else if remote.size > 0 {
match std::fs::metadata(&path) {
Ok(md) => {
if qlog_size_comparison(&fname, md.len(), remote.size) {
continue;
}
download_one_model(&client, models_dir, remote)?;
download_one_model(&client, models_dir, remote, None)?;
}
Err(e) => {
qlog!("Warning: stat failed for {}: {}. Updating...", fname, e);
download_one_model(&client, models_dir, remote)?;
wlog!("Stat failed for {}: {}. Updating...", fname, e);
download_one_model(&client, models_dir, remote, None)?;
}
}
} else {
@@ -908,7 +924,7 @@ pub fn ensure_model_available_noninteractive(model_name: &str) -> Result<std::pa
// Prefer fetching metadata to construct a proper ModelEntry
let models = fetch_all_models(&client)?;
if let Some(entry) = models.into_iter().find(|m| m.name == model_name) {
download_one_model(&client, models_dir, &entry)?;
download_one_model(&client, models_dir, &entry, None)?;
return Ok(models_dir.join(format!("ggml-{}.bin", entry.name)));
}
Err(anyhow!(

View File

@@ -0,0 +1,125 @@
// SPDX-License-Identifier: MIT
// Validation and error-handling integration tests
use std::fs;
use std::io::Read;
use std::path::PathBuf;
use std::process::Command;
fn bin() -> &'static str {
env!("CARGO_BIN_EXE_polyscribe")
}
fn manifest_path(relative: &str) -> PathBuf {
let mut p = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
p.push(relative);
p
}
#[test]
fn error_on_missing_input_file() {
let exe = bin();
let missing = manifest_path("input/definitely_missing_123.json");
let out = Command::new(exe)
.arg(missing.as_os_str())
.output()
.expect("failed to run polyscribe with missing input");
assert!(!out.status.success(), "command should fail on missing input");
let stderr = String::from_utf8_lossy(&out.stderr);
assert!(
stderr.contains("Input not found") || stderr.contains("No input files provided"),
"stderr should mention missing input; got: {}",
stderr
);
}
#[test]
fn error_on_directory_as_input() {
let exe = bin();
// Use the repo's input directory which exists and is a directory
let input_dir = manifest_path("input");
let out = Command::new(exe)
.arg(input_dir.as_os_str())
.output()
.expect("failed to run polyscribe with directory input");
assert!(!out.status.success(), "command should fail on dir input");
let stderr = String::from_utf8_lossy(&out.stderr);
assert!(
stderr.contains("directory") || stderr.contains("Unsupported input type"),
"stderr should mention directory/unsupported; got: {}",
stderr
);
}
#[test]
fn error_on_no_ffmpeg_present() {
let exe = bin();
// Create a tiny temp .wav file (may be empty; ffmpeg will be attempted but PATH will be empty)
let tmp_dir = manifest_path("target/tmp/itest_no_ffmpeg");
let _ = fs::remove_dir_all(&tmp_dir);
fs::create_dir_all(&tmp_dir).unwrap();
let wav = tmp_dir.join("dummy.wav");
fs::write(&wav, b"\0\0\0\0").unwrap();
let out = Command::new(exe)
.arg(wav.as_os_str())
.env("PATH", "") // simulate ffmpeg missing
.env_remove("WHISPER_MODEL")
.env("POLYSCRIBE_MODELS_BASE_COPY_DIR", manifest_path("models").as_os_str())
.arg("--language").arg("en")
.output()
.expect("failed to run polyscribe with empty PATH");
assert!(
!out.status.success(),
"command should fail when ffmpeg is not found"
);
let stderr = String::from_utf8_lossy(&out.stderr);
assert!(
stderr.contains("ffmpeg not found") || stderr.contains("Failed to execute ffmpeg"),
"stderr should mention ffmpeg not found; got: {}",
stderr
);
}
#[cfg(unix)]
#[test]
fn error_on_readonly_output_dir() {
use std::os::unix::fs::PermissionsExt;
let exe = bin();
let input1 = manifest_path("input/1-s0wlz.json");
// Prepare a read-only directory
let tmp_dir = manifest_path("target/tmp/itest_readonly_out");
let _ = fs::remove_dir_all(&tmp_dir);
fs::create_dir_all(&tmp_dir).unwrap();
let mut perms = fs::metadata(&tmp_dir).unwrap().permissions();
perms.set_mode(0o555); // read & execute, no write
fs::set_permissions(&tmp_dir, perms).unwrap();
let out = Command::new(exe)
.arg(input1.as_os_str())
.arg("-o")
.arg(tmp_dir.as_os_str())
.output()
.expect("failed to run polyscribe with read-only output dir");
// Restore perms for cleanup
let mut perms2 = fs::metadata(&tmp_dir).unwrap().permissions();
perms2.set_mode(0o755);
let _ = fs::set_permissions(&tmp_dir, perms2);
assert!(
!out.status.success(),
"command should fail when outputs cannot be created"
);
let stderr = String::from_utf8_lossy(&out.stderr);
assert!(
stderr.contains("Failed to create output") || stderr.contains("permission"),
"stderr should mention failure to create outputs; got: {}",
stderr
);
// Cleanup
let _ = fs::remove_dir_all(&tmp_dir);
}