File size: 3,853 Bytes
1be582a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc056d
1be582a
 
 
fdc056d
1be582a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc056d
1be582a
fdc056d
1be582a
 
 
 
 
 
 
fdc056d
1be582a
fdc056d
1be582a
 
fdc056d
1be582a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc056d
1be582a
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""Parakeet STT provider implementation."""

import logging
from pathlib import Path
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from ...domain.models.audio_content import AudioContent
    from ...domain.models.text_content import TextContent

from ..base.stt_provider_base import STTProviderBase
from ...domain.exceptions import SpeechRecognitionException

logger = logging.getLogger(__name__)


class ParakeetSTTProvider(STTProviderBase):
    """Parakeet STT provider using NVIDIA NeMo implementation."""

    def __init__(self):
        """Initialize the Parakeet STT provider."""
        super().__init__(
            provider_name="Parakeet",
            supported_languages=["en"]  # Parakeet primarily supports English
        )
        self.model = None

    def _perform_transcription(self, audio_path: Path, model: str) -> str:
        """
        Perform transcription using Parakeet.

        Args:
            audio_path: Path to the preprocessed audio file
            model: The Parakeet model to use

        Returns:
            str: The transcribed text
        """
        try:
            # Load model if not already loaded
            if self.model is None:
                self._load_model(model)

            logger.info(f"Starting Parakeet transcription with model {model}")

            # Perform transcription
            output = self.model.transcribe([str(audio_path)])
            result = output[0].text if output and len(output) > 0 else ""

            logger.info("Parakeet transcription completed successfully")
            return result

        except Exception as e:
            self._handle_provider_error(e, "transcription")

    def _load_model(self, model_name: str):
        """
        Load the Parakeet model.

        Args:
            model_name: Name of the model to load
        """
        try:
            import nemo.collections.asr as nemo_asr

            logger.info(f"Loading Parakeet model: {model_name}")

            # Map model names to actual model identifiers
            model_mapping = {
                "parakeet-tdt-0.6b-v2": "nvidia/parakeet-tdt-0.6b-v2",
                "parakeet-tdt-1.1b": "nvidia/parakeet-tdt-1.1b",
                "parakeet-ctc-0.6b": "nvidia/parakeet-ctc-0.6b",
                "default": "nvidia/parakeet-tdt-0.6b-v2"
            }

            actual_model_name = model_mapping.get(model_name, model_mapping["default"])

            self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=actual_model_name)
            logger.info(f"Parakeet model {model_name} loaded successfully")

        except ImportError as e:
            raise SpeechRecognitionException(
                "nemo_toolkit not available. Please install with: pip install -U 'nemo_toolkit[asr]'"
            ) from e
        except Exception as e:
            raise SpeechRecognitionException(f"Failed to load Parakeet model {model_name}: {str(e)}") from e

    def is_available(self) -> bool:
        """
        Check if the Parakeet provider is available.

        Returns:
            bool: True if nemo_toolkit is available, False otherwise
        """
        try:
            import nemo.collections.asr
            return True
        except ImportError:
            logger.warning("nemo_toolkit not available")
            return False

    def get_available_models(self) -> list[str]:
        """
        Get list of available Parakeet models.

        Returns:
            list[str]: List of available model names
        """
        return [
            "parakeet-tdt-0.6b-v2",
            "parakeet-tdt-1.1b",
            "parakeet-ctc-0.6b"
        ]

    def get_default_model(self) -> str:
        """
        Get the default model for this provider.

        Returns:
            str: Default model name
        """
        return "parakeet-tdt-0.6b-v2"