From a6009693ef6b503c1813175e709fd89a9282917d Mon Sep 17 00:00:00 2001 From: vikingowl Date: Fri, 8 Aug 2025 08:45:19 +0200 Subject: [PATCH] [refactor] extract model downloading functionality into a separate `models` module --- TODO.md | 7 +- src/main.rs | 490 +------------------------------------------------- src/models.rs | 474 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 481 insertions(+), 490 deletions(-) create mode 100644 src/models.rs diff --git a/TODO.md b/TODO.md index 76c75a0..c87ed08 100644 --- a/TODO.md +++ b/TODO.md @@ -1,8 +1,5 @@ -- refactor into multiple files - -- fix cli output for model display - - update the project to no more use features +- update last_model to be only used during one run - rename project to "PolyScribe" @@ -13,6 +10,8 @@ - for merging (command line flag) -> if not present, treat each file as separate output (--merge | -m) - for merge + separate output -> if present, treat each file as separate output and also output a merged version (--merge-and-separate) - set speaker-names per input-file -> prompt user for each file if flag is set (--set-speaker-names) +- fix cli output for model display +- refactor into proper cli app - add support for video files -> use ffmpeg to extract audio diff --git a/src/main.rs b/src/main.rs index 2dd77fd..2febfb6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,20 +3,17 @@ use std::io::{self, Read, Write}; use std::path::{Path, PathBuf}; use std::process::Command; use std::env; -use std::collections::BTreeMap; use anyhow::{anyhow, Context, Result}; use clap::Parser; use serde::{Deserialize, Serialize}; use chrono::Local; -use reqwest::blocking::Client; -use reqwest::redirect::Policy; -use sha2::{Digest, Sha256}; -use std::time::Duration; #[cfg(feature = "native-whisper")] use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters}; +mod models; + #[derive(Parser, Debug)] #[command(name = "merge_transcripts", version, about = "Merge multiple JSON transcripts into one or transcribe audio using native whisper")] struct Args { @@ -182,7 +179,7 @@ fn find_model_file() -> Result { io::stdin().read_line(&mut input).ok(); let ans = input.trim().to_lowercase(); if ans.is_empty() || ans == "y" || ans == "yes" { - if let Err(e) = run_interactive_model_downloader() { + if let Err(e) = models::run_interactive_model_downloader() { eprintln!("Downloader failed: {:#}", e); } // Re-scan @@ -346,7 +343,7 @@ fn main() -> Result<()> { // If requested, run the interactive model downloader first. If no inputs were provided, exit after downloading. if args.download_models { - if let Err(e) = run_interactive_model_downloader() { + if let Err(e) = models::run_interactive_model_downloader() { eprintln!("Model downloader failed: {:#}", e); } if args.inputs.is_empty() { @@ -503,482 +500,3 @@ fn main() -> Result<()> { Ok(()) } -// --- Model downloader: list & download ggml models from Hugging Face --- - -#[derive(Debug, Deserialize)] -struct HFLfsMeta { - oid: Option, - size: Option, - sha256: Option, -} - -#[derive(Debug, Deserialize)] -struct HFSibling { - rfilename: String, - size: Option, - sha256: Option, - lfs: Option, -} - -#[derive(Debug, Deserialize)] -struct HFRepoInfo { - // When using ?expand=files the field is named 'siblings' - siblings: Option>, -} - -#[derive(Debug, Deserialize)] -struct HFTreeItem { - path: String, - size: Option, - sha256: Option, - lfs: Option, -} - -#[derive(Clone, Debug)] -struct ModelEntry { - // e.g. "tiny.en-q5_1" - name: String, - base: String, - subtype: String, - size: u64, - sha256: Option, - repo: &'static str, // e.g. "ggerganov/whisper.cpp" -} - -fn split_model_name(model: &str) -> (String, String) { - let mut idx = None; - for (i, ch) in model.char_indices() { - if ch == '.' || ch == '-' { - idx = Some(i); - break; - } - } - if let Some(i) = idx { - (model[..i].to_string(), model[i + 1..].to_string()) - } else { - (model.to_string(), String::new()) - } -} - -fn human_size(bytes: u64) -> String { - const KB: f64 = 1024.0; - const MB: f64 = KB * 1024.0; - const GB: f64 = MB * 1024.0; - let b = bytes as f64; - if b >= GB { format!("{:.2} GiB", b / GB) } - else if b >= MB { format!("{:.2} MiB", b / MB) } - else if b >= KB { format!("{:.2} KiB", b / KB) } - else { format!("{} B", bytes) } -} - -fn to_hex_lower(bytes: &[u8]) -> String { - let mut s = String::with_capacity(bytes.len() * 2); - for b in bytes { s.push_str(&format!("{:02x}", b)); } - s -} - -fn expected_sha_from_sibling(s: &HFSibling) -> Option { - if let Some(h) = &s.sha256 { return Some(h.to_lowercase()); } - if let Some(lfs) = &s.lfs { - if let Some(h) = &lfs.sha256 { return Some(h.to_lowercase()); } - if let Some(oid) = &lfs.oid { - // e.g. "sha256:abcdef..." - if let Some(rest) = oid.strip_prefix("sha256:") { - return Some(rest.to_lowercase().to_string()); - } - } - } - None -} - -fn size_from_sibling(s: &HFSibling) -> Option { - if let Some(sz) = s.size { return Some(sz); } - if let Some(lfs) = &s.lfs { return lfs.size; } - None -} - -fn expected_sha_from_tree(s: &HFTreeItem) -> Option { - if let Some(h) = &s.sha256 { return Some(h.to_lowercase()); } - if let Some(lfs) = &s.lfs { - if let Some(h) = &lfs.sha256 { return Some(h.to_lowercase()); } - if let Some(oid) = &lfs.oid { - if let Some(rest) = oid.strip_prefix("sha256:") { - return Some(rest.to_lowercase().to_string()); - } - } - } - None -} - -fn size_from_tree(s: &HFTreeItem) -> Option { - if let Some(sz) = s.size { return Some(sz); } - if let Some(lfs) = &s.lfs { return lfs.size; } - None -} - -fn fill_meta_via_head(repo: &'static str, name: &str) -> (Option, Option) { - let head_client = match Client::builder() - .user_agent("dialogue_merger/0.1 (+https://github.com/)") - .redirect(Policy::none()) - .timeout(Duration::from_secs(30)) - .build() { - Ok(c) => c, - Err(_) => return (None, None), - }; - let url = format!("https://huggingface.co/{}/resolve/main/ggml-{}.bin", repo, name); - let resp = match head_client.head(url).send().and_then(|r| r.error_for_status()) { - Ok(r) => r, - Err(_) => return (None, None), - }; - let headers = resp.headers(); - let size = headers - .get("x-linked-size") - .and_then(|v| v.to_str().ok()) - .and_then(|s| s.parse::().ok()); - let mut sha = headers - .get("x-linked-etag") - .and_then(|v| v.to_str().ok()) - .map(|s| s.trim().trim_matches('"').to_string()); - if let Some(h) = &mut sha { - h.make_ascii_lowercase(); - if h.len() != 64 || !h.chars().all(|c| c.is_ascii_hexdigit()) { - sha = None; - } - } - // Fallback: try x-xet-hash header if x-linked-etag is missing/invalid - if sha.is_none() { - sha = headers - .get("x-xet-hash") - .and_then(|v| v.to_str().ok()) - .map(|s| s.trim().trim_matches('"').to_string()); - if let Some(h) = &mut sha { - h.make_ascii_lowercase(); - if h.len() != 64 || !h.chars().all(|c| c.is_ascii_hexdigit()) { - sha = None; - } - } - } - (size, sha) -} - -fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result> { - eprintln!("Fetching online data: listing models from {}...", repo); - // Prefer the tree endpoint for reliable size/hash metadata, then fall back to model metadata - let tree_url = format!("https://huggingface.co/api/models/{}/tree/main?recursive=1", repo); - let mut out: Vec = Vec::new(); - - match client.get(tree_url).send().and_then(|r| r.error_for_status()) { - Ok(resp) => { - match resp.json::>() { - Ok(items) => { - for it in items { - let path = it.path.clone(); - if !(path.starts_with("ggml-") && path.ends_with(".bin")) { continue; } - let model_name = path.trim_start_matches("ggml-").trim_end_matches(".bin").to_string(); - let (base, subtype) = split_model_name(&model_name); - let size = size_from_tree(&it).unwrap_or(0); - let sha256 = expected_sha_from_tree(&it); - out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo }); - } - } - Err(_) => { /* fall back below */ } - } - } - Err(_) => { /* fall back below */ } - } - - if out.is_empty() { - let url = format!("https://huggingface.co/api/models/{}", repo); - let resp = client - .get(url) - .send() - .and_then(|r| r.error_for_status()) - .context("Failed to query Hugging Face API")?; - - let info: HFRepoInfo = resp.json().context("Failed to parse Hugging Face API response")?; - - if let Some(files) = info.siblings { - for s in files { - let rf = s.rfilename.clone(); - if !(rf.starts_with("ggml-") && rf.ends_with(".bin")) { continue; } - let model_name = rf.trim_start_matches("ggml-").trim_end_matches(".bin").to_string(); - let (base, subtype) = split_model_name(&model_name); - let size = size_from_sibling(&s).unwrap_or(0); - let sha256 = expected_sha_from_sibling(&s); - out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo }); - } - } - } - - // Fill missing metadata (size/hash) via HEAD request if necessary - if out.iter().any(|m| m.size == 0 || m.sha256.is_none()) { - eprintln!("Fetching online data: completing metadata checks for models in {}...", repo); - } - for m in out.iter_mut() { - if m.size == 0 || m.sha256.is_none() { - let (sz, sha) = fill_meta_via_head(m.repo, &m.name); - if m.size == 0 { - if let Some(s) = sz { m.size = s; } - } - if m.sha256.is_none() { - m.sha256 = sha; - } - } - } - - // Sort by base then subtype then name for stable listing - out.sort_by(|a, b| a.base.cmp(&b.base).then(a.subtype.cmp(&b.subtype)).then(a.name.cmp(&b.name))); - Ok(out) -} - - -fn fetch_all_models(client: &Client) -> Result> { - eprintln!("Fetching online data: aggregating available models from Hugging Face..."); - let mut v1 = hf_fetch_repo_models(client, "ggerganov/whisper.cpp")?; // main repo must succeed - - // Optional tinydiarize repo; ignore errors but log to stderr - let mut v2: Vec = match hf_fetch_repo_models(client, "akashmjn/tinydiarize-whisper.cpp") { - Ok(v) => v, - Err(e) => { - eprintln!("Warning: failed to fetch optional repo akashmjn/tinydiarize-whisper.cpp: {:#}", e); - Vec::new() - } - }; - - v1.append(&mut v2); - - // Deduplicate by name preferring ggerganov repo if duplicates - let mut map: BTreeMap = BTreeMap::new(); - for m in v1 { - map.entry(m.name.clone()).and_modify(|existing| { - if existing.repo != "ggerganov/whisper.cpp" && m.repo == "ggerganov/whisper.cpp" { - *existing = m.clone(); - } - }).or_insert(m); - } - - let mut list: Vec = map.into_values().collect(); - list.sort_by(|a, b| a.base.cmp(&b.base).then(a.subtype.cmp(&b.subtype)).then(a.name.cmp(&b.name))); - Ok(list) -} - -fn print_grouped_models(models: &[ModelEntry]) { - let mut current = "".to_string(); - for m in models { - if m.base != current { - current = m.base.clone(); - println!("\n{}:", current); - } - let short_hash = m - .sha256 - .as_ref() - .map(|h| h.chars().take(8).collect::()) - .unwrap_or_else(|| "-".to_string()); - println!(" - {} [{} | {} | {}]", m.name, m.repo, human_size(m.size), short_hash); - } - println!("\nEnter selection by indices (e.g., 1 3 5-7), or 'all', '*' for all, 'q' to cancel."); -} - -fn prompt_select_models(models: &[ModelEntry]) -> Result> { - // Build a flat list but show group headers; indices count only models - println!("Available ggml Whisper models:"); - let mut idx = 1usize; - let mut current = "".to_string(); - // We'll record mapping from index -> position in models - let mut index_map: Vec = Vec::with_capacity(models.len()); - for (pos, m) in models.iter().enumerate() { - if m.base != current { - current = m.base.clone(); - println!("\n{}:", current); - } - let short_hash = m - .sha256 - .as_ref() - .map(|h| h.chars().take(8).collect::()) - .unwrap_or_else(|| "-".to_string()); - println!(" {}) {} [{} | {} | {}]", idx, m.name, m.repo, human_size(m.size), short_hash); - index_map.push(pos); - idx += 1; - } - println!("\nEnter selection by indices (e.g., 1 3 5-7), or 'all', '*' for all, 'q' to cancel."); - loop { - eprint!("Selection: "); - io::stderr().flush().ok(); - let mut line = String::new(); - io::stdin().read_line(&mut line).context("Failed to read selection")?; - let s = line.trim().to_lowercase(); - if s == "q" || s == "quit" || s == "exit" { return Ok(Vec::new()); } - let mut selected: Vec = Vec::new(); - if s == "all" || s == "*" { - selected = (1..idx).collect(); - } else if !s.is_empty() { - for part in s.split(|c| c == ',' || c == ' ' || c == ';') { - let part = part.trim(); - if part.is_empty() { continue; } - if let Some((a, b)) = part.split_once('-') { - if let (Ok(ia), Ok(ib)) = (a.parse::(), b.parse::()) { - if ia >= 1 && ib < idx && ia <= ib { selected.extend(ia..=ib); } - } - } else if let Ok(i) = part.parse::() { - if i >= 1 && i < idx { selected.push(i); } - } - } - } - selected.sort_unstable(); - selected.dedup(); - if selected.is_empty() { - eprintln!("No valid selection. Please try again or 'q' to cancel."); - continue; - } - let chosen: Vec = selected.into_iter().map(|i| models[index_map[i - 1]].clone()).collect(); - return Ok(chosen); - } -} - -fn compute_file_sha256_hex(path: &Path) -> Result { - let file = File::open(path) - .with_context(|| format!("Failed to open for hashing: {}", path.display()))?; - let mut reader = std::io::BufReader::new(file); - let mut hasher = Sha256::new(); - let mut buf = [0u8; 1024 * 128]; - loop { - let n = reader.read(&mut buf).context("Read error during hashing")?; - if n == 0 { break; } - hasher.update(&buf[..n]); - } - Ok(to_hex_lower(&hasher.finalize())) -} - -fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> { - let final_path = models_dir.join(format!("ggml-{}.bin", entry.name)); - - // If the model already exists, verify against online metadata - if final_path.exists() { - if let Some(expected) = &entry.sha256 { - match compute_file_sha256_hex(&final_path) { - Ok(local_hash) => { - if local_hash.eq_ignore_ascii_case(expected) { - eprintln!("Model {} is up-to-date (hash match).", final_path.display()); - return Ok(()); - } else { - eprintln!( - "Local model {} hash differs from online (local {}.., online {}..). Updating...", - final_path.display(), - &local_hash[..std::cmp::min(8, local_hash.len())], - &expected[..std::cmp::min(8, expected.len())] - ); - } - } - Err(e) => { - eprintln!( - "Warning: failed to hash existing {}: {}. Will re-download to ensure correctness.", - final_path.display(), e - ); - } - } - } else if entry.size > 0 { - match std::fs::metadata(&final_path) { - Ok(md) => { - if md.len() == entry.size { - eprintln!( - "Model {} appears up-to-date by size ({}).", - final_path.display(), entry.size - ); - return Ok(()); - } else { - eprintln!( - "Local model {} size ({}) differs from online ({}). Updating...", - final_path.display(), md.len(), entry.size - ); - } - } - Err(e) => { - eprintln!( - "Warning: failed to stat existing {}: {}. Will re-download to ensure correctness.", - final_path.display(), e - ); - } - } - } else { - eprintln!( - "Model {} exists but remote hash/size not available; will download to verify contents.", - final_path.display() - ); - // Fall through to download for content comparison - } - } - - let url = format!("https://huggingface.co/{}/resolve/main/ggml-{}.bin", entry.repo, entry.name); - eprintln!("Downloading {} ({} | {})...", entry.name, human_size(entry.size), url); - let mut resp = client - .get(url) - .send() - .and_then(|r| r.error_for_status()) - .context("Failed to download model")?; - - let tmp_path = models_dir.join(format!("ggml-{}.bin.part", entry.name)); - if tmp_path.exists() { - let _ = std::fs::remove_file(&tmp_path); - } - let mut file = std::io::BufWriter::new( - File::create(&tmp_path) - .with_context(|| format!("Failed to create {}", tmp_path.display()))? - ); - - let mut hasher = Sha256::new(); - let mut buf = [0u8; 1024 * 128]; - loop { - let n = resp.read(&mut buf).context("Network read error")?; - if n == 0 { break; } - hasher.update(&buf[..n]); - file.write_all(&buf[..n]).context("Write error")?; - } - file.flush().ok(); - - let got = to_hex_lower(&hasher.finalize()); - if let Some(expected) = &entry.sha256 { - if got != expected.to_lowercase() { - let _ = std::fs::remove_file(&tmp_path); - return Err(anyhow!( - "SHA-256 mismatch for {}: expected {}, got {}", - entry.name, expected, got - )); - } - } else { - eprintln!("Warning: no SHA-256 available for {}. Skipping verification.", entry.name); - } - // Replace existing file safely - if final_path.exists() { - let _ = std::fs::remove_file(&final_path); - } - std::fs::rename(&tmp_path, &final_path) - .with_context(|| format!("Failed to move into place: {}", final_path.display()))?; - eprintln!("Saved: {}", final_path.display()); - Ok(()) -} - -fn run_interactive_model_downloader() -> Result<()> { - let models_dir = Path::new("models"); - if !models_dir.exists() { create_dir_all(models_dir).context("Failed to create models directory")?; } - let client = Client::builder() - .user_agent("dialogue_merger/0.1 (+https://github.com/)") - .timeout(std::time::Duration::from_secs(600)) - .build() - .context("Failed to build HTTP client")?; - - eprintln!("Fetching online data: contacting Hugging Face to retrieve available models (this may take a moment)..."); - let models = fetch_all_models(&client)?; - if models.is_empty() { - eprintln!("No models found on Hugging Face listing. Please try again later."); - return Ok(()); - } - let selected = prompt_select_models(&models)?; - if selected.is_empty() { - eprintln!("No selection. Aborting download."); - return Ok(()); - } - for m in selected { - if let Err(e) = download_one_model(&client, models_dir, &m) { eprintln!("Error: {:#}", e); } - } - Ok(()) -} diff --git a/src/models.rs b/src/models.rs new file mode 100644 index 0000000..d22db8d --- /dev/null +++ b/src/models.rs @@ -0,0 +1,474 @@ +use std::fs::{File, create_dir_all}; +use std::io::{self, Read, Write}; +use std::path::Path; +use std::collections::BTreeMap; +use std::time::Duration; + +use anyhow::{anyhow, Context, Result}; +use serde::Deserialize; +use reqwest::blocking::Client; +use reqwest::redirect::Policy; +use sha2::{Digest, Sha256}; + +// --- Model downloader: list & download ggml models from Hugging Face --- + +#[derive(Debug, Deserialize)] +struct HFLfsMeta { + oid: Option, + size: Option, + sha256: Option, +} + +#[derive(Debug, Deserialize)] +struct HFSibling { + rfilename: String, + size: Option, + sha256: Option, + lfs: Option, +} + +#[derive(Debug, Deserialize)] +struct HFRepoInfo { + // When using ?expand=files the field is named 'siblings' + siblings: Option>, +} + +#[derive(Debug, Deserialize)] +struct HFTreeItem { + path: String, + size: Option, + sha256: Option, + lfs: Option, +} + +#[derive(Clone, Debug)] +struct ModelEntry { + // e.g. "tiny.en-q5_1" + name: String, + base: String, + subtype: String, + size: u64, + sha256: Option, + repo: &'static str, // e.g. "ggerganov/whisper.cpp" +} + +fn split_model_name(model: &str) -> (String, String) { + let mut idx = None; + for (i, ch) in model.char_indices() { + if ch == '.' || ch == '-' { + idx = Some(i); + break; + } + } + if let Some(i) = idx { + (model[..i].to_string(), model[i + 1..].to_string()) + } else { + (model.to_string(), String::new()) + } +} + +fn human_size(bytes: u64) -> String { + const KB: f64 = 1024.0; + const MB: f64 = KB * 1024.0; + const GB: f64 = MB * 1024.0; + let b = bytes as f64; + if b >= GB { format!("{:.2} GiB", b / GB) } + else if b >= MB { format!("{:.2} MiB", b / MB) } + else if b >= KB { format!("{:.2} KiB", b / KB) } + else { format!("{} B", bytes) } +} + +fn to_hex_lower(bytes: &[u8]) -> String { + let mut s = String::with_capacity(bytes.len() * 2); + for b in bytes { s.push_str(&format!("{:02x}", b)); } + s +} + +fn expected_sha_from_sibling(s: &HFSibling) -> Option { + if let Some(h) = &s.sha256 { return Some(h.to_lowercase()); } + if let Some(lfs) = &s.lfs { + if let Some(h) = &lfs.sha256 { return Some(h.to_lowercase()); } + if let Some(oid) = &lfs.oid { + // e.g. "sha256:abcdef..." + if let Some(rest) = oid.strip_prefix("sha256:") { + return Some(rest.to_lowercase().to_string()); + } + } + } + None +} + +fn size_from_sibling(s: &HFSibling) -> Option { + if let Some(sz) = s.size { return Some(sz); } + if let Some(lfs) = &s.lfs { return lfs.size; } + None +} + +fn expected_sha_from_tree(s: &HFTreeItem) -> Option { + if let Some(h) = &s.sha256 { return Some(h.to_lowercase()); } + if let Some(lfs) = &s.lfs { + if let Some(h) = &lfs.sha256 { return Some(h.to_lowercase()); } + if let Some(oid) = &lfs.oid { + if let Some(rest) = oid.strip_prefix("sha256:") { + return Some(rest.to_lowercase().to_string()); + } + } + } + None +} + +fn size_from_tree(s: &HFTreeItem) -> Option { + if let Some(sz) = s.size { return Some(sz); } + if let Some(lfs) = &s.lfs { return lfs.size; } + None +} + +fn fill_meta_via_head(repo: &'static str, name: &str) -> (Option, Option) { + let head_client = match Client::builder() + .user_agent("dialogue_merger/0.1 (+https://github.com/)") + .redirect(Policy::none()) + .timeout(Duration::from_secs(30)) + .build() { + Ok(c) => c, + Err(_) => return (None, None), + }; + let url = format!("https://huggingface.co/{}/resolve/main/ggml-{}.bin", repo, name); + let resp = match head_client.head(url).send().and_then(|r| r.error_for_status()) { + Ok(r) => r, + Err(_) => return (None, None), + }; + let headers = resp.headers(); + let size = headers + .get("x-linked-size") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()); + let mut sha = headers + .get("x-linked-etag") + .and_then(|v| v.to_str().ok()) + .map(|s| s.trim().trim_matches('"').to_string()); + if let Some(h) = &mut sha { + h.make_ascii_lowercase(); + if h.len() != 64 || !h.chars().all(|c| c.is_ascii_hexdigit()) { + sha = None; + } + } + // Fallback: try x-xet-hash header if x-linked-etag is missing/invalid + if sha.is_none() { + sha = headers + .get("x-xet-hash") + .and_then(|v| v.to_str().ok()) + .map(|s| s.trim().trim_matches('"').to_string()); + if let Some(h) = &mut sha { + h.make_ascii_lowercase(); + if h.len() != 64 || !h.chars().all(|c| c.is_ascii_hexdigit()) { + sha = None; + } + } + } + (size, sha) +} + +fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result> { + eprintln!("Fetching online data: listing models from {}...", repo); + // Prefer the tree endpoint for reliable size/hash metadata, then fall back to model metadata + let tree_url = format!("https://huggingface.co/api/models/{}/tree/main?recursive=1", repo); + let mut out: Vec = Vec::new(); + + match client.get(tree_url).send().and_then(|r| r.error_for_status()) { + Ok(resp) => { + match resp.json::>() { + Ok(items) => { + for it in items { + let path = it.path.clone(); + if !(path.starts_with("ggml-") && path.ends_with(".bin")) { continue; } + let model_name = path.trim_start_matches("ggml-").trim_end_matches(".bin").to_string(); + let (base, subtype) = split_model_name(&model_name); + let size = size_from_tree(&it).unwrap_or(0); + let sha256 = expected_sha_from_tree(&it); + out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo }); + } + } + Err(_) => { /* fall back below */ } + } + } + Err(_) => { /* fall back below */ } + } + + if out.is_empty() { + let url = format!("https://huggingface.co/api/models/{}", repo); + let resp = client + .get(url) + .send() + .and_then(|r| r.error_for_status()) + .context("Failed to query Hugging Face API")?; + + let info: HFRepoInfo = resp.json().context("Failed to parse Hugging Face API response")?; + + if let Some(files) = info.siblings { + for s in files { + let rf = s.rfilename.clone(); + if !(rf.starts_with("ggml-") && rf.ends_with(".bin")) { continue; } + let model_name = rf.trim_start_matches("ggml-").trim_end_matches(".bin").to_string(); + let (base, subtype) = split_model_name(&model_name); + let size = size_from_sibling(&s).unwrap_or(0); + let sha256 = expected_sha_from_sibling(&s); + out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo }); + } + } + } + + // Fill missing metadata (size/hash) via HEAD request if necessary + if out.iter().any(|m| m.size == 0 || m.sha256.is_none()) { + eprintln!("Fetching online data: completing metadata checks for models in {}...", repo); + } + for m in out.iter_mut() { + if m.size == 0 || m.sha256.is_none() { + let (sz, sha) = fill_meta_via_head(m.repo, &m.name); + if m.size == 0 { + if let Some(s) = sz { m.size = s; } + } + if m.sha256.is_none() { + m.sha256 = sha; + } + } + } + + // Sort by base then subtype then name for stable listing + out.sort_by(|a, b| a.base.cmp(&b.base).then(a.subtype.cmp(&b.subtype)).then(a.name.cmp(&b.name))); + Ok(out) +} + +fn fetch_all_models(client: &Client) -> Result> { + eprintln!("Fetching online data: aggregating available models from Hugging Face..."); + let mut v1 = hf_fetch_repo_models(client, "ggerganov/whisper.cpp")?; // main repo must succeed + + // Optional tinydiarize repo; ignore errors but log to stderr + let mut v2: Vec = match hf_fetch_repo_models(client, "akashmjn/tinydiarize-whisper.cpp") { + Ok(v) => v, + Err(e) => { + eprintln!("Warning: failed to fetch optional repo akashmjn/tinydiarize-whisper.cpp: {:#}", e); + Vec::new() + } + }; + + v1.append(&mut v2); + + // Deduplicate by name preferring ggerganov repo if duplicates + let mut map: BTreeMap = BTreeMap::new(); + for m in v1 { + map.entry(m.name.clone()).and_modify(|existing| { + if existing.repo != "ggerganov/whisper.cpp" && m.repo == "ggerganov/whisper.cpp" { + *existing = m.clone(); + } + }).or_insert(m); + } + + let mut list: Vec = map.into_values().collect(); + list.sort_by(|a, b| a.base.cmp(&b.base).then(a.subtype.cmp(&b.subtype)).then(a.name.cmp(&b.name))); + Ok(list) +} + + +fn prompt_select_models(models: &[ModelEntry]) -> Result> { + // Build a flat list but show group headers; indices count only models + println!("Available ggml Whisper models:"); + let mut idx = 1usize; + let mut current = "".to_string(); + // We'll record mapping from index -> position in models + let mut index_map: Vec = Vec::with_capacity(models.len()); + for (pos, m) in models.iter().enumerate() { + if m.base != current { + current = m.base.clone(); + println!("\n{}:", current); + } + let short_hash = m + .sha256 + .as_ref() + .map(|h| h.chars().take(8).collect::()) + .unwrap_or_else(|| "-".to_string()); + println!(" {}) {} [{} | {} | {}]", idx, m.name, m.repo, human_size(m.size), short_hash); + index_map.push(pos); + idx += 1; + } + println!("\nEnter selection by indices (e.g., 1 3 5-7), or 'all', '*' for all, 'q' to cancel."); + loop { + eprint!("Selection: "); + io::stderr().flush().ok(); + let mut line = String::new(); + io::stdin().read_line(&mut line).context("Failed to read selection")?; + let s = line.trim().to_lowercase(); + if s == "q" || s == "quit" || s == "exit" { return Ok(Vec::new()); } + let mut selected: Vec = Vec::new(); + if s == "all" || s == "*" { + selected = (1..idx).collect(); + } else if !s.is_empty() { + for part in s.split(|c| c == ',' || c == ' ' || c == ';') { + let part = part.trim(); + if part.is_empty() { continue; } + if let Some((a, b)) = part.split_once('-') { + if let (Ok(ia), Ok(ib)) = (a.parse::(), b.parse::()) { + if ia >= 1 && ib < idx && ia <= ib { selected.extend(ia..=ib); } + } + } else if let Ok(i) = part.parse::() { + if i >= 1 && i < idx { selected.push(i); } + } + } + } + selected.sort_unstable(); + selected.dedup(); + if selected.is_empty() { + eprintln!("No valid selection. Please try again or 'q' to cancel."); + continue; + } + let chosen: Vec = selected.into_iter().map(|i| models[index_map[i - 1]].clone()).collect(); + return Ok(chosen); + } +} + +fn compute_file_sha256_hex(path: &Path) -> Result { + let file = File::open(path) + .with_context(|| format!("Failed to open for hashing: {}", path.display()))?; + let mut reader = std::io::BufReader::new(file); + let mut hasher = Sha256::new(); + let mut buf = [0u8; 1024 * 128]; + loop { + let n = reader.read(&mut buf).context("Read error during hashing")?; + if n == 0 { break; } + hasher.update(&buf[..n]); + } + Ok(to_hex_lower(&hasher.finalize())) +} + +pub fn run_interactive_model_downloader() -> Result<()> { + let models_dir = Path::new("models"); + if !models_dir.exists() { create_dir_all(models_dir).context("Failed to create models directory")?; } + let client = Client::builder() + .user_agent("dialogue_merger/0.1 (+https://github.com/)") + .timeout(std::time::Duration::from_secs(600)) + .build() + .context("Failed to build HTTP client")?; + + eprintln!("Fetching online data: contacting Hugging Face to retrieve available models (this may take a moment)..."); + let models = fetch_all_models(&client)?; + if models.is_empty() { + eprintln!("No models found on Hugging Face listing. Please try again later."); + return Ok(()); + } + let selected = prompt_select_models(&models)?; + if selected.is_empty() { + eprintln!("No selection. Aborting download."); + return Ok(()); + } + for m in selected { + if let Err(e) = download_one_model(&client, models_dir, &m) { eprintln!("Error: {:#}", e); } + } + Ok(()) +} + +fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> { + let final_path = models_dir.join(format!("ggml-{}.bin", entry.name)); + + // If the model already exists, verify against online metadata + if final_path.exists() { + if let Some(expected) = &entry.sha256 { + match compute_file_sha256_hex(&final_path) { + Ok(local_hash) => { + if local_hash.eq_ignore_ascii_case(expected) { + eprintln!("Model {} is up-to-date (hash match).", final_path.display()); + return Ok(()); + } else { + eprintln!( + "Local model {} hash differs from online (local {}.., online {}..). Updating...", + final_path.display(), + &local_hash[..std::cmp::min(8, local_hash.len())], + &expected[..std::cmp::min(8, expected.len())] + ); + } + } + Err(e) => { + eprintln!( + "Warning: failed to hash existing {}: {}. Will re-download to ensure correctness.", + final_path.display(), e + ); + } + } + } else if entry.size > 0 { + match std::fs::metadata(&final_path) { + Ok(md) => { + if md.len() == entry.size { + eprintln!( + "Model {} appears up-to-date by size ({}).", + final_path.display(), entry.size + ); + return Ok(()); + } else { + eprintln!( + "Local model {} size ({}) differs from online ({}). Updating...", + final_path.display(), md.len(), entry.size + ); + } + } + Err(e) => { + eprintln!( + "Warning: failed to stat existing {}: {}. Will re-download to ensure correctness.", + final_path.display(), e + ); + } + } + } else { + eprintln!( + "Model {} exists but remote hash/size not available; will download to verify contents.", + final_path.display() + ); + // Fall through to download for content comparison + } + } + + let url = format!("https://huggingface.co/{}/resolve/main/ggml-{}.bin", entry.repo, entry.name); + eprintln!("Downloading {} ({} | {})...", entry.name, human_size(entry.size), url); + let mut resp = client + .get(url) + .send() + .and_then(|r| r.error_for_status()) + .context("Failed to download model")?; + + let tmp_path = models_dir.join(format!("ggml-{}.bin.part", entry.name)); + if tmp_path.exists() { + let _ = std::fs::remove_file(&tmp_path); + } + let mut file = std::io::BufWriter::new( + File::create(&tmp_path) + .with_context(|| format!("Failed to create {}", tmp_path.display()))? + ); + + let mut hasher = Sha256::new(); + let mut buf = [0u8; 1024 * 128]; + loop { + let n = resp.read(&mut buf).context("Network read error")?; + if n == 0 { break; } + hasher.update(&buf[..n]); + file.write_all(&buf[..n]).context("Write error")?; + } + file.flush().ok(); + + let got = to_hex_lower(&hasher.finalize()); + if let Some(expected) = &entry.sha256 { + if got != expected.to_lowercase() { + let _ = std::fs::remove_file(&tmp_path); + return Err(anyhow!( + "SHA-256 mismatch for {}: expected {}, got {}", + entry.name, expected, got + )); + } + } else { + eprintln!("Warning: no SHA-256 available for {}. Skipping verification.", entry.name); + } + // Replace existing file safely + if final_path.exists() { + let _ = std::fs::remove_file(&final_path); + } + std::fs::rename(&tmp_path, &final_path) + .with_context(|| format!("Failed to move into place: {}", final_path.display()))?; + eprintln!("Saved: {}", final_path.display()); + Ok(()) +}