yl4579 commited on
Commit
a28c293
·
verified ·
1 Parent(s): 2f84825

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +338 -105
app.py CHANGED
@@ -1,117 +1,350 @@
1
- # Add this to your DMOInference class or create a wrapper
2
-
3
- import os
4
  import torch
 
 
 
 
5
  from pathlib import Path
6
  from huggingface_hub import hf_hub_download
7
- import re
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- def load_checkpoint_from_hf(checkpoint_path, device='cpu'):
10
- """
11
- Load a checkpoint from either a local path or HuggingFace URL.
12
 
13
- Supports:
14
- - Local paths: /path/to/model.pt
15
- - HF URLs: hf://username/repo/model.pt
16
- - HF hub format: username/repo/model.pt
17
- """
18
- if isinstance(checkpoint_path, str):
19
- # Check if it's a HuggingFace URL
20
- if checkpoint_path.startswith("hf://"):
21
- # Parse HF URL: hf://username/repo/path/to/model.pt
22
- match = re.match(r"hf://([^/]+/[^/]+)/(.+)", checkpoint_path)
23
- if match:
24
- repo_id = match.group(1)
25
- filename = match.group(2)
26
-
27
- print(f"Loading from HuggingFace: {repo_id}/{filename}")
28
-
29
- # Download from HuggingFace
30
- local_path = hf_hub_download(
31
- repo_id=repo_id,
32
- filename=filename,
33
- cache_dir=os.environ.get("HF_HOME", "./models")
34
- )
35
-
36
- # Load the checkpoint
37
- return torch.load(local_path, map_location=device)
38
-
39
- # Check if it's a HuggingFace repo format (username/repo/file.pt)
40
- elif "/" in checkpoint_path and not os.path.exists(checkpoint_path):
41
- parts = checkpoint_path.split("/")
42
- if len(parts) >= 3:
43
- repo_id = "/".join(parts[:2])
44
- filename = "/".join(parts[2:])
45
-
46
- print(f"Loading from HuggingFace: {repo_id}/{filename}")
47
-
48
- local_path = hf_hub_download(
49
- repo_id=repo_id,
50
- filename=filename,
51
- cache_dir=os.environ.get("HF_HOME", "./models")
52
- )
53
-
54
- return torch.load(local_path, map_location=device)
55
 
56
- # Local file path
57
- elif os.path.exists(checkpoint_path):
58
- print(f"Loading from local path: {checkpoint_path}")
59
- return torch.load(checkpoint_path, map_location=device)
 
 
 
60
 
61
- raise ValueError(f"Could not load checkpoint from: {checkpoint_path}")
 
 
 
62
 
63
- # Modified DMOInference class init (partial)
64
- class DMOInference:
65
- def __init__(
66
- self,
67
- student_checkpoint_path="",
68
- duration_predictor_path="",
69
- device="cuda",
70
- model_type="F5TTS_Base",
71
- tokenizer="pinyin",
72
- dataset_name="Emilia_ZH_EN",
73
- cuda_device_id="0"
74
- ):
75
- # ... (previous initialization code) ...
76
-
77
- # Initialize components
78
- self._setup_tokenizer()
79
- self._setup_models(student_checkpoint_path) # Modified to handle HF URLs
80
- self._setup_mel_spec()
81
- self._setup_vocoder()
82
- self._setup_duration_predictor(duration_predictor_path) # Modified to handle HF URLs
83
-
84
- def _setup_models(self, student_checkpoint_path):
85
- """Initialize teacher and student models with HF support."""
86
- # ... (model configuration code) ...
87
-
88
- # Load student checkpoint with HF support
89
- checkpoint = load_checkpoint_from_hf(student_checkpoint_path, device='cpu')
90
- self.model.load_state_dict(checkpoint['model_state_dict'], strict=False)
91
-
92
- # ... (rest of the setup) ...
93
-
94
- def _setup_duration_predictor(self, checkpoint_path):
95
- """Initialize duration predictor with HF support."""
96
- # ... (model initialization code) ...
97
-
98
- # Load checkpoint with HF support
99
- checkpoint = load_checkpoint_from_hf(checkpoint_path, device='cpu')
100
- self.SLP.load_state_dict(checkpoint['model_state_dict'])
101
 
102
- # Wrapper class for easier use
103
- class DMOInferenceHF(DMOInference):
104
- """DMOInference with built-in HuggingFace support."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- def __init__(self, **kwargs):
107
- # Override checkpoint loading to support HF URLs
108
- if 'student_checkpoint_path' in kwargs:
109
- self._original_student_path = kwargs['student_checkpoint_path']
110
- if 'duration_predictor_path' in kwargs:
111
- self._original_duration_path = kwargs['duration_predictor_path']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- super().__init__(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- def _load_checkpoint(self, checkpoint_path):
116
- """Load checkpoint with HF URL support."""
117
- return load_checkpoint_from_hf(checkpoint_path, self.device)
 
1
+ import gradio as gr
 
 
2
  import torch
3
+ import torchaudio
4
+ import numpy as np
5
+ import tempfile
6
+ import time
7
  from pathlib import Path
8
  from huggingface_hub import hf_hub_download
9
+ import os
10
+
11
+ # Import the inference module (assuming it's named 'infer.py' based on the notebook)
12
+ from infer import DMOInference
13
+
14
+ # Global model instance
15
+ model = None
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+
18
+ def download_models():
19
+ """Download models from HuggingFace Hub."""
20
+ try:
21
+ print("Downloading models from HuggingFace...")
22
+
23
+ # Download student model
24
+ student_path = hf_hub_download(
25
+ repo_id="yl4579/DMOSpeech2",
26
+ filename="model_85000.pt",
27
+ cache_dir="./models"
28
+ )
29
+
30
+ # Download duration predictor
31
+ duration_path = hf_hub_download(
32
+ repo_id="yl4579/DMOSpeech2",
33
+ filename="model_1500.pt",
34
+ cache_dir="./models"
35
+ )
36
+
37
+ print(f"Student model: {student_path}")
38
+ print(f"Duration model: {duration_path}")
39
+
40
+ return student_path, duration_path
41
+
42
+ except Exception as e:
43
+ print(f"Error downloading models: {e}")
44
+ return None, None
45
 
46
+ def initialize_model():
47
+ """Initialize the model on startup."""
48
+ global model
49
 
50
+ try:
51
+ # Download models
52
+ student_path, duration_path = download_models()
53
+
54
+ if not student_path or not duration_path:
55
+ return False, "Failed to download models from HuggingFace"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ # Initialize model
58
+ model = DMOInference(
59
+ student_checkpoint_path=student_path,
60
+ duration_predictor_path=duration_path,
61
+ device=device,
62
+ model_type="F5TTS_Base"
63
+ )
64
 
65
+ return True, f"Model loaded successfully on {device.upper()}"
66
+
67
+ except Exception as e:
68
+ return False, f"Error initializing model: {str(e)}"
69
 
70
+ # Initialize model on startup
71
+ model_loaded, status_message = initialize_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ def generate_speech(
74
+ prompt_audio,
75
+ prompt_text,
76
+ target_text,
77
+ mode,
78
+ # Advanced settings
79
+ custom_teacher_steps,
80
+ custom_teacher_stopping_time,
81
+ custom_student_start_step,
82
+ temperature,
83
+ verbose
84
+ ):
85
+ """Generate speech with different configurations."""
86
+
87
+ if not model_loaded or model is None:
88
+ return None, "Model not loaded! Please refresh the page.", "", ""
89
+
90
+ if prompt_audio is None:
91
+ return None, "Please upload a reference audio!", "", ""
92
+
93
+ if not target_text:
94
+ return None, "Please enter text to generate!", "", ""
95
 
96
+ try:
97
+ start_time = time.time()
98
+
99
+ # Configure parameters based on mode
100
+ if mode == "Student Only (4 steps)":
101
+ teacher_steps = 0
102
+ student_start_step = 0
103
+ teacher_stopping_time = 1.0
104
+ elif mode == "Teacher-Guided (8 steps)":
105
+ # Default configuration from the notebook
106
+ teacher_steps = 16
107
+ teacher_stopping_time = 0.07
108
+ student_start_step = 1
109
+ elif mode == "High Diversity (16 steps)":
110
+ teacher_steps = 24
111
+ teacher_stopping_time = 0.3
112
+ student_start_step = 2
113
+ else: # Custom
114
+ teacher_steps = custom_teacher_steps
115
+ teacher_stopping_time = custom_teacher_stopping_time
116
+ student_start_step = custom_student_start_step
117
+
118
+ # Generate speech
119
+ generated_audio = model.generate(
120
+ gen_text=target_text,
121
+ audio_path=prompt_audio,
122
+ prompt_text=prompt_text if prompt_text else None,
123
+ teacher_steps=teacher_steps,
124
+ teacher_stopping_time=teacher_stopping_time,
125
+ student_start_step=student_start_step,
126
+ temperature=temperature,
127
+ verbose=verbose
128
+ )
129
+
130
+ end_time = time.time()
131
+
132
+ # Calculate metrics
133
+ processing_time = end_time - start_time
134
+ audio_duration = generated_audio.shape[-1] / 24000
135
+ rtf = processing_time / audio_duration
136
+
137
+ # Save audio
138
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
139
+ output_path = tmp_file.name
140
+
141
+ if isinstance(generated_audio, np.ndarray):
142
+ generated_audio = torch.from_numpy(generated_audio)
143
+
144
+ if generated_audio.dim() == 1:
145
+ generated_audio = generated_audio.unsqueeze(0)
146
 
147
+ torchaudio.save(output_path, generated_audio, 24000)
148
+
149
+ # Format metrics
150
+ metrics = f"RTF: {rtf:.2f}x ({1/rtf:.2f}x speed) | Processing: {processing_time:.2f}s for {audio_duration:.2f}s audio"
151
+
152
+ return output_path, "Success!", metrics, f"Mode: {mode}"
153
+
154
+ except Exception as e:
155
+ return None, f"Error: {str(e)}", "", ""
156
+
157
+ # Create Gradio interface
158
+ with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS", theme=gr.themes.Soft()) as demo:
159
+ gr.Markdown(f"""
160
+ # 🎙️ DMOSpeech 2: Zero-Shot Text-to-Speech
161
+
162
+ Generate natural speech in any voice with just a short reference audio!
163
+
164
+ **Model Status:** {status_message} | **Device:** {device.upper()}
165
+ """)
166
+
167
+ with gr.Row():
168
+ with gr.Column(scale=1):
169
+ # Reference audio input
170
+ prompt_audio = gr.Audio(
171
+ label="📎 Reference Audio",
172
+ type="filepath",
173
+ sources=["upload", "microphone"]
174
+ )
175
+
176
+ prompt_text = gr.Textbox(
177
+ label="📝 Reference Text (optional - will auto-transcribe if empty)",
178
+ placeholder="The text spoken in the reference audio...",
179
+ lines=2
180
+ )
181
+
182
+ target_text = gr.Textbox(
183
+ label="✍️ Text to Generate",
184
+ placeholder="Enter the text you want to synthesize...",
185
+ lines=4
186
+ )
187
+
188
+ # Generation mode
189
+ mode = gr.Radio(
190
+ choices=[
191
+ "Student Only (4 steps)",
192
+ "Teacher-Guided (8 steps)",
193
+ "High Diversity (16 steps)",
194
+ "Custom"
195
+ ],
196
+ value="Teacher-Guided (8 steps)",
197
+ label="🚀 Generation Mode",
198
+ info="Choose speed vs quality/diversity tradeoff"
199
+ )
200
+
201
+ # Advanced settings (collapsible)
202
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
203
+ with gr.Row():
204
+ custom_teacher_steps = gr.Slider(
205
+ minimum=0,
206
+ maximum=32,
207
+ value=16,
208
+ step=1,
209
+ label="Teacher Steps",
210
+ info="More steps = higher quality"
211
+ )
212
+
213
+ custom_teacher_stopping_time = gr.Slider(
214
+ minimum=0.0,
215
+ maximum=1.0,
216
+ value=0.07,
217
+ step=0.01,
218
+ label="Teacher Stopping Time",
219
+ info="When to switch to student"
220
+ )
221
+
222
+ custom_student_start_step = gr.Slider(
223
+ minimum=0,
224
+ maximum=4,
225
+ value=1,
226
+ step=1,
227
+ label="Student Start Step",
228
+ info="Which student step to start from"
229
+ )
230
+
231
+ temperature = gr.Slider(
232
+ minimum=0.0,
233
+ maximum=2.0,
234
+ value=0.0,
235
+ step=0.1,
236
+ label="Duration Temperature",
237
+ info="0 = deterministic, >0 = more variation in speech rhythm"
238
+ )
239
+
240
+ verbose = gr.Checkbox(
241
+ value=False,
242
+ label="Verbose Output",
243
+ info="Show detailed generation steps"
244
+ )
245
+
246
+ generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
247
+
248
+ with gr.Column(scale=1):
249
+ # Output
250
+ output_audio = gr.Audio(
251
+ label="🔊 Generated Speech",
252
+ type="filepath",
253
+ autoplay=True
254
+ )
255
+
256
+ status = gr.Textbox(
257
+ label="Status",
258
+ interactive=False
259
+ )
260
+
261
+ metrics = gr.Textbox(
262
+ label="Performance Metrics",
263
+ interactive=False
264
+ )
265
+
266
+ info = gr.Textbox(
267
+ label="Generation Info",
268
+ interactive=False
269
+ )
270
+
271
+ # Tips
272
+ gr.Markdown("""
273
+ ### 💡 Quick Tips:
274
+
275
+ - **Student Only**: Fastest (4 steps), good quality
276
+ - **Teacher-Guided**: Best balance (8 steps), recommended
277
+ - **High Diversity**: More natural prosody (16 steps)
278
+ - **Temperature**: Add randomness to speech rhythm
279
+
280
+ ### 📊 Expected RTF (Real-Time Factor):
281
+ - Student Only: ~0.05x (20x faster than real-time)
282
+ - Teacher-Guided: ~0.10x (10x faster)
283
+ - High Diversity: ~0.20x (5x faster)
284
+ """)
285
+
286
+ # Examples section
287
+ gr.Markdown("### 🎯 Examples")
288
+
289
+ examples = [
290
+ [
291
+ None, # Will be replaced with actual audio path
292
+ "Some call me nature, others call me mother nature.",
293
+ "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.",
294
+ "Teacher-Guided (8 steps)",
295
+ 16, 0.07, 1, 0.0, False
296
+ ],
297
+ [
298
+ None, # Will be replaced with actual audio path
299
+ "对,这就是我,万人敬仰的太乙真人。",
300
+ '突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"',
301
+ "Teacher-Guided (8 steps)",
302
+ 16, 0.07, 1, 0.0, False
303
+ ],
304
+ [
305
+ None,
306
+ "对,这就是我,万人敬仰的太乙真人。",
307
+ '突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:"我身上��肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"',
308
+ "High Diversity (16 steps)",
309
+ 24, 0.3, 2, 0.8, False
310
+ ]
311
+ ]
312
+
313
+ # Note about example audio files
314
+ gr.Markdown("""
315
+ *Note: Example audio files should be uploaded to the Space. The examples above show the text configurations used in the original notebook.*
316
+ """)
317
+
318
+ # Event handler
319
+ generate_btn.click(
320
+ generate_speech,
321
+ inputs=[
322
+ prompt_audio,
323
+ prompt_text,
324
+ target_text,
325
+ mode,
326
+ custom_teacher_steps,
327
+ custom_teacher_stopping_time,
328
+ custom_student_start_step,
329
+ temperature,
330
+ verbose
331
+ ],
332
+ outputs=[output_audio, status, metrics, info]
333
+ )
334
+
335
+ # Update visibility of custom settings based on mode
336
+ def update_custom_visibility(mode):
337
+ return gr.update(visible=(mode == "Custom"))
338
+
339
+ mode.change(
340
+ lambda x: [gr.update(interactive=(x == "Custom"))] * 3,
341
+ inputs=[mode],
342
+ outputs=[custom_teacher_steps, custom_teacher_stopping_time, custom_student_start_step]
343
+ )
344
+
345
+ # Launch the app
346
+ if __name__ == "__main__":
347
+ if not model_loaded:
348
+ print(f"Warning: Model failed to load - {status_message}")
349
 
350
+ demo.launch()