Michael Hu commited on
Commit
462d6de
·
1 Parent(s): 0aa0b99

Remove Streamlit dependency and simplify Dia TTS dependency handling

Browse files

- Remove streamlit from pyproject.toml and requirements.txt
- Change gradio from pinned to minimum version (>=5.9.1)
- Delete dependency_installer utility module
- Simplify Dia TTS provider to only check dependencies without auto-installation
- Replace _check_and_install_dia_dependencies with _check_dia_dependencies
- Add helpful installation messages for missing dac and dia modules

pyproject.toml CHANGED
@@ -9,8 +9,7 @@ license = {text = "MIT"}
9
  readme = "README.md"
10
  requires-python = ">=3.10"
11
  dependencies = [
12
- "streamlit>=1.44.1",
13
- "gradio==5.9.1",
14
  "nltk>=3.8",
15
  "librosa>=0.10",
16
  "ffmpeg-python>=0.2",
 
9
  readme = "README.md"
10
  requires-python = ">=3.10"
11
  dependencies = [
12
+ "gradio>=5.9.1",
 
13
  "nltk>=3.8",
14
  "librosa>=0.10",
15
  "ffmpeg-python>=0.2",
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
- streamlit>=1.44.1
2
- gradio==5.9.1
3
  nltk>=3.8
4
  librosa>=0.10
5
  ffmpeg-python>=0.2
 
1
+ gradio>=5.9.1
 
2
  nltk>=3.8
3
  librosa>=0.10
4
  ffmpeg-python>=0.2
src/infrastructure/tts/dia_provider.py CHANGED
@@ -11,7 +11,7 @@ if TYPE_CHECKING:
11
 
12
  from ..base.tts_provider_base import TTSProviderBase
13
  from ...domain.exceptions import SpeechSynthesisException
14
- from ..utils.dependency_installer import get_dependency_installer
15
 
16
  logger = logging.getLogger(__name__)
17
 
@@ -20,8 +20,8 @@ DIA_AVAILABLE = False
20
  DEFAULT_SAMPLE_RATE = 24000
21
 
22
  # Try to import Dia dependencies
23
- def _check_and_install_dia_dependencies():
24
- """Check and install Dia dependencies if needed."""
25
  global DIA_AVAILABLE
26
 
27
  logger.info("🔍 Checking Dia TTS dependencies...")
@@ -41,49 +41,24 @@ def _check_and_install_dia_dependencies():
41
  except ImportError as e:
42
  logger.warning(f"⚠️ Dia TTS engine dependencies not available: {e}")
43
  logger.info(f"ImportError details: {type(e).__name__}: {e}")
 
 
44
  except ModuleNotFoundError as e:
45
  if "dac" in str(e):
46
  logger.warning("❌ Dia TTS engine is not available due to missing 'dac' module")
 
47
  elif "dia" in str(e):
48
  logger.warning("❌ Dia TTS engine is not available due to missing 'dia' module")
 
49
  else:
50
  logger.warning(f"❌ Dia TTS engine is not available: {str(e)}")
51
  logger.info(f"ModuleNotFoundError details: {type(e).__name__}: {e}")
52
-
53
- # Try to install missing dependencies
54
- logger.info("🔧 Attempting to install Dia TTS dependencies...")
55
- try:
56
- installer = get_dependency_installer()
57
- success, errors = installer.install_dia_dependencies()
58
-
59
- if success:
60
- logger.info("✅ Successfully installed Dia TTS dependencies")
61
- # Try importing again after installation
62
- try:
63
- logger.info("Re-attempting import after installation...")
64
- import torch
65
- from dia.model import Dia
66
- DIA_AVAILABLE = True
67
- logger.info("🎉 Dia TTS engine is now available after installation")
68
- return True
69
- except Exception as e:
70
- logger.error(f"❌ Dia TTS still not available after installation: {e}")
71
- logger.info(f"Post-installation import error: {type(e).__name__}: {e}")
72
- DIA_AVAILABLE = False
73
- return False
74
- else:
75
- logger.error(f"❌ Failed to install Dia TTS dependencies: {errors}")
76
- DIA_AVAILABLE = False
77
- return False
78
- except Exception as e:
79
- logger.error(f"❌ Error during dependency installation: {e}")
80
- logger.info(f"Installation error details: {type(e).__name__}: {e}")
81
  DIA_AVAILABLE = False
82
  return False
83
 
84
  # Initial check
85
  logger.info("🚀 Initializing Dia TTS provider...")
86
- _check_and_install_dia_dependencies()
87
 
88
 
89
  class DiaTTSProvider(TTSProviderBase):
@@ -105,14 +80,14 @@ class DiaTTSProvider(TTSProviderBase):
105
  if self.model is None:
106
  logger.info("🔄 Ensuring Dia model is loaded...")
107
 
108
- # If Dia is not available, try to install dependencies
109
  if not DIA_AVAILABLE:
110
- logger.info("⚠️ Dia not available, attempting to install dependencies...")
111
- if _check_and_install_dia_dependencies():
112
  DIA_AVAILABLE = True
113
- logger.info("✅ Dependencies installed, Dia is now available")
114
  else:
115
- logger.error("❌ Failed to install dependencies, Dia remains unavailable")
116
  return False
117
 
118
  if DIA_AVAILABLE:
 
11
 
12
  from ..base.tts_provider_base import TTSProviderBase
13
  from ...domain.exceptions import SpeechSynthesisException
14
+
15
 
16
  logger = logging.getLogger(__name__)
17
 
 
20
  DEFAULT_SAMPLE_RATE = 24000
21
 
22
  # Try to import Dia dependencies
23
+ def _check_dia_dependencies():
24
+ """Check if Dia dependencies are available."""
25
  global DIA_AVAILABLE
26
 
27
  logger.info("🔍 Checking Dia TTS dependencies...")
 
41
  except ImportError as e:
42
  logger.warning(f"⚠️ Dia TTS engine dependencies not available: {e}")
43
  logger.info(f"ImportError details: {type(e).__name__}: {e}")
44
+ DIA_AVAILABLE = False
45
+ return False
46
  except ModuleNotFoundError as e:
47
  if "dac" in str(e):
48
  logger.warning("❌ Dia TTS engine is not available due to missing 'dac' module")
49
+ logger.info("Please install descript-audio-codec: pip install descript-audio-codec")
50
  elif "dia" in str(e):
51
  logger.warning("❌ Dia TTS engine is not available due to missing 'dia' module")
52
+ logger.info("Please install dia: pip install git+https://github.com/nari-labs/dia.git")
53
  else:
54
  logger.warning(f"❌ Dia TTS engine is not available: {str(e)}")
55
  logger.info(f"ModuleNotFoundError details: {type(e).__name__}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  DIA_AVAILABLE = False
57
  return False
58
 
59
  # Initial check
60
  logger.info("🚀 Initializing Dia TTS provider...")
61
+ _check_dia_dependencies()
62
 
63
 
64
  class DiaTTSProvider(TTSProviderBase):
 
80
  if self.model is None:
81
  logger.info("🔄 Ensuring Dia model is loaded...")
82
 
83
+ # If Dia is not available, check dependencies again
84
  if not DIA_AVAILABLE:
85
+ logger.info("⚠️ Dia not available, checking dependencies again...")
86
+ if _check_dia_dependencies():
87
  DIA_AVAILABLE = True
88
+ logger.info("✅ Dependencies are now available")
89
  else:
90
+ logger.error("❌ Dependencies still not available")
91
  return False
92
 
93
  if DIA_AVAILABLE:
src/infrastructure/utils/dependency_installer.py DELETED
@@ -1,304 +0,0 @@
1
- """Automatic dependency installer for TTS providers."""
2
-
3
- import logging
4
- import subprocess
5
- import sys
6
- import importlib
7
- from typing import List, Dict, Optional, Tuple
8
- import os
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
-
13
- class DependencyInstaller:
14
- """Utility class for automatically installing missing dependencies."""
15
-
16
- def __init__(self):
17
- """Initialize the dependency installer."""
18
- self.installed_packages = set()
19
-
20
- def check_module_available(self, module_name: str) -> bool:
21
- """
22
- Check if a module is available for import.
23
-
24
- Args:
25
- module_name: Name of the module to check
26
-
27
- Returns:
28
- bool: True if module is available, False otherwise
29
- """
30
- try:
31
- importlib.import_module(module_name)
32
- return True
33
- except ImportError:
34
- return False
35
-
36
- def install_package(self, package_name: str, upgrade: bool = False) -> bool:
37
- """
38
- Install a package using pip.
39
-
40
- Args:
41
- package_name: Name of the package to install
42
- upgrade: Whether to upgrade if already installed
43
-
44
- Returns:
45
- bool: True if installation succeeded, False otherwise
46
- """
47
- if package_name in self.installed_packages:
48
- logger.info(f"Package {package_name} already installed in this session")
49
- return True
50
-
51
- try:
52
- cmd = [sys.executable, "-m", "pip", "install"]
53
- if upgrade:
54
- cmd.append("--upgrade")
55
- cmd.append(package_name)
56
-
57
- logger.info(f"Installing package: {package_name}")
58
- result = subprocess.run(
59
- cmd,
60
- capture_output=True,
61
- text=True,
62
- timeout=300 # 5 minute timeout
63
- )
64
-
65
- if result.returncode == 0:
66
- logger.info(f"Successfully installed {package_name}")
67
- self.installed_packages.add(package_name)
68
- return True
69
- else:
70
- logger.error(f"Failed to install {package_name}: {result.stderr}")
71
- return False
72
-
73
- except subprocess.TimeoutExpired:
74
- logger.error(f"Installation of {package_name} timed out")
75
- return False
76
- except Exception as e:
77
- logger.error(f"Error installing {package_name}: {e}")
78
- return False
79
-
80
- def install_from_git(self, git_url: str, package_name: Optional[str] = None) -> bool:
81
- """
82
- Install a package from a git repository.
83
-
84
- Args:
85
- git_url: Git repository URL
86
- package_name: Optional package name for tracking
87
-
88
- Returns:
89
- bool: True if installation succeeded, False otherwise
90
- """
91
- package_name = package_name or git_url.split('/')[-1].replace('.git', '')
92
-
93
- if package_name in self.installed_packages:
94
- logger.info(f"Package {package_name} already installed in this session")
95
- return True
96
-
97
- try:
98
- cmd = [sys.executable, "-m", "pip", "install", f"git+{git_url}"]
99
-
100
- logger.info(f"Installing package from git: {git_url}")
101
- result = subprocess.run(
102
- cmd,
103
- capture_output=True,
104
- text=True,
105
- timeout=600 # 10 minute timeout for git installs
106
- )
107
-
108
- if result.returncode == 0:
109
- logger.info(f"Successfully installed {package_name} from git")
110
- self.installed_packages.add(package_name)
111
- return True
112
- else:
113
- logger.error(f"Failed to install {package_name} from git: {result.stderr}")
114
- return False
115
-
116
- except subprocess.TimeoutExpired:
117
- logger.error(f"Git installation of {package_name} timed out")
118
- return False
119
- except Exception as e:
120
- logger.error(f"Error installing {package_name} from git: {e}")
121
- return False
122
-
123
- def install_dia_dependencies(self) -> Tuple[bool, List[str]]:
124
- """
125
- Install all dependencies required for Dia TTS.
126
-
127
- Returns:
128
- Tuple[bool, List[str]]: (success, list of error messages)
129
- """
130
- errors = []
131
-
132
- # Check if Dia is already available
133
- if self.check_module_available("dia"):
134
- logger.info("Dia TTS is already available")
135
- return True, []
136
-
137
- # Install Dia TTS from git - this will automatically install all dependencies
138
- # including descript-audio-codec as specified in pyproject.toml
139
- logger.info("Installing Dia TTS and all dependencies from GitHub")
140
- if self.install_from_git("https://github.com/nari-labs/dia.git", "dia"):
141
- logger.info("Successfully installed Dia TTS and dependencies")
142
- return True, []
143
- else:
144
- errors.append("Failed to install Dia TTS from git")
145
-
146
- # Fallback: try installing individual dependencies if git install fails
147
- logger.info("Git install failed, trying individual dependencies...")
148
- dependencies = [
149
- ("torch", "torch"),
150
- ("transformers", "transformers"),
151
- ("accelerate", "accelerate"),
152
- ("soundfile", "soundfile"),
153
- ("dac", "descript-audio-codec"),
154
- ]
155
-
156
- success = True
157
- for module_name, package_name in dependencies:
158
- if not self.check_module_available(module_name):
159
- logger.info(f"Installing missing dependency: {package_name}")
160
- if not self.install_package(package_name):
161
- errors.append(f"Failed to install {package_name}")
162
- success = False
163
-
164
- # Try installing Dia again after dependencies
165
- if success and not self.check_module_available("dia"):
166
- if self.install_from_git("https://github.com/nari-labs/dia.git", "dia"):
167
- return True, []
168
- else:
169
- errors.append("Failed to install Dia TTS after installing dependencies")
170
-
171
- return success and len(errors) == 1, errors # Only the initial git error if dependencies succeeded
172
-
173
- def install_dependencies_for_provider(self, provider_name: str) -> Tuple[bool, List[str]]:
174
- """
175
- Install dependencies for a specific TTS provider.
176
-
177
- Args:
178
- provider_name: Name of the TTS provider
179
-
180
- Returns:
181
- Tuple[bool, List[str]]: (success, list of error messages)
182
- """
183
- if provider_name.lower() == "dia":
184
- return self.install_dia_dependencies()
185
- else:
186
- return False, [f"Unknown provider: {provider_name}"]
187
-
188
- def verify_installation(self, module_name: str) -> bool:
189
- """
190
- Verify that a module was installed correctly.
191
-
192
- Args:
193
- module_name: Name of the module to verify
194
-
195
- Returns:
196
- bool: True if module can be imported, False otherwise
197
- """
198
- try:
199
- # Clear import cache to ensure fresh import
200
- if module_name in sys.modules:
201
- del sys.modules[module_name]
202
-
203
- importlib.import_module(module_name)
204
- logger.info(f"Successfully verified installation of {module_name}")
205
- return True
206
- except ImportError as e:
207
- logger.error(f"Failed to verify installation of {module_name}: {e}")
208
- return False
209
-
210
- def get_installation_status(self) -> Dict[str, bool]:
211
- """
212
- Get the installation status of key dependencies.
213
-
214
- Returns:
215
- Dict[str, bool]: Dictionary mapping module names to availability status
216
- """
217
- modules_to_check = [
218
- "torch",
219
- "transformers",
220
- "accelerate",
221
- "soundfile",
222
- "numpy",
223
- "dac",
224
- "dia"
225
- ]
226
-
227
- status = {}
228
- for module in modules_to_check:
229
- status[module] = self.check_module_available(module)
230
-
231
- return status
232
-
233
- def install_with_retry(self, package_name: str, max_retries: int = 3) -> bool:
234
- """
235
- Install a package with retry logic.
236
-
237
- Args:
238
- package_name: Name of the package to install
239
- max_retries: Maximum number of retry attempts
240
-
241
- Returns:
242
- bool: True if installation succeeded, False otherwise
243
- """
244
- for attempt in range(max_retries):
245
- if self.install_package(package_name):
246
- return True
247
-
248
- if attempt < max_retries - 1:
249
- logger.warning(f"Installation attempt {attempt + 1} failed for {package_name}, retrying...")
250
- else:
251
- logger.error(f"All {max_retries} installation attempts failed for {package_name}")
252
-
253
- return False
254
-
255
-
256
- # Global instance for reuse
257
- _dependency_installer = None
258
-
259
-
260
- def get_dependency_installer() -> DependencyInstaller:
261
- """
262
- Get a global dependency installer instance.
263
-
264
- Returns:
265
- DependencyInstaller: Global dependency installer instance
266
- """
267
- global _dependency_installer
268
- if _dependency_installer is None:
269
- _dependency_installer = DependencyInstaller()
270
- return _dependency_installer
271
-
272
-
273
- def install_dia_dependencies() -> Tuple[bool, List[str]]:
274
- """
275
- Convenience function to install Dia TTS dependencies.
276
-
277
- Returns:
278
- Tuple[bool, List[str]]: (success, list of error messages)
279
- """
280
- installer = get_dependency_installer()
281
- return installer.install_dia_dependencies()
282
-
283
-
284
- def check_and_install_module(module_name: str, package_name: Optional[str] = None) -> bool:
285
- """
286
- Check if a module is available and install it if not.
287
-
288
- Args:
289
- module_name: Name of the module to check
290
- package_name: Name of the package to install (defaults to module_name)
291
-
292
- Returns:
293
- bool: True if module is available after check/install, False otherwise
294
- """
295
- installer = get_dependency_installer()
296
-
297
- if installer.check_module_available(module_name):
298
- return True
299
-
300
- package_name = package_name or module_name
301
- if installer.install_package(package_name):
302
- return installer.verify_installation(module_name)
303
-
304
- return False