yl4579 commited on
Commit
51a612e
·
verified ·
1 Parent(s): 4972f24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -213
app.py CHANGED
@@ -14,37 +14,61 @@ from transformers import pipeline
14
  from infer import DMOInference
15
 
16
  # Global variables
17
- model = None
18
  asr_pipe = None
19
- device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
- # Initialize ASR pipeline
22
- def initialize_asr_pipeline(device=device, dtype=None):
23
- """Initialize the ASR pipeline on startup."""
24
- global asr_pipe
25
 
26
- if dtype is None:
27
- dtype = (
28
- torch.float16
29
- if "cuda" in device
30
- and torch.cuda.is_available()
31
- and torch.cuda.get_device_properties(device).major >= 7
32
- and not torch.cuda.get_device_name().endswith("[ZLUDA]")
33
- else torch.float32
34
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  print("Initializing ASR pipeline...")
37
  try:
38
  asr_pipe = pipeline(
39
  "automatic-speech-recognition",
40
  model="openai/whisper-large-v3-turbo",
41
- torch_dtype=dtype,
42
- device="cpu" # Keep ASR on CPU to save GPU memory
43
  )
44
- print("ASR pipeline initialized successfully")
 
45
  except Exception as e:
46
  print(f"Error initializing ASR pipeline: {e}")
47
- asr_pipe = None
48
 
49
  # Transcribe function
50
  def transcribe(ref_audio, language=None):
@@ -52,7 +76,7 @@ def transcribe(ref_audio, language=None):
52
  global asr_pipe
53
 
54
  if asr_pipe is None:
55
- return "" # Return empty string if ASR is not available
56
 
57
  try:
58
  result = asr_pipe(
@@ -67,65 +91,14 @@ def transcribe(ref_audio, language=None):
67
  print(f"Transcription error: {e}")
68
  return ""
69
 
70
- def download_models():
71
- """Download models from HuggingFace Hub."""
72
- try:
73
- print("Downloading models from HuggingFace...")
74
-
75
- # Download student model
76
- student_path = hf_hub_download(
77
- repo_id="yl4579/DMOSpeech2",
78
- filename="model_85000.pt",
79
- cache_dir="./models"
80
- )
81
-
82
- # Download duration predictor
83
- duration_path = hf_hub_download(
84
- repo_id="yl4579/DMOSpeech2",
85
- filename="model_1500.pt",
86
- cache_dir="./models"
87
- )
88
-
89
- print(f"Student model: {student_path}")
90
- print(f"Duration model: {duration_path}")
91
-
92
- return student_path, duration_path
93
-
94
- except Exception as e:
95
- print(f"Error downloading models: {e}")
96
- return None, None
97
-
98
- def initialize_model():
99
- """Initialize the model on startup."""
100
- global model
101
-
102
- try:
103
- # Download models
104
- student_path, duration_path = download_models()
105
-
106
- if not student_path or not duration_path:
107
- return False, "Failed to download models from HuggingFace"
108
-
109
- # Initialize model
110
- model = DMOInference(
111
- student_checkpoint_path=student_path,
112
- duration_predictor_path=duration_path,
113
- device=device,
114
- model_type="F5TTS_Base"
115
- )
116
-
117
- return True, f"Model loaded successfully on {device.upper()}"
118
-
119
- except Exception as e:
120
- return False, f"Error initializing model: {str(e)}"
121
 
122
- # Initialize models on startup
123
- print("Initializing models...")
124
- model_loaded, status_message = initialize_model()
125
- initialize_asr_pipeline() # Initialize ASR pipeline
126
-
127
- @spaces.GPU(duration=120) # Request GPU for up to 120 seconds
128
- def generate_speech(
129
  prompt_audio,
130
  prompt_text,
131
  target_text,
@@ -136,53 +109,72 @@ def generate_speech(
136
  custom_student_start_step,
137
  verbose
138
  ):
139
- """Generate speech with different configurations."""
140
 
141
- if not model_loaded or model is None:
142
- return None, "Model not loaded! Please refresh the page.", "", ""
143
 
144
  if prompt_audio is None:
145
- return None, "Please upload a reference audio!", "", ""
146
 
147
  if not target_text:
148
- return None, "Please enter text to generate!", "", ""
149
 
150
  try:
151
- # Auto-transcribe if prompt_text is empty
152
- if not prompt_text and prompt_text != "":
 
 
 
 
 
 
 
 
 
 
 
 
153
  print("Auto-transcribing reference audio...")
154
- prompt_text = transcribe(prompt_audio)
155
- print(f"Transcribed: {prompt_text}")
156
 
157
  start_time = time.time()
158
 
159
  # Configure parameters based on mode
160
- if mode == "Student Only (4 steps)":
161
- teacher_steps = 0
162
- student_start_step = 0
163
- teacher_stopping_time = 1.0
164
- elif mode == "Teacher-Guided (8 steps)":
165
- # Default configuration from the notebook
166
- teacher_steps = 16
167
- teacher_stopping_time = 0.07
168
- student_start_step = 1
169
- elif mode == "High Diversity (16 steps)":
170
- teacher_steps = 24
171
- teacher_stopping_time = 0.3
172
- student_start_step = 2
173
- else: # Custom
174
- teacher_steps = custom_teacher_steps
175
- teacher_stopping_time = custom_teacher_stopping_time
176
- student_start_step = custom_student_start_step
 
 
 
 
 
 
 
177
 
178
  # Generate speech
179
  generated_audio = model.generate(
180
  gen_text=target_text,
181
  audio_path=prompt_audio,
182
- prompt_text=prompt_text if prompt_text else None,
183
- teacher_steps=teacher_steps,
184
- teacher_stopping_time=teacher_stopping_time,
185
- student_start_step=student_start_step,
186
  temperature=temperature,
187
  verbose=verbose
188
  )
@@ -206,29 +198,50 @@ def generate_speech(
206
 
207
  torchaudio.save(output_path, generated_audio, 24000)
208
 
209
- # Format metrics
210
- metrics = f"RTF: {rtf:.2f}x ({1/rtf:.2f}x speed) | Processing: {processing_time:.2f}s for {audio_duration:.2f}s audio"
 
 
211
 
212
- return output_path, "Success!", metrics, f"Mode: {mode} | Transcribed: {prompt_text[:50]}..." if not prompt_text else f"Mode: {mode}"
 
 
 
 
 
 
 
 
 
 
213
 
214
  except Exception as e:
215
- return None, f"Error: {str(e)}", "", ""
 
 
216
 
217
  # Create Gradio interface
218
- with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS", theme=gr.themes.Soft()) as demo:
219
- gr.Markdown(f"""
220
- # 🎙️ DMOSpeech 2: Zero-Shot Text-to-Speech
 
 
 
 
221
 
222
- Generate natural speech in any voice with just a short reference audio!
223
-
224
- **Model Status:** {status_message} | **Device:** {device.upper()} | **ASR:** {"✅ Ready" if asr_pipe else "❌ Not available"}
 
 
 
225
  """)
226
 
227
  with gr.Row():
228
  with gr.Column(scale=1):
229
- # Reference audio input
230
  prompt_audio = gr.Audio(
231
- label="📎 Reference Audio",
232
  type="filepath",
233
  sources=["upload", "microphone"]
234
  )
@@ -245,7 +258,6 @@ with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS", theme=gr.themes.Soft()) as d
245
  lines=4
246
  )
247
 
248
- # Generation mode
249
  mode = gr.Radio(
250
  choices=[
251
  "Student Only (4 steps)",
@@ -255,10 +267,10 @@ with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS", theme=gr.themes.Soft()) as d
255
  ],
256
  value="Teacher-Guided (8 steps)",
257
  label="🚀 Generation Mode",
258
- info="Choose speed vs quality/diversity tradeoff"
259
  )
260
 
261
- # Advanced settings (collapsible)
262
  with gr.Accordion("⚙️ Advanced Settings", open=False):
263
  temperature = gr.Slider(
264
  minimum=0.0,
@@ -266,115 +278,76 @@ with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS", theme=gr.themes.Soft()) as d
266
  value=0.0,
267
  step=0.1,
268
  label="Duration Temperature",
269
- info="0 = deterministic, >0 = more variation in speech rhythm"
270
  )
271
 
272
- with gr.Group(visible=False) as custom_settings:
273
- gr.Markdown("### Custom Mode Settings")
274
- custom_teacher_steps = gr.Slider(
275
- minimum=0,
276
- maximum=32,
277
- value=16,
278
- step=1,
279
- label="Teacher Steps",
280
- info="More steps = higher quality"
281
- )
282
-
283
- custom_teacher_stopping_time = gr.Slider(
284
- minimum=0.0,
285
- maximum=1.0,
286
- value=0.07,
287
- step=0.01,
288
- label="Teacher Stopping Time",
289
- info="When to switch to student"
290
- )
291
-
292
- custom_student_start_step = gr.Slider(
293
- minimum=0,
294
- maximum=4,
295
- value=1,
296
- step=1,
297
- label="Student Start Step",
298
- info="Which student step to start from"
299
- )
300
 
301
- verbose = gr.Checkbox(
302
- value=False,
303
- label="Verbose Output",
304
- info="Show detailed generation steps"
305
- )
306
 
307
  generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
308
 
309
  with gr.Column(scale=1):
310
- # Output
311
  output_audio = gr.Audio(
312
  label="🔊 Generated Speech",
313
  type="filepath",
314
  autoplay=True
315
  )
316
 
317
- status = gr.Textbox(
318
- label="Status",
319
- interactive=False
320
- )
321
-
322
- metrics = gr.Textbox(
323
- label="Performance Metrics",
324
- interactive=False
325
- )
326
-
327
- info = gr.Textbox(
328
- label="Generation Info",
329
- interactive=False
330
- )
331
 
332
- # Tips
333
  gr.Markdown("""
334
- ### 💡 Quick Tips:
335
 
336
- - **Auto-transcription**: Leave reference text empty to auto-transcribe
337
- - **Student Only**: Fastest (4 steps), good quality
338
- - **Teacher-Guided**: Best balance (8 steps), recommended
339
- - **High Diversity**: More natural prosody (16 steps)
340
- - **Custom Mode**: Fine-tune all parameters
341
 
342
- ### 📊 Expected RTF (Real-Time Factor):
343
- - Student Only: ~0.05x (20x faster than real-time)
344
- - Teacher-Guided: ~0.10x (10x faster)
345
- - High Diversity: ~0.20x (5x faster)
 
346
  """)
347
 
348
- # Examples section
349
- gr.Markdown("### 🎯 Example Configurations")
350
 
351
  gr.Markdown("""
352
  <details>
353
  <summary>English Example</summary>
354
 
355
- **Reference text:** "Some call me nature, others call me mother nature."
356
 
357
- **Target text:** "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
358
  </details>
359
 
360
  <details>
361
  <summary>Chinese Example</summary>
362
 
363
- **Reference text:** "对,这就是我,万人敬仰的太乙真人。"
364
 
365
- **Target text:** "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:'我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?'"
366
  </details>
 
367
 
368
- <details>
369
- <summary>High Diversity Chinese Example</summary>
 
370
 
371
- Same as above but with **Temperature: 0.8** for more natural variation in speech rhythm.
372
- </details>
373
- """)
374
 
375
- # Event handler
376
  generate_btn.click(
377
- generate_speech,
378
  inputs=[
379
  prompt_audio,
380
  prompt_text,
@@ -386,25 +359,15 @@ with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS", theme=gr.themes.Soft()) as d
386
  custom_student_start_step,
387
  verbose
388
  ],
389
- outputs=[output_audio, status, metrics, info]
390
- )
391
-
392
- # Update visibility of custom settings based on mode
393
- def update_custom_visibility(mode):
394
- is_custom = (mode == "Custom")
395
- return gr.update(visible=is_custom)
396
-
397
- mode.change(
398
- update_custom_visibility,
399
- inputs=[mode],
400
- outputs=[custom_settings]
401
  )
402
 
403
- # Launch the app
404
  if __name__ == "__main__":
405
- if not model_loaded:
406
- print(f"Warning: Model failed to load - {status_message}")
407
- if not asr_pipe:
408
- print("Warning: ASR pipeline not available - auto-transcription disabled")
409
-
410
  demo.launch()
 
14
  from infer import DMOInference
15
 
16
  # Global variables
17
+ model_paths = {"student": None, "duration": None}
18
  asr_pipe = None
19
+ model_downloaded = False
20
 
21
+ # Download models on startup (CPU)
22
+ def download_models():
23
+ """Download models from HuggingFace Hub."""
24
+ global model_downloaded, model_paths
25
 
26
+ try:
27
+ print("Downloading models from HuggingFace...")
28
+
29
+ # Download student model
30
+ student_path = hf_hub_download(
31
+ repo_id="yl4579/DMOSpeech2",
32
+ filename="model_85000.pt",
33
+ cache_dir="./models"
34
  )
35
+
36
+ # Download duration predictor
37
+ duration_path = hf_hub_download(
38
+ repo_id="yl4579/DMOSpeech2",
39
+ filename="model_1500.pt",
40
+ cache_dir="./models"
41
+ )
42
+
43
+ model_paths["student"] = student_path
44
+ model_paths["duration"] = duration_path
45
+ model_downloaded = True
46
+
47
+ print(f"✓ Models downloaded successfully")
48
+ return True
49
+
50
+ except Exception as e:
51
+ print(f"Error downloading models: {e}")
52
+ return False
53
+
54
+ # Initialize ASR pipeline on CPU
55
+ def initialize_asr_pipeline():
56
+ """Initialize the ASR pipeline on startup."""
57
+ global asr_pipe
58
 
59
  print("Initializing ASR pipeline...")
60
  try:
61
  asr_pipe = pipeline(
62
  "automatic-speech-recognition",
63
  model="openai/whisper-large-v3-turbo",
64
+ torch_dtype=torch.float32,
65
+ device="cpu" # Always use CPU for ASR to save GPU memory
66
  )
67
+ print("ASR pipeline initialized successfully")
68
+ return True
69
  except Exception as e:
70
  print(f"Error initializing ASR pipeline: {e}")
71
+ return False
72
 
73
  # Transcribe function
74
  def transcribe(ref_audio, language=None):
 
76
  global asr_pipe
77
 
78
  if asr_pipe is None:
79
+ return ""
80
 
81
  try:
82
  result = asr_pipe(
 
91
  print(f"Transcription error: {e}")
92
  return ""
93
 
94
+ # Initialize on startup
95
+ print("Starting DMOSpeech 2...")
96
+ models_ready = download_models()
97
+ asr_ready = initialize_asr_pipeline()
98
+ status_message = f"Models: {'✅' if models_ready else '❌'} | ASR: {'✅' if asr_ready else '❌'}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ @spaces.GPU(duration=120)
101
+ def generate_speech_gpu(
 
 
 
 
 
102
  prompt_audio,
103
  prompt_text,
104
  target_text,
 
109
  custom_student_start_step,
110
  verbose
111
  ):
112
+ """Generate speech with GPU acceleration."""
113
 
114
+ if not model_downloaded:
115
+ return None, " Models not downloaded! Please refresh the page.", "", "", prompt_text
116
 
117
  if prompt_audio is None:
118
+ return None, "Please upload a reference audio!", "", "", prompt_text
119
 
120
  if not target_text:
121
+ return None, "Please enter text to generate!", "", "", prompt_text
122
 
123
  try:
124
+ # Initialize model on GPU
125
+ device = "cuda" if torch.cuda.is_available() else "cpu"
126
+ print(f"Initializing model on {device}...")
127
+
128
+ model = DMOInference(
129
+ student_checkpoint_path=model_paths["student"],
130
+ duration_predictor_path=model_paths["duration"],
131
+ device=device,
132
+ model_type="F5TTS_Base"
133
+ )
134
+
135
+ # Auto-transcribe if needed (this happens on CPU)
136
+ transcribed_text = prompt_text # Default to provided text
137
+ if not prompt_text.strip():
138
  print("Auto-transcribing reference audio...")
139
+ transcribed_text = transcribe(prompt_audio)
140
+ print(f"Transcribed: {transcribed_text}")
141
 
142
  start_time = time.time()
143
 
144
  # Configure parameters based on mode
145
+ configs = {
146
+ "Student Only (4 steps)": {
147
+ "teacher_steps": 0,
148
+ "student_start_step": 0,
149
+ "teacher_stopping_time": 1.0
150
+ },
151
+ "Teacher-Guided (8 steps)": {
152
+ "teacher_steps": 16,
153
+ "teacher_stopping_time": 0.07,
154
+ "student_start_step": 1
155
+ },
156
+ "High Diversity (16 steps)": {
157
+ "teacher_steps": 24,
158
+ "teacher_stopping_time": 0.3,
159
+ "student_start_step": 2
160
+ },
161
+ "Custom": {
162
+ "teacher_steps": custom_teacher_steps,
163
+ "teacher_stopping_time": custom_teacher_stopping_time,
164
+ "student_start_step": custom_student_start_step
165
+ }
166
+ }
167
+
168
+ config = configs[mode]
169
 
170
  # Generate speech
171
  generated_audio = model.generate(
172
  gen_text=target_text,
173
  audio_path=prompt_audio,
174
+ prompt_text=transcribed_text if transcribed_text else None,
175
+ teacher_steps=config["teacher_steps"],
176
+ teacher_stopping_time=config["teacher_stopping_time"],
177
+ student_start_step=config["student_start_step"],
178
  temperature=temperature,
179
  verbose=verbose
180
  )
 
198
 
199
  torchaudio.save(output_path, generated_audio, 24000)
200
 
201
+ # Format output
202
+ metrics = f"""RTF: {rtf:.2f}x ({1/rtf:.2f}x faster)
203
+ Processing: {processing_time:.2f}s for {audio_duration:.2f}s audio
204
+ Device: {device.upper()}"""
205
 
206
+ info = f"Mode: {mode}"
207
+ if not prompt_text.strip():
208
+ info += f" | Auto-transcribed"
209
+
210
+ # Clean up GPU memory
211
+ del model
212
+ if device == "cuda":
213
+ torch.cuda.empty_cache()
214
+
215
+ # Return transcribed text to update the textbox
216
+ return output_path, "✅ Success!", metrics, info, transcribed_text
217
 
218
  except Exception as e:
219
+ import traceback
220
+ print(traceback.format_exc())
221
+ return None, f"❌ Error: {str(e)}", "", "", prompt_text
222
 
223
  # Create Gradio interface
224
+ with gr.Blocks(
225
+ title="DMOSpeech 2 - Zero-Shot TTS",
226
+ theme=gr.themes.Soft(),
227
+ css="""
228
+ .gradio-container { max-width: 1200px !important; }
229
+ """
230
+ ) as demo:
231
 
232
+ gr.Markdown(f"""
233
+ <div style="text-align: center;">
234
+ <h1>🎙️ DMOSpeech 2: Zero-Shot Text-to-Speech</h1>
235
+ <p>Generate natural speech in any voice with just a 3-10 second reference!</p>
236
+ <p><b>System Status:</b> {status_message}</p>
237
+ </div>
238
  """)
239
 
240
  with gr.Row():
241
  with gr.Column(scale=1):
242
+ # Inputs
243
  prompt_audio = gr.Audio(
244
+ label="📎 Reference Audio (3-10 seconds)",
245
  type="filepath",
246
  sources=["upload", "microphone"]
247
  )
 
258
  lines=4
259
  )
260
 
 
261
  mode = gr.Radio(
262
  choices=[
263
  "Student Only (4 steps)",
 
267
  ],
268
  value="Teacher-Guided (8 steps)",
269
  label="🚀 Generation Mode",
270
+ info="Speed vs quality tradeoff"
271
  )
272
 
273
+ # Advanced settings
274
  with gr.Accordion("⚙️ Advanced Settings", open=False):
275
  temperature = gr.Slider(
276
  minimum=0.0,
 
278
  value=0.0,
279
  step=0.1,
280
  label="Duration Temperature",
281
+ info="0 = consistent, >0 = varied rhythm"
282
  )
283
 
284
+ with gr.Group(visible=False) as custom_group:
285
+ custom_teacher_steps = gr.Slider(0, 32, 16, 1, label="Teacher Steps")
286
+ custom_teacher_stopping_time = gr.Slider(0.0, 1.0, 0.07, 0.01, label="Stopping Time")
287
+ custom_student_start_step = gr.Slider(0, 4, 1, 1, label="Student Start Step")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
+ verbose = gr.Checkbox(False, label="Verbose Output")
 
 
 
 
290
 
291
  generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
292
 
293
  with gr.Column(scale=1):
294
+ # Outputs
295
  output_audio = gr.Audio(
296
  label="🔊 Generated Speech",
297
  type="filepath",
298
  autoplay=True
299
  )
300
 
301
+ status = gr.Textbox(label="Status", interactive=False)
302
+ metrics = gr.Textbox(label="Performance", interactive=False, lines=3)
303
+ info = gr.Textbox(label="Info", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
304
 
305
+ # Guide
306
  gr.Markdown("""
307
+ ### 💡 Quick Guide
308
 
309
+ | Mode | Speed | Quality | Use Case |
310
+ |------|-------|---------|----------|
311
+ | Student Only | 20x realtime | Good | Real-time apps |
312
+ | Teacher-Guided | 10x realtime | Better | General use |
313
+ | High Diversity | 5x realtime | Best | Production |
314
 
315
+ **Tips:**
316
+ - Leave reference text empty for auto-transcription
317
+ - Auto-transcription only happens once - the text will be filled in
318
+ - Use temperature > 0 for more natural rhythm variation
319
+ - Custom mode lets you fine-tune all parameters
320
  """)
321
 
322
+ # Examples
323
+ gr.Markdown("### 🎯 Example Texts")
324
 
325
  gr.Markdown("""
326
  <details>
327
  <summary>English Example</summary>
328
 
329
+ **Reference:** "Some call me nature, others call me mother nature."
330
 
331
+ **Target:** "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
332
  </details>
333
 
334
  <details>
335
  <summary>Chinese Example</summary>
336
 
337
+ **Reference:** "对,这就是我,万人敬仰的太乙真人。"
338
 
339
+ **Target:** "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:'我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?'"
340
  </details>
341
+ """)
342
 
343
+ # Event handlers
344
+ def toggle_custom(mode):
345
+ return gr.update(visible=(mode == "Custom"))
346
 
347
+ mode.change(toggle_custom, [mode], [custom_group])
 
 
348
 
 
349
  generate_btn.click(
350
+ generate_speech_gpu,
351
  inputs=[
352
  prompt_audio,
353
  prompt_text,
 
359
  custom_student_start_step,
360
  verbose
361
  ],
362
+ outputs=[
363
+ output_audio,
364
+ status,
365
+ metrics,
366
+ info,
367
+ prompt_text # Update the prompt_text textbox with transcribed text
368
+ ]
 
 
 
 
 
369
  )
370
 
371
+ # Launch
372
  if __name__ == "__main__":
 
 
 
 
 
373
  demo.launch()