[refactor] improve code readability, streamline initialization, update error handling, and format multi-line statements for consistency

This commit is contained in:
2025-08-14 11:06:37 +02:00
parent 0573369b81
commit 0a249f2197
11 changed files with 289 additions and 190 deletions

View File

@@ -11,7 +11,11 @@ pub enum GpuBackend {
}
#[derive(Debug, Parser)]
#[command(name = "polyscribe", version, about = "PolyScribe local-first transcription and plugins")]
#[command(
name = "polyscribe",
version,
about = "PolyScribe local-first transcription and plugins"
)]
pub struct Cli {
/// Increase verbosity (-v, -vv)
#[arg(short, long, action = clap::ArgAction::Count)]
@@ -120,4 +124,4 @@ pub enum PluginsCmd {
#[arg(long)]
json: Option<String>,
},
}
}

View File

@@ -1,10 +1,10 @@
mod cli;
use anyhow::{anyhow, Context, Result};
use clap::{Parser, CommandFactory};
use anyhow::{Context, Result, anyhow};
use clap::{CommandFactory, Parser};
use cli::{Cli, Commands, GpuBackend, ModelsCmd, PluginsCmd};
use polyscribe_core::{config::ConfigService, ui::progress::ProgressReporter};
use polyscribe_core::models; // Added: call into core models
use polyscribe_core::{config::ConfigService, ui::progress::ProgressReporter};
use polyscribe_host::PluginManager;
use tokio::io::AsyncWriteExt;
use tracing_subscriber::EnvFilter;
@@ -81,26 +81,25 @@ async fn main() -> Result<()> {
match cmd {
ModelsCmd::Update => {
polyscribe_core::ui::info("verifying/updating local models");
tokio::task::spawn_blocking(|| models::update_local_models())
tokio::task::spawn_blocking(models::update_local_models)
.await
.map_err(|e| anyhow!("blocking task join error: {e}"))?
.context("updating models")?;
}
ModelsCmd::Download => {
polyscribe_core::ui::info("interactive model selection and download");
tokio::task::spawn_blocking(|| models::run_interactive_model_downloader())
tokio::task::spawn_blocking(models::run_interactive_model_downloader)
.await
.map_err(|e| anyhow!("blocking task join error: {e}"))?
.context("running downloader")?;
polyscribe_core::ui::success("Model download complete.");
}
}
Ok(())
}
Commands::Plugins { cmd } => {
let pm = PluginManager::default();
let pm = PluginManager;
match cmd {
PluginsCmd::List => {
@@ -111,12 +110,18 @@ async fn main() -> Result<()> {
Ok(())
}
PluginsCmd::Info { name } => {
let info = pm.info(&name).with_context(|| format!("getting info for {}", name))?;
let info = pm
.info(&name)
.with_context(|| format!("getting info for {}", name))?;
let s = serde_json::to_string_pretty(&info)?;
polyscribe_core::ui::info(s);
Ok(())
}
PluginsCmd::Run { name, command, json } => {
PluginsCmd::Run {
name,
command,
json,
} => {
let payload = json.unwrap_or_else(|| "{}".to_string());
let mut child = pm
.spawn(&name, &command)
@@ -131,7 +136,10 @@ async fn main() -> Result<()> {
let status = pm.forward_stdio(&mut child).await?;
if !status.success() {
polyscribe_core::ui::error(format!("plugin returned non-zero exit code: {}", status));
polyscribe_core::ui::error(format!(
"plugin returned non-zero exit code: {}",
status
));
return Err(anyhow!("plugin failed"));
}
Ok(())

View File

@@ -1,10 +1,12 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025 <COPYRIGHT HOLDER>. All rights reserved.
use std::process::Command;
use assert_cmd::cargo::cargo_bin;
use std::process::Command;
fn bin() -> std::path::PathBuf { cargo_bin("polyscribe") }
fn bin() -> std::path::PathBuf {
cargo_bin("polyscribe")
}
#[test]
fn aux_completions_bash_outputs_script() {

View File

@@ -3,8 +3,9 @@
//! Transcription backend selection and implementations (CPU/GPU) used by PolyScribe.
use crate::OutputEntry;
use crate::prelude::*;
use crate::{decode_audio_to_pcm_f32_ffmpeg, find_model_file};
use anyhow::{Context, Result, anyhow};
use anyhow::{Context, anyhow};
use std::env;
use std::path::Path;
@@ -95,7 +96,9 @@ pub struct VulkanBackend;
macro_rules! impl_whisper_backend {
($ty:ty, $kind:expr) => {
impl TranscribeBackend for $ty {
fn kind(&self) -> BackendKind { $kind }
fn kind(&self) -> BackendKind {
$kind
}
fn transcribe(
&self,
audio_path: &Path,
@@ -128,7 +131,7 @@ impl TranscribeBackend for VulkanBackend {
) -> Result<Vec<OutputEntry>> {
Err(anyhow!(
"Vulkan backend not yet wired to whisper.cpp FFI. Build with --features gpu-vulkan and ensure Vulkan SDK is installed. How to fix: install Vulkan loader (libvulkan), set VULKAN_SDK, and run cargo build --features gpu-vulkan."
))
).into())
}
}
@@ -164,11 +167,11 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
let instantiate_backend = |k: BackendKind| -> Box<dyn TranscribeBackend + Send + Sync> {
match k {
BackendKind::Cpu => Box::new(CpuBackend::default()),
BackendKind::Cuda => Box::new(CudaBackend::default()),
BackendKind::Hip => Box::new(HipBackend::default()),
BackendKind::Vulkan => Box::new(VulkanBackend::default()),
BackendKind::Auto => Box::new(CpuBackend::default()), // placeholder for Auto
BackendKind::Cpu => Box::new(CpuBackend),
BackendKind::Cuda => Box::new(CudaBackend),
BackendKind::Hip => Box::new(HipBackend),
BackendKind::Vulkan => Box::new(VulkanBackend),
BackendKind::Auto => Box::new(CpuBackend), // placeholder for Auto
}
};
@@ -190,7 +193,7 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
} else {
return Err(anyhow!(
"Requested CUDA backend but CUDA libraries/devices not detected. How to fix: install NVIDIA driver + CUDA toolkit, ensure libcudart/libcublas are in loader path, and build with --features gpu-cuda."
));
).into());
}
}
BackendKind::Hip => {
@@ -199,7 +202,7 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
} else {
return Err(anyhow!(
"Requested ROCm/HIP backend but libraries/devices not detected. How to fix: install ROCm hipBLAS/rocBLAS, ensure libs are in loader path, and build with --features gpu-hip."
));
).into());
}
}
BackendKind::Vulkan => {
@@ -208,7 +211,7 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
} else {
return Err(anyhow!(
"Requested Vulkan backend but libvulkan not detected. How to fix: install Vulkan loader/SDK and build with --features gpu-vulkan."
));
).into());
}
}
BackendKind::Cpu => BackendKind::Cpu,
@@ -235,7 +238,9 @@ pub(crate) fn transcribe_with_whisper_rs(
progress: Option<&(dyn Fn(i32) + Send + Sync)>,
) -> Result<Vec<OutputEntry>> {
let report = |p: i32| {
if let Some(cb) = progress { cb(p); }
if let Some(cb) = progress {
cb(p);
}
};
report(0);
@@ -248,14 +253,15 @@ pub(crate) fn transcribe_with_whisper_rs(
.and_then(|s| s.to_str())
.map(|s| s.contains(".en.") || s.ends_with(".en.bin"))
.unwrap_or(false);
if let Some(lang) = language {
if english_only_model && lang != "en" {
return Err(anyhow!(
"Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.",
model_path.display(),
lang
));
}
if let Some(lang) = language
&& english_only_model
&& lang != "en"
{
return Err(anyhow!(
"Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.",
model_path.display(),
lang
).into());
}
let model_path_str = model_path
.to_str()

View File

@@ -9,7 +9,7 @@ const ENV_PREFIX: &str = "POLYSCRIBE";
///
/// Contains paths to models and plugins directories that can be customized
/// through configuration files or environment variables.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Config {
/// Directory path where ML models are stored
pub models_dir: Option<PathBuf>,
@@ -17,14 +17,7 @@ pub struct Config {
pub plugins_dir: Option<PathBuf>,
}
impl Default for Config {
fn default() -> Self {
Self {
models_dir: None,
plugins_dir: None,
}
}
}
// Default is derived
/// Service for managing Polyscribe configuration
///
@@ -36,7 +29,7 @@ impl ConfigService {
/// Loads configuration from disk or returns default values if not found
///
/// This function attempts to read the configuration file from disk. If the file
/// doesn't exist or can't be parsed, it falls back to default values.
/// doesn't exist or can't be parsed, it falls back to default values.
/// Environment variable overrides are then applied to the configuration.
pub fn load_or_default() -> Result<Config> {
let mut cfg = Self::read_disk().unwrap_or_default();

View File

@@ -1,7 +1,7 @@
use thiserror::Error;
#[derive(Debug, Error)]
/// Error types for the polyscribe-core crate.
#[derive(Debug, Error)]
///
/// This enum represents various error conditions that can occur during
/// operations in this crate, including I/O errors, serialization/deserialization
@@ -27,6 +27,10 @@ pub enum Error {
/// Represents an error that occurred during environment variable access
EnvVar(#[from] std::env::VarError),
#[error("http error: {0}")]
/// Represents an HTTP client error from reqwest
Http(#[from] reqwest::Error),
#[error("other: {0}")]
/// Represents a general error condition with a custom message
Other(String),

View File

@@ -12,7 +12,8 @@
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
use anyhow::{anyhow, Context, Result};
use crate::prelude::*;
use anyhow::{Context, anyhow};
use chrono::Local;
use std::env;
use std::path::{Path, PathBuf};
@@ -193,13 +194,13 @@ macro_rules! qlog {
}
pub mod backend;
pub mod models;
/// Configuration handling for PolyScribe
pub mod config;
pub mod models;
// Use the file-backed ui.rs module, which also declares its own `progress` submodule.
pub mod ui;
/// Error definitions for the PolyScribe library
pub mod error;
pub mod ui;
pub use error::Error;
pub mod prelude;
@@ -266,19 +267,19 @@ pub fn models_dir_path() -> PathBuf {
if cfg!(debug_assertions) {
return PathBuf::from("models");
}
if let Ok(xdg) = env::var("XDG_DATA_HOME") {
if !xdg.is_empty() {
return PathBuf::from(xdg).join("polyscribe").join("models");
}
if let Ok(xdg) = env::var("XDG_DATA_HOME")
&& !xdg.is_empty()
{
return PathBuf::from(xdg).join("polyscribe").join("models");
}
if let Ok(home) = env::var("HOME") {
if !home.is_empty() {
return PathBuf::from(home)
.join(".local")
.join("share")
.join("polyscribe")
.join("models");
}
if let Ok(home) = env::var("HOME")
&& !home.is_empty()
{
return PathBuf::from(home)
.join(".local")
.join("share")
.join("polyscribe")
.join("models");
}
PathBuf::from("models")
}
@@ -364,13 +365,15 @@ pub fn find_model_file() -> Result<PathBuf> {
return Err(anyhow!(
"WHISPER_MODEL points to a non-existing path: {}",
p.display()
));
)
.into());
}
if !p.is_file() {
return Err(anyhow!(
"WHISPER_MODEL must point to a file, but is not: {}",
p.display()
));
)
.into());
}
return Ok(p);
}
@@ -381,17 +384,21 @@ pub fn find_model_file() -> Result<PathBuf> {
return Err(anyhow!(
"Models path exists but is not a directory: {}",
models_dir.display()
));
)
.into());
}
std::fs::create_dir_all(&models_dir).with_context(|| {
format!("Failed to ensure models dir exists: {}", models_dir.display())
format!(
"Failed to ensure models dir exists: {}",
models_dir.display()
)
})?;
// 3) Gather candidate .bin files (regular files only), prefer largest
let mut candidates = Vec::new();
for entry in std::fs::read_dir(&models_dir).with_context(|| {
format!("Failed to read models dir: {}", models_dir.display())
})? {
for entry in std::fs::read_dir(&models_dir)
.with_context(|| format!("Failed to read models dir: {}", models_dir.display()))?
{
let entry = entry?;
let path = entry.path();
@@ -423,7 +430,8 @@ pub fn find_model_file() -> Result<PathBuf> {
"No Whisper model files (*.bin) found in {}. \
Please download a model or set WHISPER_MODEL.",
models_dir.display()
));
)
.into());
}
candidates.sort_by_key(|(size, _)| *size);
@@ -465,7 +473,8 @@ pub fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> {
return Err(anyhow!(
"ffmpeg exited with non-zero status when decoding {}",
in_path
));
)
.into());
}
let raw = std::fs::read(&tmp_raw)
@@ -476,10 +485,7 @@ pub fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> {
// Interpret raw bytes as f32 little-endian
if raw.len() % 4 != 0 {
return Err(anyhow!(
"Decoded PCM file length not multiple of 4: {}",
raw.len()
));
return Err(anyhow!("Decoded PCM file length not multiple of 4: {}", raw.len()).into());
}
let mut samples = Vec::with_capacity(raw.len() / 4);
for chunk in raw.chunks_exact(4) {

View File

@@ -4,7 +4,8 @@
//! data for verification. Falls back to scraping the repository tree page
//! if the JSON API is unavailable or incomplete. No built-in manifest.
use anyhow::{anyhow, Context, Result};
use crate::prelude::*;
use anyhow::{Context, anyhow};
use chrono::{DateTime, Utc};
use hex::ToHex;
use reqwest::blocking::Client;
@@ -34,7 +35,6 @@ fn format_size_gib(bytes: u64) -> String {
format!("{gib:.2} GiB")
}
// Short date formatter (RFC -> yyyy-mm-dd)
fn short_date(s: &str) -> String {
DateTime::parse_from_rfc3339(s)
@@ -45,7 +45,7 @@ fn short_date(s: &str) -> String {
// Free disk space using libc::statvfs (already in Cargo)
fn free_space_bytes_for_path(path: &Path) -> Result<u64> {
use libc::{statvfs, statvfs as statvfs_t};
use libc::statvfs;
use std::ffi::CString;
// use parent dir or current dir if none
@@ -58,9 +58,9 @@ fn free_space_bytes_for_path(path: &Path) -> Result<u64> {
let cpath = CString::new(dir.as_os_str().to_string_lossy().as_bytes())
.map_err(|_| anyhow!("invalid path for statvfs"))?;
unsafe {
let mut s: statvfs_t = std::mem::zeroed();
let mut s: libc::statvfs = std::mem::zeroed();
if statvfs(cpath.as_ptr(), &mut s) != 0 {
return Err(anyhow!("statvfs failed for {}", dir.display()));
return Err(anyhow!("statvfs failed for {}", dir.display()).into());
}
Ok((s.f_bsize as u64) * (s.f_bavail as u64))
}
@@ -78,9 +78,10 @@ fn mirror_label(url: &str) -> &'static str {
}
}
// Perform a HEAD to get size/etag/last-modified and fill what we can
fn head_entry(client: &Client, url: &str) -> Result<(Option<u64>, Option<String>, Option<String>, bool)> {
type HeadMeta = (Option<u64>, Option<String>, Option<String>, bool);
fn head_entry(client: &Client, url: &str) -> Result<HeadMeta> {
let resp = client.head(url).send()?.error_for_status()?;
let len = resp
.headers()
@@ -189,9 +190,7 @@ fn parse_base_variant(display_name: &str) -> (String, String) {
/// Build a manifest by calling the Hugging Face API for a repo.
/// Prefers the plain API URL, then retries with `?expand=files` if needed.
fn hf_repo_manifest_api(repo: &str) -> Result<Vec<ModelEntry>> {
let client = Client::builder()
.user_agent("polyscribe/0.1")
.build()?;
let client = Client::builder().user_agent("polyscribe/0.1").build()?;
// 1) Try the plain API you specified
let base = format!("https://huggingface.co/api/models/{}", repo);
@@ -208,14 +207,14 @@ fn hf_repo_manifest_api(repo: &str) -> Result<Vec<ModelEntry>> {
let url = format!("{base}?expand=files");
let resp2 = client.get(&url).send()?;
if !resp2.status().is_success() {
return Err(anyhow!("HF API {} for {}", resp2.status(), url));
return Err(anyhow!("HF API {} for {}", resp2.status(), url).into());
}
let info: HfModelInfo = resp2.json()?;
entries = hf_info_to_entries(repo, info)?;
}
if entries.is_empty() {
return Err(anyhow!("HF API returned no usable .bin files"));
return Err(anyhow!("HF API returned no usable .bin files").into());
}
Ok(entries)
}
@@ -274,14 +273,12 @@ fn hf_info_to_entries(repo: &str, info: HfModelInfo) -> Result<Vec<ModelEntry>>
/// Scrape the repository tree page when the API doesn't return a usable list.
/// Note: sizes and hashes are generally unavailable in this path.
fn scrape_tree_manifest(repo: &str) -> Result<Vec<ModelEntry>> {
let client = Client::builder()
.user_agent("polyscribe/0.1")
.build()?;
let client = Client::builder().user_agent("polyscribe/0.1").build()?;
let url = format!("https://huggingface.co/{}/tree/main?recursive=1", repo);
let resp = client.get(&url).send()?;
if !resp.status().is_success() {
return Err(anyhow!("tree page HTTP {} for {}", resp.status(), url));
return Err(anyhow!("tree page HTTP {} for {}", resp.status(), url).into());
}
let html = resp.text()?;
@@ -344,7 +341,7 @@ fn scrape_tree_manifest(repo: &str) -> Result<Vec<ModelEntry>> {
}
if out.is_empty() {
return Err(anyhow!("tree scraper found no .bin files"));
return Err(anyhow!("tree scraper found no .bin files").into());
}
Ok(out)
}
@@ -401,50 +398,51 @@ fn enrich_entry_via_head(entry: &mut ModelEntry) -> Result<()> {
let mut filled_lm = false;
// Content-Length
if entry.size.is_none() {
if let Some(sz) = resp
if entry.size.is_none()
&& let Some(sz) = resp
.headers()
.get(CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
{
entry.size = Some(sz);
filled_size = true;
}
{
entry.size = Some(sz);
filled_size = true;
}
// SHA256 from headers if available
if entry.sha256.is_none() {
if let Some(v) = resp.headers().get("x-linked-etag").and_then(|v| v.to_str().ok()) {
if let Some(hex) = parse_sha_from_header_value(v) {
let _ = resp
.headers()
.get("x-linked-etag")
.and_then(|v| v.to_str().ok())
.and_then(parse_sha_from_header_value)
.map(|hex| {
entry.sha256 = Some(hex);
filled_sha = true;
}
}
});
if !filled_sha {
if let Some(v) = resp
let _ = resp
.headers()
.get(ETAG)
.and_then(|v| v.to_str().ok())
{
if let Some(hex) = parse_sha_from_header_value(v) {
.and_then(parse_sha_from_header_value)
.map(|hex| {
entry.sha256 = Some(hex);
filled_sha = true;
}
}
});
}
}
// Last-Modified
if entry.last_modified.is_none() {
if let Some(v) = resp
let _ = resp
.headers()
.get(LAST_MODIFIED)
.and_then(|v| v.to_str().ok())
{
entry.last_modified = Some(v.to_string());
filled_lm = true;
}
.map(|v| {
entry.last_modified = Some(v.to_string());
filled_lm = true;
});
}
let elapsed_ms = started.elapsed().as_millis();
@@ -453,9 +451,27 @@ fn enrich_entry_via_head(entry: &mut ModelEntry) -> Result<()> {
"HEAD ok in {} ms for {} (size: {}, sha256: {}, last-modified: {})",
elapsed_ms,
entry.file,
if filled_size { "new" } else { if entry.size.is_some() { "kept" } else { "missing" } },
if filled_sha { "new" } else { if entry.sha256.is_some() { "kept" } else { "missing" } },
if filled_lm { "new" } else { if entry.last_modified.is_some() { "kept" } else { "missing" } },
if filled_size {
"new"
} else if entry.size.is_some() {
"kept"
} else {
"missing"
},
if filled_sha {
"new"
} else if entry.sha256.is_some() {
"kept"
} else {
"missing"
},
if filled_lm {
"new"
} else if entry.last_modified.is_some() {
"kept"
} else {
"missing"
},
);
Ok(())
@@ -511,7 +527,7 @@ fn current_manifest() -> Result<Vec<ModelEntry>> {
);
if list.is_empty() {
return Err(anyhow!("no usable .bin files discovered"));
return Err(anyhow!("no usable .bin files discovered").into());
}
Ok(list)
}
@@ -535,7 +551,7 @@ pub fn pick_best_local_model(dir: &Path) -> Option<PathBuf> {
/// Returns the directory where models should be stored based on platform conventions.
fn resolve_models_dir() -> Result<PathBuf> {
let dirs = directories::ProjectDirs::from("org", "polyscribe", "polyscribe")
let dirs = directories::ProjectDirs::from("dev", "polyscribe", "polyscribe")
.ok_or_else(|| anyhow!("could not determine platform directories"))?;
let data_dir = dirs.data_dir().join("models");
Ok(data_dir)
@@ -552,8 +568,7 @@ fn resolve_models_dir() -> Result<PathBuf> {
/// # Returns
/// * `Result<PathBuf>` - Path to the downloaded model file on success
pub fn ensure_model_available_noninteractive(name: &str) -> Result<PathBuf> {
let entry = find_manifest_entry(name)?
.ok_or_else(|| anyhow!("unknown model: {name}"))?;
let entry = find_manifest_entry(name)?.ok_or_else(|| anyhow!("unknown model: {name}"))?;
// Resolve destination file path; ensure XDG path (or your existing logic)
let dir = resolve_models_dir()?; // implement or reuse your existing directory resolver
@@ -655,8 +670,8 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
crate::ui::info(format!("Resolving source: {} ({})", mirror_label(url), url));
// HEAD for size/etag/ranges
let (mut total_len, remote_etag, _remote_last_mod, ranges_ok) = head_entry(&client, url)
.context("probing remote file")?;
let (mut total_len, remote_etag, _remote_last_mod, ranges_ok) =
head_entry(&client, url).context("probing remote file")?;
if total_len.is_none() {
total_len = entry.size;
@@ -670,15 +685,14 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
"insufficient disk space: need {}, have {}",
format_size_mb(Some(need)),
format_size_gib(free)
));
)
.into());
}
}
if dest_path.exists() {
if file_matches(dest_path, total_len, entry.sha256.as_deref())? {
crate::ui::info(format!("Already up to date: {}", dest_path.display()));
return Ok(());
}
if dest_path.exists() && file_matches(dest_path, total_len, entry.sha256.as_deref())? {
crate::ui::info(format!("Already up to date: {}", dest_path.display()));
return Ok(());
}
let part_path = dest_path.with_extension("part");
@@ -691,7 +705,6 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
let mut part_file = OpenOptions::new()
.create(true)
.write(true)
.read(true)
.append(true)
.open(&part_path)
@@ -719,7 +732,7 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
// Defensive: if server returns 304 but we don't have a valid cached copy, retry without conditionals.
if resp.status().as_u16() == 304 && resume_from == 0 {
// Fresh download must not be conditional; redo as plain GET
let mut req2 = client.get(url);
let req2 = client.get(url);
resp = req2.send()?.error_for_status()?;
}
@@ -729,10 +742,12 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
// Server did not honor range → start over
drop(part_file);
fs::remove_file(&part_path).ok();
resume_from = 0;
// Reset local accounting; we also reinitialize the progress bar below
// and reopen the part file. No need to re-read this variable afterwards.
let _ = 0; // avoid unused-assignment lint for resume_from
// Plain GET without conditional headers
let mut req2 = client.get(url);
let req2 = client.get(url);
resp = req2.send()?.error_for_status()?;
bar.stop("restarting");
bar = crate::ui::BytesProgress::start(pb_total, "Downloading", 0);
@@ -740,7 +755,6 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
// Reopen the part file since we dropped it
part_file = OpenOptions::new()
.create(true)
.write(true)
.read(true)
.append(true)
.open(&part_path)
@@ -782,7 +796,8 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
"checksum mismatch: expected {}, got {}",
expected_hex,
actual_hex
));
)
.into());
}
} else {
crate::ui::info("Verify: checksum not provided by source (skipped)");
@@ -830,7 +845,7 @@ fn download_with_progress(dest_path: &Path, entry: &ModelEntry) -> Result<()> {
/// Run an interactive model downloader UI (2-step):
/// 1) Choose model base (tiny, small, base, medium, large)
/// 2) Choose model type/variant specific to that base
/// Displays meta info (size and last updated). Does not show raw ggml filenames.
/// Displays meta info (size and last updated). Does not show raw ggml filenames.
pub fn run_interactive_model_downloader() -> Result<()> {
use crate::ui;
@@ -892,8 +907,14 @@ pub fn run_interactive_model_downloader() -> Result<()> {
// Prepare variant list for chosen base
let mut variants = by_base.remove(&chosen_base).unwrap_or_default();
variants.sort_by(|a, b| {
let rank = |v: &str| match v { "default" => 0, "en" => 1, _ => 2 };
rank(&a.variant).cmp(&rank(&b.variant)).then_with(|| a.variant.cmp(&b.variant))
let rank = |v: &str| match v {
"default" => 0,
"en" => 1,
_ => 2,
};
rank(&a.variant)
.cmp(&rank(&b.variant))
.then_with(|| a.variant.cmp(&b.variant))
});
// Build Multi-Select items for variants
@@ -906,12 +927,18 @@ pub fn run_interactive_model_downloader() -> Result<()> {
.map(short_date)
.map(|d| format!(" • updated {}", d))
.unwrap_or_default();
let variant_label = if m.variant == "default" { "default" } else { &m.variant };
let variant_label = if m.variant == "default" {
"default"
} else {
&m.variant
};
variant_labels.push(format!("{} ({}{})", variant_label, size, updated));
}
let variant_refs: Vec<&str> = variant_labels.iter().map(|s| s.as_str()).collect();
let mut defaults = vec![false; variant_refs.len()];
if !defaults.is_empty() { defaults[0] = true; }
if !defaults.is_empty() {
defaults[0] = true;
}
let picks = ui::prompt_multi_select(
&format!("Select types for '{}'", chosen_base),
&variant_refs,
@@ -984,8 +1011,8 @@ pub fn update_local_models() -> Result<()> {
let rd = fs::read_dir(&dir).with_context(|| format!("reading models dir {}", dir.display()))?;
let entries: Vec<_> = rd.flatten().collect();
if entries.len() == 0 {
ui::info("No local models found.".to_string());
if entries.is_empty() {
ui::info("No local models found.");
} else {
for ent in entries {
let path = ent.path();

View File

@@ -62,27 +62,32 @@ pub fn prompt_input(prompt: &str, default: Option<&str>) -> io::Result<String> {
return Ok(default.unwrap_or("").to_string());
}
let mut q = cliclack::input(prompt);
if let Some(def) = default { q = q.default_input(def); }
q.interact().map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))
if let Some(def) = default {
q = q.default_input(def);
}
q.interact().map_err(|e| io::Error::other(e.to_string()))
}
/// Present a single-choice selector and return the selected index.
pub fn prompt_select<'a>(prompt: &str, items: &[&'a str]) -> io::Result<usize> {
pub fn prompt_select(prompt: &str, items: &[&str]) -> io::Result<usize> {
if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Err(io::Error::new(io::ErrorKind::Other, "interactive prompt disabled"));
return Err(io::Error::other("interactive prompt disabled"));
}
let mut sel = cliclack::select::<usize>(prompt);
for (idx, label) in items.iter().enumerate() {
sel = sel.item(idx, *label, "");
}
sel.interact()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))
sel.interact().map_err(|e| io::Error::other(e.to_string()))
}
/// Present a multi-choice selector and return indices of selected items.
pub fn prompt_multi_select<'a>(prompt: &str, items: &[&'a str], defaults: Option<&[bool]>) -> io::Result<Vec<usize>> {
pub fn prompt_multi_select(
prompt: &str,
items: &[&str],
defaults: Option<&[bool]>,
) -> io::Result<Vec<usize>> {
if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Err(io::Error::new(io::ErrorKind::Other, "interactive prompt disabled"));
return Err(io::Error::other("interactive prompt disabled"));
}
let mut ms = cliclack::multiselect::<usize>(prompt);
for (idx, label) in items.iter().enumerate() {
@@ -98,8 +103,7 @@ pub fn prompt_multi_select<'a>(prompt: &str, items: &[&'a str], defaults: Option
ms = ms.initial_values(selected);
}
}
ms.interact()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))
ms.interact().map_err(|e| io::Error::other(e.to_string()))
}
/// Confirm prompt with default, respecting non-interactive mode.
@@ -109,32 +113,42 @@ pub fn prompt_confirm(prompt: &str, default: bool) -> io::Result<bool> {
}
let mut q = cliclack::confirm(prompt);
// If `cliclack::confirm` lacks default, we simply ask; caller can handle ESC/cancel if needed.
q.interact().map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))
q.interact().map_err(|e| io::Error::other(e.to_string()))
}
/// Read a secret/password without echoing, respecting non-interactive mode.
pub fn prompt_password(prompt: &str) -> io::Result<String> {
if crate::is_no_interaction() || !crate::stdin_is_tty() {
return Err(io::Error::new(io::ErrorKind::Other, "password prompt disabled in non-interactive mode"));
return Err(io::Error::other(
"password prompt disabled in non-interactive mode",
));
}
let mut q = cliclack::password(prompt);
q.interact().map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))
q.interact().map_err(|e| io::Error::other(e.to_string()))
}
/// Input with validation closure; on non-interactive returns default or error when no default.
pub fn prompt_input_validated<F>(prompt: &str, default: Option<&str>, validate: F) -> io::Result<String>
pub fn prompt_input_validated<F>(
prompt: &str,
default: Option<&str>,
validate: F,
) -> io::Result<String>
where
F: Fn(&str) -> Result<(), String> + 'static,
{
if crate::is_no_interaction() || !crate::stdin_is_tty() {
if let Some(def) = default { return Ok(def.to_string()); }
return Err(io::Error::new(io::ErrorKind::Other, "interactive prompt disabled"));
if let Some(def) = default {
return Ok(def.to_string());
}
return Err(io::Error::other("interactive prompt disabled"));
}
let mut q = cliclack::input(prompt);
if let Some(def) = default { q = q.default_input(def); }
if let Some(def) = default {
q = q.default_input(def);
}
q.validate(move |s: &String| validate(s))
.interact()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))
.map_err(|e| io::Error::other(e.to_string()))
}
/// A simple spinner wrapper built on top of `cliclack::spinner()`.
@@ -146,7 +160,8 @@ pub struct Spinner(cliclack::ProgressBar);
impl Spinner {
/// Creates and starts a new spinner with the provided status text.
pub fn start(text: impl AsRef<str>) -> Self {
if crate::is_no_progress() || crate::is_no_interaction() || !std::io::stderr().is_terminal() {
if crate::is_no_progress() || crate::is_no_interaction() || !std::io::stderr().is_terminal()
{
// Fallback: no spinner, but log start
let _ = cliclack::log::info(text.as_ref());
let s = cliclack::spinner();
@@ -193,28 +208,44 @@ pub struct BytesProgress(Option<cliclack::ProgressBar>);
impl BytesProgress {
/// Start a new progress bar with a total and initial position.
pub fn start(total: u64, text: &str, initial: u64) -> Self {
if crate::is_no_progress() || crate::is_no_interaction() || !std::io::stderr().is_terminal() || total == 0 {
if crate::is_no_progress()
|| crate::is_no_interaction()
|| !std::io::stderr().is_terminal()
|| total == 0
{
let _ = cliclack::log::info(text);
return Self(None);
}
let mut b = cliclack::progress_bar(total);
let b = cliclack::progress_bar(total);
b.start(text);
if initial > 0 { b.inc(initial); }
if initial > 0 {
b.inc(initial);
}
Self(Some(b))
}
/// Increment by delta bytes.
pub fn inc(&mut self, delta: u64) {
if let Some(b) = self.0.as_mut() { b.inc(delta); }
if let Some(b) = self.0.as_mut() {
b.inc(delta);
}
}
/// Stop with a message.
pub fn stop(mut self, text: &str) {
if let Some(b) = self.0.take() { b.stop(text); } else { let _ = cliclack::log::info(text); }
if let Some(b) = self.0.take() {
b.stop(text);
} else {
let _ = cliclack::log::info(text);
}
}
/// Mark as error with a message.
pub fn error(mut self, text: &str) {
if let Some(b) = self.0.take() { b.error(text); } else { let _ = cliclack::log::error(text); }
if let Some(b) = self.0.take() {
b.error(text);
} else {
let _ = cliclack::log::error(text);
}
}
}

View File

@@ -15,12 +15,21 @@ pub struct ProgressManager {
impl ProgressManager {
/// Create a new manager with the given enabled flag.
pub fn new(enabled: bool) -> Self {
Self { enabled, per: Vec::new(), total: None, completed: 0, total_len: 0 }
Self {
enabled,
per: Vec::new(),
total: None,
completed: 0,
total_len: 0,
}
}
/// Create a manager that enables bars when `n > 1`, stderr is a TTY, and not quiet.
pub fn default_for_files(n: usize) -> Self {
let enabled = n > 1 && std::io::stderr().is_terminal() && !crate::is_quiet() && !crate::is_no_progress();
let enabled = n > 1
&& std::io::stderr().is_terminal()
&& !crate::is_quiet()
&& !crate::is_no_progress();
Self::new(enabled)
}
@@ -33,23 +42,27 @@ impl ProgressManager {
return;
}
// Aggregate bar at the top
let mut total = cliclack::progress_bar(labels.len() as u64);
let total = cliclack::progress_bar(labels.len() as u64);
total.start("Total");
self.total = Some(total);
// Per-file bars (100% scale for each)
for label in labels {
let mut pb = cliclack::progress_bar(100);
let pb = cliclack::progress_bar(100);
pb.start(label);
self.per.push(pb);
}
}
/// Returns true when bars are enabled (multi-file TTY mode).
pub fn is_enabled(&self) -> bool { self.enabled }
pub fn is_enabled(&self) -> bool {
self.enabled
}
/// Update a per-file bar message.
pub fn set_per_message(&mut self, idx: usize, message: &str) {
if !self.enabled { return; }
if !self.enabled {
return;
}
if let Some(pb) = self.per.get_mut(idx) {
pb.set_message(message);
}
@@ -57,16 +70,20 @@ impl ProgressManager {
/// Update a per-file bar percent (0..=100).
pub fn set_per_percent(&mut self, idx: usize, percent: u64) {
if !self.enabled { return; }
if !self.enabled {
return;
}
if let Some(pb) = self.per.get_mut(idx) {
let p = percent.min(100);
pb.set_message(&format!("{p}%"));
pb.set_message(format!("{p}%"));
}
}
/// Mark a file as finished (set to 100% and update total counter).
pub fn mark_file_done(&mut self, idx: usize) {
if !self.enabled { return; }
if !self.enabled {
return;
}
if let Some(pb) = self.per.get_mut(idx) {
pb.stop("done");
}
@@ -81,7 +98,9 @@ impl ProgressManager {
/// Finish the aggregate bar with a custom message.
pub fn finish_total(&mut self, message: &str) {
if !self.enabled { return; }
if !self.enabled {
return;
}
if let Some(total) = &mut self.total {
total.stop(message);
}
@@ -96,7 +115,9 @@ pub struct ProgressReporter {
impl ProgressReporter {
/// Creates a new progress reporter.
pub fn new(non_interactive: bool) -> Self { Self { non_interactive } }
pub fn new(non_interactive: bool) -> Self {
Self { non_interactive }
}
/// Displays a progress step message.
pub fn step(&mut self, message: &str) {

View File

@@ -1,16 +1,11 @@
use anyhow::{Context, Result};
use serde::Deserialize;
use std::{
env,
fs,
os::unix::fs::PermissionsExt,
path::Path,
};
use std::process::Stdio;
use std::{env, fs, os::unix::fs::PermissionsExt, path::Path};
use tokio::{
io::{AsyncBufReadExt, BufReader},
process::{Child as TokioChild, Command},
};
use std::process::Stdio;
#[derive(Debug, Clone)]
pub struct PluginInfo {
@@ -31,14 +26,15 @@ impl PluginManager {
if let Ok(read_dir) = fs::read_dir(&dir) {
for entry in read_dir.flatten() {
let path = entry.path();
if let Some(fname) = path.file_name().and_then(|s| s.to_str()) {
if fname.starts_with("polyscribe-plugin-") && is_executable(&path) {
let name = fname.trim_start_matches("polyscribe-plugin-").to_string();
plugins.push(PluginInfo {
name,
path: path.to_string_lossy().to_string(),
});
}
if let Some(fname) = path.file_name().and_then(|s| s.to_str())
&& fname.starts_with("polyscribe-plugin-")
&& is_executable(&path)
{
let name = fname.trim_start_matches("polyscribe-plugin-").to_string();
plugins.push(PluginInfo {
name,
path: path.to_string_lossy().to_string(),
});
}
}
}
@@ -89,7 +85,8 @@ impl PluginManager {
fn resolve(&self, name: &str) -> Result<String> {
let bin = format!("polyscribe-plugin-{name}");
let path = which::which(&bin).with_context(|| format!("plugin not found in PATH: {bin}"))?;
let path =
which::which(&bin).with_context(|| format!("plugin not found in PATH: {bin}"))?;
Ok(path.to_string_lossy().to_string())
}
}