diff --git a/src/lib.rs b/src/lib.rs index c3cf47c..3d36458 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -179,7 +179,8 @@ where #[macro_export] macro_rules! elog { ($($arg:tt)*) => {{ - eprintln!("ERROR: {}", format!($($arg)*)); + // Route errors through the progress area when available so they render inside cliclack + $crate::log_with_level!("ERROR", None, true, $($arg)*); }} } /// Internal helper macro used by other logging macros to centralize the @@ -196,7 +197,11 @@ macro_rules! log_with_level { !$crate::is_quiet() }; if should_print { - eprintln!("{}: {}", $label, format!($($arg)*)); + let line = format!("{}: {}", $label, format!($($arg)*)); + // Try to render via the active progress manager (cliclack/indicatif area). + if !$crate::progress::log_line_via_global(&line) { + eprintln!("{}", line); + } } }} } diff --git a/src/main.rs b/src/main.rs index 8c5f779..6ead4d7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -291,9 +291,27 @@ fn run() -> Result<()> { } } + // Handle model management modes early and exit + if args.download_models && args.update_models { + // Avoid ambiguous behavior when both flags are set + return Err(anyhow!("Choose only one: --download-models or --update-models")); + } + if args.download_models { + // Launch interactive model downloader and exit + polyscribe::models::run_interactive_model_downloader()?; + return Ok(()); + } + if args.update_models { + // Update existing local models and exit + polyscribe::models::update_local_models()?; + return Ok(()); + } + // Prefer Config-driven progress factory let pf = ProgressFactory::from_config(&config); let pm = pf.make_manager(pf.decide_mode(args.inputs.len())); + // Route subsequent INFO/WARN/DEBUG logs through the cliclack/indicatif area + polyscribe::progress::set_global_progress_manager(&pm); // Determine formats let out_formats = if args.out_format.is_empty() { @@ -313,7 +331,8 @@ fn run() -> Result<()> { let do_merge = args.merge || args.merge_and_separate; if polyscribe::verbose_level() >= 1 && !args.quiet { - eprintln!("Mode: {}", if do_merge { "merge" } else { "separate" }); + // Render mode information inside the progress/cliclack area + polyscribe::ilog!("Mode: {}", if do_merge { "merge" } else { "separate" }); } // Collect inputs and default speakers @@ -459,12 +478,13 @@ fn run() -> Result<()> { // Emit totals and summary to stderr unless quiet if !polyscribe::is_quiet() { - eprintln!("INFO: Total: {}/{} processed", summary.len(), plan.len()); - eprintln!("Summary:"); - for line in render_summary_lines(&summary) { eprintln!("{}", line); } - for (_, _, ok, _) in &summary { if !ok { eprintln!("ERR"); } } - eprintln!(); - if had_error { eprintln!("ERROR: One or more inputs failed"); } + // Print inside the progress/cliclack area + polyscribe::ilog!("Total: {}/{} processed", summary.len(), plan.len()); + polyscribe::ilog!("Summary:"); + for line in render_summary_lines(&summary) { polyscribe::ilog!("{}", line); } + for (_, _, ok, _) in &summary { if !ok { polyscribe::elog!("ERR"); } } + polyscribe::ilog!(""); + if had_error { polyscribe::elog!("One or more inputs failed"); } } if had_error { std::process::exit(2); } diff --git a/src/models.rs b/src/models.rs index c635cf5..a355ac1 100644 --- a/src/models.rs +++ b/src/models.rs @@ -440,11 +440,26 @@ fn prompt_select_models_two_stage(models: &[ModelEntry]) -> Result = Vec::new(); - for (i, label) in labels.iter().enumerate() { - if selected_labels.iter().any(|s| s == label) { - picked.push(variants[i].clone().clone()); + if selected_labels.is_empty() { + // Confirm with the user; default to "No" to prevent accidental bulk downloads. + if crate::ui::prompt_confirm(&format!("No variants selected. Download ALL {base} variants?"), false).unwrap_or(false) { + crate::qlog!("Downloading all {base} variants as requested."); + for v in &variants { + picked.push((*v).clone()); + } + } else { + // User declined; return empty selection so caller can abort gracefully. + return Ok(Vec::new()); + } + } else { + // Map labels back to entries in stable order + for (i, label) in labels.iter().enumerate() { + if selected_labels.iter().any(|s| s == label) { + picked.push(variants[i].clone()); + } } } @@ -480,6 +495,11 @@ pub fn run_interactive_model_downloader() -> Result<()> { .build() .context("Failed to build HTTP client")?; + // Set up a temporary progress manager so INFO/WARN logs render within the UI. + let pf0 = crate::progress::ProgressFactory::from_config(&crate::Config::from_globals()); + let pm0 = pf0.make_manager(crate::progress::ProgressMode::Single); + crate::progress::set_global_progress_manager(&pm0); + ilog!( "Fetching online data: contacting Hugging Face to retrieve available models (this may take a moment)..." ); @@ -493,11 +513,162 @@ pub fn run_interactive_model_downloader() -> Result<()> { qlog!("No selection. Aborting download."); return Ok(()); } + // Set up progress bars for downloads + let pf = crate::progress::ProgressFactory::from_config(&crate::Config::from_globals()); + let pm = pf.make_manager(crate::progress::ProgressMode::Multi { total_inputs: selected.len() as u64 }); + crate::progress::set_global_progress_manager(&pm); + // Install Ctrl-C cleanup to ensure partial downloads (*.part) are removed on cancel + crate::progress::install_ctrlc_cleanup(pm.clone()); + pm.set_total(selected.len()); for m in selected { - if let Err(e) = download_one_model(&client, models_dir, &m) { + let label = format!("{} ({} total)", m.name, human_size(m.size)); + let item = pm.start_item(&label); + // Initialize message + if m.size > 0 { update_item_progress(&item, 0, m.size); } + if let Err(e) = download_one_model_with_progress(&client, models_dir, &m, &item) { + item.finish_with("done"); elog!("Error: {:#}", e); } + pm.inc_completed(); } + pm.finish_all(); + Ok(()) +} + +/// Internal helper: update a per-item progress handle with bytes progress. +fn update_item_progress(item: &crate::progress::ItemHandle, done_bytes: u64, total_bytes: u64) { + let total_mib = (total_bytes as f64) / (1024.0 * 1024.0); + let done_mib = (done_bytes as f64) / (1024.0 * 1024.0); + let pct = if total_bytes > 0 { ((done_bytes as f64) * 100.0 / (total_bytes as f64)).round() } else { 0.0 }; + item.set_message(&format!("{:.2}/{:.2} MiB ({:.0}%)", done_mib, total_mib, pct)); + if total_bytes > 0 { + item.set_progress((done_bytes as f32) / (total_bytes as f32)); + } +} + +/// Internal streaming helper used by both network and tests. +fn stream_with_progress(mut reader: R, mut writer: W, total: u64, item: &crate::progress::ItemHandle) -> Result<(u64, String)> { + let mut hasher = Sha256::new(); + let mut buf = [0u8; 1024 * 128]; + let mut done: u64 = 0; + if total > 0 { + // initialize bar to determinate length 100 + item.set_progress(0.0); + } + loop { + let n = reader.read(&mut buf).context("Network/read error")?; + if n == 0 { break; } + hasher.update(&buf[..n]); + writer.write_all(&buf[..n]).context("Write error")?; + done += n as u64; + update_item_progress(item, done, total); + } + writer.flush().ok(); + let got = to_hex_lower(&hasher.finalize()); + Ok((done, got)) +} + +/// Download a single model entry into the given models directory, verifying SHA-256 when available, with visible progress. +fn download_one_model_with_progress(client: &Client, models_dir: &Path, entry: &ModelEntry, item: &crate::progress::ItemHandle) -> Result<()> { + let final_path = models_dir.join(format!("ggml-{}.bin", entry.name)); + + // Same pre-checks as the non-progress version (up-to-date checks) + if final_path.exists() { + if let Some(expected) = &entry.sha256 { + match compute_file_sha256_hex(&final_path) { + Ok(local_hash) => { + if local_hash.eq_ignore_ascii_case(expected) { + item.set_message(&format!("{} up-to-date", entry.name)); + item.set_progress(1.0); + item.finish_with("done"); + return Ok(()); + } + } + Err(_) => { /* proceed to download */ } + } + } else if entry.size > 0 { + if let Ok(md) = std::fs::metadata(&final_path) { + if md.len() == entry.size { + item.set_message(&format!("{} up-to-date", entry.name)); + item.set_progress(1.0); + item.finish_with("done"); + return Ok(()); + } + } + } + } + + // Offline/local copy mode for tests (same behavior, but reflect via item) + if let Ok(base_dir) = env::var("POLYSCRIBE_MODELS_BASE_COPY_DIR") { + let src_path = std::path::Path::new(&base_dir).join(format!("ggml-{}.bin", entry.name)); + if src_path.exists() { + let tmp_path = models_dir.join(format!("ggml-{}.bin.part", entry.name)); + if tmp_path.exists() { let _ = std::fs::remove_file(&tmp_path); } + std::fs::copy(&src_path, &tmp_path).with_context(|| { + format!("Failed to copy from {} to {}", src_path.display(), tmp_path.display()) + })?; + if let Some(expected) = &entry.sha256 { + let got = compute_file_sha256_hex(&tmp_path)?; + if !got.eq_ignore_ascii_case(expected) { + let _ = std::fs::remove_file(&tmp_path); + return Err(anyhow!("SHA-256 mismatch for {} (copied): expected {}, got {}", entry.name, expected, got)); + } + } + if final_path.exists() { let _ = std::fs::remove_file(&final_path); } + std::fs::rename(&tmp_path, &final_path).with_context(|| format!("Failed to move into place: {}", final_path.display()))?; + item.set_progress(1.0); + item.finish_with("done"); + return Ok(()); + } + } + + let url = format!( + "https://huggingface.co/{}/resolve/main/ggml-{}.bin", + entry.repo, entry.name + ); + + let mut resp = client + .get(url) + .send() + .and_then(|r| r.error_for_status()) + .context("Failed to download model")?; + + let tmp_path = models_dir.join(format!("ggml-{}.bin.part", entry.name)); + if tmp_path.exists() { let _ = std::fs::remove_file(&tmp_path); } + let mut file = std::io::BufWriter::new( + File::create(&tmp_path).with_context(|| format!("Failed to create {}", tmp_path.display()))?, + ); + + // Determine total bytes (prefer metadata/HEAD-derived entry.size) + let total = if entry.size > 0 { entry.size } else { resp.content_length().unwrap_or(0) }; + + // Stream with progress + let (_bytes, hash_hex) = stream_with_progress(&mut resp, &mut file, total, item)?; + + // Verify + item.set_message("sha256 verifying…"); + if let Some(expected) = &entry.sha256 { + if hash_hex.to_lowercase() != expected.to_lowercase() { + let _ = std::fs::remove_file(&tmp_path); + return Err(anyhow!( + "SHA-256 mismatch for {}: expected {}, got {}", + entry.name, + expected, + hash_hex + )); + } + } else { + qlog!( + "Warning: no SHA-256 available for {}. Skipping verification.", + entry.name + ); + } + + // Replace existing file safely + if final_path.exists() { let _ = std::fs::remove_file(&final_path); } + std::fs::rename(&tmp_path, &final_path) + .with_context(|| format!("Failed to move into place: {}", final_path.display()))?; + item.finish_with("done"); Ok(()) } @@ -701,6 +872,11 @@ pub fn update_local_models() -> Result<()> { .build() .context("Failed to build HTTP client")?; + // Ensure logs go through cliclack area during update as well + let pf_up = crate::progress::ProgressFactory::from_config(&crate::Config::from_globals()); + let pm_up = pf_up.make_manager(crate::progress::ProgressMode::Single); + crate::progress::set_global_progress_manager(&pm_up); + // Obtain manifest: env override or online fetch let models: Vec = if let Ok(manifest_path) = env::var("POLYSCRIBE_MODELS_MANIFEST") { @@ -1071,4 +1247,31 @@ mod tests { std::env::remove_var("HOME"); } } + + #[test] + fn test_download_progress_bar_reaches_done() { + use std::io::Cursor; + // Prepare small fake stream of 300 KiB + let data = vec![42u8; 300 * 1024]; + let total = data.len() as u64; + let cursor = Cursor::new(data); + let mut sink: Vec = Vec::new(); + let pm = crate::progress::ProgressManager::new_for_tests_multi_hidden(1); + let item = pm.start_item("test-download"); + // Stream into sink while updating progress + let (_bytes, _hash) = super::stream_with_progress(cursor, &mut sink, total, &item).unwrap(); + // Transition to verifying and finish + item.set_message("sha256 verifying…"); + item.finish_with("done"); + // Inspect current bar state + if let Some((pos, len, finished, msg)) = pm.current_state_for_tests() { + // Ensure determinate length is 100 and we reached 100 + assert_eq!(len, 100); + assert_eq!(pos, 100); + assert!(finished); + assert!(msg.contains("done")); + } else { + panic!("progress manager did not expose current state"); + } + } } diff --git a/src/progress.rs b/src/progress.rs index 4f25d00..56a2e06 100644 --- a/src/progress.rs +++ b/src/progress.rs @@ -8,6 +8,35 @@ use std::time::Instant; use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle}; +// Global hook to route logs through the active progress manager so they render within +// the same cliclack/indicatif area instead of raw stderr. +static GLOBAL_PM: std::sync::Mutex> = std::sync::Mutex::new(None); + +/// Install a global ProgressManager used for printing log lines above bars. +pub fn set_global_progress_manager(pm: &ProgressManager) { + if let Ok(mut g) = GLOBAL_PM.lock() { + *g = Some(pm.clone()); + } +} + +/// Remove the global ProgressManager hook. +pub fn clear_global_progress_manager() { + if let Ok(mut g) = GLOBAL_PM.lock() { + *g = None; + } +} + +/// Try to print a line via the global ProgressManager, returning true if handled. +pub fn log_line_via_global(line: &str) -> bool { + if let Ok(g) = GLOBAL_PM.lock() { + if let Some(pm) = g.as_ref() { + pm.println_above_bars(line); + return true; + } + } + false +} + const NAME_WIDTH: usize = 28; #[derive(Debug, Clone)] @@ -212,6 +241,25 @@ impl ProgressManager { } } + /// Test helper: get state of the current item bar (position, length, finished, message). + pub fn current_state_for_tests(&self) -> Option<(u64, u64, bool, String)> { + match &self.inner { + ProgressInner::Single(s) => Some(( + s.current.position(), + s.current.length().unwrap_or(0), + s.current.is_finished(), + s.current.message().to_string(), + )), + ProgressInner::Multi(m) => Some(( + m.current.position(), + m.current.length().unwrap_or(0), + m.current.is_finished(), + m.current.message().to_string(), + )), + ProgressInner::Noop => None, + } + } + fn noop() -> Self { Self { inner: ProgressInner::Noop, @@ -501,7 +549,7 @@ pub fn select_mode(si: SelectionInput) -> (bool, ProgressMode) { (enabled, mode) } -/// Optional Ctrl-C cleanup: clears progress bars and removes .last_model before exiting on SIGINT. +/// Optional Ctrl-C cleanup: clears progress bars and removes temporary files before exiting on SIGINT. pub fn install_ctrlc_cleanup(pm: ProgressManager) { let state = Arc::new(Mutex::new(Some(pm.clone()))); let state_clone = state.clone(); @@ -513,8 +561,20 @@ pub fn install_ctrlc_cleanup(pm: ProgressManager) { } } // Best-effort removal of the last-model cache so it doesn't persist after Ctrl-C - let last_path = crate::models_dir_path().join(".last_model"); + let models_dir = crate::models_dir_path(); + let last_path = models_dir.join(".last_model"); let _ = std::fs::remove_file(&last_path); + // Also remove any unfinished model downloads ("*.part") + if let Ok(rd) = std::fs::read_dir(&models_dir) { + for entry in rd.flatten() { + let p = entry.path(); + if let Some(name) = p.file_name().and_then(|s| s.to_str()) { + if name.ends_with(".part") { + let _ = std::fs::remove_file(&p); + } + } + } + } // Exit with 130 to reflect SIGINT std::process::exit(130); }) {