diff --git a/src/main.rs b/src/main.rs index 7edfcd0..2850e7a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,7 +23,11 @@ fn models_dir_path() -> PathBuf { return pb; } } - PathBuf::from("models") + if cfg!(debug_assertions) { + PathBuf::from("models") + } else { + PathBuf::from("/usr/share/polyscribe/models") + } } #[derive(Parser, Debug)] @@ -633,4 +637,25 @@ mod tests { assert_eq!(bytes[7], b'-'); assert!(bytes[8].is_ascii_digit() && bytes[9].is_ascii_digit()); } + + #[test] + #[cfg(debug_assertions)] + fn test_models_dir_path_default_debug_and_env_override() { + // clear override + unsafe { std_env::remove_var("POLYSCRIBE_MODELS_DIR"); } + assert_eq!(models_dir_path(), PathBuf::from("models")); + // override + let tmp = tempfile::tempdir().unwrap(); + unsafe { std_env::set_var("POLYSCRIBE_MODELS_DIR", tmp.path()); } + assert_eq!(models_dir_path(), tmp.path().to_path_buf()); + // cleanup + unsafe { std_env::remove_var("POLYSCRIBE_MODELS_DIR"); } + } + + #[test] + #[cfg(not(debug_assertions))] + fn test_models_dir_path_default_release() { + unsafe { std_env::remove_var("POLYSCRIBE_MODELS_DIR"); } + assert_eq!(models_dir_path(), PathBuf::from("/usr/share/polyscribe/models")); + } } diff --git a/src/models.rs b/src/models.rs index a7cf1f5..3f50340 100644 --- a/src/models.rs +++ b/src/models.rs @@ -345,7 +345,11 @@ fn models_dir_path() -> std::path::PathBuf { let pb = std::path::PathBuf::from(p); if !pb.as_os_str().is_empty() { return pb; } } - std::path::PathBuf::from("models") + if cfg!(debug_assertions) { + std::path::PathBuf::from("models") + } else { + std::path::PathBuf::from("/usr/share/polyscribe/models") + } } pub fn run_interactive_model_downloader() -> Result<()> { @@ -666,4 +670,25 @@ mod tests { let got = fs::read(&local_path).unwrap(); assert_eq!(got, new_content); } + + #[test] + #[cfg(debug_assertions)] + fn test_models_dir_path_default_debug_and_env_override_models_mod() { + // clear override + unsafe { std::env::remove_var("POLYSCRIBE_MODELS_DIR"); } + assert_eq!(super::models_dir_path(), std::path::PathBuf::from("models")); + // override + let tmp = tempfile::tempdir().unwrap(); + unsafe { std::env::set_var("POLYSCRIBE_MODELS_DIR", tmp.path()); } + assert_eq!(super::models_dir_path(), tmp.path().to_path_buf()); + // cleanup + unsafe { std::env::remove_var("POLYSCRIBE_MODELS_DIR"); } + } + + #[test] + #[cfg(not(debug_assertions))] + fn test_models_dir_path_default_release_models_mod() { + unsafe { std::env::remove_var("POLYSCRIBE_MODELS_DIR"); } + assert_eq!(super::models_dir_path(), std::path::PathBuf::from("/usr/share/polyscribe/models")); + } }