Michael Hu
Revert "lgos"
781eb5f
raw
history blame
5.77 kB
"""
Speech Recognition Module
Supports multiple ASR models including Whisper and Parakeet
Handles audio preprocessing and transcription
"""
import logging
import numpy as np
import os
from abc import ABC, abstractmethod
logger = logging.getLogger(__name__)
from faster_whisper import WhisperModel as FasterWhisperModel
from pydub import AudioSegment
class ASRModel(ABC):
"""Base class for ASR models"""
@abstractmethod
def load_model(self):
"""Load the ASR model"""
pass
@abstractmethod
def transcribe(self, audio_path):
"""Transcribe audio to text"""
pass
def preprocess_audio(self, audio_path):
"""Convert audio to required format"""
logger.info("Converting audio format")
audio = AudioSegment.from_file(audio_path)
processed_audio = audio.set_frame_rate(16000).set_channels(1)
wav_path = audio_path.replace(".mp3", ".wav") if audio_path.endswith(".mp3") else audio_path
if not wav_path.endswith(".wav"):
wav_path = f"{os.path.splitext(wav_path)[0]}.wav"
processed_audio.export(wav_path, format="wav")
logger.info(f"Audio converted to: {wav_path}")
return wav_path
class WhisperModel(ASRModel):
"""Faster Whisper ASR model implementation"""
def __init__(self):
self.model = None
# Check for CUDA availability without torch dependency
try:
import torch
self.device = "cuda" if torch.cuda.is_available() else "cpu"
except ImportError:
# Fallback to CPU if torch is not available
self.device = "cpu"
self.compute_type = "float16" if self.device == "cuda" else "int8"
def load_model(self):
"""Load Faster Whisper model"""
logger.info("Loading Faster Whisper model")
logger.info(f"Using device: {self.device}")
logger.info(f"Using compute type: {self.compute_type}")
# Use large-v3 model with appropriate compute type based on device
self.model = FasterWhisperModel(
"large-v3",
device=self.device,
compute_type=self.compute_type
)
logger.info("Faster Whisper model loaded successfully")
def transcribe(self, audio_path):
"""Transcribe audio using Faster Whisper"""
if self.model is None:
self.load_model()
wav_path = self.preprocess_audio(audio_path)
# Transcription with Faster Whisper
logger.info("Generating transcription with Faster Whisper")
segments, info = self.model.transcribe(
wav_path,
beam_size=5,
language="en",
task="transcribe"
)
logger.info(f"Detected language '{info.language}' with probability {info.language_probability}")
# Collect all segments into a single text
result_text = ""
for segment in segments:
result_text += segment.text + " "
logger.debug(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}")
result = result_text.strip()
logger.info(f"Transcription completed successfully")
return result
class ParakeetModel(ASRModel):
"""Parakeet ASR model implementation"""
def __init__(self):
self.model = None
def load_model(self):
"""Load Parakeet model"""
try:
import nemo.collections.asr as nemo_asr
logger.info("Loading Parakeet model")
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name="nvidia/parakeet-tdt-0.6b-v2")
logger.info("Parakeet model loaded successfully")
except ImportError:
logger.error("Failed to import nemo_toolkit. Please install with: pip install -U 'nemo_toolkit[asr]'")
raise
def transcribe(self, audio_path):
"""Transcribe audio using Parakeet"""
if self.model is None:
self.load_model()
wav_path = self.preprocess_audio(audio_path)
# Transcription
logger.info("Generating transcription with Parakeet")
output = self.model.transcribe([wav_path])
result = output[0].text
logger.info(f"Transcription completed successfully")
return result
class ASRFactory:
"""Factory for creating ASR model instances"""
@staticmethod
def get_model(model_name="parakeet"):
"""
Get ASR model by name
Args:
model_name: Name of the model to use (whisper or parakeet)
Returns:
ASR model instance
"""
if model_name.lower() == "whisper":
return WhisperModel()
elif model_name.lower() == "parakeet":
return ParakeetModel()
else:
logger.warning(f"Unknown model: {model_name}, falling back to Whisper")
return WhisperModel()
def transcribe_audio(audio_path, model_name="parakeet"):
"""
Convert audio file to text using specified ASR model
Args:
audio_path: Path to input audio file
model_name: Name of the ASR model to use (whisper or parakeet)
Returns:
Transcribed English text
"""
logger.info(f"Starting transcription for: {audio_path} using {model_name} model")
try:
# Get the appropriate model
asr_model = ASRFactory.get_model(model_name)
# Transcribe audio
result = asr_model.transcribe(audio_path)
logger.info(f"transcription: %s" % result)
return result
except Exception as e:
logger.error(f"Transcription failed: {str(e)}", exc_info=True)
raise