[refactor] improve variable naming and simplify logic across multiple functions and structs
This commit is contained in:
118
src/backend.rs
118
src/backend.rs
@@ -24,25 +24,18 @@ pub enum BackendKind {
|
||||
Vulkan,
|
||||
}
|
||||
|
||||
/// Abstraction for a transcription backend implementation.
|
||||
/// Abstraction for a transcription backend.
|
||||
pub trait TranscribeBackend {
|
||||
/// Return the backend kind for this implementation.
|
||||
/// Backend kind implemented by this type.
|
||||
fn kind(&self) -> BackendKind;
|
||||
/// Transcribe the given audio file path and return transcript entries.
|
||||
///
|
||||
/// Parameters:
|
||||
/// - audio_path: path to input media (audio or video) to be decoded/transcribed.
|
||||
/// - speaker: label to attach to all produced segments.
|
||||
/// - lang_opt: optional language hint (e.g., "en"); None means auto/multilingual model default.
|
||||
/// - gpu_layers: optional GPU layer count if applicable (ignored by some backends).
|
||||
/// - progress_cb: optional callback receiving percentage [0..=100] updates.
|
||||
/// Transcribe the given audio and return transcript entries.
|
||||
fn transcribe(
|
||||
&self,
|
||||
audio_path: &Path,
|
||||
speaker: &str,
|
||||
lang_opt: Option<&str>,
|
||||
language: Option<&str>,
|
||||
gpu_layers: Option<u32>,
|
||||
progress_cb: Option<&(dyn Fn(i32) + Send + Sync)>,
|
||||
progress: Option<&(dyn Fn(i32) + Send + Sync)>,
|
||||
) -> Result<Vec<OutputEntry>>;
|
||||
}
|
||||
|
||||
@@ -107,11 +100,11 @@ macro_rules! impl_whisper_backend {
|
||||
&self,
|
||||
audio_path: &Path,
|
||||
speaker: &str,
|
||||
lang_opt: Option<&str>,
|
||||
language: Option<&str>,
|
||||
_gpu_layers: Option<u32>,
|
||||
progress_cb: Option<&(dyn Fn(i32) + Send + Sync)>,
|
||||
progress: Option<&(dyn Fn(i32) + Send + Sync)>,
|
||||
) -> Result<Vec<OutputEntry>> {
|
||||
transcribe_with_whisper_rs(audio_path, speaker, lang_opt, progress_cb)
|
||||
transcribe_with_whisper_rs(audio_path, speaker, language, progress)
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -129,9 +122,9 @@ impl TranscribeBackend for VulkanBackend {
|
||||
&self,
|
||||
_audio_path: &Path,
|
||||
_speaker: &str,
|
||||
_lang_opt: Option<&str>,
|
||||
_language: Option<&str>,
|
||||
_gpu_layers: Option<u32>,
|
||||
_progress_cb: Option<&(dyn Fn(i32) + Send + Sync)>,
|
||||
_progress: Option<&(dyn Fn(i32) + Send + Sync)>,
|
||||
) -> Result<Vec<OutputEntry>> {
|
||||
Err(anyhow!(
|
||||
"Vulkan backend not yet wired to whisper.cpp FFI. Build with --features gpu-vulkan and ensure Vulkan SDK is installed. How to fix: install Vulkan loader (libvulkan), set VULKAN_SDK, and run cargo build --features gpu-vulkan."
|
||||
@@ -169,13 +162,13 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
|
||||
detected.push(BackendKind::Vulkan);
|
||||
}
|
||||
|
||||
let mk = |k: BackendKind| -> Box<dyn TranscribeBackend + Send + Sync> {
|
||||
let instantiate_backend = |k: BackendKind| -> Box<dyn TranscribeBackend + Send + Sync> {
|
||||
match k {
|
||||
BackendKind::Cpu => Box::new(CpuBackend::default()),
|
||||
BackendKind::Cuda => Box::new(CudaBackend::default()),
|
||||
BackendKind::Hip => Box::new(HipBackend::default()),
|
||||
BackendKind::Vulkan => Box::new(VulkanBackend::default()),
|
||||
BackendKind::Auto => Box::new(CpuBackend::default()), // will be replaced
|
||||
BackendKind::Auto => Box::new(CpuBackend::default()), // placeholder for Auto
|
||||
}
|
||||
};
|
||||
|
||||
@@ -227,7 +220,7 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
|
||||
}
|
||||
|
||||
Ok(SelectionResult {
|
||||
backend: mk(chosen),
|
||||
backend: instantiate_backend(chosen),
|
||||
chosen,
|
||||
detected,
|
||||
})
|
||||
@@ -238,98 +231,99 @@ pub fn select_backend(requested: BackendKind, verbose: bool) -> Result<Selection
|
||||
pub(crate) fn transcribe_with_whisper_rs(
|
||||
audio_path: &Path,
|
||||
speaker: &str,
|
||||
lang_opt: Option<&str>,
|
||||
progress_cb: Option<&(dyn Fn(i32) + Send + Sync)>,
|
||||
language: Option<&str>,
|
||||
progress: Option<&(dyn Fn(i32) + Send + Sync)>,
|
||||
) -> Result<Vec<OutputEntry>> {
|
||||
if let Some(cb) = progress_cb { cb(0); }
|
||||
let report = |p: i32| {
|
||||
if let Some(cb) = progress { cb(p); }
|
||||
};
|
||||
report(0);
|
||||
|
||||
let pcm = decode_audio_to_pcm_f32_ffmpeg(audio_path)?;
|
||||
if let Some(cb) = progress_cb { cb(5); }
|
||||
let pcm_samples = decode_audio_to_pcm_f32_ffmpeg(audio_path)?;
|
||||
report(5);
|
||||
|
||||
let model = find_model_file()?;
|
||||
let is_en_only = model
|
||||
let model_path = find_model_file()?;
|
||||
let english_only_model = model_path
|
||||
.file_name()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(|s| s.contains(".en.") || s.ends_with(".en.bin"))
|
||||
.unwrap_or(false);
|
||||
if let Some(lang) = lang_opt {
|
||||
if is_en_only && lang != "en" {
|
||||
if let Some(lang) = language {
|
||||
if english_only_model && lang != "en" {
|
||||
return Err(anyhow!(
|
||||
"Selected model is English-only ({}), but a non-English language hint '{}' was provided. Please use a multilingual model or set WHISPER_MODEL.",
|
||||
model.display(),
|
||||
model_path.display(),
|
||||
lang
|
||||
));
|
||||
}
|
||||
}
|
||||
let model_str = model
|
||||
let model_path_str = model_path
|
||||
.to_str()
|
||||
.ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model.display()))?;
|
||||
.ok_or_else(|| anyhow!("Model path not valid UTF-8: {}", model_path.display()))?;
|
||||
|
||||
// Try to reduce native library logging via environment variables when not super-verbose.
|
||||
if crate::verbose_level() < 2 {
|
||||
// These env vars are recognized by ggml/whisper in many builds; harmless if unknown.
|
||||
// Some builds of whisper/ggml expect these env vars; harmless if unknown
|
||||
unsafe {
|
||||
std::env::set_var("GGML_LOG_LEVEL", "0");
|
||||
std::env::set_var("WHISPER_PRINT_PROGRESS", "0");
|
||||
}
|
||||
}
|
||||
|
||||
// Suppress stderr from whisper/ggml during model load and inference when quiet and not verbose.
|
||||
let (_ctx, mut state) = crate::with_suppressed_stderr(|| {
|
||||
let cparams = whisper_rs::WhisperContextParameters::default();
|
||||
let ctx = whisper_rs::WhisperContext::new_with_params(model_str, cparams)
|
||||
.with_context(|| format!("Failed to load Whisper model at {}", model.display()))?;
|
||||
let state = ctx
|
||||
let (_context, mut state) = crate::with_suppressed_stderr(|| {
|
||||
let params = whisper_rs::WhisperContextParameters::default();
|
||||
let context = whisper_rs::WhisperContext::new_with_params(model_path_str, params)
|
||||
.with_context(|| format!("Failed to load Whisper model at {}", model_path.display()))?;
|
||||
let state = context
|
||||
.create_state()
|
||||
.map_err(|e| anyhow!("Failed to create Whisper state: {:?}", e))?;
|
||||
Ok::<_, anyhow::Error>((ctx, state))
|
||||
Ok::<_, anyhow::Error>((context, state))
|
||||
})?;
|
||||
if let Some(cb) = progress_cb { cb(20); }
|
||||
report(20);
|
||||
|
||||
let mut params =
|
||||
let mut full_params =
|
||||
whisper_rs::FullParams::new(whisper_rs::SamplingStrategy::Greedy { best_of: 1 });
|
||||
let n_threads = std::thread::available_parallelism()
|
||||
let threads = std::thread::available_parallelism()
|
||||
.map(|n| n.get() as i32)
|
||||
.unwrap_or(1);
|
||||
params.set_n_threads(n_threads);
|
||||
params.set_translate(false);
|
||||
if let Some(lang) = lang_opt {
|
||||
params.set_language(Some(lang));
|
||||
full_params.set_n_threads(threads);
|
||||
full_params.set_translate(false);
|
||||
if let Some(lang) = language {
|
||||
full_params.set_language(Some(lang));
|
||||
}
|
||||
if let Some(cb) = progress_cb { cb(30); }
|
||||
report(30);
|
||||
|
||||
crate::with_suppressed_stderr(|| {
|
||||
if let Some(cb) = progress_cb { cb(40); }
|
||||
report(40);
|
||||
state
|
||||
.full(params, &pcm)
|
||||
.full(full_params, &pcm_samples)
|
||||
.map_err(|e| anyhow!("Whisper full() failed: {:?}", e))
|
||||
})?;
|
||||
|
||||
if let Some(cb) = progress_cb { cb(90); }
|
||||
report(90);
|
||||
let num_segments = state
|
||||
.full_n_segments()
|
||||
.map_err(|e| anyhow!("Failed to get segments: {:?}", e))?;
|
||||
let mut items = Vec::new();
|
||||
for i in 0..num_segments {
|
||||
let text = state
|
||||
.full_get_segment_text(i)
|
||||
let mut entries = Vec::new();
|
||||
for seg_idx in 0..num_segments {
|
||||
let segment_text = state
|
||||
.full_get_segment_text(seg_idx)
|
||||
.map_err(|e| anyhow!("Failed to get segment text: {:?}", e))?;
|
||||
let t0 = state
|
||||
.full_get_segment_t0(i)
|
||||
.full_get_segment_t0(seg_idx)
|
||||
.map_err(|e| anyhow!("Failed to get segment t0: {:?}", e))?;
|
||||
let t1 = state
|
||||
.full_get_segment_t1(i)
|
||||
.full_get_segment_t1(seg_idx)
|
||||
.map_err(|e| anyhow!("Failed to get segment t1: {:?}", e))?;
|
||||
let start = (t0 as f64) * 0.01;
|
||||
let end = (t1 as f64) * 0.01;
|
||||
items.push(OutputEntry {
|
||||
entries.push(OutputEntry {
|
||||
id: 0,
|
||||
speaker: speaker.to_string(),
|
||||
start,
|
||||
end,
|
||||
text: text.trim().to_string(),
|
||||
text: segment_text.trim().to_string(),
|
||||
});
|
||||
}
|
||||
if let Some(cb) = progress_cb { cb(100); }
|
||||
Ok(items)
|
||||
report(100);
|
||||
Ok(entries)
|
||||
}
|
||||
|
Reference in New Issue
Block a user