multimodalart HF Staff commited on
Commit
b040570
·
verified ·
1 Parent(s): 7758cff

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -0
app.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import os
4
+ import sys
5
+ import time
6
+ import gradio as gr
7
+ from huggingface_hub import snapshot_download
8
+ from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError, RevisionNotFoundError
9
+ from pathlib import Path
10
+ import spaces
11
+
12
+ # Add the src directory to the system path to allow for local imports
13
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')))
14
+
15
+ from models.inference.moda_test import LiveVASAPipeline, emo_map, set_seed
16
+
17
+ # --- Configuration ---
18
+ # Set seed for reproducibility
19
+ set_seed(42)
20
+
21
+ # Paths and constants for the Gradio demo
22
+ DEFAULT_CFG_PATH = "configs/audio2motion/inference/inference.yaml"
23
+ DEFAULT_MOTION_MEAN_STD_PATH = "src/datasets/mean.pt"
24
+ DEFAULT_SILENT_AUDIO_PATH = "src/examples/silent-audio.wav"
25
+ OUTPUT_DIR = "gradio_output"
26
+ WEIGHTS_DIR = "pretrain_weights"
27
+ REPO_ID = "lixinyizju/moda"
28
+
29
+ # --- Download Pre-trained Weights from Hugging Face Hub ---
30
+ def download_weights():
31
+ """
32
+ Downloads pre-trained weights from Hugging Face Hub if they don't exist locally.
33
+ """
34
+ # A simple check for a key file to see if the download is likely complete
35
+ motion_model_file = os.path.join(WEIGHTS_DIR, "moda", "net-200.pth")
36
+
37
+ if not os.path.exists(motion_model_file):
38
+ print(f"Weights not found locally. Downloading from Hugging Face Hub repo '{REPO_ID}'...")
39
+ print(f"This may take a while depending on your internet connection.")
40
+ try:
41
+ snapshot_download(
42
+ repo_id=REPO_ID,
43
+ local_dir=WEIGHTS_DIR,
44
+ local_dir_use_symlinks=False, # Use False to copy files directly; safer for Windows
45
+ resume_download=True,
46
+ )
47
+ print("Weights downloaded successfully.")
48
+ except GatedRepoError:
49
+ raise gr.Error(f"Access to the repository '{REPO_ID}' is gated. Please visit https://huggingface.co/{REPO_ID} to request access.")
50
+ except (RepositoryNotFoundError, RevisionNotFoundError):
51
+ raise gr.Error(f"The repository '{REPO_ID}' was not found. Please check the repository ID.")
52
+ except Exception as e:
53
+ print(f"An error occurred during download: {e}")
54
+ raise gr.Error(f"Failed to download models. Please check your internet connection and try again. Error: {e}")
55
+ else:
56
+ print(f"Found existing weights at '{WEIGHTS_DIR}'. Skipping download.")
57
+
58
+
59
+ # --- Initialization ---
60
+ # Create output directory if it doesn't exist
61
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
62
+
63
+ # Download weights before initializing the pipeline
64
+ download_weights()
65
+
66
+ # Instantiate the pipeline once to avoid reloading models on every request
67
+ print("Initializing MoDA pipeline...")
68
+ try:
69
+ pipeline = LiveVASAPipeline(
70
+ cfg_path=DEFAULT_CFG_PATH,
71
+ motion_mean_std_path=DEFAULT_MOTION_MEAN_STD_PATH
72
+ )
73
+ print("MoDA pipeline initialized successfully.")
74
+ except Exception as e:
75
+ print(f"Error initializing pipeline: {e}")
76
+ pipeline = None
77
+
78
+ # Invert the emo_map for easy lookup from the dropdown value
79
+ emo_name_to_id = {v: k for k, v in emo_map.items()}
80
+
81
+ # --- Core Generation Function ---
82
+ @spaces.GPU
83
+ def generate_motion(source_image_path, driving_audio_path, emotion_name, cfg_scale, progress=gr.Progress(track_tqdm=True)):
84
+ """
85
+ The main function that takes Gradio inputs and generates the talking head video.
86
+ """
87
+ if pipeline is None:
88
+ raise gr.Error("Pipeline failed to initialize. Check the console logs for details.")
89
+
90
+ if source_image_path is None:
91
+ raise gr.Error("Please upload a source image.")
92
+ if driving_audio_path is None:
93
+ raise gr.Error("Please upload a driving audio file.")
94
+
95
+ start_time = time.time()
96
+
97
+ # Create a unique subdirectory for this run
98
+ timestamp = time.strftime("%Y%m%d-%H%M%S")
99
+ run_output_dir = os.path.join(OUTPUT_DIR, timestamp)
100
+ os.makedirs(run_output_dir, exist_ok=True)
101
+
102
+ # Get emotion ID from its name
103
+ emotion_id = emo_name_to_id.get(emotion_name, 8) # Default to 'None' (ID 8) if not found
104
+
105
+ print(f"Starting generation with the following parameters:")
106
+ print(f" Source Image: {source_image_path}")
107
+ print(f" Driving Audio: {driving_audio_path}")
108
+ print(f" Emotion: {emotion_name} (ID: {emotion_id})")
109
+ print(f" CFG Scale: {cfg_scale}")
110
+
111
+ try:
112
+ # Call the pipeline's inference method
113
+ result_video_path = pipeline.driven_sample(
114
+ image_path=source_image_path,
115
+ audio_path=driving_audio_path,
116
+ cfg_scale=float(cfg_scale),
117
+ emo=emotion_id,
118
+ save_dir=".",
119
+ smooth=False, # Smoothing can be slow, disable for a faster demo
120
+ silent_audio_path=DEFAULT_SILENT_AUDIO_PATH,
121
+ )
122
+ except Exception as e:
123
+ print(f"An error occurred during video generation: {e}")
124
+ import traceback
125
+ traceback.print_exc()
126
+ raise gr.Error(f"An unexpected error occurred: {str(e)}. Please check the console for details.")
127
+
128
+ end_time = time.time()
129
+
130
+ processing_time = end_time - start_time
131
+
132
+ result_video_path = Path(result_video_path)
133
+ final_path = result_video_path.with_name(f"final_{result_video_path.stem}{result_video_path.suffix}")
134
+
135
+ print(f"Video generated successfully at: {final_path}")
136
+ print(f"Processing time: {processing_time:.2f} seconds.")
137
+
138
+ return final_path
139
+
140
+ # --- Gradio UI Definition ---
141
+ with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 90% !important;}") as demo:
142
+ gr.HTML(
143
+ """
144
+ <div align='center'>
145
+ <h1>MoDA: Multi-modal Diffusion Architecture for Talking Head Generation</h1>
146
+ <p style="display:flex">
147
+ <a href='https://lixinyyang.github.io/MoDA.github.io/'><img src='https://img.shields.io/badge/Project-Page-blue'></a>
148
+ <a href='https://arxiv.org/abs/2507.03256'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
149
+ <a href='https://github.com/lixinyyang/MoDA/'><img src='https://img.shields.io/badge/Code-Github-green'></a>
150
+ </p>
151
+ <p>
152
+ This demo allows you to generate a talking head video from a source image and a driving audio file.
153
+ </p>
154
+ </div>
155
+ """
156
+ )
157
+
158
+ with gr.Row(variant="panel"):
159
+ with gr.Column(scale=1):
160
+ with gr.Row():
161
+ source_image = gr.Image(label="Source Image", type="filepath", value="src/examples/reference_images/6.jpg")
162
+
163
+ with gr.Row():
164
+ driving_audio = gr.Audio(label="Driving Audio", type="filepath", value="src/examples/driving_audios/5.wav")
165
+
166
+ with gr.Row():
167
+ emotion_dropdown = gr.Dropdown(
168
+ label="Emotion",
169
+ choices=list(emo_map.values()),
170
+ value="None"
171
+ )
172
+
173
+ with gr.Row():
174
+ cfg_slider = gr.Slider(
175
+ label="CFG Scale",
176
+ minimum=1.0,
177
+ maximum=3.0,
178
+ step=0.05,
179
+ value=1.2
180
+ )
181
+
182
+ submit_button = gr.Button("Generate Video", variant="primary")
183
+
184
+ with gr.Column(scale=1):
185
+ output_video = gr.Video(label="Generated Video")
186
+
187
+ gr.Markdown("## Examples")
188
+ gr.Examples(
189
+ examples=[
190
+ ["src/examples/reference_images/monalisa.jpg", "src/examples/driving_audios/monalisa.wav", "None", 1.2],
191
+ ["src/examples/reference_images/girl.png", "src/examples/driving_audios/girl.wav", "Happiness", 1.25],
192
+ ["src/examples/reference_images/jobs.jpg", "src/examples/driving_audios/jobs.wav", "Neutral", 1.15],
193
+ ],
194
+ inputs=[source_image, driving_audio, emotion_dropdown, cfg_slider],
195
+ outputs=output_video,
196
+ fn=generate_motion,
197
+ cache_examples=False,
198
+ )
199
+
200
+ gr.Markdown(
201
+ """
202
+ ---
203
+ ### **Disclaimer**
204
+ This project is intended for academic research, and we explicitly disclaim any responsibility for user-generated content. Users are solely liable for their actions while using this generative model.
205
+ """
206
+ )
207
+
208
+ submit_button.click(
209
+ fn=generate_motion,
210
+ inputs=[source_image, driving_audio, emotion_dropdown, cfg_slider],
211
+ outputs=output_video
212
+ )
213
+
214
+ if __name__ == "__main__":
215
+ demo.launch(share=True)