[feat] enable parallel model downloads with bounded concurrency and TTY-aware progress bars

This commit is contained in:
2025-08-12 10:16:21 +02:00
parent 75cfb6f160
commit f41f1a4117

View File

@@ -14,7 +14,7 @@ use reqwest::blocking::Client;
use reqwest::redirect::Policy;
use serde::Deserialize;
use sha2::{Digest, Sha256};
use indicatif::{ProgressBar, ProgressStyle};
use indicatif::{ProgressBar, ProgressStyle, MultiProgress};
use atty::Stream;
// --- Model downloader: list & download ggml models from Hugging Face ---
@@ -496,16 +496,55 @@ pub fn run_interactive_model_downloader() -> Result<()> {
qlog!("No selection. Aborting download.");
return Ok(());
}
for m in selected {
if let Err(e) = download_one_model(&client, models_dir, &m) {
elog!("Error: {:#}", e);
// Parallel downloads with bounded concurrency. Default 3; override via POLYSCRIBE_MAX_PARALLEL_DOWNLOADS (1..=6).
let max_jobs = std::env::var("POLYSCRIBE_MAX_PARALLEL_DOWNLOADS")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.map(|n| n.clamp(1, 6))
.unwrap_or(3);
// Use a MultiProgress to render per-model bars concurrently when interactive.
let mp_opt = if !crate::is_quiet() && !crate::is_no_progress() && atty::is(Stream::Stderr) {
Some(MultiProgress::new())
} else {
None
};
let mut i = 0;
while i < selected.len() {
let end = std::cmp::min(i + max_jobs, selected.len());
let mut handles = Vec::new();
for m in selected[i..end].iter().cloned() {
let client2 = client.clone();
let models_dir2 = models_dir.to_path_buf();
let pb_opt = if let Some(mp) = &mp_opt {
let pb = mp.add(ProgressBar::new(m.size));
let style = ProgressStyle::with_template("Downloading {prefix} ({total_bytes}) [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%)")
.unwrap()
.progress_chars("=>-");
pb.set_style(style);
pb.set_prefix(format!("{}", m.name));
Some(pb)
} else { None };
handles.push(std::thread::spawn(move || {
if let Err(e) = download_one_model(&client2, &models_dir2, &m, pb_opt) {
crate::elog!("Error: {:#}", e);
}
}));
}
for h in handles { let _ = h.join(); }
i = end;
}
// Drop MultiProgress after threads are joined; bars finish naturally.
drop(mp_opt);
Ok(())
}
/// Download a single model entry into the given models directory, verifying SHA-256 when available.
fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> {
fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry, pb: Option<indicatif::ProgressBar>) -> Result<()> {
let final_path = models_dir.join(format!("ggml-{}.bin", entry.name));
// If the model already exists, verify against online metadata
@@ -637,9 +676,11 @@ fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) ->
.with_context(|| format!("Failed to create {}", tmp_path.display()))?,
);
// Set up progress bar if interactive and we know size
// Set up progress bar: use provided one if present; otherwise create if interactive and we know size
let show_progress = !crate::is_quiet() && !crate::is_no_progress() && atty::is(Stream::Stderr) && entry.size > 0;
let pb_opt = if show_progress {
let pb_opt = if let Some(p) = pb {
Some(p)
} else if show_progress {
let pb = ProgressBar::new(entry.size);
let style = ProgressStyle::with_template("Downloading {prefix} ({total_bytes}) [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%)")
.unwrap()
@@ -804,18 +845,18 @@ pub fn update_local_models() -> Result<()> {
wlog!("Failed hashing {}: {}. Re-downloading.", fname, e);
}
}
download_one_model(&client, models_dir, remote)?;
download_one_model(&client, models_dir, remote, None)?;
} else if remote.size > 0 {
match std::fs::metadata(&path) {
Ok(md) => {
if qlog_size_comparison(&fname, md.len(), remote.size) {
continue;
}
download_one_model(&client, models_dir, remote)?;
download_one_model(&client, models_dir, remote, None)?;
}
Err(e) => {
wlog!("Stat failed for {}: {}. Updating...", fname, e);
download_one_model(&client, models_dir, remote)?;
download_one_model(&client, models_dir, remote, None)?;
}
}
} else {
@@ -883,7 +924,7 @@ pub fn ensure_model_available_noninteractive(model_name: &str) -> Result<std::pa
// Prefer fetching metadata to construct a proper ModelEntry
let models = fetch_all_models(&client)?;
if let Some(entry) = models.into_iter().find(|m| m.name == model_name) {
download_one_model(&client, models_dir, &entry)?;
download_one_model(&client, models_dir, &entry, None)?;
return Ok(models_dir.join(format!("ggml-{}.bin", entry.name)));
}
Err(anyhow!(