Spaces:
Build error
Build error
Michael Hu
commited on
Commit
·
fdc056d
1
Parent(s):
f7aaf3b
add more logs
Browse files- DEVELOPER_GUIDE.md +2 -2
- src/application/dtos/dto_validation.py +36 -36
- src/application/error_handling/error_mapper.py +1 -1
- src/application/error_handling/structured_logger.py +1 -1
- src/application/services/configuration_service.py +4 -4
- src/infrastructure/base/file_utils.py +9 -9
- src/infrastructure/base/stt_provider_base.py +4 -4
- src/infrastructure/base/translation_provider_base.py +2 -2
- src/infrastructure/base/tts_provider_base.py +1 -1
- src/infrastructure/config/app_config.py +8 -8
- src/infrastructure/config/dependency_container.py +9 -3
- src/infrastructure/stt/legacy_compatibility.py +26 -26
- src/infrastructure/stt/parakeet_provider.py +8 -8
- src/infrastructure/stt/provider_factory.py +17 -17
- src/infrastructure/stt/whisper_provider.py +8 -8
- src/infrastructure/translation/nllb_provider.py +3 -3
- src/infrastructure/translation/provider_factory.py +3 -3
- src/infrastructure/tts/dia_provider.py +118 -35
- src/infrastructure/tts/dummy_provider.py +15 -15
- src/infrastructure/tts/kokoro_provider.py +5 -5
- src/infrastructure/tts/provider_factory.py +24 -6
- src/infrastructure/utils/dependency_installer.py +304 -0
- tests/unit/application/error_handling/test_structured_logger.py +1 -1
- utils/stt.py +21 -21
- utils/translation.py +4 -4
- utils/tts.py +16 -16
- utils/tts_dia.py +28 -28
- utils/tts_dummy.py +12 -12
- utils/tts_kokoro.py +23 -23
DEVELOPER_GUIDE.md
CHANGED
@@ -173,7 +173,7 @@ def _register_default_providers(self):
|
|
173 |
self._providers['my_tts'] = MyTTSProvider
|
174 |
logger.info("Registered MyTTS provider")
|
175 |
except ImportError as e:
|
176 |
-
logger.
|
177 |
```
|
178 |
|
179 |
### Step 3: Add Configuration Support
|
@@ -590,7 +590,7 @@ import logging
|
|
590 |
logger = logging.getLogger(__name__)
|
591 |
|
592 |
# Use appropriate log levels
|
593 |
-
logger.
|
594 |
logger.info("General information about program execution")
|
595 |
logger.warning("Something unexpected happened")
|
596 |
logger.error("A serious error occurred")
|
|
|
173 |
self._providers['my_tts'] = MyTTSProvider
|
174 |
logger.info("Registered MyTTS provider")
|
175 |
except ImportError as e:
|
176 |
+
logger.info(f"MyTTS provider not available: {e}")
|
177 |
```
|
178 |
|
179 |
### Step 3: Add Configuration Support
|
|
|
590 |
logger = logging.getLogger(__name__)
|
591 |
|
592 |
# Use appropriate log levels
|
593 |
+
logger.info("Detailed debugging information")
|
594 |
logger.info("General information about program execution")
|
595 |
logger.warning("Something unexpected happened")
|
596 |
logger.error("A serious error occurred")
|
src/application/dtos/dto_validation.py
CHANGED
@@ -15,13 +15,13 @@ T = TypeVar('T')
|
|
15 |
|
16 |
class ValidationError(Exception):
|
17 |
"""Custom exception for DTO validation errors"""
|
18 |
-
|
19 |
def __init__(self, message: str, field: str = None, value: Any = None):
|
20 |
self.message = message
|
21 |
self.field = field
|
22 |
self.value = value
|
23 |
super().__init__(self.message)
|
24 |
-
|
25 |
def __str__(self):
|
26 |
if self.field:
|
27 |
return f"Validation error for field '{self.field}': {self.message}"
|
@@ -30,13 +30,13 @@ class ValidationError(Exception):
|
|
30 |
|
31 |
def validate_dto(dto_instance: Any) -> bool:
|
32 |
"""Validate a DTO instance
|
33 |
-
|
34 |
Args:
|
35 |
dto_instance: The DTO instance to validate
|
36 |
-
|
37 |
Returns:
|
38 |
bool: True if validation passes
|
39 |
-
|
40 |
Raises:
|
41 |
ValidationError: If validation fails
|
42 |
"""
|
@@ -44,11 +44,11 @@ def validate_dto(dto_instance: Any) -> bool:
|
|
44 |
# Call the DTO's validation method if it exists
|
45 |
if hasattr(dto_instance, '_validate'):
|
46 |
dto_instance._validate()
|
47 |
-
|
48 |
# Additional validation can be added here
|
49 |
-
logger.
|
50 |
return True
|
51 |
-
|
52 |
except ValueError as e:
|
53 |
logger.error(f"Validation failed for {type(dto_instance).__name__}: {e}")
|
54 |
raise ValidationError(str(e)) from e
|
@@ -59,10 +59,10 @@ def validate_dto(dto_instance: Any) -> bool:
|
|
59 |
|
60 |
def validation_required(func: Callable[..., T]) -> Callable[..., T]:
|
61 |
"""Decorator to ensure DTO validation before method execution
|
62 |
-
|
63 |
Args:
|
64 |
func: The method to decorate
|
65 |
-
|
66 |
Returns:
|
67 |
Decorated function that validates 'self' before execution
|
68 |
"""
|
@@ -75,23 +75,23 @@ def validation_required(func: Callable[..., T]) -> Callable[..., T]:
|
|
75 |
raise
|
76 |
except Exception as e:
|
77 |
raise ValidationError(f"Error in {func.__name__}: {e}") from e
|
78 |
-
|
79 |
return wrapper
|
80 |
|
81 |
|
82 |
-
def validate_field(value: Any, field_name: str, validator: Callable[[Any], bool],
|
83 |
error_message: str = None) -> Any:
|
84 |
"""Validate a single field value
|
85 |
-
|
86 |
Args:
|
87 |
value: The value to validate
|
88 |
field_name: Name of the field being validated
|
89 |
validator: Function that returns True if value is valid
|
90 |
error_message: Custom error message
|
91 |
-
|
92 |
Returns:
|
93 |
The validated value
|
94 |
-
|
95 |
Raises:
|
96 |
ValidationError: If validation fails
|
97 |
"""
|
@@ -108,37 +108,37 @@ def validate_field(value: Any, field_name: str, validator: Callable[[Any], bool]
|
|
108 |
|
109 |
def validate_required(value: Any, field_name: str) -> Any:
|
110 |
"""Validate that a field is not None or empty
|
111 |
-
|
112 |
Args:
|
113 |
value: The value to validate
|
114 |
field_name: Name of the field being validated
|
115 |
-
|
116 |
Returns:
|
117 |
The validated value
|
118 |
-
|
119 |
Raises:
|
120 |
ValidationError: If field is None or empty
|
121 |
"""
|
122 |
if value is None:
|
123 |
raise ValidationError(f"Field '{field_name}' is required", field_name, value)
|
124 |
-
|
125 |
if isinstance(value, (str, list, dict)) and len(value) == 0:
|
126 |
raise ValidationError(f"Field '{field_name}' cannot be empty", field_name, value)
|
127 |
-
|
128 |
return value
|
129 |
|
130 |
|
131 |
def validate_type(value: Any, field_name: str, expected_type: Union[type, tuple]) -> Any:
|
132 |
"""Validate that a field is of the expected type
|
133 |
-
|
134 |
Args:
|
135 |
value: The value to validate
|
136 |
field_name: Name of the field being validated
|
137 |
expected_type: Expected type or tuple of types
|
138 |
-
|
139 |
Returns:
|
140 |
The validated value
|
141 |
-
|
142 |
Raises:
|
143 |
ValidationError: If type doesn't match
|
144 |
"""
|
@@ -148,30 +148,30 @@ def validate_type(value: Any, field_name: str, expected_type: Union[type, tuple]
|
|
148 |
expected_str = " or ".join(type_names)
|
149 |
else:
|
150 |
expected_str = expected_type.__name__
|
151 |
-
|
152 |
actual_type = type(value).__name__
|
153 |
raise ValidationError(
|
154 |
f"Field '{field_name}' must be of type {expected_str}, got {actual_type}",
|
155 |
field_name, value
|
156 |
)
|
157 |
-
|
158 |
return value
|
159 |
|
160 |
|
161 |
-
def validate_range(value: Union[int, float], field_name: str,
|
162 |
-
min_value: Union[int, float] = None,
|
163 |
max_value: Union[int, float] = None) -> Union[int, float]:
|
164 |
"""Validate that a numeric value is within a specified range
|
165 |
-
|
166 |
Args:
|
167 |
value: The numeric value to validate
|
168 |
field_name: Name of the field being validated
|
169 |
min_value: Minimum allowed value (inclusive)
|
170 |
max_value: Maximum allowed value (inclusive)
|
171 |
-
|
172 |
Returns:
|
173 |
The validated value
|
174 |
-
|
175 |
Raises:
|
176 |
ValidationError: If value is outside the range
|
177 |
"""
|
@@ -180,27 +180,27 @@ def validate_range(value: Union[int, float], field_name: str,
|
|
180 |
f"Field '{field_name}' must be >= {min_value}, got {value}",
|
181 |
field_name, value
|
182 |
)
|
183 |
-
|
184 |
if max_value is not None and value > max_value:
|
185 |
raise ValidationError(
|
186 |
f"Field '{field_name}' must be <= {max_value}, got {value}",
|
187 |
field_name, value
|
188 |
)
|
189 |
-
|
190 |
return value
|
191 |
|
192 |
|
193 |
def validate_choices(value: Any, field_name: str, choices: list) -> Any:
|
194 |
"""Validate that a value is one of the allowed choices
|
195 |
-
|
196 |
Args:
|
197 |
value: The value to validate
|
198 |
field_name: Name of the field being validated
|
199 |
choices: List of allowed values
|
200 |
-
|
201 |
Returns:
|
202 |
The validated value
|
203 |
-
|
204 |
Raises:
|
205 |
ValidationError: If value is not in choices
|
206 |
"""
|
@@ -209,5 +209,5 @@ def validate_choices(value: Any, field_name: str, choices: list) -> Any:
|
|
209 |
f"Field '{field_name}' must be one of {choices}, got '{value}'",
|
210 |
field_name, value
|
211 |
)
|
212 |
-
|
213 |
return value
|
|
|
15 |
|
16 |
class ValidationError(Exception):
|
17 |
"""Custom exception for DTO validation errors"""
|
18 |
+
|
19 |
def __init__(self, message: str, field: str = None, value: Any = None):
|
20 |
self.message = message
|
21 |
self.field = field
|
22 |
self.value = value
|
23 |
super().__init__(self.message)
|
24 |
+
|
25 |
def __str__(self):
|
26 |
if self.field:
|
27 |
return f"Validation error for field '{self.field}': {self.message}"
|
|
|
30 |
|
31 |
def validate_dto(dto_instance: Any) -> bool:
|
32 |
"""Validate a DTO instance
|
33 |
+
|
34 |
Args:
|
35 |
dto_instance: The DTO instance to validate
|
36 |
+
|
37 |
Returns:
|
38 |
bool: True if validation passes
|
39 |
+
|
40 |
Raises:
|
41 |
ValidationError: If validation fails
|
42 |
"""
|
|
|
44 |
# Call the DTO's validation method if it exists
|
45 |
if hasattr(dto_instance, '_validate'):
|
46 |
dto_instance._validate()
|
47 |
+
|
48 |
# Additional validation can be added here
|
49 |
+
logger.info(f"Successfully validated {type(dto_instance).__name__}")
|
50 |
return True
|
51 |
+
|
52 |
except ValueError as e:
|
53 |
logger.error(f"Validation failed for {type(dto_instance).__name__}: {e}")
|
54 |
raise ValidationError(str(e)) from e
|
|
|
59 |
|
60 |
def validation_required(func: Callable[..., T]) -> Callable[..., T]:
|
61 |
"""Decorator to ensure DTO validation before method execution
|
62 |
+
|
63 |
Args:
|
64 |
func: The method to decorate
|
65 |
+
|
66 |
Returns:
|
67 |
Decorated function that validates 'self' before execution
|
68 |
"""
|
|
|
75 |
raise
|
76 |
except Exception as e:
|
77 |
raise ValidationError(f"Error in {func.__name__}: {e}") from e
|
78 |
+
|
79 |
return wrapper
|
80 |
|
81 |
|
82 |
+
def validate_field(value: Any, field_name: str, validator: Callable[[Any], bool],
|
83 |
error_message: str = None) -> Any:
|
84 |
"""Validate a single field value
|
85 |
+
|
86 |
Args:
|
87 |
value: The value to validate
|
88 |
field_name: Name of the field being validated
|
89 |
validator: Function that returns True if value is valid
|
90 |
error_message: Custom error message
|
91 |
+
|
92 |
Returns:
|
93 |
The validated value
|
94 |
+
|
95 |
Raises:
|
96 |
ValidationError: If validation fails
|
97 |
"""
|
|
|
108 |
|
109 |
def validate_required(value: Any, field_name: str) -> Any:
|
110 |
"""Validate that a field is not None or empty
|
111 |
+
|
112 |
Args:
|
113 |
value: The value to validate
|
114 |
field_name: Name of the field being validated
|
115 |
+
|
116 |
Returns:
|
117 |
The validated value
|
118 |
+
|
119 |
Raises:
|
120 |
ValidationError: If field is None or empty
|
121 |
"""
|
122 |
if value is None:
|
123 |
raise ValidationError(f"Field '{field_name}' is required", field_name, value)
|
124 |
+
|
125 |
if isinstance(value, (str, list, dict)) and len(value) == 0:
|
126 |
raise ValidationError(f"Field '{field_name}' cannot be empty", field_name, value)
|
127 |
+
|
128 |
return value
|
129 |
|
130 |
|
131 |
def validate_type(value: Any, field_name: str, expected_type: Union[type, tuple]) -> Any:
|
132 |
"""Validate that a field is of the expected type
|
133 |
+
|
134 |
Args:
|
135 |
value: The value to validate
|
136 |
field_name: Name of the field being validated
|
137 |
expected_type: Expected type or tuple of types
|
138 |
+
|
139 |
Returns:
|
140 |
The validated value
|
141 |
+
|
142 |
Raises:
|
143 |
ValidationError: If type doesn't match
|
144 |
"""
|
|
|
148 |
expected_str = " or ".join(type_names)
|
149 |
else:
|
150 |
expected_str = expected_type.__name__
|
151 |
+
|
152 |
actual_type = type(value).__name__
|
153 |
raise ValidationError(
|
154 |
f"Field '{field_name}' must be of type {expected_str}, got {actual_type}",
|
155 |
field_name, value
|
156 |
)
|
157 |
+
|
158 |
return value
|
159 |
|
160 |
|
161 |
+
def validate_range(value: Union[int, float], field_name: str,
|
162 |
+
min_value: Union[int, float] = None,
|
163 |
max_value: Union[int, float] = None) -> Union[int, float]:
|
164 |
"""Validate that a numeric value is within a specified range
|
165 |
+
|
166 |
Args:
|
167 |
value: The numeric value to validate
|
168 |
field_name: Name of the field being validated
|
169 |
min_value: Minimum allowed value (inclusive)
|
170 |
max_value: Maximum allowed value (inclusive)
|
171 |
+
|
172 |
Returns:
|
173 |
The validated value
|
174 |
+
|
175 |
Raises:
|
176 |
ValidationError: If value is outside the range
|
177 |
"""
|
|
|
180 |
f"Field '{field_name}' must be >= {min_value}, got {value}",
|
181 |
field_name, value
|
182 |
)
|
183 |
+
|
184 |
if max_value is not None and value > max_value:
|
185 |
raise ValidationError(
|
186 |
f"Field '{field_name}' must be <= {max_value}, got {value}",
|
187 |
field_name, value
|
188 |
)
|
189 |
+
|
190 |
return value
|
191 |
|
192 |
|
193 |
def validate_choices(value: Any, field_name: str, choices: list) -> Any:
|
194 |
"""Validate that a value is one of the allowed choices
|
195 |
+
|
196 |
Args:
|
197 |
value: The value to validate
|
198 |
field_name: Name of the field being validated
|
199 |
choices: List of allowed values
|
200 |
+
|
201 |
Returns:
|
202 |
The validated value
|
203 |
+
|
204 |
Raises:
|
205 |
ValidationError: If value is not in choices
|
206 |
"""
|
|
|
209 |
f"Field '{field_name}' must be one of {choices}, got '{value}'",
|
210 |
field_name, value
|
211 |
)
|
212 |
+
|
213 |
return value
|
src/application/error_handling/error_mapper.py
CHANGED
@@ -262,7 +262,7 @@ class ErrorMapper:
|
|
262 |
if context:
|
263 |
mapping = self._enhance_mapping_with_context(mapping, exception, context)
|
264 |
|
265 |
-
logger.
|
266 |
return mapping
|
267 |
|
268 |
except Exception as e:
|
|
|
262 |
if context:
|
263 |
mapping = self._enhance_mapping_with_context(mapping, exception, context)
|
264 |
|
265 |
+
logger.info(f"Mapped {type(exception).__name__} to {mapping.error_code}")
|
266 |
return mapping
|
267 |
|
268 |
except Exception as e:
|
src/application/error_handling/structured_logger.py
CHANGED
@@ -125,7 +125,7 @@ class StructuredLogger:
|
|
125 |
if self.logger.isEnabledFor(logging.DEBUG):
|
126 |
log_data = self._get_log_data(message, LogLevel.DEBUG.value, context, extra)
|
127 |
# Use 'structured_data' to avoid conflicts with LogRecord attributes
|
128 |
-
self.logger.
|
129 |
|
130 |
def info(self, message: str, context: Optional[LogContext] = None,
|
131 |
extra: Optional[Dict[str, Any]] = None) -> None:
|
|
|
125 |
if self.logger.isEnabledFor(logging.DEBUG):
|
126 |
log_data = self._get_log_data(message, LogLevel.DEBUG.value, context, extra)
|
127 |
# Use 'structured_data' to avoid conflicts with LogRecord attributes
|
128 |
+
self.logger.info(message, extra={'structured_data': log_data})
|
129 |
|
130 |
def info(self, message: str, context: Optional[LogContext] = None,
|
131 |
extra: Optional[Dict[str, Any]] = None) -> None:
|
src/application/services/configuration_service.py
CHANGED
@@ -153,7 +153,7 @@ class ConfigurationApplicationService:
|
|
153 |
# Update the actual config object
|
154 |
if hasattr(self._config.tts, key):
|
155 |
setattr(self._config.tts, key, value)
|
156 |
-
logger.
|
157 |
else:
|
158 |
logger.warning(f"Unknown TTS configuration key: {key}")
|
159 |
|
@@ -192,7 +192,7 @@ class ConfigurationApplicationService:
|
|
192 |
# Update the actual config object
|
193 |
if hasattr(self._config.stt, key):
|
194 |
setattr(self._config.stt, key, value)
|
195 |
-
logger.
|
196 |
else:
|
197 |
logger.warning(f"Unknown STT configuration key: {key}")
|
198 |
|
@@ -231,7 +231,7 @@ class ConfigurationApplicationService:
|
|
231 |
# Update the actual config object
|
232 |
if hasattr(self._config.translation, key):
|
233 |
setattr(self._config.translation, key, value)
|
234 |
-
logger.
|
235 |
else:
|
236 |
logger.warning(f"Unknown translation configuration key: {key}")
|
237 |
|
@@ -270,7 +270,7 @@ class ConfigurationApplicationService:
|
|
270 |
# Update the actual config object
|
271 |
if hasattr(self._config.processing, key):
|
272 |
setattr(self._config.processing, key, value)
|
273 |
-
logger.
|
274 |
else:
|
275 |
logger.warning(f"Unknown processing configuration key: {key}")
|
276 |
|
|
|
153 |
# Update the actual config object
|
154 |
if hasattr(self._config.tts, key):
|
155 |
setattr(self._config.tts, key, value)
|
156 |
+
logger.info(f"Updated TTS config: {key} = {value}")
|
157 |
else:
|
158 |
logger.warning(f"Unknown TTS configuration key: {key}")
|
159 |
|
|
|
192 |
# Update the actual config object
|
193 |
if hasattr(self._config.stt, key):
|
194 |
setattr(self._config.stt, key, value)
|
195 |
+
logger.info(f"Updated STT config: {key} = {value}")
|
196 |
else:
|
197 |
logger.warning(f"Unknown STT configuration key: {key}")
|
198 |
|
|
|
231 |
# Update the actual config object
|
232 |
if hasattr(self._config.translation, key):
|
233 |
setattr(self._config.translation, key, value)
|
234 |
+
logger.info(f"Updated translation config: {key} = {value}")
|
235 |
else:
|
236 |
logger.warning(f"Unknown translation configuration key: {key}")
|
237 |
|
|
|
270 |
# Update the actual config object
|
271 |
if hasattr(self._config.processing, key):
|
272 |
setattr(self._config.processing, key, value)
|
273 |
+
logger.info(f"Updated processing config: {key} = {value}")
|
274 |
else:
|
275 |
logger.warning(f"Unknown processing configuration key: {key}")
|
276 |
|
src/infrastructure/base/file_utils.py
CHANGED
@@ -27,7 +27,7 @@ class FileManager:
|
|
27 |
self.base_dir = Path(tempfile.gettempdir()) / "tts_app"
|
28 |
|
29 |
self.base_dir.mkdir(exist_ok=True)
|
30 |
-
logger.
|
31 |
|
32 |
def create_temp_file(self, suffix: str = ".tmp", prefix: str = "temp", content: bytes = None) -> Path:
|
33 |
"""
|
@@ -51,7 +51,7 @@ class FileManager:
|
|
51 |
else:
|
52 |
file_path.touch()
|
53 |
|
54 |
-
logger.
|
55 |
return file_path
|
56 |
|
57 |
def create_unique_filename(self, base_name: str, extension: str = "", content_hash: bool = False, content: bytes = None) -> str:
|
@@ -103,7 +103,7 @@ class FileManager:
|
|
103 |
with open(file_path, 'wb') as f:
|
104 |
f.write(audio_data)
|
105 |
|
106 |
-
logger.
|
107 |
return file_path
|
108 |
|
109 |
def save_text_file(self, text_content: str, encoding: str = "utf-8", prefix: str = "text") -> Path:
|
@@ -124,7 +124,7 @@ class FileManager:
|
|
124 |
with open(file_path, 'w', encoding=encoding) as f:
|
125 |
f.write(text_content)
|
126 |
|
127 |
-
logger.
|
128 |
return file_path
|
129 |
|
130 |
def cleanup_file(self, file_path: Union[str, Path]) -> bool:
|
@@ -141,7 +141,7 @@ class FileManager:
|
|
141 |
path = Path(file_path)
|
142 |
if path.exists() and path.is_file():
|
143 |
path.unlink()
|
144 |
-
logger.
|
145 |
return True
|
146 |
return False
|
147 |
except Exception as e:
|
@@ -223,7 +223,7 @@ class FileManager:
|
|
223 |
"""
|
224 |
path = Path(dir_path)
|
225 |
path.mkdir(parents=True, exist_ok=True)
|
226 |
-
logger.
|
227 |
return path
|
228 |
|
229 |
def get_disk_usage(self) -> dict:
|
@@ -282,7 +282,7 @@ class AudioFileGenerator:
|
|
282 |
wav_file.setframerate(sample_rate)
|
283 |
wav_file.writeframes(audio_data)
|
284 |
|
285 |
-
logger.
|
286 |
return path
|
287 |
|
288 |
except Exception as e:
|
@@ -318,7 +318,7 @@ class AudioFileGenerator:
|
|
318 |
|
319 |
sf.write(str(path), audio_array, sample_rate)
|
320 |
|
321 |
-
logger.
|
322 |
return path
|
323 |
|
324 |
except ImportError:
|
@@ -406,4 +406,4 @@ class ErrorHandler:
|
|
406 |
debug_msg += f" ({context})"
|
407 |
debug_msg += f": {message}"
|
408 |
|
409 |
-
self.logger.
|
|
|
27 |
self.base_dir = Path(tempfile.gettempdir()) / "tts_app"
|
28 |
|
29 |
self.base_dir.mkdir(exist_ok=True)
|
30 |
+
logger.info(f"FileManager initialized with base directory: {self.base_dir}")
|
31 |
|
32 |
def create_temp_file(self, suffix: str = ".tmp", prefix: str = "temp", content: bytes = None) -> Path:
|
33 |
"""
|
|
|
51 |
else:
|
52 |
file_path.touch()
|
53 |
|
54 |
+
logger.info(f"Created temporary file: {file_path}")
|
55 |
return file_path
|
56 |
|
57 |
def create_unique_filename(self, base_name: str, extension: str = "", content_hash: bool = False, content: bytes = None) -> str:
|
|
|
103 |
with open(file_path, 'wb') as f:
|
104 |
f.write(audio_data)
|
105 |
|
106 |
+
logger.info(f"Saved audio file: {file_path} ({len(audio_data)} bytes)")
|
107 |
return file_path
|
108 |
|
109 |
def save_text_file(self, text_content: str, encoding: str = "utf-8", prefix: str = "text") -> Path:
|
|
|
124 |
with open(file_path, 'w', encoding=encoding) as f:
|
125 |
f.write(text_content)
|
126 |
|
127 |
+
logger.info(f"Saved text file: {file_path} ({len(text_content)} characters)")
|
128 |
return file_path
|
129 |
|
130 |
def cleanup_file(self, file_path: Union[str, Path]) -> bool:
|
|
|
141 |
path = Path(file_path)
|
142 |
if path.exists() and path.is_file():
|
143 |
path.unlink()
|
144 |
+
logger.info(f"Cleaned up file: {path}")
|
145 |
return True
|
146 |
return False
|
147 |
except Exception as e:
|
|
|
223 |
"""
|
224 |
path = Path(dir_path)
|
225 |
path.mkdir(parents=True, exist_ok=True)
|
226 |
+
logger.info(f"Ensured directory exists: {path}")
|
227 |
return path
|
228 |
|
229 |
def get_disk_usage(self) -> dict:
|
|
|
282 |
wav_file.setframerate(sample_rate)
|
283 |
wav_file.writeframes(audio_data)
|
284 |
|
285 |
+
logger.info(f"Saved WAV file: {path} (sample_rate={sample_rate}, channels={channels})")
|
286 |
return path
|
287 |
|
288 |
except Exception as e:
|
|
|
318 |
|
319 |
sf.write(str(path), audio_array, sample_rate)
|
320 |
|
321 |
+
logger.info(f"Converted numpy array to WAV: {path}")
|
322 |
return path
|
323 |
|
324 |
except ImportError:
|
|
|
406 |
debug_msg += f" ({context})"
|
407 |
debug_msg += f": {message}"
|
408 |
|
409 |
+
self.logger.info(debug_msg)
|
src/infrastructure/base/stt_provider_base.py
CHANGED
@@ -145,7 +145,7 @@ class STTProviderBase(ISpeechRecognitionService, ABC):
|
|
145 |
# Convert to required format if needed
|
146 |
processed_file = self._convert_audio_format(temp_file, audio)
|
147 |
|
148 |
-
logger.
|
149 |
return processed_file
|
150 |
|
151 |
except Exception as e:
|
@@ -191,7 +191,7 @@ class STTProviderBase(ISpeechRecognitionService, ABC):
|
|
191 |
# Export converted audio
|
192 |
standardized_audio.export(output_path, format="wav")
|
193 |
|
194 |
-
logger.
|
195 |
return output_path
|
196 |
|
197 |
except ImportError:
|
@@ -273,7 +273,7 @@ class STTProviderBase(ISpeechRecognitionService, ABC):
|
|
273 |
try:
|
274 |
if file_path.exists():
|
275 |
file_path.unlink()
|
276 |
-
logger.
|
277 |
except Exception as e:
|
278 |
logger.warning(f"Failed to cleanup temp file {file_path}: {str(e)}")
|
279 |
|
@@ -294,7 +294,7 @@ class STTProviderBase(ISpeechRecognitionService, ABC):
|
|
294 |
file_age = current_time - file_path.stat().st_mtime
|
295 |
if file_age > max_age_seconds:
|
296 |
file_path.unlink()
|
297 |
-
logger.
|
298 |
|
299 |
except Exception as e:
|
300 |
logger.warning(f"Failed to cleanup old temp files: {str(e)}")
|
|
|
145 |
# Convert to required format if needed
|
146 |
processed_file = self._convert_audio_format(temp_file, audio)
|
147 |
|
148 |
+
logger.info(f"Audio preprocessed and saved to: {processed_file}")
|
149 |
return processed_file
|
150 |
|
151 |
except Exception as e:
|
|
|
191 |
# Export converted audio
|
192 |
standardized_audio.export(output_path, format="wav")
|
193 |
|
194 |
+
logger.info(f"Audio converted from {audio.format} to WAV: {output_path}")
|
195 |
return output_path
|
196 |
|
197 |
except ImportError:
|
|
|
273 |
try:
|
274 |
if file_path.exists():
|
275 |
file_path.unlink()
|
276 |
+
logger.info(f"Cleaned up temp file: {file_path}")
|
277 |
except Exception as e:
|
278 |
logger.warning(f"Failed to cleanup temp file {file_path}: {str(e)}")
|
279 |
|
|
|
294 |
file_age = current_time - file_path.stat().st_mtime
|
295 |
if file_age > max_age_seconds:
|
296 |
file_path.unlink()
|
297 |
+
logger.info(f"Cleaned up old temp file: {file_path}")
|
298 |
|
299 |
except Exception as e:
|
300 |
logger.warning(f"Failed to cleanup old temp files: {str(e)}")
|
src/infrastructure/base/translation_provider_base.py
CHANGED
@@ -56,7 +56,7 @@ class TranslationProviderBase(ITranslationService, ABC):
|
|
56 |
# Translate each chunk
|
57 |
translated_chunks = []
|
58 |
for i, chunk in enumerate(text_chunks):
|
59 |
-
logger.
|
60 |
translated_chunk = self._translate_chunk(
|
61 |
chunk,
|
62 |
request.source_text.language,
|
@@ -160,7 +160,7 @@ class TranslationProviderBase(ITranslationService, ABC):
|
|
160 |
if current_chunk.strip():
|
161 |
chunks.append(current_chunk.strip())
|
162 |
|
163 |
-
logger.
|
164 |
return chunks
|
165 |
|
166 |
def _split_into_sentences(self, text: str) -> List[str]:
|
|
|
56 |
# Translate each chunk
|
57 |
translated_chunks = []
|
58 |
for i, chunk in enumerate(text_chunks):
|
59 |
+
logger.info(f"Translating chunk {i+1}/{len(text_chunks)}")
|
60 |
translated_chunk = self._translate_chunk(
|
61 |
chunk,
|
62 |
request.source_text.language,
|
|
|
160 |
if current_chunk.strip():
|
161 |
chunks.append(current_chunk.strip())
|
162 |
|
163 |
+
logger.info(f"Text chunked into {len(chunks)} pieces")
|
164 |
return chunks
|
165 |
|
166 |
def _split_into_sentences(self, text: str) -> List[str]:
|
src/infrastructure/base/tts_provider_base.py
CHANGED
@@ -322,7 +322,7 @@ class TTSProviderBase(ISpeechSynthesisService, ABC):
|
|
322 |
file_age = current_time - file_path.stat().st_mtime
|
323 |
if file_age > max_age_seconds:
|
324 |
file_path.unlink()
|
325 |
-
logger.
|
326 |
|
327 |
except Exception as e:
|
328 |
logger.warning(f"Failed to cleanup temp files: {str(e)}")
|
|
|
322 |
file_age = current_time - file_path.stat().st_mtime
|
323 |
if file_age > max_age_seconds:
|
324 |
file_path.unlink()
|
325 |
+
logger.info(f"Cleaned up old temp file: {file_path}")
|
326 |
|
327 |
except Exception as e:
|
328 |
logger.warning(f"Failed to cleanup temp files: {str(e)}")
|
src/infrastructure/config/app_config.py
CHANGED
@@ -73,14 +73,14 @@ class AppConfig:
|
|
73 |
"""
|
74 |
self.config_file = config_file
|
75 |
self._config_data: Dict[str, Any] = {}
|
76 |
-
|
77 |
# Initialize configuration sections
|
78 |
self.tts = TTSConfig()
|
79 |
self.stt = STTConfig()
|
80 |
self.translation = TranslationConfig()
|
81 |
self.processing = ProcessingConfig()
|
82 |
self.logging = LoggingConfig()
|
83 |
-
|
84 |
# Load configuration
|
85 |
self._load_configuration()
|
86 |
|
@@ -89,16 +89,16 @@ class AppConfig:
|
|
89 |
try:
|
90 |
# Load from environment variables first
|
91 |
self._load_from_environment()
|
92 |
-
|
93 |
# Load from config file if provided
|
94 |
if self.config_file and os.path.exists(self.config_file):
|
95 |
self._load_from_file()
|
96 |
-
|
97 |
# Validate configuration
|
98 |
self._validate_configuration()
|
99 |
-
|
100 |
logger.info("Configuration loaded successfully")
|
101 |
-
|
102 |
except Exception as e:
|
103 |
logger.error(f"Failed to load configuration: {e}")
|
104 |
# Use default configuration
|
@@ -158,7 +158,7 @@ class AppConfig:
|
|
158 |
"""Load configuration from file (JSON or YAML)."""
|
159 |
try:
|
160 |
import json
|
161 |
-
|
162 |
with open(self.config_file, 'r') as f:
|
163 |
if self.config_file.endswith('.json'):
|
164 |
self._config_data = json.load(f)
|
@@ -175,7 +175,7 @@ class AppConfig:
|
|
175 |
|
176 |
# Apply configuration from file
|
177 |
self._apply_config_data()
|
178 |
-
|
179 |
except Exception as e:
|
180 |
logger.error(f"Failed to load config file {self.config_file}: {e}")
|
181 |
|
|
|
73 |
"""
|
74 |
self.config_file = config_file
|
75 |
self._config_data: Dict[str, Any] = {}
|
76 |
+
|
77 |
# Initialize configuration sections
|
78 |
self.tts = TTSConfig()
|
79 |
self.stt = STTConfig()
|
80 |
self.translation = TranslationConfig()
|
81 |
self.processing = ProcessingConfig()
|
82 |
self.logging = LoggingConfig()
|
83 |
+
|
84 |
# Load configuration
|
85 |
self._load_configuration()
|
86 |
|
|
|
89 |
try:
|
90 |
# Load from environment variables first
|
91 |
self._load_from_environment()
|
92 |
+
|
93 |
# Load from config file if provided
|
94 |
if self.config_file and os.path.exists(self.config_file):
|
95 |
self._load_from_file()
|
96 |
+
|
97 |
# Validate configuration
|
98 |
self._validate_configuration()
|
99 |
+
|
100 |
logger.info("Configuration loaded successfully")
|
101 |
+
|
102 |
except Exception as e:
|
103 |
logger.error(f"Failed to load configuration: {e}")
|
104 |
# Use default configuration
|
|
|
158 |
"""Load configuration from file (JSON or YAML)."""
|
159 |
try:
|
160 |
import json
|
161 |
+
|
162 |
with open(self.config_file, 'r') as f:
|
163 |
if self.config_file.endswith('.json'):
|
164 |
self._config_data = json.load(f)
|
|
|
175 |
|
176 |
# Apply configuration from file
|
177 |
self._apply_config_data()
|
178 |
+
|
179 |
except Exception as e:
|
180 |
logger.error(f"Failed to load config file {self.config_file}: {e}")
|
181 |
|
src/infrastructure/config/dependency_container.py
CHANGED
@@ -309,19 +309,25 @@ class DependencyContainer:
|
|
309 |
Returns:
|
310 |
ISpeechSynthesisService: TTS provider instance
|
311 |
"""
|
|
|
312 |
factory = self.resolve(TTSProviderFactory)
|
313 |
|
314 |
if provider_name:
|
|
|
315 |
try:
|
316 |
-
|
|
|
|
|
317 |
except Exception as e:
|
318 |
-
logger.warning(f"Failed to create specific TTS provider {provider_name}: {e}")
|
319 |
-
logger.info("Falling back to default provider selection")
|
320 |
# Fall back to default provider selection
|
321 |
preferred_providers = self._config.tts.preferred_providers
|
|
|
322 |
return factory.get_provider_with_fallback(preferred_providers, **kwargs)
|
323 |
else:
|
324 |
preferred_providers = self._config.tts.preferred_providers
|
|
|
325 |
return factory.get_provider_with_fallback(preferred_providers, **kwargs)
|
326 |
|
327 |
def get_stt_provider(self, provider_name: Optional[str] = None) -> ISpeechRecognitionService:
|
|
|
309 |
Returns:
|
310 |
ISpeechSynthesisService: TTS provider instance
|
311 |
"""
|
312 |
+
logger.info(f"🎯 Requesting TTS provider: {provider_name or 'default'}")
|
313 |
factory = self.resolve(TTSProviderFactory)
|
314 |
|
315 |
if provider_name:
|
316 |
+
logger.info(f"🔧 Attempting to create specific TTS provider: {provider_name}")
|
317 |
try:
|
318 |
+
provider = factory.create_provider(provider_name, **kwargs)
|
319 |
+
logger.info(f"✅ Successfully created TTS provider: {provider_name}")
|
320 |
+
return provider
|
321 |
except Exception as e:
|
322 |
+
logger.warning(f"❌ Failed to create specific TTS provider {provider_name}: {e}")
|
323 |
+
logger.info("🔄 Falling back to default provider selection")
|
324 |
# Fall back to default provider selection
|
325 |
preferred_providers = self._config.tts.preferred_providers
|
326 |
+
logger.info(f"📋 Preferred providers for fallback: {preferred_providers}")
|
327 |
return factory.get_provider_with_fallback(preferred_providers, **kwargs)
|
328 |
else:
|
329 |
preferred_providers = self._config.tts.preferred_providers
|
330 |
+
logger.info(f"📋 Using preferred providers: {preferred_providers}")
|
331 |
return factory.get_provider_with_fallback(preferred_providers, **kwargs)
|
332 |
|
333 |
def get_stt_provider(self, provider_name: Optional[str] = None) -> ISpeechRecognitionService:
|
src/infrastructure/stt/legacy_compatibility.py
CHANGED
@@ -14,37 +14,37 @@ logger = logging.getLogger(__name__)
|
|
14 |
def transcribe_audio(audio_path: Union[str, Path], model_name: str = "parakeet") -> str:
|
15 |
"""
|
16 |
Convert audio file to text using specified STT model (legacy interface).
|
17 |
-
|
18 |
This function maintains backward compatibility with the original utils/stt.py interface.
|
19 |
-
|
20 |
Args:
|
21 |
audio_path: Path to input audio file
|
22 |
model_name: Name of the STT model/provider to use (whisper or parakeet)
|
23 |
-
|
24 |
Returns:
|
25 |
str: Transcribed English text
|
26 |
-
|
27 |
Raises:
|
28 |
SpeechRecognitionException: If transcription fails
|
29 |
"""
|
30 |
logger.info(f"Starting transcription for: {audio_path} using {model_name} model")
|
31 |
-
|
32 |
try:
|
33 |
# Convert path to Path object
|
34 |
audio_path = Path(audio_path)
|
35 |
-
|
36 |
if not audio_path.exists():
|
37 |
raise SpeechRecognitionException(f"Audio file not found: {audio_path}")
|
38 |
-
|
39 |
# Read audio file and create AudioContent
|
40 |
with open(audio_path, 'rb') as f:
|
41 |
audio_data = f.read()
|
42 |
-
|
43 |
# Determine audio format from file extension
|
44 |
audio_format = audio_path.suffix.lower().lstrip('.')
|
45 |
if audio_format not in ['wav', 'mp3', 'flac', 'ogg']:
|
46 |
audio_format = 'wav' # Default fallback
|
47 |
-
|
48 |
# Create AudioContent (we'll use reasonable placeholder values)
|
49 |
# The provider will handle the actual audio analysis during preprocessing
|
50 |
try:
|
@@ -64,7 +64,7 @@ def transcribe_audio(audio_path: Union[str, Path], model_name: str = "parakeet")
|
|
64 |
duration=1.0, # Minimum valid duration
|
65 |
filename=audio_path.name
|
66 |
)
|
67 |
-
|
68 |
# Get the appropriate provider
|
69 |
try:
|
70 |
provider = STTProviderFactory.create_provider(model_name)
|
@@ -72,14 +72,14 @@ def transcribe_audio(audio_path: Union[str, Path], model_name: str = "parakeet")
|
|
72 |
# Fallback to any available provider
|
73 |
logger.warning(f"Requested provider {model_name} not available, using fallback")
|
74 |
provider = STTProviderFactory.create_provider_with_fallback(model_name)
|
75 |
-
|
76 |
# Get the default model for the provider
|
77 |
model = provider.get_default_model()
|
78 |
-
|
79 |
# Transcribe audio
|
80 |
text_content = provider.transcribe(audio_content, model)
|
81 |
result = text_content.text
|
82 |
-
|
83 |
logger.info(f"Transcription completed: {result}")
|
84 |
return result
|
85 |
|
@@ -91,33 +91,33 @@ def transcribe_audio(audio_path: Union[str, Path], model_name: str = "parakeet")
|
|
91 |
def create_audio_content_from_file(audio_path: Union[str, Path]) -> AudioContent:
|
92 |
"""
|
93 |
Create AudioContent from an audio file with proper metadata detection.
|
94 |
-
|
95 |
Args:
|
96 |
audio_path: Path to the audio file
|
97 |
-
|
98 |
Returns:
|
99 |
AudioContent: The audio content object
|
100 |
-
|
101 |
Raises:
|
102 |
SpeechRecognitionException: If file cannot be processed
|
103 |
"""
|
104 |
try:
|
105 |
from pydub import AudioSegment
|
106 |
-
|
107 |
audio_path = Path(audio_path)
|
108 |
-
|
109 |
# Load audio file to get metadata
|
110 |
audio_segment = AudioSegment.from_file(audio_path)
|
111 |
-
|
112 |
# Read raw audio data
|
113 |
with open(audio_path, 'rb') as f:
|
114 |
audio_data = f.read()
|
115 |
-
|
116 |
# Determine format
|
117 |
audio_format = audio_path.suffix.lower().lstrip('.')
|
118 |
if audio_format not in ['wav', 'mp3', 'flac', 'ogg']:
|
119 |
audio_format = 'wav'
|
120 |
-
|
121 |
# Create AudioContent with actual metadata
|
122 |
return AudioContent(
|
123 |
data=audio_data,
|
@@ -126,18 +126,18 @@ def create_audio_content_from_file(audio_path: Union[str, Path]) -> AudioContent
|
|
126 |
duration=len(audio_segment) / 1000.0, # Convert ms to seconds
|
127 |
filename=audio_path.name
|
128 |
)
|
129 |
-
|
130 |
except ImportError:
|
131 |
# Fallback without pydub
|
132 |
logger.warning("pydub not available, using placeholder metadata")
|
133 |
-
|
134 |
with open(audio_path, 'rb') as f:
|
135 |
audio_data = f.read()
|
136 |
-
|
137 |
audio_format = Path(audio_path).suffix.lower().lstrip('.')
|
138 |
if audio_format not in ['wav', 'mp3', 'flac', 'ogg']:
|
139 |
audio_format = 'wav'
|
140 |
-
|
141 |
return AudioContent(
|
142 |
data=audio_data,
|
143 |
format=audio_format,
|
@@ -145,6 +145,6 @@ def create_audio_content_from_file(audio_path: Union[str, Path]) -> AudioContent
|
|
145 |
duration=1.0, # Placeholder
|
146 |
filename=Path(audio_path).name
|
147 |
)
|
148 |
-
|
149 |
except Exception as e:
|
150 |
raise SpeechRecognitionException(f"Failed to create AudioContent from file: {str(e)}") from e
|
|
|
14 |
def transcribe_audio(audio_path: Union[str, Path], model_name: str = "parakeet") -> str:
|
15 |
"""
|
16 |
Convert audio file to text using specified STT model (legacy interface).
|
17 |
+
|
18 |
This function maintains backward compatibility with the original utils/stt.py interface.
|
19 |
+
|
20 |
Args:
|
21 |
audio_path: Path to input audio file
|
22 |
model_name: Name of the STT model/provider to use (whisper or parakeet)
|
23 |
+
|
24 |
Returns:
|
25 |
str: Transcribed English text
|
26 |
+
|
27 |
Raises:
|
28 |
SpeechRecognitionException: If transcription fails
|
29 |
"""
|
30 |
logger.info(f"Starting transcription for: {audio_path} using {model_name} model")
|
31 |
+
|
32 |
try:
|
33 |
# Convert path to Path object
|
34 |
audio_path = Path(audio_path)
|
35 |
+
|
36 |
if not audio_path.exists():
|
37 |
raise SpeechRecognitionException(f"Audio file not found: {audio_path}")
|
38 |
+
|
39 |
# Read audio file and create AudioContent
|
40 |
with open(audio_path, 'rb') as f:
|
41 |
audio_data = f.read()
|
42 |
+
|
43 |
# Determine audio format from file extension
|
44 |
audio_format = audio_path.suffix.lower().lstrip('.')
|
45 |
if audio_format not in ['wav', 'mp3', 'flac', 'ogg']:
|
46 |
audio_format = 'wav' # Default fallback
|
47 |
+
|
48 |
# Create AudioContent (we'll use reasonable placeholder values)
|
49 |
# The provider will handle the actual audio analysis during preprocessing
|
50 |
try:
|
|
|
64 |
duration=1.0, # Minimum valid duration
|
65 |
filename=audio_path.name
|
66 |
)
|
67 |
+
|
68 |
# Get the appropriate provider
|
69 |
try:
|
70 |
provider = STTProviderFactory.create_provider(model_name)
|
|
|
72 |
# Fallback to any available provider
|
73 |
logger.warning(f"Requested provider {model_name} not available, using fallback")
|
74 |
provider = STTProviderFactory.create_provider_with_fallback(model_name)
|
75 |
+
|
76 |
# Get the default model for the provider
|
77 |
model = provider.get_default_model()
|
78 |
+
|
79 |
# Transcribe audio
|
80 |
text_content = provider.transcribe(audio_content, model)
|
81 |
result = text_content.text
|
82 |
+
|
83 |
logger.info(f"Transcription completed: {result}")
|
84 |
return result
|
85 |
|
|
|
91 |
def create_audio_content_from_file(audio_path: Union[str, Path]) -> AudioContent:
|
92 |
"""
|
93 |
Create AudioContent from an audio file with proper metadata detection.
|
94 |
+
|
95 |
Args:
|
96 |
audio_path: Path to the audio file
|
97 |
+
|
98 |
Returns:
|
99 |
AudioContent: The audio content object
|
100 |
+
|
101 |
Raises:
|
102 |
SpeechRecognitionException: If file cannot be processed
|
103 |
"""
|
104 |
try:
|
105 |
from pydub import AudioSegment
|
106 |
+
|
107 |
audio_path = Path(audio_path)
|
108 |
+
|
109 |
# Load audio file to get metadata
|
110 |
audio_segment = AudioSegment.from_file(audio_path)
|
111 |
+
|
112 |
# Read raw audio data
|
113 |
with open(audio_path, 'rb') as f:
|
114 |
audio_data = f.read()
|
115 |
+
|
116 |
# Determine format
|
117 |
audio_format = audio_path.suffix.lower().lstrip('.')
|
118 |
if audio_format not in ['wav', 'mp3', 'flac', 'ogg']:
|
119 |
audio_format = 'wav'
|
120 |
+
|
121 |
# Create AudioContent with actual metadata
|
122 |
return AudioContent(
|
123 |
data=audio_data,
|
|
|
126 |
duration=len(audio_segment) / 1000.0, # Convert ms to seconds
|
127 |
filename=audio_path.name
|
128 |
)
|
129 |
+
|
130 |
except ImportError:
|
131 |
# Fallback without pydub
|
132 |
logger.warning("pydub not available, using placeholder metadata")
|
133 |
+
|
134 |
with open(audio_path, 'rb') as f:
|
135 |
audio_data = f.read()
|
136 |
+
|
137 |
audio_format = Path(audio_path).suffix.lower().lstrip('.')
|
138 |
if audio_format not in ['wav', 'mp3', 'flac', 'ogg']:
|
139 |
audio_format = 'wav'
|
140 |
+
|
141 |
return AudioContent(
|
142 |
data=audio_data,
|
143 |
format=audio_format,
|
|
|
145 |
duration=1.0, # Placeholder
|
146 |
filename=Path(audio_path).name
|
147 |
)
|
148 |
+
|
149 |
except Exception as e:
|
150 |
raise SpeechRecognitionException(f"Failed to create AudioContent from file: {str(e)}") from e
|
src/infrastructure/stt/parakeet_provider.py
CHANGED
@@ -42,11 +42,11 @@ class ParakeetSTTProvider(STTProviderBase):
|
|
42 |
self._load_model(model)
|
43 |
|
44 |
logger.info(f"Starting Parakeet transcription with model {model}")
|
45 |
-
|
46 |
# Perform transcription
|
47 |
output = self.model.transcribe([str(audio_path)])
|
48 |
result = output[0].text if output and len(output) > 0 else ""
|
49 |
-
|
50 |
logger.info("Parakeet transcription completed successfully")
|
51 |
return result
|
52 |
|
@@ -62,9 +62,9 @@ class ParakeetSTTProvider(STTProviderBase):
|
|
62 |
"""
|
63 |
try:
|
64 |
import nemo.collections.asr as nemo_asr
|
65 |
-
|
66 |
logger.info(f"Loading Parakeet model: {model_name}")
|
67 |
-
|
68 |
# Map model names to actual model identifiers
|
69 |
model_mapping = {
|
70 |
"parakeet-tdt-0.6b-v2": "nvidia/parakeet-tdt-0.6b-v2",
|
@@ -72,12 +72,12 @@ class ParakeetSTTProvider(STTProviderBase):
|
|
72 |
"parakeet-ctc-0.6b": "nvidia/parakeet-ctc-0.6b",
|
73 |
"default": "nvidia/parakeet-tdt-0.6b-v2"
|
74 |
}
|
75 |
-
|
76 |
actual_model_name = model_mapping.get(model_name, model_mapping["default"])
|
77 |
-
|
78 |
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=actual_model_name)
|
79 |
logger.info(f"Parakeet model {model_name} loaded successfully")
|
80 |
-
|
81 |
except ImportError as e:
|
82 |
raise SpeechRecognitionException(
|
83 |
"nemo_toolkit not available. Please install with: pip install -U 'nemo_toolkit[asr]'"
|
@@ -108,7 +108,7 @@ class ParakeetSTTProvider(STTProviderBase):
|
|
108 |
"""
|
109 |
return [
|
110 |
"parakeet-tdt-0.6b-v2",
|
111 |
-
"parakeet-tdt-1.1b",
|
112 |
"parakeet-ctc-0.6b"
|
113 |
]
|
114 |
|
|
|
42 |
self._load_model(model)
|
43 |
|
44 |
logger.info(f"Starting Parakeet transcription with model {model}")
|
45 |
+
|
46 |
# Perform transcription
|
47 |
output = self.model.transcribe([str(audio_path)])
|
48 |
result = output[0].text if output and len(output) > 0 else ""
|
49 |
+
|
50 |
logger.info("Parakeet transcription completed successfully")
|
51 |
return result
|
52 |
|
|
|
62 |
"""
|
63 |
try:
|
64 |
import nemo.collections.asr as nemo_asr
|
65 |
+
|
66 |
logger.info(f"Loading Parakeet model: {model_name}")
|
67 |
+
|
68 |
# Map model names to actual model identifiers
|
69 |
model_mapping = {
|
70 |
"parakeet-tdt-0.6b-v2": "nvidia/parakeet-tdt-0.6b-v2",
|
|
|
72 |
"parakeet-ctc-0.6b": "nvidia/parakeet-ctc-0.6b",
|
73 |
"default": "nvidia/parakeet-tdt-0.6b-v2"
|
74 |
}
|
75 |
+
|
76 |
actual_model_name = model_mapping.get(model_name, model_mapping["default"])
|
77 |
+
|
78 |
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=actual_model_name)
|
79 |
logger.info(f"Parakeet model {model_name} loaded successfully")
|
80 |
+
|
81 |
except ImportError as e:
|
82 |
raise SpeechRecognitionException(
|
83 |
"nemo_toolkit not available. Please install with: pip install -U 'nemo_toolkit[asr]'"
|
|
|
108 |
"""
|
109 |
return [
|
110 |
"parakeet-tdt-0.6b-v2",
|
111 |
+
"parakeet-tdt-1.1b",
|
112 |
"parakeet-ctc-0.6b"
|
113 |
]
|
114 |
|
src/infrastructure/stt/provider_factory.py
CHANGED
@@ -36,21 +36,21 @@ class STTProviderFactory:
|
|
36 |
SpeechRecognitionException: If provider is not available or creation fails
|
37 |
"""
|
38 |
provider_name = provider_name.lower()
|
39 |
-
|
40 |
if provider_name not in cls._providers:
|
41 |
raise SpeechRecognitionException(f"Unknown STT provider: {provider_name}")
|
42 |
|
43 |
provider_class = cls._providers[provider_name]
|
44 |
-
|
45 |
try:
|
46 |
provider = provider_class()
|
47 |
-
|
48 |
if not provider.is_available():
|
49 |
raise SpeechRecognitionException(f"STT provider {provider_name} is not available")
|
50 |
-
|
51 |
logger.info(f"Created STT provider: {provider_name}")
|
52 |
return provider
|
53 |
-
|
54 |
except Exception as e:
|
55 |
logger.error(f"Failed to create STT provider {provider_name}: {str(e)}")
|
56 |
raise SpeechRecognitionException(f"Failed to create STT provider {provider_name}: {str(e)}") from e
|
@@ -79,7 +79,7 @@ class STTProviderFactory:
|
|
79 |
for provider_name in cls._fallback_order:
|
80 |
if provider_name.lower() == preferred_provider.lower():
|
81 |
continue # Skip the preferred provider we already tried
|
82 |
-
|
83 |
try:
|
84 |
logger.info(f"Trying fallback STT provider: {provider_name}")
|
85 |
return cls.create_provider(provider_name)
|
@@ -98,15 +98,15 @@ class STTProviderFactory:
|
|
98 |
list[str]: List of available provider names
|
99 |
"""
|
100 |
available = []
|
101 |
-
|
102 |
for provider_name, provider_class in cls._providers.items():
|
103 |
try:
|
104 |
provider = provider_class()
|
105 |
if provider.is_available():
|
106 |
available.append(provider_name)
|
107 |
except Exception as e:
|
108 |
-
logger.
|
109 |
-
|
110 |
return available
|
111 |
|
112 |
@classmethod
|
@@ -121,12 +121,12 @@ class STTProviderFactory:
|
|
121 |
Optional[dict]: Provider information or None if not found
|
122 |
"""
|
123 |
provider_name = provider_name.lower()
|
124 |
-
|
125 |
if provider_name not in cls._providers:
|
126 |
return None
|
127 |
|
128 |
provider_class = cls._providers[provider_name]
|
129 |
-
|
130 |
try:
|
131 |
provider = provider_class()
|
132 |
return {
|
@@ -137,7 +137,7 @@ class STTProviderFactory:
|
|
137 |
"default_model": provider.get_default_model() if provider.is_available() else None
|
138 |
}
|
139 |
except Exception as e:
|
140 |
-
logger.
|
141 |
return {
|
142 |
"name": provider_name,
|
143 |
"available": False,
|
@@ -160,15 +160,15 @@ class STTProviderFactory:
|
|
160 |
# Legacy compatibility - create an ASRFactory alias
|
161 |
class ASRFactory:
|
162 |
"""Legacy ASRFactory for backward compatibility."""
|
163 |
-
|
164 |
@staticmethod
|
165 |
def get_model(model_name: str = "parakeet") -> STTProviderBase:
|
166 |
"""
|
167 |
Get STT provider by model name (legacy interface).
|
168 |
-
|
169 |
Args:
|
170 |
model_name: Name of the model/provider to use
|
171 |
-
|
172 |
Returns:
|
173 |
STTProviderBase: The provider instance
|
174 |
"""
|
@@ -178,9 +178,9 @@ class ASRFactory:
|
|
178 |
"parakeet": "parakeet",
|
179 |
"faster-whisper": "whisper"
|
180 |
}
|
181 |
-
|
182 |
provider_name = provider_mapping.get(model_name.lower(), model_name.lower())
|
183 |
-
|
184 |
try:
|
185 |
return STTProviderFactory.create_provider(provider_name)
|
186 |
except SpeechRecognitionException:
|
|
|
36 |
SpeechRecognitionException: If provider is not available or creation fails
|
37 |
"""
|
38 |
provider_name = provider_name.lower()
|
39 |
+
|
40 |
if provider_name not in cls._providers:
|
41 |
raise SpeechRecognitionException(f"Unknown STT provider: {provider_name}")
|
42 |
|
43 |
provider_class = cls._providers[provider_name]
|
44 |
+
|
45 |
try:
|
46 |
provider = provider_class()
|
47 |
+
|
48 |
if not provider.is_available():
|
49 |
raise SpeechRecognitionException(f"STT provider {provider_name} is not available")
|
50 |
+
|
51 |
logger.info(f"Created STT provider: {provider_name}")
|
52 |
return provider
|
53 |
+
|
54 |
except Exception as e:
|
55 |
logger.error(f"Failed to create STT provider {provider_name}: {str(e)}")
|
56 |
raise SpeechRecognitionException(f"Failed to create STT provider {provider_name}: {str(e)}") from e
|
|
|
79 |
for provider_name in cls._fallback_order:
|
80 |
if provider_name.lower() == preferred_provider.lower():
|
81 |
continue # Skip the preferred provider we already tried
|
82 |
+
|
83 |
try:
|
84 |
logger.info(f"Trying fallback STT provider: {provider_name}")
|
85 |
return cls.create_provider(provider_name)
|
|
|
98 |
list[str]: List of available provider names
|
99 |
"""
|
100 |
available = []
|
101 |
+
|
102 |
for provider_name, provider_class in cls._providers.items():
|
103 |
try:
|
104 |
provider = provider_class()
|
105 |
if provider.is_available():
|
106 |
available.append(provider_name)
|
107 |
except Exception as e:
|
108 |
+
logger.info(f"Provider {provider_name} not available: {str(e)}")
|
109 |
+
|
110 |
return available
|
111 |
|
112 |
@classmethod
|
|
|
121 |
Optional[dict]: Provider information or None if not found
|
122 |
"""
|
123 |
provider_name = provider_name.lower()
|
124 |
+
|
125 |
if provider_name not in cls._providers:
|
126 |
return None
|
127 |
|
128 |
provider_class = cls._providers[provider_name]
|
129 |
+
|
130 |
try:
|
131 |
provider = provider_class()
|
132 |
return {
|
|
|
137 |
"default_model": provider.get_default_model() if provider.is_available() else None
|
138 |
}
|
139 |
except Exception as e:
|
140 |
+
logger.info(f"Failed to get info for provider {provider_name}: {str(e)}")
|
141 |
return {
|
142 |
"name": provider_name,
|
143 |
"available": False,
|
|
|
160 |
# Legacy compatibility - create an ASRFactory alias
|
161 |
class ASRFactory:
|
162 |
"""Legacy ASRFactory for backward compatibility."""
|
163 |
+
|
164 |
@staticmethod
|
165 |
def get_model(model_name: str = "parakeet") -> STTProviderBase:
|
166 |
"""
|
167 |
Get STT provider by model name (legacy interface).
|
168 |
+
|
169 |
Args:
|
170 |
model_name: Name of the model/provider to use
|
171 |
+
|
172 |
Returns:
|
173 |
STTProviderBase: The provider instance
|
174 |
"""
|
|
|
178 |
"parakeet": "parakeet",
|
179 |
"faster-whisper": "whisper"
|
180 |
}
|
181 |
+
|
182 |
provider_name = provider_mapping.get(model_name.lower(), model_name.lower())
|
183 |
+
|
184 |
try:
|
185 |
return STTProviderFactory.create_provider(provider_name)
|
186 |
except SpeechRecognitionException:
|
src/infrastructure/stt/whisper_provider.py
CHANGED
@@ -36,7 +36,7 @@ class WhisperSTTProvider(STTProviderBase):
|
|
36 |
except ImportError:
|
37 |
# Fallback to CPU if torch is not available
|
38 |
self._device = "cpu"
|
39 |
-
|
40 |
self._compute_type = "float16" if self._device == "cuda" else "int8"
|
41 |
logger.info(f"Whisper provider initialized with device: {self._device}, compute_type: {self._compute_type}")
|
42 |
|
@@ -57,7 +57,7 @@ class WhisperSTTProvider(STTProviderBase):
|
|
57 |
self._load_model(model)
|
58 |
|
59 |
logger.info(f"Starting Whisper transcription with model {model}")
|
60 |
-
|
61 |
# Perform transcription
|
62 |
segments, info = self.model.transcribe(
|
63 |
str(audio_path),
|
@@ -72,7 +72,7 @@ class WhisperSTTProvider(STTProviderBase):
|
|
72 |
result_text = ""
|
73 |
for segment in segments:
|
74 |
result_text += segment.text + " "
|
75 |
-
logger.
|
76 |
|
77 |
result = result_text.strip()
|
78 |
logger.info("Whisper transcription completed successfully")
|
@@ -90,18 +90,18 @@ class WhisperSTTProvider(STTProviderBase):
|
|
90 |
"""
|
91 |
try:
|
92 |
from faster_whisper import WhisperModel as FasterWhisperModel
|
93 |
-
|
94 |
logger.info(f"Loading Whisper model: {model_name}")
|
95 |
logger.info(f"Using device: {self._device}, compute_type: {self._compute_type}")
|
96 |
-
|
97 |
self.model = FasterWhisperModel(
|
98 |
model_name,
|
99 |
device=self._device,
|
100 |
compute_type=self._compute_type
|
101 |
)
|
102 |
-
|
103 |
logger.info(f"Whisper model {model_name} loaded successfully")
|
104 |
-
|
105 |
except ImportError as e:
|
106 |
raise SpeechRecognitionException(
|
107 |
"faster-whisper not available. Please install with: pip install faster-whisper"
|
@@ -134,7 +134,7 @@ class WhisperSTTProvider(STTProviderBase):
|
|
134 |
"tiny",
|
135 |
"tiny.en",
|
136 |
"base",
|
137 |
-
"base.en",
|
138 |
"small",
|
139 |
"small.en",
|
140 |
"medium",
|
|
|
36 |
except ImportError:
|
37 |
# Fallback to CPU if torch is not available
|
38 |
self._device = "cpu"
|
39 |
+
|
40 |
self._compute_type = "float16" if self._device == "cuda" else "int8"
|
41 |
logger.info(f"Whisper provider initialized with device: {self._device}, compute_type: {self._compute_type}")
|
42 |
|
|
|
57 |
self._load_model(model)
|
58 |
|
59 |
logger.info(f"Starting Whisper transcription with model {model}")
|
60 |
+
|
61 |
# Perform transcription
|
62 |
segments, info = self.model.transcribe(
|
63 |
str(audio_path),
|
|
|
72 |
result_text = ""
|
73 |
for segment in segments:
|
74 |
result_text += segment.text + " "
|
75 |
+
logger.info(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}")
|
76 |
|
77 |
result = result_text.strip()
|
78 |
logger.info("Whisper transcription completed successfully")
|
|
|
90 |
"""
|
91 |
try:
|
92 |
from faster_whisper import WhisperModel as FasterWhisperModel
|
93 |
+
|
94 |
logger.info(f"Loading Whisper model: {model_name}")
|
95 |
logger.info(f"Using device: {self._device}, compute_type: {self._compute_type}")
|
96 |
+
|
97 |
self.model = FasterWhisperModel(
|
98 |
model_name,
|
99 |
device=self._device,
|
100 |
compute_type=self._compute_type
|
101 |
)
|
102 |
+
|
103 |
logger.info(f"Whisper model {model_name} loaded successfully")
|
104 |
+
|
105 |
except ImportError as e:
|
106 |
raise SpeechRecognitionException(
|
107 |
"faster-whisper not available. Please install with: pip install faster-whisper"
|
|
|
134 |
"tiny",
|
135 |
"tiny.en",
|
136 |
"base",
|
137 |
+
"base.en",
|
138 |
"small",
|
139 |
"small.en",
|
140 |
"medium",
|
src/infrastructure/translation/nllb_provider.py
CHANGED
@@ -430,7 +430,7 @@ class NLLBTranslationProvider(TranslationProviderBase):
|
|
430 |
# For simplicity, assume all languages can translate to all other languages
|
431 |
# In practice, you might want to be more specific about supported pairs
|
432 |
supported_languages[lang_code] = [
|
433 |
-
target for target in self.LANGUAGE_MAPPINGS.keys()
|
434 |
if target != lang_code
|
435 |
]
|
436 |
|
@@ -465,7 +465,7 @@ class NLLBTranslationProvider(TranslationProviderBase):
|
|
465 |
source_nllb = self._map_language_code(source_language)
|
466 |
target_nllb = self._map_language_code(target_language)
|
467 |
|
468 |
-
logger.
|
469 |
|
470 |
# Tokenize with source language specification
|
471 |
inputs = self._tokenizer(
|
@@ -490,7 +490,7 @@ class NLLBTranslationProvider(TranslationProviderBase):
|
|
490 |
# Post-process the translation
|
491 |
translated = self._postprocess_text(translated)
|
492 |
|
493 |
-
logger.
|
494 |
return translated
|
495 |
|
496 |
except Exception as e:
|
|
|
430 |
# For simplicity, assume all languages can translate to all other languages
|
431 |
# In practice, you might want to be more specific about supported pairs
|
432 |
supported_languages[lang_code] = [
|
433 |
+
target for target in self.LANGUAGE_MAPPINGS.keys()
|
434 |
if target != lang_code
|
435 |
]
|
436 |
|
|
|
465 |
source_nllb = self._map_language_code(source_language)
|
466 |
target_nllb = self._map_language_code(target_language)
|
467 |
|
468 |
+
logger.info(f"Translating chunk from {source_nllb} to {target_nllb}")
|
469 |
|
470 |
# Tokenize with source language specification
|
471 |
inputs = self._tokenizer(
|
|
|
490 |
# Post-process the translation
|
491 |
translated = self._postprocess_text(translated)
|
492 |
|
493 |
+
logger.info(f"Chunk translation completed: {len(text)} -> {len(translated)} chars")
|
494 |
return translated
|
495 |
|
496 |
except Exception as e:
|
src/infrastructure/translation/provider_factory.py
CHANGED
@@ -67,7 +67,7 @@ class TranslationProviderFactory:
|
|
67 |
|
68 |
# Return cached instance if available and requested
|
69 |
if use_cache and cache_key in self._provider_cache:
|
70 |
-
logger.
|
71 |
return self._provider_cache[cache_key]
|
72 |
|
73 |
# Check if provider type is registered
|
@@ -86,7 +86,7 @@ class TranslationProviderFactory:
|
|
86 |
final_config.update(config)
|
87 |
|
88 |
logger.info(f"Creating {provider_type.value} translation provider")
|
89 |
-
logger.
|
90 |
|
91 |
# Create provider instance
|
92 |
provider = provider_class(**final_config)
|
@@ -258,7 +258,7 @@ class TranslationProviderFactory:
|
|
258 |
# Cache the result
|
259 |
self._availability_cache[provider_type] = is_available
|
260 |
|
261 |
-
logger.
|
262 |
return is_available
|
263 |
|
264 |
except Exception as e:
|
|
|
67 |
|
68 |
# Return cached instance if available and requested
|
69 |
if use_cache and cache_key in self._provider_cache:
|
70 |
+
logger.info(f"Returning cached {provider_type.value} provider")
|
71 |
return self._provider_cache[cache_key]
|
72 |
|
73 |
# Check if provider type is registered
|
|
|
86 |
final_config.update(config)
|
87 |
|
88 |
logger.info(f"Creating {provider_type.value} translation provider")
|
89 |
+
logger.info(f"Provider config: {final_config}")
|
90 |
|
91 |
# Create provider instance
|
92 |
provider = provider_class(**final_config)
|
|
|
258 |
# Cache the result
|
259 |
self._availability_cache[provider_type] = is_available
|
260 |
|
261 |
+
logger.info(f"Provider {provider_type.value} availability: {is_available}")
|
262 |
return is_available
|
263 |
|
264 |
except Exception as e:
|
src/infrastructure/tts/dia_provider.py
CHANGED
@@ -19,19 +19,70 @@ DIA_AVAILABLE = False
|
|
19 |
DEFAULT_SAMPLE_RATE = 24000
|
20 |
|
21 |
# Try to import Dia dependencies
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
logger.info("Dia TTS
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
logger.
|
32 |
-
|
33 |
-
logger.
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
|
37 |
class DiaTTSProvider(TTSProviderBase):
|
@@ -48,26 +99,58 @@ class DiaTTSProvider(TTSProviderBase):
|
|
48 |
|
49 |
def _ensure_model(self):
|
50 |
"""Ensure the model is loaded."""
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
logger.
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
def is_available(self) -> bool:
|
69 |
"""Check if Dia TTS is available."""
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
def get_available_voices(self) -> list[str]:
|
73 |
"""Get available voices for Dia."""
|
@@ -81,7 +164,7 @@ class DiaTTSProvider(TTSProviderBase):
|
|
81 |
|
82 |
try:
|
83 |
import torch
|
84 |
-
|
85 |
# Extract parameters from request
|
86 |
text = request.text_content.text
|
87 |
|
@@ -120,7 +203,7 @@ class DiaTTSProvider(TTSProviderBase):
|
|
120 |
|
121 |
try:
|
122 |
import torch
|
123 |
-
|
124 |
# Extract parameters from request
|
125 |
text = request.text_content.text
|
126 |
|
@@ -158,13 +241,13 @@ class DiaTTSProvider(TTSProviderBase):
|
|
158 |
try:
|
159 |
# Create an in-memory buffer
|
160 |
buffer = io.BytesIO()
|
161 |
-
|
162 |
# Write audio data to buffer as WAV
|
163 |
sf.write(buffer, audio_array, sample_rate, format='WAV')
|
164 |
-
|
165 |
# Get bytes from buffer
|
166 |
buffer.seek(0)
|
167 |
return buffer.read()
|
168 |
-
|
169 |
except Exception as e:
|
170 |
raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e
|
|
|
19 |
DEFAULT_SAMPLE_RATE = 24000
|
20 |
|
21 |
# Try to import Dia dependencies
|
22 |
+
def _check_and_install_dia_dependencies():
|
23 |
+
"""Check and install Dia dependencies if needed."""
|
24 |
+
global DIA_AVAILABLE
|
25 |
+
|
26 |
+
logger.info("🔍 Checking Dia TTS dependencies...")
|
27 |
+
|
28 |
+
try:
|
29 |
+
logger.info("Attempting to import torch...")
|
30 |
+
import torch
|
31 |
+
logger.info("✓ Successfully imported torch")
|
32 |
+
|
33 |
+
logger.info("Attempting to import dia.model...")
|
34 |
+
from dia.model import Dia
|
35 |
+
logger.info("✓ Successfully imported dia.model")
|
36 |
+
|
37 |
+
DIA_AVAILABLE = True
|
38 |
+
logger.info("✅ Dia TTS engine is available")
|
39 |
+
return True
|
40 |
+
except ImportError as e:
|
41 |
+
logger.warning(f"⚠️ Dia TTS engine dependencies not available: {e}")
|
42 |
+
logger.info(f"ImportError details: {type(e).__name__}: {e}")
|
43 |
+
except ModuleNotFoundError as e:
|
44 |
+
if "dac" in str(e):
|
45 |
+
logger.warning("❌ Dia TTS engine is not available due to missing 'dac' module")
|
46 |
+
elif "dia" in str(e):
|
47 |
+
logger.warning("❌ Dia TTS engine is not available due to missing 'dia' module")
|
48 |
+
else:
|
49 |
+
logger.warning(f"❌ Dia TTS engine is not available: {str(e)}")
|
50 |
+
logger.info(f"ModuleNotFoundError details: {type(e).__name__}: {e}")
|
51 |
+
|
52 |
+
# Try to install missing dependencies
|
53 |
+
logger.info("🔧 Attempting to install Dia TTS dependencies...")
|
54 |
+
try:
|
55 |
+
installer = get_dependency_installer()
|
56 |
+
success, errors = installer.install_dia_dependencies()
|
57 |
+
|
58 |
+
if success:
|
59 |
+
logger.info("✅ Successfully installed Dia TTS dependencies")
|
60 |
+
# Try importing again after installation
|
61 |
+
try:
|
62 |
+
logger.info("Re-attempting import after installation...")
|
63 |
+
import torch
|
64 |
+
from dia.model import Dia
|
65 |
+
DIA_AVAILABLE = True
|
66 |
+
logger.info("🎉 Dia TTS engine is now available after installation")
|
67 |
+
return True
|
68 |
+
except Exception as e:
|
69 |
+
logger.error(f"❌ Dia TTS still not available after installation: {e}")
|
70 |
+
logger.info(f"Post-installation import error: {type(e).__name__}: {e}")
|
71 |
+
DIA_AVAILABLE = False
|
72 |
+
return False
|
73 |
+
else:
|
74 |
+
logger.error(f"❌ Failed to install Dia TTS dependencies: {errors}")
|
75 |
+
DIA_AVAILABLE = False
|
76 |
+
return False
|
77 |
+
except Exception as e:
|
78 |
+
logger.error(f"❌ Error during dependency installation: {e}")
|
79 |
+
logger.info(f"Installation error details: {type(e).__name__}: {e}")
|
80 |
+
DIA_AVAILABLE = False
|
81 |
+
return False
|
82 |
+
|
83 |
+
# Initial check
|
84 |
+
logger.info("🚀 Initializing Dia TTS provider...")
|
85 |
+
_check_and_install_dia_dependencies()
|
86 |
|
87 |
|
88 |
class DiaTTSProvider(TTSProviderBase):
|
|
|
99 |
|
100 |
def _ensure_model(self):
|
101 |
"""Ensure the model is loaded."""
|
102 |
+
global DIA_AVAILABLE
|
103 |
+
|
104 |
+
if self.model is None:
|
105 |
+
logger.info("🔄 Ensuring Dia model is loaded...")
|
106 |
+
|
107 |
+
# If Dia is not available, try to install dependencies
|
108 |
+
if not DIA_AVAILABLE:
|
109 |
+
logger.info("⚠️ Dia not available, attempting to install dependencies...")
|
110 |
+
if _check_and_install_dia_dependencies():
|
111 |
+
DIA_AVAILABLE = True
|
112 |
+
logger.info("✅ Dependencies installed, Dia is now available")
|
113 |
+
else:
|
114 |
+
logger.error("❌ Failed to install dependencies, Dia remains unavailable")
|
115 |
+
return False
|
116 |
+
|
117 |
+
if DIA_AVAILABLE:
|
118 |
+
try:
|
119 |
+
logger.info("📥 Loading Dia model from pretrained...")
|
120 |
+
import torch
|
121 |
+
from dia.model import Dia
|
122 |
+
self.model = Dia.from_pretrained()
|
123 |
+
logger.info("🎉 Dia model successfully loaded")
|
124 |
+
except ImportError as e:
|
125 |
+
logger.error(f"❌ Failed to import Dia dependencies: {str(e)}")
|
126 |
+
self.model = None
|
127 |
+
except FileNotFoundError as e:
|
128 |
+
logger.error(f"❌ Failed to load Dia model files: {str(e)}")
|
129 |
+
logger.info("ℹ️ This might be the first time loading the model. It will be downloaded automatically.")
|
130 |
+
self.model = None
|
131 |
+
except Exception as e:
|
132 |
+
logger.error(f"❌ Failed to initialize Dia model: {str(e)}")
|
133 |
+
logger.info(f"Model initialization error: {type(e).__name__}: {e}")
|
134 |
+
self.model = None
|
135 |
+
|
136 |
+
is_available = self.model is not None
|
137 |
+
logger.info(f"Model availability check result: {is_available}")
|
138 |
+
return is_available
|
139 |
|
140 |
def is_available(self) -> bool:
|
141 |
"""Check if Dia TTS is available."""
|
142 |
+
logger.info(f"🔍 Checking Dia availability: DIA_AVAILABLE={DIA_AVAILABLE}")
|
143 |
+
|
144 |
+
if not DIA_AVAILABLE:
|
145 |
+
logger.info("❌ Dia dependencies not available")
|
146 |
+
return False
|
147 |
+
|
148 |
+
model_available = self._ensure_model()
|
149 |
+
logger.info(f"🔍 Model availability: {model_available}")
|
150 |
+
|
151 |
+
result = DIA_AVAILABLE and model_available
|
152 |
+
logger.info(f"🎯 Dia TTS availability result: {result}")
|
153 |
+
return result
|
154 |
|
155 |
def get_available_voices(self) -> list[str]:
|
156 |
"""Get available voices for Dia."""
|
|
|
164 |
|
165 |
try:
|
166 |
import torch
|
167 |
+
|
168 |
# Extract parameters from request
|
169 |
text = request.text_content.text
|
170 |
|
|
|
203 |
|
204 |
try:
|
205 |
import torch
|
206 |
+
|
207 |
# Extract parameters from request
|
208 |
text = request.text_content.text
|
209 |
|
|
|
241 |
try:
|
242 |
# Create an in-memory buffer
|
243 |
buffer = io.BytesIO()
|
244 |
+
|
245 |
# Write audio data to buffer as WAV
|
246 |
sf.write(buffer, audio_array, sample_rate, format='WAV')
|
247 |
+
|
248 |
# Get bytes from buffer
|
249 |
buffer.seek(0)
|
250 |
return buffer.read()
|
251 |
+
|
252 |
except Exception as e:
|
253 |
raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e
|
src/infrastructure/tts/dummy_provider.py
CHANGED
@@ -44,14 +44,14 @@ class DummyTTSProvider(TTSProviderBase):
|
|
44 |
sample_rate = 24000
|
45 |
# Rough approximation of speech duration adjusted by speed
|
46 |
duration = min(len(text) / (20 * speed), 10)
|
47 |
-
|
48 |
# Create time array
|
49 |
t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
|
50 |
-
|
51 |
# Generate sine wave (440 Hz base frequency)
|
52 |
frequency = 440
|
53 |
audio = 0.5 * np.sin(2 * np.pi * frequency * t)
|
54 |
-
|
55 |
# Add some variation based on voice setting
|
56 |
voice = request.voice_settings.voice_id
|
57 |
if voice == 'male':
|
@@ -66,7 +66,7 @@ class DummyTTSProvider(TTSProviderBase):
|
|
66 |
|
67 |
# Convert to bytes
|
68 |
audio_bytes = self._numpy_to_bytes(audio, sample_rate)
|
69 |
-
|
70 |
logger.info(f"Generated dummy audio: duration={duration:.2f}s, voice={voice}")
|
71 |
return audio_bytes, sample_rate
|
72 |
|
@@ -84,24 +84,24 @@ class DummyTTSProvider(TTSProviderBase):
|
|
84 |
sample_rate = 24000
|
85 |
chunk_duration = 1.0 # 1 second chunks
|
86 |
total_duration = min(len(text) / (20 * speed), 10)
|
87 |
-
|
88 |
chunks_count = int(np.ceil(total_duration / chunk_duration))
|
89 |
-
|
90 |
for chunk_idx in range(chunks_count):
|
91 |
start_time = chunk_idx * chunk_duration
|
92 |
end_time = min((chunk_idx + 1) * chunk_duration, total_duration)
|
93 |
actual_duration = end_time - start_time
|
94 |
-
|
95 |
if actual_duration <= 0:
|
96 |
break
|
97 |
-
|
98 |
# Create time array for this chunk
|
99 |
t = np.linspace(0, actual_duration, int(sample_rate * actual_duration), endpoint=False)
|
100 |
-
|
101 |
# Generate sine wave
|
102 |
frequency = 440
|
103 |
audio = 0.5 * np.sin(2 * np.pi * frequency * t)
|
104 |
-
|
105 |
# Apply voice variations
|
106 |
voice = request.voice_settings.voice_id
|
107 |
if voice == 'male':
|
@@ -113,10 +113,10 @@ class DummyTTSProvider(TTSProviderBase):
|
|
113 |
|
114 |
# Convert to bytes
|
115 |
audio_bytes = self._numpy_to_bytes(audio, sample_rate)
|
116 |
-
|
117 |
# Check if this is the final chunk
|
118 |
is_final = (chunk_idx == chunks_count - 1)
|
119 |
-
|
120 |
yield audio_bytes, sample_rate, is_final
|
121 |
|
122 |
except Exception as e:
|
@@ -127,13 +127,13 @@ class DummyTTSProvider(TTSProviderBase):
|
|
127 |
try:
|
128 |
# Create an in-memory buffer
|
129 |
buffer = io.BytesIO()
|
130 |
-
|
131 |
# Write audio data to buffer as WAV
|
132 |
sf.write(buffer, audio_array, sample_rate, format='WAV')
|
133 |
-
|
134 |
# Get bytes from buffer
|
135 |
buffer.seek(0)
|
136 |
return buffer.read()
|
137 |
-
|
138 |
except Exception as e:
|
139 |
raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e
|
|
|
44 |
sample_rate = 24000
|
45 |
# Rough approximation of speech duration adjusted by speed
|
46 |
duration = min(len(text) / (20 * speed), 10)
|
47 |
+
|
48 |
# Create time array
|
49 |
t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
|
50 |
+
|
51 |
# Generate sine wave (440 Hz base frequency)
|
52 |
frequency = 440
|
53 |
audio = 0.5 * np.sin(2 * np.pi * frequency * t)
|
54 |
+
|
55 |
# Add some variation based on voice setting
|
56 |
voice = request.voice_settings.voice_id
|
57 |
if voice == 'male':
|
|
|
66 |
|
67 |
# Convert to bytes
|
68 |
audio_bytes = self._numpy_to_bytes(audio, sample_rate)
|
69 |
+
|
70 |
logger.info(f"Generated dummy audio: duration={duration:.2f}s, voice={voice}")
|
71 |
return audio_bytes, sample_rate
|
72 |
|
|
|
84 |
sample_rate = 24000
|
85 |
chunk_duration = 1.0 # 1 second chunks
|
86 |
total_duration = min(len(text) / (20 * speed), 10)
|
87 |
+
|
88 |
chunks_count = int(np.ceil(total_duration / chunk_duration))
|
89 |
+
|
90 |
for chunk_idx in range(chunks_count):
|
91 |
start_time = chunk_idx * chunk_duration
|
92 |
end_time = min((chunk_idx + 1) * chunk_duration, total_duration)
|
93 |
actual_duration = end_time - start_time
|
94 |
+
|
95 |
if actual_duration <= 0:
|
96 |
break
|
97 |
+
|
98 |
# Create time array for this chunk
|
99 |
t = np.linspace(0, actual_duration, int(sample_rate * actual_duration), endpoint=False)
|
100 |
+
|
101 |
# Generate sine wave
|
102 |
frequency = 440
|
103 |
audio = 0.5 * np.sin(2 * np.pi * frequency * t)
|
104 |
+
|
105 |
# Apply voice variations
|
106 |
voice = request.voice_settings.voice_id
|
107 |
if voice == 'male':
|
|
|
113 |
|
114 |
# Convert to bytes
|
115 |
audio_bytes = self._numpy_to_bytes(audio, sample_rate)
|
116 |
+
|
117 |
# Check if this is the final chunk
|
118 |
is_final = (chunk_idx == chunks_count - 1)
|
119 |
+
|
120 |
yield audio_bytes, sample_rate, is_final
|
121 |
|
122 |
except Exception as e:
|
|
|
127 |
try:
|
128 |
# Create an in-memory buffer
|
129 |
buffer = io.BytesIO()
|
130 |
+
|
131 |
# Write audio data to buffer as WAV
|
132 |
sf.write(buffer, audio_array, sample_rate, format='WAV')
|
133 |
+
|
134 |
# Get bytes from buffer
|
135 |
buffer.seek(0)
|
136 |
return buffer.read()
|
137 |
+
|
138 |
except Exception as e:
|
139 |
raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e
|
src/infrastructure/tts/kokoro_provider.py
CHANGED
@@ -77,7 +77,7 @@ class KokoroTTSProvider(TTSProviderBase):
|
|
77 |
|
78 |
# Generate speech using Kokoro
|
79 |
generator = self.pipeline(text, voice=voice, speed=speed)
|
80 |
-
|
81 |
for _, _, audio in generator:
|
82 |
# Convert numpy array to bytes
|
83 |
audio_bytes = self._numpy_to_bytes(audio, sample_rate=24000)
|
@@ -101,7 +101,7 @@ class KokoroTTSProvider(TTSProviderBase):
|
|
101 |
|
102 |
# Generate speech stream using Kokoro
|
103 |
generator = self.pipeline(text, voice=voice, speed=speed)
|
104 |
-
|
105 |
chunk_count = 0
|
106 |
for _, _, audio in generator:
|
107 |
chunk_count += 1
|
@@ -119,13 +119,13 @@ class KokoroTTSProvider(TTSProviderBase):
|
|
119 |
try:
|
120 |
# Create an in-memory buffer
|
121 |
buffer = io.BytesIO()
|
122 |
-
|
123 |
# Write audio data to buffer as WAV
|
124 |
sf.write(buffer, audio_array, sample_rate, format='WAV')
|
125 |
-
|
126 |
# Get bytes from buffer
|
127 |
buffer.seek(0)
|
128 |
return buffer.read()
|
129 |
-
|
130 |
except Exception as e:
|
131 |
raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e
|
|
|
77 |
|
78 |
# Generate speech using Kokoro
|
79 |
generator = self.pipeline(text, voice=voice, speed=speed)
|
80 |
+
|
81 |
for _, _, audio in generator:
|
82 |
# Convert numpy array to bytes
|
83 |
audio_bytes = self._numpy_to_bytes(audio, sample_rate=24000)
|
|
|
101 |
|
102 |
# Generate speech stream using Kokoro
|
103 |
generator = self.pipeline(text, voice=voice, speed=speed)
|
104 |
+
|
105 |
chunk_count = 0
|
106 |
for _, _, audio in generator:
|
107 |
chunk_count += 1
|
|
|
119 |
try:
|
120 |
# Create an in-memory buffer
|
121 |
buffer = io.BytesIO()
|
122 |
+
|
123 |
# Write audio data to buffer as WAV
|
124 |
sf.write(buffer, audio_array, sample_rate, format='WAV')
|
125 |
+
|
126 |
# Get bytes from buffer
|
127 |
buffer.seek(0)
|
128 |
return buffer.read()
|
129 |
+
|
130 |
except Exception as e:
|
131 |
raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e
|
src/infrastructure/tts/provider_factory.py
CHANGED
@@ -31,7 +31,7 @@ class TTSProviderFactory:
|
|
31 |
self._providers['kokoro'] = KokoroTTSProvider
|
32 |
logger.info("Registered Kokoro TTS provider")
|
33 |
except ImportError as e:
|
34 |
-
logger.
|
35 |
|
36 |
# Try to register Dia provider
|
37 |
try:
|
@@ -56,18 +56,23 @@ class TTSProviderFactory:
|
|
56 |
self._providers['cosyvoice2'] = CosyVoice2TTSProvider
|
57 |
logger.info("Registered CosyVoice2 TTS provider")
|
58 |
except ImportError as e:
|
59 |
-
logger.
|
60 |
|
61 |
def get_available_providers(self) -> List[str]:
|
62 |
"""Get list of available TTS providers."""
|
|
|
63 |
available = []
|
|
|
64 |
for name, provider_class in self._providers.items():
|
|
|
65 |
try:
|
66 |
# Create instance if not cached
|
67 |
if name not in self._provider_instances:
|
|
|
68 |
if name == 'kokoro':
|
69 |
self._provider_instances[name] = provider_class()
|
70 |
elif name == 'dia':
|
|
|
71 |
self._provider_instances[name] = provider_class()
|
72 |
elif name == 'cosyvoice2':
|
73 |
self._provider_instances[name] = provider_class()
|
@@ -75,12 +80,18 @@ class TTSProviderFactory:
|
|
75 |
self._provider_instances[name] = provider_class()
|
76 |
|
77 |
# Check if provider is available
|
78 |
-
|
|
|
|
|
|
|
|
|
79 |
available.append(name)
|
80 |
|
81 |
except Exception as e:
|
82 |
-
logger.warning(f"Failed to check availability of {name} provider: {e}")
|
|
|
83 |
|
|
|
84 |
return available
|
85 |
|
86 |
def create_provider(self, provider_name: str, **kwargs) -> TTSProviderBase:
|
@@ -147,16 +158,23 @@ class TTSProviderFactory:
|
|
147 |
if preferred_providers is None:
|
148 |
preferred_providers = ['kokoro', 'dia', 'cosyvoice2', 'dummy']
|
149 |
|
|
|
150 |
available_providers = self.get_available_providers()
|
151 |
|
152 |
# Try preferred providers in order
|
153 |
for provider_name in preferred_providers:
|
|
|
154 |
if provider_name in available_providers:
|
|
|
155 |
try:
|
156 |
-
|
|
|
|
|
157 |
except Exception as e:
|
158 |
-
logger.warning(f"Failed to create preferred provider {provider_name}: {e}")
|
159 |
continue
|
|
|
|
|
160 |
|
161 |
# If no preferred providers work, try any available provider
|
162 |
for provider_name in available_providers:
|
|
|
31 |
self._providers['kokoro'] = KokoroTTSProvider
|
32 |
logger.info("Registered Kokoro TTS provider")
|
33 |
except ImportError as e:
|
34 |
+
logger.info(f"Kokoro TTS provider not available: {e}")
|
35 |
|
36 |
# Try to register Dia provider
|
37 |
try:
|
|
|
56 |
self._providers['cosyvoice2'] = CosyVoice2TTSProvider
|
57 |
logger.info("Registered CosyVoice2 TTS provider")
|
58 |
except ImportError as e:
|
59 |
+
logger.info(f"CosyVoice2 TTS provider not available: {e}")
|
60 |
|
61 |
def get_available_providers(self) -> List[str]:
|
62 |
"""Get list of available TTS providers."""
|
63 |
+
logger.info("🔍 Checking availability of TTS providers...")
|
64 |
available = []
|
65 |
+
|
66 |
for name, provider_class in self._providers.items():
|
67 |
+
logger.info(f"Checking provider: {name}")
|
68 |
try:
|
69 |
# Create instance if not cached
|
70 |
if name not in self._provider_instances:
|
71 |
+
logger.info(f"Creating instance for {name} provider")
|
72 |
if name == 'kokoro':
|
73 |
self._provider_instances[name] = provider_class()
|
74 |
elif name == 'dia':
|
75 |
+
logger.info(f"🔧 Creating Dia TTS provider instance...")
|
76 |
self._provider_instances[name] = provider_class()
|
77 |
elif name == 'cosyvoice2':
|
78 |
self._provider_instances[name] = provider_class()
|
|
|
80 |
self._provider_instances[name] = provider_class()
|
81 |
|
82 |
# Check if provider is available
|
83 |
+
logger.info(f"Checking availability for {name}")
|
84 |
+
is_available = self._provider_instances[name].is_available()
|
85 |
+
logger.info(f"Provider {name} availability: {'✅ Available' if is_available else '❌ Not Available'}")
|
86 |
+
|
87 |
+
if is_available:
|
88 |
available.append(name)
|
89 |
|
90 |
except Exception as e:
|
91 |
+
logger.warning(f"❌ Failed to check availability of {name} provider: {e}")
|
92 |
+
logger.info(f"Provider check error details: {type(e).__name__}: {e}")
|
93 |
|
94 |
+
logger.info(f"📋 Available TTS providers: {available}")
|
95 |
return available
|
96 |
|
97 |
def create_provider(self, provider_name: str, **kwargs) -> TTSProviderBase:
|
|
|
158 |
if preferred_providers is None:
|
159 |
preferred_providers = ['kokoro', 'dia', 'cosyvoice2', 'dummy']
|
160 |
|
161 |
+
logger.info(f"🔄 Getting TTS provider with fallback, preferred order: {preferred_providers}")
|
162 |
available_providers = self.get_available_providers()
|
163 |
|
164 |
# Try preferred providers in order
|
165 |
for provider_name in preferred_providers:
|
166 |
+
logger.info(f"🔍 Trying preferred provider: {provider_name}")
|
167 |
if provider_name in available_providers:
|
168 |
+
logger.info(f"✅ Provider {provider_name} is available, attempting to create...")
|
169 |
try:
|
170 |
+
provider = self.create_provider(provider_name, **kwargs)
|
171 |
+
logger.info(f"🎉 Successfully created provider: {provider_name}")
|
172 |
+
return provider
|
173 |
except Exception as e:
|
174 |
+
logger.warning(f"❌ Failed to create preferred provider {provider_name}: {e}")
|
175 |
continue
|
176 |
+
else:
|
177 |
+
logger.info(f"❌ Provider {provider_name} is not in available providers list")
|
178 |
|
179 |
# If no preferred providers work, try any available provider
|
180 |
for provider_name in available_providers:
|
src/infrastructure/utils/dependency_installer.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Automatic dependency installer for TTS providers."""
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import subprocess
|
5 |
+
import sys
|
6 |
+
import importlib
|
7 |
+
from typing import List, Dict, Optional, Tuple
|
8 |
+
import os
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
class DependencyInstaller:
|
14 |
+
"""Utility class for automatically installing missing dependencies."""
|
15 |
+
|
16 |
+
def __init__(self):
|
17 |
+
"""Initialize the dependency installer."""
|
18 |
+
self.installed_packages = set()
|
19 |
+
|
20 |
+
def check_module_available(self, module_name: str) -> bool:
|
21 |
+
"""
|
22 |
+
Check if a module is available for import.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
module_name: Name of the module to check
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
bool: True if module is available, False otherwise
|
29 |
+
"""
|
30 |
+
try:
|
31 |
+
importlib.import_module(module_name)
|
32 |
+
return True
|
33 |
+
except ImportError:
|
34 |
+
return False
|
35 |
+
|
36 |
+
def install_package(self, package_name: str, upgrade: bool = False) -> bool:
|
37 |
+
"""
|
38 |
+
Install a package using pip.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
package_name: Name of the package to install
|
42 |
+
upgrade: Whether to upgrade if already installed
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
bool: True if installation succeeded, False otherwise
|
46 |
+
"""
|
47 |
+
if package_name in self.installed_packages:
|
48 |
+
logger.info(f"Package {package_name} already installed in this session")
|
49 |
+
return True
|
50 |
+
|
51 |
+
try:
|
52 |
+
cmd = [sys.executable, "-m", "pip", "install"]
|
53 |
+
if upgrade:
|
54 |
+
cmd.append("--upgrade")
|
55 |
+
cmd.append(package_name)
|
56 |
+
|
57 |
+
logger.info(f"Installing package: {package_name}")
|
58 |
+
result = subprocess.run(
|
59 |
+
cmd,
|
60 |
+
capture_output=True,
|
61 |
+
text=True,
|
62 |
+
timeout=300 # 5 minute timeout
|
63 |
+
)
|
64 |
+
|
65 |
+
if result.returncode == 0:
|
66 |
+
logger.info(f"Successfully installed {package_name}")
|
67 |
+
self.installed_packages.add(package_name)
|
68 |
+
return True
|
69 |
+
else:
|
70 |
+
logger.error(f"Failed to install {package_name}: {result.stderr}")
|
71 |
+
return False
|
72 |
+
|
73 |
+
except subprocess.TimeoutExpired:
|
74 |
+
logger.error(f"Installation of {package_name} timed out")
|
75 |
+
return False
|
76 |
+
except Exception as e:
|
77 |
+
logger.error(f"Error installing {package_name}: {e}")
|
78 |
+
return False
|
79 |
+
|
80 |
+
def install_from_git(self, git_url: str, package_name: Optional[str] = None) -> bool:
|
81 |
+
"""
|
82 |
+
Install a package from a git repository.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
git_url: Git repository URL
|
86 |
+
package_name: Optional package name for tracking
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
bool: True if installation succeeded, False otherwise
|
90 |
+
"""
|
91 |
+
package_name = package_name or git_url.split('/')[-1].replace('.git', '')
|
92 |
+
|
93 |
+
if package_name in self.installed_packages:
|
94 |
+
logger.info(f"Package {package_name} already installed in this session")
|
95 |
+
return True
|
96 |
+
|
97 |
+
try:
|
98 |
+
cmd = [sys.executable, "-m", "pip", "install", f"git+{git_url}"]
|
99 |
+
|
100 |
+
logger.info(f"Installing package from git: {git_url}")
|
101 |
+
result = subprocess.run(
|
102 |
+
cmd,
|
103 |
+
capture_output=True,
|
104 |
+
text=True,
|
105 |
+
timeout=600 # 10 minute timeout for git installs
|
106 |
+
)
|
107 |
+
|
108 |
+
if result.returncode == 0:
|
109 |
+
logger.info(f"Successfully installed {package_name} from git")
|
110 |
+
self.installed_packages.add(package_name)
|
111 |
+
return True
|
112 |
+
else:
|
113 |
+
logger.error(f"Failed to install {package_name} from git: {result.stderr}")
|
114 |
+
return False
|
115 |
+
|
116 |
+
except subprocess.TimeoutExpired:
|
117 |
+
logger.error(f"Git installation of {package_name} timed out")
|
118 |
+
return False
|
119 |
+
except Exception as e:
|
120 |
+
logger.error(f"Error installing {package_name} from git: {e}")
|
121 |
+
return False
|
122 |
+
|
123 |
+
def install_dia_dependencies(self) -> Tuple[bool, List[str]]:
|
124 |
+
"""
|
125 |
+
Install all dependencies required for Dia TTS.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
Tuple[bool, List[str]]: (success, list of error messages)
|
129 |
+
"""
|
130 |
+
errors = []
|
131 |
+
|
132 |
+
# Check if Dia is already available
|
133 |
+
if self.check_module_available("dia"):
|
134 |
+
logger.info("Dia TTS is already available")
|
135 |
+
return True, []
|
136 |
+
|
137 |
+
# Install Dia TTS from git - this will automatically install all dependencies
|
138 |
+
# including descript-audio-codec as specified in pyproject.toml
|
139 |
+
logger.info("Installing Dia TTS and all dependencies from GitHub")
|
140 |
+
if self.install_from_git("https://github.com/nari-labs/dia.git", "dia"):
|
141 |
+
logger.info("Successfully installed Dia TTS and dependencies")
|
142 |
+
return True, []
|
143 |
+
else:
|
144 |
+
errors.append("Failed to install Dia TTS from git")
|
145 |
+
|
146 |
+
# Fallback: try installing individual dependencies if git install fails
|
147 |
+
logger.info("Git install failed, trying individual dependencies...")
|
148 |
+
dependencies = [
|
149 |
+
("torch", "torch"),
|
150 |
+
("transformers", "transformers"),
|
151 |
+
("accelerate", "accelerate"),
|
152 |
+
("soundfile", "soundfile"),
|
153 |
+
("dac", "descript-audio-codec"),
|
154 |
+
]
|
155 |
+
|
156 |
+
success = True
|
157 |
+
for module_name, package_name in dependencies:
|
158 |
+
if not self.check_module_available(module_name):
|
159 |
+
logger.info(f"Installing missing dependency: {package_name}")
|
160 |
+
if not self.install_package(package_name):
|
161 |
+
errors.append(f"Failed to install {package_name}")
|
162 |
+
success = False
|
163 |
+
|
164 |
+
# Try installing Dia again after dependencies
|
165 |
+
if success and not self.check_module_available("dia"):
|
166 |
+
if self.install_from_git("https://github.com/nari-labs/dia.git", "dia"):
|
167 |
+
return True, []
|
168 |
+
else:
|
169 |
+
errors.append("Failed to install Dia TTS after installing dependencies")
|
170 |
+
|
171 |
+
return success and len(errors) == 1, errors # Only the initial git error if dependencies succeeded
|
172 |
+
|
173 |
+
def install_dependencies_for_provider(self, provider_name: str) -> Tuple[bool, List[str]]:
|
174 |
+
"""
|
175 |
+
Install dependencies for a specific TTS provider.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
provider_name: Name of the TTS provider
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
Tuple[bool, List[str]]: (success, list of error messages)
|
182 |
+
"""
|
183 |
+
if provider_name.lower() == "dia":
|
184 |
+
return self.install_dia_dependencies()
|
185 |
+
else:
|
186 |
+
return False, [f"Unknown provider: {provider_name}"]
|
187 |
+
|
188 |
+
def verify_installation(self, module_name: str) -> bool:
|
189 |
+
"""
|
190 |
+
Verify that a module was installed correctly.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
module_name: Name of the module to verify
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
bool: True if module can be imported, False otherwise
|
197 |
+
"""
|
198 |
+
try:
|
199 |
+
# Clear import cache to ensure fresh import
|
200 |
+
if module_name in sys.modules:
|
201 |
+
del sys.modules[module_name]
|
202 |
+
|
203 |
+
importlib.import_module(module_name)
|
204 |
+
logger.info(f"Successfully verified installation of {module_name}")
|
205 |
+
return True
|
206 |
+
except ImportError as e:
|
207 |
+
logger.error(f"Failed to verify installation of {module_name}: {e}")
|
208 |
+
return False
|
209 |
+
|
210 |
+
def get_installation_status(self) -> Dict[str, bool]:
|
211 |
+
"""
|
212 |
+
Get the installation status of key dependencies.
|
213 |
+
|
214 |
+
Returns:
|
215 |
+
Dict[str, bool]: Dictionary mapping module names to availability status
|
216 |
+
"""
|
217 |
+
modules_to_check = [
|
218 |
+
"torch",
|
219 |
+
"transformers",
|
220 |
+
"accelerate",
|
221 |
+
"soundfile",
|
222 |
+
"numpy",
|
223 |
+
"dac",
|
224 |
+
"dia"
|
225 |
+
]
|
226 |
+
|
227 |
+
status = {}
|
228 |
+
for module in modules_to_check:
|
229 |
+
status[module] = self.check_module_available(module)
|
230 |
+
|
231 |
+
return status
|
232 |
+
|
233 |
+
def install_with_retry(self, package_name: str, max_retries: int = 3) -> bool:
|
234 |
+
"""
|
235 |
+
Install a package with retry logic.
|
236 |
+
|
237 |
+
Args:
|
238 |
+
package_name: Name of the package to install
|
239 |
+
max_retries: Maximum number of retry attempts
|
240 |
+
|
241 |
+
Returns:
|
242 |
+
bool: True if installation succeeded, False otherwise
|
243 |
+
"""
|
244 |
+
for attempt in range(max_retries):
|
245 |
+
if self.install_package(package_name):
|
246 |
+
return True
|
247 |
+
|
248 |
+
if attempt < max_retries - 1:
|
249 |
+
logger.warning(f"Installation attempt {attempt + 1} failed for {package_name}, retrying...")
|
250 |
+
else:
|
251 |
+
logger.error(f"All {max_retries} installation attempts failed for {package_name}")
|
252 |
+
|
253 |
+
return False
|
254 |
+
|
255 |
+
|
256 |
+
# Global instance for reuse
|
257 |
+
_dependency_installer = None
|
258 |
+
|
259 |
+
|
260 |
+
def get_dependency_installer() -> DependencyInstaller:
|
261 |
+
"""
|
262 |
+
Get a global dependency installer instance.
|
263 |
+
|
264 |
+
Returns:
|
265 |
+
DependencyInstaller: Global dependency installer instance
|
266 |
+
"""
|
267 |
+
global _dependency_installer
|
268 |
+
if _dependency_installer is None:
|
269 |
+
_dependency_installer = DependencyInstaller()
|
270 |
+
return _dependency_installer
|
271 |
+
|
272 |
+
|
273 |
+
def install_dia_dependencies() -> Tuple[bool, List[str]]:
|
274 |
+
"""
|
275 |
+
Convenience function to install Dia TTS dependencies.
|
276 |
+
|
277 |
+
Returns:
|
278 |
+
Tuple[bool, List[str]]: (success, list of error messages)
|
279 |
+
"""
|
280 |
+
installer = get_dependency_installer()
|
281 |
+
return installer.install_dia_dependencies()
|
282 |
+
|
283 |
+
|
284 |
+
def check_and_install_module(module_name: str, package_name: Optional[str] = None) -> bool:
|
285 |
+
"""
|
286 |
+
Check if a module is available and install it if not.
|
287 |
+
|
288 |
+
Args:
|
289 |
+
module_name: Name of the module to check
|
290 |
+
package_name: Name of the package to install (defaults to module_name)
|
291 |
+
|
292 |
+
Returns:
|
293 |
+
bool: True if module is available after check/install, False otherwise
|
294 |
+
"""
|
295 |
+
installer = get_dependency_installer()
|
296 |
+
|
297 |
+
if installer.check_module_available(module_name):
|
298 |
+
return True
|
299 |
+
|
300 |
+
package_name = package_name or module_name
|
301 |
+
if installer.install_package(package_name):
|
302 |
+
return installer.verify_installation(module_name)
|
303 |
+
|
304 |
+
return False
|
tests/unit/application/error_handling/test_structured_logger.py
CHANGED
@@ -60,7 +60,7 @@ class TestStructuredLogger:
|
|
60 |
context = LogContext(correlation_id="test-123", operation="test_op")
|
61 |
|
62 |
with patch.object(self.logger.logger, 'debug') as mock_debug:
|
63 |
-
self.logger.
|
64 |
|
65 |
mock_debug.assert_called_once()
|
66 |
args, kwargs = mock_debug.call_args
|
|
|
60 |
context = LogContext(correlation_id="test-123", operation="test_op")
|
61 |
|
62 |
with patch.object(self.logger.logger, 'debug') as mock_debug:
|
63 |
+
self.logger.info("Test debug message", context=context)
|
64 |
|
65 |
mock_debug.assert_called_once()
|
66 |
args, kwargs = mock_debug.call_args
|
utils/stt.py
CHANGED
@@ -16,17 +16,17 @@ from pydub import AudioSegment
|
|
16 |
|
17 |
class ASRModel(ABC):
|
18 |
"""Base class for ASR models"""
|
19 |
-
|
20 |
@abstractmethod
|
21 |
def load_model(self):
|
22 |
"""Load the ASR model"""
|
23 |
pass
|
24 |
-
|
25 |
@abstractmethod
|
26 |
def transcribe(self, audio_path):
|
27 |
"""Transcribe audio to text"""
|
28 |
pass
|
29 |
-
|
30 |
def preprocess_audio(self, audio_path):
|
31 |
"""Convert audio to required format"""
|
32 |
logger.info("Converting audio format")
|
@@ -42,7 +42,7 @@ class ASRModel(ABC):
|
|
42 |
|
43 |
class WhisperModel(ASRModel):
|
44 |
"""Faster Whisper ASR model implementation"""
|
45 |
-
|
46 |
def __init__(self):
|
47 |
self.model = None
|
48 |
# Check for CUDA availability without torch dependency
|
@@ -53,13 +53,13 @@ class WhisperModel(ASRModel):
|
|
53 |
# Fallback to CPU if torch is not available
|
54 |
self.device = "cpu"
|
55 |
self.compute_type = "float16" if self.device == "cuda" else "int8"
|
56 |
-
|
57 |
def load_model(self):
|
58 |
"""Load Faster Whisper model"""
|
59 |
logger.info("Loading Faster Whisper model")
|
60 |
logger.info(f"Using device: {self.device}")
|
61 |
logger.info(f"Using compute type: {self.compute_type}")
|
62 |
-
|
63 |
# Use large-v3 model with appropriate compute type based on device
|
64 |
self.model = FasterWhisperModel(
|
65 |
"large-v3",
|
@@ -67,14 +67,14 @@ class WhisperModel(ASRModel):
|
|
67 |
compute_type=self.compute_type
|
68 |
)
|
69 |
logger.info("Faster Whisper model loaded successfully")
|
70 |
-
|
71 |
def transcribe(self, audio_path):
|
72 |
"""Transcribe audio using Faster Whisper"""
|
73 |
if self.model is None:
|
74 |
self.load_model()
|
75 |
-
|
76 |
wav_path = self.preprocess_audio(audio_path)
|
77 |
-
|
78 |
# Transcription with Faster Whisper
|
79 |
logger.info("Generating transcription with Faster Whisper")
|
80 |
segments, info = self.model.transcribe(
|
@@ -83,15 +83,15 @@ class WhisperModel(ASRModel):
|
|
83 |
language="en",
|
84 |
task="transcribe"
|
85 |
)
|
86 |
-
|
87 |
logger.info(f"Detected language '{info.language}' with probability {info.language_probability}")
|
88 |
-
|
89 |
# Collect all segments into a single text
|
90 |
result_text = ""
|
91 |
for segment in segments:
|
92 |
result_text += segment.text + " "
|
93 |
-
logger.
|
94 |
-
|
95 |
result = result_text.strip()
|
96 |
logger.info(f"Transcription completed successfully")
|
97 |
return result
|
@@ -99,10 +99,10 @@ class WhisperModel(ASRModel):
|
|
99 |
|
100 |
class ParakeetModel(ASRModel):
|
101 |
"""Parakeet ASR model implementation"""
|
102 |
-
|
103 |
def __init__(self):
|
104 |
self.model = None
|
105 |
-
|
106 |
def load_model(self):
|
107 |
"""Load Parakeet model"""
|
108 |
try:
|
@@ -113,14 +113,14 @@ class ParakeetModel(ASRModel):
|
|
113 |
except ImportError:
|
114 |
logger.error("Failed to import nemo_toolkit. Please install with: pip install -U 'nemo_toolkit[asr]'")
|
115 |
raise
|
116 |
-
|
117 |
def transcribe(self, audio_path):
|
118 |
"""Transcribe audio using Parakeet"""
|
119 |
if self.model is None:
|
120 |
self.load_model()
|
121 |
-
|
122 |
wav_path = self.preprocess_audio(audio_path)
|
123 |
-
|
124 |
# Transcription
|
125 |
logger.info("Generating transcription with Parakeet")
|
126 |
output = self.model.transcribe([wav_path])
|
@@ -131,7 +131,7 @@ class ParakeetModel(ASRModel):
|
|
131 |
|
132 |
class ASRFactory:
|
133 |
"""Factory for creating ASR model instances"""
|
134 |
-
|
135 |
@staticmethod
|
136 |
def get_model(model_name="parakeet"):
|
137 |
"""
|
@@ -160,11 +160,11 @@ def transcribe_audio(audio_path, model_name="parakeet"):
|
|
160 |
Transcribed English text
|
161 |
"""
|
162 |
logger.info(f"Starting transcription for: {audio_path} using {model_name} model")
|
163 |
-
|
164 |
try:
|
165 |
# Get the appropriate model
|
166 |
asr_model = ASRFactory.get_model(model_name)
|
167 |
-
|
168 |
# Transcribe audio
|
169 |
result = asr_model.transcribe(audio_path)
|
170 |
logger.info(f"transcription: %s" % result)
|
|
|
16 |
|
17 |
class ASRModel(ABC):
|
18 |
"""Base class for ASR models"""
|
19 |
+
|
20 |
@abstractmethod
|
21 |
def load_model(self):
|
22 |
"""Load the ASR model"""
|
23 |
pass
|
24 |
+
|
25 |
@abstractmethod
|
26 |
def transcribe(self, audio_path):
|
27 |
"""Transcribe audio to text"""
|
28 |
pass
|
29 |
+
|
30 |
def preprocess_audio(self, audio_path):
|
31 |
"""Convert audio to required format"""
|
32 |
logger.info("Converting audio format")
|
|
|
42 |
|
43 |
class WhisperModel(ASRModel):
|
44 |
"""Faster Whisper ASR model implementation"""
|
45 |
+
|
46 |
def __init__(self):
|
47 |
self.model = None
|
48 |
# Check for CUDA availability without torch dependency
|
|
|
53 |
# Fallback to CPU if torch is not available
|
54 |
self.device = "cpu"
|
55 |
self.compute_type = "float16" if self.device == "cuda" else "int8"
|
56 |
+
|
57 |
def load_model(self):
|
58 |
"""Load Faster Whisper model"""
|
59 |
logger.info("Loading Faster Whisper model")
|
60 |
logger.info(f"Using device: {self.device}")
|
61 |
logger.info(f"Using compute type: {self.compute_type}")
|
62 |
+
|
63 |
# Use large-v3 model with appropriate compute type based on device
|
64 |
self.model = FasterWhisperModel(
|
65 |
"large-v3",
|
|
|
67 |
compute_type=self.compute_type
|
68 |
)
|
69 |
logger.info("Faster Whisper model loaded successfully")
|
70 |
+
|
71 |
def transcribe(self, audio_path):
|
72 |
"""Transcribe audio using Faster Whisper"""
|
73 |
if self.model is None:
|
74 |
self.load_model()
|
75 |
+
|
76 |
wav_path = self.preprocess_audio(audio_path)
|
77 |
+
|
78 |
# Transcription with Faster Whisper
|
79 |
logger.info("Generating transcription with Faster Whisper")
|
80 |
segments, info = self.model.transcribe(
|
|
|
83 |
language="en",
|
84 |
task="transcribe"
|
85 |
)
|
86 |
+
|
87 |
logger.info(f"Detected language '{info.language}' with probability {info.language_probability}")
|
88 |
+
|
89 |
# Collect all segments into a single text
|
90 |
result_text = ""
|
91 |
for segment in segments:
|
92 |
result_text += segment.text + " "
|
93 |
+
logger.info(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}")
|
94 |
+
|
95 |
result = result_text.strip()
|
96 |
logger.info(f"Transcription completed successfully")
|
97 |
return result
|
|
|
99 |
|
100 |
class ParakeetModel(ASRModel):
|
101 |
"""Parakeet ASR model implementation"""
|
102 |
+
|
103 |
def __init__(self):
|
104 |
self.model = None
|
105 |
+
|
106 |
def load_model(self):
|
107 |
"""Load Parakeet model"""
|
108 |
try:
|
|
|
113 |
except ImportError:
|
114 |
logger.error("Failed to import nemo_toolkit. Please install with: pip install -U 'nemo_toolkit[asr]'")
|
115 |
raise
|
116 |
+
|
117 |
def transcribe(self, audio_path):
|
118 |
"""Transcribe audio using Parakeet"""
|
119 |
if self.model is None:
|
120 |
self.load_model()
|
121 |
+
|
122 |
wav_path = self.preprocess_audio(audio_path)
|
123 |
+
|
124 |
# Transcription
|
125 |
logger.info("Generating transcription with Parakeet")
|
126 |
output = self.model.transcribe([wav_path])
|
|
|
131 |
|
132 |
class ASRFactory:
|
133 |
"""Factory for creating ASR model instances"""
|
134 |
+
|
135 |
@staticmethod
|
136 |
def get_model(model_name="parakeet"):
|
137 |
"""
|
|
|
160 |
Transcribed English text
|
161 |
"""
|
162 |
logger.info(f"Starting transcription for: {audio_path} using {model_name} model")
|
163 |
+
|
164 |
try:
|
165 |
# Get the appropriate model
|
166 |
asr_model = ASRFactory.get_model(model_name)
|
167 |
+
|
168 |
# Transcribe audio
|
169 |
result = asr_model.transcribe(audio_path)
|
170 |
logger.info(f"transcription: %s" % result)
|
utils/translation.py
CHANGED
@@ -17,7 +17,7 @@ def translate_text(text):
|
|
17 |
Translated Chinese text
|
18 |
"""
|
19 |
logger.info(f"Starting translation for text length: {len(text)}")
|
20 |
-
|
21 |
try:
|
22 |
# Model initialization with explicit language codes
|
23 |
logger.info("Loading NLLB model")
|
@@ -36,7 +36,7 @@ def translate_text(text):
|
|
36 |
translated_chunks = []
|
37 |
for i, chunk in enumerate(text_chunks):
|
38 |
logger.info(f"Processing chunk {i+1}/{len(text_chunks)}")
|
39 |
-
|
40 |
# Tokenize with source language specification
|
41 |
inputs = tokenizer(
|
42 |
chunk,
|
@@ -44,14 +44,14 @@ def translate_text(text):
|
|
44 |
max_length=1024,
|
45 |
truncation=True
|
46 |
)
|
47 |
-
|
48 |
# Generate translation with target language specification
|
49 |
outputs = model.generate(
|
50 |
**inputs,
|
51 |
forced_bos_token_id=tokenizer.convert_tokens_to_ids("zho_Hans"),
|
52 |
max_new_tokens=1024
|
53 |
)
|
54 |
-
|
55 |
translated = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
56 |
translated_chunks.append(translated)
|
57 |
logger.info(f"Chunk {i+1} translated successfully")
|
|
|
17 |
Translated Chinese text
|
18 |
"""
|
19 |
logger.info(f"Starting translation for text length: {len(text)}")
|
20 |
+
|
21 |
try:
|
22 |
# Model initialization with explicit language codes
|
23 |
logger.info("Loading NLLB model")
|
|
|
36 |
translated_chunks = []
|
37 |
for i, chunk in enumerate(text_chunks):
|
38 |
logger.info(f"Processing chunk {i+1}/{len(text_chunks)}")
|
39 |
+
|
40 |
# Tokenize with source language specification
|
41 |
inputs = tokenizer(
|
42 |
chunk,
|
|
|
44 |
max_length=1024,
|
45 |
truncation=True
|
46 |
)
|
47 |
+
|
48 |
# Generate translation with target language specification
|
49 |
outputs = model.generate(
|
50 |
**inputs,
|
51 |
forced_bos_token_id=tokenizer.convert_tokens_to_ids("zho_Hans"),
|
52 |
max_new_tokens=1024
|
53 |
)
|
54 |
+
|
55 |
translated = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
56 |
translated_chunks.append(translated)
|
57 |
logger.info(f"Chunk {i+1} translated successfully")
|
utils/tts.py
CHANGED
@@ -17,42 +17,42 @@ logger = logging.getLogger(__name__)
|
|
17 |
|
18 |
def get_available_engines() -> List[str]:
|
19 |
"""Get a list of available TTS engines
|
20 |
-
|
21 |
Returns:
|
22 |
List[str]: List of available engine names
|
23 |
"""
|
24 |
available = []
|
25 |
-
|
26 |
if KOKORO_AVAILABLE:
|
27 |
available.append('kokoro')
|
28 |
-
|
29 |
if DIA_AVAILABLE:
|
30 |
available.append('dia')
|
31 |
-
|
32 |
if COSYVOICE2_AVAILABLE:
|
33 |
available.append('cosyvoice2')
|
34 |
-
|
35 |
# Dummy is always available
|
36 |
available.append('dummy')
|
37 |
-
|
38 |
return available
|
39 |
|
40 |
|
41 |
def get_tts_engine(engine_type: Optional[str] = None, lang_code: str = 'z') -> TTSBase:
|
42 |
"""Get a TTS engine instance
|
43 |
-
|
44 |
Args:
|
45 |
engine_type (str, optional): Type of engine to create ('kokoro', 'dia', 'cosyvoice2', 'dummy')
|
46 |
If None, the best available engine will be used
|
47 |
lang_code (str): Language code for the engine
|
48 |
-
|
49 |
Returns:
|
50 |
TTSBase: An instance of a TTS engine
|
51 |
"""
|
52 |
# Get available engines
|
53 |
available_engines = get_available_engines()
|
54 |
logger.info(f"Available TTS engines: {available_engines}")
|
55 |
-
|
56 |
# If engine_type is specified, try to create that specific engine
|
57 |
if engine_type is not None:
|
58 |
if engine_type == 'kokoro' and KOKORO_AVAILABLE:
|
@@ -69,7 +69,7 @@ def get_tts_engine(engine_type: Optional[str] = None, lang_code: str = 'z') -> T
|
|
69 |
return DummyTTS(lang_code)
|
70 |
else:
|
71 |
logger.warning(f"Requested engine '{engine_type}' is not available")
|
72 |
-
|
73 |
# If no specific engine is requested or the requested engine is not available,
|
74 |
# use the best available engine based on priority
|
75 |
priority_order = ['cosyvoice2', 'kokoro', 'dia', 'dummy']
|
@@ -84,23 +84,23 @@ def get_tts_engine(engine_type: Optional[str] = None, lang_code: str = 'z') -> T
|
|
84 |
return CosyVoice2TTS(lang_code)
|
85 |
elif engine == 'dummy':
|
86 |
return DummyTTS(lang_code)
|
87 |
-
|
88 |
# Fallback to dummy engine if no engines are available
|
89 |
logger.warning("No TTS engines available, falling back to dummy engine")
|
90 |
return DummyTTS(lang_code)
|
91 |
|
92 |
|
93 |
-
def generate_speech(text: str, engine_type: Optional[str] = None, lang_code: str = 'z',
|
94 |
voice: str = 'default', speed: float = 1.0) -> Optional[str]:
|
95 |
"""Generate speech using the specified or best available TTS engine
|
96 |
-
|
97 |
Args:
|
98 |
text (str): Input text to synthesize
|
99 |
engine_type (str, optional): Type of engine to use
|
100 |
lang_code (str): Language code
|
101 |
voice (str): Voice ID to use
|
102 |
speed (float): Speech speed multiplier
|
103 |
-
|
104 |
Returns:
|
105 |
Optional[str]: Path to the generated audio file or None if generation fails
|
106 |
"""
|
@@ -111,14 +111,14 @@ def generate_speech(text: str, engine_type: Optional[str] = None, lang_code: str
|
|
111 |
def generate_speech_stream(text: str, engine_type: Optional[str] = None, lang_code: str = 'z',
|
112 |
voice: str = 'default', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
|
113 |
"""Generate speech stream using the specified or best available TTS engine
|
114 |
-
|
115 |
Args:
|
116 |
text (str): Input text to synthesize
|
117 |
engine_type (str, optional): Type of engine to use
|
118 |
lang_code (str): Language code
|
119 |
voice (str): Voice ID to use
|
120 |
speed (float): Speech speed multiplier
|
121 |
-
|
122 |
Yields:
|
123 |
tuple: (sample_rate, audio_data) pairs for each segment
|
124 |
"""
|
|
|
17 |
|
18 |
def get_available_engines() -> List[str]:
|
19 |
"""Get a list of available TTS engines
|
20 |
+
|
21 |
Returns:
|
22 |
List[str]: List of available engine names
|
23 |
"""
|
24 |
available = []
|
25 |
+
|
26 |
if KOKORO_AVAILABLE:
|
27 |
available.append('kokoro')
|
28 |
+
|
29 |
if DIA_AVAILABLE:
|
30 |
available.append('dia')
|
31 |
+
|
32 |
if COSYVOICE2_AVAILABLE:
|
33 |
available.append('cosyvoice2')
|
34 |
+
|
35 |
# Dummy is always available
|
36 |
available.append('dummy')
|
37 |
+
|
38 |
return available
|
39 |
|
40 |
|
41 |
def get_tts_engine(engine_type: Optional[str] = None, lang_code: str = 'z') -> TTSBase:
|
42 |
"""Get a TTS engine instance
|
43 |
+
|
44 |
Args:
|
45 |
engine_type (str, optional): Type of engine to create ('kokoro', 'dia', 'cosyvoice2', 'dummy')
|
46 |
If None, the best available engine will be used
|
47 |
lang_code (str): Language code for the engine
|
48 |
+
|
49 |
Returns:
|
50 |
TTSBase: An instance of a TTS engine
|
51 |
"""
|
52 |
# Get available engines
|
53 |
available_engines = get_available_engines()
|
54 |
logger.info(f"Available TTS engines: {available_engines}")
|
55 |
+
|
56 |
# If engine_type is specified, try to create that specific engine
|
57 |
if engine_type is not None:
|
58 |
if engine_type == 'kokoro' and KOKORO_AVAILABLE:
|
|
|
69 |
return DummyTTS(lang_code)
|
70 |
else:
|
71 |
logger.warning(f"Requested engine '{engine_type}' is not available")
|
72 |
+
|
73 |
# If no specific engine is requested or the requested engine is not available,
|
74 |
# use the best available engine based on priority
|
75 |
priority_order = ['cosyvoice2', 'kokoro', 'dia', 'dummy']
|
|
|
84 |
return CosyVoice2TTS(lang_code)
|
85 |
elif engine == 'dummy':
|
86 |
return DummyTTS(lang_code)
|
87 |
+
|
88 |
# Fallback to dummy engine if no engines are available
|
89 |
logger.warning("No TTS engines available, falling back to dummy engine")
|
90 |
return DummyTTS(lang_code)
|
91 |
|
92 |
|
93 |
+
def generate_speech(text: str, engine_type: Optional[str] = None, lang_code: str = 'z',
|
94 |
voice: str = 'default', speed: float = 1.0) -> Optional[str]:
|
95 |
"""Generate speech using the specified or best available TTS engine
|
96 |
+
|
97 |
Args:
|
98 |
text (str): Input text to synthesize
|
99 |
engine_type (str, optional): Type of engine to use
|
100 |
lang_code (str): Language code
|
101 |
voice (str): Voice ID to use
|
102 |
speed (float): Speech speed multiplier
|
103 |
+
|
104 |
Returns:
|
105 |
Optional[str]: Path to the generated audio file or None if generation fails
|
106 |
"""
|
|
|
111 |
def generate_speech_stream(text: str, engine_type: Optional[str] = None, lang_code: str = 'z',
|
112 |
voice: str = 'default', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
|
113 |
"""Generate speech stream using the specified or best available TTS engine
|
114 |
+
|
115 |
Args:
|
116 |
text (str): Input text to synthesize
|
117 |
engine_type (str, optional): Type of engine to use
|
118 |
lang_code (str): Language code
|
119 |
voice (str): Voice ID to use
|
120 |
speed (float): Speech speed multiplier
|
121 |
+
|
122 |
Yields:
|
123 |
tuple: (sample_rate, audio_data) pairs for each segment
|
124 |
"""
|
utils/tts_dia.py
CHANGED
@@ -30,18 +30,18 @@ except ModuleNotFoundError as e:
|
|
30 |
|
31 |
def _get_model():
|
32 |
"""Lazy-load the Dia model
|
33 |
-
|
34 |
Returns:
|
35 |
Dia or None: The Dia model or None if not available
|
36 |
"""
|
37 |
if not DIA_AVAILABLE:
|
38 |
logger.warning("Dia TTS engine is not available")
|
39 |
return None
|
40 |
-
|
41 |
try:
|
42 |
import torch
|
43 |
from dia.model import Dia
|
44 |
-
|
45 |
# Initialize the model
|
46 |
model = Dia.from_pretrained()
|
47 |
logger.info("Dia model successfully loaded")
|
@@ -59,59 +59,59 @@ def _get_model():
|
|
59 |
|
60 |
class DiaTTS(TTSBase):
|
61 |
"""Dia TTS engine implementation
|
62 |
-
|
63 |
This engine uses the Dia model for TTS generation.
|
64 |
"""
|
65 |
-
|
66 |
def __init__(self, lang_code: str = 'z'):
|
67 |
"""Initialize the Dia TTS engine
|
68 |
-
|
69 |
Args:
|
70 |
lang_code (str): Language code for the engine
|
71 |
"""
|
72 |
super().__init__(lang_code)
|
73 |
self.model = None
|
74 |
-
|
75 |
def _ensure_model(self):
|
76 |
"""Ensure the model is loaded
|
77 |
-
|
78 |
Returns:
|
79 |
bool: True if model is available, False otherwise
|
80 |
"""
|
81 |
if self.model is None:
|
82 |
self.model = _get_model()
|
83 |
-
|
84 |
return self.model is not None
|
85 |
-
|
86 |
def generate_speech(self, text: str, voice: str = 'default', speed: float = 1.0) -> Optional[str]:
|
87 |
"""Generate speech using Dia TTS engine
|
88 |
-
|
89 |
Args:
|
90 |
text (str): Input text to synthesize
|
91 |
voice (str): Voice ID (not used in Dia)
|
92 |
speed (float): Speech speed multiplier (not used in Dia)
|
93 |
-
|
94 |
Returns:
|
95 |
Optional[str]: Path to the generated audio file or None if generation fails
|
96 |
"""
|
97 |
logger.info(f"Generating speech with Dia for text length: {len(text)}")
|
98 |
-
|
99 |
# Check if Dia is available
|
100 |
if not DIA_AVAILABLE:
|
101 |
logger.error("Dia TTS engine is not available")
|
102 |
return None
|
103 |
-
|
104 |
# Ensure model is loaded
|
105 |
if not self._ensure_model():
|
106 |
logger.error("Failed to load Dia model")
|
107 |
return None
|
108 |
-
|
109 |
try:
|
110 |
import torch
|
111 |
-
|
112 |
# Generate unique output path
|
113 |
output_path = self._generate_output_path(prefix="dia")
|
114 |
-
|
115 |
# Generate audio
|
116 |
with torch.inference_mode():
|
117 |
output_audio_np = self.model.generate(
|
@@ -124,7 +124,7 @@ class DiaTTS(TTSBase):
|
|
124 |
use_torch_compile=False,
|
125 |
verbose=False
|
126 |
)
|
127 |
-
|
128 |
if output_audio_np is not None:
|
129 |
logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})")
|
130 |
sf.write(output_path, output_audio_np, DEFAULT_SAMPLE_RATE)
|
@@ -133,7 +133,7 @@ class DiaTTS(TTSBase):
|
|
133 |
else:
|
134 |
logger.error("Dia model returned None for audio output")
|
135 |
return None
|
136 |
-
|
137 |
except ModuleNotFoundError as e:
|
138 |
if "dac" in str(e):
|
139 |
logger.error("Dia TTS engine failed due to missing 'dac' module")
|
@@ -143,33 +143,33 @@ class DiaTTS(TTSBase):
|
|
143 |
except Exception as e:
|
144 |
logger.error(f"Error generating speech with Dia: {str(e)}", exc_info=True)
|
145 |
return None
|
146 |
-
|
147 |
def generate_speech_stream(self, text: str, voice: str = 'default', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
|
148 |
"""Generate speech stream using Dia TTS engine
|
149 |
-
|
150 |
Args:
|
151 |
text (str): Input text to synthesize
|
152 |
voice (str): Voice ID (not used in Dia)
|
153 |
speed (float): Speech speed multiplier (not used in Dia)
|
154 |
-
|
155 |
Yields:
|
156 |
tuple: (sample_rate, audio_data) pairs for each segment
|
157 |
"""
|
158 |
logger.info(f"Generating speech stream with Dia for text length: {len(text)}")
|
159 |
-
|
160 |
# Check if Dia is available
|
161 |
if not DIA_AVAILABLE:
|
162 |
logger.error("Dia TTS engine is not available")
|
163 |
return
|
164 |
-
|
165 |
# Ensure model is loaded
|
166 |
if not self._ensure_model():
|
167 |
logger.error("Failed to load Dia model")
|
168 |
return
|
169 |
-
|
170 |
try:
|
171 |
import torch
|
172 |
-
|
173 |
# Generate audio
|
174 |
with torch.inference_mode():
|
175 |
output_audio_np = self.model.generate(
|
@@ -182,14 +182,14 @@ class DiaTTS(TTSBase):
|
|
182 |
use_torch_compile=False,
|
183 |
verbose=False
|
184 |
)
|
185 |
-
|
186 |
if output_audio_np is not None:
|
187 |
logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})")
|
188 |
yield DEFAULT_SAMPLE_RATE, output_audio_np
|
189 |
else:
|
190 |
logger.error("Dia model returned None for audio output")
|
191 |
return
|
192 |
-
|
193 |
except ModuleNotFoundError as e:
|
194 |
if "dac" in str(e):
|
195 |
logger.error("Dia TTS engine failed due to missing 'dac' module")
|
|
|
30 |
|
31 |
def _get_model():
|
32 |
"""Lazy-load the Dia model
|
33 |
+
|
34 |
Returns:
|
35 |
Dia or None: The Dia model or None if not available
|
36 |
"""
|
37 |
if not DIA_AVAILABLE:
|
38 |
logger.warning("Dia TTS engine is not available")
|
39 |
return None
|
40 |
+
|
41 |
try:
|
42 |
import torch
|
43 |
from dia.model import Dia
|
44 |
+
|
45 |
# Initialize the model
|
46 |
model = Dia.from_pretrained()
|
47 |
logger.info("Dia model successfully loaded")
|
|
|
59 |
|
60 |
class DiaTTS(TTSBase):
|
61 |
"""Dia TTS engine implementation
|
62 |
+
|
63 |
This engine uses the Dia model for TTS generation.
|
64 |
"""
|
65 |
+
|
66 |
def __init__(self, lang_code: str = 'z'):
|
67 |
"""Initialize the Dia TTS engine
|
68 |
+
|
69 |
Args:
|
70 |
lang_code (str): Language code for the engine
|
71 |
"""
|
72 |
super().__init__(lang_code)
|
73 |
self.model = None
|
74 |
+
|
75 |
def _ensure_model(self):
|
76 |
"""Ensure the model is loaded
|
77 |
+
|
78 |
Returns:
|
79 |
bool: True if model is available, False otherwise
|
80 |
"""
|
81 |
if self.model is None:
|
82 |
self.model = _get_model()
|
83 |
+
|
84 |
return self.model is not None
|
85 |
+
|
86 |
def generate_speech(self, text: str, voice: str = 'default', speed: float = 1.0) -> Optional[str]:
|
87 |
"""Generate speech using Dia TTS engine
|
88 |
+
|
89 |
Args:
|
90 |
text (str): Input text to synthesize
|
91 |
voice (str): Voice ID (not used in Dia)
|
92 |
speed (float): Speech speed multiplier (not used in Dia)
|
93 |
+
|
94 |
Returns:
|
95 |
Optional[str]: Path to the generated audio file or None if generation fails
|
96 |
"""
|
97 |
logger.info(f"Generating speech with Dia for text length: {len(text)}")
|
98 |
+
|
99 |
# Check if Dia is available
|
100 |
if not DIA_AVAILABLE:
|
101 |
logger.error("Dia TTS engine is not available")
|
102 |
return None
|
103 |
+
|
104 |
# Ensure model is loaded
|
105 |
if not self._ensure_model():
|
106 |
logger.error("Failed to load Dia model")
|
107 |
return None
|
108 |
+
|
109 |
try:
|
110 |
import torch
|
111 |
+
|
112 |
# Generate unique output path
|
113 |
output_path = self._generate_output_path(prefix="dia")
|
114 |
+
|
115 |
# Generate audio
|
116 |
with torch.inference_mode():
|
117 |
output_audio_np = self.model.generate(
|
|
|
124 |
use_torch_compile=False,
|
125 |
verbose=False
|
126 |
)
|
127 |
+
|
128 |
if output_audio_np is not None:
|
129 |
logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})")
|
130 |
sf.write(output_path, output_audio_np, DEFAULT_SAMPLE_RATE)
|
|
|
133 |
else:
|
134 |
logger.error("Dia model returned None for audio output")
|
135 |
return None
|
136 |
+
|
137 |
except ModuleNotFoundError as e:
|
138 |
if "dac" in str(e):
|
139 |
logger.error("Dia TTS engine failed due to missing 'dac' module")
|
|
|
143 |
except Exception as e:
|
144 |
logger.error(f"Error generating speech with Dia: {str(e)}", exc_info=True)
|
145 |
return None
|
146 |
+
|
147 |
def generate_speech_stream(self, text: str, voice: str = 'default', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
|
148 |
"""Generate speech stream using Dia TTS engine
|
149 |
+
|
150 |
Args:
|
151 |
text (str): Input text to synthesize
|
152 |
voice (str): Voice ID (not used in Dia)
|
153 |
speed (float): Speech speed multiplier (not used in Dia)
|
154 |
+
|
155 |
Yields:
|
156 |
tuple: (sample_rate, audio_data) pairs for each segment
|
157 |
"""
|
158 |
logger.info(f"Generating speech stream with Dia for text length: {len(text)}")
|
159 |
+
|
160 |
# Check if Dia is available
|
161 |
if not DIA_AVAILABLE:
|
162 |
logger.error("Dia TTS engine is not available")
|
163 |
return
|
164 |
+
|
165 |
# Ensure model is loaded
|
166 |
if not self._ensure_model():
|
167 |
logger.error("Failed to load Dia model")
|
168 |
return
|
169 |
+
|
170 |
try:
|
171 |
import torch
|
172 |
+
|
173 |
# Generate audio
|
174 |
with torch.inference_mode():
|
175 |
output_audio_np = self.model.generate(
|
|
|
182 |
use_torch_compile=False,
|
183 |
verbose=False
|
184 |
)
|
185 |
+
|
186 |
if output_audio_np is not None:
|
187 |
logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})")
|
188 |
yield DEFAULT_SAMPLE_RATE, output_audio_np
|
189 |
else:
|
190 |
logger.error("Dia model returned None for audio output")
|
191 |
return
|
192 |
+
|
193 |
except ModuleNotFoundError as e:
|
194 |
if "dac" in str(e):
|
195 |
logger.error("Dia TTS engine failed due to missing 'dac' module")
|
utils/tts_dummy.py
CHANGED
@@ -12,54 +12,54 @@ logger = logging.getLogger(__name__)
|
|
12 |
|
13 |
class DummyTTS(TTSBase):
|
14 |
"""Dummy TTS engine that generates sine wave audio
|
15 |
-
|
16 |
This class is used as a fallback when no other TTS engine is available.
|
17 |
"""
|
18 |
-
|
19 |
def generate_speech(self, text: str, voice: str = 'default', speed: float = 1.0) -> str:
|
20 |
"""Generate a dummy sine wave audio file
|
21 |
-
|
22 |
Args:
|
23 |
text (str): Input text (not used)
|
24 |
voice (str): Voice ID (not used)
|
25 |
speed (float): Speech speed multiplier (not used)
|
26 |
-
|
27 |
Returns:
|
28 |
str: Path to the generated audio file
|
29 |
"""
|
30 |
logger.info(f"Generating dummy speech for text length: {len(text)}")
|
31 |
-
|
32 |
# Generate a simple sine wave
|
33 |
sample_rate = 24000
|
34 |
duration = min(len(text) / 20, 10) # Rough approximation of speech duration
|
35 |
t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
|
36 |
audio = 0.5 * np.sin(2 * np.pi * 440 * t) # 440 Hz sine wave
|
37 |
-
|
38 |
# Save to file
|
39 |
output_path = self._generate_output_path(prefix="dummy")
|
40 |
sf.write(output_path, audio, sample_rate)
|
41 |
-
|
42 |
logger.info(f"Generated dummy audio: {output_path}")
|
43 |
return output_path
|
44 |
-
|
45 |
def generate_speech_stream(self, text: str, voice: str = 'default', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
|
46 |
"""Generate a dummy sine wave audio stream
|
47 |
-
|
48 |
Args:
|
49 |
text (str): Input text (not used)
|
50 |
voice (str): Voice ID (not used)
|
51 |
speed (float): Speech speed multiplier (not used)
|
52 |
-
|
53 |
Yields:
|
54 |
tuple: (sample_rate, audio_data) pairs
|
55 |
"""
|
56 |
logger.info(f"Generating dummy speech stream for text length: {len(text)}")
|
57 |
-
|
58 |
# Generate a simple sine wave
|
59 |
sample_rate = 24000
|
60 |
duration = min(len(text) / 20, 10) # Rough approximation of speech duration
|
61 |
t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
|
62 |
audio = 0.5 * np.sin(2 * np.pi * 440 * t) # 440 Hz sine wave
|
63 |
-
|
64 |
# Yield the audio data
|
65 |
yield sample_rate, audio
|
|
|
12 |
|
13 |
class DummyTTS(TTSBase):
|
14 |
"""Dummy TTS engine that generates sine wave audio
|
15 |
+
|
16 |
This class is used as a fallback when no other TTS engine is available.
|
17 |
"""
|
18 |
+
|
19 |
def generate_speech(self, text: str, voice: str = 'default', speed: float = 1.0) -> str:
|
20 |
"""Generate a dummy sine wave audio file
|
21 |
+
|
22 |
Args:
|
23 |
text (str): Input text (not used)
|
24 |
voice (str): Voice ID (not used)
|
25 |
speed (float): Speech speed multiplier (not used)
|
26 |
+
|
27 |
Returns:
|
28 |
str: Path to the generated audio file
|
29 |
"""
|
30 |
logger.info(f"Generating dummy speech for text length: {len(text)}")
|
31 |
+
|
32 |
# Generate a simple sine wave
|
33 |
sample_rate = 24000
|
34 |
duration = min(len(text) / 20, 10) # Rough approximation of speech duration
|
35 |
t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
|
36 |
audio = 0.5 * np.sin(2 * np.pi * 440 * t) # 440 Hz sine wave
|
37 |
+
|
38 |
# Save to file
|
39 |
output_path = self._generate_output_path(prefix="dummy")
|
40 |
sf.write(output_path, audio, sample_rate)
|
41 |
+
|
42 |
logger.info(f"Generated dummy audio: {output_path}")
|
43 |
return output_path
|
44 |
+
|
45 |
def generate_speech_stream(self, text: str, voice: str = 'default', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
|
46 |
"""Generate a dummy sine wave audio stream
|
47 |
+
|
48 |
Args:
|
49 |
text (str): Input text (not used)
|
50 |
voice (str): Voice ID (not used)
|
51 |
speed (float): Speech speed multiplier (not used)
|
52 |
+
|
53 |
Yields:
|
54 |
tuple: (sample_rate, audio_data) pairs
|
55 |
"""
|
56 |
logger.info(f"Generating dummy speech stream for text length: {len(text)}")
|
57 |
+
|
58 |
# Generate a simple sine wave
|
59 |
sample_rate = 24000
|
60 |
duration = min(len(text) / 20, 10) # Rough approximation of speech duration
|
61 |
t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
|
62 |
audio = 0.5 * np.sin(2 * np.pi * 440 * t) # 440 Hz sine wave
|
63 |
+
|
64 |
# Yield the audio data
|
65 |
yield sample_rate, audio
|
utils/tts_kokoro.py
CHANGED
@@ -25,17 +25,17 @@ except Exception as e:
|
|
25 |
|
26 |
def _get_pipeline(lang_code: str = 'z'):
|
27 |
"""Lazy-load the Kokoro pipeline
|
28 |
-
|
29 |
Args:
|
30 |
lang_code (str): Language code for the pipeline
|
31 |
-
|
32 |
Returns:
|
33 |
KPipeline or None: The Kokoro pipeline or None if not available
|
34 |
"""
|
35 |
if not KOKORO_AVAILABLE:
|
36 |
logger.warning("Kokoro TTS engine is not available")
|
37 |
return None
|
38 |
-
|
39 |
try:
|
40 |
pipeline = KPipeline(lang_code=lang_code)
|
41 |
logger.info("Kokoro pipeline successfully loaded")
|
@@ -47,93 +47,93 @@ def _get_pipeline(lang_code: str = 'z'):
|
|
47 |
|
48 |
class KokoroTTS(TTSBase):
|
49 |
"""Kokoro TTS engine implementation
|
50 |
-
|
51 |
This engine uses the Kokoro library for TTS generation.
|
52 |
"""
|
53 |
-
|
54 |
def __init__(self, lang_code: str = 'z'):
|
55 |
"""Initialize the Kokoro TTS engine
|
56 |
-
|
57 |
Args:
|
58 |
lang_code (str): Language code for the engine
|
59 |
"""
|
60 |
super().__init__(lang_code)
|
61 |
self.pipeline = None
|
62 |
-
|
63 |
def _ensure_pipeline(self):
|
64 |
"""Ensure the pipeline is loaded
|
65 |
-
|
66 |
Returns:
|
67 |
bool: True if pipeline is available, False otherwise
|
68 |
"""
|
69 |
if self.pipeline is None:
|
70 |
self.pipeline = _get_pipeline(self.lang_code)
|
71 |
-
|
72 |
return self.pipeline is not None
|
73 |
-
|
74 |
def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Optional[str]:
|
75 |
"""Generate speech using Kokoro TTS engine
|
76 |
-
|
77 |
Args:
|
78 |
text (str): Input text to synthesize
|
79 |
voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
|
80 |
speed (float): Speech speed multiplier (0.5 to 2.0)
|
81 |
-
|
82 |
Returns:
|
83 |
Optional[str]: Path to the generated audio file or None if generation fails
|
84 |
"""
|
85 |
logger.info(f"Generating speech with Kokoro for text length: {len(text)}")
|
86 |
-
|
87 |
# Check if Kokoro is available
|
88 |
if not KOKORO_AVAILABLE:
|
89 |
logger.error("Kokoro TTS engine is not available")
|
90 |
return None
|
91 |
-
|
92 |
# Ensure pipeline is loaded
|
93 |
if not self._ensure_pipeline():
|
94 |
logger.error("Failed to load Kokoro pipeline")
|
95 |
return None
|
96 |
-
|
97 |
try:
|
98 |
# Generate unique output path
|
99 |
output_path = self._generate_output_path(prefix="kokoro")
|
100 |
-
|
101 |
# Generate speech
|
102 |
generator = self.pipeline(text, voice=voice, speed=speed)
|
103 |
for _, _, audio in generator:
|
104 |
logger.info(f"Saving Kokoro audio to {output_path}")
|
105 |
sf.write(output_path, audio, 24000)
|
106 |
break
|
107 |
-
|
108 |
logger.info(f"Kokoro audio generation complete: {output_path}")
|
109 |
return output_path
|
110 |
except Exception as e:
|
111 |
logger.error(f"Error generating speech with Kokoro: {str(e)}", exc_info=True)
|
112 |
return None
|
113 |
-
|
114 |
def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
|
115 |
"""Generate speech stream using Kokoro TTS engine
|
116 |
-
|
117 |
Args:
|
118 |
text (str): Input text to synthesize
|
119 |
voice (str): Voice ID to use
|
120 |
speed (float): Speech speed multiplier
|
121 |
-
|
122 |
Yields:
|
123 |
tuple: (sample_rate, audio_data) pairs for each segment
|
124 |
"""
|
125 |
logger.info(f"Generating speech stream with Kokoro for text length: {len(text)}")
|
126 |
-
|
127 |
# Check if Kokoro is available
|
128 |
if not KOKORO_AVAILABLE:
|
129 |
logger.error("Kokoro TTS engine is not available")
|
130 |
return
|
131 |
-
|
132 |
# Ensure pipeline is loaded
|
133 |
if not self._ensure_pipeline():
|
134 |
logger.error("Failed to load Kokoro pipeline")
|
135 |
return
|
136 |
-
|
137 |
try:
|
138 |
# Generate speech stream
|
139 |
generator = self.pipeline(text, voice=voice, speed=speed)
|
|
|
25 |
|
26 |
def _get_pipeline(lang_code: str = 'z'):
|
27 |
"""Lazy-load the Kokoro pipeline
|
28 |
+
|
29 |
Args:
|
30 |
lang_code (str): Language code for the pipeline
|
31 |
+
|
32 |
Returns:
|
33 |
KPipeline or None: The Kokoro pipeline or None if not available
|
34 |
"""
|
35 |
if not KOKORO_AVAILABLE:
|
36 |
logger.warning("Kokoro TTS engine is not available")
|
37 |
return None
|
38 |
+
|
39 |
try:
|
40 |
pipeline = KPipeline(lang_code=lang_code)
|
41 |
logger.info("Kokoro pipeline successfully loaded")
|
|
|
47 |
|
48 |
class KokoroTTS(TTSBase):
|
49 |
"""Kokoro TTS engine implementation
|
50 |
+
|
51 |
This engine uses the Kokoro library for TTS generation.
|
52 |
"""
|
53 |
+
|
54 |
def __init__(self, lang_code: str = 'z'):
|
55 |
"""Initialize the Kokoro TTS engine
|
56 |
+
|
57 |
Args:
|
58 |
lang_code (str): Language code for the engine
|
59 |
"""
|
60 |
super().__init__(lang_code)
|
61 |
self.pipeline = None
|
62 |
+
|
63 |
def _ensure_pipeline(self):
|
64 |
"""Ensure the pipeline is loaded
|
65 |
+
|
66 |
Returns:
|
67 |
bool: True if pipeline is available, False otherwise
|
68 |
"""
|
69 |
if self.pipeline is None:
|
70 |
self.pipeline = _get_pipeline(self.lang_code)
|
71 |
+
|
72 |
return self.pipeline is not None
|
73 |
+
|
74 |
def generate_speech(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Optional[str]:
|
75 |
"""Generate speech using Kokoro TTS engine
|
76 |
+
|
77 |
Args:
|
78 |
text (str): Input text to synthesize
|
79 |
voice (str): Voice ID to use (e.g., 'af_heart', 'af_bella', etc.)
|
80 |
speed (float): Speech speed multiplier (0.5 to 2.0)
|
81 |
+
|
82 |
Returns:
|
83 |
Optional[str]: Path to the generated audio file or None if generation fails
|
84 |
"""
|
85 |
logger.info(f"Generating speech with Kokoro for text length: {len(text)}")
|
86 |
+
|
87 |
# Check if Kokoro is available
|
88 |
if not KOKORO_AVAILABLE:
|
89 |
logger.error("Kokoro TTS engine is not available")
|
90 |
return None
|
91 |
+
|
92 |
# Ensure pipeline is loaded
|
93 |
if not self._ensure_pipeline():
|
94 |
logger.error("Failed to load Kokoro pipeline")
|
95 |
return None
|
96 |
+
|
97 |
try:
|
98 |
# Generate unique output path
|
99 |
output_path = self._generate_output_path(prefix="kokoro")
|
100 |
+
|
101 |
# Generate speech
|
102 |
generator = self.pipeline(text, voice=voice, speed=speed)
|
103 |
for _, _, audio in generator:
|
104 |
logger.info(f"Saving Kokoro audio to {output_path}")
|
105 |
sf.write(output_path, audio, 24000)
|
106 |
break
|
107 |
+
|
108 |
logger.info(f"Kokoro audio generation complete: {output_path}")
|
109 |
return output_path
|
110 |
except Exception as e:
|
111 |
logger.error(f"Error generating speech with Kokoro: {str(e)}", exc_info=True)
|
112 |
return None
|
113 |
+
|
114 |
def generate_speech_stream(self, text: str, voice: str = 'af_heart', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
|
115 |
"""Generate speech stream using Kokoro TTS engine
|
116 |
+
|
117 |
Args:
|
118 |
text (str): Input text to synthesize
|
119 |
voice (str): Voice ID to use
|
120 |
speed (float): Speech speed multiplier
|
121 |
+
|
122 |
Yields:
|
123 |
tuple: (sample_rate, audio_data) pairs for each segment
|
124 |
"""
|
125 |
logger.info(f"Generating speech stream with Kokoro for text length: {len(text)}")
|
126 |
+
|
127 |
# Check if Kokoro is available
|
128 |
if not KOKORO_AVAILABLE:
|
129 |
logger.error("Kokoro TTS engine is not available")
|
130 |
return
|
131 |
+
|
132 |
# Ensure pipeline is loaded
|
133 |
if not self._ensure_pipeline():
|
134 |
logger.error("Failed to load Kokoro pipeline")
|
135 |
return
|
136 |
+
|
137 |
try:
|
138 |
# Generate speech stream
|
139 |
generator = self.pipeline(text, voice=voice, speed=speed)
|