diff --git a/src/main.rs b/src/main.rs index c1f6fd2..7241eed 100644 --- a/src/main.rs +++ b/src/main.rs @@ -74,6 +74,10 @@ struct Args { #[arg(long = "no-progress", global = true)] no_progress: bool, + /// Number of concurrent worker jobs to use when processing independent inputs. + #[arg(short = 'j', long = "jobs", value_name = "N", default_value_t = 1, global = true)] + jobs: usize, + /// Optional auxiliary subcommands (completions, man) #[command(subcommand)] aux: Option, @@ -386,42 +390,113 @@ fn run() -> Result<()> { let start_overall = Instant::now(); if do_merge { - for (i, (path, speaker)) in plan.iter().enumerate() { - let start = Instant::now(); - if !path.exists() { - had_error = true; - summary.push(( - path.file_name().and_then(|s| s.to_str().map(|s| s.to_string())).unwrap_or_else(|| path.to_string_lossy().to_string()), - speaker.clone(), - false, - start.elapsed(), - )); - if !args.continue_on_error { - break; + // Setup progress + pm.set_total(plan.len()); + + use std::sync::{Arc, atomic::{AtomicUsize, Ordering}}; + use std::thread; + use std::sync::mpsc; + + // Results channel: workers send Started and Finished events to main thread + enum Msg { + Started(usize, String), + Finished(usize, Result<(Vec, String /*disp_name*/, bool /*ok*/ , ::std::time::Duration)>), + } + + let (tx, rx) = mpsc::channel::(); + let next = Arc::new(AtomicUsize::new(0)); + let jobs = args.jobs.max(1).min(plan.len().max(1)); + + let plan_arc: Arc> = Arc::new(plan.clone()); + + let mut workers = Vec::new(); + for _ in 0..jobs { + let tx = tx.clone(); + let next = Arc::clone(&next); + let plan = Arc::clone(&plan_arc); + let read_json_file = read_json_file; // move fn item + workers.push(thread::spawn(move || { + loop { + let idx = next.fetch_add(1, Ordering::SeqCst); + if idx >= plan.len() { break; } + let (path, speaker) = (&plan[idx].0, &plan[idx].1); + // Notify started (use display name) + let disp = path.file_name().and_then(|s| s.to_str()).map(|s| s.to_string()).unwrap_or_else(|| path.to_string_lossy().to_string()); + let _ = tx.send(Msg::Started(idx, disp.clone())); + let start = Instant::now(); + // Process only JSON and existence checks here + let res: Result<(Vec, String, bool, ::std::time::Duration)> = (|| { + if !path.exists() { + return Ok((Vec::new(), disp.clone(), false, start.elapsed())); + } + if is_json_file(path) { + let root = read_json_file(path)?; + Ok((root.segments, disp.clone(), true, start.elapsed())) + } else if is_audio_file(path) { + // Audio path not implemented here for parallel read; handle later if needed + Ok((Vec::new(), disp.clone(), true, start.elapsed())) + } else { + // Unknown type: mark as error + Ok((Vec::new(), disp.clone(), false, start.elapsed())) + } + })(); + let _ = tx.send(Msg::Finished(idx, res)); + } + })); + } + drop(tx); // close original sender + + // Collect results deterministically by index; assign IDs sequentially after all complete + let mut per_file: Vec, String /*disp_name*/, bool, ::std::time::Duration)>> = (0..plan.len()).map(|_| None).collect(); + let mut remaining = plan.len(); + while let Ok(msg) = rx.recv() { + match msg { + Msg::Started(_idx, label) => { + // Update spinner to show most recently started file + let _ih = pm.start_item(&label); + } + Msg::Finished(idx, res) => { + match res { + Ok((segments, disp, ok, dur)) => { + per_file[idx] = Some((segments, disp, ok, dur)); + } + Err(e) => { + // Treat as failure for this file; store empty segments + per_file[idx] = Some((Vec::new(), format!("{}", e), false, ::std::time::Duration::from_millis(0))); + } + } + pm.inc_completed(); + remaining -= 1; + if remaining == 0 { break; } } - continue; } - if is_json_file(path) { - let root = read_json_file(path)?; - for (idx, seg) in root.segments.iter().enumerate() { + } + // Join workers + for w in workers { let _ = w.join(); } + + // Now, sequentially assign final IDs in input order + for (i, maybe) in per_file.into_iter().enumerate() { + let (segments, disp, ok, dur) = maybe.unwrap_or((Vec::new(), String::new(), false, ::std::time::Duration::from_millis(0))); + let (_path, speaker) = (&plan[i].0, &plan[i].1); + if ok { + for seg in segments { merged_items.push(polyscribe::OutputEntry { - id: (merged_items.len() as u64), + id: merged_items.len() as u64, speaker: speaker.clone(), start: seg.start, end: seg.end, - text: seg.text.clone(), + text: seg.text, }); } - } else if is_audio_file(path) { - // Not exercised by tests; skip for now. + } else { + had_error = true; + if !args.continue_on_error { + // If not continuing, stop building and reflect failure below + } } - summary.push(( - path.file_name().and_then(|s| s.to_str().map(|s| s.to_string())).unwrap_or_else(|| path.to_string_lossy().to_string()), - speaker.clone(), - true, - start.elapsed(), - )); - let _ = i; // silence unused in case + // push summary deterministic by input index + summary.push((disp, speaker.clone(), ok, dur)); + if !ok && !args.continue_on_error { break; } } // Write merged outputs diff --git a/tests/deterministic_jobs.rs b/tests/deterministic_jobs.rs new file mode 100644 index 0000000..f2fa32a --- /dev/null +++ b/tests/deterministic_jobs.rs @@ -0,0 +1,62 @@ +use std::ffi::OsStr; +use std::process::{Command, Stdio}; +use std::time::Duration; + +fn bin() -> &'static str { + env!("CARGO_BIN_EXE_polyscribe") +} + +fn manifest_path(rel: &str) -> std::path::PathBuf { + let mut p = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")); + p.push(rel); + p +} + +fn run_polyscribe(args: I, timeout: Duration) -> std::io::Result +where + I: IntoIterator, + S: AsRef, +{ + let mut child = Command::new(bin()) + .args(args) + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .env_clear() + .env("CI", "1") + .env("NO_COLOR", "1") + .spawn()?; + + let start = std::time::Instant::now(); + loop { + if let Some(status) = child.try_wait()? { + let mut out = std::process::Output { status, stdout: Vec::new(), stderr: Vec::new() }; + if let Some(mut s) = child.stdout.take() { let _ = std::io::copy(&mut s, &mut out.stdout); } + if let Some(mut s) = child.stderr.take() { let _ = std::io::copy(&mut s, &mut out.stderr); } + return Ok(out); + } + if start.elapsed() >= timeout { + let _ = child.kill(); + let _ = child.wait(); + return Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "polyscribe timed out")); + } + std::thread::sleep(std::time::Duration::from_millis(10)) + } +} + +#[test] +fn merge_output_is_deterministic_across_job_counts() { + let input1 = manifest_path("input/1-s0wlz.json"); + let input2 = manifest_path("input/2-vikingowl.json"); + + let out_j1 = run_polyscribe(&[input1.as_os_str(), input2.as_os_str(), OsStr::new("-m"), OsStr::new("--jobs"), OsStr::new("1")], Duration::from_secs(30)).expect("run jobs=1"); + assert!(out_j1.status.success(), "jobs=1 failed, stderr: {}", String::from_utf8_lossy(&out_j1.stderr)); + + let out_j4 = run_polyscribe(&[input1.as_os_str(), input2.as_os_str(), OsStr::new("-m"), OsStr::new("--jobs"), OsStr::new("4")], Duration::from_secs(30)).expect("run jobs=4"); + assert!(out_j4.status.success(), "jobs=4 failed, stderr: {}", String::from_utf8_lossy(&out_j4.stderr)); + + let s1 = String::from_utf8(out_j1.stdout).expect("utf8"); + let s4 = String::from_utf8(out_j4.stdout).expect("utf8"); + + assert_eq!(s1, s4, "merged JSON stdout differs between jobs=1 and jobs=4"); +}