[feat] add GPU backend support with runtime selection; refactor transcription logic; update CLI and tests
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -1078,6 +1078,7 @@ dependencies = [
|
||||
"clap",
|
||||
"clap_complete",
|
||||
"clap_mangen",
|
||||
"libloading",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
12
Cargo.toml
12
Cargo.toml
@@ -3,6 +3,16 @@ name = "polyscribe"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[features]
|
||||
# Default: CPU only; no GPU features enabled
|
||||
default = []
|
||||
# GPU backends map to whisper-rs features or FFI stub for Vulkan
|
||||
gpu-cuda = ["whisper-rs/cuda"]
|
||||
gpu-hip = ["whisper-rs/hipblas"]
|
||||
gpu-vulkan = []
|
||||
# explicit CPU fallback feature (no effect at build time, used for clarity)
|
||||
cpu-fallback = []
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.98"
|
||||
clap = { version = "4.5.43", features = ["derive"] }
|
||||
@@ -14,7 +24,9 @@ toml = "0.8"
|
||||
chrono = { version = "0.4", features = ["clock"] }
|
||||
reqwest = { version = "0.12", features = ["blocking", "json"] }
|
||||
sha2 = "0.10"
|
||||
# whisper-rs is always used (CPU-only by default); GPU features map onto it
|
||||
whisper-rs = { git = "https://github.com/tazz4843/whisper-rs" }
|
||||
libloading = { version = "0.8" }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
|
1
TODO.md
1
TODO.md
@@ -12,6 +12,7 @@
|
||||
- [x] refactor into proper cli app
|
||||
- [x] add support for video files -> use ffmpeg to extract audio
|
||||
- detect gpus and use them
|
||||
- refactor project
|
||||
- add error handling
|
||||
- add verbose flag (--verbose | -v) + add logging
|
||||
- add documentation
|
||||
|
11
build.rs
Normal file
11
build.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
fn main() {
|
||||
// Only run special build steps when gpu-vulkan feature is enabled.
|
||||
let vulkan_enabled = std::env::var("CARGO_FEATURE_GPU_VULKAN").is_ok();
|
||||
if !vulkan_enabled {
|
||||
return;
|
||||
}
|
||||
// Placeholder: In a full implementation, we would invoke CMake for whisper.cpp with GGML_VULKAN=1.
|
||||
// For now, emit a helpful note. Build will proceed; runtime Vulkan backend returns an explanatory error.
|
||||
println!("cargo:rerun-if-changed=extern/whisper.cpp");
|
||||
println!("cargo:warning=Building with gpu-vulkan: ensure Vulkan SDK/loader are installed. Future versions will compile whisper.cpp via CMake.");
|
||||
}
|
194
src/backend.rs
Normal file
194
src/backend.rs
Normal file
@@ -0,0 +1,194 @@
|
||||
use std::path::Path;
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use libloading::Library;
|
||||
use crate::{OutputEntry};
|
||||
use crate::{decode_audio_to_pcm_f32_ffmpeg, find_model_file};
|
||||
use std::env;
|
||||
|
||||
// Re-export a public enum for CLI parsing usage
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum BackendKind {
|
||||
Auto,
|
||||
Cpu,
|
||||
Cuda,
|
||||
Hip,
|
||||
Vulkan,
|
||||
}
|
||||
|
||||
pub trait TranscribeBackend {
|
||||
fn kind(&self) -> BackendKind;
|
||||
fn transcribe(&self, audio_path: &Path, speaker: &str, lang_opt: Option<&str>, gpu_layers: Option<u32>) -> Result<Vec<OutputEntry>>;
|
||||
}
|
||||
|
||||
fn check_lib(names: &[&str]) -> bool {
|
||||
#[cfg(test)]
|
||||
{
|
||||
// During unit tests, avoid touching system libs to prevent loader crashes in CI.
|
||||
return false;
|
||||
}
|
||||
#[cfg(not(test))]
|
||||
{
|
||||
if std::env::var("POLYSCRIBE_DISABLE_DLOPEN").ok().as_deref() == Some("1") {
|
||||
return false;
|
||||
}
|
||||
for n in names {
|
||||
// Attempt to dlopen; ignore errors
|
||||
if let Ok(_lib) = unsafe { Library::new(n) } { return true; }
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn cuda_available() -> bool {
|
||||
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_CUDA") { return x == "1"; }
|
||||
check_lib(&["libcudart.so", "libcudart.so.12", "libcudart.so.11", "libcublas.so", "libcublas.so.12"])
|
||||
}
|
||||
|
||||
fn hip_available() -> bool {
|
||||
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_HIP") { return x == "1"; }
|
||||
check_lib(&["libhipblas.so", "librocblas.so"])
|
||||
}
|
||||
|
||||
fn vulkan_available() -> bool {
|
||||
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_VULKAN") { return x == "1"; }
|
||||
check_lib(&["libvulkan.so.1", "libvulkan.so"])
|
||||
}
|
||||
|
||||
pub struct CpuBackend;
|
||||
pub struct CudaBackend;
|
||||
pub struct HipBackend;
|
||||
pub struct VulkanBackend;
|
||||
|
||||
impl CpuBackend {
|
||||
pub fn new() -> Self { CpuBackend }
|
||||
}
|
||||
impl CudaBackend { pub fn new() -> Self { CudaBackend } }
|
||||
impl HipBackend { pub fn new() -> Self { HipBackend } }
|
||||
impl VulkanBackend { pub fn new() -> Self { VulkanBackend } }
|
||||
|
||||
impl TranscribeBackend for CpuBackend {
|
||||
fn kind(&self) -> BackendKind { BackendKind::Cpu }
|
||||
fn transcribe(&self, audio_path: &Path, speaker: &str, lang_opt: Option<&str>, _gpu_layers: Option<u32>) -> Result<Vec<OutputEntry>> {
|
||||
transcribe_with_whisper_rs(audio_path, speaker, lang_opt)
|
||||
}
|
||||
}
|
||||
|
||||
impl TranscribeBackend for CudaBackend {
|
||||
fn kind(&self) -> BackendKind { BackendKind::Cuda }
|
||||
fn transcribe(&self, audio_path: &Path, speaker: &str, lang_opt: Option<&str>, _gpu_layers: Option<u32>) -> Result<Vec<OutputEntry>> {
|
||||
// whisper-rs uses enabled CUDA feature at build time; call same code path
|
||||
transcribe_with_whisper_rs(audio_path, speaker, lang_opt)
|
||||
}
|
||||
}
|
||||
|
||||
impl TranscribeBackend for HipBackend {
|
||||
fn kind(&self) -> BackendKind { BackendKind::Hip }
|
||||
fn transcribe(&self, audio_path: &Path, speaker: &str, lang_opt: Option<&str>, _gpu_layers: Option<u32>) -> Result<Vec<OutputEntry>> {
|
||||
transcribe_with_whisper_rs(audio_path, speaker, lang_opt)
|
||||
}
|
||||
}
|
||||
|
||||
impl TranscribeBackend for VulkanBackend {
|
||||
fn kind(&self) -> BackendKind { BackendKind::Vulkan }
|
||||
fn transcribe(&self, _audio_path: &Path, _speaker: &str, _lang_opt: Option<&str>, _gpu_layers: Option<u32>) -> Result<Vec<OutputEntry>> {
|
||||
Err(anyhow!("Vulkan backend not yet wired to whisper.cpp FFI. Build with --features gpu-vulkan and ensure Vulkan SDK is installed. How to fix: install Vulkan loader (libvulkan), set VULKAN_SDK, and run cargo build --features gpu-vulkan."))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SelectionResult {
|
||||
pub backend: Box<dyn TranscribeBackend + Send + Sync>,
|
||||
pub chosen: BackendKind,
|
||||
pub detected: Vec<BackendKind>,
|
||||
}
|
||||
|
||||
pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<SelectionResult> {
|
||||
let mut detected = Vec::new();
|
||||
if cuda_available() { detected.push(BackendKind::Cuda); }
|
||||
if hip_available() { detected.push(BackendKind::Hip); }
|
||||
if vulkan_available() { detected.push(BackendKind::Vulkan); }
|
||||
|
||||
let mk = |k: BackendKind| -> Box<dyn TranscribeBackend + Send + Sync> {
|
||||
match k {
|
||||
BackendKind::Cpu => Box::new(CpuBackend::new()),
|
||||
BackendKind::Cuda => Box::new(CudaBackend::new()),
|
||||
BackendKind::Hip => Box::new(HipBackend::new()),
|
||||
BackendKind::Vulkan => Box::new(VulkanBackend::new()),
|
||||
BackendKind::Auto => Box::new(CpuBackend::new()), // will be replaced
|
||||
}
|
||||
};
|
||||
|
||||
let chosen = match requested {
|
||||
BackendKind::Auto => {
|
||||
if detected.contains(&BackendKind::Cuda) { BackendKind::Cuda }
|
||||
else if detected.contains(&BackendKind::Hip) { BackendKind::Hip }
|
||||
else if detected.contains(&BackendKind::Vulkan) { BackendKind::Vulkan }
|
||||
else { BackendKind::Cpu }
|
||||
}
|
||||
BackendKind::Cuda => {
|
||||
if detected.contains(&BackendKind::Cuda) { BackendKind::Cuda }
|
||||
else { return Err(anyhow!("Requested CUDA backend but CUDA libraries/devices not detected. How to fix: install NVIDIA driver + CUDA toolkit, ensure libcudart/libcublas are in loader path, and build with --features gpu-cuda.")); }
|
||||
}
|
||||
BackendKind::Hip => {
|
||||
if detected.contains(&BackendKind::Hip) { BackendKind::Hip }
|
||||
else { return Err(anyhow!("Requested ROCm/HIP backend but libraries/devices not detected. How to fix: install ROCm hipBLAS/rocBLAS, ensure libs are in loader path, and build with --features gpu-hip.")); }
|
||||
}
|
||||
BackendKind::Vulkan => {
|
||||
if detected.contains(&BackendKind::Vulkan) { BackendKind::Vulkan }
|
||||
else { return Err(anyhow!("Requested Vulkan backend but libvulkan not detected. How to fix: install Vulkan loader/SDK and build with --features gpu-vulkan.")); }
|
||||
}
|
||||
BackendKind::Cpu => BackendKind::Cpu,
|
||||
};
|
||||
|
||||
if verbose {
|
||||
eprintln!("INFO: Detected backends: {:?}", detected);
|
||||
eprintln!("INFO: Selected backend: {:?}", chosen);
|
||||
}
|
||||
|
||||
Ok(SelectionResult { backend: mk(chosen), chosen, detected })
|
||||
}
|
||||
|
||||
// Internal helper: transcription using whisper-rs with CPU/GPU (depending on build features)
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn transcribe_with_whisper_rs(audio_path: &Path, speaker: &str, lang_opt: Option<&str>) -> Result<Vec<OutputEntry>> {
|
||||
let pcm = decode_audio_to_pcm_f32_ffmpeg(audio_path)?;
|
||||
let model = find_model_file()?;
|
||||
let is_en_only = model
|
||||
.file_name()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(|s| s.contains(".en.") || s.ends_with(".en.bin"))
|
||||
.unwrap_or(false);
|
||||
if let Some(lang) = lang_opt {
|
||||
if is_en_only && lang != "en" {
|
||||
return Err(anyhow!(
|
||||
"Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.",
|
||||
model.display(), lang
|
||||
));
|
||||
}
|
||||
}
|
||||
let model_str = model.to_str().ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model.display()))?;
|
||||
|
||||
let cparams = whisper_rs::WhisperContextParameters::default();
|
||||
let ctx = whisper_rs::WhisperContext::new_with_params(model_str, cparams)
|
||||
.with_context(|| format!("Failed to load Whisper model at {}", model.display()))?;
|
||||
let mut state = ctx.create_state().map_err(|e| anyhow!("Failed to create Whisper state: {:?}", e))?;
|
||||
|
||||
let mut params = whisper_rs::FullParams::new(whisper_rs::SamplingStrategy::Greedy { best_of: 1 });
|
||||
let n_threads = std::thread::available_parallelism().map(|n| n.get() as i32).unwrap_or(1);
|
||||
params.set_n_threads(n_threads);
|
||||
params.set_translate(false);
|
||||
if let Some(lang) = lang_opt { params.set_language(Some(lang)); }
|
||||
|
||||
state.full(params, &pcm).map_err(|e| anyhow!("Whisper full() failed: {:?}", e))?;
|
||||
|
||||
let num_segments = state.full_n_segments().map_err(|e| anyhow!("Failed to get segments: {:?}", e))?;
|
||||
let mut items = Vec::new();
|
||||
for i in 0..num_segments {
|
||||
let text = state.full_get_segment_text(i).map_err(|e| anyhow!("Failed to get segment text: {:?}", e))?;
|
||||
let t0 = state.full_get_segment_t0(i).map_err(|e| anyhow!("Failed to get segment t0: {:?}", e))?;
|
||||
let t1 = state.full_get_segment_t1(i).map_err(|e| anyhow!("Failed to get segment t1: {:?}", e))?;
|
||||
let start = (t0 as f64) * 0.01;
|
||||
let end = (t1 as f64) * 0.01;
|
||||
items.push(OutputEntry { id: 0, speaker: speaker.to_string(), start, end, text: text.trim().to_string() });
|
||||
}
|
||||
Ok(items)
|
||||
}
|
155
src/main.rs
155
src/main.rs
@@ -11,9 +11,10 @@ use chrono::Local;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
|
||||
use clap_complete::Shell;
|
||||
|
||||
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
|
||||
|
||||
// whisper-rs is used in backend module
|
||||
mod models;
|
||||
mod backend;
|
||||
use backend::{BackendKind, select_backend, TranscribeBackend};
|
||||
|
||||
static LAST_MODEL_WRITTEN: AtomicBool = AtomicBool::new(false);
|
||||
static VERBOSE: AtomicU8 = AtomicU8::new(0);
|
||||
@@ -82,6 +83,16 @@ enum AuxCommands {
|
||||
Man,
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Debug, Clone, Copy)]
|
||||
#[value(rename_all = "kebab-case")]
|
||||
enum GpuBackendCli {
|
||||
Auto,
|
||||
Cpu,
|
||||
Cuda,
|
||||
Hip,
|
||||
Vulkan,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "PolyScribe", bin_name = "polyscribe", version, about = "Merge JSON transcripts or transcribe audio using native whisper")]
|
||||
struct Args {
|
||||
@@ -112,6 +123,14 @@ struct Args {
|
||||
#[arg(short, long, value_name = "LANG")]
|
||||
language: Option<String>,
|
||||
|
||||
/// Choose GPU backend at runtime (auto|cpu|cuda|hip|vulkan). Default: auto.
|
||||
#[arg(long = "gpu-backend", value_enum, default_value_t = GpuBackendCli::Auto)]
|
||||
gpu_backend: GpuBackendCli,
|
||||
|
||||
/// Number of layers to offload to GPU (if supported by backend)
|
||||
#[arg(long = "gpu-layers", value_name = "N")]
|
||||
gpu_layers: Option<u32>,
|
||||
|
||||
/// Launch interactive model downloader (list HF models, multi-select and download)
|
||||
#[arg(long)]
|
||||
download_models: bool,
|
||||
@@ -251,7 +270,7 @@ fn normalize_lang_code(input: &str) -> Option<String> {
|
||||
|
||||
|
||||
|
||||
fn find_model_file() -> Result<PathBuf> {
|
||||
pub(crate) fn find_model_file() -> Result<PathBuf> {
|
||||
let models_dir_buf = models_dir_path();
|
||||
let models_dir = models_dir_buf.as_path();
|
||||
if !models_dir.exists() {
|
||||
@@ -362,7 +381,7 @@ fn find_model_file() -> Result<PathBuf> {
|
||||
Ok(chosen)
|
||||
}
|
||||
|
||||
fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> {
|
||||
pub(crate) fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> {
|
||||
let output = Command::new("ffmpeg")
|
||||
.arg("-i").arg(audio_path)
|
||||
.arg("-f").arg("f32le")
|
||||
@@ -398,61 +417,6 @@ fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> {
|
||||
}
|
||||
}
|
||||
|
||||
fn transcribe_native(audio_path: &Path, speaker: &str, lang_opt: Option<&str>) -> Result<Vec<OutputEntry>> {
|
||||
let pcm = decode_audio_to_pcm_f32_ffmpeg(audio_path)?;
|
||||
let model = find_model_file()?;
|
||||
let is_en_only = model
|
||||
.file_name()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(|s| s.contains(".en.") || s.ends_with(".en.bin"))
|
||||
.unwrap_or(false);
|
||||
if let Some(lang) = lang_opt {
|
||||
if is_en_only && lang != "en" {
|
||||
return Err(anyhow!(
|
||||
"Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model like models/ggml-base.bin or set WHISPER_MODEL accordingly.",
|
||||
model.display(),
|
||||
lang
|
||||
));
|
||||
}
|
||||
}
|
||||
let model_str = model.to_str().ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model.display()))?;
|
||||
|
||||
// Initialize Whisper with GPU preference
|
||||
let cparams = WhisperContextParameters::default();
|
||||
// Prefer GPU if available; default whisper.cpp already has use_gpu=true. If the wrapper exposes
|
||||
// a gpu_device field in the future, we could set it here from WHISPER_GPU_DEVICE.
|
||||
if let Ok(dev_str) = env::var("WHISPER_GPU_DEVICE") {
|
||||
let _ = dev_str.trim().parse::<i32>().ok();
|
||||
}
|
||||
// Even if we can't set fields explicitly (due to API differences), whisper.cpp defaults to GPU.
|
||||
let ctx = WhisperContext::new_with_params(model_str, cparams)
|
||||
.with_context(|| format!("Failed to load Whisper model at {}", model.display()))?;
|
||||
let mut state = ctx.create_state()
|
||||
.map_err(|e| anyhow!("Failed to create Whisper state: {:?}", e))?;
|
||||
|
||||
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
|
||||
let n_threads = std::thread::available_parallelism().map(|n| n.get() as i32).unwrap_or(1);
|
||||
params.set_n_threads(n_threads);
|
||||
params.set_translate(false);
|
||||
if let Some(lang) = lang_opt { params.set_language(Some(lang)); }
|
||||
|
||||
state.full(params, &pcm)
|
||||
.map_err(|e| anyhow!("Whisper full() failed: {:?}", e))?;
|
||||
|
||||
let num_segments = state.full_n_segments().map_err(|e| anyhow!("Failed to get segments: {:?}", e))?;
|
||||
let mut items = Vec::new();
|
||||
for i in 0..num_segments {
|
||||
let text = state.full_get_segment_text(i)
|
||||
.map_err(|e| anyhow!("Failed to get segment text: {:?}", e))?;
|
||||
let t0 = state.full_get_segment_t0(i).map_err(|e| anyhow!("Failed to get segment t0: {:?}", e))?;
|
||||
let t1 = state.full_get_segment_t1(i).map_err(|e| anyhow!("Failed to get segment t1: {:?}", e))?;
|
||||
let start = (t0 as f64) * 0.01;
|
||||
let end = (t1 as f64) * 0.01;
|
||||
items.push(OutputEntry { id: 0, speaker: speaker.to_string(), start, end, text: text.trim().to_string() });
|
||||
}
|
||||
Ok(items)
|
||||
}
|
||||
|
||||
struct LastModelCleanup {
|
||||
path: PathBuf,
|
||||
}
|
||||
@@ -498,6 +462,17 @@ fn main() -> Result<()> {
|
||||
// Ensure cleanup at end of program, regardless of exit path
|
||||
let _last_model_cleanup = LastModelCleanup { path: last_model_path.clone() };
|
||||
|
||||
// Select backend
|
||||
let requested = match args.gpu_backend {
|
||||
GpuBackendCli::Auto => BackendKind::Auto,
|
||||
GpuBackendCli::Cpu => BackendKind::Cpu,
|
||||
GpuBackendCli::Cuda => BackendKind::Cuda,
|
||||
GpuBackendCli::Hip => BackendKind::Hip,
|
||||
GpuBackendCli::Vulkan => BackendKind::Vulkan,
|
||||
};
|
||||
let sel = select_backend(requested, args.verbose > 0)?;
|
||||
vlog!(0, "Using backend: {:?}", sel.chosen);
|
||||
|
||||
// If requested, run the interactive model downloader first. If no inputs were provided, exit after downloading.
|
||||
if args.download_models {
|
||||
if let Err(e) = models::run_interactive_model_downloader() {
|
||||
@@ -572,7 +547,7 @@ fn main() -> Result<()> {
|
||||
// Collect entries per file and extend merged
|
||||
let mut entries: Vec<OutputEntry> = Vec::new();
|
||||
if is_audio_file(path) {
|
||||
let items = transcribe_native(path, &speaker, lang_hint.as_deref())?;
|
||||
let items = sel.backend.transcribe(path, &speaker, lang_hint.as_deref(), args.gpu_layers)?;
|
||||
entries.extend(items.into_iter());
|
||||
} else if is_json_file(path) {
|
||||
let mut buf = String::new();
|
||||
@@ -665,7 +640,7 @@ fn main() -> Result<()> {
|
||||
|
||||
let mut buf = String::new();
|
||||
if is_audio_file(path) {
|
||||
let items = transcribe_native(path, &speaker, lang_hint.as_deref())?;
|
||||
let items = sel.backend.transcribe(path, &speaker, lang_hint.as_deref(), args.gpu_layers)?;
|
||||
for e in items { entries.push(e); }
|
||||
continue;
|
||||
} else if is_json_file(path) {
|
||||
@@ -766,7 +741,7 @@ fn main() -> Result<()> {
|
||||
// Collect entries per file
|
||||
let mut entries: Vec<OutputEntry> = Vec::new();
|
||||
if is_audio_file(path) {
|
||||
let items = transcribe_native(path, &speaker, lang_hint.as_deref())?;
|
||||
let items = sel.backend.transcribe(path, &speaker, lang_hint.as_deref(), args.gpu_layers)?;
|
||||
entries.extend(items);
|
||||
} else if is_json_file(path) {
|
||||
let mut buf = String::new();
|
||||
@@ -834,6 +809,7 @@ mod tests {
|
||||
use std::io::Write;
|
||||
use std::env as std_env;
|
||||
use clap::CommandFactory;
|
||||
use super::backend::*;
|
||||
|
||||
#[test]
|
||||
fn test_cli_name_polyscribe() {
|
||||
@@ -970,4 +946,59 @@ mod tests {
|
||||
assert!(is_audio_file(Path::new("trailer.MOV")));
|
||||
assert!(is_audio_file(Path::new("animation.avi")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backend_auto_order_prefers_cuda_then_hip_then_vulkan_then_cpu() {
|
||||
// Clear overrides
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
|
||||
}
|
||||
// No GPU -> CPU
|
||||
let sel = select_backend(BackendKind::Auto, false).unwrap();
|
||||
assert_eq!(sel.chosen, BackendKind::Cpu);
|
||||
// Vulkan only
|
||||
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1"); }
|
||||
let sel = select_backend(BackendKind::Auto, false).unwrap();
|
||||
assert_eq!(sel.chosen, BackendKind::Vulkan);
|
||||
// HIP preferred over Vulkan
|
||||
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1"); std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); }
|
||||
let sel = select_backend(BackendKind::Auto, false).unwrap();
|
||||
assert_eq!(sel.chosen, BackendKind::Hip);
|
||||
// CUDA preferred over HIP
|
||||
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); }
|
||||
let sel = select_backend(BackendKind::Auto, false).unwrap();
|
||||
assert_eq!(sel.chosen, BackendKind::Cuda);
|
||||
// Cleanup
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backend_explicit_missing_errors() {
|
||||
// Ensure all off
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
|
||||
}
|
||||
assert!(select_backend(BackendKind::Cuda, false).is_err());
|
||||
assert!(select_backend(BackendKind::Hip, false).is_err());
|
||||
assert!(select_backend(BackendKind::Vulkan, false).is_err());
|
||||
// Turn on CUDA only
|
||||
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); }
|
||||
assert!(select_backend(BackendKind::Cuda, false).is_ok());
|
||||
// Turn on HIP only
|
||||
unsafe { std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA"); std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1"); }
|
||||
assert!(select_backend(BackendKind::Hip, false).is_ok());
|
||||
// Turn on Vulkan only
|
||||
unsafe { std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP"); std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1"); }
|
||||
assert!(select_backend(BackendKind::Vulkan, false).is_ok());
|
||||
// Cleanup
|
||||
unsafe { std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); }
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user