diff --git a/Cargo.lock b/Cargo.lock index 8a68318..81f6c41 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -291,6 +291,20 @@ dependencies = [ "roff", ] +[[package]] +name = "cliclack" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c420bdc04c123a2df04d9c5a07289195f00007af6e45ab18f55e56dc7e04b8" +dependencies = [ + "console", + "indicatif", + "once_cell", + "strsim", + "textwrap", + "zeroize", +] + [[package]] name = "cmake" version = "0.1.54" @@ -364,6 +378,19 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "dialoguer" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de" +dependencies = [ + "console", + "shell-words", + "tempfile", + "thiserror", + "zeroize", +] + [[package]] name = "digest" version = "0.10.7" @@ -1144,7 +1171,9 @@ dependencies = [ "clap", "clap_complete", "clap_mangen", + "cliclack", "ctrlc", + "dialoguer", "indicatif", "libc", "reqwest", @@ -1462,6 +1491,12 @@ dependencies = [ "digest", ] +[[package]] +name = "shell-words" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde" + [[package]] name = "shlex" version = "1.3.0" @@ -1480,6 +1515,12 @@ version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +[[package]] +name = "smawk" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c388c1b5e93756d0c740965c41e8822f866621d41acbdf6336a6a168f8840c" + [[package]] name = "socket2" version = "0.6.0" @@ -1573,6 +1614,37 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "textwrap" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c13547615a44dc9c452a8a534638acdf07120d4b6847c8178705da06306a3057" +dependencies = [ + "smawk", + "unicode-linebreak", + "unicode-width", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tinystr" version = "0.8.1" @@ -1756,6 +1828,12 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +[[package]] +name = "unicode-linebreak" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b09c83c3c29d37506a3e260c08c03743a6bb66a9cd432c6934ab501a190571f" + [[package]] name = "unicode-width" version = "0.2.1" @@ -2237,6 +2315,20 @@ name = "zeroize" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] [[package]] name = "zerotrie" diff --git a/Cargo.toml b/Cargo.toml index c5af94f..c596932 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,8 @@ whisper-rs = { git = "https://github.com/tazz4843/whisper-rs", default-features libc = "0.2" indicatif = "0.17" ctrlc = "3.4" +dialoguer = "0.11" +cliclack = "0.3" [dev-dependencies] tempfile = "3" diff --git a/src/lib.rs b/src/lib.rs index f03c9b5..437133b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -245,6 +245,9 @@ pub mod models; /// Progress and progress bar abstraction (TTY-aware, stderr-only) pub mod progress; +/// UI helpers for interactive prompts (cliclack-backed) +pub mod ui; + /// Transcript entry for a single segment. #[derive(Debug, serde::Serialize, Clone)] pub struct OutputEntry { @@ -515,11 +518,10 @@ where "No models available and interactive mode is disabled. Please set WHISPER_MODEL or run with --download-models." )); } - printer("Would you like to download models now? [Y/n]:"); - let mut input = String::new(); - io::stdin().read_line(&mut input).ok(); - let ans = input.trim().to_lowercase(); - if ans.is_empty() || ans == "y" || ans == "yes" { + // Use unified cliclack confirm via UI helper + let download_now = crate::ui::prompt_confirm("Download models now?", true) + .context("prompt error during confirmation")?; + if download_now { if let Err(e) = models::run_interactive_model_downloader() { elog!("Downloader failed: {:#}", e); } @@ -583,20 +585,19 @@ where // Print a blank line and the selection prompt using the provided printer to // keep output synchronized with any active progress rendering. printer(""); - printer(&format!("Select model by number [1-{}]:", candidates.len())); - let mut input = String::new(); - io::stdin() - .read_line(&mut input) + let prompt = format!("Select model [1-{}]:", candidates.len()); + // TODO(ui): migrate to cliclack::Select for model picking to standardize UI. + let sel: usize = dialoguer::Input::new() + .with_prompt(prompt) + .interact_text() .context("Failed to read selection")?; - let sel: usize = input - .trim() - .parse() - .map_err(|_| anyhow!("Invalid selection: {}", input.trim()))?; if sel == 0 || sel > candidates.len() { return Err(anyhow!("Selection out of range")); } let chosen = candidates.swap_remove(sel - 1); let _ = std::fs::write(models_dir.join(".last_model"), chosen.display().to_string()); + // Print an empty line after selection input + printer(""); Ok(chosen) } diff --git a/src/main.rs b/src/main.rs index 9381033..b585d68 100644 --- a/src/main.rs +++ b/src/main.rs @@ -144,38 +144,29 @@ fn prompt_speaker_name_for_path(path: &Path, default_name: &str, enabled: bool, // Explicitly non-interactive: never prompt return default_name.to_string(); } + let display_owned: String = path .file_name() .and_then(|s| s.to_str()) .map(|s| s.to_string()) .unwrap_or_else(|| path.to_string_lossy().to_string()); - // Synchronized prompt above any progress bars + // Render prompt above any progress bars pm.pause_for_prompt(); - pm.println_above_bars(&format!( - "Enter speaker name for {} [default: {}]:", - display_owned, default_name - )); - - let mut buf = String::new(); - let res = io::stdin().read_line(&mut buf); + let answer = { + let prompt = format!("Enter speaker name for {} [default: {}]", display_owned, default_name); + match polyscribe::ui::prompt_text(&prompt, default_name) { + Ok(ans) => ans, + Err(_) => default_name.to_string(), + } + }; pm.resume_after_prompt(); - match res { - Ok(_) => { - let raw = buf.trim(); - if raw.is_empty() { - return default_name.to_string(); - } - let sanitized = sanitize_speaker_name(raw); - if sanitized.is_empty() { - default_name.to_string() - } else { - // Defer echoing of the chosen name; caller will print a permanent line later - sanitized - } - } - Err(_) => default_name.to_string(), + let sanitized = sanitize_speaker_name(&answer); + if sanitized.is_empty() { + default_name.to_string() + } else { + sanitized } } diff --git a/src/models.rs b/src/models.rs index 6687b81..c635cf5 100644 --- a/src/models.rs +++ b/src/models.rs @@ -393,130 +393,62 @@ fn format_model_list(models: &[ModelEntry]) -> String { } fn prompt_select_models_two_stage(models: &[ModelEntry]) -> Result> { + // Non-interactive safeguard: return empty (caller will handle as cancel/skip) if crate::is_no_interaction() || !crate::stdin_is_tty() { - // Non-interactive: do not prompt, return empty selection to skip return Ok(Vec::new()); } - // 1) Choose base (tiny, small, medium, etc.) + + if models.is_empty() { + return Ok(Vec::new()); + } + + // Stage 1: pick a base family; preserve order from input list let mut bases: Vec = Vec::new(); - let mut last = String::new(); + let mut seen = std::collections::BTreeSet::new(); for m in models.iter() { - if m.base != last { - // models are sorted by base; avoid duplicates while preserving order - if !bases.last().map(|b| b == &m.base).unwrap_or(false) { - bases.push(m.base.clone()); - } - last = m.base.clone(); + if !seen.contains(&m.base) { + seen.insert(m.base.clone()); + bases.push(m.base.clone()); } } - if bases.is_empty() { - return Ok(Vec::new()); - } - // Print base selection on stderr - eprintln!("Available base model families:"); - for (i, b) in bases.iter().enumerate() { - eprintln!(" {}) {}", i + 1, b); - } - loop { - eprint!("Select base (number or name, 'q' to cancel): "); - io::stderr().flush().ok(); - let mut line = String::new(); - io::stdin() - .read_line(&mut line) - .context("Failed to read base selection")?; - let s = line.trim(); - if s.eq_ignore_ascii_case("q") - || s.eq_ignore_ascii_case("quit") - || s.eq_ignore_ascii_case("exit") - { - return Ok(Vec::new()); - } - let chosen_base = if let Ok(i) = s.parse::() { - if i >= 1 && i <= bases.len() { - Some(bases[i - 1].clone()) + let base = if bases.len() == 1 { + bases[0].clone() + } else { + crate::ui::prompt_select_one("Select model family/base:", &bases)? + }; + + // Stage 2: within base, present variants + let mut variants: Vec<&ModelEntry> = models.iter().filter(|m| m.base == base).collect(); + variants.sort_by_key(|m| (m.size, m.subtype.clone(), m.name.clone())); + + let labels: Vec = variants + .iter() + .map(|m| { + let size_h = human_size(m.size); + if let Some(sha) = &m.sha256 { + format!("{} ({}, {}, sha: {}…)", m.name, m.subtype, size_h, &sha[..std::cmp::min(8, sha.len())]) } else { - None + format!("{} ({}, {})", m.name, m.subtype, size_h) } - } else if !s.is_empty() { - // accept exact name match (case-insensitive) - bases.iter().find(|b| b.eq_ignore_ascii_case(s)).cloned() - } else { - None - }; + }) + .collect(); - if let Some(base) = chosen_base { - // 2) Choose sub-type(s) within that base - let filtered: Vec = - models.iter().filter(|m| m.base == base).cloned().collect(); - if filtered.is_empty() { - eprintln!("No models found for base '{base}'."); - continue; - } - // Reuse the formatter but only for the chosen base list - let listing = format_model_list(&filtered); - eprint!("{listing}"); + let selected_labels = crate::ui::prompt_multiselect( + "Select one or more variants to download:", + &labels, + &[], + )?; - // Build index map for filtered list - let mut index_map: Vec = Vec::with_capacity(filtered.len()); - let mut idx = 1usize; - for (pos, _m) in filtered.iter().enumerate() { - index_map.push(pos); - idx += 1; - } - // Second prompt: sub-type selection - loop { - eprint!("Selection: "); - io::stderr().flush().ok(); - let mut line2 = String::new(); - io::stdin() - .read_line(&mut line2) - .context("Failed to read selection")?; - let s2 = line2.trim().to_lowercase(); - if s2 == "q" || s2 == "quit" || s2 == "exit" { - return Ok(Vec::new()); - } - let mut selected: Vec = Vec::new(); - if s2 == "all" || s2 == "*" { - selected = (1..idx).collect(); - } else if !s2.is_empty() { - for part in s2.split([',', ' ', ';']) { - let part = part.trim(); - if part.is_empty() { - continue; - } - if let Some((a, b)) = part.split_once('-') { - if let (Ok(ia), Ok(ib)) = (a.parse::(), b.parse::()) { - if ia >= 1 && ib < idx && ia <= ib { - selected.extend(ia..=ib); - } - } - } else if let Ok(i) = part.parse::() { - if i >= 1 && i < idx { - selected.push(i); - } - } - } - } - selected.sort_unstable(); - selected.dedup(); - if selected.is_empty() { - eprintln!("No valid selection. Please try again or 'q' to cancel."); - continue; - } - let chosen: Vec = selected - .into_iter() - .map(|i| filtered[index_map[i - 1]].clone()) - .collect(); - return Ok(chosen); - } - } else { - eprintln!( - "Invalid base selection. Please enter a number from 1-{} or a base name.", - bases.len() - ); + // Map labels back to entries in stable order + let mut picked: Vec = Vec::new(); + for (i, label) in labels.iter().enumerate() { + if selected_labels.iter().any(|s| s == label) { + picked.push(variants[i].clone().clone()); } } + + Ok(picked) } fn compute_file_sha256_hex(path: &Path) -> Result { diff --git a/src/progress.rs b/src/progress.rs index f872f7d..4e50227 100644 --- a/src/progress.rs +++ b/src/progress.rs @@ -124,6 +124,8 @@ enum ProgressInner { #[derive(Debug)] struct SingleBars { + header: ProgressBar, + info: ProgressBar, current: ProgressBar, // keep MultiProgress alive for suspend/println behavior _mp: Arc, @@ -131,10 +133,14 @@ struct SingleBars { #[derive(Debug)] struct MultiBars { - // Legacy bars for compatibility (used when not using per-file init) - total: ProgressBar, + // Header row shown above bars + header: ProgressBar, + // Single info/status row shown under header and above bars + info: ProgressBar, + // Bars: current file and total current: ProgressBar, - // Optional per-file bars and aggregated total percent bar + total: ProgressBar, + // Optional per-file bars and aggregated total percent bar (unused in new UX) files: Mutex>>, // each length 100 total_pct: Mutex>, // length 100 // Metadata for aggregation @@ -206,24 +212,34 @@ impl ProgressManager { } fn with_single(mp: Arc) -> Self { + // Order: header, info row, then current file bar + let header = mp.add(ProgressBar::new(0)); + header.set_style(info_style()); + let info = mp.add(ProgressBar::new(0)); + info.set_style(info_style()); let current = mp.add(ProgressBar::new(100)); - current.set_style(spinner_style()); + current.set_style(current_style()); Self { - inner: ProgressInner::Single(Arc::new(SingleBars { current, _mp: mp })), + inner: ProgressInner::Single(Arc::new(SingleBars { header, info, current, _mp: mp })), } } fn with_multi(mp: Arc, total_inputs: u64) -> Self { - // Add current first, then total so that total stays anchored at the bottom line + // Order: header, info row, then current file bar, then total bar at the bottom + let header = mp.add(ProgressBar::new(0)); + header.set_style(info_style()); + let info = mp.add(ProgressBar::new(0)); + info.set_style(info_style()); let current = mp.add(ProgressBar::new(100)); - current.set_style(spinner_style()); + current.set_style(current_style()); let total = mp.add(ProgressBar::new(total_inputs)); total.set_style(total_style()); - total.set_message("total"); Self { inner: ProgressInner::Multi(Arc::new(MultiBars { - total, + header, + info, current, + total, files: Mutex::new(None), total_pct: Mutex::new(None), sizes: Mutex::new(None), @@ -430,15 +446,19 @@ impl ProgressManager { } } -fn spinner_style() -> ProgressStyle { - // Style for per-item determinate progress: 0-100% with a compact bar and message - ProgressStyle::with_template("{bar:24.green/green} {percent:>3}% {msg}") - .unwrap() +fn current_style() -> ProgressStyle { + // Per-item determinate progress: show 0..100 as pos/len with a simple bar + ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] {pos}/{len} {bar:40.cyan/blue} {msg}") + .expect("invalid progress template in current_style()") +} + +fn info_style() -> ProgressStyle { + ProgressStyle::with_template("{msg}").unwrap() } fn total_style() -> ProgressStyle { - // Persistent bottom bar showing total completed/total inputs - ProgressStyle::with_template("{bar:40.cyan/blue} {pos}/{len} {msg}").unwrap() + // Bottom total bar with elapsed time + ProgressStyle::with_template("Total [{bar:28=> }] {pos}/{len} [{elapsed_precise}]").unwrap() } #[derive(Debug, Clone, Copy)] @@ -515,7 +535,7 @@ impl ProgressManager { for (label_in, size_opt) in labels_and_sizes { let label: String = label_in.into(); let pb = m._mp.add(ProgressBar::new(100)); - pb.set_style(spinner_style()); + pb.set_style(current_style()); let short = truncate_label(&label, NAME_WIDTH); pb.set_message(format!("{: Result { + let res: Result = cliclack::input(prompt) + .default_input(default) + .interact(); + let value = res.map_err(|e| anyhow!("prompt error: {e}"))?; + + let trimmed = value.trim(); + Ok(if trimmed.is_empty() { + default.to_string() + } else { + trimmed.to_string() + }) +} + +/// Ask for yes/no confirmation with a default choice. +/// +/// Returns the selected boolean. Any underlying prompt error is returned as an error. +pub fn prompt_confirm(prompt: &str, default: bool) -> Result { + let res: Result = cliclack::confirm(prompt) + .initial_value(default) + .interact(); + res.map_err(|e| anyhow!("prompt error: {e}")) +} + +/// Single-select from a list of displayable items, returning the selected index. +/// +/// - `items`: non-empty slice of displayable items. +/// - Returns the index into `items`. +pub fn prompt_select_index(prompt: &str, items: &[T]) -> Result { + if items.is_empty() { + return Err(anyhow!("prompt_select_index called with empty items")); + } + let mut sel = cliclack::select(prompt); + for (i, it) in items.iter().enumerate() { + sel = sel.item(i, format!("{}", it), ""); + } + let idx: usize = sel + .interact() + .map_err(|e| anyhow!("prompt error: {e}"))?; + Ok(idx) +} + +/// Single-select from a list of clonable displayable items, returning the chosen item. +pub fn prompt_select_one(prompt: &str, items: &[T]) -> Result { + let idx = prompt_select_index(prompt, items)?; + Ok(items[idx].clone()) +} + +/// Multi-select from a list, returning the selected indices. +/// +/// - `defaults`: indices that should be pre-selected. +pub fn prompt_multiselect_indices( + prompt: &str, + items: &[T], + defaults: &[usize], +) -> Result> { + if items.is_empty() { + return Err(anyhow!("prompt_multiselect_indices called with empty items")); + } + let mut ms = cliclack::multiselect(prompt); + for (i, it) in items.iter().enumerate() { + ms = ms.item(i, format!("{}", it), ""); + } + let indices: Vec = ms + .initial_values(defaults.to_vec()) + .required(false) + .interact() + .map_err(|e| anyhow!("prompt error: {e}"))?; + Ok(indices) +} + +/// Multi-select from a list, returning the chosen items in order of appearance. +pub fn prompt_multiselect( + prompt: &str, + items: &[T], + defaults: &[usize], +) -> Result> { + let indices = prompt_multiselect_indices(prompt, items, defaults)?; + Ok(indices.into_iter().map(|i| items[i].clone()).collect()) +}