diff --git a/src/backend.rs b/src/backend.rs index d35f857..af54a9f 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -141,6 +141,28 @@ impl Default for VulkanBackend { } } +/// Validate that a provided language hint is compatible with the selected model. +/// +/// English-only models (filenames containing ".en." or ending with ".en.bin") reject non-"en" hints. +/// When no language is provided, this check passes and downstream behavior remains unchanged. +pub(crate) fn validate_model_lang_compat(model: &Path, lang_opt: Option<&str>) -> Result<()> { + let is_en_only = model + .file_name() + .and_then(|s| s.to_str()) + .map(|s| s.contains(".en.") || s.ends_with(".en.bin")) + .unwrap_or(false); + if let Some(lang) = lang_opt { + if is_en_only && lang != "en" { + return Err(anyhow!( + "Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.", + model.display(), + lang + )); + } + } + Ok(()) +} + impl TranscribeBackend for CpuBackend { fn kind(&self) -> BackendKind { BackendKind::Cpu @@ -328,20 +350,8 @@ pub(crate) fn transcribe_with_whisper_rs( note: Some("model selected".to_string()), }); } - let is_en_only = model - .file_name() - .and_then(|s| s.to_str()) - .map(|s| s.contains(".en.") || s.ends_with(".en.bin")) - .unwrap_or(false); - if let Some(lang) = lang_opt { - if is_en_only && lang != "en" { - return Err(anyhow!( - "Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.", - model.display(), - lang - )); - } - } + // Validate language hint compatibility with the selected model + validate_model_lang_compat(&model, lang_opt)?; let model_str = model .to_str() .ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model.display()))?; @@ -451,6 +461,35 @@ mod tests { use std::env as std_env; use std::sync::{Mutex, OnceLock}; + #[test] + fn test_validate_model_lang_guard_table() { + struct case<'a> { model: &'a str, lang: Option<&'a str>, ok: bool } + let cases = vec![ + // English-only model with en hint: OK + case { model: "ggml-base.en.bin", lang: Some("en"), ok: true }, + // English-only model with de hint: Error + case { model: "ggml-small.en.bin", lang: Some("de"), ok: false }, + // Multilingual model with de hint: OK + case { model: "ggml-large-v3.bin", lang: Some("de"), ok: true }, + // No language provided (audio path scenario): guard should pass (existing behavior elsewhere) + case { model: "ggml-medium.en.bin", lang: None, ok: true }, + ]; + for c in cases { + let p = std::path::Path::new(c.model); + let res = validate_model_lang_compat(p, c.lang); + match (c.ok, res) { + (true, Ok(())) => {} + (false, Err(e)) => { + let msg = format!("{}", e); + assert!(msg.contains("English-only"), "unexpected error: {msg}"); + if let Some(l) = c.lang { assert!(msg.contains(l), "missing lang in msg: {msg}"); } + } + (true, Err(e)) => panic!("expected Ok for model={}, lang={:?}, got error: {}", c.model, c.lang, e), + (false, Ok(())) => panic!("expected Err for model={}, lang={:?}", c.model, c.lang), + } + } + } + // Serialize environment variable modifications across tests in this module static ENV_LOCK: OnceLock> = OnceLock::new();