[feat] add auxiliary CLI commands for shell completions and man page generation; refactor logging with verbosity levels and macros; update tests and TODOs
This commit is contained in:
27
Cargo.lock
generated
27
Cargo.lock
generated
@@ -248,6 +248,15 @@ dependencies = [
|
|||||||
"strsim",
|
"strsim",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "clap_complete"
|
||||||
|
version = "4.5.56"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "67e4efcbb5da11a92e8a609233aa1e8a7d91e38de0be865f016d14700d45a7fd"
|
||||||
|
dependencies = [
|
||||||
|
"clap",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "clap_derive"
|
name = "clap_derive"
|
||||||
version = "4.5.41"
|
version = "4.5.41"
|
||||||
@@ -266,6 +275,16 @@ version = "0.7.5"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675"
|
checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "clap_mangen"
|
||||||
|
version = "0.2.29"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "27b4c3c54b30f0d9adcb47f25f61fcce35c4dd8916638c6b82fbd5f4fb4179e2"
|
||||||
|
dependencies = [
|
||||||
|
"clap",
|
||||||
|
"roff",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cmake"
|
name = "cmake"
|
||||||
version = "0.1.54"
|
version = "0.1.54"
|
||||||
@@ -1057,6 +1076,8 @@ dependencies = [
|
|||||||
"anyhow",
|
"anyhow",
|
||||||
"chrono",
|
"chrono",
|
||||||
"clap",
|
"clap",
|
||||||
|
"clap_complete",
|
||||||
|
"clap_mangen",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
@@ -1194,6 +1215,12 @@ dependencies = [
|
|||||||
"windows-sys 0.52.0",
|
"windows-sys 0.52.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "roff"
|
||||||
|
version = "0.2.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "88f8660c1ff60292143c98d08fc6e2f654d722db50410e3f3797d40baaf9d8f3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustc-demangle"
|
name = "rustc-demangle"
|
||||||
version = "0.1.26"
|
version = "0.1.26"
|
||||||
|
@@ -6,6 +6,8 @@ edition = "2024"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0.98"
|
anyhow = "1.0.98"
|
||||||
clap = { version = "4.5.43", features = ["derive"] }
|
clap = { version = "4.5.43", features = ["derive"] }
|
||||||
|
clap_complete = "4.5.28"
|
||||||
|
clap_mangen = "0.2"
|
||||||
serde = { version = "1.0.219", features = ["derive"] }
|
serde = { version = "1.0.219", features = ["derive"] }
|
||||||
serde_json = "1.0.142"
|
serde_json = "1.0.142"
|
||||||
toml = "0.8"
|
toml = "0.8"
|
||||||
|
2
TODO.md
2
TODO.md
@@ -9,7 +9,7 @@
|
|||||||
- [x] for merge + separate output -> if present, treat each file as separate output and also output a merged version (--merge-and-separate)
|
- [x] for merge + separate output -> if present, treat each file as separate output and also output a merged version (--merge-and-separate)
|
||||||
- [x] set speaker-names per input-file -> prompt user for each file if flag is set (--set-speaker-names)
|
- [x] set speaker-names per input-file -> prompt user for each file if flag is set (--set-speaker-names)
|
||||||
- [x] fix cli output for model display
|
- [x] fix cli output for model display
|
||||||
- refactor into proper cli app
|
- [x] refactor into proper cli app
|
||||||
- add support for video files -> use ffmpeg to extract audio
|
- add support for video files -> use ffmpeg to extract audio
|
||||||
- detect gpus and use them
|
- detect gpus and use them
|
||||||
- add error handling
|
- add error handling
|
||||||
|
92
src/main.rs
92
src/main.rs
@@ -5,16 +5,38 @@ use std::process::Command;
|
|||||||
use std::env;
|
use std::env;
|
||||||
|
|
||||||
use anyhow::{anyhow, Context, Result};
|
use anyhow::{anyhow, Context, Result};
|
||||||
use clap::Parser;
|
use clap::{Parser, Subcommand};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use chrono::Local;
|
use chrono::Local;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
|
||||||
|
use clap_complete::Shell;
|
||||||
|
|
||||||
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
|
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
|
||||||
|
|
||||||
mod models;
|
mod models;
|
||||||
|
|
||||||
static LAST_MODEL_WRITTEN: AtomicBool = AtomicBool::new(false);
|
static LAST_MODEL_WRITTEN: AtomicBool = AtomicBool::new(false);
|
||||||
|
static VERBOSE: AtomicU8 = AtomicU8::new(0);
|
||||||
|
|
||||||
|
macro_rules! vlog {
|
||||||
|
($lvl:expr, $($arg:tt)*) => {
|
||||||
|
let v = VERBOSE.load(Ordering::Relaxed);
|
||||||
|
let needed = match $lvl { 0u8 => true, 1u8 => v >= 1, 2u8 => v >= 2, _ => true };
|
||||||
|
if needed { eprintln!("INFO: {}", format!($($arg)*)); }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! warnlog {
|
||||||
|
($($arg:tt)*) => {
|
||||||
|
eprintln!("WARN: {}", format!($($arg)*));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! errorlog {
|
||||||
|
($($arg:tt)*) => {
|
||||||
|
eprintln!("ERROR: {}", format!($($arg)*));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn models_dir_path() -> PathBuf {
|
fn models_dir_path() -> PathBuf {
|
||||||
// Highest priority: explicit override
|
// Highest priority: explicit override
|
||||||
@@ -47,9 +69,30 @@ fn models_dir_path() -> PathBuf {
|
|||||||
PathBuf::from("models")
|
PathBuf::from("models")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#[derive(Subcommand, Debug, Clone)]
|
||||||
|
enum AuxCommands {
|
||||||
|
/// Generate shell completion script to stdout
|
||||||
|
Completions {
|
||||||
|
/// Shell to generate completions for
|
||||||
|
#[arg(value_enum)]
|
||||||
|
shell: Shell,
|
||||||
|
},
|
||||||
|
/// Generate a man page to stdout
|
||||||
|
Man,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(name = "PolyScribe", version, about = "Merge multiple JSON transcripts into one or transcribe audio using native whisper")]
|
#[command(name = "PolyScribe", bin_name = "polyscribe", version, about = "Merge JSON transcripts or transcribe audio using native whisper")]
|
||||||
struct Args {
|
struct Args {
|
||||||
|
/// Increase verbosity (-v, -vv). Logs go to stderr.
|
||||||
|
#[arg(short = 'v', long = "verbose", action = clap::ArgAction::Count, global = true)]
|
||||||
|
verbose: u8,
|
||||||
|
|
||||||
|
/// Optional auxiliary subcommands (completions, man)
|
||||||
|
#[command(subcommand)]
|
||||||
|
aux: Option<AuxCommands>,
|
||||||
|
|
||||||
/// Input .json transcript files or audio files to merge/transcribe
|
/// Input .json transcript files or audio files to merge/transcribe
|
||||||
inputs: Vec<String>,
|
inputs: Vec<String>,
|
||||||
|
|
||||||
@@ -243,7 +286,8 @@ fn find_model_file() -> Result<PathBuf> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if candidates.is_empty() {
|
if candidates.is_empty() {
|
||||||
eprintln!("No Whisper model files (*.bin) found in {}.", models_dir.display());
|
// In quiet mode we still prompt for models; suppress only non-essential logs
|
||||||
|
warnlog!("No Whisper model files (*.bin) found in {}.", models_dir.display());
|
||||||
eprint!("Would you like to download models now? [Y/n]: ");
|
eprint!("Would you like to download models now? [Y/n]: ");
|
||||||
io::stderr().flush().ok();
|
io::stderr().flush().ok();
|
||||||
let mut input = String::new();
|
let mut input = String::new();
|
||||||
@@ -251,7 +295,7 @@ fn find_model_file() -> Result<PathBuf> {
|
|||||||
let ans = input.trim().to_lowercase();
|
let ans = input.trim().to_lowercase();
|
||||||
if ans.is_empty() || ans == "y" || ans == "yes" {
|
if ans.is_empty() || ans == "y" || ans == "yes" {
|
||||||
if let Err(e) = models::run_interactive_model_downloader() {
|
if let Err(e) = models::run_interactive_model_downloader() {
|
||||||
eprintln!("Downloader failed: {:#}", e);
|
errorlog!("Downloader failed: {:#}", e);
|
||||||
}
|
}
|
||||||
// Re-scan
|
// Re-scan
|
||||||
candidates.clear();
|
candidates.clear();
|
||||||
@@ -292,7 +336,7 @@ fn find_model_file() -> Result<PathBuf> {
|
|||||||
if p.is_file() {
|
if p.is_file() {
|
||||||
// Also ensure it's one of the candidates (same dir)
|
// Also ensure it's one of the candidates (same dir)
|
||||||
if candidates.iter().any(|c| c == &p) {
|
if candidates.iter().any(|c| c == &p) {
|
||||||
eprintln!("Using previously selected model: {}", p.display());
|
vlog!(0, "Using previously selected model: {}", p.display());
|
||||||
return Ok(p);
|
return Ok(p);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -419,8 +463,34 @@ impl Drop for LastModelCleanup {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let args = Args::parse();
|
// Parse CLI
|
||||||
|
let mut args = Args::parse();
|
||||||
|
|
||||||
|
// Initialize verbosity
|
||||||
|
VERBOSE.store(args.verbose, Ordering::Relaxed);
|
||||||
|
|
||||||
|
// Handle auxiliary subcommands that write to stdout and exit early
|
||||||
|
if let Some(aux) = &args.aux {
|
||||||
|
use clap::CommandFactory;
|
||||||
|
match aux {
|
||||||
|
AuxCommands::Completions { shell } => {
|
||||||
|
let mut cmd = Args::command();
|
||||||
|
let bin_name = cmd.get_name().to_string();
|
||||||
|
clap_complete::generate(*shell, &mut cmd, bin_name, &mut io::stdout());
|
||||||
|
return Ok(())
|
||||||
|
}
|
||||||
|
AuxCommands::Man => {
|
||||||
|
let cmd = Args::command();
|
||||||
|
let man = clap_mangen::Man::new(cmd);
|
||||||
|
let mut out = Vec::new();
|
||||||
|
man.render(&mut out)?;
|
||||||
|
io::stdout().write_all(&out)?;
|
||||||
|
return Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Defer cleanup of .last_model until program exit
|
// Defer cleanup of .last_model until program exit
|
||||||
let models_dir_buf = models_dir_path();
|
let models_dir_buf = models_dir_path();
|
||||||
@@ -431,7 +501,7 @@ fn main() -> Result<()> {
|
|||||||
// If requested, run the interactive model downloader first. If no inputs were provided, exit after downloading.
|
// If requested, run the interactive model downloader first. If no inputs were provided, exit after downloading.
|
||||||
if args.download_models {
|
if args.download_models {
|
||||||
if let Err(e) = models::run_interactive_model_downloader() {
|
if let Err(e) = models::run_interactive_model_downloader() {
|
||||||
eprintln!("Model downloader failed: {:#}", e);
|
errorlog!("Model downloader failed: {:#}", e);
|
||||||
}
|
}
|
||||||
if args.inputs.is_empty() {
|
if args.inputs.is_empty() {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
@@ -441,7 +511,7 @@ fn main() -> Result<()> {
|
|||||||
// If requested, update local models and exit unless inputs provided to continue
|
// If requested, update local models and exit unless inputs provided to continue
|
||||||
if args.update_models {
|
if args.update_models {
|
||||||
if let Err(e) = models::update_local_models() {
|
if let Err(e) = models::update_local_models() {
|
||||||
eprintln!("Model update failed: {:#}", e);
|
errorlog!("Model update failed: {:#}", e);
|
||||||
return Err(e);
|
return Err(e);
|
||||||
}
|
}
|
||||||
// if only updating models and no inputs, exit
|
// if only updating models and no inputs, exit
|
||||||
@@ -451,6 +521,7 @@ fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Determine inputs and optional output path
|
// Determine inputs and optional output path
|
||||||
|
vlog!(1, "Parsed {} input(s)", args.inputs.len());
|
||||||
let mut inputs = args.inputs;
|
let mut inputs = args.inputs;
|
||||||
let mut output_path = args.output;
|
let mut output_path = args.output;
|
||||||
if output_path.is_none() && inputs.len() >= 2 {
|
if output_path.is_none() && inputs.len() >= 2 {
|
||||||
@@ -477,6 +548,7 @@ fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if args.merge_and_separate {
|
if args.merge_and_separate {
|
||||||
|
vlog!(1, "Mode: merge-and-separate; output_dir={:?}", output_path);
|
||||||
// Combined mode: write separate outputs per input and also a merged output set
|
// Combined mode: write separate outputs per input and also a merged output set
|
||||||
// Require an output directory
|
// Require an output directory
|
||||||
let out_dir = match output_path.as_ref() {
|
let out_dir = match output_path.as_ref() {
|
||||||
@@ -579,6 +651,7 @@ fn main() -> Result<()> {
|
|||||||
.with_context(|| format!("Failed to create output file: {}", m_srt.display()))?;
|
.with_context(|| format!("Failed to create output file: {}", m_srt.display()))?;
|
||||||
ms.write_all(m_srt_str.as_bytes())?;
|
ms.write_all(m_srt_str.as_bytes())?;
|
||||||
} else if args.merge {
|
} else if args.merge {
|
||||||
|
vlog!(1, "Mode: merge; output_base={:?}", output_path);
|
||||||
// MERGED MODE (previous default)
|
// MERGED MODE (previous default)
|
||||||
let mut entries: Vec<OutputEntry> = Vec::new();
|
let mut entries: Vec<OutputEntry> = Vec::new();
|
||||||
for input_path in &inputs {
|
for input_path in &inputs {
|
||||||
@@ -668,6 +741,7 @@ fn main() -> Result<()> {
|
|||||||
serde_json::to_writer_pretty(&mut handle, &out)?; writeln!(&mut handle)?;
|
serde_json::to_writer_pretty(&mut handle, &out)?; writeln!(&mut handle)?;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
vlog!(1, "Mode: separate; output_dir={:?}", output_path);
|
||||||
// SEPARATE MODE (default now)
|
// SEPARATE MODE (default now)
|
||||||
// If writing to stdout with multiple inputs, not supported
|
// If writing to stdout with multiple inputs, not supported
|
||||||
if output_path.is_none() && inputs.len() > 1 {
|
if output_path.is_none() && inputs.len() > 1 {
|
||||||
|
@@ -11,6 +11,13 @@ use reqwest::blocking::Client;
|
|||||||
use reqwest::redirect::Policy;
|
use reqwest::redirect::Policy;
|
||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
|
|
||||||
|
// Print to stderr only when not in quiet mode
|
||||||
|
macro_rules! qlog {
|
||||||
|
($($arg:tt)*) => {
|
||||||
|
eprintln!($($arg)*);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// --- Model downloader: list & download ggml models from Hugging Face ---
|
// --- Model downloader: list & download ggml models from Hugging Face ---
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
@@ -170,7 +177,7 @@ fn fill_meta_via_head(repo: &str, name: &str) -> (Option<u64>, Option<String>) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<ModelEntry>> {
|
fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<ModelEntry>> {
|
||||||
eprintln!("Fetching online data: listing models from {}...", repo);
|
qlog!("Fetching online data: listing models from {}...", repo);
|
||||||
// Prefer the tree endpoint for reliable size/hash metadata, then fall back to model metadata
|
// Prefer the tree endpoint for reliable size/hash metadata, then fall back to model metadata
|
||||||
let tree_url = format!("https://huggingface.co/api/models/{}/tree/main?recursive=1", repo);
|
let tree_url = format!("https://huggingface.co/api/models/{}/tree/main?recursive=1", repo);
|
||||||
let mut out: Vec<ModelEntry> = Vec::new();
|
let mut out: Vec<ModelEntry> = Vec::new();
|
||||||
@@ -220,7 +227,7 @@ fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<Model
|
|||||||
|
|
||||||
// Fill missing metadata (size/hash) via HEAD request if necessary
|
// Fill missing metadata (size/hash) via HEAD request if necessary
|
||||||
if out.iter().any(|m| m.size == 0 || m.sha256.is_none()) {
|
if out.iter().any(|m| m.size == 0 || m.sha256.is_none()) {
|
||||||
eprintln!("Fetching online data: completing metadata checks for models in {}...", repo);
|
qlog!("Fetching online data: completing metadata checks for models in {}...", repo);
|
||||||
}
|
}
|
||||||
for m in out.iter_mut() {
|
for m in out.iter_mut() {
|
||||||
if m.size == 0 || m.sha256.is_none() {
|
if m.size == 0 || m.sha256.is_none() {
|
||||||
@@ -240,14 +247,14 @@ fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<Model
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn fetch_all_models(client: &Client) -> Result<Vec<ModelEntry>> {
|
fn fetch_all_models(client: &Client) -> Result<Vec<ModelEntry>> {
|
||||||
eprintln!("Fetching online data: aggregating available models from Hugging Face...");
|
qlog!("Fetching online data: aggregating available models from Hugging Face...");
|
||||||
let mut v1 = hf_fetch_repo_models(client, "ggerganov/whisper.cpp")?; // main repo must succeed
|
let mut v1 = hf_fetch_repo_models(client, "ggerganov/whisper.cpp")?; // main repo must succeed
|
||||||
|
|
||||||
// Optional tinydiarize repo; ignore errors but log to stderr
|
// Optional tinydiarize repo; ignore errors but log to stderr
|
||||||
let mut v2: Vec<ModelEntry> = match hf_fetch_repo_models(client, "akashmjn/tinydiarize-whisper.cpp") {
|
let mut v2: Vec<ModelEntry> = match hf_fetch_repo_models(client, "akashmjn/tinydiarize-whisper.cpp") {
|
||||||
Ok(v) => v,
|
Ok(v) => v,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("Warning: failed to fetch optional repo akashmjn/tinydiarize-whisper.cpp: {:#}", e);
|
qlog!("Warning: failed to fetch optional repo akashmjn/tinydiarize-whisper.cpp: {:#}", e);
|
||||||
Vec::new()
|
Vec::new()
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -451,19 +458,19 @@ pub fn run_interactive_model_downloader() -> Result<()> {
|
|||||||
.build()
|
.build()
|
||||||
.context("Failed to build HTTP client")?;
|
.context("Failed to build HTTP client")?;
|
||||||
|
|
||||||
eprintln!("Fetching online data: contacting Hugging Face to retrieve available models (this may take a moment)...");
|
qlog!("Fetching online data: contacting Hugging Face to retrieve available models (this may take a moment)...");
|
||||||
let models = fetch_all_models(&client)?;
|
let models = fetch_all_models(&client)?;
|
||||||
if models.is_empty() {
|
if models.is_empty() {
|
||||||
eprintln!("No models found on Hugging Face listing. Please try again later.");
|
qlog!("No models found on Hugging Face listing. Please try again later.");
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
let selected = prompt_select_models_two_stage(&models)?;
|
let selected = prompt_select_models_two_stage(&models)?;
|
||||||
if selected.is_empty() {
|
if selected.is_empty() {
|
||||||
eprintln!("No selection. Aborting download.");
|
qlog!("No selection. Aborting download.");
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
for m in selected {
|
for m in selected {
|
||||||
if let Err(e) = download_one_model(&client, models_dir, &m) { eprintln!("Error: {:#}", e); }
|
if let Err(e) = download_one_model(&client, models_dir, &m) { qlog!("Error: {:#}", e); }
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -477,10 +484,10 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
|||||||
match compute_file_sha256_hex(&final_path) {
|
match compute_file_sha256_hex(&final_path) {
|
||||||
Ok(local_hash) => {
|
Ok(local_hash) => {
|
||||||
if local_hash.eq_ignore_ascii_case(expected) {
|
if local_hash.eq_ignore_ascii_case(expected) {
|
||||||
eprintln!("Model {} is up-to-date (hash match).", final_path.display());
|
qlog!("Model {} is up-to-date (hash match).", final_path.display());
|
||||||
return Ok(());
|
return Ok(());
|
||||||
} else {
|
} else {
|
||||||
eprintln!(
|
qlog!(
|
||||||
"Local model {} hash differs from online (local {}.., online {}..). Updating...",
|
"Local model {} hash differs from online (local {}.., online {}..). Updating...",
|
||||||
final_path.display(),
|
final_path.display(),
|
||||||
&local_hash[..std::cmp::min(8, local_hash.len())],
|
&local_hash[..std::cmp::min(8, local_hash.len())],
|
||||||
@@ -489,7 +496,7 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!(
|
qlog!(
|
||||||
"Warning: failed to hash existing {}: {}. Will re-download to ensure correctness.",
|
"Warning: failed to hash existing {}: {}. Will re-download to ensure correctness.",
|
||||||
final_path.display(), e
|
final_path.display(), e
|
||||||
);
|
);
|
||||||
@@ -499,27 +506,27 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
|||||||
match std::fs::metadata(&final_path) {
|
match std::fs::metadata(&final_path) {
|
||||||
Ok(md) => {
|
Ok(md) => {
|
||||||
if md.len() == entry.size {
|
if md.len() == entry.size {
|
||||||
eprintln!(
|
qlog!(
|
||||||
"Model {} appears up-to-date by size ({}).",
|
"Model {} appears up-to-date by size ({}).",
|
||||||
final_path.display(), entry.size
|
final_path.display(), entry.size
|
||||||
);
|
);
|
||||||
return Ok(());
|
return Ok(());
|
||||||
} else {
|
} else {
|
||||||
eprintln!(
|
qlog!(
|
||||||
"Local model {} size ({}) differs from online ({}). Updating...",
|
"Local model {} size ({}) differs from online ({}). Updating...",
|
||||||
final_path.display(), md.len(), entry.size
|
final_path.display(), md.len(), entry.size
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!(
|
qlog!(
|
||||||
"Warning: failed to stat existing {}: {}. Will re-download to ensure correctness.",
|
"Warning: failed to stat existing {}: {}. Will re-download to ensure correctness.",
|
||||||
final_path.display(), e
|
final_path.display(), e
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
eprintln!(
|
qlog!(
|
||||||
"Model {} exists but remote hash/size not available; will download to verify contents.",
|
"Model {} exists but remote hash/size not available; will download to verify contents.",
|
||||||
final_path.display()
|
final_path.display()
|
||||||
);
|
);
|
||||||
@@ -531,7 +538,7 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
|||||||
if let Ok(base_dir) = env::var("POLYSCRIBE_MODELS_BASE_COPY_DIR") {
|
if let Ok(base_dir) = env::var("POLYSCRIBE_MODELS_BASE_COPY_DIR") {
|
||||||
let src_path = std::path::Path::new(&base_dir).join(format!("ggml-{}.bin", entry.name));
|
let src_path = std::path::Path::new(&base_dir).join(format!("ggml-{}.bin", entry.name));
|
||||||
if src_path.exists() {
|
if src_path.exists() {
|
||||||
eprintln!("Copying {} from {}...", entry.name, src_path.display());
|
qlog!("Copying {} from {}...", entry.name, src_path.display());
|
||||||
let tmp_path = models_dir.join(format!("ggml-{}.bin.part", entry.name));
|
let tmp_path = models_dir.join(format!("ggml-{}.bin.part", entry.name));
|
||||||
if tmp_path.exists() { let _ = std::fs::remove_file(&tmp_path); }
|
if tmp_path.exists() { let _ = std::fs::remove_file(&tmp_path); }
|
||||||
std::fs::copy(&src_path, &tmp_path)
|
std::fs::copy(&src_path, &tmp_path)
|
||||||
@@ -551,13 +558,13 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
|||||||
if final_path.exists() { let _ = std::fs::remove_file(&final_path); }
|
if final_path.exists() { let _ = std::fs::remove_file(&final_path); }
|
||||||
std::fs::rename(&tmp_path, &final_path)
|
std::fs::rename(&tmp_path, &final_path)
|
||||||
.with_context(|| format!("Failed to move into place: {}", final_path.display()))?;
|
.with_context(|| format!("Failed to move into place: {}", final_path.display()))?;
|
||||||
eprintln!("Saved: {}", final_path.display());
|
qlog!("Saved: {}", final_path.display());
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let url = format!("https://huggingface.co/{}/resolve/main/ggml-{}.bin", entry.repo, entry.name);
|
let url = format!("https://huggingface.co/{}/resolve/main/ggml-{}.bin", entry.repo, entry.name);
|
||||||
eprintln!("Downloading {} ({} | {})...", entry.name, human_size(entry.size), url);
|
qlog!("Downloading {} ({} | {})...", entry.name, human_size(entry.size), url);
|
||||||
let mut resp = client
|
let mut resp = client
|
||||||
.get(url)
|
.get(url)
|
||||||
.send()
|
.send()
|
||||||
@@ -593,7 +600,7 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
eprintln!("Warning: no SHA-256 available for {}. Skipping verification.", entry.name);
|
qlog!("Warning: no SHA-256 available for {}. Skipping verification.", entry.name);
|
||||||
}
|
}
|
||||||
// Replace existing file safely
|
// Replace existing file safely
|
||||||
if final_path.exists() {
|
if final_path.exists() {
|
||||||
@@ -601,7 +608,7 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
|||||||
}
|
}
|
||||||
std::fs::rename(&tmp_path, &final_path)
|
std::fs::rename(&tmp_path, &final_path)
|
||||||
.with_context(|| format!("Failed to move into place: {}", final_path.display()))?;
|
.with_context(|| format!("Failed to move into place: {}", final_path.display()))?;
|
||||||
eprintln!("Saved: {}", final_path.display());
|
qlog!("Saved: {}", final_path.display());
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -653,10 +660,10 @@ pub fn update_local_models() -> Result<()> {
|
|||||||
match compute_file_sha256_hex(&path) {
|
match compute_file_sha256_hex(&path) {
|
||||||
Ok(local_hash) => {
|
Ok(local_hash) => {
|
||||||
if local_hash.eq_ignore_ascii_case(expected) {
|
if local_hash.eq_ignore_ascii_case(expected) {
|
||||||
eprintln!("{} is up-to-date.", fname);
|
qlog!("{} is up-to-date.", fname);
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
eprintln!(
|
qlog!(
|
||||||
"{} hash differs (local {}.. != remote {}..). Updating...",
|
"{} hash differs (local {}.. != remote {}..). Updating...",
|
||||||
fname,
|
fname,
|
||||||
&local_hash[..std::cmp::min(8, local_hash.len())],
|
&local_hash[..std::cmp::min(8, local_hash.len())],
|
||||||
@@ -665,30 +672,30 @@ pub fn update_local_models() -> Result<()> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("Warning: failed hashing {}: {}. Re-downloading.", fname, e);
|
qlog!("Warning: failed hashing {}: {}. Re-downloading.", fname, e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
download_one_model(&client, models_dir, remote)?;
|
download_one_model(&client, models_dir, remote)?;
|
||||||
} else if remote.size > 0 {
|
} else if remote.size > 0 {
|
||||||
match std::fs::metadata(&path) {
|
match std::fs::metadata(&path) {
|
||||||
Ok(md) if md.len() == remote.size => {
|
Ok(md) if md.len() == remote.size => {
|
||||||
eprintln!("{} appears up-to-date by size ({}).", fname, remote.size);
|
qlog!("{} appears up-to-date by size ({}).", fname, remote.size);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Ok(md) => {
|
Ok(md) => {
|
||||||
eprintln!("{} size {} differs from remote {}. Updating...", fname, md.len(), remote.size);
|
qlog!("{} size {} differs from remote {}. Updating...", fname, md.len(), remote.size);
|
||||||
download_one_model(&client, models_dir, remote)?;
|
download_one_model(&client, models_dir, remote)?;
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("Warning: stat failed for {}: {}. Updating...", fname, e);
|
qlog!("Warning: stat failed for {}: {}. Updating...", fname, e);
|
||||||
download_one_model(&client, models_dir, remote)?;
|
download_one_model(&client, models_dir, remote)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
eprintln!("No remote hash/size for {}. Skipping.", fname);
|
qlog!("No remote hash/size for {}. Skipping.", fname);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
eprintln!("No remote metadata for {}. Skipping.", fname);
|
qlog!("No remote metadata for {}. Skipping.", fname);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
45
tests/integration_aux.rs
Normal file
45
tests/integration_aux.rs
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
use std::process::Command;
|
||||||
|
|
||||||
|
fn bin() -> &'static str { env!("CARGO_BIN_EXE_polyscribe") }
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn aux_completions_bash_outputs_script() {
|
||||||
|
let out = Command::new(bin())
|
||||||
|
.arg("completions")
|
||||||
|
.arg("bash")
|
||||||
|
.output()
|
||||||
|
.expect("failed to run polyscribe completions bash");
|
||||||
|
assert!(out.status.success(), "completions bash exited with failure: {:?}", out.status);
|
||||||
|
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
|
||||||
|
assert!(!stdout.trim().is_empty(), "completions bash stdout is empty");
|
||||||
|
// Heuristic: bash completion scripts often contain 'complete -F' lines
|
||||||
|
assert!(stdout.contains("complete") || stdout.contains("_polyscribe"), "bash completion script did not contain expected markers");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn aux_completions_zsh_outputs_script() {
|
||||||
|
let out = Command::new(bin())
|
||||||
|
.arg("completions")
|
||||||
|
.arg("zsh")
|
||||||
|
.output()
|
||||||
|
.expect("failed to run polyscribe completions zsh");
|
||||||
|
assert!(out.status.success(), "completions zsh exited with failure: {:?}", out.status);
|
||||||
|
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
|
||||||
|
assert!(!stdout.trim().is_empty(), "completions zsh stdout is empty");
|
||||||
|
// Heuristic: zsh completion scripts often start with '#compdef'
|
||||||
|
assert!(stdout.contains("#compdef") || stdout.contains("#compdef polyscribe"), "zsh completion script did not contain expected markers");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn aux_man_outputs_roff() {
|
||||||
|
let out = Command::new(bin())
|
||||||
|
.arg("man")
|
||||||
|
.output()
|
||||||
|
.expect("failed to run polyscribe man");
|
||||||
|
assert!(out.status.success(), "man exited with failure: {:?}", out.status);
|
||||||
|
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
|
||||||
|
assert!(!stdout.trim().is_empty(), "man stdout is empty");
|
||||||
|
// clap_mangen typically emits roff with .TH and/or section headers
|
||||||
|
let looks_like_roff = stdout.contains(".TH ") || stdout.starts_with(".TH") || stdout.contains(".SH NAME") || stdout.contains(".SH SYNOPSIS");
|
||||||
|
assert!(looks_like_roff, "man output does not look like a roff manpage; got: {}", &stdout.lines().take(3).collect::<Vec<_>>().join(" | "));
|
||||||
|
}
|
Reference in New Issue
Block a user