Spaces:
Running
Running
File size: 6,006 Bytes
cd1309d 0ee4f42 cd1309d c72d839 c10f1ac 0ee4f42 c72d839 cd1309d 0ee4f42 cd1309d 0ee4f42 cd1309d 0ee4f42 2477bc4 c72d839 0ee4f42 c72d839 2477bc4 0ee4f42 c72d839 0ee4f42 c72d839 0ee4f42 c72d839 2477bc4 a4f48aa 7eff88c 0ee4f42 7eff88c c72d839 7eff88c 0ee4f42 c72d839 7eff88c 0ee4f42 7eff88c c72d839 0ee4f42 c72d839 0ee4f42 31708ca 0ee4f42 31708ca 0ee4f42 c72d839 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
"""
Speech Recognition Module
Supports multiple ASR models including Whisper and Parakeet
Handles audio preprocessing and transcription
"""
import logging
import numpy as np
import os
from abc import ABC, abstractmethod
logger = logging.getLogger(__name__)
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from pydub import AudioSegment
import soundfile as sf
class ASRModel(ABC):
"""Base class for ASR models"""
@abstractmethod
def load_model(self):
"""Load the ASR model"""
pass
@abstractmethod
def transcribe(self, audio_path):
"""Transcribe audio to text"""
pass
def preprocess_audio(self, audio_path):
"""Convert audio to required format"""
logger.info("Converting audio format")
audio = AudioSegment.from_file(audio_path)
processed_audio = audio.set_frame_rate(16000).set_channels(1)
wav_path = audio_path.replace(".mp3", ".wav") if audio_path.endswith(".mp3") else audio_path
if not wav_path.endswith(".wav"):
wav_path = f"{os.path.splitext(wav_path)[0]}.wav"
processed_audio.export(wav_path, format="wav")
logger.info(f"Audio converted to: {wav_path}")
return wav_path
class WhisperModel(ASRModel):
"""Whisper ASR model implementation"""
def __init__(self):
self.model = None
self.processor = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model(self):
"""Load Whisper model"""
logger.info("Loading Whisper model")
logger.info(f"Using device: {self.device}")
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
"openai/whisper-large-v3",
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
use_safetensors=True
).to(self.device)
self.processor = AutoProcessor.from_pretrained("openai/whisper-large-v3")
logger.info("Whisper model loaded successfully")
def transcribe(self, audio_path):
"""Transcribe audio using Whisper"""
if self.model is None or self.processor is None:
self.load_model()
wav_path = self.preprocess_audio(audio_path)
# Processing
logger.info("Processing audio input")
logger.debug("Loading audio data")
audio_data, sample_rate = sf.read(wav_path)
audio_data = audio_data.astype(np.float32)
# Increase chunk length and stride for longer transcriptions
inputs = self.processor(
audio_data,
sampling_rate=16000,
return_tensors="pt",
# Increase chunk length to handle longer segments
chunk_length_s=60,
stride_length_s=10
).to(self.device)
# Transcription
logger.info("Generating transcription")
with torch.no_grad():
# Add max_length parameter to allow for longer outputs
outputs = self.model.generate(
**inputs,
language="en",
task="transcribe",
max_length=448, # Explicitly set max output length
no_repeat_ngram_size=3 # Prevent repetition in output
)
result = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
logger.info(f"Transcription completed successfully")
return result
class ParakeetModel(ASRModel):
"""Parakeet ASR model implementation"""
def __init__(self):
self.model = None
def load_model(self):
"""Load Parakeet model"""
try:
import nemo.collections.asr as nemo_asr
logger.info("Loading Parakeet model")
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name="nvidia/parakeet-tdt-0.6b-v2")
logger.info("Parakeet model loaded successfully")
except ImportError:
logger.error("Failed to import nemo_toolkit. Please install with: pip install -U 'nemo_toolkit[asr]'")
raise
def transcribe(self, audio_path):
"""Transcribe audio using Parakeet"""
if self.model is None:
self.load_model()
wav_path = self.preprocess_audio(audio_path)
# Transcription
logger.info("Generating transcription with Parakeet")
output = self.model.transcribe([wav_path])
result = output[0].text
logger.info(f"Transcription completed successfully")
return result
class ASRFactory:
"""Factory for creating ASR model instances"""
@staticmethod
def get_model(model_name="parakeet"):
"""
Get ASR model by name
Args:
model_name: Name of the model to use (whisper or parakeet)
Returns:
ASR model instance
"""
if model_name.lower() == "whisper":
return WhisperModel()
elif model_name.lower() == "parakeet":
return ParakeetModel()
else:
logger.warning(f"Unknown model: {model_name}, falling back to Whisper")
return WhisperModel()
def transcribe_audio(audio_path, model_name="parakeet"):
"""
Convert audio file to text using specified ASR model
Args:
audio_path: Path to input audio file
model_name: Name of the ASR model to use (whisper or parakeet)
Returns:
Transcribed English text
"""
logger.info(f"Starting transcription for: {audio_path} using {model_name} model")
try:
# Get the appropriate model
asr_model = ASRFactory.get_model(model_name)
# Transcribe audio
result = asr_model.transcribe(audio_path)
logger.info(f"transcription: %s" % result)
return result
except Exception as e:
logger.error(f"Transcription failed: {str(e)}", exc_info=True)
raise |