[test] add comprehensive tests for select_backend ensuring proper backend priority and error guidance

This commit is contained in:
2025-08-12 04:24:51 +02:00
parent 7832545033
commit f143e66e80

View File

@@ -443,3 +443,98 @@ pub(crate) fn transcribe_with_whisper_rs(
"Transcription requires the 'whisper' feature. Rebuild with --features whisper (and optional gpu-cuda/gpu-hip)." "Transcription requires the 'whisper' feature. Rebuild with --features whisper (and optional gpu-cuda/gpu-hip)."
)) ))
} }
#[cfg(test)]
mod tests {
use super::*;
use std::env as std_env;
use std::sync::{Mutex, OnceLock};
// Serialize environment variable modifications across tests in this module
static ENV_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
#[test]
fn test_select_backend_auto_prefers_cuda_then_hip_then_vulkan_then_cpu() {
let _guard = ENV_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
// Clear overrides
unsafe {
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
}
// No GPU -> CPU
let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap();
assert_eq!(sel.chosen, BackendKind::Cpu);
// Vulkan only -> Vulkan
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1"); }
let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap();
assert_eq!(sel.chosen, BackendKind::Vulkan);
// HIP only -> HIP (and preferred over Vulkan)
unsafe {
std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1");
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
}
let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap();
assert_eq!(sel.chosen, BackendKind::Hip);
// CUDA only -> CUDA (and preferred over HIP)
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); }
let sel = select_backend(BackendKind::Auto, &crate::Config::default()).unwrap();
assert_eq!(sel.chosen, BackendKind::Cuda);
// Cleanup
unsafe {
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
}
}
#[test]
fn test_select_backend_explicit_unavailable_errors_with_guidance() {
let _guard = ENV_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
// Ensure all off
unsafe {
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN");
}
// CUDA requested but unavailable -> error with guidance
let err = select_backend(BackendKind::Cuda, &crate::Config::default()).err().expect("expected error");
let msg = format!("{}", err);
assert!(msg.contains("Requested CUDA backend"), "unexpected msg: {msg}");
assert!(msg.contains("How to fix"), "expected guidance text in: {msg}");
// HIP requested but unavailable -> error with guidance
let err = select_backend(BackendKind::Hip, &crate::Config::default()).err().expect("expected error");
let msg = format!("{}", err);
assert!(msg.contains("ROCm/HIP"), "unexpected msg: {msg}");
assert!(msg.contains("How to fix"), "expected guidance text in: {msg}");
// Vulkan requested but unavailable -> error with guidance
let err = select_backend(BackendKind::Vulkan, &crate::Config::default()).err().expect("expected error");
let msg = format!("{}", err);
assert!(msg.contains("Vulkan"), "unexpected msg: {msg}");
assert!(msg.contains("How to fix"), "expected guidance text in: {msg}");
// Now verify success when explicitly available via overrides
unsafe { std_env::set_var("POLYSCRIBE_TEST_FORCE_CUDA", "1"); }
assert!(select_backend(BackendKind::Cuda, &crate::Config::default()).is_ok());
unsafe {
std_env::remove_var("POLYSCRIBE_TEST_FORCE_CUDA");
std_env::set_var("POLYSCRIBE_TEST_FORCE_HIP", "1");
}
assert!(select_backend(BackendKind::Hip, &crate::Config::default()).is_ok());
unsafe {
std_env::remove_var("POLYSCRIBE_TEST_FORCE_HIP");
std_env::set_var("POLYSCRIBE_TEST_FORCE_VULKAN", "1");
}
assert!(select_backend(BackendKind::Vulkan, &crate::Config::default()).is_ok());
// Cleanup
unsafe { std_env::remove_var("POLYSCRIBE_TEST_FORCE_VULKAN"); }
}
}