diff --git a/src/models.rs b/src/models.rs index 4144c71..8139f19 100644 --- a/src/models.rs +++ b/src/models.rs @@ -14,7 +14,7 @@ use reqwest::blocking::Client; use reqwest::redirect::Policy; use serde::Deserialize; use sha2::{Digest, Sha256}; -use indicatif::{ProgressBar, ProgressStyle}; +use indicatif::{ProgressBar, ProgressStyle, MultiProgress}; use atty::Stream; // --- Model downloader: list & download ggml models from Hugging Face --- @@ -496,16 +496,55 @@ pub fn run_interactive_model_downloader() -> Result<()> { qlog!("No selection. Aborting download."); return Ok(()); } - for m in selected { - if let Err(e) = download_one_model(&client, models_dir, &m) { - elog!("Error: {:#}", e); + + // Parallel downloads with bounded concurrency. Default 3; override via POLYSCRIBE_MAX_PARALLEL_DOWNLOADS (1..=6). + let max_jobs = std::env::var("POLYSCRIBE_MAX_PARALLEL_DOWNLOADS") + .ok() + .and_then(|s| s.parse::().ok()) + .map(|n| n.clamp(1, 6)) + .unwrap_or(3); + + // Use a MultiProgress to render per-model bars concurrently when interactive. + let mp_opt = if !crate::is_quiet() && !crate::is_no_progress() && atty::is(Stream::Stderr) { + Some(MultiProgress::new()) + } else { + None + }; + + let mut i = 0; + while i < selected.len() { + let end = std::cmp::min(i + max_jobs, selected.len()); + let mut handles = Vec::new(); + for m in selected[i..end].iter().cloned() { + let client2 = client.clone(); + let models_dir2 = models_dir.to_path_buf(); + let pb_opt = if let Some(mp) = &mp_opt { + let pb = mp.add(ProgressBar::new(m.size)); + let style = ProgressStyle::with_template("Downloading {prefix} ({total_bytes}) [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%)") + .unwrap() + .progress_chars("=>-"); + pb.set_style(style); + pb.set_prefix(format!("{}", m.name)); + Some(pb) + } else { None }; + handles.push(std::thread::spawn(move || { + if let Err(e) = download_one_model(&client2, &models_dir2, &m, pb_opt) { + crate::elog!("Error: {:#}", e); + } + })); } + for h in handles { let _ = h.join(); } + i = end; } + + // Drop MultiProgress after threads are joined; bars finish naturally. + drop(mp_opt); + Ok(()) } /// Download a single model entry into the given models directory, verifying SHA-256 when available. -fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> { +fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry, pb: Option) -> Result<()> { let final_path = models_dir.join(format!("ggml-{}.bin", entry.name)); // If the model already exists, verify against online metadata @@ -637,9 +676,11 @@ fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> .with_context(|| format!("Failed to create {}", tmp_path.display()))?, ); - // Set up progress bar if interactive and we know size + // Set up progress bar: use provided one if present; otherwise create if interactive and we know size let show_progress = !crate::is_quiet() && !crate::is_no_progress() && atty::is(Stream::Stderr) && entry.size > 0; - let pb_opt = if show_progress { + let pb_opt = if let Some(p) = pb { + Some(p) + } else if show_progress { let pb = ProgressBar::new(entry.size); let style = ProgressStyle::with_template("Downloading {prefix} ({total_bytes}) [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({percent}%)") .unwrap() @@ -804,18 +845,18 @@ pub fn update_local_models() -> Result<()> { wlog!("Failed hashing {}: {}. Re-downloading.", fname, e); } } - download_one_model(&client, models_dir, remote)?; + download_one_model(&client, models_dir, remote, None)?; } else if remote.size > 0 { match std::fs::metadata(&path) { Ok(md) => { if qlog_size_comparison(&fname, md.len(), remote.size) { continue; } - download_one_model(&client, models_dir, remote)?; + download_one_model(&client, models_dir, remote, None)?; } Err(e) => { wlog!("Stat failed for {}: {}. Updating...", fname, e); - download_one_model(&client, models_dir, remote)?; + download_one_model(&client, models_dir, remote, None)?; } } } else { @@ -883,7 +924,7 @@ pub fn ensure_model_available_noninteractive(model_name: &str) -> Result