[test] add comprehensive tests for select_backend
ensuring proper backend priority and error guidance
This commit is contained in:
@@ -443,3 +443,98 @@ pub(crate) fn transcribe_with_whisper_rs(
|
||||
"Transcription requires the 'whisper' feature. Rebuild with --features whisper (and optional gpu-cuda/gpu-hip)."
|
||||
))
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::env as std_env;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
// Serialize environment variable modifications across tests in this module
|
||||
static ENV_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
|
||||
#[test]
|
||||
fn test_select_backend_auto_prefers_cuda_then_hip_then_vulkan_then_cpu() {
|
||||
let _guard = ENV_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
|
||||
// Clear overrides
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
|
||||
}
|
||||
// No GPU -> CPU
|
||||
let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap();
|
||||
assert_eq!(sel.chosen, BackendKind::Cpu);
|
||||
|
||||
// Vulkan only -> Vulkan
|
||||
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1"); }
|
||||
let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap();
|
||||
assert_eq!(sel.chosen, BackendKind::Vulkan);
|
||||
|
||||
// HIP only -> HIP (and preferred over Vulkan)
|
||||
unsafe {
|
||||
std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
|
||||
}
|
||||
let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap();
|
||||
assert_eq!(sel.chosen, BackendKind::Hip);
|
||||
|
||||
// CUDA only -> CUDA (and preferred over HIP)
|
||||
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); }
|
||||
let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap();
|
||||
assert_eq!(sel.chosen, BackendKind::Cuda);
|
||||
|
||||
// Cleanup
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_backend_explicit_unavailable_errors_with_guidance() {
|
||||
let _guard = ENV_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
|
||||
// Ensure all off
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
|
||||
}
|
||||
// CUDA requested but unavailable -> error with guidance
|
||||
let err = select_backend(BackendKind::Cuda, &crate::Config::default()).err().expect("expected error");
|
||||
let msg = format!("{}", err);
|
||||
assert!(msg.contains("Requested CUDA backend"), "unexpected msg: {msg}");
|
||||
assert!(msg.contains("How to fix"), "expected guidance text in: {msg}");
|
||||
|
||||
// HIP requested but unavailable -> error with guidance
|
||||
let err = select_backend(BackendKind::Hip, &crate::Config::default()).err().expect("expected error");
|
||||
let msg = format!("{}", err);
|
||||
assert!(msg.contains("ROCm/HIP"), "unexpected msg: {msg}");
|
||||
assert!(msg.contains("How to fix"), "expected guidance text in: {msg}");
|
||||
|
||||
// Vulkan requested but unavailable -> error with guidance
|
||||
let err = select_backend(BackendKind::Vulkan, &crate::Config::default()).err().expect("expected error");
|
||||
let msg = format!("{}", err);
|
||||
assert!(msg.contains("Vulkan"), "unexpected msg: {msg}");
|
||||
assert!(msg.contains("How to fix"), "expected guidance text in: {msg}");
|
||||
|
||||
// Now verify success when explicitly available via overrides
|
||||
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); }
|
||||
assert!(select_backend(BackendKind::Cuda, &crate::Config::default()).is_ok());
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
|
||||
std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1");
|
||||
}
|
||||
assert!(select_backend(BackendKind::Hip, &crate::Config::default()).is_ok());
|
||||
unsafe {
|
||||
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
|
||||
std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1");
|
||||
}
|
||||
assert!(select_backend(BackendKind::Vulkan, &crate::Config::default()).is_ok());
|
||||
|
||||
// Cleanup
|
||||
unsafe { std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); }
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user