Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import spaces | |
import torch | |
import gradio as gr | |
import tempfile | |
import subprocess | |
import sys | |
from pathlib import Path | |
import datetime | |
import math | |
import random | |
import gc | |
import json | |
import numpy as np | |
from PIL import Image | |
from moviepy import * | |
import librosa | |
from omegaconf import OmegaConf | |
from transformers import AutoTokenizer, Wav2Vec2Model, Wav2Vec2Processor | |
from diffusers import FlowMatchEulerDiscreteScheduler | |
from huggingface_hub import hf_hub_download, snapshot_download | |
def setup_repository(): | |
if not os.path.exists("echomimic_v3"): | |
print("π Cloning EchoMimicV3 repository...") | |
subprocess.run([ | |
"git", "clone", | |
"https://github.com/antgroup/echomimic_v3.git" | |
], check=True) | |
print("β Repository cloned successfully") | |
sys.path.insert(0, "echomimic_v3") | |
print("β Repository added to Python path") | |
def download_models(): | |
print("π₯ Downloading models...") | |
os.makedirs("models", exist_ok=True) | |
try: | |
print("π Downloading base model...") | |
base_model_path = snapshot_download( | |
repo_id="alibaba-pai/Wan2.1-Fun-V1.1-1.3B-InP", | |
local_dir="models/Wan2.1-Fun-V1.1-1.3B-InP", | |
local_dir_use_symlinks=False | |
) | |
print(f"β Base model downloaded to: {base_model_path}") | |
print("π Downloading EchoMimicV3 transformer...") | |
os.makedirs("models/transformer", exist_ok=True) | |
transformer_file = hf_hub_download( | |
repo_id="BadToBest/EchoMimicV3", | |
filename="transformer/diffusion_pytorch_model.safetensors", | |
local_dir="models", | |
local_dir_use_symlinks=False | |
) | |
print(f"β Transformer downloaded to: {transformer_file}") | |
config_file = hf_hub_download( | |
repo_id="BadToBest/EchoMimicV3", | |
filename="transformer/config.json", | |
local_dir="models", | |
local_dir_use_symlinks=False | |
) | |
print(f"β Config downloaded to: {config_file}") | |
print("π Downloading Wav2Vec model...") | |
wav2vec_path = snapshot_download( | |
repo_id="facebook/wav2vec2-base-960h", | |
local_dir="models/wav2vec2-base-960h", | |
local_dir_use_symlinks=False | |
) | |
print(f"β Wav2Vec model downloaded to: {wav2vec_path}") | |
print("β All models downloaded successfully!") | |
return True | |
except Exception as e: | |
print(f"β Error downloading models: {e}") | |
return False | |
def download_examples(): | |
print("π Downloading example files...") | |
os.makedirs("examples", exist_ok=True) | |
try: | |
example_files = [ | |
"datasets/echomimicv3_demos/imgs/demo_ch_woman_04.png", | |
"datasets/echomimicv3_demos/audios/demo_ch_woman_04.WAV", | |
"datasets/echomimicv3_demos/prompts/demo_ch_woman_04.txt", | |
"datasets/echomimicv3_demos/imgs/guitar_woman_01.png", | |
"datasets/echomimicv3_demos/audios/guitar_woman_01.WAV", | |
"datasets/echomimicv3_demos/prompts/guitar_woman_01.txt" | |
] | |
repo_url = "https://github.com/antgroup/echomimic_v3/raw/main/" | |
for file_path in example_files: | |
try: | |
import urllib.request | |
filename = os.path.basename(file_path) | |
local_path = f"examples/{filename}" | |
if not os.path.exists(local_path): | |
print(f"π Downloading {filename}...") | |
urllib.request.urlretrieve(f"{repo_url}{file_path}", local_path) | |
print(f"β Downloaded {filename}") | |
else: | |
print(f"β {filename} already exists") | |
except Exception as e: | |
print(f"β οΈ Could not download {filename}: {e}") | |
print("β Example files downloaded!") | |
return True | |
except Exception as e: | |
print(f"β Error downloading examples: {e}") | |
return False | |
setup_repository() | |
from src.dist import set_multi_gpus_devices | |
from src.wan_vae import AutoencoderKLWan | |
from src.wan_image_encoder import CLIPModel | |
from src.wan_text_encoder import WanT5EncoderModel | |
from src.wan_transformer3d_audio import WanTransformerAudioMask3DModel | |
from src.pipeline_wan_fun_inpaint_audio import WanFunInpaintAudioPipeline | |
from src.utils import filter_kwargs, get_image_to_video_latent3, save_videos_grid | |
from src.fm_solvers import FlowDPMSolverMultistepScheduler | |
from src.fm_solvers_unipc import FlowUniPCMultistepScheduler | |
from src.cache_utils import get_teacache_coefficients | |
from src.face_detect import get_mask_coord | |
class ComprehensiveConfig: | |
def __init__(self): | |
self.ulysses_degree = 1 | |
self.ring_degree = 1 | |
self.fsdp_dit = False | |
self.config_path = "echomimic_v3/config/config.yaml" | |
self.model_name = "models/Wan2.1-Fun-V1.1-1.3B-InP" | |
self.transformer_path = "models/transformer/diffusion_pytorch_model.safetensors" | |
self.wav2vec_model_dir = "models/wav2vec2-base-960h" | |
self.weight_dtype = torch.bfloat16 | |
self.sample_size = [768, 768] | |
self.sampler_name = "Flow_DPM++" | |
self.lora_weight = 1.0 | |
config = ComprehensiveConfig() | |
pipeline = None | |
wav2vec_processor = None | |
wav2vec_model = None | |
def load_wav2vec_models(wav2vec_model_dir): | |
print(f"π Loading Wav2Vec models from {wav2vec_model_dir}...") | |
try: | |
processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_dir) | |
model = Wav2Vec2Model.from_pretrained(wav2vec_model_dir).eval() | |
model.requires_grad_(False) | |
print("β Wav2Vec models loaded successfully") | |
return processor, model | |
except Exception as e: | |
print(f"β Error loading Wav2Vec models: {e}") | |
raise | |
def extract_audio_features(audio_path, processor, model): | |
try: | |
sr = 16000 | |
audio_segment, sample_rate = librosa.load(audio_path, sr=sr) | |
input_values = processor(audio_segment, sampling_rate=sample_rate, return_tensors="pt").input_values | |
input_values = input_values.to(model.device) | |
with torch.no_grad(): | |
features = model(input_values).last_hidden_state | |
return features.squeeze(0) | |
except Exception as e: | |
print(f"β Error extracting audio features: {e}") | |
raise | |
def get_sample_size(image, default_size): | |
width, height = image.size | |
original_area = width * height | |
default_area = default_size[0] * default_size[1] | |
if default_area < original_area: | |
ratio = math.sqrt(original_area / default_area) | |
width = width / ratio // 16 * 16 | |
height = height / ratio // 16 * 16 | |
else: | |
width = width // 16 * 16 | |
height = height // 16 * 16 | |
return int(height), int(width) | |
def get_ip_mask(coords): | |
y1, y2, x1, x2, h, w = coords | |
Y, X = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij') | |
mask = (Y.unsqueeze(-1) >= y1) & (Y.unsqueeze(-1) < y2) & (X.unsqueeze(-1) >= x1) & (X.unsqueeze(-1) < x2) | |
mask = mask.reshape(-1) | |
return mask.float() | |
def initialize_models(): | |
global pipeline, wav2vec_processor, wav2vec_model, config | |
print("π Initializing EchoMimicV3 models...") | |
try: | |
if not download_models(): | |
raise Exception("Failed to download required models") | |
download_examples() | |
device = set_multi_gpus_devices(config.ulysses_degree, config.ring_degree) | |
print(f"β Device set to: {device}") | |
cfg = OmegaConf.load(config.config_path) | |
print(f"β Config loaded from {config.config_path}") | |
print("π Loading transformer...") | |
transformer = WanTransformerAudioMask3DModel.from_pretrained( | |
os.path.join(config.model_name, cfg['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')), | |
transformer_additional_kwargs=OmegaConf.to_container(cfg['transformer_additional_kwargs']), | |
torch_dtype=config.weight_dtype, | |
) | |
if config.transformer_path is not None and os.path.exists(config.transformer_path): | |
print(f"π Loading custom transformer weights from {config.transformer_path}...") | |
from safetensors.torch import load_file | |
state_dict = load_file(config.transformer_path) | |
state_dict = state_dict.get("state_dict", state_dict) | |
missing, unexpected = transformer.load_state_dict(state_dict, strict=False) | |
print(f"β Custom transformer weights loaded - Missing: {len(missing)}, Unexpected: {len(unexpected)}") | |
print("π Loading VAE...") | |
vae = AutoencoderKLWan.from_pretrained( | |
os.path.join(config.model_name, cfg['vae_kwargs'].get('vae_subpath', 'vae')), | |
additional_kwargs=OmegaConf.to_container(cfg['vae_kwargs']), | |
).to(config.weight_dtype) | |
print("β VAE loaded") | |
print("π Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained( | |
os.path.join(config.model_name, cfg['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), | |
) | |
print("β Tokenizer loaded") | |
print("π Loading text encoder...") | |
text_encoder = WanT5EncoderModel.from_pretrained( | |
os.path.join(config.model_name, cfg['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), | |
additional_kwargs=OmegaConf.to_container(cfg['text_encoder_kwargs']), | |
torch_dtype=config.weight_dtype, | |
).eval() | |
print("β Text encoder loaded") | |
print("π Loading CLIP image encoder...") | |
clip_image_encoder = CLIPModel.from_pretrained( | |
os.path.join(config.model_name, cfg['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')), | |
).to(config.weight_dtype).eval() | |
print("β CLIP image encoder loaded") | |
print("π Loading scheduler...") | |
scheduler_cls_map = { | |
"Flow": FlowMatchEulerDiscreteScheduler, | |
"Flow_Unipc": FlowUniPCMultistepScheduler, | |
"Flow_DPM++": FlowDPMSolverMultistepScheduler, | |
} | |
scheduler_cls = scheduler_cls_map.get(config.sampler_name, FlowDPMSolverMultistepScheduler) | |
scheduler = scheduler_cls(**filter_kwargs(scheduler_cls, OmegaConf.to_container(cfg['scheduler_kwargs']))) | |
print("β Scheduler loaded") | |
print("π Creating pipeline...") | |
pipeline = WanFunInpaintAudioPipeline( | |
transformer=transformer, | |
vae=vae, | |
tokenizer=tokenizer, | |
text_encoder=text_encoder, | |
scheduler=scheduler, | |
clip_image_encoder=clip_image_encoder, | |
) | |
pipeline.to(device=device) | |
if torch.__version__ >= "2.0": | |
print("π Compiling the pipeline with torch.compile()...") | |
pipeline.transformer = torch.compile(pipeline.transformer, mode="reduce-overhead", fullgraph=True) | |
print("β Pipeline transformer compiled!") | |
print("β Pipeline created and moved to device") | |
print("π Loading Wav2Vec models...") | |
wav2vec_processor, wav2vec_model = load_wav2vec_models(config.wav2vec_model_dir) | |
wav2vec_model.to(device) | |
print("β Wav2Vec models loaded") | |
print("π All models initialized successfully!") | |
return True | |
except Exception as e: | |
print(f"β Model initialization failed: {str(e)}") | |
import traceback | |
traceback.print_exc() | |
return False | |
def generate_video( | |
image_path, | |
audio_path, | |
prompt, | |
negative_prompt, | |
seed_param, | |
num_inference_steps, | |
guidance_scale, | |
audio_guidance_scale, | |
fps, | |
partial_video_length, | |
overlap_video_length, | |
neg_scale, | |
neg_steps, | |
use_dynamic_cfg, | |
use_dynamic_acfg, | |
sampler_name, | |
shift, | |
audio_scale, | |
use_un_ip_mask, | |
enable_teacache, | |
teacache_threshold, | |
teacache_offload, | |
num_skip_start_steps, | |
enable_riflex, | |
riflex_k, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
global pipeline, wav2vec_processor, wav2vec_model, config | |
progress(0, desc="Starting video generation...") | |
if image_path is None: | |
raise gr.Error("Please upload an image") | |
if audio_path is None: | |
raise gr.Error("Please upload an audio file") | |
if not models_ready or pipeline is None: | |
raise gr.Error("Models not initialized. Please restart the space.") | |
device = pipeline.device | |
if seed_param < 0: | |
seed = random.randint(0, np.iinfo(np.int32).max) | |
else: | |
seed = int(seed_param) | |
print(f"π² Using seed: {seed}") | |
try: | |
generator = torch.Generator(device=device).manual_seed(seed) | |
ref_img_pil = Image.open(image_path).convert("RGB") | |
print(f"πΈ Image loaded: {ref_img_pil.size}") | |
progress(0.1, desc="Detecting face...") | |
try: | |
y1, y2, x1, x2, h_, w_ = get_mask_coord(image_path) | |
print("β Face detection successful") | |
except Exception as e: | |
print(f"β οΈ Face detection failed: {e}, using center crop") | |
h_, w_ = ref_img_pil.size[1], ref_img_pil.size[0] | |
y1, y2 = h_ // 4, 3 * h_ // 4 | |
x1, x2 = w_ // 4, 3 * w_ // 4 | |
progress(0.2, desc="Processing audio...") | |
audio_clip = AudioFileClip(audio_path) | |
audio_features = extract_audio_features(audio_path, wav2vec_processor, wav2vec_model) | |
audio_embeds = audio_features.unsqueeze(0).to(device=device, dtype=config.weight_dtype) | |
progress(0.25, desc="Encoding prompts...") | |
prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt( | |
prompt, | |
device=device, | |
num_images_per_prompt=1, | |
do_classifier_free_guidance=(guidance_scale > 1.0), | |
negative_prompt=negative_prompt | |
) | |
video_length = int(audio_clip.duration * fps) | |
video_length = ( | |
int((video_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1 | |
if video_length != 1 else 1 | |
) | |
print(f"π₯ Total video length: {video_length} frames") | |
sample_height, sample_width = get_sample_size(ref_img_pil, config.sample_size) | |
print(f"π Sample size: {sample_width}x{sample_height}") | |
downratio = math.sqrt(sample_height * sample_width / h_ / w_) | |
coords = ( | |
y1 * downratio // 16, y2 * downratio // 16, | |
x1 * downratio // 16, x2 * downratio // 16, | |
sample_height // 16, sample_width // 16, | |
) | |
ip_mask = get_ip_mask(coords).unsqueeze(0) | |
ip_mask = torch.cat([ip_mask]*3).to(device=device, dtype=config.weight_dtype) | |
if enable_riflex: | |
latent_frames = (video_length - 1) // pipeline.vae.config.temporal_compression_ratio + 1 | |
pipeline.transformer.enable_riflex(k=riflex_k, L_test=latent_frames) | |
if enable_teacache: | |
try: | |
coefficients = get_teacache_coefficients(config.model_name) | |
if coefficients: | |
pipeline.transformer.enable_teacache( | |
coefficients, num_inference_steps, teacache_threshold, | |
num_skip_start_steps=num_skip_start_steps, | |
offload=teacache_offload | |
) | |
print("β TeaCache enabled for this run") | |
except Exception as e: | |
print(f"β οΈ Could not enable TeaCache: {e}") | |
init_frames = 0 | |
new_sample = None | |
ref_img_for_loop = ref_img_pil | |
total_chunks = math.ceil(video_length / (partial_video_length - overlap_video_length)) if video_length > partial_video_length else 1 | |
chunk_num = 0 | |
while init_frames < video_length: | |
chunk_num += 1 | |
progress(0.3 + (0.6 * (chunk_num / total_chunks)), desc=f"Generating chunk {chunk_num}/{total_chunks}...") | |
current_partial_length = min(partial_video_length, video_length - init_frames) | |
current_partial_length = ( | |
int((current_partial_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1 | |
if current_partial_length > 1 else 1 | |
) | |
if current_partial_length <= 0: | |
break | |
input_video, input_video_mask, clip_image = get_image_to_video_latent3( | |
ref_img_for_loop, None, video_length=current_partial_length, | |
sample_size=[sample_height, sample_width] | |
) | |
audio_start_frame = init_frames * 2 | |
audio_end_frame = (init_frames + current_partial_length) * 2 | |
if audio_embeds.shape[1] < audio_end_frame: | |
repeat_times = (audio_end_frame // audio_embeds.shape[1]) + 1 | |
audio_embeds = audio_embeds.repeat(1, repeat_times, 1) | |
partial_audio_embeds = audio_embeds[:, audio_start_frame:audio_end_frame] | |
with torch.no_grad(): | |
sample = pipeline( | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
num_frames=current_partial_length, | |
audio_embeds=partial_audio_embeds, | |
audio_scale=audio_scale, | |
ip_mask=ip_mask, | |
use_un_ip_mask=use_un_ip_mask, | |
height=sample_height, | |
width=sample_width, | |
generator=generator, | |
neg_scale=neg_scale, | |
neg_steps=neg_steps, | |
use_dynamic_cfg=use_dynamic_cfg, | |
use_dynamic_acfg=use_dynamic_acfg, | |
guidance_scale=guidance_scale, | |
audio_guidance_scale=audio_guidance_scale, | |
num_inference_steps=num_inference_steps, | |
video=input_video, | |
mask_video=input_video_mask, | |
clip_image=clip_image, | |
shift=shift, | |
).videos | |
if new_sample is None: | |
new_sample = sample | |
else: | |
mix_ratio = torch.linspace(0, 1, steps=overlap_video_length, device=device).view(1, 1, -1, 1, 1).to(new_sample.dtype) | |
new_sample[:, :, -overlap_video_length:] = ( | |
new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + | |
sample[:, :, :overlap_video_length] * mix_ratio | |
) | |
new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim=2) | |
if new_sample.shape[2] >= video_length: | |
break | |
ref_img_for_loop = [ | |
Image.fromarray( | |
(new_sample[0, :, i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) | |
) for i in range(-overlap_video_length, 0) | |
] | |
init_frames += current_partial_length - overlap_video_length | |
progress(0.9, desc="Stitching video and audio...") | |
final_sample = new_sample[:, :, :video_length] | |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file: | |
video_path = tmp_file.name | |
with tempfile.NamedTemporaryFile(suffix="_audio.mp4", delete=False) as tmp_file: | |
video_audio_path = tmp_file.name | |
save_videos_grid(final_sample, video_path, fps=fps) | |
video_clip_final = VideoFileClip(video_path) | |
audio_clip_trimmed = audio_clip.subclip(0, final_sample.shape[2] / fps) | |
final_video = video_clip_final.with_audio(audio_clip_trimmed) | |
final_video.write_videofile(video_audio_path, codec="libx264", audio_codec="aac", threads=4, logger=None) | |
video_clip_final.close() | |
audio_clip.close() | |
audio_clip_trimmed.close() | |
final_video.close() | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
progress(1.0, desc="Generation complete!") | |
return video_audio_path, seed | |
except Exception as e: | |
print(f"β Generation error: {str(e)}") | |
import traceback | |
traceback.print_exc() | |
raise gr.Error(f"Generation failed: {str(e)}") | |
def create_demo(): | |
with gr.Blocks(theme=gr.themes.Soft(), title="EchoMimicV3 Demo") as demo: | |
gr.Markdown(""" | |
# π EchoMimicV3: Audio-Driven Human Animation | |
Transform a portrait photo into a talking video! Upload an image and an audio file to create lifelike, expressive animations. This demo showcases the power of the EchoMimicV3 model. | |
**Key Features:** | |
- π― **High-Quality Lip Sync:** Accurate mouth movements that match the input audio. | |
- π¨ **Natural Facial Expressions:** Generates subtle and natural facial emotions. | |
- π΅ **Speech & Singing:** Works with both spoken word and singing. | |
- β‘ **Efficient:** Powered by a compact 1.3B parameter model. | |
""") | |
if not models_ready: | |
gr.Warning("Models are still loading. The UI is disabled. Please wait and refresh the page if necessary.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image( | |
label="πΈ Upload Portrait Image", | |
type="filepath", | |
sources=["upload"], | |
height=400, | |
) | |
audio_input = gr.Audio( | |
label="π΅ Upload Audio", | |
type="filepath", | |
sources=["upload"], | |
) | |
with gr.Accordion("π Text Prompts", open=True): | |
prompt = gr.Textbox( | |
label="βοΈ Prompt", | |
value="A person talking naturally with clear expressions.", | |
) | |
negative_prompt = gr.Textbox( | |
label="π« Negative Prompt", | |
value="Gesture is bad, unclear. Strange, twisted, bad, blurry hands and fingers.", | |
lines=2, | |
) | |
with gr.Column(scale=1): | |
video_output = gr.Video( | |
label="π₯ Generated Video", | |
interactive=False, | |
height=400 | |
) | |
seed_output = gr.Number( | |
label="π² Used Seed", | |
interactive=False, | |
precision=0 | |
) | |
with gr.Accordion("βοΈ Advanced Settings", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Core Generation Parameters") | |
seed_param = gr.Number(label="π² Seed", value=-1, precision=0, info="-1 for random seed.") | |
num_inference_steps = gr.Slider(label="Inference Steps", minimum=5, maximum=50, value=20, step=1, info="More steps can improve quality but take longer. 15-25 is a good range.") | |
fps = gr.Slider(label="Frames Per Second (FPS)", minimum=10, maximum=30, value=25, step=1, info="Controls the smoothness of the output video.") | |
with gr.Column(): | |
gr.Markdown("### Classifier-Free Guidance (CFG)") | |
guidance_scale = gr.Slider(label="Text Guidance Scale (CFG)", minimum=1.0, maximum=10.0, value=4.5, step=0.1, info="How strongly to follow the text prompt. Recommended: 3.0-6.0.") | |
audio_guidance_scale = gr.Slider(label="Audio Guidance Scale (aCFG)", minimum=1.0, maximum=10.0, value=2.5, step=0.1, info="How strongly to follow the audio for lip sync. Recommended: 2.0-3.0.") | |
use_dynamic_cfg = gr.Checkbox(label="Use Dynamic Text CFG", value=True, info="Gradually adjusts CFG during generation, can improve quality.") | |
use_dynamic_acfg = gr.Checkbox(label="Use Dynamic Audio aCFG", value=True, info="Gradually adjusts aCFG during generation, can improve quality.") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Performance & VRAM (Chunking)") | |
partial_video_length = gr.Slider(label="Partial Video Length (Chunk Size)", minimum=49, maximum=161, value=113, step=16, info="Key for VRAM usage. 24G VRAM: ~113, 16G: ~81, 12G: ~49. Lower values use less memory but may affect consistency.") | |
overlap_video_length = gr.Slider(label="Overlap Length", minimum=4, maximum=16, value=8, step=1, info="How many frames to overlap between chunks for smooth transitions.") | |
with gr.Column(): | |
gr.Markdown("### Sampler & Scheduler") | |
sampler_name = gr.Dropdown(label="Sampler", choices=["Flow", "Flow_Unipc", "Flow_DPM++"], value="Flow_DPM++", info="Algorithm for the diffusion process.") | |
shift = gr.Slider(label="Scheduler Shift", minimum=1.0, maximum=10.0, value=5.0, step=0.1, info="Adjusts the noise schedule. Optimal range depends on the sampler.") | |
audio_scale = gr.Slider(label="Audio Scale", minimum=0.5, maximum=2.0, value=1.0, step=0.1, info="Global scale for audio feature influence.") | |
use_un_ip_mask = gr.Checkbox(label="Use Un-IP Mask", value=False, info="Inverts the inpainting mask.") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Negative Guidance (Advanced CFG)") | |
neg_scale = gr.Slider(label="Negative Scale", minimum=1.0, maximum=5.0, value=1.5, step=0.1, info="Strength of negative prompt in early steps.") | |
neg_steps = gr.Slider(label="Negative Steps", minimum=0, maximum=10, value=2, step=1, info="How many initial steps to apply the negative scale.") | |
with gr.Accordion("π¬ Experimental Settings", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### TeaCache (Performance Boost)") | |
enable_teacache = gr.Checkbox(label="Enable TeaCache", value=True) | |
teacache_threshold = gr.Slider(label="TeaCache Threshold", minimum=0.0, maximum=0.2, value=0.1, step=0.01) | |
teacache_offload = gr.Checkbox(label="TeaCache Offload", value=True) | |
with gr.Column(): | |
gr.Markdown("### Riflex (Consistency)") | |
enable_riflex = gr.Checkbox(label="Enable Riflex", value=False) | |
riflex_k = gr.Slider(label="Riflex K", minimum=1, maximum=10, value=6, step=1) | |
with gr.Column(): | |
gr.Markdown("### Other") | |
num_skip_start_steps = gr.Slider(label="Num Skip Start Steps", minimum=0, maximum=10, value=5, step=1) | |
generate_button = gr.Button( | |
"π¬ Generate Video", | |
variant='primary', | |
size="lg", | |
interactive=models_ready | |
) | |
all_inputs = [ | |
image_input, audio_input, prompt, negative_prompt, seed_param, | |
num_inference_steps, guidance_scale, audio_guidance_scale, fps, | |
partial_video_length, overlap_video_length, neg_scale, neg_steps, | |
use_dynamic_cfg, use_dynamic_acfg, sampler_name, shift, audio_scale, | |
use_un_ip_mask, enable_teacache, teacache_threshold, teacache_offload, | |
num_skip_start_steps, enable_riflex, riflex_k | |
] | |
if models_ready: | |
generate_button.click( | |
fn=generate_video, | |
inputs=all_inputs, | |
outputs=[video_output, seed_output] | |
) | |
gr.Markdown("---") | |
gr.Markdown("### β¨ Click to Try Examples") | |
gr.Examples( | |
examples=[ | |
[ | |
"examples/demo_ch_woman_04.png", | |
"examples/demo_ch_woman_04.WAV", | |
"A Chinese woman is talking naturally.", | |
"bad gestures, blurry, distorted face", | |
42, 20, 4.5, 2.5, 25, 113, 8, 1.5, 2, True, True, "Flow_DPM++", 5.0, 1.0, False, True, 0.1, True, 5, False, 6 | |
], | |
[ | |
"examples/guitar_woman_01.png", | |
"examples/guitar_woman_01.WAV", | |
"A woman with glasses is singing and playing the guitar.", | |
"blurry, distorted face, bad hands", | |
123, 25, 5.0, 2.8, 25, 113, 8, 1.5, 2, True, True, "Flow_DPM++", 5.0, 1.0, False, True, 0.1, True, 5, False, 6 | |
], | |
], | |
inputs=all_inputs, | |
outputs=[video_output, seed_output], | |
fn=generate_video, | |
cache_examples=True, | |
label=None, | |
) | |
gr.Markdown("---") | |
gr.Markdown(""" | |
### π How to Use | |
1. **Upload Image:** Choose a clear portrait photo (front-facing works best). | |
2. **Upload Audio:** Add an audio file with clear speech or singing. | |
3. **Adjust Settings (Optional):** Fine-tune parameters in the advanced sections for different results. For memory issues, try lowering the "Partial Video Length". | |
4. **Generate:** Click the button and wait for your talking video! | |
**Note:** Generation time depends on settings and audio length. It can take a few minutes. | |
This demo is based on the [EchoMimicV3 repository](https://github.com/antgroup/echomimic_v3). | |
""") | |
return demo | |
if __name__ == "__main__": | |
print("π Starting model initialization...") | |
models_ready = initialize_models() | |
demo = create_demo() | |
demo.launch(share=True) |