[refactor] modularize code by moving logic to polyscribe
crate; cleanup imports and remove redundant functions
This commit is contained in:
482
src/models.rs
482
src/models.rs
@@ -1,14 +1,14 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::env;
|
||||
use std::fs::{File, create_dir_all};
|
||||
use std::io::{self, Read, Write};
|
||||
use std::path::Path;
|
||||
use std::collections::BTreeMap;
|
||||
use std::time::Duration;
|
||||
use std::env;
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use serde::Deserialize;
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use reqwest::blocking::Client;
|
||||
use reqwest::redirect::Policy;
|
||||
use serde::Deserialize;
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
// Print to stderr only when not in quiet mode
|
||||
@@ -80,22 +80,33 @@ fn human_size(bytes: u64) -> String {
|
||||
const MB: f64 = KB * 1024.0;
|
||||
const GB: f64 = MB * 1024.0;
|
||||
let b = bytes as f64;
|
||||
if b >= GB { format!("{:.2} GiB", b / GB) }
|
||||
else if b >= MB { format!("{:.2} MiB", b / MB) }
|
||||
else if b >= KB { format!("{:.2} KiB", b / KB) }
|
||||
else { format!("{} B", bytes) }
|
||||
if b >= GB {
|
||||
format!("{:.2} GiB", b / GB)
|
||||
} else if b >= MB {
|
||||
format!("{:.2} MiB", b / MB)
|
||||
} else if b >= KB {
|
||||
format!("{:.2} KiB", b / KB)
|
||||
} else {
|
||||
format!("{} B", bytes)
|
||||
}
|
||||
}
|
||||
|
||||
fn to_hex_lower(bytes: &[u8]) -> String {
|
||||
let mut s = String::with_capacity(bytes.len() * 2);
|
||||
for b in bytes { s.push_str(&format!("{:02x}", b)); }
|
||||
for b in bytes {
|
||||
s.push_str(&format!("{:02x}", b));
|
||||
}
|
||||
s
|
||||
}
|
||||
|
||||
fn expected_sha_from_sibling(s: &HFSibling) -> Option<String> {
|
||||
if let Some(h) = &s.sha256 { return Some(h.to_lowercase()); }
|
||||
if let Some(h) = &s.sha256 {
|
||||
return Some(h.to_lowercase());
|
||||
}
|
||||
if let Some(lfs) = &s.lfs {
|
||||
if let Some(h) = &lfs.sha256 { return Some(h.to_lowercase()); }
|
||||
if let Some(h) = &lfs.sha256 {
|
||||
return Some(h.to_lowercase());
|
||||
}
|
||||
if let Some(oid) = &lfs.oid {
|
||||
// e.g. "sha256:abcdef..."
|
||||
if let Some(rest) = oid.strip_prefix("sha256:") {
|
||||
@@ -107,15 +118,23 @@ fn expected_sha_from_sibling(s: &HFSibling) -> Option<String> {
|
||||
}
|
||||
|
||||
fn size_from_sibling(s: &HFSibling) -> Option<u64> {
|
||||
if let Some(sz) = s.size { return Some(sz); }
|
||||
if let Some(lfs) = &s.lfs { return lfs.size; }
|
||||
if let Some(sz) = s.size {
|
||||
return Some(sz);
|
||||
}
|
||||
if let Some(lfs) = &s.lfs {
|
||||
return lfs.size;
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn expected_sha_from_tree(s: &HFTreeItem) -> Option<String> {
|
||||
if let Some(h) = &s.sha256 { return Some(h.to_lowercase()); }
|
||||
if let Some(h) = &s.sha256 {
|
||||
return Some(h.to_lowercase());
|
||||
}
|
||||
if let Some(lfs) = &s.lfs {
|
||||
if let Some(h) = &lfs.sha256 { return Some(h.to_lowercase()); }
|
||||
if let Some(h) = &lfs.sha256 {
|
||||
return Some(h.to_lowercase());
|
||||
}
|
||||
if let Some(oid) = &lfs.oid {
|
||||
if let Some(rest) = oid.strip_prefix("sha256:") {
|
||||
return Some(rest.to_lowercase().to_string());
|
||||
@@ -126,8 +145,12 @@ fn expected_sha_from_tree(s: &HFTreeItem) -> Option<String> {
|
||||
}
|
||||
|
||||
fn size_from_tree(s: &HFTreeItem) -> Option<u64> {
|
||||
if let Some(sz) = s.size { return Some(sz); }
|
||||
if let Some(lfs) = &s.lfs { return lfs.size; }
|
||||
if let Some(sz) = s.size {
|
||||
return Some(sz);
|
||||
}
|
||||
if let Some(lfs) = &s.lfs {
|
||||
return lfs.size;
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
@@ -136,12 +159,20 @@ fn fill_meta_via_head(repo: &str, name: &str) -> (Option<u64>, Option<String>) {
|
||||
.user_agent("PolyScribe/0.1 (+https://github.com/)")
|
||||
.redirect(Policy::none())
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build() {
|
||||
.build()
|
||||
{
|
||||
Ok(c) => c,
|
||||
Err(_) => return (None, None),
|
||||
};
|
||||
let url = format!("https://huggingface.co/{}/resolve/main/ggml-{}.bin", repo, name);
|
||||
let resp = match head_client.head(url).send().and_then(|r| r.error_for_status()) {
|
||||
let url = format!(
|
||||
"https://huggingface.co/{}/resolve/main/ggml-{}.bin",
|
||||
repo, name
|
||||
);
|
||||
let resp = match head_client
|
||||
.head(url)
|
||||
.send()
|
||||
.and_then(|r| r.error_for_status())
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(_) => return (None, None),
|
||||
};
|
||||
@@ -179,21 +210,40 @@ fn fill_meta_via_head(repo: &str, name: &str) -> (Option<u64>, Option<String>) {
|
||||
fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<ModelEntry>> {
|
||||
qlog!("Fetching online data: listing models from {}...", repo);
|
||||
// Prefer the tree endpoint for reliable size/hash metadata, then fall back to model metadata
|
||||
let tree_url = format!("https://huggingface.co/api/models/{}/tree/main?recursive=1", repo);
|
||||
let tree_url = format!(
|
||||
"https://huggingface.co/api/models/{}/tree/main?recursive=1",
|
||||
repo
|
||||
);
|
||||
let mut out: Vec<ModelEntry> = Vec::new();
|
||||
|
||||
match client.get(tree_url).send().and_then(|r| r.error_for_status()) {
|
||||
match client
|
||||
.get(tree_url)
|
||||
.send()
|
||||
.and_then(|r| r.error_for_status())
|
||||
{
|
||||
Ok(resp) => {
|
||||
match resp.json::<Vec<HFTreeItem>>() {
|
||||
Ok(items) => {
|
||||
for it in items {
|
||||
let path = it.path.clone();
|
||||
if !(path.starts_with("ggml-") && path.ends_with(".bin")) { continue; }
|
||||
let model_name = path.trim_start_matches("ggml-").trim_end_matches(".bin").to_string();
|
||||
if !(path.starts_with("ggml-") && path.ends_with(".bin")) {
|
||||
continue;
|
||||
}
|
||||
let model_name = path
|
||||
.trim_start_matches("ggml-")
|
||||
.trim_end_matches(".bin")
|
||||
.to_string();
|
||||
let (base, subtype) = split_model_name(&model_name);
|
||||
let size = size_from_tree(&it).unwrap_or(0);
|
||||
let sha256 = expected_sha_from_tree(&it);
|
||||
out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo: repo.to_string() });
|
||||
out.push(ModelEntry {
|
||||
name: model_name,
|
||||
base,
|
||||
subtype,
|
||||
size,
|
||||
sha256,
|
||||
repo: repo.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
Err(_) => { /* fall back below */ }
|
||||
@@ -210,30 +260,49 @@ fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<Model
|
||||
.and_then(|r| r.error_for_status())
|
||||
.context("Failed to query Hugging Face API")?;
|
||||
|
||||
let info: HFRepoInfo = resp.json().context("Failed to parse Hugging Face API response")?;
|
||||
let info: HFRepoInfo = resp
|
||||
.json()
|
||||
.context("Failed to parse Hugging Face API response")?;
|
||||
|
||||
if let Some(files) = info.siblings {
|
||||
for s in files {
|
||||
let rf = s.rfilename.clone();
|
||||
if !(rf.starts_with("ggml-") && rf.ends_with(".bin")) { continue; }
|
||||
let model_name = rf.trim_start_matches("ggml-").trim_end_matches(".bin").to_string();
|
||||
if !(rf.starts_with("ggml-") && rf.ends_with(".bin")) {
|
||||
continue;
|
||||
}
|
||||
let model_name = rf
|
||||
.trim_start_matches("ggml-")
|
||||
.trim_end_matches(".bin")
|
||||
.to_string();
|
||||
let (base, subtype) = split_model_name(&model_name);
|
||||
let size = size_from_sibling(&s).unwrap_or(0);
|
||||
let sha256 = expected_sha_from_sibling(&s);
|
||||
out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo: repo.to_string() });
|
||||
out.push(ModelEntry {
|
||||
name: model_name,
|
||||
base,
|
||||
subtype,
|
||||
size,
|
||||
sha256,
|
||||
repo: repo.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fill missing metadata (size/hash) via HEAD request if necessary
|
||||
if out.iter().any(|m| m.size == 0 || m.sha256.is_none()) {
|
||||
qlog!("Fetching online data: completing metadata checks for models in {}...", repo);
|
||||
qlog!(
|
||||
"Fetching online data: completing metadata checks for models in {}...",
|
||||
repo
|
||||
);
|
||||
}
|
||||
for m in out.iter_mut() {
|
||||
if m.size == 0 || m.sha256.is_none() {
|
||||
let (sz, sha) = fill_meta_via_head(&m.repo, &m.name);
|
||||
if m.size == 0 {
|
||||
if let Some(s) = sz { m.size = s; }
|
||||
if let Some(s) = sz {
|
||||
m.size = s;
|
||||
}
|
||||
}
|
||||
if m.sha256.is_none() {
|
||||
m.sha256 = sha;
|
||||
@@ -242,7 +311,12 @@ fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<Model
|
||||
}
|
||||
|
||||
// Sort by base then subtype then name for stable listing
|
||||
out.sort_by(|a, b| a.base.cmp(&b.base).then(a.subtype.cmp(&b.subtype)).then(a.name.cmp(&b.name)));
|
||||
out.sort_by(|a, b| {
|
||||
a.base
|
||||
.cmp(&b.base)
|
||||
.then(a.subtype.cmp(&b.subtype))
|
||||
.then(a.name.cmp(&b.name))
|
||||
});
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
@@ -251,32 +325,42 @@ fn fetch_all_models(client: &Client) -> Result<Vec<ModelEntry>> {
|
||||
let mut v1 = hf_fetch_repo_models(client, "ggerganov/whisper.cpp")?; // main repo must succeed
|
||||
|
||||
// Optional tinydiarize repo; ignore errors but log to stderr
|
||||
let mut v2: Vec<ModelEntry> = match hf_fetch_repo_models(client, "akashmjn/tinydiarize-whisper.cpp") {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
qlog!("Warning: failed to fetch optional repo akashmjn/tinydiarize-whisper.cpp: {:#}", e);
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
let mut v2: Vec<ModelEntry> =
|
||||
match hf_fetch_repo_models(client, "akashmjn/tinydiarize-whisper.cpp") {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
qlog!(
|
||||
"Warning: failed to fetch optional repo akashmjn/tinydiarize-whisper.cpp: {:#}",
|
||||
e
|
||||
);
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
v1.append(&mut v2);
|
||||
|
||||
// Deduplicate by name preferring ggerganov repo if duplicates
|
||||
let mut map: BTreeMap<String, ModelEntry> = BTreeMap::new();
|
||||
for m in v1 {
|
||||
map.entry(m.name.clone()).and_modify(|existing| {
|
||||
if existing.repo != "ggerganov/whisper.cpp" && m.repo == "ggerganov/whisper.cpp" {
|
||||
*existing = m.clone();
|
||||
}
|
||||
}).or_insert(m);
|
||||
map.entry(m.name.clone())
|
||||
.and_modify(|existing| {
|
||||
if existing.repo != "ggerganov/whisper.cpp" && m.repo == "ggerganov/whisper.cpp" {
|
||||
*existing = m.clone();
|
||||
}
|
||||
})
|
||||
.or_insert(m);
|
||||
}
|
||||
|
||||
let mut list: Vec<ModelEntry> = map.into_values().collect();
|
||||
list.sort_by(|a, b| a.base.cmp(&b.base).then(a.subtype.cmp(&b.subtype)).then(a.name.cmp(&b.name)));
|
||||
list.sort_by(|a, b| {
|
||||
a.base
|
||||
.cmp(&b.base)
|
||||
.then(a.subtype.cmp(&b.subtype))
|
||||
.then(a.name.cmp(&b.name))
|
||||
});
|
||||
Ok(list)
|
||||
}
|
||||
|
||||
|
||||
fn format_model_list(models: &[ModelEntry]) -> String {
|
||||
let mut out = String::new();
|
||||
out.push_str("Available ggml Whisper models:\n");
|
||||
@@ -305,7 +389,9 @@ fn format_model_list(models: &[ModelEntry]) -> String {
|
||||
));
|
||||
idx += 1;
|
||||
}
|
||||
out.push_str("\nEnter selection by indices (e.g., 1 3 5-7), or 'all', '*' for all, 'q' to cancel.\n");
|
||||
out.push_str(
|
||||
"\nEnter selection by indices (e.g., 1 3 5-7), or 'all', '*' for all, 'q' to cancel.\n",
|
||||
);
|
||||
out
|
||||
}
|
||||
|
||||
@@ -335,21 +421,33 @@ fn prompt_select_models_two_stage(models: &[ModelEntry]) -> Result<Vec<ModelEntr
|
||||
eprint!("Select base (number or name, 'q' to cancel): ");
|
||||
io::stderr().flush().ok();
|
||||
let mut line = String::new();
|
||||
io::stdin().read_line(&mut line).context("Failed to read base selection")?;
|
||||
io::stdin()
|
||||
.read_line(&mut line)
|
||||
.context("Failed to read base selection")?;
|
||||
let s = line.trim();
|
||||
if s.eq_ignore_ascii_case("q") || s.eq_ignore_ascii_case("quit") || s.eq_ignore_ascii_case("exit") {
|
||||
if s.eq_ignore_ascii_case("q")
|
||||
|| s.eq_ignore_ascii_case("quit")
|
||||
|| s.eq_ignore_ascii_case("exit")
|
||||
{
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let chosen_base = if let Ok(i) = s.parse::<usize>() {
|
||||
if i >= 1 && i <= bases.len() { Some(bases[i - 1].clone()) } else { None }
|
||||
if i >= 1 && i <= bases.len() {
|
||||
Some(bases[i - 1].clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else if !s.is_empty() {
|
||||
// accept exact name match (case-insensitive)
|
||||
bases.iter().find(|b| b.eq_ignore_ascii_case(s)).cloned()
|
||||
} else { None };
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(base) = chosen_base {
|
||||
// 2) Choose sub-type(s) within that base
|
||||
let filtered: Vec<ModelEntry> = models.iter().filter(|m| m.base == base).cloned().collect();
|
||||
let filtered: Vec<ModelEntry> =
|
||||
models.iter().filter(|m| m.base == base).cloned().collect();
|
||||
if filtered.is_empty() {
|
||||
eprintln!("No models found for base '{}'.", base);
|
||||
continue;
|
||||
@@ -370,22 +468,32 @@ fn prompt_select_models_two_stage(models: &[ModelEntry]) -> Result<Vec<ModelEntr
|
||||
eprint!("Selection: ");
|
||||
io::stderr().flush().ok();
|
||||
let mut line2 = String::new();
|
||||
io::stdin().read_line(&mut line2).context("Failed to read selection")?;
|
||||
io::stdin()
|
||||
.read_line(&mut line2)
|
||||
.context("Failed to read selection")?;
|
||||
let s2 = line2.trim().to_lowercase();
|
||||
if s2 == "q" || s2 == "quit" || s2 == "exit" { return Ok(Vec::new()); }
|
||||
if s2 == "q" || s2 == "quit" || s2 == "exit" {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut selected: Vec<usize> = Vec::new();
|
||||
if s2 == "all" || s2 == "*" {
|
||||
selected = (1..idx).collect();
|
||||
} else if !s2.is_empty() {
|
||||
for part in s2.split(|c| c == ',' || c == ' ' || c == ';') {
|
||||
let part = part.trim();
|
||||
if part.is_empty() { continue; }
|
||||
if part.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Some((a, b)) = part.split_once('-') {
|
||||
if let (Ok(ia), Ok(ib)) = (a.parse::<usize>(), b.parse::<usize>()) {
|
||||
if ia >= 1 && ib < idx && ia <= ib { selected.extend(ia..=ib); }
|
||||
if ia >= 1 && ib < idx && ia <= ib {
|
||||
selected.extend(ia..=ib);
|
||||
}
|
||||
}
|
||||
} else if let Ok(i) = part.parse::<usize>() {
|
||||
if i >= 1 && i < idx { selected.push(i); }
|
||||
if i >= 1 && i < idx {
|
||||
selected.push(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -395,12 +503,17 @@ fn prompt_select_models_two_stage(models: &[ModelEntry]) -> Result<Vec<ModelEntr
|
||||
eprintln!("No valid selection. Please try again or 'q' to cancel.");
|
||||
continue;
|
||||
}
|
||||
let chosen: Vec<ModelEntry> = selected.into_iter().map(|i| filtered[index_map[i - 1]].clone()).collect();
|
||||
let chosen: Vec<ModelEntry> = selected
|
||||
.into_iter()
|
||||
.map(|i| filtered[index_map[i - 1]].clone())
|
||||
.collect();
|
||||
return Ok(chosen);
|
||||
}
|
||||
} else {
|
||||
eprintln!("Invalid base selection. Please enter a number from 1-{} or a base name.", bases.len());
|
||||
continue;
|
||||
eprintln!(
|
||||
"Invalid base selection. Please enter a number from 1-{} or a base name.",
|
||||
bases.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -413,52 +526,30 @@ fn compute_file_sha256_hex(path: &Path) -> Result<String> {
|
||||
let mut buf = [0u8; 1024 * 128];
|
||||
loop {
|
||||
let n = reader.read(&mut buf).context("Read error during hashing")?;
|
||||
if n == 0 { break; }
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
hasher.update(&buf[..n]);
|
||||
}
|
||||
Ok(to_hex_lower(&hasher.finalize()))
|
||||
}
|
||||
|
||||
fn models_dir_path() -> std::path::PathBuf {
|
||||
// Highest priority: explicit override
|
||||
if let Ok(p) = env::var("POLYSCRIBE_MODELS_DIR") {
|
||||
let pb = std::path::PathBuf::from(p);
|
||||
if !pb.as_os_str().is_empty() { return pb; }
|
||||
}
|
||||
// In debug builds, keep local ./models for convenience
|
||||
if cfg!(debug_assertions) {
|
||||
return std::path::PathBuf::from("models");
|
||||
}
|
||||
// In release builds, choose a user-writable data directory
|
||||
if let Ok(xdg) = env::var("XDG_DATA_HOME") {
|
||||
if !xdg.is_empty() {
|
||||
return std::path::PathBuf::from(xdg).join("polyscribe").join("models");
|
||||
}
|
||||
}
|
||||
if let Ok(home) = env::var("HOME") {
|
||||
if !home.is_empty() {
|
||||
return std::path::PathBuf::from(home)
|
||||
.join(".local")
|
||||
.join("share")
|
||||
.join("polyscribe")
|
||||
.join("models");
|
||||
}
|
||||
}
|
||||
// Last resort fallback
|
||||
std::path::PathBuf::from("models")
|
||||
}
|
||||
|
||||
/// Interactively list and download Whisper models from Hugging Face into the models directory.
|
||||
pub fn run_interactive_model_downloader() -> Result<()> {
|
||||
let models_dir_buf = models_dir_path();
|
||||
let models_dir_buf = crate::models_dir_path();
|
||||
let models_dir = models_dir_buf.as_path();
|
||||
if !models_dir.exists() { create_dir_all(models_dir).context("Failed to create models directory")?; }
|
||||
if !models_dir.exists() {
|
||||
create_dir_all(models_dir).context("Failed to create models directory")?;
|
||||
}
|
||||
let client = Client::builder()
|
||||
.user_agent("PolyScribe/0.1 (+https://github.com/)")
|
||||
.timeout(std::time::Duration::from_secs(600))
|
||||
.build()
|
||||
.context("Failed to build HTTP client")?;
|
||||
|
||||
qlog!("Fetching online data: contacting Hugging Face to retrieve available models (this may take a moment)...");
|
||||
qlog!(
|
||||
"Fetching online data: contacting Hugging Face to retrieve available models (this may take a moment)..."
|
||||
);
|
||||
let models = fetch_all_models(&client)?;
|
||||
if models.is_empty() {
|
||||
qlog!("No models found on Hugging Face listing. Please try again later.");
|
||||
@@ -470,12 +561,15 @@ pub fn run_interactive_model_downloader() -> Result<()> {
|
||||
return Ok(());
|
||||
}
|
||||
for m in selected {
|
||||
if let Err(e) = download_one_model(&client, models_dir, &m) { qlog!("Error: {:#}", e); }
|
||||
if let Err(e) = download_one_model(&client, models_dir, &m) {
|
||||
qlog!("Error: {:#}", e);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> {
|
||||
/// Download a single model entry into the given models directory, verifying SHA-256 when available.
|
||||
fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> {
|
||||
let final_path = models_dir.join(format!("ggml-{}.bin", entry.name));
|
||||
|
||||
// If the model already exists, verify against online metadata
|
||||
@@ -497,9 +591,10 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
||||
}
|
||||
Err(e) => {
|
||||
qlog!(
|
||||
"Warning: failed to hash existing {}: {}. Will re-download to ensure correctness.",
|
||||
final_path.display(), e
|
||||
);
|
||||
"Warning: failed to hash existing {}: {}. Will re-download to ensure correctness.",
|
||||
final_path.display(),
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
} else if entry.size > 0 {
|
||||
@@ -508,20 +603,24 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
||||
if md.len() == entry.size {
|
||||
qlog!(
|
||||
"Model {} appears up-to-date by size ({}).",
|
||||
final_path.display(), entry.size
|
||||
final_path.display(),
|
||||
entry.size
|
||||
);
|
||||
return Ok(());
|
||||
} else {
|
||||
qlog!(
|
||||
"Local model {} size ({}) differs from online ({}). Updating...",
|
||||
final_path.display(), md.len(), entry.size
|
||||
final_path.display(),
|
||||
md.len(),
|
||||
entry.size
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
qlog!(
|
||||
"Warning: failed to stat existing {}: {}. Will re-download to ensure correctness.",
|
||||
final_path.display(), e
|
||||
final_path.display(),
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -540,9 +639,16 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
||||
if src_path.exists() {
|
||||
qlog!("Copying {} from {}...", entry.name, src_path.display());
|
||||
let tmp_path = models_dir.join(format!("ggml-{}.bin.part", entry.name));
|
||||
if tmp_path.exists() { let _ = std::fs::remove_file(&tmp_path); }
|
||||
std::fs::copy(&src_path, &tmp_path)
|
||||
.with_context(|| format!("Failed to copy from {} to {}", src_path.display(), tmp_path.display()))?;
|
||||
if tmp_path.exists() {
|
||||
let _ = std::fs::remove_file(&tmp_path);
|
||||
}
|
||||
std::fs::copy(&src_path, &tmp_path).with_context(|| {
|
||||
format!(
|
||||
"Failed to copy from {} to {}",
|
||||
src_path.display(),
|
||||
tmp_path.display()
|
||||
)
|
||||
})?;
|
||||
// Verify hash if available
|
||||
if let Some(expected) = &entry.sha256 {
|
||||
let got = compute_file_sha256_hex(&tmp_path)?;
|
||||
@@ -550,12 +656,16 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
||||
let _ = std::fs::remove_file(&tmp_path);
|
||||
return Err(anyhow!(
|
||||
"SHA-256 mismatch for {} (copied): expected {}, got {}",
|
||||
entry.name, expected, got
|
||||
entry.name,
|
||||
expected,
|
||||
got
|
||||
));
|
||||
}
|
||||
}
|
||||
// Replace existing file safely
|
||||
if final_path.exists() { let _ = std::fs::remove_file(&final_path); }
|
||||
if final_path.exists() {
|
||||
let _ = std::fs::remove_file(&final_path);
|
||||
}
|
||||
std::fs::rename(&tmp_path, &final_path)
|
||||
.with_context(|| format!("Failed to move into place: {}", final_path.display()))?;
|
||||
qlog!("Saved: {}", final_path.display());
|
||||
@@ -563,8 +673,16 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
||||
}
|
||||
}
|
||||
|
||||
let url = format!("https://huggingface.co/{}/resolve/main/ggml-{}.bin", entry.repo, entry.name);
|
||||
qlog!("Downloading {} ({} | {})...", entry.name, human_size(entry.size), url);
|
||||
let url = format!(
|
||||
"https://huggingface.co/{}/resolve/main/ggml-{}.bin",
|
||||
entry.repo, entry.name
|
||||
);
|
||||
qlog!(
|
||||
"Downloading {} ({} | {})...",
|
||||
entry.name,
|
||||
human_size(entry.size),
|
||||
url
|
||||
);
|
||||
let mut resp = client
|
||||
.get(url)
|
||||
.send()
|
||||
@@ -577,14 +695,16 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
||||
}
|
||||
let mut file = std::io::BufWriter::new(
|
||||
File::create(&tmp_path)
|
||||
.with_context(|| format!("Failed to create {}", tmp_path.display()))?
|
||||
.with_context(|| format!("Failed to create {}", tmp_path.display()))?,
|
||||
);
|
||||
|
||||
let mut hasher = Sha256::new();
|
||||
let mut buf = [0u8; 1024 * 128];
|
||||
loop {
|
||||
let n = resp.read(&mut buf).context("Network read error")?;
|
||||
if n == 0 { break; }
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
hasher.update(&buf[..n]);
|
||||
file.write_all(&buf[..n]).context("Write error")?;
|
||||
}
|
||||
@@ -596,11 +716,16 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
||||
let _ = std::fs::remove_file(&tmp_path);
|
||||
return Err(anyhow!(
|
||||
"SHA-256 mismatch for {}: expected {}, got {}",
|
||||
entry.name, expected, got
|
||||
entry.name,
|
||||
expected,
|
||||
got
|
||||
));
|
||||
}
|
||||
} else {
|
||||
qlog!("Warning: no SHA-256 available for {}. Skipping verification.", entry.name);
|
||||
qlog!(
|
||||
"Warning: no SHA-256 available for {}. Skipping verification.",
|
||||
entry.name
|
||||
);
|
||||
}
|
||||
// Replace existing file safely
|
||||
if final_path.exists() {
|
||||
@@ -612,8 +737,9 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update locally stored models by re-downloading when size or hash does not match online metadata.
|
||||
pub fn update_local_models() -> Result<()> {
|
||||
let models_dir_buf = models_dir_path();
|
||||
let models_dir_buf = crate::models_dir_path();
|
||||
let models_dir = models_dir_buf.as_path();
|
||||
if !models_dir.exists() {
|
||||
create_dir_all(models_dir).context("Failed to create models directory")?;
|
||||
@@ -627,13 +753,14 @@ pub fn update_local_models() -> Result<()> {
|
||||
.context("Failed to build HTTP client")?;
|
||||
|
||||
// Obtain manifest: env override or online fetch
|
||||
let models: Vec<ModelEntry> = if let Ok(manifest_path) = env::var("POLYSCRIBE_MODELS_MANIFEST") {
|
||||
let models: Vec<ModelEntry> = if let Ok(manifest_path) = env::var("POLYSCRIBE_MODELS_MANIFEST")
|
||||
{
|
||||
let data = std::fs::read_to_string(&manifest_path)
|
||||
.with_context(|| format!("Failed to read manifest at {}", manifest_path))?;
|
||||
let mut list: Vec<ModelEntry> = serde_json::from_str(&data)
|
||||
.with_context(|| format!("Invalid JSON manifest: {}", manifest_path))?;
|
||||
// sort for stability
|
||||
list.sort_by(|a,b| a.name.cmp(&b.name));
|
||||
list.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
list
|
||||
} else {
|
||||
fetch_all_models(&client)?
|
||||
@@ -641,7 +768,9 @@ pub fn update_local_models() -> Result<()> {
|
||||
|
||||
// Map name -> entry for fast lookup
|
||||
let mut map: BTreeMap<String, ModelEntry> = BTreeMap::new();
|
||||
for m in models { map.insert(m.name.clone(), m); }
|
||||
for m in models {
|
||||
map.insert(m.name.clone(), m);
|
||||
}
|
||||
|
||||
// Scan local ggml-*.bin models
|
||||
let rd = std::fs::read_dir(models_dir)
|
||||
@@ -649,10 +778,20 @@ pub fn update_local_models() -> Result<()> {
|
||||
for entry in rd {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if !path.is_file() { continue; }
|
||||
let fname = match path.file_name().and_then(|s| s.to_str()) { Some(s) => s.to_string(), None => continue };
|
||||
if !fname.starts_with("ggml-") || !fname.ends_with(".bin") { continue; }
|
||||
let model_name = fname.trim_start_matches("ggml-").trim_end_matches(".bin").to_string();
|
||||
if !path.is_file() {
|
||||
continue;
|
||||
}
|
||||
let fname = match path.file_name().and_then(|s| s.to_str()) {
|
||||
Some(s) => s.to_string(),
|
||||
None => continue,
|
||||
};
|
||||
if !fname.starts_with("ggml-") || !fname.ends_with(".bin") {
|
||||
continue;
|
||||
}
|
||||
let model_name = fname
|
||||
.trim_start_matches("ggml-")
|
||||
.trim_end_matches(".bin")
|
||||
.to_string();
|
||||
|
||||
if let Some(remote) = map.get(&model_name) {
|
||||
// If SHA256 available, verify and update if mismatch
|
||||
@@ -664,11 +803,11 @@ pub fn update_local_models() -> Result<()> {
|
||||
continue;
|
||||
} else {
|
||||
qlog!(
|
||||
"{} hash differs (local {}.. != remote {}..). Updating...",
|
||||
fname,
|
||||
&local_hash[..std::cmp::min(8, local_hash.len())],
|
||||
&expected[..std::cmp::min(8, expected.len())]
|
||||
);
|
||||
"{} hash differs (local {}.. != remote {}..). Updating...",
|
||||
fname,
|
||||
&local_hash[..std::cmp::min(8, local_hash.len())],
|
||||
&expected[..std::cmp::min(8, expected.len())]
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -683,7 +822,12 @@ pub fn update_local_models() -> Result<()> {
|
||||
continue;
|
||||
}
|
||||
Ok(md) => {
|
||||
qlog!("{} size {} differs from remote {}. Updating...", fname, md.len(), remote.size);
|
||||
qlog!(
|
||||
"{} size {} differs from remote {}. Updating...",
|
||||
fname,
|
||||
md.len(),
|
||||
remote.size
|
||||
);
|
||||
download_one_model(&client, models_dir, remote)?;
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -702,20 +846,43 @@ pub fn update_local_models() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_format_model_list_spacing_and_structure() {
|
||||
let models = vec![
|
||||
ModelEntry { name: "tiny.en-q5_1".to_string(), base: "tiny".to_string(), subtype: "en-q5_1".to_string(), size: 1024*1024, sha256: Some("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef".to_string()), repo: "ggerganov/whisper.cpp".to_string() },
|
||||
ModelEntry { name: "tiny-q5_1".to_string(), base: "tiny".to_string(), subtype: "q5_1".to_string(), size: 2048, sha256: None, repo: "ggerganov/whisper.cpp".to_string() },
|
||||
ModelEntry { name: "base.en-q5_1".to_string(), base: "base".to_string(), subtype: "en-q5_1".to_string(), size: 10, sha256: Some("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string()), repo: "akashmjn/tinydiarize-whisper.cpp".to_string() },
|
||||
ModelEntry {
|
||||
name: "tiny.en-q5_1".to_string(),
|
||||
base: "tiny".to_string(),
|
||||
subtype: "en-q5_1".to_string(),
|
||||
size: 1024 * 1024,
|
||||
sha256: Some(
|
||||
"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef".to_string(),
|
||||
),
|
||||
repo: "ggerganov/whisper.cpp".to_string(),
|
||||
},
|
||||
ModelEntry {
|
||||
name: "tiny-q5_1".to_string(),
|
||||
base: "tiny".to_string(),
|
||||
subtype: "q5_1".to_string(),
|
||||
size: 2048,
|
||||
sha256: None,
|
||||
repo: "ggerganov/whisper.cpp".to_string(),
|
||||
},
|
||||
ModelEntry {
|
||||
name: "base.en-q5_1".to_string(),
|
||||
base: "base".to_string(),
|
||||
subtype: "en-q5_1".to_string(),
|
||||
size: 10,
|
||||
sha256: Some(
|
||||
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string(),
|
||||
),
|
||||
repo: "akashmjn/tinydiarize-whisper.cpp".to_string(),
|
||||
},
|
||||
];
|
||||
let s = format_model_list(&models);
|
||||
// Header present
|
||||
@@ -724,7 +891,10 @@ mod tests {
|
||||
assert!(s.contains("\ntiny:\n"));
|
||||
assert!(s.contains("\nbase:\n"));
|
||||
// No immediate double space before a bracket after parenthesis
|
||||
assert!(!s.contains(") ["), "should not have double space immediately before bracket");
|
||||
assert!(
|
||||
!s.contains(") ["),
|
||||
"should not have double space immediately before bracket"
|
||||
);
|
||||
// Lines contain normalized spacing around pipes and no hash
|
||||
assert!(s.contains("[ggerganov/whisper.cpp | 1.00 MiB]"));
|
||||
assert!(s.contains("[ggerganov/whisper.cpp | 2.00 KiB]"));
|
||||
@@ -748,7 +918,9 @@ mod tests {
|
||||
hasher.update(data);
|
||||
let out = hasher.finalize();
|
||||
let mut s = String::new();
|
||||
for b in out { s.push_str(&format!("{:02x}", b)); }
|
||||
for b in out {
|
||||
s.push_str(&format!("{:02x}", b));
|
||||
}
|
||||
s
|
||||
}
|
||||
|
||||
@@ -786,7 +958,11 @@ mod tests {
|
||||
"repo": "ggerganov/whisper.cpp"
|
||||
}
|
||||
]);
|
||||
fs::write(&manifest_path, serde_json::to_string_pretty(&manifest).unwrap()).unwrap();
|
||||
fs::write(
|
||||
&manifest_path,
|
||||
serde_json::to_string_pretty(&manifest).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Set env vars to force offline behavior and directories
|
||||
unsafe {
|
||||
@@ -807,34 +983,54 @@ mod tests {
|
||||
#[cfg(debug_assertions)]
|
||||
fn test_models_dir_path_default_debug_and_env_override_models_mod() {
|
||||
// clear override
|
||||
unsafe { std::env::remove_var("POLYSCRIBE_MODELS_DIR"); }
|
||||
assert_eq!(super::models_dir_path(), std::path::PathBuf::from("models"));
|
||||
unsafe {
|
||||
std::env::remove_var("POLYSCRIBE_MODELS_DIR");
|
||||
}
|
||||
assert_eq!(crate::models_dir_path(), std::path::PathBuf::from("models"));
|
||||
// override
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
unsafe { std::env::set_var("POLYSCRIBE_MODELS_DIR", tmp.path()); }
|
||||
assert_eq!(super::models_dir_path(), tmp.path().to_path_buf());
|
||||
unsafe {
|
||||
std::env::set_var("POLYSCRIBE_MODELS_DIR", tmp.path());
|
||||
}
|
||||
assert_eq!(crate::models_dir_path(), tmp.path().to_path_buf());
|
||||
// cleanup
|
||||
unsafe { std::env::remove_var("POLYSCRIBE_MODELS_DIR"); }
|
||||
unsafe {
|
||||
std::env::remove_var("POLYSCRIBE_MODELS_DIR");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(not(debug_assertions))]
|
||||
fn test_models_dir_path_default_release_models_mod() {
|
||||
unsafe { std::env::remove_var("POLYSCRIBE_MODELS_DIR"); }
|
||||
unsafe {
|
||||
std::env::remove_var("POLYSCRIBE_MODELS_DIR");
|
||||
}
|
||||
// With XDG_DATA_HOME set
|
||||
let tmp_xdg = tempfile::tempdir().unwrap();
|
||||
unsafe {
|
||||
std::env::set_var("XDG_DATA_HOME", tmp_xdg.path());
|
||||
std::env::remove_var("HOME");
|
||||
}
|
||||
assert_eq!(super::models_dir_path(), std::path::PathBuf::from(tmp_xdg.path()).join("polyscribe").join("models"));
|
||||
assert_eq!(
|
||||
crate::models_dir_path(),
|
||||
std::path::PathBuf::from(tmp_xdg.path())
|
||||
.join("polyscribe")
|
||||
.join("models")
|
||||
);
|
||||
// With HOME fallback
|
||||
let tmp_home = tempfile::tempdir().unwrap();
|
||||
unsafe {
|
||||
std::env::remove_var("XDG_DATA_HOME");
|
||||
std::env::set_var("HOME", tmp_home.path());
|
||||
}
|
||||
assert_eq!(super::models_dir_path(), std::path::PathBuf::from(tmp_home.path()).join(".local").join("share").join("polyscribe").join("models"));
|
||||
assert_eq!(
|
||||
super::models_dir_path(),
|
||||
std::path::PathBuf::from(tmp_home.path())
|
||||
.join(".local")
|
||||
.join("share")
|
||||
.join("polyscribe")
|
||||
.join("models")
|
||||
);
|
||||
unsafe {
|
||||
std::env::remove_var("XDG_DATA_HOME");
|
||||
std::env::remove_var("HOME");
|
||||
|
Reference in New Issue
Block a user