yl4579 commited on
Commit
2f84825
·
verified ·
1 Parent(s): 1b8d1f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -338
app.py CHANGED
@@ -1,350 +1,117 @@
1
- import gradio as gr
 
 
2
  import torch
3
- import torchaudio
4
- import numpy as np
5
- import tempfile
6
- import time
7
  from pathlib import Path
8
  from huggingface_hub import hf_hub_download
9
- import os
10
-
11
- # Import the inference module (assuming it's named 'infer.py' based on the notebook)
12
- from infer import DMOInference
13
-
14
- # Global model instance
15
- model = None
16
- device = "cuda" if torch.cuda.is_available() else "cpu"
17
-
18
- def download_models():
19
- """Download models from HuggingFace Hub."""
20
- try:
21
- print("Downloading models from HuggingFace...")
22
-
23
- # Download student model
24
- student_path = hf_hub_download(
25
- repo_id="yl4579/DMOSpeech2",
26
- filename="model_85000.pt",
27
- cache_dir="./models"
28
- )
29
-
30
- # Download duration predictor
31
- duration_path = hf_hub_download(
32
- repo_id="yl4579/DMOSpeech2",
33
- filename="model_1500.pt",
34
- cache_dir="./models"
35
- )
36
-
37
- print(f"Student model: {student_path}")
38
- print(f"Duration model: {duration_path}")
39
-
40
- return student_path, duration_path
41
-
42
- except Exception as e:
43
- print(f"Error downloading models: {e}")
44
- return None, None
45
 
46
- def initialize_model():
47
- """Initialize the model on startup."""
48
- global model
49
 
50
- try:
51
- # Download models
52
- student_path, duration_path = download_models()
53
-
54
- if not student_path or not duration_path:
55
- return False, "Failed to download models from HuggingFace"
56
-
57
- # Initialize model
58
- model = DMOInference(
59
- student_checkpoint_path=student_path,
60
- duration_predictor_path=duration_path,
61
- device=device,
62
- model_type="F5TTS_Base"
63
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- return True, f"Model loaded successfully on {device.upper()}"
 
 
 
66
 
67
- except Exception as e:
68
- return False, f"Error initializing model: {str(e)}"
69
 
70
- # Initialize model on startup
71
- model_loaded, status_message = initialize_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- def generate_speech(
74
- prompt_audio,
75
- prompt_text,
76
- target_text,
77
- mode,
78
- # Advanced settings
79
- custom_teacher_steps,
80
- custom_teacher_stopping_time,
81
- custom_student_start_step,
82
- temperature,
83
- verbose
84
- ):
85
- """Generate speech with different configurations."""
86
-
87
- if not model_loaded or model is None:
88
- return None, "Model not loaded! Please refresh the page.", "", ""
89
-
90
- if prompt_audio is None:
91
- return None, "Please upload a reference audio!", "", ""
92
-
93
- if not target_text:
94
- return None, "Please enter text to generate!", "", ""
95
 
96
- try:
97
- start_time = time.time()
98
-
99
- # Configure parameters based on mode
100
- if mode == "Student Only (4 steps)":
101
- teacher_steps = 0
102
- student_start_step = 0
103
- teacher_stopping_time = 1.0
104
- elif mode == "Teacher-Guided (8 steps)":
105
- # Default configuration from the notebook
106
- teacher_steps = 16
107
- teacher_stopping_time = 0.07
108
- student_start_step = 1
109
- elif mode == "High Diversity (16 steps)":
110
- teacher_steps = 24
111
- teacher_stopping_time = 0.3
112
- student_start_step = 2
113
- else: # Custom
114
- teacher_steps = custom_teacher_steps
115
- teacher_stopping_time = custom_teacher_stopping_time
116
- student_start_step = custom_student_start_step
117
-
118
- # Generate speech
119
- generated_audio = model.generate(
120
- gen_text=target_text,
121
- audio_path=prompt_audio,
122
- prompt_text=prompt_text if prompt_text else None,
123
- teacher_steps=teacher_steps,
124
- teacher_stopping_time=teacher_stopping_time,
125
- student_start_step=student_start_step,
126
- temperature=temperature,
127
- verbose=verbose
128
- )
129
-
130
- end_time = time.time()
131
-
132
- # Calculate metrics
133
- processing_time = end_time - start_time
134
- audio_duration = generated_audio.shape[-1] / 24000
135
- rtf = processing_time / audio_duration
136
-
137
- # Save audio
138
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
139
- output_path = tmp_file.name
140
-
141
- if isinstance(generated_audio, np.ndarray):
142
- generated_audio = torch.from_numpy(generated_audio)
143
-
144
- if generated_audio.dim() == 1:
145
- generated_audio = generated_audio.unsqueeze(0)
146
 
147
- torchaudio.save(output_path, generated_audio, 24000)
148
-
149
- # Format metrics
150
- metrics = f"RTF: {rtf:.2f}x ({1/rtf:.2f}x speed) | Processing: {processing_time:.2f}s for {audio_duration:.2f}s audio"
151
-
152
- return output_path, "Success!", metrics, f"Mode: {mode}"
153
-
154
- except Exception as e:
155
- return None, f"Error: {str(e)}", "", ""
156
-
157
- # Create Gradio interface
158
- with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS", theme=gr.themes.Soft()) as demo:
159
- gr.Markdown(f"""
160
- # 🎙️ DMOSpeech 2: Zero-Shot Text-to-Speech
161
-
162
- Generate natural speech in any voice with just a short reference audio!
163
-
164
- **Model Status:** {status_message} | **Device:** {device.upper()}
165
- """)
166
-
167
- with gr.Row():
168
- with gr.Column(scale=1):
169
- # Reference audio input
170
- prompt_audio = gr.Audio(
171
- label="📎 Reference Audio",
172
- type="filepath",
173
- sources=["upload", "microphone"]
174
- )
175
-
176
- prompt_text = gr.Textbox(
177
- label="📝 Reference Text (optional - will auto-transcribe if empty)",
178
- placeholder="The text spoken in the reference audio...",
179
- lines=2
180
- )
181
-
182
- target_text = gr.Textbox(
183
- label="✍️ Text to Generate",
184
- placeholder="Enter the text you want to synthesize...",
185
- lines=4
186
- )
187
-
188
- # Generation mode
189
- mode = gr.Radio(
190
- choices=[
191
- "Student Only (4 steps)",
192
- "Teacher-Guided (8 steps)",
193
- "High Diversity (16 steps)",
194
- "Custom"
195
- ],
196
- value="Teacher-Guided (8 steps)",
197
- label="🚀 Generation Mode",
198
- info="Choose speed vs quality/diversity tradeoff"
199
- )
200
-
201
- # Advanced settings (collapsible)
202
- with gr.Accordion("⚙️ Advanced Settings", open=False):
203
- with gr.Row():
204
- custom_teacher_steps = gr.Slider(
205
- minimum=0,
206
- maximum=32,
207
- value=16,
208
- step=1,
209
- label="Teacher Steps",
210
- info="More steps = higher quality"
211
- )
212
-
213
- custom_teacher_stopping_time = gr.Slider(
214
- minimum=0.0,
215
- maximum=1.0,
216
- value=0.07,
217
- step=0.01,
218
- label="Teacher Stopping Time",
219
- info="When to switch to student"
220
- )
221
-
222
- custom_student_start_step = gr.Slider(
223
- minimum=0,
224
- maximum=4,
225
- value=1,
226
- step=1,
227
- label="Student Start Step",
228
- info="Which student step to start from"
229
- )
230
-
231
- temperature = gr.Slider(
232
- minimum=0.0,
233
- maximum=2.0,
234
- value=0.0,
235
- step=0.1,
236
- label="Duration Temperature",
237
- info="0 = deterministic, >0 = more variation in speech rhythm"
238
- )
239
-
240
- verbose = gr.Checkbox(
241
- value=False,
242
- label="Verbose Output",
243
- info="Show detailed generation steps"
244
- )
245
-
246
- generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
247
-
248
- with gr.Column(scale=1):
249
- # Output
250
- output_audio = gr.Audio(
251
- label="🔊 Generated Speech",
252
- type="filepath",
253
- autoplay=True
254
- )
255
-
256
- status = gr.Textbox(
257
- label="Status",
258
- interactive=False
259
- )
260
-
261
- metrics = gr.Textbox(
262
- label="Performance Metrics",
263
- interactive=False
264
- )
265
-
266
- info = gr.Textbox(
267
- label="Generation Info",
268
- interactive=False
269
- )
270
-
271
- # Tips
272
- gr.Markdown("""
273
- ### 💡 Quick Tips:
274
-
275
- - **Student Only**: Fastest (4 steps), good quality
276
- - **Teacher-Guided**: Best balance (8 steps), recommended
277
- - **High Diversity**: More natural prosody (16 steps)
278
- - **Temperature**: Add randomness to speech rhythm
279
-
280
- ### 📊 Expected RTF (Real-Time Factor):
281
- - Student Only: ~0.05x (20x faster than real-time)
282
- - Teacher-Guided: ~0.10x (10x faster)
283
- - High Diversity: ~0.20x (5x faster)
284
- """)
285
-
286
- # Examples section
287
- gr.Markdown("### 🎯 Examples")
288
-
289
- examples = [
290
- [
291
- None, # Will be replaced with actual audio path
292
- "Some call me nature, others call me mother nature.",
293
- "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.",
294
- "Teacher-Guided (8 steps)",
295
- 16, 0.07, 1, 0.0, False
296
- ],
297
- [
298
- None, # Will be replaced with actual audio path
299
- "对,这就是我,万人敬仰的太乙真人。",
300
- '突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"',
301
- "Teacher-Guided (8 steps)",
302
- 16, 0.07, 1, 0.0, False
303
- ],
304
- [
305
- None,
306
- "对,这就是我,万人敬仰的太乙真人。",
307
- '突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:"我身上的肉,是为了掩饰我爆棚的��力,否则,岂不吓坏了你们呢?"',
308
- "High Diversity (16 steps)",
309
- 24, 0.3, 2, 0.8, False
310
- ]
311
- ]
312
-
313
- # Note about example audio files
314
- gr.Markdown("""
315
- *Note: Example audio files should be uploaded to the Space. The examples above show the text configurations used in the original notebook.*
316
- """)
317
-
318
- # Event handler
319
- generate_btn.click(
320
- generate_speech,
321
- inputs=[
322
- prompt_audio,
323
- prompt_text,
324
- target_text,
325
- mode,
326
- custom_teacher_steps,
327
- custom_teacher_stopping_time,
328
- custom_student_start_step,
329
- temperature,
330
- verbose
331
- ],
332
- outputs=[output_audio, status, metrics, info]
333
- )
334
-
335
- # Update visibility of custom settings based on mode
336
- def update_custom_visibility(mode):
337
- return gr.update(visible=(mode == "Custom"))
338
-
339
- mode.change(
340
- lambda x: [gr.update(interactive=(x == "Custom"))] * 3,
341
- inputs=[mode],
342
- outputs=[custom_teacher_steps, custom_teacher_stopping_time, custom_student_start_step]
343
- )
344
-
345
- # Launch the app
346
- if __name__ == "__main__":
347
- if not model_loaded:
348
- print(f"Warning: Model failed to load - {status_message}")
349
 
350
- demo.launch()
 
 
 
1
+ # Add this to your DMOInference class or create a wrapper
2
+
3
+ import os
4
  import torch
 
 
 
 
5
  from pathlib import Path
6
  from huggingface_hub import hf_hub_download
7
+ import re
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ def load_checkpoint_from_hf(checkpoint_path, device='cpu'):
10
+ """
11
+ Load a checkpoint from either a local path or HuggingFace URL.
12
 
13
+ Supports:
14
+ - Local paths: /path/to/model.pt
15
+ - HF URLs: hf://username/repo/model.pt
16
+ - HF hub format: username/repo/model.pt
17
+ """
18
+ if isinstance(checkpoint_path, str):
19
+ # Check if it's a HuggingFace URL
20
+ if checkpoint_path.startswith("hf://"):
21
+ # Parse HF URL: hf://username/repo/path/to/model.pt
22
+ match = re.match(r"hf://([^/]+/[^/]+)/(.+)", checkpoint_path)
23
+ if match:
24
+ repo_id = match.group(1)
25
+ filename = match.group(2)
26
+
27
+ print(f"Loading from HuggingFace: {repo_id}/{filename}")
28
+
29
+ # Download from HuggingFace
30
+ local_path = hf_hub_download(
31
+ repo_id=repo_id,
32
+ filename=filename,
33
+ cache_dir=os.environ.get("HF_HOME", "./models")
34
+ )
35
+
36
+ # Load the checkpoint
37
+ return torch.load(local_path, map_location=device)
38
+
39
+ # Check if it's a HuggingFace repo format (username/repo/file.pt)
40
+ elif "/" in checkpoint_path and not os.path.exists(checkpoint_path):
41
+ parts = checkpoint_path.split("/")
42
+ if len(parts) >= 3:
43
+ repo_id = "/".join(parts[:2])
44
+ filename = "/".join(parts[2:])
45
+
46
+ print(f"Loading from HuggingFace: {repo_id}/{filename}")
47
+
48
+ local_path = hf_hub_download(
49
+ repo_id=repo_id,
50
+ filename=filename,
51
+ cache_dir=os.environ.get("HF_HOME", "./models")
52
+ )
53
+
54
+ return torch.load(local_path, map_location=device)
55
 
56
+ # Local file path
57
+ elif os.path.exists(checkpoint_path):
58
+ print(f"Loading from local path: {checkpoint_path}")
59
+ return torch.load(checkpoint_path, map_location=device)
60
 
61
+ raise ValueError(f"Could not load checkpoint from: {checkpoint_path}")
 
62
 
63
+ # Modified DMOInference class init (partial)
64
+ class DMOInference:
65
+ def __init__(
66
+ self,
67
+ student_checkpoint_path="",
68
+ duration_predictor_path="",
69
+ device="cuda",
70
+ model_type="F5TTS_Base",
71
+ tokenizer="pinyin",
72
+ dataset_name="Emilia_ZH_EN",
73
+ cuda_device_id="0"
74
+ ):
75
+ # ... (previous initialization code) ...
76
+
77
+ # Initialize components
78
+ self._setup_tokenizer()
79
+ self._setup_models(student_checkpoint_path) # Modified to handle HF URLs
80
+ self._setup_mel_spec()
81
+ self._setup_vocoder()
82
+ self._setup_duration_predictor(duration_predictor_path) # Modified to handle HF URLs
83
+
84
+ def _setup_models(self, student_checkpoint_path):
85
+ """Initialize teacher and student models with HF support."""
86
+ # ... (model configuration code) ...
87
+
88
+ # Load student checkpoint with HF support
89
+ checkpoint = load_checkpoint_from_hf(student_checkpoint_path, device='cpu')
90
+ self.model.load_state_dict(checkpoint['model_state_dict'], strict=False)
91
+
92
+ # ... (rest of the setup) ...
93
+
94
+ def _setup_duration_predictor(self, checkpoint_path):
95
+ """Initialize duration predictor with HF support."""
96
+ # ... (model initialization code) ...
97
+
98
+ # Load checkpoint with HF support
99
+ checkpoint = load_checkpoint_from_hf(checkpoint_path, device='cpu')
100
+ self.SLP.load_state_dict(checkpoint['model_state_dict'])
101
 
102
+ # Wrapper class for easier use
103
+ class DMOInferenceHF(DMOInference):
104
+ """DMOInference with built-in HuggingFace support."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ def __init__(self, **kwargs):
107
+ # Override checkpoint loading to support HF URLs
108
+ if 'student_checkpoint_path' in kwargs:
109
+ self._original_student_path = kwargs['student_checkpoint_path']
110
+ if 'duration_predictor_path' in kwargs:
111
+ self._original_duration_path = kwargs['duration_predictor_path']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
+ super().__init__(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
+ def _load_checkpoint(self, checkpoint_path):
116
+ """Load checkpoint with HF URL support."""
117
+ return load_checkpoint_from_hf(checkpoint_path, self.device)