diff --git a/src/backend.rs b/src/backend.rs index 33dd9a5..d35f857 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -443,3 +443,98 @@ pub(crate) fn transcribe_with_whisper_rs( "Transcription requires the 'whisper' feature. Rebuild with --features whisper (and optional gpu-cuda/gpu-hip)." )) } + + +#[cfg(test)] +mod tests { + use super::*; + use std::env as std_env; + use std::sync::{Mutex, OnceLock}; + + // Serialize environment variable modifications across tests in this module + static ENV_LOCK: OnceLock> = OnceLock::new(); + + #[test] + fn test_select_backend_auto_prefers_cuda_then_hip_then_vulkan_then_cpu() { + let _guard = ENV_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap(); + // Clear overrides + unsafe { + std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA"); + std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP"); + std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); + } + // No GPU -> CPU + let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap(); + assert_eq!(sel.chosen, BackendKind::Cpu); + + // Vulkan only -> Vulkan + unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1"); } + let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap(); + assert_eq!(sel.chosen, BackendKind::Vulkan); + + // HIP only -> HIP (and preferred over Vulkan) + unsafe { + std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1"); + std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); + } + let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap(); + assert_eq!(sel.chosen, BackendKind::Hip); + + // CUDA only -> CUDA (and preferred over HIP) + unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); } + let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap(); + assert_eq!(sel.chosen, BackendKind::Cuda); + + // Cleanup + unsafe { + std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA"); + std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP"); + std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); + } + } + + #[test] + fn test_select_backend_explicit_unavailable_errors_with_guidance() { + let _guard = ENV_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap(); + // Ensure all off + unsafe { + std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA"); + std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP"); + std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); + } + // CUDA requested but unavailable -> error with guidance + let err = select_backend(BackendKind::Cuda, &crate::Config::default()).err().expect("expected error"); + let msg = format!("{}", err); + assert!(msg.contains("Requested CUDA backend"), "unexpected msg: {msg}"); + assert!(msg.contains("How to fix"), "expected guidance text in: {msg}"); + + // HIP requested but unavailable -> error with guidance + let err = select_backend(BackendKind::Hip, &crate::Config::default()).err().expect("expected error"); + let msg = format!("{}", err); + assert!(msg.contains("ROCm/HIP"), "unexpected msg: {msg}"); + assert!(msg.contains("How to fix"), "expected guidance text in: {msg}"); + + // Vulkan requested but unavailable -> error with guidance + let err = select_backend(BackendKind::Vulkan, &crate::Config::default()).err().expect("expected error"); + let msg = format!("{}", err); + assert!(msg.contains("Vulkan"), "unexpected msg: {msg}"); + assert!(msg.contains("How to fix"), "expected guidance text in: {msg}"); + + // Now verify success when explicitly available via overrides + unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); } + assert!(select_backend(BackendKind::Cuda, &crate::Config::default()).is_ok()); + unsafe { + std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA"); + std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1"); + } + assert!(select_backend(BackendKind::Hip, &crate::Config::default()).is_ok()); + unsafe { + std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP"); + std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1"); + } + assert!(select_backend(BackendKind::Vulkan, &crate::Config::default()).is_ok()); + + // Cleanup + unsafe { std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); } + } +}