Michael Hu commited on
Commit
1be582a
Β·
1 Parent(s): 1f9c751

Migrate existing STT providers to infrastructure layer

Browse files
src/infrastructure/base/stt_provider_base.py CHANGED
@@ -114,6 +114,16 @@ class STTProviderBase(ISpeechRecognitionService, ABC):
114
  """
115
  pass
116
 
 
 
 
 
 
 
 
 
 
 
117
  def _preprocess_audio(self, audio: 'AudioContent') -> Path:
118
  """
119
  Preprocess audio content for transcription.
 
114
  """
115
  pass
116
 
117
+ @abstractmethod
118
+ def get_default_model(self) -> str:
119
+ """
120
+ Get the default model for this provider.
121
+
122
+ Returns:
123
+ str: Default model name
124
+ """
125
+ pass
126
+
127
  def _preprocess_audio(self, audio: 'AudioContent') -> Path:
128
  """
129
  Preprocess audio content for transcription.
src/infrastructure/stt/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """STT provider implementations."""
2
+
3
+ from .whisper_provider import WhisperSTTProvider
4
+ from .parakeet_provider import ParakeetSTTProvider
5
+ from .provider_factory import STTProviderFactory, ASRFactory
6
+ from .legacy_compatibility import transcribe_audio, create_audio_content_from_file
7
+
8
+ __all__ = [
9
+ 'WhisperSTTProvider',
10
+ 'ParakeetSTTProvider',
11
+ 'STTProviderFactory',
12
+ 'ASRFactory',
13
+ 'transcribe_audio',
14
+ 'create_audio_content_from_file'
15
+ ]
src/infrastructure/stt/legacy_compatibility.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Legacy compatibility functions for STT functionality."""
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ from .provider_factory import STTProviderFactory
8
+ from ...domain.models.audio_content import AudioContent
9
+ from ...domain.exceptions import SpeechRecognitionException
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def transcribe_audio(audio_path: Union[str, Path], model_name: str = "parakeet") -> str:
15
+ """
16
+ Convert audio file to text using specified STT model (legacy interface).
17
+
18
+ This function maintains backward compatibility with the original utils/stt.py interface.
19
+
20
+ Args:
21
+ audio_path: Path to input audio file
22
+ model_name: Name of the STT model/provider to use (whisper or parakeet)
23
+
24
+ Returns:
25
+ str: Transcribed English text
26
+
27
+ Raises:
28
+ SpeechRecognitionException: If transcription fails
29
+ """
30
+ logger.info(f"Starting transcription for: {audio_path} using {model_name} model")
31
+
32
+ try:
33
+ # Convert path to Path object
34
+ audio_path = Path(audio_path)
35
+
36
+ if not audio_path.exists():
37
+ raise SpeechRecognitionException(f"Audio file not found: {audio_path}")
38
+
39
+ # Read audio file and create AudioContent
40
+ with open(audio_path, 'rb') as f:
41
+ audio_data = f.read()
42
+
43
+ # Determine audio format from file extension
44
+ audio_format = audio_path.suffix.lower().lstrip('.')
45
+ if audio_format not in ['wav', 'mp3', 'flac', 'ogg']:
46
+ audio_format = 'wav' # Default fallback
47
+
48
+ # Create AudioContent (we'll use reasonable placeholder values)
49
+ # The provider will handle the actual audio analysis during preprocessing
50
+ try:
51
+ audio_content = AudioContent(
52
+ data=audio_data,
53
+ format=audio_format,
54
+ sample_rate=16000, # Standard rate for STT
55
+ duration=max(1.0, len(audio_data) / (16000 * 2)), # Rough estimate
56
+ filename=audio_path.name
57
+ )
58
+ except ValueError:
59
+ # If validation fails, try with minimal valid values
60
+ audio_content = AudioContent(
61
+ data=audio_data,
62
+ format=audio_format,
63
+ sample_rate=16000,
64
+ duration=1.0, # Minimum valid duration
65
+ filename=audio_path.name
66
+ )
67
+
68
+ # Get the appropriate provider
69
+ try:
70
+ provider = STTProviderFactory.create_provider(model_name)
71
+ except SpeechRecognitionException:
72
+ # Fallback to any available provider
73
+ logger.warning(f"Requested provider {model_name} not available, using fallback")
74
+ provider = STTProviderFactory.create_provider_with_fallback(model_name)
75
+
76
+ # Get the default model for the provider
77
+ model = provider.get_default_model()
78
+
79
+ # Transcribe audio
80
+ text_content = provider.transcribe(audio_content, model)
81
+ result = text_content.text
82
+
83
+ logger.info(f"Transcription completed: {result}")
84
+ return result
85
+
86
+ except Exception as e:
87
+ logger.error(f"Transcription failed: {str(e)}", exc_info=True)
88
+ raise SpeechRecognitionException(f"Transcription failed: {str(e)}") from e
89
+
90
+
91
+ def create_audio_content_from_file(audio_path: Union[str, Path]) -> AudioContent:
92
+ """
93
+ Create AudioContent from an audio file with proper metadata detection.
94
+
95
+ Args:
96
+ audio_path: Path to the audio file
97
+
98
+ Returns:
99
+ AudioContent: The audio content object
100
+
101
+ Raises:
102
+ SpeechRecognitionException: If file cannot be processed
103
+ """
104
+ try:
105
+ from pydub import AudioSegment
106
+
107
+ audio_path = Path(audio_path)
108
+
109
+ # Load audio file to get metadata
110
+ audio_segment = AudioSegment.from_file(audio_path)
111
+
112
+ # Read raw audio data
113
+ with open(audio_path, 'rb') as f:
114
+ audio_data = f.read()
115
+
116
+ # Determine format
117
+ audio_format = audio_path.suffix.lower().lstrip('.')
118
+ if audio_format not in ['wav', 'mp3', 'flac', 'ogg']:
119
+ audio_format = 'wav'
120
+
121
+ # Create AudioContent with actual metadata
122
+ return AudioContent(
123
+ data=audio_data,
124
+ format=audio_format,
125
+ sample_rate=audio_segment.frame_rate,
126
+ duration=len(audio_segment) / 1000.0, # Convert ms to seconds
127
+ filename=audio_path.name
128
+ )
129
+
130
+ except ImportError:
131
+ # Fallback without pydub
132
+ logger.warning("pydub not available, using placeholder metadata")
133
+
134
+ with open(audio_path, 'rb') as f:
135
+ audio_data = f.read()
136
+
137
+ audio_format = Path(audio_path).suffix.lower().lstrip('.')
138
+ if audio_format not in ['wav', 'mp3', 'flac', 'ogg']:
139
+ audio_format = 'wav'
140
+
141
+ return AudioContent(
142
+ data=audio_data,
143
+ format=audio_format,
144
+ sample_rate=16000, # Default
145
+ duration=1.0, # Placeholder
146
+ filename=Path(audio_path).name
147
+ )
148
+
149
+ except Exception as e:
150
+ raise SpeechRecognitionException(f"Failed to create AudioContent from file: {str(e)}") from e
src/infrastructure/stt/parakeet_provider.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Parakeet STT provider implementation."""
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING
6
+
7
+ if TYPE_CHECKING:
8
+ from ...domain.models.audio_content import AudioContent
9
+ from ...domain.models.text_content import TextContent
10
+
11
+ from ..base.stt_provider_base import STTProviderBase
12
+ from ...domain.exceptions import SpeechRecognitionException
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class ParakeetSTTProvider(STTProviderBase):
18
+ """Parakeet STT provider using NVIDIA NeMo implementation."""
19
+
20
+ def __init__(self):
21
+ """Initialize the Parakeet STT provider."""
22
+ super().__init__(
23
+ provider_name="Parakeet",
24
+ supported_languages=["en"] # Parakeet primarily supports English
25
+ )
26
+ self.model = None
27
+
28
+ def _perform_transcription(self, audio_path: Path, model: str) -> str:
29
+ """
30
+ Perform transcription using Parakeet.
31
+
32
+ Args:
33
+ audio_path: Path to the preprocessed audio file
34
+ model: The Parakeet model to use
35
+
36
+ Returns:
37
+ str: The transcribed text
38
+ """
39
+ try:
40
+ # Load model if not already loaded
41
+ if self.model is None:
42
+ self._load_model(model)
43
+
44
+ logger.info(f"Starting Parakeet transcription with model {model}")
45
+
46
+ # Perform transcription
47
+ output = self.model.transcribe([str(audio_path)])
48
+ result = output[0].text if output and len(output) > 0 else ""
49
+
50
+ logger.info("Parakeet transcription completed successfully")
51
+ return result
52
+
53
+ except Exception as e:
54
+ self._handle_provider_error(e, "transcription")
55
+
56
+ def _load_model(self, model_name: str):
57
+ """
58
+ Load the Parakeet model.
59
+
60
+ Args:
61
+ model_name: Name of the model to load
62
+ """
63
+ try:
64
+ import nemo.collections.asr as nemo_asr
65
+
66
+ logger.info(f"Loading Parakeet model: {model_name}")
67
+
68
+ # Map model names to actual model identifiers
69
+ model_mapping = {
70
+ "parakeet-tdt-0.6b-v2": "nvidia/parakeet-tdt-0.6b-v2",
71
+ "parakeet-tdt-1.1b": "nvidia/parakeet-tdt-1.1b",
72
+ "parakeet-ctc-0.6b": "nvidia/parakeet-ctc-0.6b",
73
+ "default": "nvidia/parakeet-tdt-0.6b-v2"
74
+ }
75
+
76
+ actual_model_name = model_mapping.get(model_name, model_mapping["default"])
77
+
78
+ self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=actual_model_name)
79
+ logger.info(f"Parakeet model {model_name} loaded successfully")
80
+
81
+ except ImportError as e:
82
+ raise SpeechRecognitionException(
83
+ "nemo_toolkit not available. Please install with: pip install -U 'nemo_toolkit[asr]'"
84
+ ) from e
85
+ except Exception as e:
86
+ raise SpeechRecognitionException(f"Failed to load Parakeet model {model_name}: {str(e)}") from e
87
+
88
+ def is_available(self) -> bool:
89
+ """
90
+ Check if the Parakeet provider is available.
91
+
92
+ Returns:
93
+ bool: True if nemo_toolkit is available, False otherwise
94
+ """
95
+ try:
96
+ import nemo.collections.asr
97
+ return True
98
+ except ImportError:
99
+ logger.warning("nemo_toolkit not available")
100
+ return False
101
+
102
+ def get_available_models(self) -> list[str]:
103
+ """
104
+ Get list of available Parakeet models.
105
+
106
+ Returns:
107
+ list[str]: List of available model names
108
+ """
109
+ return [
110
+ "parakeet-tdt-0.6b-v2",
111
+ "parakeet-tdt-1.1b",
112
+ "parakeet-ctc-0.6b"
113
+ ]
114
+
115
+ def get_default_model(self) -> str:
116
+ """
117
+ Get the default model for this provider.
118
+
119
+ Returns:
120
+ str: Default model name
121
+ """
122
+ return "parakeet-tdt-0.6b-v2"
src/infrastructure/stt/provider_factory.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Factory for creating STT provider instances."""
2
+
3
+ import logging
4
+ from typing import Dict, Type, Optional
5
+
6
+ from ..base.stt_provider_base import STTProviderBase
7
+ from .whisper_provider import WhisperSTTProvider
8
+ from .parakeet_provider import ParakeetSTTProvider
9
+ from ...domain.exceptions import SpeechRecognitionException
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class STTProviderFactory:
15
+ """Factory for creating STT provider instances with availability checking and fallback logic."""
16
+
17
+ _providers: Dict[str, Type[STTProviderBase]] = {
18
+ "whisper": WhisperSTTProvider,
19
+ "parakeet": ParakeetSTTProvider
20
+ }
21
+
22
+ _fallback_order = ["whisper", "parakeet"]
23
+
24
+ @classmethod
25
+ def create_provider(cls, provider_name: str) -> STTProviderBase:
26
+ """
27
+ Create an STT provider instance by name.
28
+
29
+ Args:
30
+ provider_name: Name of the provider to create
31
+
32
+ Returns:
33
+ STTProviderBase: The created provider instance
34
+
35
+ Raises:
36
+ SpeechRecognitionException: If provider is not available or creation fails
37
+ """
38
+ provider_name = provider_name.lower()
39
+
40
+ if provider_name not in cls._providers:
41
+ raise SpeechRecognitionException(f"Unknown STT provider: {provider_name}")
42
+
43
+ provider_class = cls._providers[provider_name]
44
+
45
+ try:
46
+ provider = provider_class()
47
+
48
+ if not provider.is_available():
49
+ raise SpeechRecognitionException(f"STT provider {provider_name} is not available")
50
+
51
+ logger.info(f"Created STT provider: {provider_name}")
52
+ return provider
53
+
54
+ except Exception as e:
55
+ logger.error(f"Failed to create STT provider {provider_name}: {str(e)}")
56
+ raise SpeechRecognitionException(f"Failed to create STT provider {provider_name}: {str(e)}") from e
57
+
58
+ @classmethod
59
+ def create_provider_with_fallback(cls, preferred_provider: str) -> STTProviderBase:
60
+ """
61
+ Create an STT provider with fallback to other available providers.
62
+
63
+ Args:
64
+ preferred_provider: The preferred provider name
65
+
66
+ Returns:
67
+ STTProviderBase: The created provider instance
68
+
69
+ Raises:
70
+ SpeechRecognitionException: If no providers are available
71
+ """
72
+ # Try preferred provider first
73
+ try:
74
+ return cls.create_provider(preferred_provider)
75
+ except SpeechRecognitionException as e:
76
+ logger.warning(f"Preferred STT provider {preferred_provider} failed: {str(e)}")
77
+
78
+ # Try fallback providers
79
+ for provider_name in cls._fallback_order:
80
+ if provider_name.lower() == preferred_provider.lower():
81
+ continue # Skip the preferred provider we already tried
82
+
83
+ try:
84
+ logger.info(f"Trying fallback STT provider: {provider_name}")
85
+ return cls.create_provider(provider_name)
86
+ except SpeechRecognitionException as e:
87
+ logger.warning(f"Fallback STT provider {provider_name} failed: {str(e)}")
88
+ continue
89
+
90
+ raise SpeechRecognitionException("No STT providers are available")
91
+
92
+ @classmethod
93
+ def get_available_providers(cls) -> list[str]:
94
+ """
95
+ Get list of available STT providers.
96
+
97
+ Returns:
98
+ list[str]: List of available provider names
99
+ """
100
+ available = []
101
+
102
+ for provider_name, provider_class in cls._providers.items():
103
+ try:
104
+ provider = provider_class()
105
+ if provider.is_available():
106
+ available.append(provider_name)
107
+ except Exception as e:
108
+ logger.debug(f"Provider {provider_name} not available: {str(e)}")
109
+
110
+ return available
111
+
112
+ @classmethod
113
+ def get_provider_info(cls, provider_name: str) -> Optional[dict]:
114
+ """
115
+ Get information about a specific provider.
116
+
117
+ Args:
118
+ provider_name: Name of the provider
119
+
120
+ Returns:
121
+ Optional[dict]: Provider information or None if not found
122
+ """
123
+ provider_name = provider_name.lower()
124
+
125
+ if provider_name not in cls._providers:
126
+ return None
127
+
128
+ provider_class = cls._providers[provider_name]
129
+
130
+ try:
131
+ provider = provider_class()
132
+ return {
133
+ "name": provider.provider_name,
134
+ "available": provider.is_available(),
135
+ "supported_languages": provider.supported_languages,
136
+ "available_models": provider.get_available_models() if provider.is_available() else [],
137
+ "default_model": provider.get_default_model() if provider.is_available() else None
138
+ }
139
+ except Exception as e:
140
+ logger.debug(f"Failed to get info for provider {provider_name}: {str(e)}")
141
+ return {
142
+ "name": provider_name,
143
+ "available": False,
144
+ "error": str(e)
145
+ }
146
+
147
+ @classmethod
148
+ def register_provider(cls, name: str, provider_class: Type[STTProviderBase]) -> None:
149
+ """
150
+ Register a new STT provider.
151
+
152
+ Args:
153
+ name: Name of the provider
154
+ provider_class: The provider class
155
+ """
156
+ cls._providers[name.lower()] = provider_class
157
+ logger.info(f"Registered STT provider: {name}")
158
+
159
+
160
+ # Legacy compatibility - create an ASRFactory alias
161
+ class ASRFactory:
162
+ """Legacy ASRFactory for backward compatibility."""
163
+
164
+ @staticmethod
165
+ def get_model(model_name: str = "parakeet") -> STTProviderBase:
166
+ """
167
+ Get STT provider by model name (legacy interface).
168
+
169
+ Args:
170
+ model_name: Name of the model/provider to use
171
+
172
+ Returns:
173
+ STTProviderBase: The provider instance
174
+ """
175
+ # Map legacy model names to provider names
176
+ provider_mapping = {
177
+ "whisper": "whisper",
178
+ "parakeet": "parakeet",
179
+ "faster-whisper": "whisper"
180
+ }
181
+
182
+ provider_name = provider_mapping.get(model_name.lower(), model_name.lower())
183
+
184
+ try:
185
+ return STTProviderFactory.create_provider(provider_name)
186
+ except SpeechRecognitionException:
187
+ # Fallback to any available provider
188
+ logger.warning(f"Requested provider {provider_name} not available, using fallback")
189
+ return STTProviderFactory.create_provider_with_fallback(provider_name)
src/infrastructure/stt/whisper_provider.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Whisper STT provider implementation."""
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING
6
+
7
+ if TYPE_CHECKING:
8
+ from ...domain.models.audio_content import AudioContent
9
+ from ...domain.models.text_content import TextContent
10
+
11
+ from ..base.stt_provider_base import STTProviderBase
12
+ from ...domain.exceptions import SpeechRecognitionException
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class WhisperSTTProvider(STTProviderBase):
18
+ """Whisper STT provider using faster-whisper implementation."""
19
+
20
+ def __init__(self):
21
+ """Initialize the Whisper STT provider."""
22
+ super().__init__(
23
+ provider_name="Whisper",
24
+ supported_languages=["en", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh"]
25
+ )
26
+ self.model = None
27
+ self._device = None
28
+ self._compute_type = None
29
+ self._initialize_device_settings()
30
+
31
+ def _initialize_device_settings(self):
32
+ """Initialize device and compute type settings."""
33
+ try:
34
+ import torch
35
+ self._device = "cuda" if torch.cuda.is_available() else "cpu"
36
+ except ImportError:
37
+ # Fallback to CPU if torch is not available
38
+ self._device = "cpu"
39
+
40
+ self._compute_type = "float16" if self._device == "cuda" else "int8"
41
+ logger.info(f"Whisper provider initialized with device: {self._device}, compute_type: {self._compute_type}")
42
+
43
+ def _perform_transcription(self, audio_path: Path, model: str) -> str:
44
+ """
45
+ Perform transcription using Faster Whisper.
46
+
47
+ Args:
48
+ audio_path: Path to the preprocessed audio file
49
+ model: The Whisper model to use (e.g., 'large-v3', 'medium', 'small')
50
+
51
+ Returns:
52
+ str: The transcribed text
53
+ """
54
+ try:
55
+ # Load model if not already loaded or if model changed
56
+ if self.model is None or getattr(self.model, 'model_size_or_path', None) != model:
57
+ self._load_model(model)
58
+
59
+ logger.info(f"Starting Whisper transcription with model {model}")
60
+
61
+ # Perform transcription
62
+ segments, info = self.model.transcribe(
63
+ str(audio_path),
64
+ beam_size=5,
65
+ language="en", # Can be made configurable
66
+ task="transcribe"
67
+ )
68
+
69
+ logger.info(f"Detected language '{info.language}' with probability {info.language_probability}")
70
+
71
+ # Collect all segments into a single text
72
+ result_text = ""
73
+ for segment in segments:
74
+ result_text += segment.text + " "
75
+ logger.debug(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}")
76
+
77
+ result = result_text.strip()
78
+ logger.info("Whisper transcription completed successfully")
79
+ return result
80
+
81
+ except Exception as e:
82
+ self._handle_provider_error(e, "transcription")
83
+
84
+ def _load_model(self, model_name: str):
85
+ """
86
+ Load the Whisper model.
87
+
88
+ Args:
89
+ model_name: Name of the model to load
90
+ """
91
+ try:
92
+ from faster_whisper import WhisperModel as FasterWhisperModel
93
+
94
+ logger.info(f"Loading Whisper model: {model_name}")
95
+ logger.info(f"Using device: {self._device}, compute_type: {self._compute_type}")
96
+
97
+ self.model = FasterWhisperModel(
98
+ model_name,
99
+ device=self._device,
100
+ compute_type=self._compute_type
101
+ )
102
+
103
+ logger.info(f"Whisper model {model_name} loaded successfully")
104
+
105
+ except ImportError as e:
106
+ raise SpeechRecognitionException(
107
+ "faster-whisper not available. Please install with: pip install faster-whisper"
108
+ ) from e
109
+ except Exception as e:
110
+ raise SpeechRecognitionException(f"Failed to load Whisper model {model_name}: {str(e)}") from e
111
+
112
+ def is_available(self) -> bool:
113
+ """
114
+ Check if the Whisper provider is available.
115
+
116
+ Returns:
117
+ bool: True if faster-whisper is available, False otherwise
118
+ """
119
+ try:
120
+ import faster_whisper
121
+ return True
122
+ except ImportError:
123
+ logger.warning("faster-whisper not available")
124
+ return False
125
+
126
+ def get_available_models(self) -> list[str]:
127
+ """
128
+ Get list of available Whisper models.
129
+
130
+ Returns:
131
+ list[str]: List of available model names
132
+ """
133
+ return [
134
+ "tiny",
135
+ "tiny.en",
136
+ "base",
137
+ "base.en",
138
+ "small",
139
+ "small.en",
140
+ "medium",
141
+ "medium.en",
142
+ "large-v1",
143
+ "large-v2",
144
+ "large-v3"
145
+ ]
146
+
147
+ def get_default_model(self) -> str:
148
+ """
149
+ Get the default model for this provider.
150
+
151
+ Returns:
152
+ str: Default model name
153
+ """
154
+ return "large-v3"
test_stt_migration.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Test script for STT migration."""
3
+
4
+ import sys
5
+ import logging
6
+ from pathlib import Path
7
+
8
+ # Add src to path
9
+ sys.path.insert(0, str(Path(__file__).parent / "src"))
10
+
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ def test_provider_availability():
16
+ """Test that providers can be imported and checked for availability."""
17
+ try:
18
+ from infrastructure.stt import STTProviderFactory, WhisperSTTProvider, ParakeetSTTProvider
19
+
20
+ print("βœ“ Successfully imported STT providers")
21
+
22
+ # Test factory
23
+ available_providers = STTProviderFactory.get_available_providers()
24
+ print(f"Available providers: {available_providers}")
25
+
26
+ # Test individual providers
27
+ whisper = WhisperSTTProvider()
28
+ print(f"Whisper available: {whisper.is_available()}")
29
+ print(f"Whisper models: {whisper.get_available_models()}")
30
+ print(f"Whisper default model: {whisper.get_default_model()}")
31
+
32
+ parakeet = ParakeetSTTProvider()
33
+ print(f"Parakeet available: {parakeet.is_available()}")
34
+ print(f"Parakeet models: {parakeet.get_available_models()}")
35
+ print(f"Parakeet default model: {parakeet.get_default_model()}")
36
+
37
+ return True
38
+
39
+ except Exception as e:
40
+ print(f"βœ— Error testing providers: {e}")
41
+ import traceback
42
+ traceback.print_exc()
43
+ return False
44
+
45
+ def test_legacy_compatibility():
46
+ """Test legacy compatibility functions."""
47
+ try:
48
+ from infrastructure.stt import transcribe_audio, ASRFactory
49
+
50
+ print("βœ“ Successfully imported legacy compatibility functions")
51
+
52
+ # Test ASRFactory
53
+ try:
54
+ model = ASRFactory.get_model("whisper")
55
+ print(f"βœ“ ASRFactory created model: {model.provider_name}")
56
+ except Exception as e:
57
+ print(f"ASRFactory test failed (expected if dependencies missing): {e}")
58
+
59
+ return True
60
+
61
+ except Exception as e:
62
+ print(f"βœ— Error testing legacy compatibility: {e}")
63
+ import traceback
64
+ traceback.print_exc()
65
+ return False
66
+
67
+ def test_domain_integration():
68
+ """Test integration with domain models."""
69
+ try:
70
+ from domain.models.audio_content import AudioContent
71
+ from domain.models.text_content import TextContent
72
+ from domain.exceptions import SpeechRecognitionException
73
+
74
+ print("βœ“ Successfully imported domain models")
75
+
76
+ # Create test audio content
77
+ test_audio = AudioContent(
78
+ data=b"fake audio data for testing",
79
+ format="wav",
80
+ sample_rate=16000,
81
+ duration=1.0,
82
+ filename="test.wav"
83
+ )
84
+
85
+ print(f"βœ“ Created test AudioContent: {test_audio.filename}")
86
+
87
+ return True
88
+
89
+ except Exception as e:
90
+ print(f"βœ— Error testing domain integration: {e}")
91
+ import traceback
92
+ traceback.print_exc()
93
+ return False
94
+
95
+ if __name__ == "__main__":
96
+ print("Testing STT migration...")
97
+ print("=" * 50)
98
+
99
+ tests = [
100
+ ("Provider Availability", test_provider_availability),
101
+ ("Legacy Compatibility", test_legacy_compatibility),
102
+ ("Domain Integration", test_domain_integration)
103
+ ]
104
+
105
+ results = []
106
+ for test_name, test_func in tests:
107
+ print(f"\n{test_name}:")
108
+ print("-" * 30)
109
+ result = test_func()
110
+ results.append((test_name, result))
111
+
112
+ print("\n" + "=" * 50)
113
+ print("Test Results:")
114
+ for test_name, result in results:
115
+ status = "βœ“ PASS" if result else "βœ— FAIL"
116
+ print(f"{test_name}: {status}")
117
+
118
+ all_passed = all(result for _, result in results)
119
+ print(f"\nOverall: {'βœ“ ALL TESTS PASSED' if all_passed else 'βœ— SOME TESTS FAILED'}")
120
+
121
+ sys.exit(0 if all_passed else 1)