Spaces:
Build error
Build error
File size: 3,853 Bytes
1be582a fdc056d 1be582a fdc056d 1be582a fdc056d 1be582a fdc056d 1be582a fdc056d 1be582a fdc056d 1be582a fdc056d 1be582a fdc056d 1be582a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
"""Parakeet STT provider implementation."""
import logging
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ...domain.models.audio_content import AudioContent
from ...domain.models.text_content import TextContent
from ..base.stt_provider_base import STTProviderBase
from ...domain.exceptions import SpeechRecognitionException
logger = logging.getLogger(__name__)
class ParakeetSTTProvider(STTProviderBase):
"""Parakeet STT provider using NVIDIA NeMo implementation."""
def __init__(self):
"""Initialize the Parakeet STT provider."""
super().__init__(
provider_name="Parakeet",
supported_languages=["en"] # Parakeet primarily supports English
)
self.model = None
def _perform_transcription(self, audio_path: Path, model: str) -> str:
"""
Perform transcription using Parakeet.
Args:
audio_path: Path to the preprocessed audio file
model: The Parakeet model to use
Returns:
str: The transcribed text
"""
try:
# Load model if not already loaded
if self.model is None:
self._load_model(model)
logger.info(f"Starting Parakeet transcription with model {model}")
# Perform transcription
output = self.model.transcribe([str(audio_path)])
result = output[0].text if output and len(output) > 0 else ""
logger.info("Parakeet transcription completed successfully")
return result
except Exception as e:
self._handle_provider_error(e, "transcription")
def _load_model(self, model_name: str):
"""
Load the Parakeet model.
Args:
model_name: Name of the model to load
"""
try:
import nemo.collections.asr as nemo_asr
logger.info(f"Loading Parakeet model: {model_name}")
# Map model names to actual model identifiers
model_mapping = {
"parakeet-tdt-0.6b-v2": "nvidia/parakeet-tdt-0.6b-v2",
"parakeet-tdt-1.1b": "nvidia/parakeet-tdt-1.1b",
"parakeet-ctc-0.6b": "nvidia/parakeet-ctc-0.6b",
"default": "nvidia/parakeet-tdt-0.6b-v2"
}
actual_model_name = model_mapping.get(model_name, model_mapping["default"])
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=actual_model_name)
logger.info(f"Parakeet model {model_name} loaded successfully")
except ImportError as e:
raise SpeechRecognitionException(
"nemo_toolkit not available. Please install with: pip install -U 'nemo_toolkit[asr]'"
) from e
except Exception as e:
raise SpeechRecognitionException(f"Failed to load Parakeet model {model_name}: {str(e)}") from e
def is_available(self) -> bool:
"""
Check if the Parakeet provider is available.
Returns:
bool: True if nemo_toolkit is available, False otherwise
"""
try:
import nemo.collections.asr
return True
except ImportError:
logger.warning("nemo_toolkit not available")
return False
def get_available_models(self) -> list[str]:
"""
Get list of available Parakeet models.
Returns:
list[str]: List of available model names
"""
return [
"parakeet-tdt-0.6b-v2",
"parakeet-tdt-1.1b",
"parakeet-ctc-0.6b"
]
def get_default_model(self) -> str:
"""
Get the default model for this provider.
Returns:
str: Default model name
"""
return "parakeet-tdt-0.6b-v2" |