File size: 10,811 Bytes
e3cb97b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1be582a
 
 
 
 
 
 
 
 
 
e3cb97b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc056d
e3cb97b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc056d
e3cb97b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc056d
e3cb97b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc056d
e3cb97b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6514731
e3cb97b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
"""Base class for STT provider implementations."""

import logging
import os
import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, TYPE_CHECKING

if TYPE_CHECKING:
    from ...domain.models.audio_content import AudioContent
    from ...domain.models.text_content import TextContent

from ...domain.interfaces.speech_recognition import ISpeechRecognitionService
from ...domain.exceptions import SpeechRecognitionException

logger = logging.getLogger(__name__)


class STTProviderBase(ISpeechRecognitionService, ABC):
    """Abstract base class for STT provider implementations."""

    def __init__(self, provider_name: str, supported_languages: list[str] = None):
        """
        Initialize the STT provider.

        Args:
            provider_name: Name of the STT provider
            supported_languages: List of supported language codes
        """
        self.provider_name = provider_name
        self.supported_languages = supported_languages or []
        self._temp_dir = self._ensure_temp_directory()

    def transcribe(self, audio: 'AudioContent', model: str) -> 'TextContent':
        """
        Transcribe audio content to text.

        Args:
            audio: The audio content to transcribe
            model: The STT model to use for transcription

        Returns:
            TextContent: The transcribed text

        Raises:
            SpeechRecognitionException: If transcription fails
        """
        try:
            logger.info(f"Starting transcription with {self.provider_name} provider using model {model}")
            self._validate_audio(audio)

            # Preprocess audio if needed
            processed_audio_path = self._preprocess_audio(audio)

            try:
                # Perform transcription using provider-specific implementation
                transcribed_text = self._perform_transcription(processed_audio_path, model)

                # Create TextContent from transcription result
                from ...domain.models.text_content import TextContent

                # Detect language if not specified (default to English)
                detected_language = self._detect_language(transcribed_text) or 'en'

                text_content = TextContent(
                    text=transcribed_text,
                    language=detected_language,
                    encoding='utf-8'
                )

                logger.info(f"Transcription completed successfully with {self.provider_name}")
                return text_content

            finally:
                # Clean up temporary audio file
                self._cleanup_temp_file(processed_audio_path)

        except Exception as e:
            logger.error(f"Transcription failed with {self.provider_name}: {str(e)}")
            raise SpeechRecognitionException(f"STT transcription failed: {str(e)}") from e

    @abstractmethod
    def _perform_transcription(self, audio_path: Path, model: str) -> str:
        """
        Perform the actual transcription using provider-specific implementation.

        Args:
            audio_path: Path to the preprocessed audio file
            model: The STT model to use

        Returns:
            str: The transcribed text
        """
        pass

    @abstractmethod
    def is_available(self) -> bool:
        """
        Check if the STT provider is available and ready to use.

        Returns:
            bool: True if provider is available, False otherwise
        """
        pass

    @abstractmethod
    def get_available_models(self) -> list[str]:
        """
        Get list of available models for this provider.

        Returns:
            list[str]: List of model identifiers
        """
        pass

    @abstractmethod
    def get_default_model(self) -> str:
        """
        Get the default model for this provider.

        Returns:
            str: Default model name
        """
        pass

    def _preprocess_audio(self, audio: 'AudioContent') -> Path:
        """
        Preprocess audio content for transcription.

        Args:
            audio: The audio content to preprocess

        Returns:
            Path: Path to the preprocessed audio file
        """
        try:
            # Create temporary file for audio processing
            temp_file = self._temp_dir / f"audio_{id(audio)}.wav"

            # Write audio data to temporary file
            with open(temp_file, 'wb') as f:
                f.write(audio.data)

            # Convert to required format if needed
            processed_file = self._convert_audio_format(temp_file, audio)

            logger.info(f"Audio preprocessed and saved to: {processed_file}")
            return processed_file

        except Exception as e:
            logger.error(f"Audio preprocessing failed: {str(e)}")
            raise SpeechRecognitionException(f"Audio preprocessing failed: {str(e)}") from e

    def _convert_audio_format(self, audio_path: Path, audio: 'AudioContent') -> Path:
        """
        Convert audio to the required format for transcription.

        Args:
            audio_path: Path to the original audio file
            audio: The audio content metadata

        Returns:
            Path: Path to the converted audio file
        """
        try:
            # Import audio processing library
            from pydub import AudioSegment

            # Load audio file
            if audio.format.lower() == 'mp3':
                audio_segment = AudioSegment.from_mp3(audio_path)
            elif audio.format.lower() == 'wav':
                audio_segment = AudioSegment.from_wav(audio_path)
            elif audio.format.lower() == 'flac':
                audio_segment = AudioSegment.from_file(audio_path, format='flac')
            elif audio.format.lower() == 'ogg':
                audio_segment = AudioSegment.from_ogg(audio_path)
            else:
                # Try to load as generic audio file
                audio_segment = AudioSegment.from_file(audio_path)

            # Convert to standard format for STT (16kHz, mono, WAV)
            standardized_audio = audio_segment.set_frame_rate(16000).set_channels(1)

            # Create output path
            output_path = audio_path.with_suffix('.wav')
            if output_path == audio_path:
                output_path = audio_path.with_name(f"converted_{audio_path.name}")

            # Export converted audio
            standardized_audio.export(output_path, format="wav")

            logger.info(f"Audio converted from {audio.format} to WAV: {output_path}")
            return output_path

        except ImportError:
            logger.warning("pydub not available, using original audio file")
            return audio_path
        except Exception as e:
            logger.warning(f"Audio conversion failed, using original file: {str(e)}")
            return audio_path

    def _validate_audio(self, audio: 'AudioContent') -> None:
        """
        Validate the audio content for transcription.

        Args:
            audio: The audio content to validate

        Raises:
            SpeechRecognitionException: If audio is invalid
        """
        if not audio.data:
            raise SpeechRecognitionException("Audio data cannot be empty")

        if audio.duration > 3600:  # 1 hour limit
            raise SpeechRecognitionException("Audio duration exceeds maximum limit of 1 hour")

        if audio.duration < 0.1:  # Minimum 100ms
            raise SpeechRecognitionException("Audio duration too short (minimum 100ms)")

        if not audio.is_valid_format:
            raise SpeechRecognitionException(f"Unsupported audio format: {audio.format}")

    def _detect_language(self, text: str) -> Optional[str]:
        """
        Detect the language of transcribed text.

        Args:
            text: The transcribed text

        Returns:
            Optional[str]: Detected language code or None if detection fails
        """
        try:
            # Simple heuristic-based language detection
            # This is a basic implementation - in production, you might use langdetect or similar

            # Check for common English words
            english_indicators = ['the', 'and', 'is', 'in', 'to', 'of', 'a', 'that', 'it', 'with']
            text_lower = text.lower()
            english_count = sum(1 for word in english_indicators if word in text_lower)

            if english_count >= 2:
                return 'en'

            # Default to English if uncertain
            return 'en'

        except Exception as e:
            logger.warning(f"Language detection failed: {str(e)}")
            return None

    def _ensure_temp_directory(self) -> Path:
        """
        Ensure temporary directory exists and return its path.

        Returns:
            Path: Path to the temporary directory
        """
        temp_dir = Path(tempfile.gettempdir()) / "stt_temp"
        temp_dir.mkdir(exist_ok=True)
        return temp_dir

    def _cleanup_temp_file(self, file_path: Path) -> None:
        """
        Clean up a temporary file.

        Args:
            file_path: Path to the file to clean up
        """
        try:
            if file_path.exists():
                file_path.unlink()
                logger.info(f"Cleaned up temp file: {file_path}")
        except Exception as e:
            logger.warning(f"Failed to cleanup temp file {file_path}: {str(e)}")

    def _cleanup_old_temp_files(self, max_age_hours: int = 24) -> None:
        """
        Clean up old temporary files.

        Args:
            max_age_hours: Maximum age of files to keep in hours
        """
        try:
            import time
            current_time = time.time()
            max_age_seconds = max_age_hours * 3600

            for file_path in self._temp_dir.glob("*"):
                if file_path.is_file():
                    file_age = current_time - file_path.stat().st_mtime
                    if file_age > max_age_seconds:
                        file_path.unlink()
                        logger.info(f"Cleaned up old temp file: {file_path}")

        except Exception as e:
            logger.warning(f"Failed to cleanup old temp files: {str(e)}")

    def _handle_provider_error(self, error: Exception, context: str = "") -> None:
        """
        Handle provider-specific errors and convert to domain exceptions.

        Args:
            error: The original error
            context: Additional context about when the error occurred
        """
        error_msg = f"{self.provider_name} error"
        if context:
            error_msg += f" during {context}"
        error_msg += f": {str(error)}"

        logger.error(error_msg, exception=error)
        raise SpeechRecognitionException(error_msg) from error