seawolf2357 commited on
Commit
3e242b8
·
verified ·
1 Parent(s): f1a281a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -19
app.py CHANGED
@@ -10,6 +10,9 @@ from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError, Revis
10
  from pathlib import Path
11
  import tempfile
12
  from pydub import AudioSegment
 
 
 
13
 
14
  # Add the src directory to the system path to allow for local imports
15
  sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')))
@@ -83,11 +86,11 @@ def ensure_wav_format(audio_path):
83
  # Create a temporary WAV file
84
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
85
  wav_path = tmp_file.name
86
- # Export as WAV with standard settings
87
  audio.export(
88
  wav_path,
89
  format='wav',
90
- parameters=["-ar", "16000", "-ac", "1"] # 16kHz, mono - adjust if your model needs different settings
91
  )
92
 
93
  print(f"Audio converted successfully to: {wav_path}")
@@ -97,6 +100,88 @@ def ensure_wav_format(audio_path):
97
  print(f"Error converting audio: {e}")
98
  raise gr.Error(f"Failed to convert audio file to WAV format. Error: {e}")
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  # --- Initialization ---
101
  # Create output directory if it doesn't exist
102
  os.makedirs(OUTPUT_DIR, exist_ok=True)
@@ -120,10 +205,20 @@ except Exception as e:
120
  emo_name_to_id = {v: k for k, v in emo_map.items()}
121
 
122
  # --- Core Generation Function ---
123
- @spaces.GPU(duration=120)
124
- def generate_motion(source_image_path, driving_audio_path, emotion_name, cfg_scale, progress=gr.Progress(track_tqdm=True)):
 
 
125
  """
126
  The main function that takes Gradio inputs and generates the talking head video.
 
 
 
 
 
 
 
 
127
  """
128
  if pipeline is None:
129
  raise gr.Error("Pipeline failed to initialize. Check the console logs for details.")
@@ -135,7 +230,7 @@ def generate_motion(source_image_path, driving_audio_path, emotion_name, cfg_sca
135
 
136
  start_time = time.time()
137
 
138
- # Ensure audio is in WAV format
139
  wav_audio_path = ensure_wav_format(driving_audio_path)
140
  temp_wav_created = wav_audio_path != driving_audio_path
141
 
@@ -153,6 +248,8 @@ def generate_motion(source_image_path, driving_audio_path, emotion_name, cfg_sca
153
  print(f" Driving Audio (WAV): {wav_audio_path}")
154
  print(f" Emotion: {emotion_name} (ID: {emotion_id})")
155
  print(f" CFG Scale: {cfg_scale}")
 
 
156
 
157
  try:
158
  # Call the pipeline's inference method with the WAV audio
@@ -162,9 +259,15 @@ def generate_motion(source_image_path, driving_audio_path, emotion_name, cfg_sca
162
  cfg_scale=float(cfg_scale),
163
  emo=emotion_id,
164
  save_dir=".",
165
- smooth=False, # Smoothing can be slow, disable for a faster demo
166
  silent_audio_path=DEFAULT_SILENT_AUDIO_PATH,
167
  )
 
 
 
 
 
 
168
  except Exception as e:
169
  print(f"An error occurred during video generation: {e}")
170
  import traceback
@@ -180,7 +283,6 @@ def generate_motion(source_image_path, driving_audio_path, emotion_name, cfg_sca
180
  print(f"Warning: Could not delete temporary file {wav_audio_path}: {e}")
181
 
182
  end_time = time.time()
183
-
184
  processing_time = end_time - start_time
185
 
186
  result_video_path = Path(result_video_path)
@@ -197,7 +299,8 @@ with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 960px
197
  """
198
  <div align='center'>
199
  <h1>MoDA: Multi-modal Diffusion Architecture for Talking Head Generation</h1>
200
- <p style="display:flex">
 
201
  <a href='https://lixinyyang.github.io/MoDA.github.io/'><img src='https://img.shields.io/badge/Project-Page-blue'></a>
202
  <a href='https://arxiv.org/abs/2507.03256'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
203
  <a href='https://github.com/lixinyyang/MoDA/'><img src='https://img.shields.io/badge/Code-Github-green'></a>
@@ -208,8 +311,14 @@ with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 960px
208
 
209
  with gr.Row(variant="panel"):
210
  with gr.Column(scale=1):
 
 
211
  with gr.Row():
212
- source_image = gr.Image(label="Source Image", type="filepath", value="src/examples/reference_images/7.jpg")
 
 
 
 
213
 
214
  with gr.Row():
215
  driving_audio = gr.Audio(
@@ -218,38 +327,99 @@ with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 960px
218
  value="src/examples/driving_audios/5.wav"
219
  )
220
 
 
 
221
  with gr.Row():
222
  emotion_dropdown = gr.Dropdown(
223
  label="Emotion",
224
  choices=list(emo_map.values()),
225
- value="None"
 
226
  )
227
 
228
  with gr.Row():
229
  cfg_slider = gr.Slider(
230
- label="CFG Scale",
231
- minimum=1.0,
232
- maximum=3.0,
233
- step=0.05,
234
- value=1.2
 
 
 
 
 
 
 
 
 
 
235
  )
236
 
237
- submit_button = gr.Button("Generate Video", variant="primary")
 
 
 
 
 
 
 
 
 
 
238
 
239
  with gr.Column(scale=1):
 
240
  output_video = gr.Video(label="Generated Video")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
  gr.Markdown(
243
  """
244
  ---
245
- ### **Disclaimer**
246
- This project is intended for academic research, and we explicitly disclaim any responsibility for user-generated content. Users are solely liable for their actions while using this generative model.
 
 
 
 
 
 
 
247
  """
248
  )
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  submit_button.click(
251
  fn=generate_motion,
252
- inputs=[source_image, driving_audio, emotion_dropdown, cfg_slider],
253
  outputs=output_video
254
  )
255
 
 
10
  from pathlib import Path
11
  import tempfile
12
  from pydub import AudioSegment
13
+ import cv2
14
+ import numpy as np
15
+ from scipy import interpolate
16
 
17
  # Add the src directory to the system path to allow for local imports
18
  sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')))
 
86
  # Create a temporary WAV file
87
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
88
  wav_path = tmp_file.name
89
+ # Export as WAV with higher sampling rate for better quality
90
  audio.export(
91
  wav_path,
92
  format='wav',
93
+ parameters=["-ar", "24000", "-ac", "1"] # 24kHz, mono for better lip-sync
94
  )
95
 
96
  print(f"Audio converted successfully to: {wav_path}")
 
100
  print(f"Error converting audio: {e}")
101
  raise gr.Error(f"Failed to convert audio file to WAV format. Error: {e}")
102
 
103
+ # --- Frame Interpolation Function ---
104
+ def interpolate_frames(video_path, target_fps=30):
105
+ """
106
+ Interpolates frames in a video to achieve smoother motion.
107
+
108
+ Args:
109
+ video_path: Path to the input video
110
+ target_fps: Target frames per second
111
+
112
+ Returns:
113
+ Path to the interpolated video
114
+ """
115
+ try:
116
+ video_path = str(video_path)
117
+ cap = cv2.VideoCapture(video_path)
118
+
119
+ # Get original video properties
120
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
121
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
122
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
123
+
124
+ print(f"Original FPS: {original_fps}, Target FPS: {target_fps}")
125
+
126
+ # If target FPS is not higher, return original
127
+ if original_fps >= target_fps:
128
+ cap.release()
129
+ print("Target FPS is not higher than original. Skipping interpolation.")
130
+ return video_path
131
+
132
+ # Read all frames
133
+ frames = []
134
+ while True:
135
+ ret, frame = cap.read()
136
+ if not ret:
137
+ break
138
+ frames.append(frame)
139
+ cap.release()
140
+
141
+ if len(frames) < 2:
142
+ print("Not enough frames for interpolation.")
143
+ return video_path
144
+
145
+ # Calculate interpolation factor
146
+ interpolation_factor = int(target_fps / original_fps)
147
+ interpolated_frames = []
148
+
149
+ print(f"Interpolating with factor: {interpolation_factor}")
150
+
151
+ # Perform frame interpolation
152
+ for i in range(len(frames) - 1):
153
+ interpolated_frames.append(frames[i])
154
+
155
+ # Generate intermediate frames
156
+ for j in range(1, interpolation_factor):
157
+ alpha = j / interpolation_factor
158
+ # Use weighted average for simple interpolation
159
+ interpolated_frame = cv2.addWeighted(
160
+ frames[i], 1 - alpha,
161
+ frames[i + 1], alpha,
162
+ 0
163
+ )
164
+ interpolated_frames.append(interpolated_frame)
165
+
166
+ # Add the last frame
167
+ interpolated_frames.append(frames[-1])
168
+
169
+ # Save the interpolated video
170
+ output_path = video_path.replace('.mp4', '_interpolated.mp4')
171
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
172
+ out = cv2.VideoWriter(output_path, fourcc, target_fps, (width, height))
173
+
174
+ for frame in interpolated_frames:
175
+ out.write(frame)
176
+ out.release()
177
+
178
+ print(f"Interpolated video saved to: {output_path}")
179
+ return output_path
180
+
181
+ except Exception as e:
182
+ print(f"Error during frame interpolation: {e}")
183
+ return video_path # Return original if interpolation fails
184
+
185
  # --- Initialization ---
186
  # Create output directory if it doesn't exist
187
  os.makedirs(OUTPUT_DIR, exist_ok=True)
 
205
  emo_name_to_id = {v: k for k, v in emo_map.items()}
206
 
207
  # --- Core Generation Function ---
208
+ @spaces.GPU(duration=180) # Increased duration for smoothing and interpolation
209
+ def generate_motion(source_image_path, driving_audio_path, emotion_name,
210
+ cfg_scale, smooth_enabled, target_fps,
211
+ progress=gr.Progress(track_tqdm=True)):
212
  """
213
  The main function that takes Gradio inputs and generates the talking head video.
214
+
215
+ Args:
216
+ source_image_path: Path to the source image
217
+ driving_audio_path: Path to the driving audio
218
+ emotion_name: Selected emotion
219
+ cfg_scale: CFG scale for generation
220
+ smooth_enabled: Whether to enable smoothing
221
+ target_fps: Target frames per second for interpolation
222
  """
223
  if pipeline is None:
224
  raise gr.Error("Pipeline failed to initialize. Check the console logs for details.")
 
230
 
231
  start_time = time.time()
232
 
233
+ # Ensure audio is in WAV format with optimal sampling rate
234
  wav_audio_path = ensure_wav_format(driving_audio_path)
235
  temp_wav_created = wav_audio_path != driving_audio_path
236
 
 
248
  print(f" Driving Audio (WAV): {wav_audio_path}")
249
  print(f" Emotion: {emotion_name} (ID: {emotion_id})")
250
  print(f" CFG Scale: {cfg_scale}")
251
+ print(f" Smoothing: {smooth_enabled}")
252
+ print(f" Target FPS: {target_fps}")
253
 
254
  try:
255
  # Call the pipeline's inference method with the WAV audio
 
259
  cfg_scale=float(cfg_scale),
260
  emo=emotion_id,
261
  save_dir=".",
262
+ smooth=smooth_enabled, # Use the checkbox value
263
  silent_audio_path=DEFAULT_SILENT_AUDIO_PATH,
264
  )
265
+
266
+ # Apply frame interpolation if requested
267
+ if target_fps > 24: # Assuming default is around 24 FPS
268
+ print(f"Applying frame interpolation to achieve {target_fps} FPS...")
269
+ result_video_path = interpolate_frames(result_video_path, target_fps=target_fps)
270
+
271
  except Exception as e:
272
  print(f"An error occurred during video generation: {e}")
273
  import traceback
 
283
  print(f"Warning: Could not delete temporary file {wav_audio_path}: {e}")
284
 
285
  end_time = time.time()
 
286
  processing_time = end_time - start_time
287
 
288
  result_video_path = Path(result_video_path)
 
299
  """
300
  <div align='center'>
301
  <h1>MoDA: Multi-modal Diffusion Architecture for Talking Head Generation</h1>
302
+ <h2 style="color: #4A90E2;">Enhanced Version with Smooth Motion</h2>
303
+ <p style="display:flex; justify-content: center; gap: 10px;">
304
  <a href='https://lixinyyang.github.io/MoDA.github.io/'><img src='https://img.shields.io/badge/Project-Page-blue'></a>
305
  <a href='https://arxiv.org/abs/2507.03256'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
306
  <a href='https://github.com/lixinyyang/MoDA/'><img src='https://img.shields.io/badge/Code-Github-green'></a>
 
311
 
312
  with gr.Row(variant="panel"):
313
  with gr.Column(scale=1):
314
+ gr.Markdown("### 📥 Input Settings")
315
+
316
  with gr.Row():
317
+ source_image = gr.Image(
318
+ label="Source Image",
319
+ type="filepath",
320
+ value="src/examples/reference_images/7.jpg"
321
+ )
322
 
323
  with gr.Row():
324
  driving_audio = gr.Audio(
 
327
  value="src/examples/driving_audios/5.wav"
328
  )
329
 
330
+ gr.Markdown("### ⚙️ Generation Settings")
331
+
332
  with gr.Row():
333
  emotion_dropdown = gr.Dropdown(
334
  label="Emotion",
335
  choices=list(emo_map.values()),
336
+ value="None",
337
+ info="Select an emotion for more natural facial expressions"
338
  )
339
 
340
  with gr.Row():
341
  cfg_slider = gr.Slider(
342
+ label="CFG Scale (Lower = Smoother motion)",
343
+ minimum=0.5,
344
+ maximum=5.0,
345
+ step=0.1,
346
+ value=1.0,
347
+ info="Lower values produce smoother but less controlled motion"
348
+ )
349
+
350
+ gr.Markdown("### 🎬 Motion Enhancement")
351
+
352
+ with gr.Row():
353
+ smooth_checkbox = gr.Checkbox(
354
+ label="Enable Smoothing",
355
+ value=True,
356
+ info="Enables frame smoothing for more natural motion (increases processing time)"
357
  )
358
 
359
+ with gr.Row():
360
+ fps_slider = gr.Slider(
361
+ label="Target FPS",
362
+ minimum=24,
363
+ maximum=60,
364
+ step=6,
365
+ value=30,
366
+ info="Higher FPS for smoother motion (uses frame interpolation)"
367
+ )
368
+
369
+ submit_button = gr.Button("🎥 Generate Video", variant="primary", size="lg")
370
 
371
  with gr.Column(scale=1):
372
+ gr.Markdown("### 📺 Output")
373
  output_video = gr.Video(label="Generated Video")
374
+
375
+ # Processing status
376
+ with gr.Row():
377
+ gr.Markdown(
378
+ """
379
+ <div style="background-color: #f0f8ff; padding: 10px; border-radius: 5px; margin-top: 10px;">
380
+ <p style="margin: 0; font-size: 0.9em;">
381
+ <b>Tips for best results:</b><br>
382
+ • Use high-quality front-facing images<br>
383
+ • Clear audio without background noise<br>
384
+ • Enable smoothing for natural motion<br>
385
+ • Adjust CFG scale if motion seems stiff
386
+ </p>
387
+ </div>
388
+ """
389
+ )
390
 
391
  gr.Markdown(
392
  """
393
  ---
394
+ ### ⚠️ **Disclaimer**
395
+ This project is intended for academic research, and we explicitly disclaim any responsibility for user-generated content.
396
+ Users are solely liable for their actions while using this generative model.
397
+
398
+ ### 🚀 **Enhancement Features**
399
+ - **Frame Smoothing**: Reduces jitter and improves transition between frames
400
+ - **Frame Interpolation**: Increases FPS for smoother motion
401
+ - **Optimized Audio Processing**: Better lip-sync with 24kHz sampling
402
+ - **Fine-tuned CFG Scale**: Better control over motion naturalness
403
  """
404
  )
405
 
406
+ # Examples section
407
+ gr.Examples(
408
+ examples=[
409
+ ["src/examples/reference_images/7.jpg", "src/examples/driving_audios/5.wav", "None", 1.0, True, 30],
410
+ ["src/examples/reference_images/7.jpg", "src/examples/driving_audios/5.wav", "Happy", 0.8, True, 30],
411
+ ["src/examples/reference_images/7.jpg", "src/examples/driving_audios/5.wav", "Sad", 1.2, True, 24],
412
+ ],
413
+ inputs=[source_image, driving_audio, emotion_dropdown, cfg_slider, smooth_checkbox, fps_slider],
414
+ outputs=output_video,
415
+ fn=generate_motion,
416
+ cache_examples=False,
417
+ label="Example Configurations"
418
+ )
419
+
420
  submit_button.click(
421
  fn=generate_motion,
422
+ inputs=[source_image, driving_audio, emotion_dropdown, cfg_slider, smooth_checkbox, fps_slider],
423
  outputs=output_video
424
  )
425