[refactor] simplify backend initialization and transcription logic using macro and trait improvements
This commit is contained in:
124
src/backend.rs
124
src/backend.rs
@@ -87,107 +87,39 @@ fn vulkan_available() -> bool {
|
||||
}
|
||||
|
||||
/// CPU-based transcription backend using whisper-rs.
|
||||
#[derive(Default)]
|
||||
pub struct CpuBackend;
|
||||
/// CUDA-accelerated transcription backend for NVIDIA GPUs.
|
||||
#[derive(Default)]
|
||||
pub struct CudaBackend;
|
||||
/// ROCm/HIP-accelerated transcription backend for AMD GPUs.
|
||||
#[derive(Default)]
|
||||
pub struct HipBackend;
|
||||
/// Vulkan-based transcription backend (experimental/incomplete).
|
||||
#[derive(Default)]
|
||||
pub struct VulkanBackend;
|
||||
|
||||
impl CpuBackend {
|
||||
/// Create a new CPU backend instance.
|
||||
pub fn new() -> Self {
|
||||
CpuBackend
|
||||
}
|
||||
}
|
||||
impl Default for CpuBackend {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
impl CudaBackend {
|
||||
/// Create a new CUDA backend instance.
|
||||
pub fn new() -> Self {
|
||||
CudaBackend
|
||||
}
|
||||
}
|
||||
impl Default for CudaBackend {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
impl HipBackend {
|
||||
/// Create a new HIP backend instance.
|
||||
pub fn new() -> Self {
|
||||
HipBackend
|
||||
}
|
||||
}
|
||||
impl Default for HipBackend {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
impl VulkanBackend {
|
||||
/// Create a new Vulkan backend instance.
|
||||
pub fn new() -> Self {
|
||||
VulkanBackend
|
||||
}
|
||||
}
|
||||
impl Default for VulkanBackend {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
macro_rules! impl_whisper_backend {
|
||||
($ty:ty, $kind:expr) => {
|
||||
impl TranscribeBackend for $ty {
|
||||
fn kind(&self) -> BackendKind { $kind }
|
||||
fn transcribe(
|
||||
&self,
|
||||
audio_path: &Path,
|
||||
speaker: &str,
|
||||
lang_opt: Option<&str>,
|
||||
_gpu_layers: Option<u32>,
|
||||
progress_cb: Option<&(dyn Fn(i32) + Send + Sync)>,
|
||||
) -> Result<Vec<OutputEntry>> {
|
||||
transcribe_with_whisper_rs(audio_path, speaker, lang_opt, progress_cb)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl TranscribeBackend for CpuBackend {
|
||||
fn kind(&self) -> BackendKind {
|
||||
BackendKind::Cpu
|
||||
}
|
||||
fn transcribe(
|
||||
&self,
|
||||
audio_path: &Path,
|
||||
speaker: &str,
|
||||
lang_opt: Option<&str>,
|
||||
_gpu_layers: Option<u32>,
|
||||
progress_cb: Option<&(dyn Fn(i32) + Send + Sync)>,
|
||||
) -> Result<Vec<OutputEntry>> {
|
||||
transcribe_with_whisper_rs(audio_path, speaker, lang_opt, progress_cb)
|
||||
}
|
||||
}
|
||||
|
||||
impl TranscribeBackend for CudaBackend {
|
||||
fn kind(&self) -> BackendKind {
|
||||
BackendKind::Cuda
|
||||
}
|
||||
fn transcribe(
|
||||
&self,
|
||||
audio_path: &Path,
|
||||
speaker: &str,
|
||||
lang_opt: Option<&str>,
|
||||
_gpu_layers: Option<u32>,
|
||||
progress_cb: Option<&(dyn Fn(i32) + Send + Sync)>,
|
||||
) -> Result<Vec<OutputEntry>> {
|
||||
// whisper-rs uses enabled CUDA feature at build time; call same code path
|
||||
transcribe_with_whisper_rs(audio_path, speaker, lang_opt, progress_cb)
|
||||
}
|
||||
}
|
||||
|
||||
impl TranscribeBackend for HipBackend {
|
||||
fn kind(&self) -> BackendKind {
|
||||
BackendKind::Hip
|
||||
}
|
||||
fn transcribe(
|
||||
&self,
|
||||
audio_path: &Path,
|
||||
speaker: &str,
|
||||
lang_opt: Option<&str>,
|
||||
_gpu_layers: Option<u32>,
|
||||
progress_cb: Option<&(dyn Fn(i32) + Send + Sync)>,
|
||||
) -> Result<Vec<OutputEntry>> {
|
||||
transcribe_with_whisper_rs(audio_path, speaker, lang_opt, progress_cb)
|
||||
}
|
||||
}
|
||||
impl_whisper_backend!(CpuBackend, BackendKind::Cpu);
|
||||
impl_whisper_backend!(CudaBackend, BackendKind::Cuda);
|
||||
impl_whisper_backend!(HipBackend, BackendKind::Hip);
|
||||
|
||||
impl TranscribeBackend for VulkanBackend {
|
||||
fn kind(&self) -> BackendKind {
|
||||
@@ -239,11 +171,11 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
|
||||
|
||||
let mk = |k: BackendKind| -> Box<dyn TranscribeBackend + Send + Sync> {
|
||||
match k {
|
||||
BackendKind::Cpu => Box::new(CpuBackend::new()),
|
||||
BackendKind::Cuda => Box::new(CudaBackend::new()),
|
||||
BackendKind::Hip => Box::new(HipBackend::new()),
|
||||
BackendKind::Vulkan => Box::new(VulkanBackend::new()),
|
||||
BackendKind::Auto => Box::new(CpuBackend::new()), // will be replaced
|
||||
BackendKind::Cpu => Box::new(CpuBackend::default()),
|
||||
BackendKind::Cuda => Box::new(CudaBackend::default()),
|
||||
BackendKind::Hip => Box::new(HipBackend::default()),
|
||||
BackendKind::Vulkan => Box::new(VulkanBackend::default()),
|
||||
BackendKind::Auto => Box::new(CpuBackend::default()), // will be replaced
|
||||
}
|
||||
};
|
||||
|
||||
|
@@ -15,23 +15,13 @@ use std::path::{Path, PathBuf};
|
||||
///
|
||||
/// Heuristic: choose the largest .bin file by size. Returns None if none found.
|
||||
pub fn pick_best_local_model(dir: &Path) -> Option<PathBuf> {
|
||||
let mut best: Option<(u64, PathBuf)> = None;
|
||||
let rd = fs::read_dir(dir).ok()?;
|
||||
for e in rd.flatten() {
|
||||
let p = e.path();
|
||||
if p.is_file() {
|
||||
if p.extension().and_then(|s| s.to_str()).map(|s| s.eq_ignore_ascii_case("bin")).unwrap_or(false) {
|
||||
if let Ok(md) = fs::metadata(&p) {
|
||||
let sz = md.len();
|
||||
match &best {
|
||||
Some((b_sz, _)) if *b_sz >= sz => {}
|
||||
_ => best = Some((sz, p.clone())),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
best.map(|(_, p)| p)
|
||||
rd.flatten()
|
||||
.map(|e| e.path())
|
||||
.filter(|p| p.is_file() && p.extension().and_then(|s| s.to_str()).is_some_and(|s| s.eq_ignore_ascii_case("bin")))
|
||||
.filter_map(|p| fs::metadata(&p).ok().map(|md| (md.len(), p)))
|
||||
.max_by_key(|(sz, _)| *sz)
|
||||
.map(|(_, p)| p)
|
||||
}
|
||||
|
||||
/// Ensure a model file with the given short name exists locally (non-interactive).
|
||||
@@ -105,22 +95,19 @@ pub fn run_interactive_model_downloader() -> Result<()> {
|
||||
let selection = if selection_raw.is_empty() { "1" } else { &selection_raw };
|
||||
|
||||
// Parse indices
|
||||
let mut picked_indices: Vec<usize> = Vec::new();
|
||||
for part in selection.split(|c| c == ',' || c == ' ' || c == ';') {
|
||||
use std::collections::BTreeSet;
|
||||
let mut picked_set: BTreeSet<usize> = BTreeSet::new();
|
||||
for part in selection.split([',', ' ', ';']) {
|
||||
let t = part.trim();
|
||||
if t.is_empty() { continue; }
|
||||
match t.parse::<usize>() {
|
||||
Ok(n) if n >= 1 && n <= available.len() => {
|
||||
let idx = n - 1;
|
||||
if !picked_indices.contains(&idx) {
|
||||
picked_indices.push(idx);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
ui::warn(format!("Ignoring invalid selection: '{}'", t));
|
||||
Ok(n) if (1..=available.len()).contains(&n) => {
|
||||
picked_set.insert(n - 1);
|
||||
}
|
||||
_ => ui::warn(format!("Ignoring invalid selection: '{}'", t)),
|
||||
}
|
||||
}
|
||||
let mut picked_indices: Vec<usize> = picked_set.into_iter().collect();
|
||||
if picked_indices.is_empty() {
|
||||
// Fallback to default first item
|
||||
picked_indices.push(0);
|
||||
|
Reference in New Issue
Block a user