[feat] add GPU backend support with runtime selection; refactor transcription logic; update CLI and tests

This commit is contained in:
2025-08-08 16:19:02 +02:00
parent 7a6a313107
commit bc8bbdc381
6 changed files with 312 additions and 62 deletions

View File

@@ -11,9 +11,10 @@ use chrono::Local;
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
use clap_complete::Shell;
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
// whisper-rs is used in backend module
mod models;
mod backend;
use backend::{BackendKind, select_backend, TranscribeBackend};
static LAST_MODEL_WRITTEN: AtomicBool = AtomicBool::new(false);
static VERBOSE: AtomicU8 = AtomicU8::new(0);
@@ -82,6 +83,16 @@ enum AuxCommands {
Man,
}
#[derive(clap::ValueEnum, Debug, Clone, Copy)]
#[value(rename_all = "kebab-case")]
enum GpuBackendCli {
Auto,
Cpu,
Cuda,
Hip,
Vulkan,
}
#[derive(Parser, Debug)]
#[command(name = "PolyScribe", bin_name = "polyscribe", version, about = "Merge JSON transcripts or transcribe audio using native whisper")]
struct Args {
@@ -112,6 +123,14 @@ struct Args {
#[arg(short, long, value_name = "LANG")]
language: Option<String>,
/// Choose GPU backend at runtime (auto|cpu|cuda|hip|vulkan). Default: auto.
#[arg(long = "gpu-backend", value_enum, default_value_t = GpuBackendCli::Auto)]
gpu_backend: GpuBackendCli,
/// Number of layers to offload to GPU (if supported by backend)
#[arg(long = "gpu-layers", value_name = "N")]
gpu_layers: Option<u32>,
/// Launch interactive model downloader (list HF models, multi-select and download)
#[arg(long)]
download_models: bool,
@@ -251,7 +270,7 @@ fn normalize_lang_code(input: &str) -> Option<String> {
fn find_model_file() -> Result<PathBuf> {
pub(crate) fn find_model_file() -> Result<PathBuf> {
let models_dir_buf = models_dir_path();
let models_dir = models_dir_buf.as_path();
if !models_dir.exists() {
@@ -362,7 +381,7 @@ fn find_model_file() -> Result<PathBuf> {
Ok(chosen)
}
fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> {
pub(crate) fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> {
let output = Command::new("ffmpeg")
.arg("-i").arg(audio_path)
.arg("-f").arg("f32le")
@@ -398,61 +417,6 @@ fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> {
}
}
fn transcribe_native(audio_path: &Path, speaker: &str, lang_opt: Option<&str>) -> Result<Vec<OutputEntry>> {
let pcm = decode_audio_to_pcm_f32_ffmpeg(audio_path)?;
let model = find_model_file()?;
let is_en_only = model
.file_name()
.and_then(|s| s.to_str())
.map(|s| s.contains(".en.") || s.ends_with(".en.bin"))
.unwrap_or(false);
if let Some(lang) = lang_opt {
if is_en_only && lang != "en" {
return Err(anyhow!(
"Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model like models/ggml-base.bin or set WHISPER_MODEL accordingly.",
model.display(),
lang
));
}
}
let model_str = model.to_str().ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model.display()))?;
// Initialize Whisper with GPU preference
let cparams = WhisperContextParameters::default();
// Prefer GPU if available; default whisper.cpp already has use_gpu=true. If the wrapper exposes
// a gpu_device field in the future, we could set it here from WHISPER_GPU_DEVICE.
if let Ok(dev_str) = env::var("WHISPER_GPU_DEVICE") {
let _ = dev_str.trim().parse::<i32>().ok();
}
// Even if we can't set fields explicitly (due to API differences), whisper.cpp defaults to GPU.
let ctx = WhisperContext::new_with_params(model_str, cparams)
.with_context(|| format!("Failed to load Whisper model at {}", model.display()))?;
let mut state = ctx.create_state()
.map_err(|e| anyhow!("Failed to create Whisper state: {:?}", e))?;
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
let n_threads = std::thread::available_parallelism().map(|n| n.get() as i32).unwrap_or(1);
params.set_n_threads(n_threads);
params.set_translate(false);
if let Some(lang) = lang_opt { params.set_language(Some(lang)); }
state.full(params, &pcm)
.map_err(|e| anyhow!("Whisper full() failed: {:?}", e))?;
let num_segments = state.full_n_segments().map_err(|e| anyhow!("Failed to get segments: {:?}", e))?;
let mut items = Vec::new();
for i in 0..num_segments {
let text = state.full_get_segment_text(i)
.map_err(|e| anyhow!("Failed to get segment text: {:?}", e))?;
let t0 = state.full_get_segment_t0(i).map_err(|e| anyhow!("Failed to get segment t0: {:?}", e))?;
let t1 = state.full_get_segment_t1(i).map_err(|e| anyhow!("Failed to get segment t1: {:?}", e))?;
let start = (t0 as f64) * 0.01;
let end = (t1 as f64) * 0.01;
items.push(OutputEntry { id: 0, speaker: speaker.to_string(), start, end, text: text.trim().to_string() });
}
Ok(items)
}
struct LastModelCleanup {
path: PathBuf,
}
@@ -498,6 +462,17 @@ fn main() -> Result<()> {
// Ensure cleanup at end of program, regardless of exit path
let _last_model_cleanup = LastModelCleanup { path: last_model_path.clone() };
// Select backend
let requested = match args.gpu_backend {
GpuBackendCli::Auto => BackendKind::Auto,
GpuBackendCli::Cpu => BackendKind::Cpu,
GpuBackendCli::Cuda => BackendKind::Cuda,
GpuBackendCli::Hip => BackendKind::Hip,
GpuBackendCli::Vulkan => BackendKind::Vulkan,
};
let sel = select_backend(requested, args.verbose > 0)?;
vlog!(0, "Using backend: {:?}", sel.chosen);
// If requested, run the interactive model downloader first. If no inputs were provided, exit after downloading.
if args.download_models {
if let Err(e) = models::run_interactive_model_downloader() {
@@ -572,7 +547,7 @@ fn main() -> Result<()> {
// Collect entries per file and extend merged
let mut entries: Vec<OutputEntry> = Vec::new();
if is_audio_file(path) {
let items = transcribe_native(path, &speaker, lang_hint.as_deref())?;
let items = sel.backend.transcribe(path, &speaker, lang_hint.as_deref(), args.gpu_layers)?;
entries.extend(items.into_iter());
} else if is_json_file(path) {
let mut buf = String::new();
@@ -665,7 +640,7 @@ fn main() -> Result<()> {
let mut buf = String::new();
if is_audio_file(path) {
let items = transcribe_native(path, &speaker, lang_hint.as_deref())?;
let items = sel.backend.transcribe(path, &speaker, lang_hint.as_deref(), args.gpu_layers)?;
for e in items { entries.push(e); }
continue;
} else if is_json_file(path) {
@@ -766,7 +741,7 @@ fn main() -> Result<()> {
// Collect entries per file
let mut entries: Vec<OutputEntry> = Vec::new();
if is_audio_file(path) {
let items = transcribe_native(path, &speaker, lang_hint.as_deref())?;
let items = sel.backend.transcribe(path, &speaker, lang_hint.as_deref(), args.gpu_layers)?;
entries.extend(items);
} else if is_json_file(path) {
let mut buf = String::new();
@@ -834,6 +809,7 @@ mod tests {
use std::io::Write;
use std::env as std_env;
use clap::CommandFactory;
use super::backend::*;
#[test]
fn test_cli_name_polyscribe() {
@@ -970,4 +946,59 @@ mod tests {
assert!(is_audio_file(Path::new("trailer.MOV")));
assert!(is_audio_file(Path::new("animation.avi")));
}
#[test]
fn test_backend_auto_order_prefers_cuda_then_hip_then_vulkan_then_cpu() {
// Clear overrides
unsafe {
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
}
// No GPU -> CPU
let sel = select_backend(BackendKind::Auto, false).unwrap();
assert_eq!(sel.chosen, BackendKind::Cpu);
// Vulkan only
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1"); }
let sel = select_backend(BackendKind::Auto, false).unwrap();
assert_eq!(sel.chosen, BackendKind::Vulkan);
// HIP preferred over Vulkan
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1"); std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); }
let sel = select_backend(BackendKind::Auto, false).unwrap();
assert_eq!(sel.chosen, BackendKind::Hip);
// CUDA preferred over HIP
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); }
let sel = select_backend(BackendKind::Auto, false).unwrap();
assert_eq!(sel.chosen, BackendKind::Cuda);
// Cleanup
unsafe {
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
}
}
#[test]
fn test_backend_explicit_missing_errors() {
// Ensure all off
unsafe {
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
}
assert!(select_backend(BackendKind::Cuda, false).is_err());
assert!(select_backend(BackendKind::Hip, false).is_err());
assert!(select_backend(BackendKind::Vulkan, false).is_err());
// Turn on CUDA only
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); }
assert!(select_backend(BackendKind::Cuda, false).is_ok());
// Turn on HIP only
unsafe { std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA"); std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1"); }
assert!(select_backend(BackendKind::Hip, false).is_ok());
// Turn on Vulkan only
unsafe { std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP"); std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1"); }
assert!(select_backend(BackendKind::Vulkan, false).is_ok());
// Cleanup
unsafe { std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); }
}
}