seawolf2357 commited on
Commit
1f4cc99
·
verified ·
1 Parent(s): ebdf0db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -505
app.py CHANGED
@@ -1,518 +1,35 @@
1
- # app.py
2
-
3
  import os
4
  import sys
5
- import time
6
- import gradio as gr
7
- import spaces
8
- from huggingface_hub import snapshot_download
9
- from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError, RevisionNotFoundError
10
- from pathlib import Path
11
- import tempfile
12
- from pydub import AudioSegment
13
- import cv2
14
- import numpy as np
15
- from scipy import interpolate
16
-
17
- # Add the src directory to the system path to allow for local imports
18
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')))
19
-
20
- from models.inference.moda_test import LiveVASAPipeline, emo_map, set_seed
21
-
22
- # --- Configuration ---
23
- # Set seed for reproducibility
24
- set_seed(42)
25
-
26
- # Paths and constants for the Gradio demo
27
- DEFAULT_CFG_PATH = "configs/audio2motion/inference/inference.yaml"
28
- DEFAULT_MOTION_MEAN_STD_PATH = "src/datasets/mean.pt"
29
- DEFAULT_SILENT_AUDIO_PATH = "src/examples/silent-audio.wav"
30
- OUTPUT_DIR = "gradio_output"
31
- WEIGHTS_DIR = "pretrain_weights"
32
- REPO_ID = "lixinyizju/moda"
33
-
34
- # --- Download Pre-trained Weights from Hugging Face Hub ---
35
- def download_weights():
36
- """
37
- Downloads pre-trained weights from Hugging Face Hub if they don't exist locally.
38
- """
39
- # A simple check for a key file to see if the download is likely complete
40
- motion_model_file = os.path.join(WEIGHTS_DIR, "moda", "net-200.pth")
41
-
42
- if not os.path.exists(motion_model_file):
43
- print(f"Weights not found locally. Downloading from Hugging Face Hub repo '{REPO_ID}'...")
44
- print(f"This may take a while depending on your internet connection.")
45
- try:
46
- snapshot_download(
47
- repo_id=REPO_ID,
48
- local_dir=WEIGHTS_DIR,
49
- local_dir_use_symlinks=False, # Use False to copy files directly; safer for Windows
50
- resume_download=True,
51
- )
52
- print("Weights downloaded successfully.")
53
- except GatedRepoError:
54
- raise gr.Error(f"Access to the repository '{REPO_ID}' is gated. Please visit https://huggingface.co/{REPO_ID} to request access.")
55
- except (RepositoryNotFoundError, RevisionNotFoundError):
56
- raise gr.Error(f"The repository '{REPO_ID}' was not found. Please check the repository ID.")
57
- except Exception as e:
58
- print(f"An error occurred during download: {e}")
59
- raise gr.Error(f"Failed to download models. Please check your internet connection and try again. Error: {e}")
60
- else:
61
- print(f"Found existing weights at '{WEIGHTS_DIR}'. Skipping download.")
62
-
63
- # --- Audio Conversion Function ---
64
- def ensure_wav_format(audio_path):
65
- """
66
- Ensures the audio file is in WAV format. If not, converts it to WAV.
67
- Returns the path to the WAV file (either original or converted).
68
- """
69
- if audio_path is None:
70
- return None
71
-
72
- audio_path = Path(audio_path)
73
-
74
- # Check if already WAV
75
- if audio_path.suffix.lower() == '.wav':
76
- print(f"Audio is already in WAV format: {audio_path}")
77
- return str(audio_path)
78
-
79
- # Convert to WAV
80
- print(f"Converting audio from {audio_path.suffix} to WAV format...")
81
-
82
- try:
83
- # Load the audio file
84
- audio = AudioSegment.from_file(audio_path)
85
-
86
- # Create a temporary WAV file
87
- with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
88
- wav_path = tmp_file.name
89
- # Export as WAV with higher sampling rate for better quality
90
- audio.export(
91
- wav_path,
92
- format='wav',
93
- parameters=["-ar", "24000", "-ac", "1"] # 24kHz, mono for better lip-sync
94
- )
95
-
96
- print(f"Audio converted successfully to: {wav_path}")
97
- return wav_path
98
-
99
- except Exception as e:
100
- print(f"Error converting audio: {e}")
101
- raise gr.Error(f"Failed to convert audio file to WAV format. Error: {e}")
102
 
103
- # --- Frame Interpolation Function ---
104
- def interpolate_frames(video_path, target_fps=30):
105
- """
106
- Interpolates frames in a video to achieve smoother motion.
107
-
108
- Args:
109
- video_path: Path to the input video
110
- target_fps: Target frames per second
111
-
112
- Returns:
113
- Path to the interpolated video
114
- """
115
  try:
116
- video_path = str(video_path)
117
- cap = cv2.VideoCapture(video_path)
118
-
119
- # Get original video properties
120
- original_fps = cap.get(cv2.CAP_PROP_FPS)
121
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
122
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
123
-
124
- # Fix for FPS detection issue
125
- if original_fps == 0 or original_fps is None:
126
- print("Warning: Could not detect original FPS. Assuming 25 FPS.")
127
- original_fps = 25.0
128
-
129
- print(f"Original FPS: {original_fps}, Target FPS: {target_fps}")
130
 
131
- # If target FPS is not higher, return original
132
- if original_fps >= target_fps:
133
- cap.release()
134
- print("Target FPS is not higher than original. Skipping interpolation.")
135
- return video_path
136
 
137
- # Read all frames
138
- frames = []
139
- while True:
140
- ret, frame = cap.read()
141
- if not ret:
142
- break
143
- frames.append(frame)
144
- cap.release()
145
 
146
- if len(frames) < 2:
147
- print("Not enough frames for interpolation.")
148
- return video_path
149
-
150
- # Calculate interpolation factor (can be fractional)
151
- interpolation_factor = target_fps / original_fps
152
-
153
- # For fractional factors, we need different approach
154
- if interpolation_factor <= 1:
155
- print("Interpolation factor too low. Skipping.")
156
- return video_path
157
-
158
- print(f"Interpolating with factor: {interpolation_factor:.2f}")
159
- print(f"Total frames to process: {len(frames)}")
160
-
161
- # Perform frame interpolation
162
- interpolated_frames = []
163
-
164
- if interpolation_factor == int(interpolation_factor):
165
- # Integer factor - simple interpolation
166
- factor = int(interpolation_factor)
167
- for i in range(len(frames) - 1):
168
- interpolated_frames.append(frames[i])
169
- # Generate intermediate frames
170
- for j in range(1, factor):
171
- alpha = j / factor
172
- interpolated_frame = cv2.addWeighted(
173
- frames[i], 1 - alpha,
174
- frames[i + 1], alpha,
175
- 0
176
- )
177
- interpolated_frames.append(interpolated_frame)
178
- interpolated_frames.append(frames[-1])
179
- else:
180
- # Fractional factor - use different approach
181
- # For 25 -> 60 fps, we need to add selective frames
182
- for i in range(len(frames) - 1):
183
- interpolated_frames.append(frames[i])
184
- # Add intermediate frame for smoother motion
185
- if i % 2 == 0: # Add extra frame every other original frame
186
- alpha = 0.4 # Blend ratio
187
- interpolated_frame = cv2.addWeighted(
188
- frames[i], 1 - alpha,
189
- frames[i + 1], alpha,
190
- 0
191
- )
192
- interpolated_frames.append(interpolated_frame)
193
- interpolated_frames.append(frames[-1])
194
-
195
- print(f"Total interpolated frames: {len(interpolated_frames)}")
196
-
197
- # Save the interpolated video
198
- output_path = video_path.replace('.mp4', '_interpolated.mp4')
199
-
200
- # Use H.264 codec for better compatibility
201
- fourcc = cv2.VideoWriter_fourcc(*'H264')
202
- out = cv2.VideoWriter(output_path, fourcc, target_fps, (width, height))
203
-
204
- if not out.isOpened():
205
- # Fallback to mp4v if H264 not available
206
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
207
- out = cv2.VideoWriter(output_path, fourcc, target_fps, (width, height))
208
-
209
- for frame in interpolated_frames:
210
- out.write(frame)
211
- out.release()
212
-
213
- print(f"Interpolated video saved to: {output_path}")
214
- return output_path
215
-
216
- except Exception as e:
217
- print(f"Error during frame interpolation: {e}")
218
- import traceback
219
- traceback.print_exc()
220
- return video_path # Return original if interpolation fails
221
-
222
- # --- Initialization ---
223
- # Create output directory if it doesn't exist
224
- os.makedirs(OUTPUT_DIR, exist_ok=True)
225
-
226
- # Download weights before initializing the pipeline
227
- download_weights()
228
-
229
- # Instantiate the pipeline once to avoid reloading models on every request
230
- print("Initializing MoDA pipeline...")
231
- try:
232
- pipeline = LiveVASAPipeline(
233
- cfg_path=DEFAULT_CFG_PATH,
234
- motion_mean_std_path=DEFAULT_MOTION_MEAN_STD_PATH
235
- )
236
- print("MoDA pipeline initialized successfully.")
237
- except Exception as e:
238
- print(f"Error initializing pipeline: {e}")
239
- pipeline = None
240
-
241
- # Invert the emo_map for easy lookup from the dropdown value
242
- emo_name_to_id = {v: k for k, v in emo_map.items()}
243
-
244
- # --- Audio Length Check Function ---
245
- def check_audio_length(audio_path):
246
- """
247
- Check the length of an audio file and warn if it's too long.
248
-
249
- Args:
250
- audio_path: Path to the audio file
251
-
252
- Returns:
253
- Duration in seconds
254
- """
255
- try:
256
- audio = AudioSegment.from_file(audio_path)
257
- duration_seconds = len(audio) / 1000.0
258
- return duration_seconds
259
- except Exception as e:
260
- print(f"Error checking audio length: {e}")
261
- return None
262
-
263
- # --- Core Generation Function ---
264
- @spaces.GPU(duration=180) # Increased duration for smoothing and interpolation
265
- def generate_motion(source_image_path, driving_audio_path, emotion_name,
266
- cfg_scale, smooth_enabled, target_fps,
267
- progress=gr.Progress(track_tqdm=True)):
268
- """
269
- The main function that takes Gradio inputs and generates the talking head video.
270
-
271
- Args:
272
- source_image_path: Path to the source image
273
- driving_audio_path: Path to the driving audio
274
- emotion_name: Selected emotion
275
- cfg_scale: CFG scale for generation
276
- smooth_enabled: Whether to enable smoothing
277
- target_fps: Target frames per second for interpolation
278
- """
279
- if pipeline is None:
280
- raise gr.Error("Pipeline failed to initialize. Check the console logs for details.")
281
 
282
- if source_image_path is None:
283
- raise gr.Error("Please upload a source image.")
284
- if driving_audio_path is None:
285
- raise gr.Error("Please upload a driving audio file.")
286
-
287
- # Check audio length
288
- audio_duration = check_audio_length(driving_audio_path)
289
- if audio_duration:
290
- print(f"Audio duration: {audio_duration:.1f} seconds")
291
- if audio_duration > 60:
292
- gr.Warning(f"⚠️ Audio is {audio_duration:.1f} seconds long. MoDA works best with audio under 60 seconds. Processing may be slow and quality may degrade.")
293
- if audio_duration > 180:
294
- raise gr.Error("Audio is too long. Please use audio files under 3 minutes (180 seconds) for best results.")
295
-
296
- start_time = time.time()
297
-
298
- # Ensure audio is in WAV format with optimal sampling rate
299
- wav_audio_path = ensure_wav_format(driving_audio_path)
300
- temp_wav_created = wav_audio_path != driving_audio_path
301
-
302
- # Create a unique subdirectory for this run
303
- timestamp = time.strftime("%Y%m%d-%H%M%S")
304
- run_output_dir = os.path.join(OUTPUT_DIR, timestamp)
305
- os.makedirs(run_output_dir, exist_ok=True)
306
-
307
- # Get emotion ID from its name
308
- emotion_id = emo_name_to_id.get(emotion_name, 8) # Default to 'None' (ID 8) if not found
309
-
310
- print(f"Starting generation with the following parameters:")
311
- print(f" Source Image: {source_image_path}")
312
- print(f" Driving Audio (original): {driving_audio_path}")
313
- print(f" Driving Audio (WAV): {wav_audio_path}")
314
- print(f" Emotion: {emotion_name} (ID: {emotion_id})")
315
- print(f" CFG Scale: {cfg_scale}")
316
- print(f" Smoothing: {smooth_enabled}")
317
- print(f" Target FPS: {target_fps}")
318
-
319
- try:
320
- # Temporarily disable smoothing if it causes CUDA tensor issues
321
- # Check if smooth causes issues and handle gracefully
322
  try:
323
- # Try with smoothing first
324
- result_video_path = pipeline.driven_sample(
325
- image_path=source_image_path,
326
- audio_path=wav_audio_path,
327
- cfg_scale=float(cfg_scale),
328
- emo=emotion_id,
329
- save_dir=".",
330
- smooth=smooth_enabled, # Use the checkbox value
331
- silent_audio_path=DEFAULT_SILENT_AUDIO_PATH,
332
- )
333
- except TypeError as tensor_error:
334
- if "can't convert cuda" in str(tensor_error) and smooth_enabled:
335
- print("Warning: Smoothing caused CUDA tensor error. Retrying without smoothing...")
336
- # Retry without smoothing
337
- result_video_path = pipeline.driven_sample(
338
- image_path=source_image_path,
339
- audio_path=wav_audio_path,
340
- cfg_scale=float(cfg_scale),
341
- emo=emotion_id,
342
- save_dir=".",
343
- smooth=False, # Disable smoothing as fallback
344
- silent_audio_path=DEFAULT_SILENT_AUDIO_PATH,
345
- )
346
- print("Generated video without smoothing due to technical limitations.")
347
- else:
348
- raise tensor_error
349
-
350
- # Apply frame interpolation if requested
351
- if target_fps > 24: # Assuming default is around 24 FPS
352
- print(f"Applying frame interpolation to achieve {target_fps} FPS...")
353
- result_video_path = interpolate_frames(result_video_path, target_fps=target_fps)
354
-
355
  except Exception as e:
356
- print(f"An error occurred during video generation: {e}")
357
  import traceback
358
- traceback.print_exc()
359
- raise gr.Error(f"An unexpected error occurred: {str(e)}. Please check the console for details.")
360
- finally:
361
- # Clean up temporary WAV file if created
362
- if temp_wav_created and os.path.exists(wav_audio_path):
363
- try:
364
- os.remove(wav_audio_path)
365
- print(f"Cleaned up temporary WAV file: {wav_audio_path}")
366
- except Exception as e:
367
- print(f"Warning: Could not delete temporary file {wav_audio_path}: {e}")
368
-
369
- end_time = time.time()
370
- processing_time = end_time - start_time
371
-
372
- result_video_path = Path(result_video_path)
373
- final_path = result_video_path.with_name(f"final_{result_video_path.stem}{result_video_path.suffix}")
374
-
375
- print(f"Video generated successfully at: {final_path}")
376
- print(f"Processing time: {processing_time:.2f} seconds.")
377
-
378
- return final_path
379
-
380
- # --- Gradio UI Definition ---
381
- with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 960px !important; margin: 0 auto !important}") as demo:
382
- gr.HTML(
383
- """
384
- <div align='center'>
385
- <h1>MoDA: Multi-modal Diffusion Architecture for Talking Head Generation</h1>
386
- <h2 style="color: #4A90E2;">Enhanced Version with Smooth Motion</h2>
387
- <p style="display:flex; justify-content: center; gap: 10px;">
388
- <a href='https://lixinyyang.github.io/MoDA.github.io/'><img src='https://img.shields.io/badge/Project-Page-blue'></a>
389
- <a href='https://arxiv.org/abs/2507.03256'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
390
- <a href='https://github.com/lixinyyang/MoDA/'><img src='https://img.shields.io/badge/Code-Github-green'></a>
391
- </p>
392
- </div>
393
- """
394
- )
395
-
396
- with gr.Row(variant="panel"):
397
- with gr.Column(scale=1):
398
- gr.Markdown("### 📥 Input Settings")
399
-
400
- with gr.Row():
401
- source_image = gr.Image(
402
- label="Source Image",
403
- type="filepath",
404
- value="src/examples/reference_images/7.jpg"
405
- )
406
-
407
- with gr.Row():
408
- driving_audio = gr.Audio(
409
- label="Driving Audio (Recommended: < 60 seconds)",
410
- type="filepath",
411
- value="src/examples/driving_audios/5.wav"
412
- )
413
-
414
- gr.Markdown("### ⚙️ Generation Settings")
415
-
416
- with gr.Row():
417
- emotion_dropdown = gr.Dropdown(
418
- label="Emotion",
419
- choices=list(emo_map.values()),
420
- value="None",
421
- info="Select an emotion for more natural facial expressions"
422
- )
423
-
424
- with gr.Row():
425
- cfg_slider = gr.Slider(
426
- label="CFG Scale (Lower = Smoother motion)",
427
- minimum=0.5,
428
- maximum=5.0,
429
- step=0.1,
430
- value=0.5,
431
- info="Lower values produce smoother but less controlled motion"
432
- )
433
-
434
- gr.Markdown("### 🎬 Motion Enhancement")
435
-
436
- with gr.Row():
437
- smooth_checkbox = gr.Checkbox(
438
- label="Enable Smoothing (Experimental)",
439
- value=False, # Changed to False due to CUDA issues
440
- info="May cause errors on some systems. If errors occur, disable this option."
441
- )
442
-
443
- with gr.Row():
444
- fps_slider = gr.Slider(
445
- label="Target FPS",
446
- minimum=24,
447
- maximum=50,
448
- step=1,
449
- value=50,
450
- info="Higher FPS for smoother motion. 30 FPS recommended, 50 FPS maximum"
451
- )
452
-
453
- submit_button = gr.Button("🎥 Generate Video", variant="primary", size="lg")
454
-
455
- with gr.Column(scale=1):
456
- gr.Markdown("### 📺 Output")
457
- output_video = gr.Video(label="Generated Video")
458
-
459
- # Processing status
460
- with gr.Row():
461
- gr.Markdown(
462
- """
463
- <div style="background-color: #f0f8ff; padding: 10px; border-radius: 5px; margin-top: 10px;">
464
- <p style="margin: 0; font-size: 0.9em;">
465
- <b>Tips for best results:</b><br>
466
- • Use high-quality front-facing images<br>
467
- • Clear audio without background noise<br>
468
- • <b>Keep audio under 60 seconds</b><br>
469
- • Adjust CFG scale if motion seems stiff<br>
470
- • For longer audio, split into segments
471
- </p>
472
- </div>
473
- """
474
- )
475
-
476
- gr.Markdown(
477
- """
478
- ---
479
- ### ⚠️ **Disclaimer**
480
- This project is intended for academic research, and we explicitly disclaim any responsibility for user-generated content.
481
- Users are solely liable for their actions while using this generative model.
482
-
483
- ### 🚀 **Enhancement Features**
484
- - **Frame Smoothing**: Reduces jitter and improves transition between frames (currently experimental)
485
- - **Frame Interpolation**: Increases FPS for smoother motion
486
- - **Optimized Audio Processing**: Better lip-sync with 24kHz sampling
487
- - **Fine-tuned CFG Scale**: Better control over motion naturalness
488
-
489
- ### ⏱️ **Audio Length Limitations**
490
- - **Optimal**: Under 30 seconds for best quality and speed
491
- - **Recommended**: Under 60 seconds
492
- - **Maximum**: 180 seconds (3 minutes) - very slow processing
493
- - For longer content, consider splitting audio into segments
494
- """
495
- )
496
-
497
- # Examples section
498
- gr.Examples(
499
- examples=[
500
- ["src/examples/reference_images/7.jpg", "src/examples/driving_audios/5.wav", "None", 1.0, False, 30],
501
- ["src/examples/reference_images/7.jpg", "src/examples/driving_audios/5.wav", "Happy", 0.8, False, 30],
502
- ["src/examples/reference_images/7.jpg", "src/examples/driving_audios/5.wav", "Sad", 1.2, False, 24],
503
- ],
504
- inputs=[source_image, driving_audio, emotion_dropdown, cfg_slider, smooth_checkbox, fps_slider],
505
- outputs=output_video,
506
- fn=generate_motion,
507
- cache_examples=False,
508
- label="Example Configurations"
509
- )
510
-
511
- submit_button.click(
512
- fn=generate_motion,
513
- inputs=[source_image, driving_audio, emotion_dropdown, cfg_slider, smooth_checkbox, fps_slider],
514
- outputs=output_video
515
- )
516
 
517
  if __name__ == "__main__":
518
- demo.launch(share=True)
 
 
 
1
  import os
2
  import sys
3
+ import streamlit as st
4
+ from tempfile import NamedTemporaryFile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ def main():
 
 
 
 
 
 
 
 
 
 
 
7
  try:
8
+ # Get the code from secrets
9
+ code = os.environ.get("MAIN_CODE")
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ if not code:
12
+ st.error("⚠️ The application code wasn't found in secrets. Please add the MAIN_CODE secret.")
13
+ return
 
 
14
 
15
+ # Create a temporary Python file
16
+ with NamedTemporaryFile(suffix='.py', delete=False, mode='w') as tmp:
17
+ tmp.write(code)
18
+ tmp_path = tmp.name
 
 
 
 
19
 
20
+ # Execute the code
21
+ exec(compile(code, tmp_path, 'exec'), globals())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # Clean up the temporary file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  try:
25
+ os.unlink(tmp_path)
26
+ except:
27
+ pass
28
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  except Exception as e:
30
+ st.error(f"⚠️ Error loading or executing the application: {str(e)}")
31
  import traceback
32
+ st.code(traceback.format_exc())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  if __name__ == "__main__":
35
+ main()