[refactor] extract model downloading functionality into a separate models
module
This commit is contained in:
7
TODO.md
7
TODO.md
@@ -1,8 +1,5 @@
|
|||||||
- refactor into multiple files
|
|
||||||
|
|
||||||
- fix cli output for model display
|
|
||||||
|
|
||||||
- update the project to no more use features
|
- update the project to no more use features
|
||||||
|
- update last_model to be only used during one run
|
||||||
|
|
||||||
- rename project to "PolyScribe"
|
- rename project to "PolyScribe"
|
||||||
|
|
||||||
@@ -13,6 +10,8 @@
|
|||||||
- for merging (command line flag) -> if not present, treat each file as separate output (--merge | -m)
|
- for merging (command line flag) -> if not present, treat each file as separate output (--merge | -m)
|
||||||
- for merge + separate output -> if present, treat each file as separate output and also output a merged version (--merge-and-separate)
|
- for merge + separate output -> if present, treat each file as separate output and also output a merged version (--merge-and-separate)
|
||||||
- set speaker-names per input-file -> prompt user for each file if flag is set (--set-speaker-names)
|
- set speaker-names per input-file -> prompt user for each file if flag is set (--set-speaker-names)
|
||||||
|
- fix cli output for model display
|
||||||
|
- refactor into proper cli app
|
||||||
|
|
||||||
- add support for video files -> use ffmpeg to extract audio
|
- add support for video files -> use ffmpeg to extract audio
|
||||||
|
|
||||||
|
490
src/main.rs
490
src/main.rs
@@ -3,20 +3,17 @@ use std::io::{self, Read, Write};
|
|||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::process::Command;
|
use std::process::Command;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::collections::BTreeMap;
|
|
||||||
|
|
||||||
use anyhow::{anyhow, Context, Result};
|
use anyhow::{anyhow, Context, Result};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use chrono::Local;
|
use chrono::Local;
|
||||||
use reqwest::blocking::Client;
|
|
||||||
use reqwest::redirect::Policy;
|
|
||||||
use sha2::{Digest, Sha256};
|
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
#[cfg(feature = "native-whisper")]
|
#[cfg(feature = "native-whisper")]
|
||||||
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
|
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
|
||||||
|
|
||||||
|
mod models;
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(name = "merge_transcripts", version, about = "Merge multiple JSON transcripts into one or transcribe audio using native whisper")]
|
#[command(name = "merge_transcripts", version, about = "Merge multiple JSON transcripts into one or transcribe audio using native whisper")]
|
||||||
struct Args {
|
struct Args {
|
||||||
@@ -182,7 +179,7 @@ fn find_model_file() -> Result<PathBuf> {
|
|||||||
io::stdin().read_line(&mut input).ok();
|
io::stdin().read_line(&mut input).ok();
|
||||||
let ans = input.trim().to_lowercase();
|
let ans = input.trim().to_lowercase();
|
||||||
if ans.is_empty() || ans == "y" || ans == "yes" {
|
if ans.is_empty() || ans == "y" || ans == "yes" {
|
||||||
if let Err(e) = run_interactive_model_downloader() {
|
if let Err(e) = models::run_interactive_model_downloader() {
|
||||||
eprintln!("Downloader failed: {:#}", e);
|
eprintln!("Downloader failed: {:#}", e);
|
||||||
}
|
}
|
||||||
// Re-scan
|
// Re-scan
|
||||||
@@ -346,7 +343,7 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
// If requested, run the interactive model downloader first. If no inputs were provided, exit after downloading.
|
// If requested, run the interactive model downloader first. If no inputs were provided, exit after downloading.
|
||||||
if args.download_models {
|
if args.download_models {
|
||||||
if let Err(e) = run_interactive_model_downloader() {
|
if let Err(e) = models::run_interactive_model_downloader() {
|
||||||
eprintln!("Model downloader failed: {:#}", e);
|
eprintln!("Model downloader failed: {:#}", e);
|
||||||
}
|
}
|
||||||
if args.inputs.is_empty() {
|
if args.inputs.is_empty() {
|
||||||
@@ -503,482 +500,3 @@ fn main() -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Model downloader: list & download ggml models from Hugging Face ---
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct HFLfsMeta {
|
|
||||||
oid: Option<String>,
|
|
||||||
size: Option<u64>,
|
|
||||||
sha256: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct HFSibling {
|
|
||||||
rfilename: String,
|
|
||||||
size: Option<u64>,
|
|
||||||
sha256: Option<String>,
|
|
||||||
lfs: Option<HFLfsMeta>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct HFRepoInfo {
|
|
||||||
// When using ?expand=files the field is named 'siblings'
|
|
||||||
siblings: Option<Vec<HFSibling>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct HFTreeItem {
|
|
||||||
path: String,
|
|
||||||
size: Option<u64>,
|
|
||||||
sha256: Option<String>,
|
|
||||||
lfs: Option<HFLfsMeta>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
struct ModelEntry {
|
|
||||||
// e.g. "tiny.en-q5_1"
|
|
||||||
name: String,
|
|
||||||
base: String,
|
|
||||||
subtype: String,
|
|
||||||
size: u64,
|
|
||||||
sha256: Option<String>,
|
|
||||||
repo: &'static str, // e.g. "ggerganov/whisper.cpp"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn split_model_name(model: &str) -> (String, String) {
|
|
||||||
let mut idx = None;
|
|
||||||
for (i, ch) in model.char_indices() {
|
|
||||||
if ch == '.' || ch == '-' {
|
|
||||||
idx = Some(i);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if let Some(i) = idx {
|
|
||||||
(model[..i].to_string(), model[i + 1..].to_string())
|
|
||||||
} else {
|
|
||||||
(model.to_string(), String::new())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn human_size(bytes: u64) -> String {
|
|
||||||
const KB: f64 = 1024.0;
|
|
||||||
const MB: f64 = KB * 1024.0;
|
|
||||||
const GB: f64 = MB * 1024.0;
|
|
||||||
let b = bytes as f64;
|
|
||||||
if b >= GB { format!("{:.2} GiB", b / GB) }
|
|
||||||
else if b >= MB { format!("{:.2} MiB", b / MB) }
|
|
||||||
else if b >= KB { format!("{:.2} KiB", b / KB) }
|
|
||||||
else { format!("{} B", bytes) }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_hex_lower(bytes: &[u8]) -> String {
|
|
||||||
let mut s = String::with_capacity(bytes.len() * 2);
|
|
||||||
for b in bytes { s.push_str(&format!("{:02x}", b)); }
|
|
||||||
s
|
|
||||||
}
|
|
||||||
|
|
||||||
fn expected_sha_from_sibling(s: &HFSibling) -> Option<String> {
|
|
||||||
if let Some(h) = &s.sha256 { return Some(h.to_lowercase()); }
|
|
||||||
if let Some(lfs) = &s.lfs {
|
|
||||||
if let Some(h) = &lfs.sha256 { return Some(h.to_lowercase()); }
|
|
||||||
if let Some(oid) = &lfs.oid {
|
|
||||||
// e.g. "sha256:abcdef..."
|
|
||||||
if let Some(rest) = oid.strip_prefix("sha256:") {
|
|
||||||
return Some(rest.to_lowercase().to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
fn size_from_sibling(s: &HFSibling) -> Option<u64> {
|
|
||||||
if let Some(sz) = s.size { return Some(sz); }
|
|
||||||
if let Some(lfs) = &s.lfs { return lfs.size; }
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
fn expected_sha_from_tree(s: &HFTreeItem) -> Option<String> {
|
|
||||||
if let Some(h) = &s.sha256 { return Some(h.to_lowercase()); }
|
|
||||||
if let Some(lfs) = &s.lfs {
|
|
||||||
if let Some(h) = &lfs.sha256 { return Some(h.to_lowercase()); }
|
|
||||||
if let Some(oid) = &lfs.oid {
|
|
||||||
if let Some(rest) = oid.strip_prefix("sha256:") {
|
|
||||||
return Some(rest.to_lowercase().to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
fn size_from_tree(s: &HFTreeItem) -> Option<u64> {
|
|
||||||
if let Some(sz) = s.size { return Some(sz); }
|
|
||||||
if let Some(lfs) = &s.lfs { return lfs.size; }
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
fn fill_meta_via_head(repo: &'static str, name: &str) -> (Option<u64>, Option<String>) {
|
|
||||||
let head_client = match Client::builder()
|
|
||||||
.user_agent("dialogue_merger/0.1 (+https://github.com/)")
|
|
||||||
.redirect(Policy::none())
|
|
||||||
.timeout(Duration::from_secs(30))
|
|
||||||
.build() {
|
|
||||||
Ok(c) => c,
|
|
||||||
Err(_) => return (None, None),
|
|
||||||
};
|
|
||||||
let url = format!("https://huggingface.co/{}/resolve/main/ggml-{}.bin", repo, name);
|
|
||||||
let resp = match head_client.head(url).send().and_then(|r| r.error_for_status()) {
|
|
||||||
Ok(r) => r,
|
|
||||||
Err(_) => return (None, None),
|
|
||||||
};
|
|
||||||
let headers = resp.headers();
|
|
||||||
let size = headers
|
|
||||||
.get("x-linked-size")
|
|
||||||
.and_then(|v| v.to_str().ok())
|
|
||||||
.and_then(|s| s.parse::<u64>().ok());
|
|
||||||
let mut sha = headers
|
|
||||||
.get("x-linked-etag")
|
|
||||||
.and_then(|v| v.to_str().ok())
|
|
||||||
.map(|s| s.trim().trim_matches('"').to_string());
|
|
||||||
if let Some(h) = &mut sha {
|
|
||||||
h.make_ascii_lowercase();
|
|
||||||
if h.len() != 64 || !h.chars().all(|c| c.is_ascii_hexdigit()) {
|
|
||||||
sha = None;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Fallback: try x-xet-hash header if x-linked-etag is missing/invalid
|
|
||||||
if sha.is_none() {
|
|
||||||
sha = headers
|
|
||||||
.get("x-xet-hash")
|
|
||||||
.and_then(|v| v.to_str().ok())
|
|
||||||
.map(|s| s.trim().trim_matches('"').to_string());
|
|
||||||
if let Some(h) = &mut sha {
|
|
||||||
h.make_ascii_lowercase();
|
|
||||||
if h.len() != 64 || !h.chars().all(|c| c.is_ascii_hexdigit()) {
|
|
||||||
sha = None;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(size, sha)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<ModelEntry>> {
|
|
||||||
eprintln!("Fetching online data: listing models from {}...", repo);
|
|
||||||
// Prefer the tree endpoint for reliable size/hash metadata, then fall back to model metadata
|
|
||||||
let tree_url = format!("https://huggingface.co/api/models/{}/tree/main?recursive=1", repo);
|
|
||||||
let mut out: Vec<ModelEntry> = Vec::new();
|
|
||||||
|
|
||||||
match client.get(tree_url).send().and_then(|r| r.error_for_status()) {
|
|
||||||
Ok(resp) => {
|
|
||||||
match resp.json::<Vec<HFTreeItem>>() {
|
|
||||||
Ok(items) => {
|
|
||||||
for it in items {
|
|
||||||
let path = it.path.clone();
|
|
||||||
if !(path.starts_with("ggml-") && path.ends_with(".bin")) { continue; }
|
|
||||||
let model_name = path.trim_start_matches("ggml-").trim_end_matches(".bin").to_string();
|
|
||||||
let (base, subtype) = split_model_name(&model_name);
|
|
||||||
let size = size_from_tree(&it).unwrap_or(0);
|
|
||||||
let sha256 = expected_sha_from_tree(&it);
|
|
||||||
out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(_) => { /* fall back below */ }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(_) => { /* fall back below */ }
|
|
||||||
}
|
|
||||||
|
|
||||||
if out.is_empty() {
|
|
||||||
let url = format!("https://huggingface.co/api/models/{}", repo);
|
|
||||||
let resp = client
|
|
||||||
.get(url)
|
|
||||||
.send()
|
|
||||||
.and_then(|r| r.error_for_status())
|
|
||||||
.context("Failed to query Hugging Face API")?;
|
|
||||||
|
|
||||||
let info: HFRepoInfo = resp.json().context("Failed to parse Hugging Face API response")?;
|
|
||||||
|
|
||||||
if let Some(files) = info.siblings {
|
|
||||||
for s in files {
|
|
||||||
let rf = s.rfilename.clone();
|
|
||||||
if !(rf.starts_with("ggml-") && rf.ends_with(".bin")) { continue; }
|
|
||||||
let model_name = rf.trim_start_matches("ggml-").trim_end_matches(".bin").to_string();
|
|
||||||
let (base, subtype) = split_model_name(&model_name);
|
|
||||||
let size = size_from_sibling(&s).unwrap_or(0);
|
|
||||||
let sha256 = expected_sha_from_sibling(&s);
|
|
||||||
out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fill missing metadata (size/hash) via HEAD request if necessary
|
|
||||||
if out.iter().any(|m| m.size == 0 || m.sha256.is_none()) {
|
|
||||||
eprintln!("Fetching online data: completing metadata checks for models in {}...", repo);
|
|
||||||
}
|
|
||||||
for m in out.iter_mut() {
|
|
||||||
if m.size == 0 || m.sha256.is_none() {
|
|
||||||
let (sz, sha) = fill_meta_via_head(m.repo, &m.name);
|
|
||||||
if m.size == 0 {
|
|
||||||
if let Some(s) = sz { m.size = s; }
|
|
||||||
}
|
|
||||||
if m.sha256.is_none() {
|
|
||||||
m.sha256 = sha;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sort by base then subtype then name for stable listing
|
|
||||||
out.sort_by(|a, b| a.base.cmp(&b.base).then(a.subtype.cmp(&b.subtype)).then(a.name.cmp(&b.name)));
|
|
||||||
Ok(out)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
fn fetch_all_models(client: &Client) -> Result<Vec<ModelEntry>> {
|
|
||||||
eprintln!("Fetching online data: aggregating available models from Hugging Face...");
|
|
||||||
let mut v1 = hf_fetch_repo_models(client, "ggerganov/whisper.cpp")?; // main repo must succeed
|
|
||||||
|
|
||||||
// Optional tinydiarize repo; ignore errors but log to stderr
|
|
||||||
let mut v2: Vec<ModelEntry> = match hf_fetch_repo_models(client, "akashmjn/tinydiarize-whisper.cpp") {
|
|
||||||
Ok(v) => v,
|
|
||||||
Err(e) => {
|
|
||||||
eprintln!("Warning: failed to fetch optional repo akashmjn/tinydiarize-whisper.cpp: {:#}", e);
|
|
||||||
Vec::new()
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
v1.append(&mut v2);
|
|
||||||
|
|
||||||
// Deduplicate by name preferring ggerganov repo if duplicates
|
|
||||||
let mut map: BTreeMap<String, ModelEntry> = BTreeMap::new();
|
|
||||||
for m in v1 {
|
|
||||||
map.entry(m.name.clone()).and_modify(|existing| {
|
|
||||||
if existing.repo != "ggerganov/whisper.cpp" && m.repo == "ggerganov/whisper.cpp" {
|
|
||||||
*existing = m.clone();
|
|
||||||
}
|
|
||||||
}).or_insert(m);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut list: Vec<ModelEntry> = map.into_values().collect();
|
|
||||||
list.sort_by(|a, b| a.base.cmp(&b.base).then(a.subtype.cmp(&b.subtype)).then(a.name.cmp(&b.name)));
|
|
||||||
Ok(list)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn print_grouped_models(models: &[ModelEntry]) {
|
|
||||||
let mut current = "".to_string();
|
|
||||||
for m in models {
|
|
||||||
if m.base != current {
|
|
||||||
current = m.base.clone();
|
|
||||||
println!("\n{}:", current);
|
|
||||||
}
|
|
||||||
let short_hash = m
|
|
||||||
.sha256
|
|
||||||
.as_ref()
|
|
||||||
.map(|h| h.chars().take(8).collect::<String>())
|
|
||||||
.unwrap_or_else(|| "-".to_string());
|
|
||||||
println!(" - {} [{} | {} | {}]", m.name, m.repo, human_size(m.size), short_hash);
|
|
||||||
}
|
|
||||||
println!("\nEnter selection by indices (e.g., 1 3 5-7), or 'all', '*' for all, 'q' to cancel.");
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prompt_select_models(models: &[ModelEntry]) -> Result<Vec<ModelEntry>> {
|
|
||||||
// Build a flat list but show group headers; indices count only models
|
|
||||||
println!("Available ggml Whisper models:");
|
|
||||||
let mut idx = 1usize;
|
|
||||||
let mut current = "".to_string();
|
|
||||||
// We'll record mapping from index -> position in models
|
|
||||||
let mut index_map: Vec<usize> = Vec::with_capacity(models.len());
|
|
||||||
for (pos, m) in models.iter().enumerate() {
|
|
||||||
if m.base != current {
|
|
||||||
current = m.base.clone();
|
|
||||||
println!("\n{}:", current);
|
|
||||||
}
|
|
||||||
let short_hash = m
|
|
||||||
.sha256
|
|
||||||
.as_ref()
|
|
||||||
.map(|h| h.chars().take(8).collect::<String>())
|
|
||||||
.unwrap_or_else(|| "-".to_string());
|
|
||||||
println!(" {}) {} [{} | {} | {}]", idx, m.name, m.repo, human_size(m.size), short_hash);
|
|
||||||
index_map.push(pos);
|
|
||||||
idx += 1;
|
|
||||||
}
|
|
||||||
println!("\nEnter selection by indices (e.g., 1 3 5-7), or 'all', '*' for all, 'q' to cancel.");
|
|
||||||
loop {
|
|
||||||
eprint!("Selection: ");
|
|
||||||
io::stderr().flush().ok();
|
|
||||||
let mut line = String::new();
|
|
||||||
io::stdin().read_line(&mut line).context("Failed to read selection")?;
|
|
||||||
let s = line.trim().to_lowercase();
|
|
||||||
if s == "q" || s == "quit" || s == "exit" { return Ok(Vec::new()); }
|
|
||||||
let mut selected: Vec<usize> = Vec::new();
|
|
||||||
if s == "all" || s == "*" {
|
|
||||||
selected = (1..idx).collect();
|
|
||||||
} else if !s.is_empty() {
|
|
||||||
for part in s.split(|c| c == ',' || c == ' ' || c == ';') {
|
|
||||||
let part = part.trim();
|
|
||||||
if part.is_empty() { continue; }
|
|
||||||
if let Some((a, b)) = part.split_once('-') {
|
|
||||||
if let (Ok(ia), Ok(ib)) = (a.parse::<usize>(), b.parse::<usize>()) {
|
|
||||||
if ia >= 1 && ib < idx && ia <= ib { selected.extend(ia..=ib); }
|
|
||||||
}
|
|
||||||
} else if let Ok(i) = part.parse::<usize>() {
|
|
||||||
if i >= 1 && i < idx { selected.push(i); }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
selected.sort_unstable();
|
|
||||||
selected.dedup();
|
|
||||||
if selected.is_empty() {
|
|
||||||
eprintln!("No valid selection. Please try again or 'q' to cancel.");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let chosen: Vec<ModelEntry> = selected.into_iter().map(|i| models[index_map[i - 1]].clone()).collect();
|
|
||||||
return Ok(chosen);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn compute_file_sha256_hex(path: &Path) -> Result<String> {
|
|
||||||
let file = File::open(path)
|
|
||||||
.with_context(|| format!("Failed to open for hashing: {}", path.display()))?;
|
|
||||||
let mut reader = std::io::BufReader::new(file);
|
|
||||||
let mut hasher = Sha256::new();
|
|
||||||
let mut buf = [0u8; 1024 * 128];
|
|
||||||
loop {
|
|
||||||
let n = reader.read(&mut buf).context("Read error during hashing")?;
|
|
||||||
if n == 0 { break; }
|
|
||||||
hasher.update(&buf[..n]);
|
|
||||||
}
|
|
||||||
Ok(to_hex_lower(&hasher.finalize()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> {
|
|
||||||
let final_path = models_dir.join(format!("ggml-{}.bin", entry.name));
|
|
||||||
|
|
||||||
// If the model already exists, verify against online metadata
|
|
||||||
if final_path.exists() {
|
|
||||||
if let Some(expected) = &entry.sha256 {
|
|
||||||
match compute_file_sha256_hex(&final_path) {
|
|
||||||
Ok(local_hash) => {
|
|
||||||
if local_hash.eq_ignore_ascii_case(expected) {
|
|
||||||
eprintln!("Model {} is up-to-date (hash match).", final_path.display());
|
|
||||||
return Ok(());
|
|
||||||
} else {
|
|
||||||
eprintln!(
|
|
||||||
"Local model {} hash differs from online (local {}.., online {}..). Updating...",
|
|
||||||
final_path.display(),
|
|
||||||
&local_hash[..std::cmp::min(8, local_hash.len())],
|
|
||||||
&expected[..std::cmp::min(8, expected.len())]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
eprintln!(
|
|
||||||
"Warning: failed to hash existing {}: {}. Will re-download to ensure correctness.",
|
|
||||||
final_path.display(), e
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if entry.size > 0 {
|
|
||||||
match std::fs::metadata(&final_path) {
|
|
||||||
Ok(md) => {
|
|
||||||
if md.len() == entry.size {
|
|
||||||
eprintln!(
|
|
||||||
"Model {} appears up-to-date by size ({}).",
|
|
||||||
final_path.display(), entry.size
|
|
||||||
);
|
|
||||||
return Ok(());
|
|
||||||
} else {
|
|
||||||
eprintln!(
|
|
||||||
"Local model {} size ({}) differs from online ({}). Updating...",
|
|
||||||
final_path.display(), md.len(), entry.size
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
eprintln!(
|
|
||||||
"Warning: failed to stat existing {}: {}. Will re-download to ensure correctness.",
|
|
||||||
final_path.display(), e
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
eprintln!(
|
|
||||||
"Model {} exists but remote hash/size not available; will download to verify contents.",
|
|
||||||
final_path.display()
|
|
||||||
);
|
|
||||||
// Fall through to download for content comparison
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let url = format!("https://huggingface.co/{}/resolve/main/ggml-{}.bin", entry.repo, entry.name);
|
|
||||||
eprintln!("Downloading {} ({} | {})...", entry.name, human_size(entry.size), url);
|
|
||||||
let mut resp = client
|
|
||||||
.get(url)
|
|
||||||
.send()
|
|
||||||
.and_then(|r| r.error_for_status())
|
|
||||||
.context("Failed to download model")?;
|
|
||||||
|
|
||||||
let tmp_path = models_dir.join(format!("ggml-{}.bin.part", entry.name));
|
|
||||||
if tmp_path.exists() {
|
|
||||||
let _ = std::fs::remove_file(&tmp_path);
|
|
||||||
}
|
|
||||||
let mut file = std::io::BufWriter::new(
|
|
||||||
File::create(&tmp_path)
|
|
||||||
.with_context(|| format!("Failed to create {}", tmp_path.display()))?
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut hasher = Sha256::new();
|
|
||||||
let mut buf = [0u8; 1024 * 128];
|
|
||||||
loop {
|
|
||||||
let n = resp.read(&mut buf).context("Network read error")?;
|
|
||||||
if n == 0 { break; }
|
|
||||||
hasher.update(&buf[..n]);
|
|
||||||
file.write_all(&buf[..n]).context("Write error")?;
|
|
||||||
}
|
|
||||||
file.flush().ok();
|
|
||||||
|
|
||||||
let got = to_hex_lower(&hasher.finalize());
|
|
||||||
if let Some(expected) = &entry.sha256 {
|
|
||||||
if got != expected.to_lowercase() {
|
|
||||||
let _ = std::fs::remove_file(&tmp_path);
|
|
||||||
return Err(anyhow!(
|
|
||||||
"SHA-256 mismatch for {}: expected {}, got {}",
|
|
||||||
entry.name, expected, got
|
|
||||||
));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
eprintln!("Warning: no SHA-256 available for {}. Skipping verification.", entry.name);
|
|
||||||
}
|
|
||||||
// Replace existing file safely
|
|
||||||
if final_path.exists() {
|
|
||||||
let _ = std::fs::remove_file(&final_path);
|
|
||||||
}
|
|
||||||
std::fs::rename(&tmp_path, &final_path)
|
|
||||||
.with_context(|| format!("Failed to move into place: {}", final_path.display()))?;
|
|
||||||
eprintln!("Saved: {}", final_path.display());
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_interactive_model_downloader() -> Result<()> {
|
|
||||||
let models_dir = Path::new("models");
|
|
||||||
if !models_dir.exists() { create_dir_all(models_dir).context("Failed to create models directory")?; }
|
|
||||||
let client = Client::builder()
|
|
||||||
.user_agent("dialogue_merger/0.1 (+https://github.com/)")
|
|
||||||
.timeout(std::time::Duration::from_secs(600))
|
|
||||||
.build()
|
|
||||||
.context("Failed to build HTTP client")?;
|
|
||||||
|
|
||||||
eprintln!("Fetching online data: contacting Hugging Face to retrieve available models (this may take a moment)...");
|
|
||||||
let models = fetch_all_models(&client)?;
|
|
||||||
if models.is_empty() {
|
|
||||||
eprintln!("No models found on Hugging Face listing. Please try again later.");
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let selected = prompt_select_models(&models)?;
|
|
||||||
if selected.is_empty() {
|
|
||||||
eprintln!("No selection. Aborting download.");
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
for m in selected {
|
|
||||||
if let Err(e) = download_one_model(&client, models_dir, &m) { eprintln!("Error: {:#}", e); }
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
474
src/models.rs
Normal file
474
src/models.rs
Normal file
@@ -0,0 +1,474 @@
|
|||||||
|
use std::fs::{File, create_dir_all};
|
||||||
|
use std::io::{self, Read, Write};
|
||||||
|
use std::path::Path;
|
||||||
|
use std::collections::BTreeMap;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Context, Result};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use reqwest::blocking::Client;
|
||||||
|
use reqwest::redirect::Policy;
|
||||||
|
use sha2::{Digest, Sha256};
|
||||||
|
|
||||||
|
// --- Model downloader: list & download ggml models from Hugging Face ---
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct HFLfsMeta {
|
||||||
|
oid: Option<String>,
|
||||||
|
size: Option<u64>,
|
||||||
|
sha256: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct HFSibling {
|
||||||
|
rfilename: String,
|
||||||
|
size: Option<u64>,
|
||||||
|
sha256: Option<String>,
|
||||||
|
lfs: Option<HFLfsMeta>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct HFRepoInfo {
|
||||||
|
// When using ?expand=files the field is named 'siblings'
|
||||||
|
siblings: Option<Vec<HFSibling>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct HFTreeItem {
|
||||||
|
path: String,
|
||||||
|
size: Option<u64>,
|
||||||
|
sha256: Option<String>,
|
||||||
|
lfs: Option<HFLfsMeta>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct ModelEntry {
|
||||||
|
// e.g. "tiny.en-q5_1"
|
||||||
|
name: String,
|
||||||
|
base: String,
|
||||||
|
subtype: String,
|
||||||
|
size: u64,
|
||||||
|
sha256: Option<String>,
|
||||||
|
repo: &'static str, // e.g. "ggerganov/whisper.cpp"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn split_model_name(model: &str) -> (String, String) {
|
||||||
|
let mut idx = None;
|
||||||
|
for (i, ch) in model.char_indices() {
|
||||||
|
if ch == '.' || ch == '-' {
|
||||||
|
idx = Some(i);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(i) = idx {
|
||||||
|
(model[..i].to_string(), model[i + 1..].to_string())
|
||||||
|
} else {
|
||||||
|
(model.to_string(), String::new())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn human_size(bytes: u64) -> String {
|
||||||
|
const KB: f64 = 1024.0;
|
||||||
|
const MB: f64 = KB * 1024.0;
|
||||||
|
const GB: f64 = MB * 1024.0;
|
||||||
|
let b = bytes as f64;
|
||||||
|
if b >= GB { format!("{:.2} GiB", b / GB) }
|
||||||
|
else if b >= MB { format!("{:.2} MiB", b / MB) }
|
||||||
|
else if b >= KB { format!("{:.2} KiB", b / KB) }
|
||||||
|
else { format!("{} B", bytes) }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_hex_lower(bytes: &[u8]) -> String {
|
||||||
|
let mut s = String::with_capacity(bytes.len() * 2);
|
||||||
|
for b in bytes { s.push_str(&format!("{:02x}", b)); }
|
||||||
|
s
|
||||||
|
}
|
||||||
|
|
||||||
|
fn expected_sha_from_sibling(s: &HFSibling) -> Option<String> {
|
||||||
|
if let Some(h) = &s.sha256 { return Some(h.to_lowercase()); }
|
||||||
|
if let Some(lfs) = &s.lfs {
|
||||||
|
if let Some(h) = &lfs.sha256 { return Some(h.to_lowercase()); }
|
||||||
|
if let Some(oid) = &lfs.oid {
|
||||||
|
// e.g. "sha256:abcdef..."
|
||||||
|
if let Some(rest) = oid.strip_prefix("sha256:") {
|
||||||
|
return Some(rest.to_lowercase().to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn size_from_sibling(s: &HFSibling) -> Option<u64> {
|
||||||
|
if let Some(sz) = s.size { return Some(sz); }
|
||||||
|
if let Some(lfs) = &s.lfs { return lfs.size; }
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn expected_sha_from_tree(s: &HFTreeItem) -> Option<String> {
|
||||||
|
if let Some(h) = &s.sha256 { return Some(h.to_lowercase()); }
|
||||||
|
if let Some(lfs) = &s.lfs {
|
||||||
|
if let Some(h) = &lfs.sha256 { return Some(h.to_lowercase()); }
|
||||||
|
if let Some(oid) = &lfs.oid {
|
||||||
|
if let Some(rest) = oid.strip_prefix("sha256:") {
|
||||||
|
return Some(rest.to_lowercase().to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn size_from_tree(s: &HFTreeItem) -> Option<u64> {
|
||||||
|
if let Some(sz) = s.size { return Some(sz); }
|
||||||
|
if let Some(lfs) = &s.lfs { return lfs.size; }
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fill_meta_via_head(repo: &'static str, name: &str) -> (Option<u64>, Option<String>) {
|
||||||
|
let head_client = match Client::builder()
|
||||||
|
.user_agent("dialogue_merger/0.1 (+https://github.com/)")
|
||||||
|
.redirect(Policy::none())
|
||||||
|
.timeout(Duration::from_secs(30))
|
||||||
|
.build() {
|
||||||
|
Ok(c) => c,
|
||||||
|
Err(_) => return (None, None),
|
||||||
|
};
|
||||||
|
let url = format!("https://huggingface.co/{}/resolve/main/ggml-{}.bin", repo, name);
|
||||||
|
let resp = match head_client.head(url).send().and_then(|r| r.error_for_status()) {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(_) => return (None, None),
|
||||||
|
};
|
||||||
|
let headers = resp.headers();
|
||||||
|
let size = headers
|
||||||
|
.get("x-linked-size")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.and_then(|s| s.parse::<u64>().ok());
|
||||||
|
let mut sha = headers
|
||||||
|
.get("x-linked-etag")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(|s| s.trim().trim_matches('"').to_string());
|
||||||
|
if let Some(h) = &mut sha {
|
||||||
|
h.make_ascii_lowercase();
|
||||||
|
if h.len() != 64 || !h.chars().all(|c| c.is_ascii_hexdigit()) {
|
||||||
|
sha = None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fallback: try x-xet-hash header if x-linked-etag is missing/invalid
|
||||||
|
if sha.is_none() {
|
||||||
|
sha = headers
|
||||||
|
.get("x-xet-hash")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(|s| s.trim().trim_matches('"').to_string());
|
||||||
|
if let Some(h) = &mut sha {
|
||||||
|
h.make_ascii_lowercase();
|
||||||
|
if h.len() != 64 || !h.chars().all(|c| c.is_ascii_hexdigit()) {
|
||||||
|
sha = None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(size, sha)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<ModelEntry>> {
|
||||||
|
eprintln!("Fetching online data: listing models from {}...", repo);
|
||||||
|
// Prefer the tree endpoint for reliable size/hash metadata, then fall back to model metadata
|
||||||
|
let tree_url = format!("https://huggingface.co/api/models/{}/tree/main?recursive=1", repo);
|
||||||
|
let mut out: Vec<ModelEntry> = Vec::new();
|
||||||
|
|
||||||
|
match client.get(tree_url).send().and_then(|r| r.error_for_status()) {
|
||||||
|
Ok(resp) => {
|
||||||
|
match resp.json::<Vec<HFTreeItem>>() {
|
||||||
|
Ok(items) => {
|
||||||
|
for it in items {
|
||||||
|
let path = it.path.clone();
|
||||||
|
if !(path.starts_with("ggml-") && path.ends_with(".bin")) { continue; }
|
||||||
|
let model_name = path.trim_start_matches("ggml-").trim_end_matches(".bin").to_string();
|
||||||
|
let (base, subtype) = split_model_name(&model_name);
|
||||||
|
let size = size_from_tree(&it).unwrap_or(0);
|
||||||
|
let sha256 = expected_sha_from_tree(&it);
|
||||||
|
out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => { /* fall back below */ }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => { /* fall back below */ }
|
||||||
|
}
|
||||||
|
|
||||||
|
if out.is_empty() {
|
||||||
|
let url = format!("https://huggingface.co/api/models/{}", repo);
|
||||||
|
let resp = client
|
||||||
|
.get(url)
|
||||||
|
.send()
|
||||||
|
.and_then(|r| r.error_for_status())
|
||||||
|
.context("Failed to query Hugging Face API")?;
|
||||||
|
|
||||||
|
let info: HFRepoInfo = resp.json().context("Failed to parse Hugging Face API response")?;
|
||||||
|
|
||||||
|
if let Some(files) = info.siblings {
|
||||||
|
for s in files {
|
||||||
|
let rf = s.rfilename.clone();
|
||||||
|
if !(rf.starts_with("ggml-") && rf.ends_with(".bin")) { continue; }
|
||||||
|
let model_name = rf.trim_start_matches("ggml-").trim_end_matches(".bin").to_string();
|
||||||
|
let (base, subtype) = split_model_name(&model_name);
|
||||||
|
let size = size_from_sibling(&s).unwrap_or(0);
|
||||||
|
let sha256 = expected_sha_from_sibling(&s);
|
||||||
|
out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill missing metadata (size/hash) via HEAD request if necessary
|
||||||
|
if out.iter().any(|m| m.size == 0 || m.sha256.is_none()) {
|
||||||
|
eprintln!("Fetching online data: completing metadata checks for models in {}...", repo);
|
||||||
|
}
|
||||||
|
for m in out.iter_mut() {
|
||||||
|
if m.size == 0 || m.sha256.is_none() {
|
||||||
|
let (sz, sha) = fill_meta_via_head(m.repo, &m.name);
|
||||||
|
if m.size == 0 {
|
||||||
|
if let Some(s) = sz { m.size = s; }
|
||||||
|
}
|
||||||
|
if m.sha256.is_none() {
|
||||||
|
m.sha256 = sha;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by base then subtype then name for stable listing
|
||||||
|
out.sort_by(|a, b| a.base.cmp(&b.base).then(a.subtype.cmp(&b.subtype)).then(a.name.cmp(&b.name)));
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fetch_all_models(client: &Client) -> Result<Vec<ModelEntry>> {
|
||||||
|
eprintln!("Fetching online data: aggregating available models from Hugging Face...");
|
||||||
|
let mut v1 = hf_fetch_repo_models(client, "ggerganov/whisper.cpp")?; // main repo must succeed
|
||||||
|
|
||||||
|
// Optional tinydiarize repo; ignore errors but log to stderr
|
||||||
|
let mut v2: Vec<ModelEntry> = match hf_fetch_repo_models(client, "akashmjn/tinydiarize-whisper.cpp") {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("Warning: failed to fetch optional repo akashmjn/tinydiarize-whisper.cpp: {:#}", e);
|
||||||
|
Vec::new()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
v1.append(&mut v2);
|
||||||
|
|
||||||
|
// Deduplicate by name preferring ggerganov repo if duplicates
|
||||||
|
let mut map: BTreeMap<String, ModelEntry> = BTreeMap::new();
|
||||||
|
for m in v1 {
|
||||||
|
map.entry(m.name.clone()).and_modify(|existing| {
|
||||||
|
if existing.repo != "ggerganov/whisper.cpp" && m.repo == "ggerganov/whisper.cpp" {
|
||||||
|
*existing = m.clone();
|
||||||
|
}
|
||||||
|
}).or_insert(m);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut list: Vec<ModelEntry> = map.into_values().collect();
|
||||||
|
list.sort_by(|a, b| a.base.cmp(&b.base).then(a.subtype.cmp(&b.subtype)).then(a.name.cmp(&b.name)));
|
||||||
|
Ok(list)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fn prompt_select_models(models: &[ModelEntry]) -> Result<Vec<ModelEntry>> {
|
||||||
|
// Build a flat list but show group headers; indices count only models
|
||||||
|
println!("Available ggml Whisper models:");
|
||||||
|
let mut idx = 1usize;
|
||||||
|
let mut current = "".to_string();
|
||||||
|
// We'll record mapping from index -> position in models
|
||||||
|
let mut index_map: Vec<usize> = Vec::with_capacity(models.len());
|
||||||
|
for (pos, m) in models.iter().enumerate() {
|
||||||
|
if m.base != current {
|
||||||
|
current = m.base.clone();
|
||||||
|
println!("\n{}:", current);
|
||||||
|
}
|
||||||
|
let short_hash = m
|
||||||
|
.sha256
|
||||||
|
.as_ref()
|
||||||
|
.map(|h| h.chars().take(8).collect::<String>())
|
||||||
|
.unwrap_or_else(|| "-".to_string());
|
||||||
|
println!(" {}) {} [{} | {} | {}]", idx, m.name, m.repo, human_size(m.size), short_hash);
|
||||||
|
index_map.push(pos);
|
||||||
|
idx += 1;
|
||||||
|
}
|
||||||
|
println!("\nEnter selection by indices (e.g., 1 3 5-7), or 'all', '*' for all, 'q' to cancel.");
|
||||||
|
loop {
|
||||||
|
eprint!("Selection: ");
|
||||||
|
io::stderr().flush().ok();
|
||||||
|
let mut line = String::new();
|
||||||
|
io::stdin().read_line(&mut line).context("Failed to read selection")?;
|
||||||
|
let s = line.trim().to_lowercase();
|
||||||
|
if s == "q" || s == "quit" || s == "exit" { return Ok(Vec::new()); }
|
||||||
|
let mut selected: Vec<usize> = Vec::new();
|
||||||
|
if s == "all" || s == "*" {
|
||||||
|
selected = (1..idx).collect();
|
||||||
|
} else if !s.is_empty() {
|
||||||
|
for part in s.split(|c| c == ',' || c == ' ' || c == ';') {
|
||||||
|
let part = part.trim();
|
||||||
|
if part.is_empty() { continue; }
|
||||||
|
if let Some((a, b)) = part.split_once('-') {
|
||||||
|
if let (Ok(ia), Ok(ib)) = (a.parse::<usize>(), b.parse::<usize>()) {
|
||||||
|
if ia >= 1 && ib < idx && ia <= ib { selected.extend(ia..=ib); }
|
||||||
|
}
|
||||||
|
} else if let Ok(i) = part.parse::<usize>() {
|
||||||
|
if i >= 1 && i < idx { selected.push(i); }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
selected.sort_unstable();
|
||||||
|
selected.dedup();
|
||||||
|
if selected.is_empty() {
|
||||||
|
eprintln!("No valid selection. Please try again or 'q' to cancel.");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let chosen: Vec<ModelEntry> = selected.into_iter().map(|i| models[index_map[i - 1]].clone()).collect();
|
||||||
|
return Ok(chosen);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_file_sha256_hex(path: &Path) -> Result<String> {
|
||||||
|
let file = File::open(path)
|
||||||
|
.with_context(|| format!("Failed to open for hashing: {}", path.display()))?;
|
||||||
|
let mut reader = std::io::BufReader::new(file);
|
||||||
|
let mut hasher = Sha256::new();
|
||||||
|
let mut buf = [0u8; 1024 * 128];
|
||||||
|
loop {
|
||||||
|
let n = reader.read(&mut buf).context("Read error during hashing")?;
|
||||||
|
if n == 0 { break; }
|
||||||
|
hasher.update(&buf[..n]);
|
||||||
|
}
|
||||||
|
Ok(to_hex_lower(&hasher.finalize()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run_interactive_model_downloader() -> Result<()> {
|
||||||
|
let models_dir = Path::new("models");
|
||||||
|
if !models_dir.exists() { create_dir_all(models_dir).context("Failed to create models directory")?; }
|
||||||
|
let client = Client::builder()
|
||||||
|
.user_agent("dialogue_merger/0.1 (+https://github.com/)")
|
||||||
|
.timeout(std::time::Duration::from_secs(600))
|
||||||
|
.build()
|
||||||
|
.context("Failed to build HTTP client")?;
|
||||||
|
|
||||||
|
eprintln!("Fetching online data: contacting Hugging Face to retrieve available models (this may take a moment)...");
|
||||||
|
let models = fetch_all_models(&client)?;
|
||||||
|
if models.is_empty() {
|
||||||
|
eprintln!("No models found on Hugging Face listing. Please try again later.");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
let selected = prompt_select_models(&models)?;
|
||||||
|
if selected.is_empty() {
|
||||||
|
eprintln!("No selection. Aborting download.");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
for m in selected {
|
||||||
|
if let Err(e) = download_one_model(&client, models_dir, &m) { eprintln!("Error: {:#}", e); }
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> {
|
||||||
|
let final_path = models_dir.join(format!("ggml-{}.bin", entry.name));
|
||||||
|
|
||||||
|
// If the model already exists, verify against online metadata
|
||||||
|
if final_path.exists() {
|
||||||
|
if let Some(expected) = &entry.sha256 {
|
||||||
|
match compute_file_sha256_hex(&final_path) {
|
||||||
|
Ok(local_hash) => {
|
||||||
|
if local_hash.eq_ignore_ascii_case(expected) {
|
||||||
|
eprintln!("Model {} is up-to-date (hash match).", final_path.display());
|
||||||
|
return Ok(());
|
||||||
|
} else {
|
||||||
|
eprintln!(
|
||||||
|
"Local model {} hash differs from online (local {}.., online {}..). Updating...",
|
||||||
|
final_path.display(),
|
||||||
|
&local_hash[..std::cmp::min(8, local_hash.len())],
|
||||||
|
&expected[..std::cmp::min(8, expected.len())]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!(
|
||||||
|
"Warning: failed to hash existing {}: {}. Will re-download to ensure correctness.",
|
||||||
|
final_path.display(), e
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if entry.size > 0 {
|
||||||
|
match std::fs::metadata(&final_path) {
|
||||||
|
Ok(md) => {
|
||||||
|
if md.len() == entry.size {
|
||||||
|
eprintln!(
|
||||||
|
"Model {} appears up-to-date by size ({}).",
|
||||||
|
final_path.display(), entry.size
|
||||||
|
);
|
||||||
|
return Ok(());
|
||||||
|
} else {
|
||||||
|
eprintln!(
|
||||||
|
"Local model {} size ({}) differs from online ({}). Updating...",
|
||||||
|
final_path.display(), md.len(), entry.size
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!(
|
||||||
|
"Warning: failed to stat existing {}: {}. Will re-download to ensure correctness.",
|
||||||
|
final_path.display(), e
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
eprintln!(
|
||||||
|
"Model {} exists but remote hash/size not available; will download to verify contents.",
|
||||||
|
final_path.display()
|
||||||
|
);
|
||||||
|
// Fall through to download for content comparison
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let url = format!("https://huggingface.co/{}/resolve/main/ggml-{}.bin", entry.repo, entry.name);
|
||||||
|
eprintln!("Downloading {} ({} | {})...", entry.name, human_size(entry.size), url);
|
||||||
|
let mut resp = client
|
||||||
|
.get(url)
|
||||||
|
.send()
|
||||||
|
.and_then(|r| r.error_for_status())
|
||||||
|
.context("Failed to download model")?;
|
||||||
|
|
||||||
|
let tmp_path = models_dir.join(format!("ggml-{}.bin.part", entry.name));
|
||||||
|
if tmp_path.exists() {
|
||||||
|
let _ = std::fs::remove_file(&tmp_path);
|
||||||
|
}
|
||||||
|
let mut file = std::io::BufWriter::new(
|
||||||
|
File::create(&tmp_path)
|
||||||
|
.with_context(|| format!("Failed to create {}", tmp_path.display()))?
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut hasher = Sha256::new();
|
||||||
|
let mut buf = [0u8; 1024 * 128];
|
||||||
|
loop {
|
||||||
|
let n = resp.read(&mut buf).context("Network read error")?;
|
||||||
|
if n == 0 { break; }
|
||||||
|
hasher.update(&buf[..n]);
|
||||||
|
file.write_all(&buf[..n]).context("Write error")?;
|
||||||
|
}
|
||||||
|
file.flush().ok();
|
||||||
|
|
||||||
|
let got = to_hex_lower(&hasher.finalize());
|
||||||
|
if let Some(expected) = &entry.sha256 {
|
||||||
|
if got != expected.to_lowercase() {
|
||||||
|
let _ = std::fs::remove_file(&tmp_path);
|
||||||
|
return Err(anyhow!(
|
||||||
|
"SHA-256 mismatch for {}: expected {}, got {}",
|
||||||
|
entry.name, expected, got
|
||||||
|
));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
eprintln!("Warning: no SHA-256 available for {}. Skipping verification.", entry.name);
|
||||||
|
}
|
||||||
|
// Replace existing file safely
|
||||||
|
if final_path.exists() {
|
||||||
|
let _ = std::fs::remove_file(&final_path);
|
||||||
|
}
|
||||||
|
std::fs::rename(&tmp_path, &final_path)
|
||||||
|
.with_context(|| format!("Failed to move into place: {}", final_path.display()))?;
|
||||||
|
eprintln!("Saved: {}", final_path.display());
|
||||||
|
Ok(())
|
||||||
|
}
|
Reference in New Issue
Block a user