seawolf2357 commited on
Commit
ebdf0db
·
verified ·
1 Parent(s): c94b781

Delete app-backup.py

Browse files
Files changed (1) hide show
  1. app-backup.py +0 -446
app-backup.py DELETED
@@ -1,446 +0,0 @@
1
- # app.py
2
-
3
- import os
4
- import sys
5
- import time
6
- import gradio as gr
7
- import spaces
8
- from huggingface_hub import snapshot_download
9
- from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError, RevisionNotFoundError
10
- from pathlib import Path
11
- import tempfile
12
- from pydub import AudioSegment
13
- import cv2
14
- import numpy as np
15
- from scipy import interpolate
16
-
17
- # Add the src directory to the system path to allow for local imports
18
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')))
19
-
20
- from models.inference.moda_test import LiveVASAPipeline, emo_map, set_seed
21
-
22
- # --- Configuration ---
23
- # Set seed for reproducibility
24
- set_seed(42)
25
-
26
- # Paths and constants for the Gradio demo
27
- DEFAULT_CFG_PATH = "configs/audio2motion/inference/inference.yaml"
28
- DEFAULT_MOTION_MEAN_STD_PATH = "src/datasets/mean.pt"
29
- DEFAULT_SILENT_AUDIO_PATH = "src/examples/silent-audio.wav"
30
- OUTPUT_DIR = "gradio_output"
31
- WEIGHTS_DIR = "pretrain_weights"
32
- REPO_ID = "lixinyizju/moda"
33
-
34
- # --- Download Pre-trained Weights from Hugging Face Hub ---
35
- def download_weights():
36
- """
37
- Downloads pre-trained weights from Hugging Face Hub if they don't exist locally.
38
- """
39
- # A simple check for a key file to see if the download is likely complete
40
- motion_model_file = os.path.join(WEIGHTS_DIR, "moda", "net-200.pth")
41
-
42
- if not os.path.exists(motion_model_file):
43
- print(f"Weights not found locally. Downloading from Hugging Face Hub repo '{REPO_ID}'...")
44
- print(f"This may take a while depending on your internet connection.")
45
- try:
46
- snapshot_download(
47
- repo_id=REPO_ID,
48
- local_dir=WEIGHTS_DIR,
49
- local_dir_use_symlinks=False, # Use False to copy files directly; safer for Windows
50
- resume_download=True,
51
- )
52
- print("Weights downloaded successfully.")
53
- except GatedRepoError:
54
- raise gr.Error(f"Access to the repository '{REPO_ID}' is gated. Please visit https://huggingface.co/{REPO_ID} to request access.")
55
- except (RepositoryNotFoundError, RevisionNotFoundError):
56
- raise gr.Error(f"The repository '{REPO_ID}' was not found. Please check the repository ID.")
57
- except Exception as e:
58
- print(f"An error occurred during download: {e}")
59
- raise gr.Error(f"Failed to download models. Please check your internet connection and try again. Error: {e}")
60
- else:
61
- print(f"Found existing weights at '{WEIGHTS_DIR}'. Skipping download.")
62
-
63
- # --- Audio Conversion Function ---
64
- def ensure_wav_format(audio_path):
65
- """
66
- Ensures the audio file is in WAV format. If not, converts it to WAV.
67
- Returns the path to the WAV file (either original or converted).
68
- """
69
- if audio_path is None:
70
- return None
71
-
72
- audio_path = Path(audio_path)
73
-
74
- # Check if already WAV
75
- if audio_path.suffix.lower() == '.wav':
76
- print(f"Audio is already in WAV format: {audio_path}")
77
- return str(audio_path)
78
-
79
- # Convert to WAV
80
- print(f"Converting audio from {audio_path.suffix} to WAV format...")
81
-
82
- try:
83
- # Load the audio file
84
- audio = AudioSegment.from_file(audio_path)
85
-
86
- # Create a temporary WAV file
87
- with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
88
- wav_path = tmp_file.name
89
- # Export as WAV with higher sampling rate for better quality
90
- audio.export(
91
- wav_path,
92
- format='wav',
93
- parameters=["-ar", "24000", "-ac", "1"] # 24kHz, mono for better lip-sync
94
- )
95
-
96
- print(f"Audio converted successfully to: {wav_path}")
97
- return wav_path
98
-
99
- except Exception as e:
100
- print(f"Error converting audio: {e}")
101
- raise gr.Error(f"Failed to convert audio file to WAV format. Error: {e}")
102
-
103
- # --- Frame Interpolation Function ---
104
- def interpolate_frames(video_path, target_fps=30):
105
- """
106
- Interpolates frames in a video to achieve smoother motion.
107
-
108
- Args:
109
- video_path: Path to the input video
110
- target_fps: Target frames per second
111
-
112
- Returns:
113
- Path to the interpolated video
114
- """
115
- try:
116
- video_path = str(video_path)
117
- cap = cv2.VideoCapture(video_path)
118
-
119
- # Get original video properties
120
- original_fps = cap.get(cv2.CAP_PROP_FPS)
121
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
122
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
123
-
124
- print(f"Original FPS: {original_fps}, Target FPS: {target_fps}")
125
-
126
- # If target FPS is not higher, return original
127
- if original_fps >= target_fps:
128
- cap.release()
129
- print("Target FPS is not higher than original. Skipping interpolation.")
130
- return video_path
131
-
132
- # Read all frames
133
- frames = []
134
- while True:
135
- ret, frame = cap.read()
136
- if not ret:
137
- break
138
- frames.append(frame)
139
- cap.release()
140
-
141
- if len(frames) < 2:
142
- print("Not enough frames for interpolation.")
143
- return video_path
144
-
145
- # Calculate interpolation factor
146
- interpolation_factor = int(target_fps / original_fps)
147
- interpolated_frames = []
148
-
149
- print(f"Interpolating with factor: {interpolation_factor}")
150
-
151
- # Perform frame interpolation
152
- for i in range(len(frames) - 1):
153
- interpolated_frames.append(frames[i])
154
-
155
- # Generate intermediate frames
156
- for j in range(1, interpolation_factor):
157
- alpha = j / interpolation_factor
158
- # Use weighted average for simple interpolation
159
- interpolated_frame = cv2.addWeighted(
160
- frames[i], 1 - alpha,
161
- frames[i + 1], alpha,
162
- 0
163
- )
164
- interpolated_frames.append(interpolated_frame)
165
-
166
- # Add the last frame
167
- interpolated_frames.append(frames[-1])
168
-
169
- # Save the interpolated video
170
- output_path = video_path.replace('.mp4', '_interpolated.mp4')
171
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
172
- out = cv2.VideoWriter(output_path, fourcc, target_fps, (width, height))
173
-
174
- for frame in interpolated_frames:
175
- out.write(frame)
176
- out.release()
177
-
178
- print(f"Interpolated video saved to: {output_path}")
179
- return output_path
180
-
181
- except Exception as e:
182
- print(f"Error during frame interpolation: {e}")
183
- return video_path # Return original if interpolation fails
184
-
185
- # --- Initialization ---
186
- # Create output directory if it doesn't exist
187
- os.makedirs(OUTPUT_DIR, exist_ok=True)
188
-
189
- # Download weights before initializing the pipeline
190
- download_weights()
191
-
192
- # Instantiate the pipeline once to avoid reloading models on every request
193
- print("Initializing MoDA pipeline...")
194
- try:
195
- pipeline = LiveVASAPipeline(
196
- cfg_path=DEFAULT_CFG_PATH,
197
- motion_mean_std_path=DEFAULT_MOTION_MEAN_STD_PATH
198
- )
199
- print("MoDA pipeline initialized successfully.")
200
- except Exception as e:
201
- print(f"Error initializing pipeline: {e}")
202
- pipeline = None
203
-
204
- # Invert the emo_map for easy lookup from the dropdown value
205
- emo_name_to_id = {v: k for k, v in emo_map.items()}
206
-
207
- # --- Core Generation Function ---
208
- @spaces.GPU(duration=180) # Increased duration for smoothing and interpolation
209
- def generate_motion(source_image_path, driving_audio_path, emotion_name,
210
- cfg_scale, smooth_enabled, target_fps,
211
- progress=gr.Progress(track_tqdm=True)):
212
- """
213
- The main function that takes Gradio inputs and generates the talking head video.
214
-
215
- Args:
216
- source_image_path: Path to the source image
217
- driving_audio_path: Path to the driving audio
218
- emotion_name: Selected emotion
219
- cfg_scale: CFG scale for generation
220
- smooth_enabled: Whether to enable smoothing
221
- target_fps: Target frames per second for interpolation
222
- """
223
- if pipeline is None:
224
- raise gr.Error("Pipeline failed to initialize. Check the console logs for details.")
225
-
226
- if source_image_path is None:
227
- raise gr.Error("Please upload a source image.")
228
- if driving_audio_path is None:
229
- raise gr.Error("Please upload a driving audio file.")
230
-
231
- start_time = time.time()
232
-
233
- # Ensure audio is in WAV format with optimal sampling rate
234
- wav_audio_path = ensure_wav_format(driving_audio_path)
235
- temp_wav_created = wav_audio_path != driving_audio_path
236
-
237
- # Create a unique subdirectory for this run
238
- timestamp = time.strftime("%Y%m%d-%H%M%S")
239
- run_output_dir = os.path.join(OUTPUT_DIR, timestamp)
240
- os.makedirs(run_output_dir, exist_ok=True)
241
-
242
- # Get emotion ID from its name
243
- emotion_id = emo_name_to_id.get(emotion_name, 8) # Default to 'None' (ID 8) if not found
244
-
245
- print(f"Starting generation with the following parameters:")
246
- print(f" Source Image: {source_image_path}")
247
- print(f" Driving Audio (original): {driving_audio_path}")
248
- print(f" Driving Audio (WAV): {wav_audio_path}")
249
- print(f" Emotion: {emotion_name} (ID: {emotion_id})")
250
- print(f" CFG Scale: {cfg_scale}")
251
- print(f" Smoothing: {smooth_enabled}")
252
- print(f" Target FPS: {target_fps}")
253
-
254
- try:
255
- # Temporarily disable smoothing if it causes CUDA tensor issues
256
- # Check if smooth causes issues and handle gracefully
257
- try:
258
- # Try with smoothing first
259
- result_video_path = pipeline.driven_sample(
260
- image_path=source_image_path,
261
- audio_path=wav_audio_path,
262
- cfg_scale=float(cfg_scale),
263
- emo=emotion_id,
264
- save_dir=".",
265
- smooth=smooth_enabled, # Use the checkbox value
266
- silent_audio_path=DEFAULT_SILENT_AUDIO_PATH,
267
- )
268
- except TypeError as tensor_error:
269
- if "can't convert cuda" in str(tensor_error) and smooth_enabled:
270
- print("Warning: Smoothing caused CUDA tensor error. Retrying without smoothing...")
271
- # Retry without smoothing
272
- result_video_path = pipeline.driven_sample(
273
- image_path=source_image_path,
274
- audio_path=wav_audio_path,
275
- cfg_scale=float(cfg_scale),
276
- emo=emotion_id,
277
- save_dir=".",
278
- smooth=False, # Disable smoothing as fallback
279
- silent_audio_path=DEFAULT_SILENT_AUDIO_PATH,
280
- )
281
- print("Generated video without smoothing due to technical limitations.")
282
- else:
283
- raise tensor_error
284
-
285
- # Apply frame interpolation if requested
286
- if target_fps > 24: # Assuming default is around 24 FPS
287
- print(f"Applying frame interpolation to achieve {target_fps} FPS...")
288
- result_video_path = interpolate_frames(result_video_path, target_fps=target_fps)
289
-
290
- except Exception as e:
291
- print(f"An error occurred during video generation: {e}")
292
- import traceback
293
- traceback.print_exc()
294
- raise gr.Error(f"An unexpected error occurred: {str(e)}. Please check the console for details.")
295
- finally:
296
- # Clean up temporary WAV file if created
297
- if temp_wav_created and os.path.exists(wav_audio_path):
298
- try:
299
- os.remove(wav_audio_path)
300
- print(f"Cleaned up temporary WAV file: {wav_audio_path}")
301
- except Exception as e:
302
- print(f"Warning: Could not delete temporary file {wav_audio_path}: {e}")
303
-
304
- end_time = time.time()
305
- processing_time = end_time - start_time
306
-
307
- result_video_path = Path(result_video_path)
308
- final_path = result_video_path.with_name(f"final_{result_video_path.stem}{result_video_path.suffix}")
309
-
310
- print(f"Video generated successfully at: {final_path}")
311
- print(f"Processing time: {processing_time:.2f} seconds.")
312
-
313
- return final_path
314
-
315
- # --- Gradio UI Definition ---
316
- with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 960px !important; margin: 0 auto !important}") as demo:
317
- gr.HTML(
318
- """
319
- <div align='center'>
320
- <h1>MoDA: Multi-modal Diffusion Architecture for Talking Head Generation</h1>
321
- <h2 style="color: #4A90E2;">Enhanced Version with Smooth Motion</h2>
322
- <p style="display:flex; justify-content: center; gap: 10px;">
323
- <a href='https://lixinyyang.github.io/MoDA.github.io/'><img src='https://img.shields.io/badge/Project-Page-blue'></a>
324
- <a href='https://arxiv.org/abs/2507.03256'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
325
- <a href='https://github.com/lixinyyang/MoDA/'><img src='https://img.shields.io/badge/Code-Github-green'></a>
326
- </p>
327
- </div>
328
- """
329
- )
330
-
331
- with gr.Row(variant="panel"):
332
- with gr.Column(scale=1):
333
- gr.Markdown("### 📥 Input Settings")
334
-
335
- with gr.Row():
336
- source_image = gr.Image(
337
- label="Source Image",
338
- type="filepath",
339
- value="src/examples/reference_images/7.jpg"
340
- )
341
-
342
- with gr.Row():
343
- driving_audio = gr.Audio(
344
- label="Driving Audio",
345
- type="filepath",
346
- value="src/examples/driving_audios/5.wav"
347
- )
348
-
349
- gr.Markdown("### ⚙️ Generation Settings")
350
-
351
- with gr.Row():
352
- emotion_dropdown = gr.Dropdown(
353
- label="Emotion",
354
- choices=list(emo_map.values()),
355
- value="Neutral",
356
- info="Select an emotion for more natural facial expressions"
357
- )
358
-
359
- with gr.Row():
360
- cfg_slider = gr.Slider(
361
- label="CFG Scale (Lower = Smoother motion)",
362
- minimum=0.5,
363
- maximum=5.0,
364
- step=0.1,
365
- value=0.5,
366
- info="Lower values produce smoother but less controlled motion"
367
- )
368
-
369
- gr.Markdown("### 🎬 Motion Enhancement")
370
-
371
- with gr.Row():
372
- smooth_checkbox = gr.Checkbox(
373
- label="Enable Smoothing (Experimental)",
374
- value=True, # Changed to False due to CUDA issues
375
- info="May cause errors on some systems. If errors occur, disable this option."
376
- )
377
-
378
- with gr.Row():
379
- fps_slider = gr.Slider(
380
- label="Target FPS",
381
- minimum=24,
382
- maximum=60,
383
- step=6,
384
- value=60,
385
- info="Higher FPS for smoother motion (uses frame interpolation)"
386
- )
387
-
388
- submit_button = gr.Button("🎥 Generate Video", variant="primary", size="lg")
389
-
390
- with gr.Column(scale=1):
391
- gr.Markdown("### 📺 Output")
392
- output_video = gr.Video(label="Generated Video")
393
-
394
- # Processing status
395
- with gr.Row():
396
- gr.Markdown(
397
- """
398
- <div style="background-color: #f0f8ff; padding: 10px; border-radius: 5px; margin-top: 10px;">
399
- <p style="margin: 0; font-size: 0.9em;">
400
- <b>Tips for best results:</b><br>
401
- • Use high-quality front-facing images<br>
402
- • Clear audio without background noise<br>
403
- • Enable smoothing for natural motion<br>
404
- • Adjust CFG scale if motion seems stiff
405
- </p>
406
- </div>
407
- """
408
- )
409
-
410
- gr.Markdown(
411
- """
412
- ---
413
- ### ⚠️ **Disclaimer**
414
- This project is intended for academic research, and we explicitly disclaim any responsibility for user-generated content.
415
- Users are solely liable for their actions while using this generative model.
416
-
417
- ### 🚀 **Enhancement Features**
418
- - **Frame Smoothing**: Reduces jitter and improves transition between frames
419
- - **Frame Interpolation**: Increases FPS for smoother motion
420
- - **Optimized Audio Processing**: Better lip-sync with 24kHz sampling
421
- - **Fine-tuned CFG Scale**: Better control over motion naturalness
422
- """
423
- )
424
-
425
- # Examples section
426
- gr.Examples(
427
- examples=[
428
- ["src/examples/reference_images/7.jpg", "src/examples/driving_audios/5.wav", "None", 1.0, False, 30],
429
- ["src/examples/reference_images/7.jpg", "src/examples/driving_audios/5.wav", "Happy", 0.8, False, 30],
430
- ["src/examples/reference_images/7.jpg", "src/examples/driving_audios/5.wav", "Sad", 1.2, False, 24],
431
- ],
432
- inputs=[source_image, driving_audio, emotion_dropdown, cfg_slider, smooth_checkbox, fps_slider],
433
- outputs=output_video,
434
- fn=generate_motion,
435
- cache_examples=False,
436
- label="Example Configurations"
437
- )
438
-
439
- submit_button.click(
440
- fn=generate_motion,
441
- inputs=[source_image, driving_audio, emotion_dropdown, cfg_slider, smooth_checkbox, fps_slider],
442
- outputs=output_video
443
- )
444
-
445
- if __name__ == "__main__":
446
- demo.launch(share=True)