yl4579 commited on
Commit
7b1f9ef
·
verified ·
1 Parent(s): 51a612e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -176
app.py CHANGED
@@ -14,61 +14,37 @@ from transformers import pipeline
14
  from infer import DMOInference
15
 
16
  # Global variables
17
- model_paths = {"student": None, "duration": None}
18
  asr_pipe = None
19
- model_downloaded = False
20
 
21
- # Download models on startup (CPU)
22
- def download_models():
23
- """Download models from HuggingFace Hub."""
24
- global model_downloaded, model_paths
25
-
26
- try:
27
- print("Downloading models from HuggingFace...")
28
-
29
- # Download student model
30
- student_path = hf_hub_download(
31
- repo_id="yl4579/DMOSpeech2",
32
- filename="model_85000.pt",
33
- cache_dir="./models"
34
- )
35
-
36
- # Download duration predictor
37
- duration_path = hf_hub_download(
38
- repo_id="yl4579/DMOSpeech2",
39
- filename="model_1500.pt",
40
- cache_dir="./models"
41
- )
42
-
43
- model_paths["student"] = student_path
44
- model_paths["duration"] = duration_path
45
- model_downloaded = True
46
-
47
- print(f"✓ Models downloaded successfully")
48
- return True
49
-
50
- except Exception as e:
51
- print(f"Error downloading models: {e}")
52
- return False
53
-
54
- # Initialize ASR pipeline on CPU
55
- def initialize_asr_pipeline():
56
  """Initialize the ASR pipeline on startup."""
57
  global asr_pipe
58
 
 
 
 
 
 
 
 
 
 
 
59
  print("Initializing ASR pipeline...")
60
  try:
61
  asr_pipe = pipeline(
62
  "automatic-speech-recognition",
63
  model="openai/whisper-large-v3-turbo",
64
- torch_dtype=torch.float32,
65
- device="cpu" # Always use CPU for ASR to save GPU memory
66
  )
67
- print("ASR pipeline initialized successfully")
68
- return True
69
  except Exception as e:
70
  print(f"Error initializing ASR pipeline: {e}")
71
- return False
72
 
73
  # Transcribe function
74
  def transcribe(ref_audio, language=None):
@@ -76,7 +52,7 @@ def transcribe(ref_audio, language=None):
76
  global asr_pipe
77
 
78
  if asr_pipe is None:
79
- return ""
80
 
81
  try:
82
  result = asr_pipe(
@@ -91,14 +67,65 @@ def transcribe(ref_audio, language=None):
91
  print(f"Transcription error: {e}")
92
  return ""
93
 
94
- # Initialize on startup
95
- print("Starting DMOSpeech 2...")
96
- models_ready = download_models()
97
- asr_ready = initialize_asr_pipeline()
98
- status_message = f"Models: {'✅' if models_ready else '❌'} | ASR: {'✅' if asr_ready else '❌'}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- @spaces.GPU(duration=120)
101
- def generate_speech_gpu(
 
 
 
 
 
102
  prompt_audio,
103
  prompt_text,
104
  target_text,
@@ -109,72 +136,53 @@ def generate_speech_gpu(
109
  custom_student_start_step,
110
  verbose
111
  ):
112
- """Generate speech with GPU acceleration."""
113
 
114
- if not model_downloaded:
115
- return None, " Models not downloaded! Please refresh the page.", "", "", prompt_text
116
 
117
  if prompt_audio is None:
118
- return None, "Please upload a reference audio!", "", "", prompt_text
119
 
120
  if not target_text:
121
- return None, "Please enter text to generate!", "", "", prompt_text
122
 
123
  try:
124
- # Initialize model on GPU
125
- device = "cuda" if torch.cuda.is_available() else "cpu"
126
- print(f"Initializing model on {device}...")
127
-
128
- model = DMOInference(
129
- student_checkpoint_path=model_paths["student"],
130
- duration_predictor_path=model_paths["duration"],
131
- device=device,
132
- model_type="F5TTS_Base"
133
- )
134
-
135
- # Auto-transcribe if needed (this happens on CPU)
136
- transcribed_text = prompt_text # Default to provided text
137
- if not prompt_text.strip():
138
  print("Auto-transcribing reference audio...")
139
- transcribed_text = transcribe(prompt_audio)
140
- print(f"Transcribed: {transcribed_text}")
141
 
142
  start_time = time.time()
143
 
144
  # Configure parameters based on mode
145
- configs = {
146
- "Student Only (4 steps)": {
147
- "teacher_steps": 0,
148
- "student_start_step": 0,
149
- "teacher_stopping_time": 1.0
150
- },
151
- "Teacher-Guided (8 steps)": {
152
- "teacher_steps": 16,
153
- "teacher_stopping_time": 0.07,
154
- "student_start_step": 1
155
- },
156
- "High Diversity (16 steps)": {
157
- "teacher_steps": 24,
158
- "teacher_stopping_time": 0.3,
159
- "student_start_step": 2
160
- },
161
- "Custom": {
162
- "teacher_steps": custom_teacher_steps,
163
- "teacher_stopping_time": custom_teacher_stopping_time,
164
- "student_start_step": custom_student_start_step
165
- }
166
- }
167
-
168
- config = configs[mode]
169
 
170
  # Generate speech
171
  generated_audio = model.generate(
172
  gen_text=target_text,
173
  audio_path=prompt_audio,
174
- prompt_text=transcribed_text if transcribed_text else None,
175
- teacher_steps=config["teacher_steps"],
176
- teacher_stopping_time=config["teacher_stopping_time"],
177
- student_start_step=config["student_start_step"],
178
  temperature=temperature,
179
  verbose=verbose
180
  )
@@ -198,50 +206,29 @@ def generate_speech_gpu(
198
 
199
  torchaudio.save(output_path, generated_audio, 24000)
200
 
201
- # Format output
202
- metrics = f"""RTF: {rtf:.2f}x ({1/rtf:.2f}x faster)
203
- Processing: {processing_time:.2f}s for {audio_duration:.2f}s audio
204
- Device: {device.upper()}"""
205
 
206
- info = f"Mode: {mode}"
207
- if not prompt_text.strip():
208
- info += f" | Auto-transcribed"
209
-
210
- # Clean up GPU memory
211
- del model
212
- if device == "cuda":
213
- torch.cuda.empty_cache()
214
-
215
- # Return transcribed text to update the textbox
216
- return output_path, "✅ Success!", metrics, info, transcribed_text
217
 
218
  except Exception as e:
219
- import traceback
220
- print(traceback.format_exc())
221
- return None, f"❌ Error: {str(e)}", "", "", prompt_text
222
 
223
  # Create Gradio interface
224
- with gr.Blocks(
225
- title="DMOSpeech 2 - Zero-Shot TTS",
226
- theme=gr.themes.Soft(),
227
- css="""
228
- .gradio-container { max-width: 1200px !important; }
229
- """
230
- ) as demo:
231
-
232
  gr.Markdown(f"""
233
- <div style="text-align: center;">
234
- <h1>🎙️ DMOSpeech 2: Zero-Shot Text-to-Speech</h1>
235
- <p>Generate natural speech in any voice with just a 3-10 second reference!</p>
236
- <p><b>System Status:</b> {status_message}</p>
237
- </div>
238
  """)
239
 
240
  with gr.Row():
241
  with gr.Column(scale=1):
242
- # Inputs
243
  prompt_audio = gr.Audio(
244
- label="📎 Reference Audio (3-10 seconds)",
245
  type="filepath",
246
  sources=["upload", "microphone"]
247
  )
@@ -258,6 +245,7 @@ with gr.Blocks(
258
  lines=4
259
  )
260
 
 
261
  mode = gr.Radio(
262
  choices=[
263
  "Student Only (4 steps)",
@@ -267,10 +255,10 @@ with gr.Blocks(
267
  ],
268
  value="Teacher-Guided (8 steps)",
269
  label="🚀 Generation Mode",
270
- info="Speed vs quality tradeoff"
271
  )
272
 
273
- # Advanced settings
274
  with gr.Accordion("⚙️ Advanced Settings", open=False):
275
  temperature = gr.Slider(
276
  minimum=0.0,
@@ -278,76 +266,115 @@ with gr.Blocks(
278
  value=0.0,
279
  step=0.1,
280
  label="Duration Temperature",
281
- info="0 = consistent, >0 = varied rhythm"
282
  )
283
 
284
- with gr.Group(visible=False) as custom_group:
285
- custom_teacher_steps = gr.Slider(0, 32, 16, 1, label="Teacher Steps")
286
- custom_teacher_stopping_time = gr.Slider(0.0, 1.0, 0.07, 0.01, label="Stopping Time")
287
- custom_student_start_step = gr.Slider(0, 4, 1, 1, label="Student Start Step")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
- verbose = gr.Checkbox(False, label="Verbose Output")
 
 
 
 
290
 
291
  generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
292
 
293
  with gr.Column(scale=1):
294
- # Outputs
295
  output_audio = gr.Audio(
296
  label="🔊 Generated Speech",
297
  type="filepath",
298
  autoplay=True
299
  )
300
 
301
- status = gr.Textbox(label="Status", interactive=False)
302
- metrics = gr.Textbox(label="Performance", interactive=False, lines=3)
303
- info = gr.Textbox(label="Info", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
304
 
305
- # Guide
306
  gr.Markdown("""
307
- ### 💡 Quick Guide
308
 
309
- | Mode | Speed | Quality | Use Case |
310
- |------|-------|---------|----------|
311
- | Student Only | 20x realtime | Good | Real-time apps |
312
- | Teacher-Guided | 10x realtime | Better | General use |
313
- | High Diversity | 5x realtime | Best | Production |
314
 
315
- **Tips:**
316
- - Leave reference text empty for auto-transcription
317
- - Auto-transcription only happens once - the text will be filled in
318
- - Use temperature > 0 for more natural rhythm variation
319
- - Custom mode lets you fine-tune all parameters
320
  """)
321
 
322
- # Examples
323
- gr.Markdown("### 🎯 Example Texts")
324
 
325
  gr.Markdown("""
326
  <details>
327
  <summary>English Example</summary>
328
 
329
- **Reference:** "Some call me nature, others call me mother nature."
330
 
331
- **Target:** "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
332
  </details>
333
 
334
  <details>
335
  <summary>Chinese Example</summary>
336
 
337
- **Reference:** "对,这就是我,万人敬仰的太乙真人。"
338
 
339
- **Target:** "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:'我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?'"
340
  </details>
341
- """)
342
 
343
- # Event handlers
344
- def toggle_custom(mode):
345
- return gr.update(visible=(mode == "Custom"))
346
 
347
- mode.change(toggle_custom, [mode], [custom_group])
 
 
348
 
 
349
  generate_btn.click(
350
- generate_speech_gpu,
351
  inputs=[
352
  prompt_audio,
353
  prompt_text,
@@ -359,15 +386,25 @@ with gr.Blocks(
359
  custom_student_start_step,
360
  verbose
361
  ],
362
- outputs=[
363
- output_audio,
364
- status,
365
- metrics,
366
- info,
367
- prompt_text # Update the prompt_text textbox with transcribed text
368
- ]
 
 
 
 
 
369
  )
370
 
371
- # Launch
372
  if __name__ == "__main__":
 
 
 
 
 
373
  demo.launch()
 
14
  from infer import DMOInference
15
 
16
  # Global variables
17
+ model = None
18
  asr_pipe = None
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
+ # Initialize ASR pipeline
22
+ def initialize_asr_pipeline(device=device, dtype=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  """Initialize the ASR pipeline on startup."""
24
  global asr_pipe
25
 
26
+ if dtype is None:
27
+ dtype = (
28
+ torch.float16
29
+ if "cuda" in device
30
+ and torch.cuda.is_available()
31
+ and torch.cuda.get_device_properties(device).major >= 7
32
+ and not torch.cuda.get_device_name().endswith("[ZLUDA]")
33
+ else torch.float32
34
+ )
35
+
36
  print("Initializing ASR pipeline...")
37
  try:
38
  asr_pipe = pipeline(
39
  "automatic-speech-recognition",
40
  model="openai/whisper-large-v3-turbo",
41
+ torch_dtype=dtype,
42
+ device="cpu" # Keep ASR on CPU to save GPU memory
43
  )
44
+ print("ASR pipeline initialized successfully")
 
45
  except Exception as e:
46
  print(f"Error initializing ASR pipeline: {e}")
47
+ asr_pipe = None
48
 
49
  # Transcribe function
50
  def transcribe(ref_audio, language=None):
 
52
  global asr_pipe
53
 
54
  if asr_pipe is None:
55
+ return "" # Return empty string if ASR is not available
56
 
57
  try:
58
  result = asr_pipe(
 
67
  print(f"Transcription error: {e}")
68
  return ""
69
 
70
+ def download_models():
71
+ """Download models from HuggingFace Hub."""
72
+ try:
73
+ print("Downloading models from HuggingFace...")
74
+
75
+ # Download student model
76
+ student_path = hf_hub_download(
77
+ repo_id="yl4579/DMOSpeech2",
78
+ filename="model_85000.pt",
79
+ cache_dir="./models"
80
+ )
81
+
82
+ # Download duration predictor
83
+ duration_path = hf_hub_download(
84
+ repo_id="yl4579/DMOSpeech2",
85
+ filename="model_1500.pt",
86
+ cache_dir="./models"
87
+ )
88
+
89
+ print(f"Student model: {student_path}")
90
+ print(f"Duration model: {duration_path}")
91
+
92
+ return student_path, duration_path
93
+
94
+ except Exception as e:
95
+ print(f"Error downloading models: {e}")
96
+ return None, None
97
+
98
+ def initialize_model():
99
+ """Initialize the model on startup."""
100
+ global model
101
+
102
+ try:
103
+ # Download models
104
+ student_path, duration_path = download_models()
105
+
106
+ if not student_path or not duration_path:
107
+ return False, "Failed to download models from HuggingFace"
108
+
109
+ # Initialize model
110
+ model = DMOInference(
111
+ student_checkpoint_path=student_path,
112
+ duration_predictor_path=duration_path,
113
+ device=device,
114
+ model_type="F5TTS_Base"
115
+ )
116
+
117
+ return True, f"Model loaded successfully on {device.upper()}"
118
+
119
+ except Exception as e:
120
+ return False, f"Error initializing model: {str(e)}"
121
 
122
+ # Initialize models on startup
123
+ print("Initializing models...")
124
+ model_loaded, status_message = initialize_model()
125
+ initialize_asr_pipeline() # Initialize ASR pipeline
126
+
127
+ @spaces.GPU(duration=120) # Request GPU for up to 120 seconds
128
+ def generate_speech(
129
  prompt_audio,
130
  prompt_text,
131
  target_text,
 
136
  custom_student_start_step,
137
  verbose
138
  ):
139
+ """Generate speech with different configurations."""
140
 
141
+ if not model_loaded or model is None:
142
+ return None, "Model not loaded! Please refresh the page.", "", ""
143
 
144
  if prompt_audio is None:
145
+ return None, "Please upload a reference audio!", "", ""
146
 
147
  if not target_text:
148
+ return None, "Please enter text to generate!", "", ""
149
 
150
  try:
151
+ # Auto-transcribe if prompt_text is empty
152
+ if not prompt_text and prompt_text != "":
 
 
 
 
 
 
 
 
 
 
 
 
153
  print("Auto-transcribing reference audio...")
154
+ prompt_text = transcribe(prompt_audio)
155
+ print(f"Transcribed: {prompt_text}")
156
 
157
  start_time = time.time()
158
 
159
  # Configure parameters based on mode
160
+ if mode == "Student Only (4 steps)":
161
+ teacher_steps = 0
162
+ student_start_step = 0
163
+ teacher_stopping_time = 1.0
164
+ elif mode == "Teacher-Guided (8 steps)":
165
+ # Default configuration from the notebook
166
+ teacher_steps = 16
167
+ teacher_stopping_time = 0.07
168
+ student_start_step = 1
169
+ elif mode == "High Diversity (16 steps)":
170
+ teacher_steps = 24
171
+ teacher_stopping_time = 0.3
172
+ student_start_step = 2
173
+ else: # Custom
174
+ teacher_steps = custom_teacher_steps
175
+ teacher_stopping_time = custom_teacher_stopping_time
176
+ student_start_step = custom_student_start_step
 
 
 
 
 
 
 
177
 
178
  # Generate speech
179
  generated_audio = model.generate(
180
  gen_text=target_text,
181
  audio_path=prompt_audio,
182
+ prompt_text=prompt_text if prompt_text else None,
183
+ teacher_steps=teacher_steps,
184
+ teacher_stopping_time=teacher_stopping_time,
185
+ student_start_step=student_start_step,
186
  temperature=temperature,
187
  verbose=verbose
188
  )
 
206
 
207
  torchaudio.save(output_path, generated_audio, 24000)
208
 
209
+ # Format metrics
210
+ metrics = f"RTF: {rtf:.2f}x ({1/rtf:.2f}x speed) | Processing: {processing_time:.2f}s for {audio_duration:.2f}s audio"
 
 
211
 
212
+ return output_path, "Success!", metrics, f"Mode: {mode} | Transcribed: {prompt_text[:50]}..." if not prompt_text else f"Mode: {mode}"
 
 
 
 
 
 
 
 
 
 
213
 
214
  except Exception as e:
215
+ return None, f"Error: {str(e)}", "", ""
 
 
216
 
217
  # Create Gradio interface
218
+ with gr.Blocks(title="DMOSpeech 2 - Zero-Shot TTS", theme=gr.themes.Soft()) as demo:
 
 
 
 
 
 
 
219
  gr.Markdown(f"""
220
+ # 🎙️ DMOSpeech 2: Zero-Shot Text-to-Speech
221
+
222
+ Generate natural speech in any voice with just a short reference audio!
223
+
224
+ **Model Status:** {status_message} | **Device:** {device.upper()} | **ASR:** {"✅ Ready" if asr_pipe else "❌ Not available"}
225
  """)
226
 
227
  with gr.Row():
228
  with gr.Column(scale=1):
229
+ # Reference audio input
230
  prompt_audio = gr.Audio(
231
+ label="📎 Reference Audio",
232
  type="filepath",
233
  sources=["upload", "microphone"]
234
  )
 
245
  lines=4
246
  )
247
 
248
+ # Generation mode
249
  mode = gr.Radio(
250
  choices=[
251
  "Student Only (4 steps)",
 
255
  ],
256
  value="Teacher-Guided (8 steps)",
257
  label="🚀 Generation Mode",
258
+ info="Choose speed vs quality/diversity tradeoff"
259
  )
260
 
261
+ # Advanced settings (collapsible)
262
  with gr.Accordion("⚙️ Advanced Settings", open=False):
263
  temperature = gr.Slider(
264
  minimum=0.0,
 
266
  value=0.0,
267
  step=0.1,
268
  label="Duration Temperature",
269
+ info="0 = deterministic, >0 = more variation in speech rhythm"
270
  )
271
 
272
+ with gr.Group(visible=False) as custom_settings:
273
+ gr.Markdown("### Custom Mode Settings")
274
+ custom_teacher_steps = gr.Slider(
275
+ minimum=0,
276
+ maximum=32,
277
+ value=16,
278
+ step=1,
279
+ label="Teacher Steps",
280
+ info="More steps = higher quality"
281
+ )
282
+
283
+ custom_teacher_stopping_time = gr.Slider(
284
+ minimum=0.0,
285
+ maximum=1.0,
286
+ value=0.07,
287
+ step=0.01,
288
+ label="Teacher Stopping Time",
289
+ info="When to switch to student"
290
+ )
291
+
292
+ custom_student_start_step = gr.Slider(
293
+ minimum=0,
294
+ maximum=4,
295
+ value=1,
296
+ step=1,
297
+ label="Student Start Step",
298
+ info="Which student step to start from"
299
+ )
300
 
301
+ verbose = gr.Checkbox(
302
+ value=False,
303
+ label="Verbose Output",
304
+ info="Show detailed generation steps"
305
+ )
306
 
307
  generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
308
 
309
  with gr.Column(scale=1):
310
+ # Output
311
  output_audio = gr.Audio(
312
  label="🔊 Generated Speech",
313
  type="filepath",
314
  autoplay=True
315
  )
316
 
317
+ status = gr.Textbox(
318
+ label="Status",
319
+ interactive=False
320
+ )
321
+
322
+ metrics = gr.Textbox(
323
+ label="Performance Metrics",
324
+ interactive=False
325
+ )
326
+
327
+ info = gr.Textbox(
328
+ label="Generation Info",
329
+ interactive=False
330
+ )
331
 
332
+ # Tips
333
  gr.Markdown("""
334
+ ### 💡 Quick Tips:
335
 
336
+ - **Auto-transcription**: Leave reference text empty to auto-transcribe
337
+ - **Student Only**: Fastest (4 steps), good quality
338
+ - **Teacher-Guided**: Best balance (8 steps), recommended
339
+ - **High Diversity**: More natural prosody (16 steps)
340
+ - **Custom Mode**: Fine-tune all parameters
341
 
342
+ ### 📊 Expected RTF (Real-Time Factor):
343
+ - Student Only: ~0.05x (20x faster than real-time)
344
+ - Teacher-Guided: ~0.10x (10x faster)
345
+ - High Diversity: ~0.20x (5x faster)
 
346
  """)
347
 
348
+ # Examples section
349
+ gr.Markdown("### 🎯 Example Configurations")
350
 
351
  gr.Markdown("""
352
  <details>
353
  <summary>English Example</summary>
354
 
355
+ **Reference text:** "Some call me nature, others call me mother nature."
356
 
357
+ **Target text:** "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
358
  </details>
359
 
360
  <details>
361
  <summary>Chinese Example</summary>
362
 
363
+ **Reference text:** "对,这就是我,万人敬仰的太乙真人。"
364
 
365
+ **Target text:** "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:'我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?'"
366
  </details>
 
367
 
368
+ <details>
369
+ <summary>High Diversity Chinese Example</summary>
 
370
 
371
+ Same as above but with **Temperature: 0.8** for more natural variation in speech rhythm.
372
+ </details>
373
+ """)
374
 
375
+ # Event handler
376
  generate_btn.click(
377
+ generate_speech,
378
  inputs=[
379
  prompt_audio,
380
  prompt_text,
 
386
  custom_student_start_step,
387
  verbose
388
  ],
389
+ outputs=[output_audio, status, metrics, info]
390
+ )
391
+
392
+ # Update visibility of custom settings based on mode
393
+ def update_custom_visibility(mode):
394
+ is_custom = (mode == "Custom")
395
+ return gr.update(visible=is_custom)
396
+
397
+ mode.change(
398
+ update_custom_visibility,
399
+ inputs=[mode],
400
+ outputs=[custom_settings]
401
  )
402
 
403
+ # Launch the app
404
  if __name__ == "__main__":
405
+ if not model_loaded:
406
+ print(f"Warning: Model failed to load - {status_message}")
407
+ if not asr_pipe:
408
+ print("Warning: ASR pipeline not available - auto-transcription disabled")
409
+
410
  demo.launch()