[refactor] simplify backend initialization and transcription logic using macro and trait improvements

This commit is contained in:
2025-08-12 12:05:32 +02:00
parent 9fd44a2e37
commit 79397a3b9c
2 changed files with 41 additions and 122 deletions

View File

@@ -87,63 +87,22 @@ fn vulkan_available() -> bool {
} }
/// CPU-based transcription backend using whisper-rs. /// CPU-based transcription backend using whisper-rs.
#[derive(Default)]
pub struct CpuBackend; pub struct CpuBackend;
/// CUDA-accelerated transcription backend for NVIDIA GPUs. /// CUDA-accelerated transcription backend for NVIDIA GPUs.
#[derive(Default)]
pub struct CudaBackend; pub struct CudaBackend;
/// ROCm/HIP-accelerated transcription backend for AMD GPUs. /// ROCm/HIP-accelerated transcription backend for AMD GPUs.
#[derive(Default)]
pub struct HipBackend; pub struct HipBackend;
/// Vulkan-based transcription backend (experimental/incomplete). /// Vulkan-based transcription backend (experimental/incomplete).
#[derive(Default)]
pub struct VulkanBackend; pub struct VulkanBackend;
impl CpuBackend { macro_rules! impl_whisper_backend {
/// Create a new CPU backend instance. ($ty:ty, $kind:expr) => {
pub fn new() -> Self { impl TranscribeBackend for $ty {
CpuBackend fn kind(&self) -> BackendKind { $kind }
}
}
impl Default for CpuBackend {
fn default() -> Self {
Self::new()
}
}
impl CudaBackend {
/// Create a new CUDA backend instance.
pub fn new() -> Self {
CudaBackend
}
}
impl Default for CudaBackend {
fn default() -> Self {
Self::new()
}
}
impl HipBackend {
/// Create a new HIP backend instance.
pub fn new() -> Self {
HipBackend
}
}
impl Default for HipBackend {
fn default() -> Self {
Self::new()
}
}
impl VulkanBackend {
/// Create a new Vulkan backend instance.
pub fn new() -> Self {
VulkanBackend
}
}
impl Default for VulkanBackend {
fn default() -> Self {
Self::new()
}
}
impl TranscribeBackend for CpuBackend {
fn kind(&self) -> BackendKind {
BackendKind::Cpu
}
fn transcribe( fn transcribe(
&self, &self,
audio_path: &Path, audio_path: &Path,
@@ -154,40 +113,13 @@ impl TranscribeBackend for CpuBackend {
) -> Result<Vec<OutputEntry>> { ) -> Result<Vec<OutputEntry>> {
transcribe_with_whisper_rs(audio_path, speaker, lang_opt, progress_cb) transcribe_with_whisper_rs(audio_path, speaker, lang_opt, progress_cb)
} }
}
};
} }
impl TranscribeBackend for CudaBackend { impl_whisper_backend!(CpuBackend, BackendKind::Cpu);
fn kind(&self) -> BackendKind { impl_whisper_backend!(CudaBackend, BackendKind::Cuda);
BackendKind::Cuda impl_whisper_backend!(HipBackend, BackendKind::Hip);
}
fn transcribe(
&self,
audio_path: &Path,
speaker: &str,
lang_opt: Option<&str>,
_gpu_layers: Option<u32>,
progress_cb: Option<&(dyn Fn(i32) + Send + Sync)>,
) -> Result<Vec<OutputEntry>> {
// whisper-rs uses enabled CUDA feature at build time; call same code path
transcribe_with_whisper_rs(audio_path, speaker, lang_opt, progress_cb)
}
}
impl TranscribeBackend for HipBackend {
fn kind(&self) -> BackendKind {
BackendKind::Hip
}
fn transcribe(
&self,
audio_path: &Path,
speaker: &str,
lang_opt: Option<&str>,
_gpu_layers: Option<u32>,
progress_cb: Option<&(dyn Fn(i32) + Send + Sync)>,
) -> Result<Vec<OutputEntry>> {
transcribe_with_whisper_rs(audio_path, speaker, lang_opt, progress_cb)
}
}
impl TranscribeBackend for VulkanBackend { impl TranscribeBackend for VulkanBackend {
fn kind(&self) -> BackendKind { fn kind(&self) -> BackendKind {
@@ -239,11 +171,11 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
let mk = |k: BackendKind| -> Box<dyn TranscribeBackend + Send + Sync> { let mk = |k: BackendKind| -> Box<dyn TranscribeBackend + Send + Sync> {
match k { match k {
BackendKind::Cpu => Box::new(CpuBackend::new()), BackendKind::Cpu => Box::new(CpuBackend::default()),
BackendKind::Cuda => Box::new(CudaBackend::new()), BackendKind::Cuda => Box::new(CudaBackend::default()),
BackendKind::Hip => Box::new(HipBackend::new()), BackendKind::Hip => Box::new(HipBackend::default()),
BackendKind::Vulkan => Box::new(VulkanBackend::new()), BackendKind::Vulkan => Box::new(VulkanBackend::default()),
BackendKind::Auto => Box::new(CpuBackend::new()), // will be replaced BackendKind::Auto => Box::new(CpuBackend::default()), // will be replaced
} }
}; };

View File

@@ -15,23 +15,13 @@ use std::path::{Path, PathBuf};
/// ///
/// Heuristic: choose the largest .bin file by size. Returns None if none found. /// Heuristic: choose the largest .bin file by size. Returns None if none found.
pub fn pick_best_local_model(dir: &Path) -> Option<PathBuf> { pub fn pick_best_local_model(dir: &Path) -> Option<PathBuf> {
let mut best: Option<(u64, PathBuf)> = None;
let rd = fs::read_dir(dir).ok()?; let rd = fs::read_dir(dir).ok()?;
for e in rd.flatten() { rd.flatten()
let p = e.path(); .map(|e| e.path())
if p.is_file() { .filter(|p| p.is_file() && p.extension().and_then(|s| s.to_str()).is_some_and(|s| s.eq_ignore_ascii_case("bin")))
if p.extension().and_then(|s| s.to_str()).map(|s| s.eq_ignore_ascii_case("bin")).unwrap_or(false) { .filter_map(|p| fs::metadata(&p).ok().map(|md| (md.len(), p)))
if let Ok(md) = fs::metadata(&p) { .max_by_key(|(sz, _)| *sz)
let sz = md.len(); .map(|(_, p)| p)
match &best {
Some((b_sz, _)) if *b_sz >= sz => {}
_ => best = Some((sz, p.clone())),
}
}
}
}
}
best.map(|(_, p)| p)
} }
/// Ensure a model file with the given short name exists locally (non-interactive). /// Ensure a model file with the given short name exists locally (non-interactive).
@@ -105,22 +95,19 @@ pub fn run_interactive_model_downloader() -> Result<()> {
let selection = if selection_raw.is_empty() { "1" } else { &selection_raw }; let selection = if selection_raw.is_empty() { "1" } else { &selection_raw };
// Parse indices // Parse indices
let mut picked_indices: Vec<usize> = Vec::new(); use std::collections::BTreeSet;
for part in selection.split(|c| c == ',' || c == ' ' || c == ';') { let mut picked_set: BTreeSet<usize> = BTreeSet::new();
for part in selection.split([',', ' ', ';']) {
let t = part.trim(); let t = part.trim();
if t.is_empty() { continue; } if t.is_empty() { continue; }
match t.parse::<usize>() { match t.parse::<usize>() {
Ok(n) if n >= 1 && n <= available.len() => { Ok(n) if (1..=available.len()).contains(&n) => {
let idx = n - 1; picked_set.insert(n - 1);
if !picked_indices.contains(&idx) { }
picked_indices.push(idx); _ => ui::warn(format!("Ignoring invalid selection: '{}'", t)),
}
}
_ => {
ui::warn(format!("Ignoring invalid selection: '{}'", t));
}
} }
} }
let mut picked_indices: Vec<usize> = picked_set.into_iter().collect();
if picked_indices.is_empty() { if picked_indices.is_empty() {
// Fallback to default first item // Fallback to default first item
picked_indices.push(0); picked_indices.push(0);