artificialguybr commited on
Commit
12f2db2
Β·
verified Β·
1 Parent(s): a2cb8bc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +667 -0
app.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import spaces
3
+ import torch
4
+ import gradio as gr
5
+ import tempfile
6
+ import subprocess
7
+ import sys
8
+ from pathlib import Path
9
+ import datetime
10
+ import math
11
+ import random
12
+ import gc
13
+ import json
14
+ import numpy as np
15
+ from PIL import Image
16
+ from moviepy.editor import VideoFileClip, AudioFileClip
17
+ import librosa
18
+ from omegaconf import OmegaConf
19
+ from transformers import AutoTokenizer, Wav2Vec2Model, Wav2Vec2Processor
20
+ from diffusers import FlowMatchEulerDiscreteScheduler
21
+ from huggingface_hub import hf_hub_download, snapshot_download
22
+
23
+ def setup_repository():
24
+ if not os.path.exists("echomimic_v3"):
25
+ print("πŸ”„ Cloning EchoMimicV3 repository...")
26
+ subprocess.run([
27
+ "git", "clone",
28
+ "https://github.com/antgroup/echomimic_v3.git"
29
+ ], check=True)
30
+ print("βœ… Repository cloned successfully")
31
+
32
+ sys.path.insert(0, "echomimic_v3")
33
+ print("βœ… Repository added to Python path")
34
+
35
+ def download_models():
36
+ print("πŸ“₯ Downloading models...")
37
+ os.makedirs("models", exist_ok=True)
38
+ try:
39
+ print("πŸ”„ Downloading base model...")
40
+ base_model_path = snapshot_download(
41
+ repo_id="alibaba-pai/Wan2.1-Fun-V1.1-1.3B-InP",
42
+ local_dir="models/Wan2.1-Fun-V1.1-1.3B-InP",
43
+ local_dir_use_symlinks=False
44
+ )
45
+ print(f"βœ… Base model downloaded to: {base_model_path}")
46
+
47
+ print("πŸ”„ Downloading EchoMimicV3 transformer...")
48
+ os.makedirs("models/transformer", exist_ok=True)
49
+ transformer_file = hf_hub_download(
50
+ repo_id="BadToBest/EchoMimicV3",
51
+ filename="transformer/diffusion_pytorch_model.safetensors",
52
+ local_dir="models",
53
+ local_dir_use_symlinks=False
54
+ )
55
+ print(f"βœ… Transformer downloaded to: {transformer_file}")
56
+
57
+ config_file = hf_hub_download(
58
+ repo_id="BadToBest/EchoMimicV3",
59
+ filename="transformer/config.json",
60
+ local_dir="models",
61
+ local_dir_use_symlinks=False
62
+ )
63
+ print(f"βœ… Config downloaded to: {config_file}")
64
+
65
+ print("πŸ”„ Downloading Wav2Vec model...")
66
+ wav2vec_path = snapshot_download(
67
+ repo_id="facebook/wav2vec2-base-960h",
68
+ local_dir="models/wav2vec2-base-960h",
69
+ local_dir_use_symlinks=False
70
+ )
71
+ print(f"βœ… Wav2Vec model downloaded to: {wav2vec_path}")
72
+
73
+ print("βœ… All models downloaded successfully!")
74
+ return True
75
+
76
+ except Exception as e:
77
+ print(f"❌ Error downloading models: {e}")
78
+ return False
79
+
80
+ def download_examples():
81
+ print("πŸ“ Downloading example files...")
82
+ os.makedirs("examples", exist_ok=True)
83
+ try:
84
+ example_files = [
85
+ "datasets/echomimicv3_demos/imgs/demo_ch_woman_04.png",
86
+ "datasets/echomimicv3_demos/audios/demo_ch_woman_04.WAV",
87
+ "datasets/echomimicv3_demos/prompts/demo_ch_woman_04.txt",
88
+ "datasets/echomimicv3_demos/imgs/guitar_woman_01.png",
89
+ "datasets/echomimicv3_demos/audios/guitar_woman_01.WAV",
90
+ "datasets/echomimicv3_demos/prompts/guitar_woman_01.txt"
91
+ ]
92
+ repo_url = "https://github.com/antgroup/echomimic_v3/raw/main/"
93
+ for file_path in example_files:
94
+ try:
95
+ import urllib.request
96
+ filename = os.path.basename(file_path)
97
+ local_path = f"examples/{filename}"
98
+ if not os.path.exists(local_path):
99
+ print(f"πŸ”„ Downloading {filename}...")
100
+ urllib.request.urlretrieve(f"{repo_url}{file_path}", local_path)
101
+ print(f"βœ… Downloaded {filename}")
102
+ else:
103
+ print(f"βœ… {filename} already exists")
104
+ except Exception as e:
105
+ print(f"⚠️ Could not download {filename}: {e}")
106
+ print("βœ… Example files downloaded!")
107
+ return True
108
+ except Exception as e:
109
+ print(f"❌ Error downloading examples: {e}")
110
+ return False
111
+
112
+ setup_repository()
113
+
114
+ from src.dist import set_multi_gpus_devices
115
+ from src.wan_vae import AutoencoderKLWan
116
+ from src.wan_image_encoder import CLIPModel
117
+ from src.wan_text_encoder import WanT5EncoderModel
118
+ from src.wan_transformer3d_audio import WanTransformerAudioMask3DModel
119
+ from src.pipeline_wan_fun_inpaint_audio import WanFunInpaintAudioPipeline
120
+ from src.utils import filter_kwargs, get_image_to_video_latent3, save_videos_grid
121
+ from src.fm_solvers import FlowDPMSolverMultistepScheduler
122
+ from src.fm_solvers_unipc import FlowUniPCMultistepScheduler
123
+ from src.cache_utils import get_teacache_coefficients
124
+ from src.face_detect import get_mask_coord
125
+
126
+ class ComprehensiveConfig:
127
+ def __init__(self):
128
+ self.ulysses_degree = 1
129
+ self.ring_degree = 1
130
+ self.fsdp_dit = False
131
+ self.config_path = "echomimic_v3/config/config.yaml"
132
+ self.model_name = "models/Wan2.1-Fun-V1.1-1.3B-InP"
133
+ self.transformer_path = "models/transformer/diffusion_pytorch_model.safetensors"
134
+ self.wav2vec_model_dir = "models/wav2vec2-base-960h"
135
+ self.weight_dtype = torch.bfloat16
136
+ self.sample_size = [768, 768]
137
+ self.sampler_name = "Flow_DPM++"
138
+ self.lora_weight = 1.0
139
+
140
+ config = ComprehensiveConfig()
141
+ pipeline = None
142
+ wav2vec_processor = None
143
+ wav2vec_model = None
144
+
145
+ def load_wav2vec_models(wav2vec_model_dir):
146
+ print(f"πŸ”„ Loading Wav2Vec models from {wav2vec_model_dir}...")
147
+ try:
148
+ processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_dir)
149
+ model = Wav2Vec2Model.from_pretrained(wav2vec_model_dir).eval()
150
+ model.requires_grad_(False)
151
+ print("βœ… Wav2Vec models loaded successfully")
152
+ return processor, model
153
+ except Exception as e:
154
+ print(f"❌ Error loading Wav2Vec models: {e}")
155
+ raise
156
+
157
+ def extract_audio_features(audio_path, processor, model):
158
+ try:
159
+ sr = 16000
160
+ audio_segment, sample_rate = librosa.load(audio_path, sr=sr)
161
+ input_values = processor(audio_segment, sampling_rate=sample_rate, return_tensors="pt").input_values
162
+ input_values = input_values.to(model.device)
163
+ with torch.no_grad():
164
+ features = model(input_values).last_hidden_state
165
+ return features.squeeze(0)
166
+ except Exception as e:
167
+ print(f"❌ Error extracting audio features: {e}")
168
+ raise
169
+
170
+ def get_sample_size(image, default_size):
171
+ width, height = image.size
172
+ original_area = width * height
173
+ default_area = default_size[0] * default_size[1]
174
+ if default_area < original_area:
175
+ ratio = math.sqrt(original_area / default_area)
176
+ width = width / ratio // 16 * 16
177
+ height = height / ratio // 16 * 16
178
+ else:
179
+ width = width // 16 * 16
180
+ height = height // 16 * 16
181
+ return int(height), int(width)
182
+
183
+ def get_ip_mask(coords):
184
+ y1, y2, x1, x2, h, w = coords
185
+ Y, X = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij')
186
+ mask = (Y.unsqueeze(-1) >= y1) & (Y.unsqueeze(-1) < y2) & (X.unsqueeze(-1) >= x1) & (X.unsqueeze(-1) < x2)
187
+ mask = mask.reshape(-1)
188
+ return mask.float()
189
+
190
+ def initialize_models():
191
+ global pipeline, wav2vec_processor, wav2vec_model, config
192
+ print("πŸš€ Initializing EchoMimicV3 models...")
193
+ try:
194
+ if not download_models():
195
+ raise Exception("Failed to download required models")
196
+ download_examples()
197
+ device = set_multi_gpus_devices(config.ulysses_degree, config.ring_degree)
198
+ print(f"βœ… Device set to: {device}")
199
+ cfg = OmegaConf.load(config.config_path)
200
+ print(f"βœ… Config loaded from {config.config_path}")
201
+ print("πŸ”„ Loading transformer...")
202
+ transformer = WanTransformerAudioMask3DModel.from_pretrained(
203
+ os.path.join(config.model_name, cfg['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
204
+ transformer_additional_kwargs=OmegaConf.to_container(cfg['transformer_additional_kwargs']),
205
+ torch_dtype=config.weight_dtype,
206
+ )
207
+ if config.transformer_path is not None and os.path.exists(config.transformer_path):
208
+ print(f"πŸ”„ Loading custom transformer weights from {config.transformer_path}...")
209
+ from safetensors.torch import load_file
210
+ state_dict = load_file(config.transformer_path)
211
+ state_dict = state_dict.get("state_dict", state_dict)
212
+ missing, unexpected = transformer.load_state_dict(state_dict, strict=False)
213
+ print(f"βœ… Custom transformer weights loaded - Missing: {len(missing)}, Unexpected: {len(unexpected)}")
214
+
215
+ print("πŸ”„ Loading VAE...")
216
+ vae = AutoencoderKLWan.from_pretrained(
217
+ os.path.join(config.model_name, cfg['vae_kwargs'].get('vae_subpath', 'vae')),
218
+ additional_kwargs=OmegaConf.to_container(cfg['vae_kwargs']),
219
+ ).to(config.weight_dtype)
220
+ print("βœ… VAE loaded")
221
+
222
+ print("πŸ”„ Loading tokenizer...")
223
+ tokenizer = AutoTokenizer.from_pretrained(
224
+ os.path.join(config.model_name, cfg['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')),
225
+ )
226
+ print("βœ… Tokenizer loaded")
227
+
228
+ print("πŸ”„ Loading text encoder...")
229
+ text_encoder = WanT5EncoderModel.from_pretrained(
230
+ os.path.join(config.model_name, cfg['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
231
+ additional_kwargs=OmegaConf.to_container(cfg['text_encoder_kwargs']),
232
+ torch_dtype=config.weight_dtype,
233
+ ).eval()
234
+ print("βœ… Text encoder loaded")
235
+
236
+ print("πŸ”„ Loading CLIP image encoder...")
237
+ clip_image_encoder = CLIPModel.from_pretrained(
238
+ os.path.join(config.model_name, cfg['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')),
239
+ ).to(config.weight_dtype).eval()
240
+ print("βœ… CLIP image encoder loaded")
241
+
242
+ print("πŸ”„ Loading scheduler...")
243
+ scheduler_cls_map = {
244
+ "Flow": FlowMatchEulerDiscreteScheduler,
245
+ "Flow_Unipc": FlowUniPCMultistepScheduler,
246
+ "Flow_DPM++": FlowDPMSolverMultistepScheduler,
247
+ }
248
+ scheduler_cls = scheduler_cls_map.get(config.sampler_name, FlowDPMSolverMultistepScheduler)
249
+ scheduler = scheduler_cls(**filter_kwargs(scheduler_cls, OmegaConf.to_container(cfg['scheduler_kwargs'])))
250
+ print("βœ… Scheduler loaded")
251
+
252
+ print("πŸ”„ Creating pipeline...")
253
+ pipeline = WanFunInpaintAudioPipeline(
254
+ transformer=transformer,
255
+ vae=vae,
256
+ tokenizer=tokenizer,
257
+ text_encoder=text_encoder,
258
+ scheduler=scheduler,
259
+ clip_image_encoder=clip_image_encoder,
260
+ )
261
+ pipeline.to(device=device)
262
+ print("βœ… Pipeline created and moved to device")
263
+
264
+ print("πŸ”„ Loading Wav2Vec models...")
265
+ wav2vec_processor, wav2vec_model = load_wav2vec_models(config.wav2vec_model_dir)
266
+ wav2vec_model.to(device)
267
+ print("βœ… Wav2Vec models loaded")
268
+
269
+ print("πŸŽ‰ All models initialized successfully!")
270
+ return True
271
+ except Exception as e:
272
+ print(f"❌ Model initialization failed: {str(e)}")
273
+ import traceback
274
+ traceback.print_exc()
275
+ return False
276
+
277
+ @spaces.GPU(duration=120)
278
+ def generate_video(
279
+ image_path,
280
+ audio_path,
281
+ prompt,
282
+ negative_prompt,
283
+ seed_param,
284
+ num_inference_steps,
285
+ guidance_scale,
286
+ audio_guidance_scale,
287
+ fps,
288
+ partial_video_length,
289
+ overlap_video_length,
290
+ neg_scale,
291
+ neg_steps,
292
+ use_dynamic_cfg,
293
+ use_dynamic_acfg,
294
+ sampler_name,
295
+ shift,
296
+ audio_scale,
297
+ use_un_ip_mask,
298
+ enable_teacache,
299
+ teacache_threshold,
300
+ teacache_offload,
301
+ num_skip_start_steps,
302
+ enable_riflex,
303
+ riflex_k,
304
+ progress=gr.Progress(track_ Ο„ΟŒΟ„Ξ΅=True)
305
+ ):
306
+ global pipeline, wav2vec_processor, wav2vec_model, config
307
+
308
+ progress(0, desc="Starting video generation...")
309
+
310
+ if image_path is None: raise gr.Error("Please upload an image")
311
+ if audio_path is None: raise gr.Error("Please upload an audio file")
312
+ if not models_ready or pipeline is None: raise gr.Error("Models not initialized. Please restart the space.")
313
+
314
+ device = pipeline.device
315
+
316
+ if seed_param < 0:
317
+ seed = random.randint(0, np.iinfo(np.int32).max)
318
+ else:
319
+ seed = int(seed_param)
320
+
321
+ print(f"🎲 Using seed: {seed}")
322
+
323
+ try:
324
+ generator = torch.Generator(device=device).manual_seed(seed)
325
+ ref_img_pil = Image.open(image_path).convert("RGB")
326
+ print(f"πŸ“Έ Image loaded: {ref_img_pil.size}")
327
+
328
+ progress(0.1, desc="Detecting face...")
329
+ try:
330
+ y1, y2, x1, x2, h_, w_ = get_mask_coord(image_path)
331
+ print("βœ… Face detection successful")
332
+ except Exception as e:
333
+ print(f"⚠️ Face detection failed: {e}, using center crop")
334
+ h_, w_ = ref_img_pil.size[1], ref_img_pil.size[0]
335
+ y1, y2 = h_ // 4, 3 * h_ // 4
336
+ x1, x2 = w_ // 4, 3 * w_ // 4
337
+
338
+ progress(0.2, desc="Processing audio...")
339
+ audio_clip = AudioFileClip(audio_path)
340
+ audio_features = extract_audio_features(audio_path, wav2vec_processor, wav2vec_model)
341
+ audio_embeds = audio_features.unsqueeze(0).to(device=device, dtype=config.weight_dtype)
342
+
343
+ video_length = int(audio_clip.duration * fps)
344
+ video_length = (
345
+ int((video_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1
346
+ if video_length != 1 else 1
347
+ )
348
+ print(f"πŸŽ₯ Total video length: {video_length} frames")
349
+
350
+ sample_height, sample_width = get_sample_size(ref_img_pil, config.sample_size)
351
+ print(f"πŸ“ Sample size: {sample_width}x{sample_height}")
352
+
353
+ downratio = math.sqrt(sample_height * sample_width / h_ / w_)
354
+ coords = (
355
+ y1 * downratio // 16, y2 * downratio // 16,
356
+ x1 * downratio // 16, x2 * downratio // 16,
357
+ sample_height // 16, sample_width // 16,
358
+ )
359
+ ip_mask = get_ip_mask(coords).unsqueeze(0)
360
+ ip_mask = torch.cat([ip_mask]*3).to(device=device, dtype=config.weight_dtype)
361
+
362
+ if enable_riflex:
363
+ latent_frames = (video_length - 1) // pipeline.vae.config.temporal_compression_ratio + 1
364
+ pipeline.transformer.enable_riflex(k=riflex_k, L_test=latent_frames)
365
+
366
+ if enable_teacache:
367
+ try:
368
+ coefficients = get_teacache_coefficients(config.model_name)
369
+ if coefficients:
370
+ pipeline.transformer.enable_teacache(
371
+ coefficients, num_inference_steps, teacache_threshold,
372
+ num_skip_start_steps=num_skip_start_steps,
373
+ offload=teacache_offload
374
+ )
375
+ print("βœ… TeaCache enabled for this run")
376
+ except Exception as e:
377
+ print(f"⚠️ Could not enable TeaCache: {e}")
378
+
379
+ init_frames = 0
380
+ new_sample = None
381
+ ref_img_for_loop = ref_img_pil
382
+ total_chunks = math.ceil(video_length / (partial_video_length - overlap_video_length)) if video_length > partial_video_length else 1
383
+ chunk_num = 0
384
+
385
+ while init_frames < video_length:
386
+ chunk_num += 1
387
+ progress(0.3 + (0.6 * (chunk_num / total_chunks)), desc=f"Generating chunk {chunk_num}/{total_chunks}...")
388
+
389
+ current_partial_length = min(partial_video_length, video_length - init_frames)
390
+ current_partial_length = (
391
+ int((current_partial_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1
392
+ if current_partial_length > 1 else 1
393
+ )
394
+ if current_partial_length <= 0: break
395
+
396
+ input_video, input_video_mask, clip_image = get_image_to_video_latent3(
397
+ ref_img_for_loop, None, video_length=current_partial_length,
398
+ sample_size=[sample_height, sample_width]
399
+ )
400
+
401
+ audio_start_frame = init_frames * 2
402
+ audio_end_frame = (init_frames + current_partial_length) * 2
403
+
404
+ # Ensure audio embeds are long enough
405
+ if audio_embeds.shape[1] < audio_end_frame:
406
+ repeat_times = (audio_end_frame // audio_embeds.shape[1]) + 1
407
+ audio_embeds = audio_embeds.repeat(1, repeat_times, 1)
408
+
409
+ partial_audio_embeds = audio_embeds[:, audio_start_frame:audio_end_frame]
410
+
411
+ with torch.no_grad():
412
+ sample = pipeline(
413
+ prompt,
414
+ num_frames=current_partial_length,
415
+ negative_prompt=negative_prompt,
416
+ audio_embeds=partial_audio_embeds,
417
+ audio_scale=audio_scale,
418
+ ip_mask=ip_mask,
419
+ use_un_ip_mask=use_un_ip_mask,
420
+ height=sample_height,
421
+ width=sample_width,
422
+ generator=generator,
423
+ neg_scale=neg_scale,
424
+ neg_steps=neg_steps,
425
+ use_dynamic_cfg=use_dynamic_cfg,
426
+ use_dynamic_acfg=use_dynamic_acfg,
427
+ guidance_scale=guidance_scale,
428
+ audio_guidance_scale=audio_guidance_scale,
429
+ num_inference_steps=num_inference_steps,
430
+ video=input_video,
431
+ mask_video=input_video_mask,
432
+ clip_image=clip_image,
433
+ shift=shift,
434
+ ).videos
435
+
436
+ if new_sample is None:
437
+ new_sample = sample
438
+ else:
439
+ mix_ratio = torch.linspace(0, 1, steps=overlap_video_length, device=device).view(1, 1, -1, 1, 1).to(new_sample.dtype)
440
+ new_sample[:, :, -overlap_video_length:] = (
441
+ new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) +
442
+ sample[:, :, :overlap_video_length] * mix_ratio
443
+ )
444
+ new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim=2)
445
+
446
+ if new_sample.shape[2] >= video_length:
447
+ break
448
+
449
+ ref_img_for_loop = [
450
+ Image.fromarray(
451
+ (new_sample[0, :, i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
452
+ ) for i in range(-overlap_video_length, 0)
453
+ ]
454
+
455
+ init_frames += current_partial_length - overlap_video_length
456
+
457
+ progress(0.9, desc="Stitching video and audio...")
458
+ final_sample = new_sample[:, :, :video_length]
459
+
460
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file:
461
+ video_path = tmp_file.name
462
+ with tempfile.NamedTemporaryFile(suffix="_audio.mp4", delete=False) as tmp_file:
463
+ video_audio_path = tmp_file.name
464
+
465
+ save_videos_grid(final_sample, video_path, fps=fps)
466
+
467
+ video_clip_final = VideoFileClip(video_path)
468
+ audio_clip_trimmed = audio_clip.subclip(0, final_sample.shape[2] / fps)
469
+ final_video = video_clip_final.set_audio(audio_clip_trimmed)
470
+ final_video.write_videofile(video_audio_path, codec="libx264", audio_codec="aac", threads=4, logger=None)
471
+
472
+ video_clip_final.close()
473
+ audio_clip.close()
474
+ audio_clip_trimmed.close()
475
+ final_video.close()
476
+
477
+ gc.collect()
478
+ if torch.cuda.is_available():
479
+ torch.cuda.empty_cache()
480
+ torch.cuda.ipc_collect()
481
+
482
+ progress(1.0, desc="Generation complete!")
483
+ return video_audio_path, seed
484
+
485
+ except Exception as e:
486
+ print(f"❌ Generation error: {str(e)}")
487
+ import traceback
488
+ traceback.print_exc()
489
+ raise gr.Error(f"Generation failed: {str(e)}")
490
+
491
+
492
+ def create_demo():
493
+ with gr.Blocks(theme=gr.themes.Soft(), title="EchoMimicV3 Demo") as demo:
494
+ gr.Markdown("""
495
+ # 🎭 EchoMimicV3: Audio-Driven Human Animation
496
+
497
+ Transform a portrait photo into a talking video! Upload an image and an audio file to create lifelike, expressive animations. This demo showcases the power of the EchoMimicV3 model.
498
+
499
+ **Key Features:**
500
+ - 🎯 **High-Quality Lip Sync:** Accurate mouth movements that match the input audio.
501
+ - 🎨 **Natural Facial Expressions:** Generates subtle and natural facial emotions.
502
+ - 🎡 **Speech & Singing:** Works with both spoken word and singing.
503
+ - ⚑ **Efficient:** Powered by a compact 1.3B parameter model.
504
+ """)
505
+
506
+ if not models_ready:
507
+ gr.Warning("Models are still loading. The UI is disabled. Please wait and refresh the page if necessary.")
508
+
509
+ with gr.Row():
510
+ with gr.Column(scale=1):
511
+ image_input = gr.Image(
512
+ label="πŸ“Έ Upload Portrait Image",
513
+ type="filepath",
514
+ sources=["upload"],
515
+ height=400,
516
+ info="Use a clear, front-facing portrait photo for best results."
517
+ )
518
+ audio_input = gr.Audio(
519
+ label="🎡 Upload Audio",
520
+ type="filepath",
521
+ sources=["upload"],
522
+ info="Clear speech or singing without background noise works best."
523
+ )
524
+
525
+ with gr.Accordion("πŸ“ Text Prompts", open=True):
526
+ prompt = gr.Textbox(
527
+ label="✍️ Prompt",
528
+ value="A person talking naturally with clear expressions.",
529
+ info="Describe the desired animation. Can influence style and expression."
530
+ )
531
+ negative_prompt = gr.Textbox(
532
+ label="🚫 Negative Prompt",
533
+ value="Gesture is bad, unclear. Strange, twisted, bad, blurry hands and fingers.",
534
+ lines=2,
535
+ info="Describe what to avoid. Helps prevent artifacts."
536
+ )
537
+
538
+ with gr.Column(scale=1):
539
+ video_output = gr.Video(
540
+ label="πŸŽ₯ Generated Video",
541
+ interactive=False,
542
+ height=400
543
+ )
544
+ seed_output = gr.Number(
545
+ label="🎲 Used Seed",
546
+ interactive=False,
547
+ precision=0
548
+ )
549
+
550
+ with gr.Accordion("βš™οΈ Advanced Settings", open=False):
551
+ with gr.Row():
552
+ with gr.Column():
553
+ gr.Markdown("### Core Generation Parameters")
554
+ seed_param = gr.Number(label="🎲 Seed", value=-1, precision=0, info="-1 for random seed.")
555
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=5, maximum=50, value=20, step=1, info="More steps can improve quality but take longer. 15-25 is a good range.")
556
+ fps = gr.Slider(label="Frames Per Second (FPS)", minimum=10, maximum=30, value=25, step=1, info="Controls the smoothness of the output video.")
557
+ with gr.Column():
558
+ gr.Markdown("### Classifier-Free Guidance (CFG)")
559
+ guidance_scale = gr.Slider(label="Text Guidance Scale (CFG)", minimum=1.0, maximum=10.0, value=4.5, step=0.1, info="How strongly to follow the text prompt. Recommended: 3.0-6.0.")
560
+ audio_guidance_scale = gr.Slider(label="Audio Guidance Scale (aCFG)", minimum=1.0, maximum=10.0, value=2.5, step=0.1, info="How strongly to follow the audio for lip sync. Recommended: 2.0-3.0.")
561
+ use_dynamic_cfg = gr.Checkbox(label="Use Dynamic Text CFG", value=True, info="Gradually adjusts CFG during generation, can improve quality.")
562
+ use_dynamic_acfg = gr.Checkbox(label="Use Dynamic Audio aCFG", value=True, info="Gradually adjusts aCFG during generation, can improve quality.")
563
+
564
+ with gr.Row():
565
+ with gr.Column():
566
+ gr.Markdown("### Performance & VRAM (Chunking)")
567
+ partial_video_length = gr.Slider(label="Partial Video Length (Chunk Size)", minimum=49, maximum=161, value=113, step=16, info="Key for VRAM usage. 24G VRAM: ~113, 16G: ~81, 12G: ~49. Lower values use less memory but may affect consistency.")
568
+ overlap_video_length = gr.Slider(label="Overlap Length", minimum=4, maximum=16, value=8, step=1, info="How many frames to overlap between chunks for smooth transitions.")
569
+ with gr.Column():
570
+ gr.Markdown("### Sampler & Scheduler")
571
+ sampler_name = gr.Dropdown(label="Sampler", choices=["Flow", "Flow_Unipc", "Flow_DPM++"], value="Flow_DPM++", info="Algorithm for the diffusion process.")
572
+ shift = gr.Slider(label="Scheduler Shift", minimum=1.0, maximum=10.0, value=5.0, step=0.1, info="Adjusts the noise schedule. Optimal range depends on the sampler.")
573
+ audio_scale = gr.Slider(label="Audio Scale", minimum=0.5, maximum=2.0, value=1.0, step=0.1, info="Global scale for audio feature influence.")
574
+ use_un_ip_mask = gr.Checkbox(label="Use Un-IP Mask", value=False, info="Inverts the inpainting mask.")
575
+
576
+ with gr.Row():
577
+ with gr.Column():
578
+ gr.Markdown("### Negative Guidance (Advanced CFG)")
579
+ neg_scale = gr.Slider(label="Negative Scale", minimum=1.0, maximum=5.0, value=1.5, step=0.1, info="Strength of negative prompt in early steps.")
580
+ neg_steps = gr.Slider(label="Negative Steps", minimum=0, maximum=10, value=2, step=1, info="How many initial steps to apply the negative scale.")
581
+
582
+ with gr.Accordion("πŸ”¬ Experimental Settings", open=False):
583
+ with gr.Row():
584
+ with gr.Column():
585
+ gr.Markdown("### TeaCache (Performance Boost)")
586
+ enable_teacache = gr.Checkbox(label="Enable TeaCache", value=True)
587
+ teacache_threshold = gr.Slider(label="TeaCache Threshold", minimum=0.0, maximum=0.2, value=0.1, step=0.01)
588
+ teacache_offload = gr.Checkbox(label="TeaCache Offload", value=True)
589
+ with gr.Column():
590
+ gr.Markdown("### Riflex (Consistency)")
591
+ enable_riflex = gr.Checkbox(label="Enable Riflex", value=False)
592
+ riflex_k = gr.Slider(label="Riflex K", minimum=1, maximum=10, value=6, step=1)
593
+ with gr.Column():
594
+ gr.Markdown("### Other")
595
+ num_skip_start_steps = gr.Slider(label="Num Skip Start Steps", minimum=0, maximum=10, value=5, step=1)
596
+
597
+ generate_button = gr.Button(
598
+ "🎬 Generate Video",
599
+ variant='primary',
600
+ size="lg",
601
+ interactive=models_ready
602
+ )
603
+
604
+ all_inputs = [
605
+ image_input, audio_input, prompt, negative_prompt, seed_param,
606
+ num_inference_steps, guidance_scale, audio_guidance_scale, fps,
607
+ partial_video_length, overlap_video_length, neg_scale, neg_steps,
608
+ use_dynamic_cfg, use_dynamic_acfg, sampler_name, shift, audio_scale,
609
+ use_un_ip_mask, enable_teacache, teacache_threshold, teacache_offload,
610
+ num_skip_start_steps, enable_riflex, riflex_k
611
+ ]
612
+
613
+ if models_ready:
614
+ generate_button.click(
615
+ fn=generate_video,
616
+ inputs=all_inputs,
617
+ outputs=[video_output, seed_output]
618
+ )
619
+
620
+ gr.Markdown("---")
621
+ gr.Markdown("### ✨ Click to Try Examples")
622
+
623
+ gr.Examples(
624
+ examples=[
625
+ [
626
+ "examples/demo_ch_woman_04.png",
627
+ "examples/demo_ch_woman_04.WAV",
628
+ "A Chinese woman is talking naturally.",
629
+ "bad gestures, blurry, distorted face",
630
+ 42, 20, 4.5, 2.5, 25, 113, 8, 1.5, 2, True, True, "Flow_DPM++", 5.0, 1.0, False, True, 0.1, True, 5, False, 6
631
+ ],
632
+ [
633
+ "examples/guitar_woman_01.png",
634
+ "examples/guitar_woman_01.WAV",
635
+ "A woman with glasses is singing and playing the guitar.",
636
+ "blurry, distorted face, bad hands",
637
+ 123, 25, 5.0, 2.8, 25, 113, 8, 1.5, 2, True, True, "Flow_DPM++", 5.0, 1.0, False, True, 0.1, True, 5, False, 6
638
+ ],
639
+ ],
640
+ inputs=all_inputs,
641
+ outputs=[video_output, seed_output],
642
+ fn=generate_video,
643
+ cache_examples=True,
644
+ label=None,
645
+ )
646
+
647
+ gr.Markdown("---")
648
+ gr.Markdown("""
649
+ ### πŸ“‹ How to Use
650
+ 1. **Upload Image:** Choose a clear portrait photo (front-facing works best).
651
+ 2. **Upload Audio:** Add an audio file with clear speech or singing.
652
+ 3. **Adjust Settings (Optional):** Fine-tune parameters in the advanced sections for different results. For memory issues, try lowering the "Partial Video Length".
653
+ 4. **Generate:** Click the button and wait for your talking video!
654
+
655
+ **Note:** Generation time depends on settings and audio length. It can take a few minutes.
656
+
657
+ This demo is based on the [EchoMimicV3 repository](https://github.com/antgroup/echomimic_v3).
658
+ """)
659
+
660
+ return demo
661
+
662
+ if __name__ == "__main__":
663
+ print("πŸ”„ Starting model initialization...")
664
+ models_ready = initialize_models()
665
+
666
+ demo = create_demo()
667
+ demo.launch(share=True)