seawolf2357 commited on
Commit
af5108f
·
verified ·
1 Parent(s): 6588b24

Create app-backup.py

Browse files
Files changed (1) hide show
  1. app-backup.py +446 -0
app-backup.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import os
4
+ import sys
5
+ import time
6
+ import gradio as gr
7
+ import spaces
8
+ from huggingface_hub import snapshot_download
9
+ from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError, RevisionNotFoundError
10
+ from pathlib import Path
11
+ import tempfile
12
+ from pydub import AudioSegment
13
+ import cv2
14
+ import numpy as np
15
+ from scipy import interpolate
16
+
17
+ # Add the src directory to the system path to allow for local imports
18
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')))
19
+
20
+ from models.inference.moda_test import LiveVASAPipeline, emo_map, set_seed
21
+
22
+ # --- Configuration ---
23
+ # Set seed for reproducibility
24
+ set_seed(42)
25
+
26
+ # Paths and constants for the Gradio demo
27
+ DEFAULT_CFG_PATH = "configs/audio2motion/inference/inference.yaml"
28
+ DEFAULT_MOTION_MEAN_STD_PATH = "src/datasets/mean.pt"
29
+ DEFAULT_SILENT_AUDIO_PATH = "src/examples/silent-audio.wav"
30
+ OUTPUT_DIR = "gradio_output"
31
+ WEIGHTS_DIR = "pretrain_weights"
32
+ REPO_ID = "lixinyizju/moda"
33
+
34
+ # --- Download Pre-trained Weights from Hugging Face Hub ---
35
+ def download_weights():
36
+ """
37
+ Downloads pre-trained weights from Hugging Face Hub if they don't exist locally.
38
+ """
39
+ # A simple check for a key file to see if the download is likely complete
40
+ motion_model_file = os.path.join(WEIGHTS_DIR, "moda", "net-200.pth")
41
+
42
+ if not os.path.exists(motion_model_file):
43
+ print(f"Weights not found locally. Downloading from Hugging Face Hub repo '{REPO_ID}'...")
44
+ print(f"This may take a while depending on your internet connection.")
45
+ try:
46
+ snapshot_download(
47
+ repo_id=REPO_ID,
48
+ local_dir=WEIGHTS_DIR,
49
+ local_dir_use_symlinks=False, # Use False to copy files directly; safer for Windows
50
+ resume_download=True,
51
+ )
52
+ print("Weights downloaded successfully.")
53
+ except GatedRepoError:
54
+ raise gr.Error(f"Access to the repository '{REPO_ID}' is gated. Please visit https://huggingface.co/{REPO_ID} to request access.")
55
+ except (RepositoryNotFoundError, RevisionNotFoundError):
56
+ raise gr.Error(f"The repository '{REPO_ID}' was not found. Please check the repository ID.")
57
+ except Exception as e:
58
+ print(f"An error occurred during download: {e}")
59
+ raise gr.Error(f"Failed to download models. Please check your internet connection and try again. Error: {e}")
60
+ else:
61
+ print(f"Found existing weights at '{WEIGHTS_DIR}'. Skipping download.")
62
+
63
+ # --- Audio Conversion Function ---
64
+ def ensure_wav_format(audio_path):
65
+ """
66
+ Ensures the audio file is in WAV format. If not, converts it to WAV.
67
+ Returns the path to the WAV file (either original or converted).
68
+ """
69
+ if audio_path is None:
70
+ return None
71
+
72
+ audio_path = Path(audio_path)
73
+
74
+ # Check if already WAV
75
+ if audio_path.suffix.lower() == '.wav':
76
+ print(f"Audio is already in WAV format: {audio_path}")
77
+ return str(audio_path)
78
+
79
+ # Convert to WAV
80
+ print(f"Converting audio from {audio_path.suffix} to WAV format...")
81
+
82
+ try:
83
+ # Load the audio file
84
+ audio = AudioSegment.from_file(audio_path)
85
+
86
+ # Create a temporary WAV file
87
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
88
+ wav_path = tmp_file.name
89
+ # Export as WAV with higher sampling rate for better quality
90
+ audio.export(
91
+ wav_path,
92
+ format='wav',
93
+ parameters=["-ar", "24000", "-ac", "1"] # 24kHz, mono for better lip-sync
94
+ )
95
+
96
+ print(f"Audio converted successfully to: {wav_path}")
97
+ return wav_path
98
+
99
+ except Exception as e:
100
+ print(f"Error converting audio: {e}")
101
+ raise gr.Error(f"Failed to convert audio file to WAV format. Error: {e}")
102
+
103
+ # --- Frame Interpolation Function ---
104
+ def interpolate_frames(video_path, target_fps=30):
105
+ """
106
+ Interpolates frames in a video to achieve smoother motion.
107
+
108
+ Args:
109
+ video_path: Path to the input video
110
+ target_fps: Target frames per second
111
+
112
+ Returns:
113
+ Path to the interpolated video
114
+ """
115
+ try:
116
+ video_path = str(video_path)
117
+ cap = cv2.VideoCapture(video_path)
118
+
119
+ # Get original video properties
120
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
121
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
122
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
123
+
124
+ print(f"Original FPS: {original_fps}, Target FPS: {target_fps}")
125
+
126
+ # If target FPS is not higher, return original
127
+ if original_fps >= target_fps:
128
+ cap.release()
129
+ print("Target FPS is not higher than original. Skipping interpolation.")
130
+ return video_path
131
+
132
+ # Read all frames
133
+ frames = []
134
+ while True:
135
+ ret, frame = cap.read()
136
+ if not ret:
137
+ break
138
+ frames.append(frame)
139
+ cap.release()
140
+
141
+ if len(frames) < 2:
142
+ print("Not enough frames for interpolation.")
143
+ return video_path
144
+
145
+ # Calculate interpolation factor
146
+ interpolation_factor = int(target_fps / original_fps)
147
+ interpolated_frames = []
148
+
149
+ print(f"Interpolating with factor: {interpolation_factor}")
150
+
151
+ # Perform frame interpolation
152
+ for i in range(len(frames) - 1):
153
+ interpolated_frames.append(frames[i])
154
+
155
+ # Generate intermediate frames
156
+ for j in range(1, interpolation_factor):
157
+ alpha = j / interpolation_factor
158
+ # Use weighted average for simple interpolation
159
+ interpolated_frame = cv2.addWeighted(
160
+ frames[i], 1 - alpha,
161
+ frames[i + 1], alpha,
162
+ 0
163
+ )
164
+ interpolated_frames.append(interpolated_frame)
165
+
166
+ # Add the last frame
167
+ interpolated_frames.append(frames[-1])
168
+
169
+ # Save the interpolated video
170
+ output_path = video_path.replace('.mp4', '_interpolated.mp4')
171
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
172
+ out = cv2.VideoWriter(output_path, fourcc, target_fps, (width, height))
173
+
174
+ for frame in interpolated_frames:
175
+ out.write(frame)
176
+ out.release()
177
+
178
+ print(f"Interpolated video saved to: {output_path}")
179
+ return output_path
180
+
181
+ except Exception as e:
182
+ print(f"Error during frame interpolation: {e}")
183
+ return video_path # Return original if interpolation fails
184
+
185
+ # --- Initialization ---
186
+ # Create output directory if it doesn't exist
187
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
188
+
189
+ # Download weights before initializing the pipeline
190
+ download_weights()
191
+
192
+ # Instantiate the pipeline once to avoid reloading models on every request
193
+ print("Initializing MoDA pipeline...")
194
+ try:
195
+ pipeline = LiveVASAPipeline(
196
+ cfg_path=DEFAULT_CFG_PATH,
197
+ motion_mean_std_path=DEFAULT_MOTION_MEAN_STD_PATH
198
+ )
199
+ print("MoDA pipeline initialized successfully.")
200
+ except Exception as e:
201
+ print(f"Error initializing pipeline: {e}")
202
+ pipeline = None
203
+
204
+ # Invert the emo_map for easy lookup from the dropdown value
205
+ emo_name_to_id = {v: k for k, v in emo_map.items()}
206
+
207
+ # --- Core Generation Function ---
208
+ @spaces.GPU(duration=180) # Increased duration for smoothing and interpolation
209
+ def generate_motion(source_image_path, driving_audio_path, emotion_name,
210
+ cfg_scale, smooth_enabled, target_fps,
211
+ progress=gr.Progress(track_tqdm=True)):
212
+ """
213
+ The main function that takes Gradio inputs and generates the talking head video.
214
+
215
+ Args:
216
+ source_image_path: Path to the source image
217
+ driving_audio_path: Path to the driving audio
218
+ emotion_name: Selected emotion
219
+ cfg_scale: CFG scale for generation
220
+ smooth_enabled: Whether to enable smoothing
221
+ target_fps: Target frames per second for interpolation
222
+ """
223
+ if pipeline is None:
224
+ raise gr.Error("Pipeline failed to initialize. Check the console logs for details.")
225
+
226
+ if source_image_path is None:
227
+ raise gr.Error("Please upload a source image.")
228
+ if driving_audio_path is None:
229
+ raise gr.Error("Please upload a driving audio file.")
230
+
231
+ start_time = time.time()
232
+
233
+ # Ensure audio is in WAV format with optimal sampling rate
234
+ wav_audio_path = ensure_wav_format(driving_audio_path)
235
+ temp_wav_created = wav_audio_path != driving_audio_path
236
+
237
+ # Create a unique subdirectory for this run
238
+ timestamp = time.strftime("%Y%m%d-%H%M%S")
239
+ run_output_dir = os.path.join(OUTPUT_DIR, timestamp)
240
+ os.makedirs(run_output_dir, exist_ok=True)
241
+
242
+ # Get emotion ID from its name
243
+ emotion_id = emo_name_to_id.get(emotion_name, 8) # Default to 'None' (ID 8) if not found
244
+
245
+ print(f"Starting generation with the following parameters:")
246
+ print(f" Source Image: {source_image_path}")
247
+ print(f" Driving Audio (original): {driving_audio_path}")
248
+ print(f" Driving Audio (WAV): {wav_audio_path}")
249
+ print(f" Emotion: {emotion_name} (ID: {emotion_id})")
250
+ print(f" CFG Scale: {cfg_scale}")
251
+ print(f" Smoothing: {smooth_enabled}")
252
+ print(f" Target FPS: {target_fps}")
253
+
254
+ try:
255
+ # Temporarily disable smoothing if it causes CUDA tensor issues
256
+ # Check if smooth causes issues and handle gracefully
257
+ try:
258
+ # Try with smoothing first
259
+ result_video_path = pipeline.driven_sample(
260
+ image_path=source_image_path,
261
+ audio_path=wav_audio_path,
262
+ cfg_scale=float(cfg_scale),
263
+ emo=emotion_id,
264
+ save_dir=".",
265
+ smooth=smooth_enabled, # Use the checkbox value
266
+ silent_audio_path=DEFAULT_SILENT_AUDIO_PATH,
267
+ )
268
+ except TypeError as tensor_error:
269
+ if "can't convert cuda" in str(tensor_error) and smooth_enabled:
270
+ print("Warning: Smoothing caused CUDA tensor error. Retrying without smoothing...")
271
+ # Retry without smoothing
272
+ result_video_path = pipeline.driven_sample(
273
+ image_path=source_image_path,
274
+ audio_path=wav_audio_path,
275
+ cfg_scale=float(cfg_scale),
276
+ emo=emotion_id,
277
+ save_dir=".",
278
+ smooth=False, # Disable smoothing as fallback
279
+ silent_audio_path=DEFAULT_SILENT_AUDIO_PATH,
280
+ )
281
+ print("Generated video without smoothing due to technical limitations.")
282
+ else:
283
+ raise tensor_error
284
+
285
+ # Apply frame interpolation if requested
286
+ if target_fps > 24: # Assuming default is around 24 FPS
287
+ print(f"Applying frame interpolation to achieve {target_fps} FPS...")
288
+ result_video_path = interpolate_frames(result_video_path, target_fps=target_fps)
289
+
290
+ except Exception as e:
291
+ print(f"An error occurred during video generation: {e}")
292
+ import traceback
293
+ traceback.print_exc()
294
+ raise gr.Error(f"An unexpected error occurred: {str(e)}. Please check the console for details.")
295
+ finally:
296
+ # Clean up temporary WAV file if created
297
+ if temp_wav_created and os.path.exists(wav_audio_path):
298
+ try:
299
+ os.remove(wav_audio_path)
300
+ print(f"Cleaned up temporary WAV file: {wav_audio_path}")
301
+ except Exception as e:
302
+ print(f"Warning: Could not delete temporary file {wav_audio_path}: {e}")
303
+
304
+ end_time = time.time()
305
+ processing_time = end_time - start_time
306
+
307
+ result_video_path = Path(result_video_path)
308
+ final_path = result_video_path.with_name(f"final_{result_video_path.stem}{result_video_path.suffix}")
309
+
310
+ print(f"Video generated successfully at: {final_path}")
311
+ print(f"Processing time: {processing_time:.2f} seconds.")
312
+
313
+ return final_path
314
+
315
+ # --- Gradio UI Definition ---
316
+ with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 960px !important; margin: 0 auto !important}") as demo:
317
+ gr.HTML(
318
+ """
319
+ <div align='center'>
320
+ <h1>MoDA: Multi-modal Diffusion Architecture for Talking Head Generation</h1>
321
+ <h2 style="color: #4A90E2;">Enhanced Version with Smooth Motion</h2>
322
+ <p style="display:flex; justify-content: center; gap: 10px;">
323
+ <a href='https://lixinyyang.github.io/MoDA.github.io/'><img src='https://img.shields.io/badge/Project-Page-blue'></a>
324
+ <a href='https://arxiv.org/abs/2507.03256'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
325
+ <a href='https://github.com/lixinyyang/MoDA/'><img src='https://img.shields.io/badge/Code-Github-green'></a>
326
+ </p>
327
+ </div>
328
+ """
329
+ )
330
+
331
+ with gr.Row(variant="panel"):
332
+ with gr.Column(scale=1):
333
+ gr.Markdown("### 📥 Input Settings")
334
+
335
+ with gr.Row():
336
+ source_image = gr.Image(
337
+ label="Source Image",
338
+ type="filepath",
339
+ value="src/examples/reference_images/7.jpg"
340
+ )
341
+
342
+ with gr.Row():
343
+ driving_audio = gr.Audio(
344
+ label="Driving Audio",
345
+ type="filepath",
346
+ value="src/examples/driving_audios/5.wav"
347
+ )
348
+
349
+ gr.Markdown("### ⚙️ Generation Settings")
350
+
351
+ with gr.Row():
352
+ emotion_dropdown = gr.Dropdown(
353
+ label="Emotion",
354
+ choices=list(emo_map.values()),
355
+ value="Neutral",
356
+ info="Select an emotion for more natural facial expressions"
357
+ )
358
+
359
+ with gr.Row():
360
+ cfg_slider = gr.Slider(
361
+ label="CFG Scale (Lower = Smoother motion)",
362
+ minimum=0.5,
363
+ maximum=5.0,
364
+ step=0.1,
365
+ value=0.5,
366
+ info="Lower values produce smoother but less controlled motion"
367
+ )
368
+
369
+ gr.Markdown("### 🎬 Motion Enhancement")
370
+
371
+ with gr.Row():
372
+ smooth_checkbox = gr.Checkbox(
373
+ label="Enable Smoothing (Experimental)",
374
+ value=True, # Changed to False due to CUDA issues
375
+ info="May cause errors on some systems. If errors occur, disable this option."
376
+ )
377
+
378
+ with gr.Row():
379
+ fps_slider = gr.Slider(
380
+ label="Target FPS",
381
+ minimum=24,
382
+ maximum=60,
383
+ step=6,
384
+ value=60,
385
+ info="Higher FPS for smoother motion (uses frame interpolation)"
386
+ )
387
+
388
+ submit_button = gr.Button("🎥 Generate Video", variant="primary", size="lg")
389
+
390
+ with gr.Column(scale=1):
391
+ gr.Markdown("### 📺 Output")
392
+ output_video = gr.Video(label="Generated Video")
393
+
394
+ # Processing status
395
+ with gr.Row():
396
+ gr.Markdown(
397
+ """
398
+ <div style="background-color: #f0f8ff; padding: 10px; border-radius: 5px; margin-top: 10px;">
399
+ <p style="margin: 0; font-size: 0.9em;">
400
+ <b>Tips for best results:</b><br>
401
+ • Use high-quality front-facing images<br>
402
+ • Clear audio without background noise<br>
403
+ • Enable smoothing for natural motion<br>
404
+ • Adjust CFG scale if motion seems stiff
405
+ </p>
406
+ </div>
407
+ """
408
+ )
409
+
410
+ gr.Markdown(
411
+ """
412
+ ---
413
+ ### ⚠️ **Disclaimer**
414
+ This project is intended for academic research, and we explicitly disclaim any responsibility for user-generated content.
415
+ Users are solely liable for their actions while using this generative model.
416
+
417
+ ### 🚀 **Enhancement Features**
418
+ - **Frame Smoothing**: Reduces jitter and improves transition between frames
419
+ - **Frame Interpolation**: Increases FPS for smoother motion
420
+ - **Optimized Audio Processing**: Better lip-sync with 24kHz sampling
421
+ - **Fine-tuned CFG Scale**: Better control over motion naturalness
422
+ """
423
+ )
424
+
425
+ # Examples section
426
+ gr.Examples(
427
+ examples=[
428
+ ["src/examples/reference_images/7.jpg", "src/examples/driving_audios/5.wav", "None", 1.0, False, 30],
429
+ ["src/examples/reference_images/7.jpg", "src/examples/driving_audios/5.wav", "Happy", 0.8, False, 30],
430
+ ["src/examples/reference_images/7.jpg", "src/examples/driving_audios/5.wav", "Sad", 1.2, False, 24],
431
+ ],
432
+ inputs=[source_image, driving_audio, emotion_dropdown, cfg_slider, smooth_checkbox, fps_slider],
433
+ outputs=output_video,
434
+ fn=generate_motion,
435
+ cache_examples=False,
436
+ label="Example Configurations"
437
+ )
438
+
439
+ submit_button.click(
440
+ fn=generate_motion,
441
+ inputs=[source_image, driving_audio, emotion_dropdown, cfg_slider, smooth_checkbox, fps_slider],
442
+ outputs=output_video
443
+ )
444
+
445
+ if __name__ == "__main__":
446
+ demo.launch(share=True)