|
|
|
|
|
import os |
|
import sys |
|
import time |
|
import gradio as gr |
|
import spaces |
|
from huggingface_hub import snapshot_download |
|
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError, RevisionNotFoundError |
|
from pathlib import Path |
|
import tempfile |
|
from pydub import AudioSegment |
|
import cv2 |
|
import numpy as np |
|
from scipy import interpolate |
|
|
|
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))) |
|
|
|
from models.inference.moda_test import LiveVASAPipeline, emo_map, set_seed |
|
|
|
|
|
|
|
set_seed(42) |
|
|
|
|
|
DEFAULT_CFG_PATH = "configs/audio2motion/inference/inference.yaml" |
|
DEFAULT_MOTION_MEAN_STD_PATH = "src/datasets/mean.pt" |
|
DEFAULT_SILENT_AUDIO_PATH = "src/examples/silent-audio.wav" |
|
OUTPUT_DIR = "gradio_output" |
|
WEIGHTS_DIR = "pretrain_weights" |
|
REPO_ID = "lixinyizju/moda" |
|
|
|
|
|
def download_weights(): |
|
""" |
|
Downloads pre-trained weights from Hugging Face Hub if they don't exist locally. |
|
""" |
|
|
|
motion_model_file = os.path.join(WEIGHTS_DIR, "moda", "net-200.pth") |
|
|
|
if not os.path.exists(motion_model_file): |
|
print(f"Weights not found locally. Downloading from Hugging Face Hub repo '{REPO_ID}'...") |
|
print(f"This may take a while depending on your internet connection.") |
|
try: |
|
snapshot_download( |
|
repo_id=REPO_ID, |
|
local_dir=WEIGHTS_DIR, |
|
local_dir_use_symlinks=False, |
|
resume_download=True, |
|
) |
|
print("Weights downloaded successfully.") |
|
except GatedRepoError: |
|
raise gr.Error(f"Access to the repository '{REPO_ID}' is gated. Please visit https://huggingface.co/{REPO_ID} to request access.") |
|
except (RepositoryNotFoundError, RevisionNotFoundError): |
|
raise gr.Error(f"The repository '{REPO_ID}' was not found. Please check the repository ID.") |
|
except Exception as e: |
|
print(f"An error occurred during download: {e}") |
|
raise gr.Error(f"Failed to download models. Please check your internet connection and try again. Error: {e}") |
|
else: |
|
print(f"Found existing weights at '{WEIGHTS_DIR}'. Skipping download.") |
|
|
|
|
|
def ensure_wav_format(audio_path): |
|
""" |
|
Ensures the audio file is in WAV format. If not, converts it to WAV. |
|
Returns the path to the WAV file (either original or converted). |
|
""" |
|
if audio_path is None: |
|
return None |
|
|
|
audio_path = Path(audio_path) |
|
|
|
|
|
if audio_path.suffix.lower() == '.wav': |
|
print(f"Audio is already in WAV format: {audio_path}") |
|
return str(audio_path) |
|
|
|
|
|
print(f"Converting audio from {audio_path.suffix} to WAV format...") |
|
|
|
try: |
|
|
|
audio = AudioSegment.from_file(audio_path) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file: |
|
wav_path = tmp_file.name |
|
|
|
audio.export( |
|
wav_path, |
|
format='wav', |
|
parameters=["-ar", "24000", "-ac", "1"] |
|
) |
|
|
|
print(f"Audio converted successfully to: {wav_path}") |
|
return wav_path |
|
|
|
except Exception as e: |
|
print(f"Error converting audio: {e}") |
|
raise gr.Error(f"Failed to convert audio file to WAV format. Error: {e}") |
|
|
|
|
|
def interpolate_frames(video_path, target_fps=30): |
|
""" |
|
Interpolates frames in a video to achieve smoother motion. |
|
|
|
Args: |
|
video_path: Path to the input video |
|
target_fps: Target frames per second |
|
|
|
Returns: |
|
Path to the interpolated video |
|
""" |
|
try: |
|
video_path = str(video_path) |
|
cap = cv2.VideoCapture(video_path) |
|
|
|
|
|
original_fps = cap.get(cv2.CAP_PROP_FPS) |
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
|
print(f"Original FPS: {original_fps}, Target FPS: {target_fps}") |
|
|
|
|
|
if original_fps >= target_fps: |
|
cap.release() |
|
print("Target FPS is not higher than original. Skipping interpolation.") |
|
return video_path |
|
|
|
|
|
frames = [] |
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
frames.append(frame) |
|
cap.release() |
|
|
|
if len(frames) < 2: |
|
print("Not enough frames for interpolation.") |
|
return video_path |
|
|
|
|
|
interpolation_factor = int(target_fps / original_fps) |
|
interpolated_frames = [] |
|
|
|
print(f"Interpolating with factor: {interpolation_factor}") |
|
|
|
|
|
for i in range(len(frames) - 1): |
|
interpolated_frames.append(frames[i]) |
|
|
|
|
|
for j in range(1, interpolation_factor): |
|
alpha = j / interpolation_factor |
|
|
|
interpolated_frame = cv2.addWeighted( |
|
frames[i], 1 - alpha, |
|
frames[i + 1], alpha, |
|
0 |
|
) |
|
interpolated_frames.append(interpolated_frame) |
|
|
|
|
|
interpolated_frames.append(frames[-1]) |
|
|
|
|
|
output_path = video_path.replace('.mp4', '_interpolated.mp4') |
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
out = cv2.VideoWriter(output_path, fourcc, target_fps, (width, height)) |
|
|
|
for frame in interpolated_frames: |
|
out.write(frame) |
|
out.release() |
|
|
|
print(f"Interpolated video saved to: {output_path}") |
|
return output_path |
|
|
|
except Exception as e: |
|
print(f"Error during frame interpolation: {e}") |
|
return video_path |
|
|
|
|
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
|
|
download_weights() |
|
|
|
|
|
print("Initializing MoDA pipeline...") |
|
try: |
|
pipeline = LiveVASAPipeline( |
|
cfg_path=DEFAULT_CFG_PATH, |
|
motion_mean_std_path=DEFAULT_MOTION_MEAN_STD_PATH |
|
) |
|
print("MoDA pipeline initialized successfully.") |
|
except Exception as e: |
|
print(f"Error initializing pipeline: {e}") |
|
pipeline = None |
|
|
|
|
|
emo_name_to_id = {v: k for k, v in emo_map.items()} |
|
|
|
|
|
@spaces.GPU(duration=180) |
|
def generate_motion(source_image_path, driving_audio_path, emotion_name, |
|
cfg_scale, smooth_enabled, target_fps, |
|
progress=gr.Progress(track_tqdm=True)): |
|
""" |
|
The main function that takes Gradio inputs and generates the talking head video. |
|
|
|
Args: |
|
source_image_path: Path to the source image |
|
driving_audio_path: Path to the driving audio |
|
emotion_name: Selected emotion |
|
cfg_scale: CFG scale for generation |
|
smooth_enabled: Whether to enable smoothing |
|
target_fps: Target frames per second for interpolation |
|
""" |
|
if pipeline is None: |
|
raise gr.Error("Pipeline failed to initialize. Check the console logs for details.") |
|
|
|
if source_image_path is None: |
|
raise gr.Error("Please upload a source image.") |
|
if driving_audio_path is None: |
|
raise gr.Error("Please upload a driving audio file.") |
|
|
|
start_time = time.time() |
|
|
|
|
|
wav_audio_path = ensure_wav_format(driving_audio_path) |
|
temp_wav_created = wav_audio_path != driving_audio_path |
|
|
|
|
|
timestamp = time.strftime("%Y%m%d-%H%M%S") |
|
run_output_dir = os.path.join(OUTPUT_DIR, timestamp) |
|
os.makedirs(run_output_dir, exist_ok=True) |
|
|
|
|
|
emotion_id = emo_name_to_id.get(emotion_name, 8) |
|
|
|
print(f"Starting generation with the following parameters:") |
|
print(f" Source Image: {source_image_path}") |
|
print(f" Driving Audio (original): {driving_audio_path}") |
|
print(f" Driving Audio (WAV): {wav_audio_path}") |
|
print(f" Emotion: {emotion_name} (ID: {emotion_id})") |
|
print(f" CFG Scale: {cfg_scale}") |
|
print(f" Smoothing: {smooth_enabled}") |
|
print(f" Target FPS: {target_fps}") |
|
|
|
try: |
|
|
|
|
|
try: |
|
|
|
result_video_path = pipeline.driven_sample( |
|
image_path=source_image_path, |
|
audio_path=wav_audio_path, |
|
cfg_scale=float(cfg_scale), |
|
emo=emotion_id, |
|
save_dir=".", |
|
smooth=smooth_enabled, |
|
silent_audio_path=DEFAULT_SILENT_AUDIO_PATH, |
|
) |
|
except TypeError as tensor_error: |
|
if "can't convert cuda" in str(tensor_error) and smooth_enabled: |
|
print("Warning: Smoothing caused CUDA tensor error. Retrying without smoothing...") |
|
|
|
result_video_path = pipeline.driven_sample( |
|
image_path=source_image_path, |
|
audio_path=wav_audio_path, |
|
cfg_scale=float(cfg_scale), |
|
emo=emotion_id, |
|
save_dir=".", |
|
smooth=False, |
|
silent_audio_path=DEFAULT_SILENT_AUDIO_PATH, |
|
) |
|
print("Generated video without smoothing due to technical limitations.") |
|
else: |
|
raise tensor_error |
|
|
|
|
|
if target_fps > 24: |
|
print(f"Applying frame interpolation to achieve {target_fps} FPS...") |
|
result_video_path = interpolate_frames(result_video_path, target_fps=target_fps) |
|
|
|
except Exception as e: |
|
print(f"An error occurred during video generation: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
raise gr.Error(f"An unexpected error occurred: {str(e)}. Please check the console for details.") |
|
finally: |
|
|
|
if temp_wav_created and os.path.exists(wav_audio_path): |
|
try: |
|
os.remove(wav_audio_path) |
|
print(f"Cleaned up temporary WAV file: {wav_audio_path}") |
|
except Exception as e: |
|
print(f"Warning: Could not delete temporary file {wav_audio_path}: {e}") |
|
|
|
end_time = time.time() |
|
processing_time = end_time - start_time |
|
|
|
result_video_path = Path(result_video_path) |
|
final_path = result_video_path.with_name(f"final_{result_video_path.stem}{result_video_path.suffix}") |
|
|
|
print(f"Video generated successfully at: {final_path}") |
|
print(f"Processing time: {processing_time:.2f} seconds.") |
|
|
|
return final_path |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 960px !important; margin: 0 auto !important}") as demo: |
|
gr.HTML( |
|
""" |
|
<div align='center'> |
|
<h1>MoDA: Multi-modal Diffusion Architecture for Talking Head Generation</h1> |
|
<h2 style="color: #4A90E2;">Enhanced Version with Smooth Motion</h2> |
|
<p style="display:flex; justify-content: center; gap: 10px;"> |
|
<a href='https://lixinyyang.github.io/MoDA.github.io/'><img src='https://img.shields.io/badge/Project-Page-blue'></a> |
|
<a href='https://arxiv.org/abs/2507.03256'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> |
|
<a href='https://github.com/lixinyyang/MoDA/'><img src='https://img.shields.io/badge/Code-Github-green'></a> |
|
</p> |
|
</div> |
|
""" |
|
) |
|
|
|
with gr.Row(variant="panel"): |
|
with gr.Column(scale=1): |
|
gr.Markdown("### 📥 Input Settings") |
|
|
|
with gr.Row(): |
|
source_image = gr.Image( |
|
label="Source Image", |
|
type="filepath", |
|
value="src/examples/reference_images/7.jpg" |
|
) |
|
|
|
with gr.Row(): |
|
driving_audio = gr.Audio( |
|
label="Driving Audio", |
|
type="filepath", |
|
value="src/examples/driving_audios/5.wav" |
|
) |
|
|
|
gr.Markdown("### ⚙️ Generation Settings") |
|
|
|
with gr.Row(): |
|
emotion_dropdown = gr.Dropdown( |
|
label="Emotion", |
|
choices=list(emo_map.values()), |
|
value="Neutral", |
|
info="Select an emotion for more natural facial expressions" |
|
) |
|
|
|
with gr.Row(): |
|
cfg_slider = gr.Slider( |
|
label="CFG Scale (Lower = Smoother motion)", |
|
minimum=0.5, |
|
maximum=5.0, |
|
step=0.1, |
|
value=0.5, |
|
info="Lower values produce smoother but less controlled motion" |
|
) |
|
|
|
gr.Markdown("### 🎬 Motion Enhancement") |
|
|
|
with gr.Row(): |
|
smooth_checkbox = gr.Checkbox( |
|
label="Enable Smoothing (Experimental)", |
|
value=True, |
|
info="May cause errors on some systems. If errors occur, disable this option." |
|
) |
|
|
|
with gr.Row(): |
|
fps_slider = gr.Slider( |
|
label="Target FPS", |
|
minimum=24, |
|
maximum=60, |
|
step=6, |
|
value=60, |
|
info="Higher FPS for smoother motion (uses frame interpolation)" |
|
) |
|
|
|
submit_button = gr.Button("🎥 Generate Video", variant="primary", size="lg") |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("### 📺 Output") |
|
output_video = gr.Video(label="Generated Video") |
|
|
|
|
|
with gr.Row(): |
|
gr.Markdown( |
|
""" |
|
<div style="background-color: #f0f8ff; padding: 10px; border-radius: 5px; margin-top: 10px;"> |
|
<p style="margin: 0; font-size: 0.9em;"> |
|
<b>Tips for best results:</b><br> |
|
• Use high-quality front-facing images<br> |
|
• Clear audio without background noise<br> |
|
• Enable smoothing for natural motion<br> |
|
• Adjust CFG scale if motion seems stiff |
|
</p> |
|
</div> |
|
""" |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
--- |
|
### ⚠️ **Disclaimer** |
|
This project is intended for academic research, and we explicitly disclaim any responsibility for user-generated content. |
|
Users are solely liable for their actions while using this generative model. |
|
|
|
### 🚀 **Enhancement Features** |
|
- **Frame Smoothing**: Reduces jitter and improves transition between frames |
|
- **Frame Interpolation**: Increases FPS for smoother motion |
|
- **Optimized Audio Processing**: Better lip-sync with 24kHz sampling |
|
- **Fine-tuned CFG Scale**: Better control over motion naturalness |
|
""" |
|
) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
["src/examples/reference_images/7.jpg", "src/examples/driving_audios/5.wav", "None", 1.0, False, 30], |
|
["src/examples/reference_images/7.jpg", "src/examples/driving_audios/5.wav", "Happy", 0.8, False, 30], |
|
["src/examples/reference_images/7.jpg", "src/examples/driving_audios/5.wav", "Sad", 1.2, False, 24], |
|
], |
|
inputs=[source_image, driving_audio, emotion_dropdown, cfg_slider, smooth_checkbox, fps_slider], |
|
outputs=output_video, |
|
fn=generate_motion, |
|
cache_examples=False, |
|
label="Example Configurations" |
|
) |
|
|
|
submit_button.click( |
|
fn=generate_motion, |
|
inputs=[source_image, driving_audio, emotion_dropdown, cfg_slider, smooth_checkbox, fps_slider], |
|
outputs=output_video |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |