[feat] enable parallel model downloads with bounded concurrency and TTY-aware progress bars
This commit is contained in:
@@ -14,7 +14,7 @@ use reqwest::blocking::Client;
|
|||||||
use reqwest::redirect::Policy;
|
use reqwest::redirect::Policy;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
use indicatif::{ProgressBar, ProgressStyle};
|
use indicatif::{ProgressBar, ProgressStyle, MultiProgress};
|
||||||
use atty::Stream;
|
use atty::Stream;
|
||||||
|
|
||||||
// --- Model downloader: list & download ggml models from Hugging Face ---
|
// --- Model downloader: list & download ggml models from Hugging Face ---
|
||||||
@@ -496,16 +496,55 @@ pub fn run_interactive_model_downloader() -> Result<()> {
|
|||||||
qlog!("No selection. Aborting download.");
|
qlog!("No selection. Aborting download.");
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
for m in selected {
|
|
||||||
if let Err(e) = download_one_model(&client, models_dir, &m) {
|
// Parallel downloads with bounded concurrency. Default 3; override via POLYSCRIBE_MAX_PARALLEL_DOWNLOADS (1..=6).
|
||||||
elog!("Error: {:#}", e);
|
let max_jobs = std::env::var("POLYSCRIBE_MAX_PARALLEL_DOWNLOADS")
|
||||||
|
.ok()
|
||||||
|
.and_then(|s| s.parse::<usize>().ok())
|
||||||
|
.map(|n| n.clamp(1, 6))
|
||||||
|
.unwrap_or(3);
|
||||||
|
|
||||||
|
// Use a MultiProgress to render per-model bars concurrently when interactive.
|
||||||
|
let mp_opt = if !crate::is_quiet() && !crate::is_no_progress() && atty::is(Stream::Stderr) {
|
||||||
|
Some(MultiProgress::new())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut i = 0;
|
||||||
|
while i < selected.len() {
|
||||||
|
let end = std::cmp::min(i + max_jobs, selected.len());
|
||||||
|
let mut handles = Vec::new();
|
||||||
|
for m in selected[i..end].iter().cloned() {
|
||||||
|
let client2 = client.clone();
|
||||||
|
let models_dir2 = models_dir.to_path_buf();
|
||||||
|
let pb_opt = if let Some(mp) = &mp_opt {
|
||||||
|
let pb = mp.add(ProgressBar::new(m.size));
|
||||||
|
let style = ProgressStyle::with_template("Downloading {prefix} ({total_bytes}) [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%)")
|
||||||
|
.unwrap()
|
||||||
|
.progress_chars("=>-");
|
||||||
|
pb.set_style(style);
|
||||||
|
pb.set_prefix(format!("{}", m.name));
|
||||||
|
Some(pb)
|
||||||
|
} else { None };
|
||||||
|
handles.push(std::thread::spawn(move || {
|
||||||
|
if let Err(e) = download_one_model(&client2, &models_dir2, &m, pb_opt) {
|
||||||
|
crate::elog!("Error: {:#}", e);
|
||||||
|
}
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
|
for h in handles { let _ = h.join(); }
|
||||||
|
i = end;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Drop MultiProgress after threads are joined; bars finish naturally.
|
||||||
|
drop(mp_opt);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Download a single model entry into the given models directory, verifying SHA-256 when available.
|
/// Download a single model entry into the given models directory, verifying SHA-256 when available.
|
||||||
fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> {
|
fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry, pb: Option<indicatif::ProgressBar>) -> Result<()> {
|
||||||
let final_path = models_dir.join(format!("ggml-{}.bin", entry.name));
|
let final_path = models_dir.join(format!("ggml-{}.bin", entry.name));
|
||||||
|
|
||||||
// If the model already exists, verify against online metadata
|
// If the model already exists, verify against online metadata
|
||||||
@@ -637,9 +676,11 @@ fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) ->
|
|||||||
.with_context(|| format!("Failed to create {}", tmp_path.display()))?,
|
.with_context(|| format!("Failed to create {}", tmp_path.display()))?,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Set up progress bar if interactive and we know size
|
// Set up progress bar: use provided one if present; otherwise create if interactive and we know size
|
||||||
let show_progress = !crate::is_quiet() && !crate::is_no_progress() && atty::is(Stream::Stderr) && entry.size > 0;
|
let show_progress = !crate::is_quiet() && !crate::is_no_progress() && atty::is(Stream::Stderr) && entry.size > 0;
|
||||||
let pb_opt = if show_progress {
|
let pb_opt = if let Some(p) = pb {
|
||||||
|
Some(p)
|
||||||
|
} else if show_progress {
|
||||||
let pb = ProgressBar::new(entry.size);
|
let pb = ProgressBar::new(entry.size);
|
||||||
let style = ProgressStyle::with_template("Downloading {prefix} ({total_bytes}) [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%)")
|
let style = ProgressStyle::with_template("Downloading {prefix} ({total_bytes}) [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%)")
|
||||||
.unwrap()
|
.unwrap()
|
||||||
@@ -804,18 +845,18 @@ pub fn update_local_models() -> Result<()> {
|
|||||||
wlog!("Failed hashing {}: {}. Re-downloading.", fname, e);
|
wlog!("Failed hashing {}: {}. Re-downloading.", fname, e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
download_one_model(&client, models_dir, remote)?;
|
download_one_model(&client, models_dir, remote, None)?;
|
||||||
} else if remote.size > 0 {
|
} else if remote.size > 0 {
|
||||||
match std::fs::metadata(&path) {
|
match std::fs::metadata(&path) {
|
||||||
Ok(md) => {
|
Ok(md) => {
|
||||||
if qlog_size_comparison(&fname, md.len(), remote.size) {
|
if qlog_size_comparison(&fname, md.len(), remote.size) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
download_one_model(&client, models_dir, remote)?;
|
download_one_model(&client, models_dir, remote, None)?;
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
wlog!("Stat failed for {}: {}. Updating...", fname, e);
|
wlog!("Stat failed for {}: {}. Updating...", fname, e);
|
||||||
download_one_model(&client, models_dir, remote)?;
|
download_one_model(&client, models_dir, remote, None)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -883,7 +924,7 @@ pub fn ensure_model_available_noninteractive(model_name: &str) -> Result<std::pa
|
|||||||
// Prefer fetching metadata to construct a proper ModelEntry
|
// Prefer fetching metadata to construct a proper ModelEntry
|
||||||
let models = fetch_all_models(&client)?;
|
let models = fetch_all_models(&client)?;
|
||||||
if let Some(entry) = models.into_iter().find(|m| m.name == model_name) {
|
if let Some(entry) = models.into_iter().find(|m| m.name == model_name) {
|
||||||
download_one_model(&client, models_dir, &entry)?;
|
download_one_model(&client, models_dir, &entry, None)?;
|
||||||
return Ok(models_dir.join(format!("ggml-{}.bin", entry.name)));
|
return Ok(models_dir.join(format!("ggml-{}.bin", entry.name)));
|
||||||
}
|
}
|
||||||
Err(anyhow!(
|
Err(anyhow!(
|
||||||
|
Reference in New Issue
Block a user