[refactor] modularize code by moving logic to polyscribe
crate; cleanup imports and remove redundant functions
This commit is contained in:
40
CHANGELOG.md
Normal file
40
CHANGELOG.md
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# PolyScribe Refactor toward Rust 2024 — Incremental Patches
|
||||||
|
|
||||||
|
This changelog documents each incremental step applied to keep the build green while moving the codebase toward Rust 2024 idioms.
|
||||||
|
|
||||||
|
## 1) Formatting only (rustfmt)
|
||||||
|
- Ran `cargo fmt` across the repository.
|
||||||
|
- No semantic changes.
|
||||||
|
- Build status: OK (`cargo build` succeeded).
|
||||||
|
|
||||||
|
## 2) Lints — initial fixes (non-pedantic)
|
||||||
|
- Adjusted crate lint policy in `src/lib.rs`:
|
||||||
|
- Replaced `#![warn(clippy::pedantic, clippy::nursery, clippy::cargo)]` with `#![warn(clippy::all)]` to align with the plan (skip pedantic/nursery for now).
|
||||||
|
- Added comment/TODO to revisit stricter lints in a later pass.
|
||||||
|
- Fixed several clippy warnings that were causing `cargo clippy --all-targets` to error under tests:
|
||||||
|
- `src/backend.rs`: conditionally import `libloading::Library` only for non-test builds and mark `names` parameter as used in test cfg to avoid unused warnings; keep `check_lib()` side‑effect free during tests.
|
||||||
|
- `src/models.rs`: removed an unused `std::io::Write` import in test module.
|
||||||
|
- `src/main.rs` (unit tests): imported `polyscribe::format_srt_time` explicitly and removed a duplicate `use super::*;` to fix unresolved name and unused import warnings under clippy test builds.
|
||||||
|
- Build/Clippy status:
|
||||||
|
- `cargo build`: OK.
|
||||||
|
- `cargo clippy --all-targets`: OK (only warnings remain; no errors).
|
||||||
|
|
||||||
|
## 3) Module hygiene
|
||||||
|
- Verified crate structure:
|
||||||
|
- Library crate (`src/lib.rs`) exposes a coherent API and re‑exports `backend` and `models` via `pub mod`.
|
||||||
|
- Binary (`src/main.rs`) consumes the library API through `polyscribe::...` paths.
|
||||||
|
- No structural changes required. Build status: OK.
|
||||||
|
|
||||||
|
## 4) Edition
|
||||||
|
- The project already targets `edition = "2024"` in Cargo.toml.
|
||||||
|
- Verified that the project compiles under Rust 2024. No changes needed.
|
||||||
|
- TODO: If stricter lints or new features from 2024 edition introduce issues in future steps, document blockers here.
|
||||||
|
|
||||||
|
## 5) Error handling
|
||||||
|
- The codebase already returns `anyhow::Result` in the binary and uses contextual errors widely.
|
||||||
|
- No `unwrap`/`expect` usages in production paths required attention in this pass.
|
||||||
|
- Build status: OK.
|
||||||
|
|
||||||
|
## Next planned steps (not yet applied in this changelog)
|
||||||
|
- Gradually fix remaining clippy warnings (e.g., `uninlined_format_args`, small style nits) in small, compile‑green patches.
|
||||||
|
- Optionally re‑enable `clippy::pedantic`, `clippy::nursery`, and `clippy::cargo` once warnings are significantly reduced, then address non‑breaking warnings.
|
5
TODO.md
5
TODO.md
@@ -11,11 +11,12 @@
|
|||||||
- [x] fix cli output for model display
|
- [x] fix cli output for model display
|
||||||
- [x] refactor into proper cli app
|
- [x] refactor into proper cli app
|
||||||
- [x] add support for video files -> use ffmpeg to extract audio
|
- [x] add support for video files -> use ffmpeg to extract audio
|
||||||
- detect gpus and use them
|
- [x] detect gpus and use them
|
||||||
- refactor project
|
- [x] refactor project
|
||||||
- add error handling
|
- add error handling
|
||||||
- add verbose flag (--verbose | -v) + add logging
|
- add verbose flag (--verbose | -v) + add logging
|
||||||
- add documentation
|
- add documentation
|
||||||
|
- refactor project
|
||||||
- package into executable
|
- package into executable
|
||||||
- add CI
|
- add CI
|
||||||
- add package build for arch linux
|
- add package build for arch linux
|
||||||
|
4
build.rs
4
build.rs
@@ -7,5 +7,7 @@ fn main() {
|
|||||||
// Placeholder: In a full implementation, we would invoke CMake for whisper.cpp with GGML_VULKAN=1.
|
// Placeholder: In a full implementation, we would invoke CMake for whisper.cpp with GGML_VULKAN=1.
|
||||||
// For now, emit a helpful note. Build will proceed; runtime Vulkan backend returns an explanatory error.
|
// For now, emit a helpful note. Build will proceed; runtime Vulkan backend returns an explanatory error.
|
||||||
println!("cargo:rerun-if-changed=extern/whisper.cpp");
|
println!("cargo:rerun-if-changed=extern/whisper.cpp");
|
||||||
println!("cargo:warning=Building with gpu-vulkan: ensure Vulkan SDK/loader are installed. Future versions will compile whisper.cpp via CMake.");
|
println!(
|
||||||
|
"cargo:warning=Building with gpu-vulkan: ensure Vulkan SDK/loader are installed. Future versions will compile whisper.cpp via CMake."
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
268
src/backend.rs
268
src/backend.rs
@@ -1,30 +1,54 @@
|
|||||||
use std::path::Path;
|
use crate::OutputEntry;
|
||||||
use anyhow::{anyhow, Context, Result};
|
|
||||||
use libloading::Library;
|
|
||||||
use crate::{OutputEntry};
|
|
||||||
use crate::{decode_audio_to_pcm_f32_ffmpeg, find_model_file};
|
use crate::{decode_audio_to_pcm_f32_ffmpeg, find_model_file};
|
||||||
|
use anyhow::{Context, Result, anyhow};
|
||||||
|
#[cfg(not(test))]
|
||||||
|
use libloading::Library;
|
||||||
use std::env;
|
use std::env;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
// Re-export a public enum for CLI parsing usage
|
// Re-export a public enum for CLI parsing usage
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
/// Kind of transcription backend to use.
|
||||||
pub enum BackendKind {
|
pub enum BackendKind {
|
||||||
|
/// Automatically detect the best available backend (CUDA > HIP > Vulkan > CPU).
|
||||||
Auto,
|
Auto,
|
||||||
|
/// Pure CPU backend using whisper-rs.
|
||||||
Cpu,
|
Cpu,
|
||||||
|
/// NVIDIA CUDA backend (requires CUDA runtime available at load time and proper feature build).
|
||||||
Cuda,
|
Cuda,
|
||||||
|
/// AMD ROCm/HIP backend (requires hip/rocBLAS libraries available and proper feature build).
|
||||||
Hip,
|
Hip,
|
||||||
|
/// Vulkan backend (experimental; requires Vulkan loader/SDK and feature build).
|
||||||
Vulkan,
|
Vulkan,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Abstraction for a transcription backend implementation.
|
||||||
pub trait TranscribeBackend {
|
pub trait TranscribeBackend {
|
||||||
|
/// Return the backend kind for this implementation.
|
||||||
fn kind(&self) -> BackendKind;
|
fn kind(&self) -> BackendKind;
|
||||||
fn transcribe(&self, audio_path: &Path, speaker: &str, lang_opt: Option<&str>, gpu_layers: Option<u32>) -> Result<Vec<OutputEntry>>;
|
/// Transcribe the given audio file path and return transcript entries.
|
||||||
|
///
|
||||||
|
/// Parameters:
|
||||||
|
/// - audio_path: path to input media (audio or video) to be decoded/transcribed.
|
||||||
|
/// - speaker: label to attach to all produced segments.
|
||||||
|
/// - lang_opt: optional language hint (e.g., "en"); None means auto/multilingual model default.
|
||||||
|
/// - gpu_layers: optional GPU layer count if applicable (ignored by some backends).
|
||||||
|
fn transcribe(
|
||||||
|
&self,
|
||||||
|
audio_path: &Path,
|
||||||
|
speaker: &str,
|
||||||
|
lang_opt: Option<&str>,
|
||||||
|
gpu_layers: Option<u32>,
|
||||||
|
) -> Result<Vec<OutputEntry>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_lib(names: &[&str]) -> bool {
|
fn check_lib(names: &[&str]) -> bool {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
{
|
{
|
||||||
// During unit tests, avoid touching system libs to prevent loader crashes in CI.
|
// During unit tests, avoid touching system libs to prevent loader crashes in CI.
|
||||||
return false;
|
// Mark parameter as used to silence warnings in test builds.
|
||||||
|
let _ = names;
|
||||||
|
false
|
||||||
}
|
}
|
||||||
#[cfg(not(test))]
|
#[cfg(not(test))]
|
||||||
{
|
{
|
||||||
@@ -33,79 +57,167 @@ fn check_lib(names: &[&str]) -> bool {
|
|||||||
}
|
}
|
||||||
for n in names {
|
for n in names {
|
||||||
// Attempt to dlopen; ignore errors
|
// Attempt to dlopen; ignore errors
|
||||||
if let Ok(_lib) = unsafe { Library::new(n) } { return true; }
|
if let Ok(_lib) = unsafe { Library::new(n) } {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cuda_available() -> bool {
|
fn cuda_available() -> bool {
|
||||||
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_CUDA") { return x == "1"; }
|
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_CUDA") {
|
||||||
check_lib(&["libcudart.so", "libcudart.so.12", "libcudart.so.11", "libcublas.so", "libcublas.so.12"])
|
return x == "1";
|
||||||
|
}
|
||||||
|
check_lib(&[
|
||||||
|
"libcudart.so",
|
||||||
|
"libcudart.so.12",
|
||||||
|
"libcudart.so.11",
|
||||||
|
"libcublas.so",
|
||||||
|
"libcublas.so.12",
|
||||||
|
])
|
||||||
}
|
}
|
||||||
|
|
||||||
fn hip_available() -> bool {
|
fn hip_available() -> bool {
|
||||||
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_HIP") { return x == "1"; }
|
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_HIP") {
|
||||||
|
return x == "1";
|
||||||
|
}
|
||||||
check_lib(&["libhipblas.so", "librocblas.so"])
|
check_lib(&["libhipblas.so", "librocblas.so"])
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vulkan_available() -> bool {
|
fn vulkan_available() -> bool {
|
||||||
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_VULKAN") { return x == "1"; }
|
if let Ok(x) = env::var("POLYSCRIBE_TEST_FORCE_VULKAN") {
|
||||||
|
return x == "1";
|
||||||
|
}
|
||||||
check_lib(&["libvulkan.so.1", "libvulkan.so"])
|
check_lib(&["libvulkan.so.1", "libvulkan.so"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// CPU-based transcription backend using whisper-rs.
|
||||||
pub struct CpuBackend;
|
pub struct CpuBackend;
|
||||||
|
/// CUDA-accelerated transcription backend for NVIDIA GPUs.
|
||||||
pub struct CudaBackend;
|
pub struct CudaBackend;
|
||||||
|
/// ROCm/HIP-accelerated transcription backend for AMD GPUs.
|
||||||
pub struct HipBackend;
|
pub struct HipBackend;
|
||||||
|
/// Vulkan-based transcription backend (experimental/incomplete).
|
||||||
pub struct VulkanBackend;
|
pub struct VulkanBackend;
|
||||||
|
|
||||||
impl CpuBackend {
|
impl CpuBackend {
|
||||||
pub fn new() -> Self { CpuBackend }
|
/// Create a new CPU backend instance.
|
||||||
|
pub fn new() -> Self {
|
||||||
|
CpuBackend
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl CudaBackend {
|
||||||
|
/// Create a new CUDA backend instance.
|
||||||
|
pub fn new() -> Self {
|
||||||
|
CudaBackend
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl HipBackend {
|
||||||
|
/// Create a new HIP backend instance.
|
||||||
|
pub fn new() -> Self {
|
||||||
|
HipBackend
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl VulkanBackend {
|
||||||
|
/// Create a new Vulkan backend instance.
|
||||||
|
pub fn new() -> Self {
|
||||||
|
VulkanBackend
|
||||||
|
}
|
||||||
}
|
}
|
||||||
impl CudaBackend { pub fn new() -> Self { CudaBackend } }
|
|
||||||
impl HipBackend { pub fn new() -> Self { HipBackend } }
|
|
||||||
impl VulkanBackend { pub fn new() -> Self { VulkanBackend } }
|
|
||||||
|
|
||||||
impl TranscribeBackend for CpuBackend {
|
impl TranscribeBackend for CpuBackend {
|
||||||
fn kind(&self) -> BackendKind { BackendKind::Cpu }
|
fn kind(&self) -> BackendKind {
|
||||||
fn transcribe(&self, audio_path: &Path, speaker: &str, lang_opt: Option<&str>, _gpu_layers: Option<u32>) -> Result<Vec<OutputEntry>> {
|
BackendKind::Cpu
|
||||||
|
}
|
||||||
|
fn transcribe(
|
||||||
|
&self,
|
||||||
|
audio_path: &Path,
|
||||||
|
speaker: &str,
|
||||||
|
lang_opt: Option<&str>,
|
||||||
|
_gpu_layers: Option<u32>,
|
||||||
|
) -> Result<Vec<OutputEntry>> {
|
||||||
transcribe_with_whisper_rs(audio_path, speaker, lang_opt)
|
transcribe_with_whisper_rs(audio_path, speaker, lang_opt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TranscribeBackend for CudaBackend {
|
impl TranscribeBackend for CudaBackend {
|
||||||
fn kind(&self) -> BackendKind { BackendKind::Cuda }
|
fn kind(&self) -> BackendKind {
|
||||||
fn transcribe(&self, audio_path: &Path, speaker: &str, lang_opt: Option<&str>, _gpu_layers: Option<u32>) -> Result<Vec<OutputEntry>> {
|
BackendKind::Cuda
|
||||||
|
}
|
||||||
|
fn transcribe(
|
||||||
|
&self,
|
||||||
|
audio_path: &Path,
|
||||||
|
speaker: &str,
|
||||||
|
lang_opt: Option<&str>,
|
||||||
|
_gpu_layers: Option<u32>,
|
||||||
|
) -> Result<Vec<OutputEntry>> {
|
||||||
// whisper-rs uses enabled CUDA feature at build time; call same code path
|
// whisper-rs uses enabled CUDA feature at build time; call same code path
|
||||||
transcribe_with_whisper_rs(audio_path, speaker, lang_opt)
|
transcribe_with_whisper_rs(audio_path, speaker, lang_opt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TranscribeBackend for HipBackend {
|
impl TranscribeBackend for HipBackend {
|
||||||
fn kind(&self) -> BackendKind { BackendKind::Hip }
|
fn kind(&self) -> BackendKind {
|
||||||
fn transcribe(&self, audio_path: &Path, speaker: &str, lang_opt: Option<&str>, _gpu_layers: Option<u32>) -> Result<Vec<OutputEntry>> {
|
BackendKind::Hip
|
||||||
|
}
|
||||||
|
fn transcribe(
|
||||||
|
&self,
|
||||||
|
audio_path: &Path,
|
||||||
|
speaker: &str,
|
||||||
|
lang_opt: Option<&str>,
|
||||||
|
_gpu_layers: Option<u32>,
|
||||||
|
) -> Result<Vec<OutputEntry>> {
|
||||||
transcribe_with_whisper_rs(audio_path, speaker, lang_opt)
|
transcribe_with_whisper_rs(audio_path, speaker, lang_opt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TranscribeBackend for VulkanBackend {
|
impl TranscribeBackend for VulkanBackend {
|
||||||
fn kind(&self) -> BackendKind { BackendKind::Vulkan }
|
fn kind(&self) -> BackendKind {
|
||||||
fn transcribe(&self, _audio_path: &Path, _speaker: &str, _lang_opt: Option<&str>, _gpu_layers: Option<u32>) -> Result<Vec<OutputEntry>> {
|
BackendKind::Vulkan
|
||||||
Err(anyhow!("Vulkan backend not yet wired to whisper.cpp FFI. Build with --features gpu-vulkan and ensure Vulkan SDK is installed. How to fix: install Vulkan loader (libvulkan), set VULKAN_SDK, and run cargo build --features gpu-vulkan."))
|
}
|
||||||
|
fn transcribe(
|
||||||
|
&self,
|
||||||
|
_audio_path: &Path,
|
||||||
|
_speaker: &str,
|
||||||
|
_lang_opt: Option<&str>,
|
||||||
|
_gpu_layers: Option<u32>,
|
||||||
|
) -> Result<Vec<OutputEntry>> {
|
||||||
|
Err(anyhow!(
|
||||||
|
"Vulkan backend not yet wired to whisper.cpp FFI. Build with --features gpu-vulkan and ensure Vulkan SDK is installed. How to fix: install Vulkan loader (libvulkan), set VULKAN_SDK, and run cargo build --features gpu-vulkan."
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Result of choosing a transcription backend.
|
||||||
pub struct SelectionResult {
|
pub struct SelectionResult {
|
||||||
|
/// The constructed backend instance to perform transcription with.
|
||||||
pub backend: Box<dyn TranscribeBackend + Send + Sync>,
|
pub backend: Box<dyn TranscribeBackend + Send + Sync>,
|
||||||
|
/// Which backend kind was ultimately selected.
|
||||||
pub chosen: BackendKind,
|
pub chosen: BackendKind,
|
||||||
|
/// Which backend kinds were detected as available on this system.
|
||||||
pub detected: Vec<BackendKind>,
|
pub detected: Vec<BackendKind>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Select an appropriate backend based on user request and system detection.
|
||||||
|
///
|
||||||
|
/// If `requested` is `BackendKind::Auto`, the function prefers CUDA, then HIP,
|
||||||
|
/// then Vulkan, falling back to CPU when no GPU backend is detected. When a
|
||||||
|
/// specific GPU backend is requested but unavailable, an error is returned with
|
||||||
|
/// guidance on how to enable it.
|
||||||
|
///
|
||||||
|
/// Set `verbose` to true to print detection/selection info to stderr.
|
||||||
pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<SelectionResult> {
|
pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<SelectionResult> {
|
||||||
let mut detected = Vec::new();
|
let mut detected = Vec::new();
|
||||||
if cuda_available() { detected.push(BackendKind::Cuda); }
|
if cuda_available() {
|
||||||
if hip_available() { detected.push(BackendKind::Hip); }
|
detected.push(BackendKind::Cuda);
|
||||||
if vulkan_available() { detected.push(BackendKind::Vulkan); }
|
}
|
||||||
|
if hip_available() {
|
||||||
|
detected.push(BackendKind::Hip);
|
||||||
|
}
|
||||||
|
if vulkan_available() {
|
||||||
|
detected.push(BackendKind::Vulkan);
|
||||||
|
}
|
||||||
|
|
||||||
let mk = |k: BackendKind| -> Box<dyn TranscribeBackend + Send + Sync> {
|
let mk = |k: BackendKind| -> Box<dyn TranscribeBackend + Send + Sync> {
|
||||||
match k {
|
match k {
|
||||||
@@ -119,22 +231,42 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
|
|||||||
|
|
||||||
let chosen = match requested {
|
let chosen = match requested {
|
||||||
BackendKind::Auto => {
|
BackendKind::Auto => {
|
||||||
if detected.contains(&BackendKind::Cuda) { BackendKind::Cuda }
|
if detected.contains(&BackendKind::Cuda) {
|
||||||
else if detected.contains(&BackendKind::Hip) { BackendKind::Hip }
|
BackendKind::Cuda
|
||||||
else if detected.contains(&BackendKind::Vulkan) { BackendKind::Vulkan }
|
} else if detected.contains(&BackendKind::Hip) {
|
||||||
else { BackendKind::Cpu }
|
BackendKind::Hip
|
||||||
|
} else if detected.contains(&BackendKind::Vulkan) {
|
||||||
|
BackendKind::Vulkan
|
||||||
|
} else {
|
||||||
|
BackendKind::Cpu
|
||||||
|
}
|
||||||
}
|
}
|
||||||
BackendKind::Cuda => {
|
BackendKind::Cuda => {
|
||||||
if detected.contains(&BackendKind::Cuda) { BackendKind::Cuda }
|
if detected.contains(&BackendKind::Cuda) {
|
||||||
else { return Err(anyhow!("Requested CUDA backend but CUDA libraries/devices not detected. How to fix: install NVIDIA driver + CUDA toolkit, ensure libcudart/libcublas are in loader path, and build with --features gpu-cuda.")); }
|
BackendKind::Cuda
|
||||||
|
} else {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"Requested CUDA backend but CUDA libraries/devices not detected. How to fix: install NVIDIA driver + CUDA toolkit, ensure libcudart/libcublas are in loader path, and build with --features gpu-cuda."
|
||||||
|
));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
BackendKind::Hip => {
|
BackendKind::Hip => {
|
||||||
if detected.contains(&BackendKind::Hip) { BackendKind::Hip }
|
if detected.contains(&BackendKind::Hip) {
|
||||||
else { return Err(anyhow!("Requested ROCm/HIP backend but libraries/devices not detected. How to fix: install ROCm hipBLAS/rocBLAS, ensure libs are in loader path, and build with --features gpu-hip.")); }
|
BackendKind::Hip
|
||||||
|
} else {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"Requested ROCm/HIP backend but libraries/devices not detected. How to fix: install ROCm hipBLAS/rocBLAS, ensure libs are in loader path, and build with --features gpu-hip."
|
||||||
|
));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
BackendKind::Vulkan => {
|
BackendKind::Vulkan => {
|
||||||
if detected.contains(&BackendKind::Vulkan) { BackendKind::Vulkan }
|
if detected.contains(&BackendKind::Vulkan) {
|
||||||
else { return Err(anyhow!("Requested Vulkan backend but libvulkan not detected. How to fix: install Vulkan loader/SDK and build with --features gpu-vulkan.")); }
|
BackendKind::Vulkan
|
||||||
|
} else {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"Requested Vulkan backend but libvulkan not detected. How to fix: install Vulkan loader/SDK and build with --features gpu-vulkan."
|
||||||
|
));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
BackendKind::Cpu => BackendKind::Cpu,
|
BackendKind::Cpu => BackendKind::Cpu,
|
||||||
};
|
};
|
||||||
@@ -144,12 +276,20 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
|
|||||||
eprintln!("INFO: Selected backend: {:?}", chosen);
|
eprintln!("INFO: Selected backend: {:?}", chosen);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(SelectionResult { backend: mk(chosen), chosen, detected })
|
Ok(SelectionResult {
|
||||||
|
backend: mk(chosen),
|
||||||
|
chosen,
|
||||||
|
detected,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Internal helper: transcription using whisper-rs with CPU/GPU (depending on build features)
|
// Internal helper: transcription using whisper-rs with CPU/GPU (depending on build features)
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) fn transcribe_with_whisper_rs(audio_path: &Path, speaker: &str, lang_opt: Option<&str>) -> Result<Vec<OutputEntry>> {
|
pub(crate) fn transcribe_with_whisper_rs(
|
||||||
|
audio_path: &Path,
|
||||||
|
speaker: &str,
|
||||||
|
lang_opt: Option<&str>,
|
||||||
|
) -> Result<Vec<OutputEntry>> {
|
||||||
let pcm = decode_audio_to_pcm_f32_ffmpeg(audio_path)?;
|
let pcm = decode_audio_to_pcm_f32_ffmpeg(audio_path)?;
|
||||||
let model = find_model_file()?;
|
let model = find_model_file()?;
|
||||||
let is_en_only = model
|
let is_en_only = model
|
||||||
@@ -161,34 +301,60 @@ pub(crate) fn transcribe_with_whisper_rs(audio_path: &Path, speaker: &str, lang_
|
|||||||
if is_en_only && lang != "en" {
|
if is_en_only && lang != "en" {
|
||||||
return Err(anyhow!(
|
return Err(anyhow!(
|
||||||
"Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.",
|
"Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.",
|
||||||
model.display(), lang
|
model.display(),
|
||||||
|
lang
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let model_str = model.to_str().ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model.display()))?;
|
let model_str = model
|
||||||
|
.to_str()
|
||||||
|
.ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model.display()))?;
|
||||||
|
|
||||||
let cparams = whisper_rs::WhisperContextParameters::default();
|
let cparams = whisper_rs::WhisperContextParameters::default();
|
||||||
let ctx = whisper_rs::WhisperContext::new_with_params(model_str, cparams)
|
let ctx = whisper_rs::WhisperContext::new_with_params(model_str, cparams)
|
||||||
.with_context(|| format!("Failed to load Whisper model at {}", model.display()))?;
|
.with_context(|| format!("Failed to load Whisper model at {}", model.display()))?;
|
||||||
let mut state = ctx.create_state().map_err(|e| anyhow!("Failed to create Whisper state: {:?}", e))?;
|
let mut state = ctx
|
||||||
|
.create_state()
|
||||||
|
.map_err(|e| anyhow!("Failed to create Whisper state: {:?}", e))?;
|
||||||
|
|
||||||
let mut params = whisper_rs::FullParams::new(whisper_rs::SamplingStrategy::Greedy { best_of: 1 });
|
let mut params =
|
||||||
let n_threads = std::thread::available_parallelism().map(|n| n.get() as i32).unwrap_or(1);
|
whisper_rs::FullParams::new(whisper_rs::SamplingStrategy::Greedy { best_of: 1 });
|
||||||
|
let n_threads = std::thread::available_parallelism()
|
||||||
|
.map(|n| n.get() as i32)
|
||||||
|
.unwrap_or(1);
|
||||||
params.set_n_threads(n_threads);
|
params.set_n_threads(n_threads);
|
||||||
params.set_translate(false);
|
params.set_translate(false);
|
||||||
if let Some(lang) = lang_opt { params.set_language(Some(lang)); }
|
if let Some(lang) = lang_opt {
|
||||||
|
params.set_language(Some(lang));
|
||||||
|
}
|
||||||
|
|
||||||
state.full(params, &pcm).map_err(|e| anyhow!("Whisper full() failed: {:?}", e))?;
|
state
|
||||||
|
.full(params, &pcm)
|
||||||
|
.map_err(|e| anyhow!("Whisper full() failed: {:?}", e))?;
|
||||||
|
|
||||||
let num_segments = state.full_n_segments().map_err(|e| anyhow!("Failed to get segments: {:?}", e))?;
|
let num_segments = state
|
||||||
|
.full_n_segments()
|
||||||
|
.map_err(|e| anyhow!("Failed to get segments: {:?}", e))?;
|
||||||
let mut items = Vec::new();
|
let mut items = Vec::new();
|
||||||
for i in 0..num_segments {
|
for i in 0..num_segments {
|
||||||
let text = state.full_get_segment_text(i).map_err(|e| anyhow!("Failed to get segment text: {:?}", e))?;
|
let text = state
|
||||||
let t0 = state.full_get_segment_t0(i).map_err(|e| anyhow!("Failed to get segment t0: {:?}", e))?;
|
.full_get_segment_text(i)
|
||||||
let t1 = state.full_get_segment_t1(i).map_err(|e| anyhow!("Failed to get segment t1: {:?}", e))?;
|
.map_err(|e| anyhow!("Failed to get segment text: {:?}", e))?;
|
||||||
|
let t0 = state
|
||||||
|
.full_get_segment_t0(i)
|
||||||
|
.map_err(|e| anyhow!("Failed to get segment t0: {:?}", e))?;
|
||||||
|
let t1 = state
|
||||||
|
.full_get_segment_t1(i)
|
||||||
|
.map_err(|e| anyhow!("Failed to get segment t1: {:?}", e))?;
|
||||||
let start = (t0 as f64) * 0.01;
|
let start = (t0 as f64) * 0.01;
|
||||||
let end = (t1 as f64) * 0.01;
|
let end = (t1 as f64) * 0.01;
|
||||||
items.push(OutputEntry { id: 0, speaker: speaker.to_string(), start, end, text: text.trim().to_string() });
|
items.push(OutputEntry {
|
||||||
|
id: 0,
|
||||||
|
speaker: speaker.to_string(),
|
||||||
|
start,
|
||||||
|
end,
|
||||||
|
text: text.trim().to_string(),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
Ok(items)
|
Ok(items)
|
||||||
}
|
}
|
||||||
|
341
src/lib.rs
Normal file
341
src/lib.rs
Normal file
@@ -0,0 +1,341 @@
|
|||||||
|
#![forbid(elided_lifetimes_in_paths)]
|
||||||
|
#![forbid(unused_must_use)]
|
||||||
|
#![deny(missing_docs)]
|
||||||
|
// Lint policy for incremental refactor toward 2024:
|
||||||
|
// - Keep basic clippy warnings enabled; skip pedantic/nursery for now (will revisit in step 7).
|
||||||
|
// - cargo lints can be re-enabled later once codebase is tidied.
|
||||||
|
#![warn(clippy::all)]
|
||||||
|
//! PolyScribe library: business logic and core types.
|
||||||
|
//!
|
||||||
|
//! This crate exposes the reusable parts of the PolyScribe CLI as a library.
|
||||||
|
//! The binary entry point (main.rs) remains a thin CLI wrapper.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result, anyhow};
|
||||||
|
use chrono::Local;
|
||||||
|
use std::env;
|
||||||
|
use std::fs::create_dir_all;
|
||||||
|
use std::io::{self, Write};
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::process::Command;
|
||||||
|
|
||||||
|
/// Re-export backend module (GPU/CPU selection and transcription).
|
||||||
|
pub mod backend;
|
||||||
|
/// Re-export models module (model listing/downloading/updating).
|
||||||
|
pub mod models;
|
||||||
|
|
||||||
|
/// Transcript entry for a single segment.
|
||||||
|
#[derive(Debug, serde::Serialize, Clone)]
|
||||||
|
pub struct OutputEntry {
|
||||||
|
/// Sequential id in output ordering.
|
||||||
|
pub id: u64,
|
||||||
|
/// Speaker label associated with the segment.
|
||||||
|
pub speaker: String,
|
||||||
|
/// Start time in seconds.
|
||||||
|
pub start: f64,
|
||||||
|
/// End time in seconds.
|
||||||
|
pub end: f64,
|
||||||
|
/// Text content.
|
||||||
|
pub text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return a YYYY-MM-DD date prefix string for output file naming.
|
||||||
|
pub fn date_prefix() -> String {
|
||||||
|
Local::now().format("%Y-%m-%d").to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Format a floating-point number of seconds as SRT timestamp (HH:MM:SS,mmm).
|
||||||
|
pub fn format_srt_time(seconds: f64) -> String {
|
||||||
|
let total_ms = (seconds * 1000.0).round() as i64;
|
||||||
|
let ms = (total_ms % 1000) as i64;
|
||||||
|
let total_secs = total_ms / 1000;
|
||||||
|
let s = (total_secs % 60) as i64;
|
||||||
|
let m = ((total_secs / 60) % 60) as i64;
|
||||||
|
let h = (total_secs / 3600) as i64;
|
||||||
|
format!("{:02}:{:02}:{:02},{:03}", h, m, s, ms)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Render a list of transcript entries to SRT format.
|
||||||
|
pub fn render_srt(items: &[OutputEntry]) -> String {
|
||||||
|
let mut out = String::new();
|
||||||
|
for (i, e) in items.iter().enumerate() {
|
||||||
|
let idx = i + 1;
|
||||||
|
out.push_str(&format!("{}\n", idx));
|
||||||
|
out.push_str(&format!(
|
||||||
|
"{} --> {}\n",
|
||||||
|
format_srt_time(e.start),
|
||||||
|
format_srt_time(e.end)
|
||||||
|
));
|
||||||
|
if !e.speaker.is_empty() {
|
||||||
|
out.push_str(&format!("{}: {}\n", e.speaker, e.text));
|
||||||
|
} else {
|
||||||
|
out.push_str(&format!("{}\n", e.text));
|
||||||
|
}
|
||||||
|
out.push('\n');
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Determine the default models directory, honoring POLYSCRIBE_MODELS_DIR override.
|
||||||
|
pub fn models_dir_path() -> PathBuf {
|
||||||
|
if let Ok(p) = env::var("POLYSCRIBE_MODELS_DIR") {
|
||||||
|
let pb = PathBuf::from(p);
|
||||||
|
if !pb.as_os_str().is_empty() {
|
||||||
|
return pb;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cfg!(debug_assertions) {
|
||||||
|
return PathBuf::from("models");
|
||||||
|
}
|
||||||
|
if let Ok(xdg) = env::var("XDG_DATA_HOME") {
|
||||||
|
if !xdg.is_empty() {
|
||||||
|
return PathBuf::from(xdg).join("polyscribe").join("models");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Ok(home) = env::var("HOME") {
|
||||||
|
if !home.is_empty() {
|
||||||
|
return PathBuf::from(home)
|
||||||
|
.join(".local")
|
||||||
|
.join("share")
|
||||||
|
.join("polyscribe")
|
||||||
|
.join("models");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PathBuf::from("models")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Normalize a language identifier to a short ISO code when possible.
|
||||||
|
pub fn normalize_lang_code(input: &str) -> Option<String> {
|
||||||
|
let mut s = input.trim().to_lowercase();
|
||||||
|
if s.is_empty() || s == "auto" || s == "c" || s == "posix" {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
if let Some((lhs, _)) = s.split_once('.') {
|
||||||
|
s = lhs.to_string();
|
||||||
|
}
|
||||||
|
if let Some((lhs, _)) = s.split_once('_') {
|
||||||
|
s = lhs.to_string();
|
||||||
|
}
|
||||||
|
let code = match s.as_str() {
|
||||||
|
"en" => "en",
|
||||||
|
"de" => "de",
|
||||||
|
"es" => "es",
|
||||||
|
"fr" => "fr",
|
||||||
|
"it" => "it",
|
||||||
|
"pt" => "pt",
|
||||||
|
"nl" => "nl",
|
||||||
|
"ru" => "ru",
|
||||||
|
"pl" => "pl",
|
||||||
|
"uk" => "uk",
|
||||||
|
"cs" => "cs",
|
||||||
|
"sv" => "sv",
|
||||||
|
"no" => "no",
|
||||||
|
"da" => "da",
|
||||||
|
"fi" => "fi",
|
||||||
|
"hu" => "hu",
|
||||||
|
"tr" => "tr",
|
||||||
|
"el" => "el",
|
||||||
|
"zh" => "zh",
|
||||||
|
"ja" => "ja",
|
||||||
|
"ko" => "ko",
|
||||||
|
"ar" => "ar",
|
||||||
|
"he" => "he",
|
||||||
|
"hi" => "hi",
|
||||||
|
"ro" => "ro",
|
||||||
|
"bg" => "bg",
|
||||||
|
"sk" => "sk",
|
||||||
|
"english" => "en",
|
||||||
|
"german" => "de",
|
||||||
|
"spanish" => "es",
|
||||||
|
"french" => "fr",
|
||||||
|
"italian" => "it",
|
||||||
|
"portuguese" => "pt",
|
||||||
|
"dutch" => "nl",
|
||||||
|
"russian" => "ru",
|
||||||
|
"polish" => "pl",
|
||||||
|
"ukrainian" => "uk",
|
||||||
|
"czech" => "cs",
|
||||||
|
"swedish" => "sv",
|
||||||
|
"norwegian" => "no",
|
||||||
|
"danish" => "da",
|
||||||
|
"finnish" => "fi",
|
||||||
|
"hungarian" => "hu",
|
||||||
|
"turkish" => "tr",
|
||||||
|
"greek" => "el",
|
||||||
|
"chinese" => "zh",
|
||||||
|
"japanese" => "ja",
|
||||||
|
"korean" => "ko",
|
||||||
|
"arabic" => "ar",
|
||||||
|
"hebrew" => "he",
|
||||||
|
"hindi" => "hi",
|
||||||
|
"romanian" => "ro",
|
||||||
|
"bulgarian" => "bg",
|
||||||
|
"slovak" => "sk",
|
||||||
|
_ => return None,
|
||||||
|
};
|
||||||
|
Some(code.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Locate a Whisper model file, prompting user to download/select when necessary.
|
||||||
|
pub fn find_model_file() -> Result<PathBuf> {
|
||||||
|
let models_dir_buf = models_dir_path();
|
||||||
|
let models_dir = models_dir_buf.as_path();
|
||||||
|
if !models_dir.exists() {
|
||||||
|
create_dir_all(models_dir).with_context(|| {
|
||||||
|
format!(
|
||||||
|
"Failed to create models directory: {}",
|
||||||
|
models_dir.display()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(env_model) = env::var("WHISPER_MODEL") {
|
||||||
|
let p = PathBuf::from(env_model);
|
||||||
|
if p.is_file() {
|
||||||
|
let _ = std::fs::write(models_dir.join(".last_model"), p.display().to_string());
|
||||||
|
return Ok(p);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut candidates: Vec<PathBuf> = Vec::new();
|
||||||
|
let rd = std::fs::read_dir(models_dir)
|
||||||
|
.with_context(|| format!("Failed to read models directory: {}", models_dir.display()))?;
|
||||||
|
for entry in rd {
|
||||||
|
let entry = entry?;
|
||||||
|
let path = entry.path();
|
||||||
|
if path.is_file() {
|
||||||
|
if let Some(ext) = path
|
||||||
|
.extension()
|
||||||
|
.and_then(|s| s.to_str())
|
||||||
|
.map(|s| s.to_lowercase())
|
||||||
|
{
|
||||||
|
if ext == "bin" {
|
||||||
|
candidates.push(path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if candidates.is_empty() {
|
||||||
|
eprintln!(
|
||||||
|
"WARN: No Whisper model files (*.bin) found in {}.",
|
||||||
|
models_dir.display()
|
||||||
|
);
|
||||||
|
eprint!("Would you like to download models now? [Y/n]: ");
|
||||||
|
io::stderr().flush().ok();
|
||||||
|
let mut input = String::new();
|
||||||
|
io::stdin().read_line(&mut input).ok();
|
||||||
|
let ans = input.trim().to_lowercase();
|
||||||
|
if ans.is_empty() || ans == "y" || ans == "yes" {
|
||||||
|
if let Err(e) = models::run_interactive_model_downloader() {
|
||||||
|
eprintln!("ERROR: Downloader failed: {:#}", e);
|
||||||
|
}
|
||||||
|
candidates.clear();
|
||||||
|
let rd2 = std::fs::read_dir(models_dir).with_context(|| {
|
||||||
|
format!("Failed to read models directory: {}", models_dir.display())
|
||||||
|
})?;
|
||||||
|
for entry in rd2 {
|
||||||
|
let entry = entry?;
|
||||||
|
let path = entry.path();
|
||||||
|
if path.is_file() {
|
||||||
|
if let Some(ext) = path
|
||||||
|
.extension()
|
||||||
|
.and_then(|s| s.to_str())
|
||||||
|
.map(|s| s.to_lowercase())
|
||||||
|
{
|
||||||
|
if ext == "bin" {
|
||||||
|
candidates.push(path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if candidates.is_empty() {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"No Whisper model files (*.bin) available in {}",
|
||||||
|
models_dir.display()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if candidates.len() == 1 {
|
||||||
|
let only = candidates.remove(0);
|
||||||
|
let _ = std::fs::write(models_dir.join(".last_model"), only.display().to_string());
|
||||||
|
return Ok(only);
|
||||||
|
}
|
||||||
|
|
||||||
|
let last_file = models_dir.join(".last_model");
|
||||||
|
if let Ok(prev) = std::fs::read_to_string(&last_file) {
|
||||||
|
let prev = prev.trim();
|
||||||
|
if !prev.is_empty() {
|
||||||
|
let p = PathBuf::from(prev);
|
||||||
|
if p.is_file() {
|
||||||
|
if candidates.iter().any(|c| c == &p) {
|
||||||
|
eprintln!("INFO: Using previously selected model: {}", p.display());
|
||||||
|
return Ok(p);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
eprintln!("Multiple Whisper models found in {}:", models_dir.display());
|
||||||
|
for (i, p) in candidates.iter().enumerate() {
|
||||||
|
eprintln!(" {}) {}", i + 1, p.display());
|
||||||
|
}
|
||||||
|
eprint!("Select model by number [1-{}]: ", candidates.len());
|
||||||
|
io::stderr().flush().ok();
|
||||||
|
let mut input = String::new();
|
||||||
|
io::stdin()
|
||||||
|
.read_line(&mut input)
|
||||||
|
.context("Failed to read selection")?;
|
||||||
|
let sel: usize = input
|
||||||
|
.trim()
|
||||||
|
.parse()
|
||||||
|
.map_err(|_| anyhow!("Invalid selection: {}", input.trim()))?;
|
||||||
|
if sel == 0 || sel > candidates.len() {
|
||||||
|
return Err(anyhow!("Selection out of range"));
|
||||||
|
}
|
||||||
|
let chosen = candidates.swap_remove(sel - 1);
|
||||||
|
let _ = std::fs::write(models_dir.join(".last_model"), chosen.display().to_string());
|
||||||
|
Ok(chosen)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decode an input media file to 16kHz mono f32 PCM using ffmpeg available on PATH.
|
||||||
|
pub fn decode_audio_to_pcm_f32_ffmpeg(audio_path: &Path) -> Result<Vec<f32>> {
|
||||||
|
let output = Command::new("ffmpeg")
|
||||||
|
.arg("-i")
|
||||||
|
.arg(audio_path)
|
||||||
|
.arg("-f")
|
||||||
|
.arg("f32le")
|
||||||
|
.arg("-ac")
|
||||||
|
.arg("1")
|
||||||
|
.arg("-ar")
|
||||||
|
.arg("16000")
|
||||||
|
.arg("pipe:1")
|
||||||
|
.output()
|
||||||
|
.with_context(|| format!("Failed to execute ffmpeg for {}", audio_path.display()))?;
|
||||||
|
if !output.status.success() {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"ffmpeg failed for {}: {}",
|
||||||
|
audio_path.display(),
|
||||||
|
String::from_utf8_lossy(&output.stderr)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let bytes = output.stdout;
|
||||||
|
if bytes.len() % 4 != 0 {
|
||||||
|
let truncated = bytes.len() - (bytes.len() % 4);
|
||||||
|
let mut v = Vec::with_capacity(truncated / 4);
|
||||||
|
for chunk in bytes[..truncated].chunks_exact(4) {
|
||||||
|
let arr = [chunk[0], chunk[1], chunk[2], chunk[3]];
|
||||||
|
v.push(f32::from_le_bytes(arr));
|
||||||
|
}
|
||||||
|
Ok(v)
|
||||||
|
} else {
|
||||||
|
let mut v = Vec::with_capacity(bytes.len() / 4);
|
||||||
|
for chunk in bytes.chunks_exact(4) {
|
||||||
|
let arr = [chunk[0], chunk[1], chunk[2], chunk[3]];
|
||||||
|
v.push(f32::from_le_bytes(arr));
|
||||||
|
}
|
||||||
|
Ok(v)
|
||||||
|
}
|
||||||
|
}
|
650
src/main.rs
650
src/main.rs
File diff suppressed because it is too large
Load Diff
482
src/models.rs
482
src/models.rs
@@ -1,14 +1,14 @@
|
|||||||
|
use std::collections::BTreeMap;
|
||||||
|
use std::env;
|
||||||
use std::fs::{File, create_dir_all};
|
use std::fs::{File, create_dir_all};
|
||||||
use std::io::{self, Read, Write};
|
use std::io::{self, Read, Write};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::collections::BTreeMap;
|
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use std::env;
|
|
||||||
|
|
||||||
use anyhow::{anyhow, Context, Result};
|
use anyhow::{Context, Result, anyhow};
|
||||||
use serde::Deserialize;
|
|
||||||
use reqwest::blocking::Client;
|
use reqwest::blocking::Client;
|
||||||
use reqwest::redirect::Policy;
|
use reqwest::redirect::Policy;
|
||||||
|
use serde::Deserialize;
|
||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
|
|
||||||
// Print to stderr only when not in quiet mode
|
// Print to stderr only when not in quiet mode
|
||||||
@@ -80,22 +80,33 @@ fn human_size(bytes: u64) -> String {
|
|||||||
const MB: f64 = KB * 1024.0;
|
const MB: f64 = KB * 1024.0;
|
||||||
const GB: f64 = MB * 1024.0;
|
const GB: f64 = MB * 1024.0;
|
||||||
let b = bytes as f64;
|
let b = bytes as f64;
|
||||||
if b >= GB { format!("{:.2} GiB", b / GB) }
|
if b >= GB {
|
||||||
else if b >= MB { format!("{:.2} MiB", b / MB) }
|
format!("{:.2} GiB", b / GB)
|
||||||
else if b >= KB { format!("{:.2} KiB", b / KB) }
|
} else if b >= MB {
|
||||||
else { format!("{} B", bytes) }
|
format!("{:.2} MiB", b / MB)
|
||||||
|
} else if b >= KB {
|
||||||
|
format!("{:.2} KiB", b / KB)
|
||||||
|
} else {
|
||||||
|
format!("{} B", bytes)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_hex_lower(bytes: &[u8]) -> String {
|
fn to_hex_lower(bytes: &[u8]) -> String {
|
||||||
let mut s = String::with_capacity(bytes.len() * 2);
|
let mut s = String::with_capacity(bytes.len() * 2);
|
||||||
for b in bytes { s.push_str(&format!("{:02x}", b)); }
|
for b in bytes {
|
||||||
|
s.push_str(&format!("{:02x}", b));
|
||||||
|
}
|
||||||
s
|
s
|
||||||
}
|
}
|
||||||
|
|
||||||
fn expected_sha_from_sibling(s: &HFSibling) -> Option<String> {
|
fn expected_sha_from_sibling(s: &HFSibling) -> Option<String> {
|
||||||
if let Some(h) = &s.sha256 { return Some(h.to_lowercase()); }
|
if let Some(h) = &s.sha256 {
|
||||||
|
return Some(h.to_lowercase());
|
||||||
|
}
|
||||||
if let Some(lfs) = &s.lfs {
|
if let Some(lfs) = &s.lfs {
|
||||||
if let Some(h) = &lfs.sha256 { return Some(h.to_lowercase()); }
|
if let Some(h) = &lfs.sha256 {
|
||||||
|
return Some(h.to_lowercase());
|
||||||
|
}
|
||||||
if let Some(oid) = &lfs.oid {
|
if let Some(oid) = &lfs.oid {
|
||||||
// e.g. "sha256:abcdef..."
|
// e.g. "sha256:abcdef..."
|
||||||
if let Some(rest) = oid.strip_prefix("sha256:") {
|
if let Some(rest) = oid.strip_prefix("sha256:") {
|
||||||
@@ -107,15 +118,23 @@ fn expected_sha_from_sibling(s: &HFSibling) -> Option<String> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn size_from_sibling(s: &HFSibling) -> Option<u64> {
|
fn size_from_sibling(s: &HFSibling) -> Option<u64> {
|
||||||
if let Some(sz) = s.size { return Some(sz); }
|
if let Some(sz) = s.size {
|
||||||
if let Some(lfs) = &s.lfs { return lfs.size; }
|
return Some(sz);
|
||||||
|
}
|
||||||
|
if let Some(lfs) = &s.lfs {
|
||||||
|
return lfs.size;
|
||||||
|
}
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
fn expected_sha_from_tree(s: &HFTreeItem) -> Option<String> {
|
fn expected_sha_from_tree(s: &HFTreeItem) -> Option<String> {
|
||||||
if let Some(h) = &s.sha256 { return Some(h.to_lowercase()); }
|
if let Some(h) = &s.sha256 {
|
||||||
|
return Some(h.to_lowercase());
|
||||||
|
}
|
||||||
if let Some(lfs) = &s.lfs {
|
if let Some(lfs) = &s.lfs {
|
||||||
if let Some(h) = &lfs.sha256 { return Some(h.to_lowercase()); }
|
if let Some(h) = &lfs.sha256 {
|
||||||
|
return Some(h.to_lowercase());
|
||||||
|
}
|
||||||
if let Some(oid) = &lfs.oid {
|
if let Some(oid) = &lfs.oid {
|
||||||
if let Some(rest) = oid.strip_prefix("sha256:") {
|
if let Some(rest) = oid.strip_prefix("sha256:") {
|
||||||
return Some(rest.to_lowercase().to_string());
|
return Some(rest.to_lowercase().to_string());
|
||||||
@@ -126,8 +145,12 @@ fn expected_sha_from_tree(s: &HFTreeItem) -> Option<String> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn size_from_tree(s: &HFTreeItem) -> Option<u64> {
|
fn size_from_tree(s: &HFTreeItem) -> Option<u64> {
|
||||||
if let Some(sz) = s.size { return Some(sz); }
|
if let Some(sz) = s.size {
|
||||||
if let Some(lfs) = &s.lfs { return lfs.size; }
|
return Some(sz);
|
||||||
|
}
|
||||||
|
if let Some(lfs) = &s.lfs {
|
||||||
|
return lfs.size;
|
||||||
|
}
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,12 +159,20 @@ fn fill_meta_via_head(repo: &str, name: &str) -> (Option<u64>, Option<String>) {
|
|||||||
.user_agent("PolyScribe/0.1 (+https://github.com/)")
|
.user_agent("PolyScribe/0.1 (+https://github.com/)")
|
||||||
.redirect(Policy::none())
|
.redirect(Policy::none())
|
||||||
.timeout(Duration::from_secs(30))
|
.timeout(Duration::from_secs(30))
|
||||||
.build() {
|
.build()
|
||||||
|
{
|
||||||
Ok(c) => c,
|
Ok(c) => c,
|
||||||
Err(_) => return (None, None),
|
Err(_) => return (None, None),
|
||||||
};
|
};
|
||||||
let url = format!("https://huggingface.co/{}/resolve/main/ggml-{}.bin", repo, name);
|
let url = format!(
|
||||||
let resp = match head_client.head(url).send().and_then(|r| r.error_for_status()) {
|
"https://huggingface.co/{}/resolve/main/ggml-{}.bin",
|
||||||
|
repo, name
|
||||||
|
);
|
||||||
|
let resp = match head_client
|
||||||
|
.head(url)
|
||||||
|
.send()
|
||||||
|
.and_then(|r| r.error_for_status())
|
||||||
|
{
|
||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
Err(_) => return (None, None),
|
Err(_) => return (None, None),
|
||||||
};
|
};
|
||||||
@@ -179,21 +210,40 @@ fn fill_meta_via_head(repo: &str, name: &str) -> (Option<u64>, Option<String>) {
|
|||||||
fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<ModelEntry>> {
|
fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<ModelEntry>> {
|
||||||
qlog!("Fetching online data: listing models from {}...", repo);
|
qlog!("Fetching online data: listing models from {}...", repo);
|
||||||
// Prefer the tree endpoint for reliable size/hash metadata, then fall back to model metadata
|
// Prefer the tree endpoint for reliable size/hash metadata, then fall back to model metadata
|
||||||
let tree_url = format!("https://huggingface.co/api/models/{}/tree/main?recursive=1", repo);
|
let tree_url = format!(
|
||||||
|
"https://huggingface.co/api/models/{}/tree/main?recursive=1",
|
||||||
|
repo
|
||||||
|
);
|
||||||
let mut out: Vec<ModelEntry> = Vec::new();
|
let mut out: Vec<ModelEntry> = Vec::new();
|
||||||
|
|
||||||
match client.get(tree_url).send().and_then(|r| r.error_for_status()) {
|
match client
|
||||||
|
.get(tree_url)
|
||||||
|
.send()
|
||||||
|
.and_then(|r| r.error_for_status())
|
||||||
|
{
|
||||||
Ok(resp) => {
|
Ok(resp) => {
|
||||||
match resp.json::<Vec<HFTreeItem>>() {
|
match resp.json::<Vec<HFTreeItem>>() {
|
||||||
Ok(items) => {
|
Ok(items) => {
|
||||||
for it in items {
|
for it in items {
|
||||||
let path = it.path.clone();
|
let path = it.path.clone();
|
||||||
if !(path.starts_with("ggml-") && path.ends_with(".bin")) { continue; }
|
if !(path.starts_with("ggml-") && path.ends_with(".bin")) {
|
||||||
let model_name = path.trim_start_matches("ggml-").trim_end_matches(".bin").to_string();
|
continue;
|
||||||
|
}
|
||||||
|
let model_name = path
|
||||||
|
.trim_start_matches("ggml-")
|
||||||
|
.trim_end_matches(".bin")
|
||||||
|
.to_string();
|
||||||
let (base, subtype) = split_model_name(&model_name);
|
let (base, subtype) = split_model_name(&model_name);
|
||||||
let size = size_from_tree(&it).unwrap_or(0);
|
let size = size_from_tree(&it).unwrap_or(0);
|
||||||
let sha256 = expected_sha_from_tree(&it);
|
let sha256 = expected_sha_from_tree(&it);
|
||||||
out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo: repo.to_string() });
|
out.push(ModelEntry {
|
||||||
|
name: model_name,
|
||||||
|
base,
|
||||||
|
subtype,
|
||||||
|
size,
|
||||||
|
sha256,
|
||||||
|
repo: repo.to_string(),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(_) => { /* fall back below */ }
|
Err(_) => { /* fall back below */ }
|
||||||
@@ -210,30 +260,49 @@ fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<Model
|
|||||||
.and_then(|r| r.error_for_status())
|
.and_then(|r| r.error_for_status())
|
||||||
.context("Failed to query Hugging Face API")?;
|
.context("Failed to query Hugging Face API")?;
|
||||||
|
|
||||||
let info: HFRepoInfo = resp.json().context("Failed to parse Hugging Face API response")?;
|
let info: HFRepoInfo = resp
|
||||||
|
.json()
|
||||||
|
.context("Failed to parse Hugging Face API response")?;
|
||||||
|
|
||||||
if let Some(files) = info.siblings {
|
if let Some(files) = info.siblings {
|
||||||
for s in files {
|
for s in files {
|
||||||
let rf = s.rfilename.clone();
|
let rf = s.rfilename.clone();
|
||||||
if !(rf.starts_with("ggml-") && rf.ends_with(".bin")) { continue; }
|
if !(rf.starts_with("ggml-") && rf.ends_with(".bin")) {
|
||||||
let model_name = rf.trim_start_matches("ggml-").trim_end_matches(".bin").to_string();
|
continue;
|
||||||
|
}
|
||||||
|
let model_name = rf
|
||||||
|
.trim_start_matches("ggml-")
|
||||||
|
.trim_end_matches(".bin")
|
||||||
|
.to_string();
|
||||||
let (base, subtype) = split_model_name(&model_name);
|
let (base, subtype) = split_model_name(&model_name);
|
||||||
let size = size_from_sibling(&s).unwrap_or(0);
|
let size = size_from_sibling(&s).unwrap_or(0);
|
||||||
let sha256 = expected_sha_from_sibling(&s);
|
let sha256 = expected_sha_from_sibling(&s);
|
||||||
out.push(ModelEntry { name: model_name, base, subtype, size, sha256, repo: repo.to_string() });
|
out.push(ModelEntry {
|
||||||
|
name: model_name,
|
||||||
|
base,
|
||||||
|
subtype,
|
||||||
|
size,
|
||||||
|
sha256,
|
||||||
|
repo: repo.to_string(),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fill missing metadata (size/hash) via HEAD request if necessary
|
// Fill missing metadata (size/hash) via HEAD request if necessary
|
||||||
if out.iter().any(|m| m.size == 0 || m.sha256.is_none()) {
|
if out.iter().any(|m| m.size == 0 || m.sha256.is_none()) {
|
||||||
qlog!("Fetching online data: completing metadata checks for models in {}...", repo);
|
qlog!(
|
||||||
|
"Fetching online data: completing metadata checks for models in {}...",
|
||||||
|
repo
|
||||||
|
);
|
||||||
}
|
}
|
||||||
for m in out.iter_mut() {
|
for m in out.iter_mut() {
|
||||||
if m.size == 0 || m.sha256.is_none() {
|
if m.size == 0 || m.sha256.is_none() {
|
||||||
let (sz, sha) = fill_meta_via_head(&m.repo, &m.name);
|
let (sz, sha) = fill_meta_via_head(&m.repo, &m.name);
|
||||||
if m.size == 0 {
|
if m.size == 0 {
|
||||||
if let Some(s) = sz { m.size = s; }
|
if let Some(s) = sz {
|
||||||
|
m.size = s;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if m.sha256.is_none() {
|
if m.sha256.is_none() {
|
||||||
m.sha256 = sha;
|
m.sha256 = sha;
|
||||||
@@ -242,7 +311,12 @@ fn hf_fetch_repo_models(client: &Client, repo: &'static str) -> Result<Vec<Model
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Sort by base then subtype then name for stable listing
|
// Sort by base then subtype then name for stable listing
|
||||||
out.sort_by(|a, b| a.base.cmp(&b.base).then(a.subtype.cmp(&b.subtype)).then(a.name.cmp(&b.name)));
|
out.sort_by(|a, b| {
|
||||||
|
a.base
|
||||||
|
.cmp(&b.base)
|
||||||
|
.then(a.subtype.cmp(&b.subtype))
|
||||||
|
.then(a.name.cmp(&b.name))
|
||||||
|
});
|
||||||
Ok(out)
|
Ok(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -251,32 +325,42 @@ fn fetch_all_models(client: &Client) -> Result<Vec<ModelEntry>> {
|
|||||||
let mut v1 = hf_fetch_repo_models(client, "ggerganov/whisper.cpp")?; // main repo must succeed
|
let mut v1 = hf_fetch_repo_models(client, "ggerganov/whisper.cpp")?; // main repo must succeed
|
||||||
|
|
||||||
// Optional tinydiarize repo; ignore errors but log to stderr
|
// Optional tinydiarize repo; ignore errors but log to stderr
|
||||||
let mut v2: Vec<ModelEntry> = match hf_fetch_repo_models(client, "akashmjn/tinydiarize-whisper.cpp") {
|
let mut v2: Vec<ModelEntry> =
|
||||||
Ok(v) => v,
|
match hf_fetch_repo_models(client, "akashmjn/tinydiarize-whisper.cpp") {
|
||||||
Err(e) => {
|
Ok(v) => v,
|
||||||
qlog!("Warning: failed to fetch optional repo akashmjn/tinydiarize-whisper.cpp: {:#}", e);
|
Err(e) => {
|
||||||
Vec::new()
|
qlog!(
|
||||||
}
|
"Warning: failed to fetch optional repo akashmjn/tinydiarize-whisper.cpp: {:#}",
|
||||||
};
|
e
|
||||||
|
);
|
||||||
|
Vec::new()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
v1.append(&mut v2);
|
v1.append(&mut v2);
|
||||||
|
|
||||||
// Deduplicate by name preferring ggerganov repo if duplicates
|
// Deduplicate by name preferring ggerganov repo if duplicates
|
||||||
let mut map: BTreeMap<String, ModelEntry> = BTreeMap::new();
|
let mut map: BTreeMap<String, ModelEntry> = BTreeMap::new();
|
||||||
for m in v1 {
|
for m in v1 {
|
||||||
map.entry(m.name.clone()).and_modify(|existing| {
|
map.entry(m.name.clone())
|
||||||
if existing.repo != "ggerganov/whisper.cpp" && m.repo == "ggerganov/whisper.cpp" {
|
.and_modify(|existing| {
|
||||||
*existing = m.clone();
|
if existing.repo != "ggerganov/whisper.cpp" && m.repo == "ggerganov/whisper.cpp" {
|
||||||
}
|
*existing = m.clone();
|
||||||
}).or_insert(m);
|
}
|
||||||
|
})
|
||||||
|
.or_insert(m);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut list: Vec<ModelEntry> = map.into_values().collect();
|
let mut list: Vec<ModelEntry> = map.into_values().collect();
|
||||||
list.sort_by(|a, b| a.base.cmp(&b.base).then(a.subtype.cmp(&b.subtype)).then(a.name.cmp(&b.name)));
|
list.sort_by(|a, b| {
|
||||||
|
a.base
|
||||||
|
.cmp(&b.base)
|
||||||
|
.then(a.subtype.cmp(&b.subtype))
|
||||||
|
.then(a.name.cmp(&b.name))
|
||||||
|
});
|
||||||
Ok(list)
|
Ok(list)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
fn format_model_list(models: &[ModelEntry]) -> String {
|
fn format_model_list(models: &[ModelEntry]) -> String {
|
||||||
let mut out = String::new();
|
let mut out = String::new();
|
||||||
out.push_str("Available ggml Whisper models:\n");
|
out.push_str("Available ggml Whisper models:\n");
|
||||||
@@ -305,7 +389,9 @@ fn format_model_list(models: &[ModelEntry]) -> String {
|
|||||||
));
|
));
|
||||||
idx += 1;
|
idx += 1;
|
||||||
}
|
}
|
||||||
out.push_str("\nEnter selection by indices (e.g., 1 3 5-7), or 'all', '*' for all, 'q' to cancel.\n");
|
out.push_str(
|
||||||
|
"\nEnter selection by indices (e.g., 1 3 5-7), or 'all', '*' for all, 'q' to cancel.\n",
|
||||||
|
);
|
||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -335,21 +421,33 @@ fn prompt_select_models_two_stage(models: &[ModelEntry]) -> Result<Vec<ModelEntr
|
|||||||
eprint!("Select base (number or name, 'q' to cancel): ");
|
eprint!("Select base (number or name, 'q' to cancel): ");
|
||||||
io::stderr().flush().ok();
|
io::stderr().flush().ok();
|
||||||
let mut line = String::new();
|
let mut line = String::new();
|
||||||
io::stdin().read_line(&mut line).context("Failed to read base selection")?;
|
io::stdin()
|
||||||
|
.read_line(&mut line)
|
||||||
|
.context("Failed to read base selection")?;
|
||||||
let s = line.trim();
|
let s = line.trim();
|
||||||
if s.eq_ignore_ascii_case("q") || s.eq_ignore_ascii_case("quit") || s.eq_ignore_ascii_case("exit") {
|
if s.eq_ignore_ascii_case("q")
|
||||||
|
|| s.eq_ignore_ascii_case("quit")
|
||||||
|
|| s.eq_ignore_ascii_case("exit")
|
||||||
|
{
|
||||||
return Ok(Vec::new());
|
return Ok(Vec::new());
|
||||||
}
|
}
|
||||||
let chosen_base = if let Ok(i) = s.parse::<usize>() {
|
let chosen_base = if let Ok(i) = s.parse::<usize>() {
|
||||||
if i >= 1 && i <= bases.len() { Some(bases[i - 1].clone()) } else { None }
|
if i >= 1 && i <= bases.len() {
|
||||||
|
Some(bases[i - 1].clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
} else if !s.is_empty() {
|
} else if !s.is_empty() {
|
||||||
// accept exact name match (case-insensitive)
|
// accept exact name match (case-insensitive)
|
||||||
bases.iter().find(|b| b.eq_ignore_ascii_case(s)).cloned()
|
bases.iter().find(|b| b.eq_ignore_ascii_case(s)).cloned()
|
||||||
} else { None };
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
if let Some(base) = chosen_base {
|
if let Some(base) = chosen_base {
|
||||||
// 2) Choose sub-type(s) within that base
|
// 2) Choose sub-type(s) within that base
|
||||||
let filtered: Vec<ModelEntry> = models.iter().filter(|m| m.base == base).cloned().collect();
|
let filtered: Vec<ModelEntry> =
|
||||||
|
models.iter().filter(|m| m.base == base).cloned().collect();
|
||||||
if filtered.is_empty() {
|
if filtered.is_empty() {
|
||||||
eprintln!("No models found for base '{}'.", base);
|
eprintln!("No models found for base '{}'.", base);
|
||||||
continue;
|
continue;
|
||||||
@@ -370,22 +468,32 @@ fn prompt_select_models_two_stage(models: &[ModelEntry]) -> Result<Vec<ModelEntr
|
|||||||
eprint!("Selection: ");
|
eprint!("Selection: ");
|
||||||
io::stderr().flush().ok();
|
io::stderr().flush().ok();
|
||||||
let mut line2 = String::new();
|
let mut line2 = String::new();
|
||||||
io::stdin().read_line(&mut line2).context("Failed to read selection")?;
|
io::stdin()
|
||||||
|
.read_line(&mut line2)
|
||||||
|
.context("Failed to read selection")?;
|
||||||
let s2 = line2.trim().to_lowercase();
|
let s2 = line2.trim().to_lowercase();
|
||||||
if s2 == "q" || s2 == "quit" || s2 == "exit" { return Ok(Vec::new()); }
|
if s2 == "q" || s2 == "quit" || s2 == "exit" {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
let mut selected: Vec<usize> = Vec::new();
|
let mut selected: Vec<usize> = Vec::new();
|
||||||
if s2 == "all" || s2 == "*" {
|
if s2 == "all" || s2 == "*" {
|
||||||
selected = (1..idx).collect();
|
selected = (1..idx).collect();
|
||||||
} else if !s2.is_empty() {
|
} else if !s2.is_empty() {
|
||||||
for part in s2.split(|c| c == ',' || c == ' ' || c == ';') {
|
for part in s2.split(|c| c == ',' || c == ' ' || c == ';') {
|
||||||
let part = part.trim();
|
let part = part.trim();
|
||||||
if part.is_empty() { continue; }
|
if part.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if let Some((a, b)) = part.split_once('-') {
|
if let Some((a, b)) = part.split_once('-') {
|
||||||
if let (Ok(ia), Ok(ib)) = (a.parse::<usize>(), b.parse::<usize>()) {
|
if let (Ok(ia), Ok(ib)) = (a.parse::<usize>(), b.parse::<usize>()) {
|
||||||
if ia >= 1 && ib < idx && ia <= ib { selected.extend(ia..=ib); }
|
if ia >= 1 && ib < idx && ia <= ib {
|
||||||
|
selected.extend(ia..=ib);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if let Ok(i) = part.parse::<usize>() {
|
} else if let Ok(i) = part.parse::<usize>() {
|
||||||
if i >= 1 && i < idx { selected.push(i); }
|
if i >= 1 && i < idx {
|
||||||
|
selected.push(i);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -395,12 +503,17 @@ fn prompt_select_models_two_stage(models: &[ModelEntry]) -> Result<Vec<ModelEntr
|
|||||||
eprintln!("No valid selection. Please try again or 'q' to cancel.");
|
eprintln!("No valid selection. Please try again or 'q' to cancel.");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let chosen: Vec<ModelEntry> = selected.into_iter().map(|i| filtered[index_map[i - 1]].clone()).collect();
|
let chosen: Vec<ModelEntry> = selected
|
||||||
|
.into_iter()
|
||||||
|
.map(|i| filtered[index_map[i - 1]].clone())
|
||||||
|
.collect();
|
||||||
return Ok(chosen);
|
return Ok(chosen);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
eprintln!("Invalid base selection. Please enter a number from 1-{} or a base name.", bases.len());
|
eprintln!(
|
||||||
continue;
|
"Invalid base selection. Please enter a number from 1-{} or a base name.",
|
||||||
|
bases.len()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -413,52 +526,30 @@ fn compute_file_sha256_hex(path: &Path) -> Result<String> {
|
|||||||
let mut buf = [0u8; 1024 * 128];
|
let mut buf = [0u8; 1024 * 128];
|
||||||
loop {
|
loop {
|
||||||
let n = reader.read(&mut buf).context("Read error during hashing")?;
|
let n = reader.read(&mut buf).context("Read error during hashing")?;
|
||||||
if n == 0 { break; }
|
if n == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
hasher.update(&buf[..n]);
|
hasher.update(&buf[..n]);
|
||||||
}
|
}
|
||||||
Ok(to_hex_lower(&hasher.finalize()))
|
Ok(to_hex_lower(&hasher.finalize()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn models_dir_path() -> std::path::PathBuf {
|
/// Interactively list and download Whisper models from Hugging Face into the models directory.
|
||||||
// Highest priority: explicit override
|
|
||||||
if let Ok(p) = env::var("POLYSCRIBE_MODELS_DIR") {
|
|
||||||
let pb = std::path::PathBuf::from(p);
|
|
||||||
if !pb.as_os_str().is_empty() { return pb; }
|
|
||||||
}
|
|
||||||
// In debug builds, keep local ./models for convenience
|
|
||||||
if cfg!(debug_assertions) {
|
|
||||||
return std::path::PathBuf::from("models");
|
|
||||||
}
|
|
||||||
// In release builds, choose a user-writable data directory
|
|
||||||
if let Ok(xdg) = env::var("XDG_DATA_HOME") {
|
|
||||||
if !xdg.is_empty() {
|
|
||||||
return std::path::PathBuf::from(xdg).join("polyscribe").join("models");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if let Ok(home) = env::var("HOME") {
|
|
||||||
if !home.is_empty() {
|
|
||||||
return std::path::PathBuf::from(home)
|
|
||||||
.join(".local")
|
|
||||||
.join("share")
|
|
||||||
.join("polyscribe")
|
|
||||||
.join("models");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Last resort fallback
|
|
||||||
std::path::PathBuf::from("models")
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn run_interactive_model_downloader() -> Result<()> {
|
pub fn run_interactive_model_downloader() -> Result<()> {
|
||||||
let models_dir_buf = models_dir_path();
|
let models_dir_buf = crate::models_dir_path();
|
||||||
let models_dir = models_dir_buf.as_path();
|
let models_dir = models_dir_buf.as_path();
|
||||||
if !models_dir.exists() { create_dir_all(models_dir).context("Failed to create models directory")?; }
|
if !models_dir.exists() {
|
||||||
|
create_dir_all(models_dir).context("Failed to create models directory")?;
|
||||||
|
}
|
||||||
let client = Client::builder()
|
let client = Client::builder()
|
||||||
.user_agent("PolyScribe/0.1 (+https://github.com/)")
|
.user_agent("PolyScribe/0.1 (+https://github.com/)")
|
||||||
.timeout(std::time::Duration::from_secs(600))
|
.timeout(std::time::Duration::from_secs(600))
|
||||||
.build()
|
.build()
|
||||||
.context("Failed to build HTTP client")?;
|
.context("Failed to build HTTP client")?;
|
||||||
|
|
||||||
qlog!("Fetching online data: contacting Hugging Face to retrieve available models (this may take a moment)...");
|
qlog!(
|
||||||
|
"Fetching online data: contacting Hugging Face to retrieve available models (this may take a moment)..."
|
||||||
|
);
|
||||||
let models = fetch_all_models(&client)?;
|
let models = fetch_all_models(&client)?;
|
||||||
if models.is_empty() {
|
if models.is_empty() {
|
||||||
qlog!("No models found on Hugging Face listing. Please try again later.");
|
qlog!("No models found on Hugging Face listing. Please try again later.");
|
||||||
@@ -470,12 +561,15 @@ pub fn run_interactive_model_downloader() -> Result<()> {
|
|||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
for m in selected {
|
for m in selected {
|
||||||
if let Err(e) = download_one_model(&client, models_dir, &m) { qlog!("Error: {:#}", e); }
|
if let Err(e) = download_one_model(&client, models_dir, &m) {
|
||||||
|
qlog!("Error: {:#}", e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> {
|
/// Download a single model entry into the given models directory, verifying SHA-256 when available.
|
||||||
|
fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry) -> Result<()> {
|
||||||
let final_path = models_dir.join(format!("ggml-{}.bin", entry.name));
|
let final_path = models_dir.join(format!("ggml-{}.bin", entry.name));
|
||||||
|
|
||||||
// If the model already exists, verify against online metadata
|
// If the model already exists, verify against online metadata
|
||||||
@@ -497,9 +591,10 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
|||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
qlog!(
|
qlog!(
|
||||||
"Warning: failed to hash existing {}: {}. Will re-download to ensure correctness.",
|
"Warning: failed to hash existing {}: {}. Will re-download to ensure correctness.",
|
||||||
final_path.display(), e
|
final_path.display(),
|
||||||
);
|
e
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if entry.size > 0 {
|
} else if entry.size > 0 {
|
||||||
@@ -508,20 +603,24 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
|||||||
if md.len() == entry.size {
|
if md.len() == entry.size {
|
||||||
qlog!(
|
qlog!(
|
||||||
"Model {} appears up-to-date by size ({}).",
|
"Model {} appears up-to-date by size ({}).",
|
||||||
final_path.display(), entry.size
|
final_path.display(),
|
||||||
|
entry.size
|
||||||
);
|
);
|
||||||
return Ok(());
|
return Ok(());
|
||||||
} else {
|
} else {
|
||||||
qlog!(
|
qlog!(
|
||||||
"Local model {} size ({}) differs from online ({}). Updating...",
|
"Local model {} size ({}) differs from online ({}). Updating...",
|
||||||
final_path.display(), md.len(), entry.size
|
final_path.display(),
|
||||||
|
md.len(),
|
||||||
|
entry.size
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
qlog!(
|
qlog!(
|
||||||
"Warning: failed to stat existing {}: {}. Will re-download to ensure correctness.",
|
"Warning: failed to stat existing {}: {}. Will re-download to ensure correctness.",
|
||||||
final_path.display(), e
|
final_path.display(),
|
||||||
|
e
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -540,9 +639,16 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
|||||||
if src_path.exists() {
|
if src_path.exists() {
|
||||||
qlog!("Copying {} from {}...", entry.name, src_path.display());
|
qlog!("Copying {} from {}...", entry.name, src_path.display());
|
||||||
let tmp_path = models_dir.join(format!("ggml-{}.bin.part", entry.name));
|
let tmp_path = models_dir.join(format!("ggml-{}.bin.part", entry.name));
|
||||||
if tmp_path.exists() { let _ = std::fs::remove_file(&tmp_path); }
|
if tmp_path.exists() {
|
||||||
std::fs::copy(&src_path, &tmp_path)
|
let _ = std::fs::remove_file(&tmp_path);
|
||||||
.with_context(|| format!("Failed to copy from {} to {}", src_path.display(), tmp_path.display()))?;
|
}
|
||||||
|
std::fs::copy(&src_path, &tmp_path).with_context(|| {
|
||||||
|
format!(
|
||||||
|
"Failed to copy from {} to {}",
|
||||||
|
src_path.display(),
|
||||||
|
tmp_path.display()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
// Verify hash if available
|
// Verify hash if available
|
||||||
if let Some(expected) = &entry.sha256 {
|
if let Some(expected) = &entry.sha256 {
|
||||||
let got = compute_file_sha256_hex(&tmp_path)?;
|
let got = compute_file_sha256_hex(&tmp_path)?;
|
||||||
@@ -550,12 +656,16 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
|||||||
let _ = std::fs::remove_file(&tmp_path);
|
let _ = std::fs::remove_file(&tmp_path);
|
||||||
return Err(anyhow!(
|
return Err(anyhow!(
|
||||||
"SHA-256 mismatch for {} (copied): expected {}, got {}",
|
"SHA-256 mismatch for {} (copied): expected {}, got {}",
|
||||||
entry.name, expected, got
|
entry.name,
|
||||||
|
expected,
|
||||||
|
got
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Replace existing file safely
|
// Replace existing file safely
|
||||||
if final_path.exists() { let _ = std::fs::remove_file(&final_path); }
|
if final_path.exists() {
|
||||||
|
let _ = std::fs::remove_file(&final_path);
|
||||||
|
}
|
||||||
std::fs::rename(&tmp_path, &final_path)
|
std::fs::rename(&tmp_path, &final_path)
|
||||||
.with_context(|| format!("Failed to move into place: {}", final_path.display()))?;
|
.with_context(|| format!("Failed to move into place: {}", final_path.display()))?;
|
||||||
qlog!("Saved: {}", final_path.display());
|
qlog!("Saved: {}", final_path.display());
|
||||||
@@ -563,8 +673,16 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let url = format!("https://huggingface.co/{}/resolve/main/ggml-{}.bin", entry.repo, entry.name);
|
let url = format!(
|
||||||
qlog!("Downloading {} ({} | {})...", entry.name, human_size(entry.size), url);
|
"https://huggingface.co/{}/resolve/main/ggml-{}.bin",
|
||||||
|
entry.repo, entry.name
|
||||||
|
);
|
||||||
|
qlog!(
|
||||||
|
"Downloading {} ({} | {})...",
|
||||||
|
entry.name,
|
||||||
|
human_size(entry.size),
|
||||||
|
url
|
||||||
|
);
|
||||||
let mut resp = client
|
let mut resp = client
|
||||||
.get(url)
|
.get(url)
|
||||||
.send()
|
.send()
|
||||||
@@ -577,14 +695,16 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
|||||||
}
|
}
|
||||||
let mut file = std::io::BufWriter::new(
|
let mut file = std::io::BufWriter::new(
|
||||||
File::create(&tmp_path)
|
File::create(&tmp_path)
|
||||||
.with_context(|| format!("Failed to create {}", tmp_path.display()))?
|
.with_context(|| format!("Failed to create {}", tmp_path.display()))?,
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut hasher = Sha256::new();
|
let mut hasher = Sha256::new();
|
||||||
let mut buf = [0u8; 1024 * 128];
|
let mut buf = [0u8; 1024 * 128];
|
||||||
loop {
|
loop {
|
||||||
let n = resp.read(&mut buf).context("Network read error")?;
|
let n = resp.read(&mut buf).context("Network read error")?;
|
||||||
if n == 0 { break; }
|
if n == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
hasher.update(&buf[..n]);
|
hasher.update(&buf[..n]);
|
||||||
file.write_all(&buf[..n]).context("Write error")?;
|
file.write_all(&buf[..n]).context("Write error")?;
|
||||||
}
|
}
|
||||||
@@ -596,11 +716,16 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
|||||||
let _ = std::fs::remove_file(&tmp_path);
|
let _ = std::fs::remove_file(&tmp_path);
|
||||||
return Err(anyhow!(
|
return Err(anyhow!(
|
||||||
"SHA-256 mismatch for {}: expected {}, got {}",
|
"SHA-256 mismatch for {}: expected {}, got {}",
|
||||||
entry.name, expected, got
|
entry.name,
|
||||||
|
expected,
|
||||||
|
got
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
qlog!("Warning: no SHA-256 available for {}. Skipping verification.", entry.name);
|
qlog!(
|
||||||
|
"Warning: no SHA-256 available for {}. Skipping verification.",
|
||||||
|
entry.name
|
||||||
|
);
|
||||||
}
|
}
|
||||||
// Replace existing file safely
|
// Replace existing file safely
|
||||||
if final_path.exists() {
|
if final_path.exists() {
|
||||||
@@ -612,8 +737,9 @@ pub fn download_one_model(client: &Client, models_dir: &Path, entry: &ModelEntry
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Update locally stored models by re-downloading when size or hash does not match online metadata.
|
||||||
pub fn update_local_models() -> Result<()> {
|
pub fn update_local_models() -> Result<()> {
|
||||||
let models_dir_buf = models_dir_path();
|
let models_dir_buf = crate::models_dir_path();
|
||||||
let models_dir = models_dir_buf.as_path();
|
let models_dir = models_dir_buf.as_path();
|
||||||
if !models_dir.exists() {
|
if !models_dir.exists() {
|
||||||
create_dir_all(models_dir).context("Failed to create models directory")?;
|
create_dir_all(models_dir).context("Failed to create models directory")?;
|
||||||
@@ -627,13 +753,14 @@ pub fn update_local_models() -> Result<()> {
|
|||||||
.context("Failed to build HTTP client")?;
|
.context("Failed to build HTTP client")?;
|
||||||
|
|
||||||
// Obtain manifest: env override or online fetch
|
// Obtain manifest: env override or online fetch
|
||||||
let models: Vec<ModelEntry> = if let Ok(manifest_path) = env::var("POLYSCRIBE_MODELS_MANIFEST") {
|
let models: Vec<ModelEntry> = if let Ok(manifest_path) = env::var("POLYSCRIBE_MODELS_MANIFEST")
|
||||||
|
{
|
||||||
let data = std::fs::read_to_string(&manifest_path)
|
let data = std::fs::read_to_string(&manifest_path)
|
||||||
.with_context(|| format!("Failed to read manifest at {}", manifest_path))?;
|
.with_context(|| format!("Failed to read manifest at {}", manifest_path))?;
|
||||||
let mut list: Vec<ModelEntry> = serde_json::from_str(&data)
|
let mut list: Vec<ModelEntry> = serde_json::from_str(&data)
|
||||||
.with_context(|| format!("Invalid JSON manifest: {}", manifest_path))?;
|
.with_context(|| format!("Invalid JSON manifest: {}", manifest_path))?;
|
||||||
// sort for stability
|
// sort for stability
|
||||||
list.sort_by(|a,b| a.name.cmp(&b.name));
|
list.sort_by(|a, b| a.name.cmp(&b.name));
|
||||||
list
|
list
|
||||||
} else {
|
} else {
|
||||||
fetch_all_models(&client)?
|
fetch_all_models(&client)?
|
||||||
@@ -641,7 +768,9 @@ pub fn update_local_models() -> Result<()> {
|
|||||||
|
|
||||||
// Map name -> entry for fast lookup
|
// Map name -> entry for fast lookup
|
||||||
let mut map: BTreeMap<String, ModelEntry> = BTreeMap::new();
|
let mut map: BTreeMap<String, ModelEntry> = BTreeMap::new();
|
||||||
for m in models { map.insert(m.name.clone(), m); }
|
for m in models {
|
||||||
|
map.insert(m.name.clone(), m);
|
||||||
|
}
|
||||||
|
|
||||||
// Scan local ggml-*.bin models
|
// Scan local ggml-*.bin models
|
||||||
let rd = std::fs::read_dir(models_dir)
|
let rd = std::fs::read_dir(models_dir)
|
||||||
@@ -649,10 +778,20 @@ pub fn update_local_models() -> Result<()> {
|
|||||||
for entry in rd {
|
for entry in rd {
|
||||||
let entry = entry?;
|
let entry = entry?;
|
||||||
let path = entry.path();
|
let path = entry.path();
|
||||||
if !path.is_file() { continue; }
|
if !path.is_file() {
|
||||||
let fname = match path.file_name().and_then(|s| s.to_str()) { Some(s) => s.to_string(), None => continue };
|
continue;
|
||||||
if !fname.starts_with("ggml-") || !fname.ends_with(".bin") { continue; }
|
}
|
||||||
let model_name = fname.trim_start_matches("ggml-").trim_end_matches(".bin").to_string();
|
let fname = match path.file_name().and_then(|s| s.to_str()) {
|
||||||
|
Some(s) => s.to_string(),
|
||||||
|
None => continue,
|
||||||
|
};
|
||||||
|
if !fname.starts_with("ggml-") || !fname.ends_with(".bin") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let model_name = fname
|
||||||
|
.trim_start_matches("ggml-")
|
||||||
|
.trim_end_matches(".bin")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
if let Some(remote) = map.get(&model_name) {
|
if let Some(remote) = map.get(&model_name) {
|
||||||
// If SHA256 available, verify and update if mismatch
|
// If SHA256 available, verify and update if mismatch
|
||||||
@@ -664,11 +803,11 @@ pub fn update_local_models() -> Result<()> {
|
|||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
qlog!(
|
qlog!(
|
||||||
"{} hash differs (local {}.. != remote {}..). Updating...",
|
"{} hash differs (local {}.. != remote {}..). Updating...",
|
||||||
fname,
|
fname,
|
||||||
&local_hash[..std::cmp::min(8, local_hash.len())],
|
&local_hash[..std::cmp::min(8, local_hash.len())],
|
||||||
&expected[..std::cmp::min(8, expected.len())]
|
&expected[..std::cmp::min(8, expected.len())]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -683,7 +822,12 @@ pub fn update_local_models() -> Result<()> {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Ok(md) => {
|
Ok(md) => {
|
||||||
qlog!("{} size {} differs from remote {}. Updating...", fname, md.len(), remote.size);
|
qlog!(
|
||||||
|
"{} size {} differs from remote {}. Updating...",
|
||||||
|
fname,
|
||||||
|
md.len(),
|
||||||
|
remote.size
|
||||||
|
);
|
||||||
download_one_model(&client, models_dir, remote)?;
|
download_one_model(&client, models_dir, remote)?;
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -702,20 +846,43 @@ pub fn update_local_models() -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use tempfile::tempdir;
|
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::io::Write;
|
use tempfile::tempdir;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_format_model_list_spacing_and_structure() {
|
fn test_format_model_list_spacing_and_structure() {
|
||||||
let models = vec![
|
let models = vec![
|
||||||
ModelEntry { name: "tiny.en-q5_1".to_string(), base: "tiny".to_string(), subtype: "en-q5_1".to_string(), size: 1024*1024, sha256: Some("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef".to_string()), repo: "ggerganov/whisper.cpp".to_string() },
|
ModelEntry {
|
||||||
ModelEntry { name: "tiny-q5_1".to_string(), base: "tiny".to_string(), subtype: "q5_1".to_string(), size: 2048, sha256: None, repo: "ggerganov/whisper.cpp".to_string() },
|
name: "tiny.en-q5_1".to_string(),
|
||||||
ModelEntry { name: "base.en-q5_1".to_string(), base: "base".to_string(), subtype: "en-q5_1".to_string(), size: 10, sha256: Some("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string()), repo: "akashmjn/tinydiarize-whisper.cpp".to_string() },
|
base: "tiny".to_string(),
|
||||||
|
subtype: "en-q5_1".to_string(),
|
||||||
|
size: 1024 * 1024,
|
||||||
|
sha256: Some(
|
||||||
|
"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef".to_string(),
|
||||||
|
),
|
||||||
|
repo: "ggerganov/whisper.cpp".to_string(),
|
||||||
|
},
|
||||||
|
ModelEntry {
|
||||||
|
name: "tiny-q5_1".to_string(),
|
||||||
|
base: "tiny".to_string(),
|
||||||
|
subtype: "q5_1".to_string(),
|
||||||
|
size: 2048,
|
||||||
|
sha256: None,
|
||||||
|
repo: "ggerganov/whisper.cpp".to_string(),
|
||||||
|
},
|
||||||
|
ModelEntry {
|
||||||
|
name: "base.en-q5_1".to_string(),
|
||||||
|
base: "base".to_string(),
|
||||||
|
subtype: "en-q5_1".to_string(),
|
||||||
|
size: 10,
|
||||||
|
sha256: Some(
|
||||||
|
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string(),
|
||||||
|
),
|
||||||
|
repo: "akashmjn/tinydiarize-whisper.cpp".to_string(),
|
||||||
|
},
|
||||||
];
|
];
|
||||||
let s = format_model_list(&models);
|
let s = format_model_list(&models);
|
||||||
// Header present
|
// Header present
|
||||||
@@ -724,7 +891,10 @@ mod tests {
|
|||||||
assert!(s.contains("\ntiny:\n"));
|
assert!(s.contains("\ntiny:\n"));
|
||||||
assert!(s.contains("\nbase:\n"));
|
assert!(s.contains("\nbase:\n"));
|
||||||
// No immediate double space before a bracket after parenthesis
|
// No immediate double space before a bracket after parenthesis
|
||||||
assert!(!s.contains(") ["), "should not have double space immediately before bracket");
|
assert!(
|
||||||
|
!s.contains(") ["),
|
||||||
|
"should not have double space immediately before bracket"
|
||||||
|
);
|
||||||
// Lines contain normalized spacing around pipes and no hash
|
// Lines contain normalized spacing around pipes and no hash
|
||||||
assert!(s.contains("[ggerganov/whisper.cpp | 1.00 MiB]"));
|
assert!(s.contains("[ggerganov/whisper.cpp | 1.00 MiB]"));
|
||||||
assert!(s.contains("[ggerganov/whisper.cpp | 2.00 KiB]"));
|
assert!(s.contains("[ggerganov/whisper.cpp | 2.00 KiB]"));
|
||||||
@@ -748,7 +918,9 @@ mod tests {
|
|||||||
hasher.update(data);
|
hasher.update(data);
|
||||||
let out = hasher.finalize();
|
let out = hasher.finalize();
|
||||||
let mut s = String::new();
|
let mut s = String::new();
|
||||||
for b in out { s.push_str(&format!("{:02x}", b)); }
|
for b in out {
|
||||||
|
s.push_str(&format!("{:02x}", b));
|
||||||
|
}
|
||||||
s
|
s
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -786,7 +958,11 @@ mod tests {
|
|||||||
"repo": "ggerganov/whisper.cpp"
|
"repo": "ggerganov/whisper.cpp"
|
||||||
}
|
}
|
||||||
]);
|
]);
|
||||||
fs::write(&manifest_path, serde_json::to_string_pretty(&manifest).unwrap()).unwrap();
|
fs::write(
|
||||||
|
&manifest_path,
|
||||||
|
serde_json::to_string_pretty(&manifest).unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Set env vars to force offline behavior and directories
|
// Set env vars to force offline behavior and directories
|
||||||
unsafe {
|
unsafe {
|
||||||
@@ -807,34 +983,54 @@ mod tests {
|
|||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
fn test_models_dir_path_default_debug_and_env_override_models_mod() {
|
fn test_models_dir_path_default_debug_and_env_override_models_mod() {
|
||||||
// clear override
|
// clear override
|
||||||
unsafe { std::env::remove_var("POLYSCRIBE_MODELS_DIR"); }
|
unsafe {
|
||||||
assert_eq!(super::models_dir_path(), std::path::PathBuf::from("models"));
|
std::env::remove_var("POLYSCRIBE_MODELS_DIR");
|
||||||
|
}
|
||||||
|
assert_eq!(crate::models_dir_path(), std::path::PathBuf::from("models"));
|
||||||
// override
|
// override
|
||||||
let tmp = tempfile::tempdir().unwrap();
|
let tmp = tempfile::tempdir().unwrap();
|
||||||
unsafe { std::env::set_var("POLYSCRIBE_MODELS_DIR", tmp.path()); }
|
unsafe {
|
||||||
assert_eq!(super::models_dir_path(), tmp.path().to_path_buf());
|
std::env::set_var("POLYSCRIBE_MODELS_DIR", tmp.path());
|
||||||
|
}
|
||||||
|
assert_eq!(crate::models_dir_path(), tmp.path().to_path_buf());
|
||||||
// cleanup
|
// cleanup
|
||||||
unsafe { std::env::remove_var("POLYSCRIBE_MODELS_DIR"); }
|
unsafe {
|
||||||
|
std::env::remove_var("POLYSCRIBE_MODELS_DIR");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(not(debug_assertions))]
|
#[cfg(not(debug_assertions))]
|
||||||
fn test_models_dir_path_default_release_models_mod() {
|
fn test_models_dir_path_default_release_models_mod() {
|
||||||
unsafe { std::env::remove_var("POLYSCRIBE_MODELS_DIR"); }
|
unsafe {
|
||||||
|
std::env::remove_var("POLYSCRIBE_MODELS_DIR");
|
||||||
|
}
|
||||||
// With XDG_DATA_HOME set
|
// With XDG_DATA_HOME set
|
||||||
let tmp_xdg = tempfile::tempdir().unwrap();
|
let tmp_xdg = tempfile::tempdir().unwrap();
|
||||||
unsafe {
|
unsafe {
|
||||||
std::env::set_var("XDG_DATA_HOME", tmp_xdg.path());
|
std::env::set_var("XDG_DATA_HOME", tmp_xdg.path());
|
||||||
std::env::remove_var("HOME");
|
std::env::remove_var("HOME");
|
||||||
}
|
}
|
||||||
assert_eq!(super::models_dir_path(), std::path::PathBuf::from(tmp_xdg.path()).join("polyscribe").join("models"));
|
assert_eq!(
|
||||||
|
crate::models_dir_path(),
|
||||||
|
std::path::PathBuf::from(tmp_xdg.path())
|
||||||
|
.join("polyscribe")
|
||||||
|
.join("models")
|
||||||
|
);
|
||||||
// With HOME fallback
|
// With HOME fallback
|
||||||
let tmp_home = tempfile::tempdir().unwrap();
|
let tmp_home = tempfile::tempdir().unwrap();
|
||||||
unsafe {
|
unsafe {
|
||||||
std::env::remove_var("XDG_DATA_HOME");
|
std::env::remove_var("XDG_DATA_HOME");
|
||||||
std::env::set_var("HOME", tmp_home.path());
|
std::env::set_var("HOME", tmp_home.path());
|
||||||
}
|
}
|
||||||
assert_eq!(super::models_dir_path(), std::path::PathBuf::from(tmp_home.path()).join(".local").join("share").join("polyscribe").join("models"));
|
assert_eq!(
|
||||||
|
super::models_dir_path(),
|
||||||
|
std::path::PathBuf::from(tmp_home.path())
|
||||||
|
.join(".local")
|
||||||
|
.join("share")
|
||||||
|
.join("polyscribe")
|
||||||
|
.join("models")
|
||||||
|
);
|
||||||
unsafe {
|
unsafe {
|
||||||
std::env::remove_var("XDG_DATA_HOME");
|
std::env::remove_var("XDG_DATA_HOME");
|
||||||
std::env::remove_var("HOME");
|
std::env::remove_var("HOME");
|
||||||
|
@@ -1,6 +1,8 @@
|
|||||||
use std::process::Command;
|
use std::process::Command;
|
||||||
|
|
||||||
fn bin() -> &'static str { env!("CARGO_BIN_EXE_polyscribe") }
|
fn bin() -> &'static str {
|
||||||
|
env!("CARGO_BIN_EXE_polyscribe")
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn aux_completions_bash_outputs_script() {
|
fn aux_completions_bash_outputs_script() {
|
||||||
@@ -9,11 +11,21 @@ fn aux_completions_bash_outputs_script() {
|
|||||||
.arg("bash")
|
.arg("bash")
|
||||||
.output()
|
.output()
|
||||||
.expect("failed to run polyscribe completions bash");
|
.expect("failed to run polyscribe completions bash");
|
||||||
assert!(out.status.success(), "completions bash exited with failure: {:?}", out.status);
|
assert!(
|
||||||
|
out.status.success(),
|
||||||
|
"completions bash exited with failure: {:?}",
|
||||||
|
out.status
|
||||||
|
);
|
||||||
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
|
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
|
||||||
assert!(!stdout.trim().is_empty(), "completions bash stdout is empty");
|
assert!(
|
||||||
|
!stdout.trim().is_empty(),
|
||||||
|
"completions bash stdout is empty"
|
||||||
|
);
|
||||||
// Heuristic: bash completion scripts often contain 'complete -F' lines
|
// Heuristic: bash completion scripts often contain 'complete -F' lines
|
||||||
assert!(stdout.contains("complete") || stdout.contains("_polyscribe"), "bash completion script did not contain expected markers");
|
assert!(
|
||||||
|
stdout.contains("complete") || stdout.contains("_polyscribe"),
|
||||||
|
"bash completion script did not contain expected markers"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -23,11 +35,18 @@ fn aux_completions_zsh_outputs_script() {
|
|||||||
.arg("zsh")
|
.arg("zsh")
|
||||||
.output()
|
.output()
|
||||||
.expect("failed to run polyscribe completions zsh");
|
.expect("failed to run polyscribe completions zsh");
|
||||||
assert!(out.status.success(), "completions zsh exited with failure: {:?}", out.status);
|
assert!(
|
||||||
|
out.status.success(),
|
||||||
|
"completions zsh exited with failure: {:?}",
|
||||||
|
out.status
|
||||||
|
);
|
||||||
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
|
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
|
||||||
assert!(!stdout.trim().is_empty(), "completions zsh stdout is empty");
|
assert!(!stdout.trim().is_empty(), "completions zsh stdout is empty");
|
||||||
// Heuristic: zsh completion scripts often start with '#compdef'
|
// Heuristic: zsh completion scripts often start with '#compdef'
|
||||||
assert!(stdout.contains("#compdef") || stdout.contains("#compdef polyscribe"), "zsh completion script did not contain expected markers");
|
assert!(
|
||||||
|
stdout.contains("#compdef") || stdout.contains("#compdef polyscribe"),
|
||||||
|
"zsh completion script did not contain expected markers"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -36,10 +55,21 @@ fn aux_man_outputs_roff() {
|
|||||||
.arg("man")
|
.arg("man")
|
||||||
.output()
|
.output()
|
||||||
.expect("failed to run polyscribe man");
|
.expect("failed to run polyscribe man");
|
||||||
assert!(out.status.success(), "man exited with failure: {:?}", out.status);
|
assert!(
|
||||||
|
out.status.success(),
|
||||||
|
"man exited with failure: {:?}",
|
||||||
|
out.status
|
||||||
|
);
|
||||||
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
|
let stdout = String::from_utf8(out.stdout).expect("stdout not utf-8");
|
||||||
assert!(!stdout.trim().is_empty(), "man stdout is empty");
|
assert!(!stdout.trim().is_empty(), "man stdout is empty");
|
||||||
// clap_mangen typically emits roff with .TH and/or section headers
|
// clap_mangen typically emits roff with .TH and/or section headers
|
||||||
let looks_like_roff = stdout.contains(".TH ") || stdout.starts_with(".TH") || stdout.contains(".SH NAME") || stdout.contains(".SH SYNOPSIS");
|
let looks_like_roff = stdout.contains(".TH ")
|
||||||
assert!(looks_like_roff, "man output does not look like a roff manpage; got: {}", &stdout.lines().take(3).collect::<Vec<_>>().join(" | "));
|
|| stdout.starts_with(".TH")
|
||||||
|
|| stdout.contains(".SH NAME")
|
||||||
|
|| stdout.contains(".SH SYNOPSIS");
|
||||||
|
assert!(
|
||||||
|
looks_like_roff,
|
||||||
|
"man output does not look like a roff manpage; got: {}",
|
||||||
|
&stdout.lines().take(3).collect::<Vec<_>>().join(" | ")
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
@@ -30,7 +30,9 @@ impl TestDir {
|
|||||||
fs::create_dir_all(&p).expect("Failed to create temp dir");
|
fs::create_dir_all(&p).expect("Failed to create temp dir");
|
||||||
TestDir(p)
|
TestDir(p)
|
||||||
}
|
}
|
||||||
fn path(&self) -> &Path { &self.0 }
|
fn path(&self) -> &Path {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
impl Drop for TestDir {
|
impl Drop for TestDir {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
@@ -79,14 +81,32 @@ fn cli_writes_separate_outputs_by_default() {
|
|||||||
for e in entries {
|
for e in entries {
|
||||||
let p = e.unwrap().path();
|
let p = e.unwrap().path();
|
||||||
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
|
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
|
||||||
if name.ends_with(".json") { json_paths.push(p.clone()); }
|
if name.ends_with(".json") {
|
||||||
if name.ends_with(".toml") { count_toml += 1; }
|
json_paths.push(p.clone());
|
||||||
if name.ends_with(".srt") { count_srt += 1; }
|
}
|
||||||
|
if name.ends_with(".toml") {
|
||||||
|
count_toml += 1;
|
||||||
|
}
|
||||||
|
if name.ends_with(".srt") {
|
||||||
|
count_srt += 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert!(json_paths.len() >= 2, "expected at least 2 JSON files, found {}", json_paths.len());
|
assert!(
|
||||||
assert!(count_toml >= 2, "expected at least 2 TOML files, found {}", count_toml);
|
json_paths.len() >= 2,
|
||||||
assert!(count_srt >= 2, "expected at least 2 SRT files, found {}", count_srt);
|
"expected at least 2 JSON files, found {}",
|
||||||
|
json_paths.len()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
count_toml >= 2,
|
||||||
|
"expected at least 2 TOML files, found {}",
|
||||||
|
count_toml
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
count_srt >= 2,
|
||||||
|
"expected at least 2 SRT files, found {}",
|
||||||
|
count_srt
|
||||||
|
);
|
||||||
|
|
||||||
// JSON contents are assumed valid if files exist; detailed parsing is covered elsewhere
|
// JSON contents are assumed valid if files exist; detailed parsing is covered elsewhere
|
||||||
|
|
||||||
@@ -124,9 +144,15 @@ fn cli_merges_json_inputs_with_flag_and_writes_outputs_to_temp_dir() {
|
|||||||
for e in entries {
|
for e in entries {
|
||||||
let p = e.unwrap().path();
|
let p = e.unwrap().path();
|
||||||
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
|
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
|
||||||
if name.ends_with("_out.json") { found_json = Some(p.clone()); }
|
if name.ends_with("_out.json") {
|
||||||
if name.ends_with("_out.toml") { found_toml = Some(p.clone()); }
|
found_json = Some(p.clone());
|
||||||
if name.ends_with("_out.srt") { found_srt = Some(p.clone()); }
|
}
|
||||||
|
if name.ends_with("_out.toml") {
|
||||||
|
found_toml = Some(p.clone());
|
||||||
|
}
|
||||||
|
if name.ends_with("_out.srt") {
|
||||||
|
found_srt = Some(p.clone());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let _json_path = found_json.expect("missing JSON output in temp dir");
|
let _json_path = found_json.expect("missing JSON output in temp dir");
|
||||||
@@ -154,7 +180,10 @@ fn cli_prints_json_to_stdout_when_no_output_path_merge_mode() {
|
|||||||
assert!(output.status.success(), "CLI failed");
|
assert!(output.status.success(), "CLI failed");
|
||||||
|
|
||||||
let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8");
|
let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8");
|
||||||
assert!(stdout.contains("\"items\""), "stdout should contain items JSON array");
|
assert!(
|
||||||
|
stdout.contains("\"items\""),
|
||||||
|
"stdout should contain items JSON array"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -187,16 +216,36 @@ fn cli_merge_and_separate_writes_both_kinds_of_outputs() {
|
|||||||
for e in entries {
|
for e in entries {
|
||||||
let p = e.unwrap().path();
|
let p = e.unwrap().path();
|
||||||
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
|
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
|
||||||
if name.ends_with(".json") { json_count += 1; }
|
if name.ends_with(".json") {
|
||||||
if name.ends_with(".toml") { toml_count += 1; }
|
json_count += 1;
|
||||||
if name.ends_with(".srt") { srt_count += 1; }
|
}
|
||||||
if name.ends_with("_merged.json") { merged_json = Some(p.clone()); }
|
if name.ends_with(".toml") {
|
||||||
|
toml_count += 1;
|
||||||
|
}
|
||||||
|
if name.ends_with(".srt") {
|
||||||
|
srt_count += 1;
|
||||||
|
}
|
||||||
|
if name.ends_with("_merged.json") {
|
||||||
|
merged_json = Some(p.clone());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// At least 2 inputs -> expect at least 3 JSONs (2 separate + 1 merged)
|
// At least 2 inputs -> expect at least 3 JSONs (2 separate + 1 merged)
|
||||||
assert!(json_count >= 3, "expected at least 3 JSON files, found {}", json_count);
|
assert!(
|
||||||
assert!(toml_count >= 3, "expected at least 3 TOML files, found {}", toml_count);
|
json_count >= 3,
|
||||||
assert!(srt_count >= 3, "expected at least 3 SRT files, found {}", srt_count);
|
"expected at least 3 JSON files, found {}",
|
||||||
|
json_count
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
toml_count >= 3,
|
||||||
|
"expected at least 3 TOML files, found {}",
|
||||||
|
toml_count
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
srt_count >= 3,
|
||||||
|
"expected at least 3 SRT files, found {}",
|
||||||
|
srt_count
|
||||||
|
);
|
||||||
|
|
||||||
let _merged_json = merged_json.expect("missing merged JSON output ending with _merged.json");
|
let _merged_json = merged_json.expect("missing merged JSON output ending with _merged.json");
|
||||||
// Contents of merged JSON are validated by unit tests and other integration coverage
|
// Contents of merged JSON are validated by unit tests and other integration coverage
|
||||||
@@ -205,7 +254,6 @@ fn cli_merge_and_separate_writes_both_kinds_of_outputs() {
|
|||||||
let _ = fs::remove_dir_all(&out_dir);
|
let _ = fs::remove_dir_all(&out_dir);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn cli_set_speaker_names_merge_prompts_and_uses_names() {
|
fn cli_set_speaker_names_merge_prompts_and_uses_names() {
|
||||||
use std::io::{Read as _, Write as _};
|
use std::io::{Read as _, Write as _};
|
||||||
@@ -238,7 +286,8 @@ fn cli_set_speaker_names_merge_prompts_and_uses_names() {
|
|||||||
|
|
||||||
let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8");
|
let stdout = String::from_utf8(output.stdout).expect("stdout not UTF-8");
|
||||||
let root: OutputRoot = serde_json::from_str(&stdout).unwrap();
|
let root: OutputRoot = serde_json::from_str(&stdout).unwrap();
|
||||||
let speakers: std::collections::HashSet<String> = root.items.into_iter().map(|e| e.speaker).collect();
|
let speakers: std::collections::HashSet<String> =
|
||||||
|
root.items.into_iter().map(|e| e.speaker).collect();
|
||||||
assert!(speakers.contains("Alpha"), "Alpha not found in speakers");
|
assert!(speakers.contains("Alpha"), "Alpha not found in speakers");
|
||||||
assert!(speakers.contains("Beta"), "Beta not found in speakers");
|
assert!(speakers.contains("Beta"), "Beta not found in speakers");
|
||||||
}
|
}
|
||||||
@@ -279,12 +328,17 @@ fn cli_set_speaker_names_separate_single_input() {
|
|||||||
for e in fs::read_dir(&out_dir).unwrap() {
|
for e in fs::read_dir(&out_dir).unwrap() {
|
||||||
let p = e.unwrap().path();
|
let p = e.unwrap().path();
|
||||||
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
|
if let Some(name) = p.file_name().and_then(|s| s.to_str()) {
|
||||||
if name.ends_with(".json") { json_paths.push(p.clone()); }
|
if name.ends_with(".json") {
|
||||||
|
json_paths.push(p.clone());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert!(!json_paths.is_empty(), "no JSON outputs created");
|
assert!(!json_paths.is_empty(), "no JSON outputs created");
|
||||||
let mut buf = String::new();
|
let mut buf = String::new();
|
||||||
std::fs::File::open(&json_paths[0]).unwrap().read_to_string(&mut buf).unwrap();
|
std::fs::File::open(&json_paths[0])
|
||||||
|
.unwrap()
|
||||||
|
.read_to_string(&mut buf)
|
||||||
|
.unwrap();
|
||||||
let root: OutputRoot = serde_json::from_str(&buf).unwrap();
|
let root: OutputRoot = serde_json::from_str(&buf).unwrap();
|
||||||
assert!(root.items.iter().all(|e| e.speaker == "ChosenOne"));
|
assert!(root.items.iter().all(|e| e.speaker == "ChosenOne"));
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user