""" Speech Recognition Module Supports multiple ASR models including Whisper and Parakeet Handles audio preprocessing and transcription """ import logging import numpy as np import os from abc import ABC, abstractmethod logger = logging.getLogger(__name__) import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor from pydub import AudioSegment import soundfile as sf class ASRModel(ABC): """Base class for ASR models""" @abstractmethod def load_model(self): """Load the ASR model""" pass @abstractmethod def transcribe(self, audio_path): """Transcribe audio to text""" pass def preprocess_audio(self, audio_path): """Convert audio to required format""" logger.info("Converting audio format") audio = AudioSegment.from_file(audio_path) processed_audio = audio.set_frame_rate(16000).set_channels(1) wav_path = audio_path.replace(".mp3", ".wav") if audio_path.endswith(".mp3") else audio_path if not wav_path.endswith(".wav"): wav_path = f"{os.path.splitext(wav_path)[0]}.wav" processed_audio.export(wav_path, format="wav") logger.info(f"Audio converted to: {wav_path}") return wav_path class WhisperModel(ASRModel): """Whisper ASR model implementation""" def __init__(self): self.model = None self.processor = None self.device = "cuda" if torch.cuda.is_available() else "cpu" def load_model(self): """Load Whisper model""" logger.info("Loading Whisper model") logger.info(f"Using device: {self.device}") self.model = AutoModelForSpeechSeq2Seq.from_pretrained( "openai/whisper-large-v3", torch_dtype=torch.float32, low_cpu_mem_usage=True, use_safetensors=True ).to(self.device) self.processor = AutoProcessor.from_pretrained("openai/whisper-large-v3") logger.info("Whisper model loaded successfully") def transcribe(self, audio_path): """Transcribe audio using Whisper""" if self.model is None or self.processor is None: self.load_model() wav_path = self.preprocess_audio(audio_path) # Processing logger.info("Processing audio input") logger.debug("Loading audio data") audio_data, sample_rate = sf.read(wav_path) audio_data = audio_data.astype(np.float32) # Increase chunk length and stride for longer transcriptions inputs = self.processor( audio_data, sampling_rate=16000, return_tensors="pt", # Increase chunk length to handle longer segments chunk_length_s=60, stride_length_s=10 ).to(self.device) # Transcription logger.info("Generating transcription") with torch.no_grad(): # Add max_length parameter to allow for longer outputs outputs = self.model.generate( **inputs, language="en", task="transcribe", max_length=448, # Explicitly set max output length no_repeat_ngram_size=3 # Prevent repetition in output ) result = self.processor.batch_decode(outputs, skip_special_tokens=True)[0] logger.info(f"Transcription completed successfully") return result class ParakeetModel(ASRModel): """Parakeet ASR model implementation""" def __init__(self): self.model = None def load_model(self): """Load Parakeet model""" try: import nemo.collections.asr as nemo_asr logger.info("Loading Parakeet model") self.model = nemo_asr.models.ASRModel.from_pretrained(model_name="nvidia/parakeet-tdt-0.6b-v2") logger.info("Parakeet model loaded successfully") except ImportError: logger.error("Failed to import nemo_toolkit. Please install with: pip install -U 'nemo_toolkit[asr]'") raise def transcribe(self, audio_path): """Transcribe audio using Parakeet""" if self.model is None: self.load_model() wav_path = self.preprocess_audio(audio_path) # Transcription logger.info("Generating transcription with Parakeet") output = self.model.transcribe([wav_path]) result = output[0].text logger.info(f"Transcription completed successfully") return result class ASRFactory: """Factory for creating ASR model instances""" @staticmethod def get_model(model_name="parakeet"): """ Get ASR model by name Args: model_name: Name of the model to use (whisper or parakeet) Returns: ASR model instance """ if model_name.lower() == "whisper": return WhisperModel() elif model_name.lower() == "parakeet": return ParakeetModel() else: logger.warning(f"Unknown model: {model_name}, falling back to Whisper") return WhisperModel() def transcribe_audio(audio_path, model_name="parakeet"): """ Convert audio file to text using specified ASR model Args: audio_path: Path to input audio file model_name: Name of the ASR model to use (whisper or parakeet) Returns: Transcribed English text """ logger.info(f"Starting transcription for: {audio_path} using {model_name} model") try: # Get the appropriate model asr_model = ASRFactory.get_model(model_name) # Transcribe audio result = asr_model.transcribe(audio_path) logger.info(f"transcription: %s" % result) return result except Exception as e: logger.error(f"Transcription failed: {str(e)}", exc_info=True) raise