lucas-ventura's picture
Update tools/extract/asr.py
ea75f09 verified
raw
history blame
2.34 kB
import inspect
from pathlib import Path
import torch
import whisperx
from whisperx.audio import SAMPLE_RATE
from src.data.chapters import sec_to_hms
# Set device and disable TF32 for consistent results
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
class ASRProcessor:
"""
Automatic Speech Recognition processor using WhisperX.
Transcribes audio files and returns time-aligned transcription segments.
"""
def __init__(self, model_name="large-v2", compute_type="float16"):
self.model_name = model_name
# Check if whisperx.load_model accepts compute_type parameter
if "compute_type" in inspect.signature(whisperx.load_model).parameters:
self.model = whisperx.load_model(
model_name, device, compute_type=compute_type
)
else:
self.model = whisperx.load_model(model_name, device)
def get_asr(self, audio_file, return_duration=True):
assert Path(audio_file).exists(), f"File {audio_file} does not exist"
audio = whisperx.load_audio(audio_file)
if "batch_size" in inspect.signature(self.model.transcribe).parameters:
result = self.model.transcribe(audio, batch_size=1)
else:
result = self.model.transcribe(audio)
language = result["language"]
duration = audio.shape[0] / SAMPLE_RATE
# Align the transcription
model_a, metadata = whisperx.load_align_model(
language_code=language, device=device
)
aligned_result = whisperx.align(
result["segments"],
model_a,
metadata,
audio,
device,
return_char_alignments=False,
)
# Format the output
segments = [
{field: segment[field] for field in ["start", "end", "text"]}
for segment in aligned_result["segments"]
]
asr_clean = []
for segment in segments:
t = segment["text"].strip()
s = sec_to_hms(segment["start"])
asr_clean.append(f"{s}: {t}")
if return_duration:
return "\n".join(asr_clean) + "\n", duration
else:
return "\n".join(asr_clean) + "\n"