Spaces:
Build error
Build error
"""Parakeet STT provider implementation.""" | |
import logging | |
from pathlib import Path | |
from typing import TYPE_CHECKING | |
if TYPE_CHECKING: | |
from ...domain.models.audio_content import AudioContent | |
from ...domain.models.text_content import TextContent | |
from ..base.stt_provider_base import STTProviderBase | |
from ...domain.exceptions import SpeechRecognitionException | |
logger = logging.getLogger(__name__) | |
class ParakeetSTTProvider(STTProviderBase): | |
"""Parakeet STT provider using NVIDIA NeMo implementation.""" | |
def __init__(self): | |
"""Initialize the Parakeet STT provider.""" | |
super().__init__( | |
provider_name="Parakeet", | |
supported_languages=["en"] # Parakeet primarily supports English | |
) | |
self.model = None | |
def _perform_transcription(self, audio_path: Path, model: str) -> str: | |
""" | |
Perform transcription using Parakeet. | |
Args: | |
audio_path: Path to the preprocessed audio file | |
model: The Parakeet model to use | |
Returns: | |
str: The transcribed text | |
""" | |
try: | |
# Load model if not already loaded | |
if self.model is None: | |
self._load_model(model) | |
logger.info(f"Starting Parakeet transcription with model {model}") | |
# Perform transcription | |
output = self.model.transcribe([str(audio_path)]) | |
result = output[0].text if output and len(output) > 0 else "" | |
logger.info("Parakeet transcription completed successfully") | |
return result | |
except Exception as e: | |
self._handle_provider_error(e, "transcription") | |
def _load_model(self, model_name: str): | |
""" | |
Load the Parakeet model. | |
Args: | |
model_name: Name of the model to load | |
""" | |
try: | |
import nemo.collections.asr as nemo_asr | |
logger.info(f"Loading Parakeet model: {model_name}") | |
# Map model names to actual model identifiers | |
model_mapping = { | |
"parakeet-tdt-0.6b-v2": "nvidia/parakeet-tdt-0.6b-v2", | |
"parakeet-tdt-1.1b": "nvidia/parakeet-tdt-1.1b", | |
"parakeet-ctc-0.6b": "nvidia/parakeet-ctc-0.6b", | |
"default": "nvidia/parakeet-tdt-0.6b-v2" | |
} | |
actual_model_name = model_mapping.get(model_name, model_mapping["default"]) | |
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=actual_model_name) | |
logger.info(f"Parakeet model {model_name} loaded successfully") | |
except ImportError as e: | |
raise SpeechRecognitionException( | |
"nemo_toolkit not available. Please install with: pip install -U 'nemo_toolkit[asr]'" | |
) from e | |
except Exception as e: | |
raise SpeechRecognitionException(f"Failed to load Parakeet model {model_name}: {str(e)}") from e | |
def is_available(self) -> bool: | |
""" | |
Check if the Parakeet provider is available. | |
Returns: | |
bool: True if nemo_toolkit is available, False otherwise | |
""" | |
try: | |
import nemo.collections.asr | |
return True | |
except ImportError: | |
logger.warning("nemo_toolkit not available") | |
return False | |
def get_available_models(self) -> list[str]: | |
""" | |
Get list of available Parakeet models. | |
Returns: | |
list[str]: List of available model names | |
""" | |
return [ | |
"parakeet-tdt-0.6b-v2", | |
"parakeet-tdt-1.1b", | |
"parakeet-ctc-0.6b" | |
] | |
def get_default_model(self) -> str: | |
""" | |
Get the default model for this provider. | |
Returns: | |
str: Default model name | |
""" | |
return "parakeet-tdt-0.6b-v2" |