From 79397a3b9cc2eb2ace7178625d30aec600c8d042 Mon Sep 17 00:00:00 2001 From: vikingowl Date: Tue, 12 Aug 2025 12:05:32 +0200 Subject: [PATCH] [refactor] simplify backend initialization and transcription logic using macro and trait improvements --- src/backend.rs | 124 +++++++++++-------------------------------------- src/models.rs | 39 ++++++---------- 2 files changed, 41 insertions(+), 122 deletions(-) diff --git a/src/backend.rs b/src/backend.rs index 4f3ffae..eb5be7e 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -87,107 +87,39 @@ fn vulkan_available() -> bool { } /// CPU-based transcription backend using whisper-rs. +#[derive(Default)] pub struct CpuBackend; /// CUDA-accelerated transcription backend for NVIDIA GPUs. +#[derive(Default)] pub struct CudaBackend; /// ROCm/HIP-accelerated transcription backend for AMD GPUs. +#[derive(Default)] pub struct HipBackend; /// Vulkan-based transcription backend (experimental/incomplete). +#[derive(Default)] pub struct VulkanBackend; -impl CpuBackend { - /// Create a new CPU backend instance. - pub fn new() -> Self { - CpuBackend - } -} -impl Default for CpuBackend { - fn default() -> Self { - Self::new() - } -} -impl CudaBackend { - /// Create a new CUDA backend instance. - pub fn new() -> Self { - CudaBackend - } -} -impl Default for CudaBackend { - fn default() -> Self { - Self::new() - } -} -impl HipBackend { - /// Create a new HIP backend instance. - pub fn new() -> Self { - HipBackend - } -} -impl Default for HipBackend { - fn default() -> Self { - Self::new() - } -} -impl VulkanBackend { - /// Create a new Vulkan backend instance. - pub fn new() -> Self { - VulkanBackend - } -} -impl Default for VulkanBackend { - fn default() -> Self { - Self::new() - } +macro_rules! impl_whisper_backend { + ($ty:ty, $kind:expr) => { + impl TranscribeBackend for $ty { + fn kind(&self) -> BackendKind { $kind } + fn transcribe( + &self, + audio_path: &Path, + speaker: &str, + lang_opt: Option<&str>, + _gpu_layers: Option, + progress_cb: Option<&(dyn Fn(i32) + Send + Sync)>, + ) -> Result> { + transcribe_with_whisper_rs(audio_path, speaker, lang_opt, progress_cb) + } + } + }; } -impl TranscribeBackend for CpuBackend { - fn kind(&self) -> BackendKind { - BackendKind::Cpu - } - fn transcribe( - &self, - audio_path: &Path, - speaker: &str, - lang_opt: Option<&str>, - _gpu_layers: Option, - progress_cb: Option<&(dyn Fn(i32) + Send + Sync)>, - ) -> Result> { - transcribe_with_whisper_rs(audio_path, speaker, lang_opt, progress_cb) - } -} - -impl TranscribeBackend for CudaBackend { - fn kind(&self) -> BackendKind { - BackendKind::Cuda - } - fn transcribe( - &self, - audio_path: &Path, - speaker: &str, - lang_opt: Option<&str>, - _gpu_layers: Option, - progress_cb: Option<&(dyn Fn(i32) + Send + Sync)>, - ) -> Result> { - // whisper-rs uses enabled CUDA feature at build time; call same code path - transcribe_with_whisper_rs(audio_path, speaker, lang_opt, progress_cb) - } -} - -impl TranscribeBackend for HipBackend { - fn kind(&self) -> BackendKind { - BackendKind::Hip - } - fn transcribe( - &self, - audio_path: &Path, - speaker: &str, - lang_opt: Option<&str>, - _gpu_layers: Option, - progress_cb: Option<&(dyn Fn(i32) + Send + Sync)>, - ) -> Result> { - transcribe_with_whisper_rs(audio_path, speaker, lang_opt, progress_cb) - } -} +impl_whisper_backend!(CpuBackend, BackendKind::Cpu); +impl_whisper_backend!(CudaBackend, BackendKind::Cuda); +impl_whisper_backend!(HipBackend, BackendKind::Hip); impl TranscribeBackend for VulkanBackend { fn kind(&self) -> BackendKind { @@ -239,11 +171,11 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result Box { match k { - BackendKind::Cpu => Box::new(CpuBackend::new()), - BackendKind::Cuda => Box::new(CudaBackend::new()), - BackendKind::Hip => Box::new(HipBackend::new()), - BackendKind::Vulkan => Box::new(VulkanBackend::new()), - BackendKind::Auto => Box::new(CpuBackend::new()), // will be replaced + BackendKind::Cpu => Box::new(CpuBackend::default()), + BackendKind::Cuda => Box::new(CudaBackend::default()), + BackendKind::Hip => Box::new(HipBackend::default()), + BackendKind::Vulkan => Box::new(VulkanBackend::default()), + BackendKind::Auto => Box::new(CpuBackend::default()), // will be replaced } }; diff --git a/src/models.rs b/src/models.rs index c40f27c..3cc4eb0 100644 --- a/src/models.rs +++ b/src/models.rs @@ -15,23 +15,13 @@ use std::path::{Path, PathBuf}; /// /// Heuristic: choose the largest .bin file by size. Returns None if none found. pub fn pick_best_local_model(dir: &Path) -> Option { - let mut best: Option<(u64, PathBuf)> = None; let rd = fs::read_dir(dir).ok()?; - for e in rd.flatten() { - let p = e.path(); - if p.is_file() { - if p.extension().and_then(|s| s.to_str()).map(|s| s.eq_ignore_ascii_case("bin")).unwrap_or(false) { - if let Ok(md) = fs::metadata(&p) { - let sz = md.len(); - match &best { - Some((b_sz, _)) if *b_sz >= sz => {} - _ => best = Some((sz, p.clone())), - } - } - } - } - } - best.map(|(_, p)| p) + rd.flatten() + .map(|e| e.path()) + .filter(|p| p.is_file() && p.extension().and_then(|s| s.to_str()).is_some_and(|s| s.eq_ignore_ascii_case("bin"))) + .filter_map(|p| fs::metadata(&p).ok().map(|md| (md.len(), p))) + .max_by_key(|(sz, _)| *sz) + .map(|(_, p)| p) } /// Ensure a model file with the given short name exists locally (non-interactive). @@ -105,22 +95,19 @@ pub fn run_interactive_model_downloader() -> Result<()> { let selection = if selection_raw.is_empty() { "1" } else { &selection_raw }; // Parse indices - let mut picked_indices: Vec = Vec::new(); - for part in selection.split(|c| c == ',' || c == ' ' || c == ';') { + use std::collections::BTreeSet; + let mut picked_set: BTreeSet = BTreeSet::new(); + for part in selection.split([',', ' ', ';']) { let t = part.trim(); if t.is_empty() { continue; } match t.parse::() { - Ok(n) if n >= 1 && n <= available.len() => { - let idx = n - 1; - if !picked_indices.contains(&idx) { - picked_indices.push(idx); - } - } - _ => { - ui::warn(format!("Ignoring invalid selection: '{}'", t)); + Ok(n) if (1..=available.len()).contains(&n) => { + picked_set.insert(n - 1); } + _ => ui::warn(format!("Ignoring invalid selection: '{}'", t)), } } + let mut picked_indices: Vec = picked_set.into_iter().collect(); if picked_indices.is_empty() { // Fallback to default first item picked_indices.push(0);