[feat] enable parallel model downloads with bounded concurrency and TTY-aware progress bars
This commit is contained in:
@@ -14,7 +14,7 @@ use reqwest::blocking::Client;
|
||||
use reqwest::redirect::Policy;
|
||||
use serde::Deserialize;
|
||||
use sha2::{Digest, Sha256};
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use indicatif::{ProgressBar, ProgressStyle, MultiProgress};
|
||||
use atty::Stream;
|
||||
|
||||
// --- Model downloader: list & download ggml models from Hugging Face ---
|
||||
@@ -496,16 +496,55 @@ pub fn run_interactive_model_downloader() -> Result<()> {
|
||||
qlog!("No selection. Aborting download.");
|
||||
return Ok(());
|
||||
}
|
||||
for m in selected {
|
||||
if let Err(e) = download_one_model(&client, models_dir, &m) {
|
||||
elog!("Error: {:#}", e);
|
||||
|
||||
// Parallel downloads with bounded concurrency. Default 3; override via POLYSCRIBE_MAX_PARALLEL_DOWNLOADS (1..=6).
|
||||
let max_jobs = std::env::var("POLYSCRIBE_MAX_PARALLEL_DOWNLOADS")
|
||||
.ok()
|
||||
.and_then(|s| s.parse::<usize>().ok())
|
||||
.map(|n| n.clamp(1, 6))
|
||||
.unwrap_or(3);
|
||||
|
||||
// Use a MultiProgress to render per-model bars concurrently when interactive.
|
||||
let mp_opt = if !crate::is_quiet() && !crate::is_no_progress() && atty::is(Stream::Stderr) {
|
||||
Some(MultiProgress::new())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut i = 0;
|
||||
while i < selected.len() {
|
||||
let end = std::cmp::min(i + max_jobs, selected.len());
|
||||
let mut handles = Vec::new();
|
||||
for m in selected[i..end].iter().cloned() {
|
||||
let client2 = client.clone();
|
||||
let models_dir2 = models_dir.to_path_buf();
|
||||
let pb_opt = if let Some(mp) = &mp_opt {
|
||||
let pb = mp.add(ProgressBar::new(m.size));
|
||||
let style = ProgressStyle::with_template("Downloading {prefix} ({total_bytes}) [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%)")
|
||||
.unwrap()
|
||||
.progress_chars("=>-");
|
||||
pb.set_style(style);
|
||||
pb.set_prefix(format!("{}", m.name));
|
||||
Some(pb)
|
||||
} else { None };
|
||||
handles.push(std::thread::spawn(move || {
|
||||
if let Err(e) = download_one_model(&client2, &models_dir2, &m, pb_opt) {
|
||||
crate::elog!("Error: {:#}", e);
|
||||
}
|
||||
}));
|
||||
}
|
||||
for h in handles { let _ = h.join(); }
|
||||
i = end;
|
||||
}
|
||||
|
||||
// Drop MultiProgress after threads are joined; bars finish naturally.
|
||||
drop(mp_opt);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Download a single model entry into the given models directory, verifying SHA-256 when available.
|
||||
fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> {
|
||||
fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry, pb: Option<indicatif::ProgressBar>) -> Result<()> {
|
||||
let final_path = models_dir.join(format!("ggml-{}.bin", entry.name));
|
||||
|
||||
// If the model already exists, verify against online metadata
|
||||
@@ -637,9 +676,11 @@ fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) ->
|
||||
.with_context(|| format!("Failed to create {}", tmp_path.display()))?,
|
||||
);
|
||||
|
||||
// Set up progress bar if interactive and we know size
|
||||
// Set up progress bar: use provided one if present; otherwise create if interactive and we know size
|
||||
let show_progress = !crate::is_quiet() && !crate::is_no_progress() && atty::is(Stream::Stderr) && entry.size > 0;
|
||||
let pb_opt = if show_progress {
|
||||
let pb_opt = if let Some(p) = pb {
|
||||
Some(p)
|
||||
} else if show_progress {
|
||||
let pb = ProgressBar::new(entry.size);
|
||||
let style = ProgressStyle::with_template("Downloading {prefix} ({total_bytes}) [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%)")
|
||||
.unwrap()
|
||||
@@ -804,18 +845,18 @@ pub fn update_local_models() -> Result<()> {
|
||||
wlog!("Failed hashing {}: {}. Re-downloading.", fname, e);
|
||||
}
|
||||
}
|
||||
download_one_model(&client, models_dir, remote)?;
|
||||
download_one_model(&client, models_dir, remote, None)?;
|
||||
} else if remote.size > 0 {
|
||||
match std::fs::metadata(&path) {
|
||||
Ok(md) => {
|
||||
if qlog_size_comparison(&fname, md.len(), remote.size) {
|
||||
continue;
|
||||
}
|
||||
download_one_model(&client, models_dir, remote)?;
|
||||
download_one_model(&client, models_dir, remote, None)?;
|
||||
}
|
||||
Err(e) => {
|
||||
wlog!("Stat failed for {}: {}. Updating...", fname, e);
|
||||
download_one_model(&client, models_dir, remote)?;
|
||||
download_one_model(&client, models_dir, remote, None)?;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -883,7 +924,7 @@ pub fn ensure_model_available_noninteractive(model_name: &str) -> Result<std::pa
|
||||
// Prefer fetching metadata to construct a proper ModelEntry
|
||||
let models = fetch_all_models(&client)?;
|
||||
if let Some(entry) = models.into_iter().find(|m| m.name == model_name) {
|
||||
download_one_model(&client, models_dir, &entry)?;
|
||||
download_one_model(&client, models_dir, &entry, None)?;
|
||||
return Ok(models_dir.join(format!("ggml-{}.bin", entry.name)));
|
||||
}
|
||||
Err(anyhow!(
|
||||
|
Reference in New Issue
Block a user