diff --git a/Cargo.lock b/Cargo.lock index befb354..7524656 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1078,6 +1078,7 @@ dependencies = [ "clap", "clap_complete", "clap_mangen", + "libloading", "reqwest", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 5fab8c0..3b38005 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,16 @@ name = "polyscribe" version = "0.1.0" edition = "2024" +[features] +# Default: CPU only; no GPU features enabled +default = [] +# GPU backends map to whisper-rs features or FFI stub for Vulkan +gpu-cuda = ["whisper-rs/cuda"] +gpu-hip = ["whisper-rs/hipblas"] +gpu-vulkan = [] +# explicit CPU fallback feature (no effect at build time, used for clarity) +cpu-fallback = [] + [dependencies] anyhow = "1.0.98" clap = { version = "4.5.43", features = ["derive"] } @@ -14,7 +24,9 @@ toml = "0.8" chrono = { version = "0.4", features = ["clock"] } reqwest = { version = "0.12", features = ["blocking", "json"] } sha2 = "0.10" +# whisper-rs is always used (CPU-only by default); GPU features map onto it whisper-rs = { git = "https://github.com/tazz4843/whisper-rs" } +libloading = { version = "0.8" } [dev-dependencies] tempfile = "3" diff --git a/TODO.md b/TODO.md index 71603fc..1cfbeae 100644 --- a/TODO.md +++ b/TODO.md @@ -12,6 +12,7 @@ - [x] refactor into proper cli app - [x] add support for video files -> use ffmpeg to extract audio - detect gpus and use them +- refactor project - add error handling - add verbose flag (--verbose | -v) + add logging - add documentation diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..3a9588a --- /dev/null +++ b/build.rs @@ -0,0 +1,11 @@ +fn main() { + // Only run special build steps when gpu-vulkan feature is enabled. + let vulkan_enabled = std::env::var("CARGO_FEATURE_GPU_VULKAN").is_ok(); + if !vulkan_enabled { + return; + } + // Placeholder: In a full implementation, we would invoke CMake for whisper.cpp with GGML_VULKAN=1. + // For now, emit a helpful note. Build will proceed; runtime Vulkan backend returns an explanatory error. + println!("cargo:rerun-if-changed=extern/whisper.cpp"); + println!("cargo:warning=Building with gpu-vulkan: ensure Vulkan SDK/loader are installed. Future versions will compile whisper.cpp via CMake."); +} diff --git a/src/backend.rs b/src/backend.rs new file mode 100644 index 0000000..d8c5c98 --- /dev/null +++ b/src/backend.rs @@ -0,0 +1,194 @@ +use std::path::Path; +use anyhow::{anyhow, Context, Result}; +use libloading::Library; +use crate::{OutputEntry}; +use crate::{decode_audio_to_pcm_f32_ffmpeg, find_model_file}; +use std::env; + +// Re-export a public enum for CLI parsing usage +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BackendKind { + Auto, + Cpu, + Cuda, + Hip, + Vulkan, +} + +pub trait TranscribeBackend { + fn kind(&self) -> BackendKind; + fn transcribe(&self, audio_path: &Path, speaker: &str, lang_opt: Option<&str>, gpu_layers: Option) -> Result>; +} + +fn check_lib(names: &[&str]) -> bool { + #[cfg(test)] + { + // During unit tests, avoid touching system libs to prevent loader crashes in CI. + return false; + } + #[cfg(not(test))] + { + if std::env::var("POLYSCRIBE_DISABLE_DLOPEN").ok().as_deref() == Some("1") { + return false; + } + for n in names { + // Attempt to dlopen; ignore errors + if let Ok(_lib) = unsafe { Library::new(n) } { return true; } + } + false + } +} + +fn cuda_available() -> bool { + if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_CUDA") { return x == "1"; } + check_lib(&["libcudart.so", "libcudart.so.12", "libcudart.so.11", "libcublas.so", "libcublas.so.12"]) +} + +fn hip_available() -> bool { + if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_HIP") { return x == "1"; } + check_lib(&["libhipblas.so", "librocblas.so"]) +} + +fn vulkan_available() -> bool { + if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_VULKAN") { return x == "1"; } + check_lib(&["libvulkan.so.1", "libvulkan.so"]) +} + +pub struct CpuBackend; +pub struct CudaBackend; +pub struct HipBackend; +pub struct VulkanBackend; + +impl CpuBackend { + pub fn new() -> Self { CpuBackend } +} +impl CudaBackend { pub fn new() -> Self { CudaBackend } } +impl HipBackend { pub fn new() -> Self { HipBackend } } +impl VulkanBackend { pub fn new() -> Self { VulkanBackend } } + +impl TranscribeBackend for CpuBackend { + fn kind(&self) -> BackendKind { BackendKind::Cpu } + fn transcribe(&self, audio_path: &Path, speaker: &str, lang_opt: Option<&str>, _gpu_layers: Option) -> Result> { + transcribe_with_whisper_rs(audio_path, speaker, lang_opt) + } +} + +impl TranscribeBackend for CudaBackend { + fn kind(&self) -> BackendKind { BackendKind::Cuda } + fn transcribe(&self, audio_path: &Path, speaker: &str, lang_opt: Option<&str>, _gpu_layers: Option) -> Result> { + // whisper-rs uses enabled CUDA feature at build time; call same code path + transcribe_with_whisper_rs(audio_path, speaker, lang_opt) + } +} + +impl TranscribeBackend for HipBackend { + fn kind(&self) -> BackendKind { BackendKind::Hip } + fn transcribe(&self, audio_path: &Path, speaker: &str, lang_opt: Option<&str>, _gpu_layers: Option) -> Result> { + transcribe_with_whisper_rs(audio_path, speaker, lang_opt) + } +} + +impl TranscribeBackend for VulkanBackend { + fn kind(&self) -> BackendKind { BackendKind::Vulkan } + fn transcribe(&self, _audio_path: &Path, _speaker: &str, _lang_opt: Option<&str>, _gpu_layers: Option) -> Result> { + Err(anyhow!("Vulkan backend not yet wired to whisper.cpp FFI. Build with --features gpu-vulkan and ensure Vulkan SDK is installed. How to fix: install Vulkan loader (libvulkan), set VULKAN_SDK, and run cargo build --features gpu-vulkan.")) + } +} + +pub struct SelectionResult { + pub backend: Box, + pub chosen: BackendKind, + pub detected: Vec, +} + +pub fn select_backend(requested: BackendKind, verbose: bool) -> Result { + let mut detected = Vec::new(); + if cuda_available() { detected.push(BackendKind::Cuda); } + if hip_available() { detected.push(BackendKind::Hip); } + if vulkan_available() { detected.push(BackendKind::Vulkan); } + + let mk = |k: BackendKind| -> Box { + match k { + BackendKind::Cpu => Box::new(CpuBackend::new()), + BackendKind::Cuda => Box::new(CudaBackend::new()), + BackendKind::Hip => Box::new(HipBackend::new()), + BackendKind::Vulkan => Box::new(VulkanBackend::new()), + BackendKind::Auto => Box::new(CpuBackend::new()), // will be replaced + } + }; + + let chosen = match requested { + BackendKind::Auto => { + if detected.contains(&BackendKind::Cuda) { BackendKind::Cuda } + else if detected.contains(&BackendKind::Hip) { BackendKind::Hip } + else if detected.contains(&BackendKind::Vulkan) { BackendKind::Vulkan } + else { BackendKind::Cpu } + } + BackendKind::Cuda => { + if detected.contains(&BackendKind::Cuda) { BackendKind::Cuda } + else { return Err(anyhow!("Requested CUDA backend but CUDA libraries/devices not detected. How to fix: install NVIDIA driver + CUDA toolkit, ensure libcudart/libcublas are in loader path, and build with --features gpu-cuda.")); } + } + BackendKind::Hip => { + if detected.contains(&BackendKind::Hip) { BackendKind::Hip } + else { return Err(anyhow!("Requested ROCm/HIP backend but libraries/devices not detected. How to fix: install ROCm hipBLAS/rocBLAS, ensure libs are in loader path, and build with --features gpu-hip.")); } + } + BackendKind::Vulkan => { + if detected.contains(&BackendKind::Vulkan) { BackendKind::Vulkan } + else { return Err(anyhow!("Requested Vulkan backend but libvulkan not detected. How to fix: install Vulkan loader/SDK and build with --features gpu-vulkan.")); } + } + BackendKind::Cpu => BackendKind::Cpu, + }; + + if verbose { + eprintln!("INFO: Detected backends: {:?}", detected); + eprintln!("INFO: Selected backend: {:?}", chosen); + } + + Ok(SelectionResult { backend: mk(chosen), chosen, detected }) +} + +// Internal helper: transcription using whisper-rs with CPU/GPU (depending on build features) +#[allow(clippy::too_many_arguments)] +pub(crate) fn transcribe_with_whisper_rs(audio_path: &Path, speaker: &str, lang_opt: Option<&str>) -> Result> { + let pcm = decode_audio_to_pcm_f32_ffmpeg(audio_path)?; + let model = find_model_file()?; + let is_en_only = model + .file_name() + .and_then(|s| s.to_str()) + .map(|s| s.contains(".en.") || s.ends_with(".en.bin")) + .unwrap_or(false); + if let Some(lang) = lang_opt { + if is_en_only && lang != "en" { + return Err(anyhow!( + "Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.", + model.display(), lang + )); + } + } + let model_str = model.to_str().ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model.display()))?; + + let cparams = whisper_rs::WhisperContextParameters::default(); + let ctx = whisper_rs::WhisperContext::new_with_params(model_str, cparams) + .with_context(|| format!("Failed to load Whisper model at {}", model.display()))?; + let mut state = ctx.create_state().map_err(|e| anyhow!("Failed to create Whisper state: {:?}", e))?; + + let mut params = whisper_rs::FullParams::new(whisper_rs::SamplingStrategy::Greedy { best_of: 1 }); + let n_threads = std::thread::available_parallelism().map(|n| n.get() as i32).unwrap_or(1); + params.set_n_threads(n_threads); + params.set_translate(false); + if let Some(lang) = lang_opt { params.set_language(Some(lang)); } + + state.full(params, &pcm).map_err(|e| anyhow!("Whisper full() failed: {:?}", e))?; + + let num_segments = state.full_n_segments().map_err(|e| anyhow!("Failed to get segments: {:?}", e))?; + let mut items = Vec::new(); + for i in 0..num_segments { + let text = state.full_get_segment_text(i).map_err(|e| anyhow!("Failed to get segment text: {:?}", e))?; + let t0 = state.full_get_segment_t0(i).map_err(|e| anyhow!("Failed to get segment t0: {:?}", e))?; + let t1 = state.full_get_segment_t1(i).map_err(|e| anyhow!("Failed to get segment t1: {:?}", e))?; + let start = (t0 as f64) * 0.01; + let end = (t1 as f64) * 0.01; + items.push(OutputEntry { id: 0, speaker: speaker.to_string(), start, end, text: text.trim().to_string() }); + } + Ok(items) +} diff --git a/src/main.rs b/src/main.rs index 3f00191..169459e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,9 +11,10 @@ use chrono::Local; use std::sync::atomic::{AtomicBool, AtomicU8, Ordering}; use clap_complete::Shell; -use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters}; - +// whisper-rs is used in backend module mod models; +mod backend; +use backend::{BackendKind, select_backend, TranscribeBackend}; static LAST_MODEL_WRITTEN: AtomicBool = AtomicBool::new(false); static VERBOSE: AtomicU8 = AtomicU8::new(0); @@ -82,6 +83,16 @@ enum AuxCommands { Man, } +#[derive(clap::ValueEnum, Debug, Clone, Copy)] +#[value(rename_all = "kebab-case")] +enum GpuBackendCli { + Auto, + Cpu, + Cuda, + Hip, + Vulkan, +} + #[derive(Parser, Debug)] #[command(name = "PolyScribe", bin_name = "polyscribe", version, about = "Merge JSON transcripts or transcribe audio using native whisper")] struct Args { @@ -112,6 +123,14 @@ struct Args { #[arg(short, long, value_name = "LANG")] language: Option, + /// Choose GPU backend at runtime (auto|cpu|cuda|hip|vulkan). Default: auto. + #[arg(long = "gpu-backend", value_enum, default_value_t = GpuBackendCli::Auto)] + gpu_backend: GpuBackendCli, + + /// Number of layers to offload to GPU (if supported by backend) + #[arg(long = "gpu-layers", value_name = "N")] + gpu_layers: Option, + /// Launch interactive model downloader (list HF models, multi-select and download) #[arg(long)] download_models: bool, @@ -251,7 +270,7 @@ fn normalize_lang_code(input: &str) -> Option { -fn find_model_file() -> Result { +pub(crate) fn find_model_file() -> Result { let models_dir_buf = models_dir_path(); let models_dir = models_dir_buf.as_path(); if !models_dir.exists() { @@ -362,7 +381,7 @@ fn find_model_file() -> Result { Ok(chosen) } -fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result> { +pub(crate) fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result> { let output = Command::new("ffmpeg") .arg("-i").arg(audio_path) .arg("-f").arg("f32le") @@ -398,61 +417,6 @@ fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result> { } } -fn transcribe_native(audio_path: &Path, speaker: &str, lang_opt: Option<&str>) -> Result> { - let pcm = decode_audio_to_pcm_f32_ffmpeg(audio_path)?; - let model = find_model_file()?; - let is_en_only = model - .file_name() - .and_then(|s| s.to_str()) - .map(|s| s.contains(".en.") || s.ends_with(".en.bin")) - .unwrap_or(false); - if let Some(lang) = lang_opt { - if is_en_only && lang != "en" { - return Err(anyhow!( - "Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model like models/ggml-base.bin or set WHISPER_MODEL accordingly.", - model.display(), - lang - )); - } - } - let model_str = model.to_str().ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model.display()))?; - - // Initialize Whisper with GPU preference - let cparams = WhisperContextParameters::default(); - // Prefer GPU if available; default whisper.cpp already has use_gpu=true. If the wrapper exposes - // a gpu_device field in the future, we could set it here from WHISPER_GPU_DEVICE. - if let Ok(dev_str) = env::var("WHISPER_GPU_DEVICE") { - let _ = dev_str.trim().parse::().ok(); - } - // Even if we can't set fields explicitly (due to API differences), whisper.cpp defaults to GPU. - let ctx = WhisperContext::new_with_params(model_str, cparams) - .with_context(|| format!("Failed to load Whisper model at {}", model.display()))?; - let mut state = ctx.create_state() - .map_err(|e| anyhow!("Failed to create Whisper state: {:?}", e))?; - - let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 }); - let n_threads = std::thread::available_parallelism().map(|n| n.get() as i32).unwrap_or(1); - params.set_n_threads(n_threads); - params.set_translate(false); - if let Some(lang) = lang_opt { params.set_language(Some(lang)); } - - state.full(params, &pcm) - .map_err(|e| anyhow!("Whisper full() failed: {:?}", e))?; - - let num_segments = state.full_n_segments().map_err(|e| anyhow!("Failed to get segments: {:?}", e))?; - let mut items = Vec::new(); - for i in 0..num_segments { - let text = state.full_get_segment_text(i) - .map_err(|e| anyhow!("Failed to get segment text: {:?}", e))?; - let t0 = state.full_get_segment_t0(i).map_err(|e| anyhow!("Failed to get segment t0: {:?}", e))?; - let t1 = state.full_get_segment_t1(i).map_err(|e| anyhow!("Failed to get segment t1: {:?}", e))?; - let start = (t0 as f64) * 0.01; - let end = (t1 as f64) * 0.01; - items.push(OutputEntry { id: 0, speaker: speaker.to_string(), start, end, text: text.trim().to_string() }); - } - Ok(items) -} - struct LastModelCleanup { path: PathBuf, } @@ -498,6 +462,17 @@ fn main() -> Result<()> { // Ensure cleanup at end of program, regardless of exit path let _last_model_cleanup = LastModelCleanup { path: last_model_path.clone() }; + // Select backend + let requested = match args.gpu_backend { + GpuBackendCli::Auto => BackendKind::Auto, + GpuBackendCli::Cpu => BackendKind::Cpu, + GpuBackendCli::Cuda => BackendKind::Cuda, + GpuBackendCli::Hip => BackendKind::Hip, + GpuBackendCli::Vulkan => BackendKind::Vulkan, + }; + let sel = select_backend(requested, args.verbose > 0)?; + vlog!(0, "Using backend: {:?}", sel.chosen); + // If requested, run the interactive model downloader first. If no inputs were provided, exit after downloading. if args.download_models { if let Err(e) = models::run_interactive_model_downloader() { @@ -572,7 +547,7 @@ fn main() -> Result<()> { // Collect entries per file and extend merged let mut entries: Vec = Vec::new(); if is_audio_file(path) { - let items = transcribe_native(path, &speaker, lang_hint.as_deref())?; + let items = sel.backend.transcribe(path, &speaker, lang_hint.as_deref(), args.gpu_layers)?; entries.extend(items.into_iter()); } else if is_json_file(path) { let mut buf = String::new(); @@ -665,7 +640,7 @@ fn main() -> Result<()> { let mut buf = String::new(); if is_audio_file(path) { - let items = transcribe_native(path, &speaker, lang_hint.as_deref())?; + let items = sel.backend.transcribe(path, &speaker, lang_hint.as_deref(), args.gpu_layers)?; for e in items { entries.push(e); } continue; } else if is_json_file(path) { @@ -766,7 +741,7 @@ fn main() -> Result<()> { // Collect entries per file let mut entries: Vec = Vec::new(); if is_audio_file(path) { - let items = transcribe_native(path, &speaker, lang_hint.as_deref())?; + let items = sel.backend.transcribe(path, &speaker, lang_hint.as_deref(), args.gpu_layers)?; entries.extend(items); } else if is_json_file(path) { let mut buf = String::new(); @@ -834,6 +809,7 @@ mod tests { use std::io::Write; use std::env as std_env; use clap::CommandFactory; + use super::backend::*; #[test] fn test_cli_name_polyscribe() { @@ -970,4 +946,59 @@ mod tests { assert!(is_audio_file(Path::new("trailer.MOV"))); assert!(is_audio_file(Path::new("animation.avi"))); } + + #[test] + fn test_backend_auto_order_prefers_cuda_then_hip_then_vulkan_then_cpu() { + // Clear overrides + unsafe { + std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA"); + std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP"); + std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); + } + // No GPU -> CPU + let sel = select_backend(BackendKind::Auto, false).unwrap(); + assert_eq!(sel.chosen, BackendKind::Cpu); + // Vulkan only + unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1"); } + let sel = select_backend(BackendKind::Auto, false).unwrap(); + assert_eq!(sel.chosen, BackendKind::Vulkan); + // HIP preferred over Vulkan + unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1"); std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); } + let sel = select_backend(BackendKind::Auto, false).unwrap(); + assert_eq!(sel.chosen, BackendKind::Hip); + // CUDA preferred over HIP + unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); } + let sel = select_backend(BackendKind::Auto, false).unwrap(); + assert_eq!(sel.chosen, BackendKind::Cuda); + // Cleanup + unsafe { + std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA"); + std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP"); + std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); + } + } + + #[test] + fn test_backend_explicit_missing_errors() { + // Ensure all off + unsafe { + std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA"); + std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP"); + std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); + } + assert!(select_backend(BackendKind::Cuda, false).is_err()); + assert!(select_backend(BackendKind::Hip, false).is_err()); + assert!(select_backend(BackendKind::Vulkan, false).is_err()); + // Turn on CUDA only + unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); } + assert!(select_backend(BackendKind::Cuda, false).is_ok()); + // Turn on HIP only + unsafe { std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA"); std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1"); } + assert!(select_backend(BackendKind::Hip, false).is_ok()); + // Turn on Vulkan only + unsafe { std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP"); std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1"); } + assert!(select_backend(BackendKind::Vulkan, false).is_ok()); + // Cleanup + unsafe { std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); } + } }