[test] add comprehensive tests for select_backend
ensuring proper backend priority and error guidance
This commit is contained in:
@@ -443,3 +443,98 @@ pub(crate) fn transcribe_with_whisper_rs(
|
|||||||
"Transcription requires the 'whisper' feature. Rebuild with --features whisper (and optional gpu-cuda/gpu-hip)."
|
"Transcription requires the 'whisper' feature. Rebuild with --features whisper (and optional gpu-cuda/gpu-hip)."
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::env as std_env;
|
||||||
|
use std::sync::{Mutex, OnceLock};
|
||||||
|
|
||||||
|
// Serialize environment variable modifications across tests in this module
|
||||||
|
static ENV_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_select_backend_auto_prefers_cuda_then_hip_then_vulkan_then_cpu() {
|
||||||
|
let _guard = ENV_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
|
||||||
|
// Clear overrides
|
||||||
|
unsafe {
|
||||||
|
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
|
||||||
|
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
|
||||||
|
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
|
||||||
|
}
|
||||||
|
// No GPU -> CPU
|
||||||
|
let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap();
|
||||||
|
assert_eq!(sel.chosen, BackendKind::Cpu);
|
||||||
|
|
||||||
|
// Vulkan only -> Vulkan
|
||||||
|
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1"); }
|
||||||
|
let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap();
|
||||||
|
assert_eq!(sel.chosen, BackendKind::Vulkan);
|
||||||
|
|
||||||
|
// HIP only -> HIP (and preferred over Vulkan)
|
||||||
|
unsafe {
|
||||||
|
std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1");
|
||||||
|
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
|
||||||
|
}
|
||||||
|
let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap();
|
||||||
|
assert_eq!(sel.chosen, BackendKind::Hip);
|
||||||
|
|
||||||
|
// CUDA only -> CUDA (and preferred over HIP)
|
||||||
|
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); }
|
||||||
|
let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap();
|
||||||
|
assert_eq!(sel.chosen, BackendKind::Cuda);
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
unsafe {
|
||||||
|
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
|
||||||
|
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
|
||||||
|
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_select_backend_explicit_unavailable_errors_with_guidance() {
|
||||||
|
let _guard = ENV_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
|
||||||
|
// Ensure all off
|
||||||
|
unsafe {
|
||||||
|
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
|
||||||
|
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
|
||||||
|
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
|
||||||
|
}
|
||||||
|
// CUDA requested but unavailable -> error with guidance
|
||||||
|
let err = select_backend(BackendKind::Cuda, &crate::Config::default()).err().expect("expected error");
|
||||||
|
let msg = format!("{}", err);
|
||||||
|
assert!(msg.contains("Requested CUDA backend"), "unexpected msg: {msg}");
|
||||||
|
assert!(msg.contains("How to fix"), "expected guidance text in: {msg}");
|
||||||
|
|
||||||
|
// HIP requested but unavailable -> error with guidance
|
||||||
|
let err = select_backend(BackendKind::Hip, &crate::Config::default()).err().expect("expected error");
|
||||||
|
let msg = format!("{}", err);
|
||||||
|
assert!(msg.contains("ROCm/HIP"), "unexpected msg: {msg}");
|
||||||
|
assert!(msg.contains("How to fix"), "expected guidance text in: {msg}");
|
||||||
|
|
||||||
|
// Vulkan requested but unavailable -> error with guidance
|
||||||
|
let err = select_backend(BackendKind::Vulkan, &crate::Config::default()).err().expect("expected error");
|
||||||
|
let msg = format!("{}", err);
|
||||||
|
assert!(msg.contains("Vulkan"), "unexpected msg: {msg}");
|
||||||
|
assert!(msg.contains("How to fix"), "expected guidance text in: {msg}");
|
||||||
|
|
||||||
|
// Now verify success when explicitly available via overrides
|
||||||
|
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); }
|
||||||
|
assert!(select_backend(BackendKind::Cuda, &crate::Config::default()).is_ok());
|
||||||
|
unsafe {
|
||||||
|
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
|
||||||
|
std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1");
|
||||||
|
}
|
||||||
|
assert!(select_backend(BackendKind::Hip, &crate::Config::default()).is_ok());
|
||||||
|
unsafe {
|
||||||
|
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
|
||||||
|
std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1");
|
||||||
|
}
|
||||||
|
assert!(select_backend(BackendKind::Vulkan, &crate::Config::default()).is_ok());
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
unsafe { std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user