[feat] add models_dir_path helper and implement dynamic models directory resolution

This commit is contained in:
2025-08-08 12:12:42 +02:00
parent aa8ea14407
commit 9ebe46b7fc
5 changed files with 134 additions and 10 deletions

View File

@@ -340,8 +340,17 @@ fn compute_file_sha256_hex(path: &Path) -> Result<String> {
Ok(to_hex_lower(&hasher.finalize()))
}
fn models_dir_path() -> std::path::PathBuf {
if let Ok(p) = env::var("POLYSCRIBE_MODELS_DIR") {
let pb = std::path::PathBuf::from(p);
if !pb.as_os_str().is_empty() { return pb; }
}
std::path::PathBuf::from("models")
}
pub fn run_interactive_model_downloader() -> Result<()> {
let models_dir = Path::new("models");
let models_dir_buf = models_dir_path();
let models_dir = models_dir_buf.as_path();
if !models_dir.exists() { create_dir_all(models_dir).context("Failed to create models directory")?; }
let client = Client::builder()
.user_agent("PolyScribe/0.1 (+https://github.com/)")
@@ -504,7 +513,8 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
}
pub fn update_local_models() -> Result<()> {
let models_dir = Path::new("models");
let models_dir_buf = models_dir_path();
let models_dir = models_dir_buf.as_path();
if !models_dir.exists() {
create_dir_all(models_dir).context("Failed to create models directory")?;
}
@@ -591,3 +601,69 @@ pub fn update_local_models() -> Result<()> {
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
use std::fs;
use std::io::Write;
fn sha256_hex(data: &[u8]) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(data);
let out = hasher.finalize();
let mut s = String::new();
for b in out { s.push_str(&format!("{:02x}", b)); }
s
}
#[test]
fn test_update_local_models_offline_copy_and_manifest() {
let tmp_models = tempdir().unwrap();
let tmp_base = tempdir().unwrap();
let tmp_manifest = tempdir().unwrap();
// Prepare source model file content and hash
let model_name = "tiny.en-q5_1";
let src_path = tmp_base.path().join(format!("ggml-{}.bin", model_name));
let new_content = b"new model content";
fs::write(&src_path, new_content).unwrap();
let expected_sha = sha256_hex(new_content);
let expected_size = new_content.len() as u64;
// Write a wrong existing local file to trigger update
let local_path = tmp_models.path().join(format!("ggml-{}.bin", model_name));
fs::write(&local_path, b"old content").unwrap();
// Write manifest JSON
let manifest_path = tmp_manifest.path().join("manifest.json");
let manifest = serde_json::json!([
{
"name": model_name,
"base": "tiny",
"subtype": "en-q5_1",
"size": expected_size,
"sha256": expected_sha,
"repo": "ggerganov/whisper.cpp"
}
]);
fs::write(&manifest_path, serde_json::to_string_pretty(&manifest).unwrap()).unwrap();
// Set env vars to force offline behavior and directories
unsafe {
std::env::set_var("POLYSCRIBE_MODELS_MANIFEST", &manifest_path);
std::env::set_var("POLYSCRIBE_MODELS_BASE_COPY_DIR", tmp_base.path());
std::env::set_var("POLYSCRIBE_MODELS_DIR", tmp_models.path());
}
// Run update
update_local_models().unwrap();
// Verify local file equals source content
let got = fs::read(&local_path).unwrap();
assert_eq!(got, new_content);
}
}