Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import torch.utils.checkpoint | |
from PIL import Image | |
import numpy as np | |
from omegaconf import OmegaConf | |
from tqdm import tqdm | |
import cv2 | |
from diffusers import AutoencoderKLTemporalDecoder | |
from diffusers.schedulers import EulerDiscreteScheduler | |
from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatureExtractor | |
from src.utils.util import save_videos_grid, seed_everything | |
from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor | |
from src.models.base.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel, add_ip_adapters | |
from src.pipelines.pipeline_sonic import SonicPipeline | |
from src.models.audio_adapter.audio_proj import AudioProjModel | |
from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel | |
from src.utils.RIFE.RIFE_HDv3 import RIFEModel | |
from src.dataset.face_align.align import AlignImage | |
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
def test( | |
pipe, | |
config, | |
wav_enc, | |
audio_pe, | |
audio2bucket, | |
image_encoder, | |
width, | |
height, | |
batch | |
): | |
"""Run one forward pass to generate the video tensor.""" | |
for k, v in batch.items(): | |
if isinstance(v, torch.Tensor): | |
batch[k] = v.unsqueeze(0).to(pipe.device).float() | |
ref_img = batch['ref_img'] | |
clip_img = batch['clip_images'] | |
face_mask = batch['face_mask'] | |
image_embeds = image_encoder(clip_img).image_embeds | |
audio_feature = batch['audio_feature'] | |
audio_len = batch['audio_len'] | |
step = int(config.step) | |
window = 3000 | |
audio_prompts = [] | |
last_audio_prompts = [] | |
for i in range(0, audio_feature.shape[-1], window): | |
audio_prompt = wav_enc.encoder(audio_feature[:, :, i:i + window], output_hidden_states=True).hidden_states | |
last_audio_prompt = wav_enc.encoder(audio_feature[:, :, i:i + window]).last_hidden_state | |
last_audio_prompt = last_audio_prompt.unsqueeze(-2) | |
audio_prompt = torch.stack(audio_prompt, dim=2) | |
audio_prompts.append(audio_prompt) | |
last_audio_prompts.append(last_audio_prompt) | |
audio_prompts = torch.cat(audio_prompts, dim=1) | |
audio_prompts = audio_prompts[:, :audio_len * 2] | |
audio_prompts = torch.cat([torch.zeros_like(audio_prompts[:, :4]), audio_prompts, | |
torch.zeros_like(audio_prompts[:, :6])], 1) | |
last_audio_prompts = torch.cat(last_audio_prompts, dim=1) | |
last_audio_prompts = last_audio_prompts[:, :audio_len * 2] | |
last_audio_prompts = torch.cat([torch.zeros_like(last_audio_prompts[:, :24]), last_audio_prompts, | |
torch.zeros_like(last_audio_prompts[:, :26])], 1) | |
ref_tensor_list = [] | |
audio_tensor_list = [] | |
uncond_audio_tensor_list = [] | |
motion_buckets = [] | |
for i in tqdm(range(audio_len // step)): | |
audio_clip = audio_prompts[:, i * 2 * step:i * 2 * step + 10].unsqueeze(0) | |
audio_clip_for_bucket = last_audio_prompts[:, i * 2 * step:i * 2 * step + 50].unsqueeze(0) | |
motion_bucket = audio2bucket(audio_clip_for_bucket, image_embeds) | |
motion_bucket = motion_bucket * 16 + 16 | |
motion_buckets.append(motion_bucket[0]) | |
cond_audio_clip = audio_pe(audio_clip).squeeze(0) | |
uncond_audio_clip = audio_pe(torch.zeros_like(audio_clip)).squeeze(0) | |
ref_tensor_list.append(ref_img[0]) | |
audio_tensor_list.append(cond_audio_clip[0]) | |
uncond_audio_tensor_list.append(uncond_audio_clip[0]) | |
video = pipe( | |
ref_img, | |
clip_img, | |
face_mask, | |
audio_tensor_list, | |
uncond_audio_tensor_list, | |
motion_buckets, | |
height=height, | |
width=width, | |
num_frames=len(audio_tensor_list), | |
decode_chunk_size=config.decode_chunk_size, | |
motion_bucket_scale=config.motion_bucket_scale, | |
fps=config.fps, | |
noise_aug_strength=config.noise_aug_strength, | |
min_guidance_scale1=config.min_appearance_guidance_scale, | |
max_guidance_scale1=config.max_appearance_guidance_scale, | |
min_guidance_scale2=config.audio_guidance_scale, | |
max_guidance_scale2=config.audio_guidance_scale, | |
overlap=config.overlap, | |
shift_offset=config.shift_offset, | |
frames_per_batch=config.n_sample_frames, | |
num_inference_steps=config.num_inference_steps, | |
i2i_noise_strength=config.i2i_noise_strength | |
).frames | |
video = (video * 0.5 + 0.5).clamp(0, 1) | |
video = torch.cat([video.to(pipe.device)], dim=0).cpu() | |
return video | |
class Sonic: | |
"""Wrapper class for the Sonic portrait animation pipeline.""" | |
config_file = os.path.join(BASE_DIR, 'config/inference/sonic.yaml') | |
config = OmegaConf.load(config_file) | |
def __init__(self, device_id: int = 0, enable_interpolate_frame: bool = True): | |
# --------- load config & device --------- | |
config = self.config | |
config.use_interframe = enable_interpolate_frame | |
device = f'cuda:{device_id}' if device_id > -1 else 'cpu' | |
self.device = device | |
# --------- Model paths --------- | |
config.pretrained_model_name_or_path = os.path.join(BASE_DIR, config.pretrained_model_name_or_path) | |
# --------- Load sub‑modules --------- | |
vae = AutoencoderKLTemporalDecoder.from_pretrained( | |
config.pretrained_model_name_or_path, | |
subfolder="vae", | |
variant="fp16" | |
) | |
val_noise_scheduler = EulerDiscreteScheduler.from_pretrained( | |
config.pretrained_model_name_or_path, | |
subfolder="scheduler" | |
) | |
image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
config.pretrained_model_name_or_path, | |
subfolder="image_encoder", | |
variant="fp16" | |
) | |
unet = UNetSpatioTemporalConditionModel.from_pretrained( | |
config.pretrained_model_name_or_path, | |
subfolder="unet", | |
variant="fp16" | |
) | |
add_ip_adapters(unet, [32], [config.ip_audio_scale]) | |
audio2token = AudioProjModel(seq_len=10, blocks=5, channels=384, intermediate_dim=1024, output_dim=1024, | |
context_tokens=32).to(device) | |
audio2bucket = Audio2bucketModel(seq_len=50, blocks=1, channels=384, clip_channels=1024, intermediate_dim=1024, | |
output_dim=1, context_tokens=2).to(device) | |
# --------- Load checkpoints --------- | |
unet_ckpt = torch.load(os.path.join(BASE_DIR, config.unet_checkpoint_path), map_location="cpu") | |
audio2token_ckpt = torch.load(os.path.join(BASE_DIR, config.audio2token_checkpoint_path), map_location="cpu") | |
audio2bucket_ckpt = torch.load(os.path.join(BASE_DIR, config.audio2bucket_checkpoint_path), map_location="cpu") | |
unet.load_state_dict(unet_ckpt, strict=True) | |
audio2token.load_state_dict(audio2token_ckpt, strict=True) | |
audio2bucket.load_state_dict(audio2bucket_ckpt, strict=True) | |
# --------- dtype --------- | |
if config.weight_dtype == "fp16": | |
weight_dtype = torch.float16 | |
elif config.weight_dtype == "fp32": | |
weight_dtype = torch.float32 | |
elif config.weight_dtype == "bf16": | |
weight_dtype = torch.bfloat16 | |
else: | |
raise ValueError(f"Unsupported weight dtype: {config.weight_dtype}") | |
# --------- Whisper encoder for audio --------- | |
whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')).to(device).eval() | |
whisper.requires_grad_(False) | |
self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')) | |
# --------- Face detector & frame interpolator --------- | |
det_path = os.path.join(BASE_DIR, 'checkpoints/yoloface_v5m.pt') | |
self.face_det = AlignImage(device, det_path=det_path) | |
if config.use_interframe: | |
self.rife = RIFEModel(device=device) | |
self.rife.load_model(os.path.join(BASE_DIR, 'checkpoints', 'RIFE/')) | |
# --------- Move modules to device & dtype --------- | |
image_encoder.to(weight_dtype) | |
vae.to(weight_dtype) | |
unet.to(weight_dtype) | |
# --------- Compose pipeline --------- | |
pipe = SonicPipeline( | |
unet=unet, | |
image_encoder=image_encoder, | |
vae=vae, | |
scheduler=val_noise_scheduler, | |
) | |
self.pipe = pipe.to(device=device, dtype=weight_dtype) | |
self.whisper = whisper | |
self.audio2token = audio2token | |
self.audio2bucket = audio2bucket | |
self.image_encoder = image_encoder | |
print('Sonic initialization complete.') | |
# -------------------------- Public helpers -------------------------- | |
def preprocess(self, image_path: str, expand_ratio: float = 1.0): | |
"""Detect face and compute crop bbox (optional).""" | |
face_image = cv2.imread(image_path) | |
h, w = face_image.shape[:2] | |
_, _, bboxes = self.face_det(face_image, maxface=True) | |
face_num = len(bboxes) | |
bbox_s = [] | |
if face_num > 0: | |
x1, y1, ww, hh = bboxes[0] | |
x2, y2 = x1 + ww, y1 + hh | |
bbox = x1, y1, x2, y2 | |
bbox_s = process_bbox(bbox, expand_radio=expand_ratio, height=h, width=w) | |
return { | |
'face_num': face_num, | |
'crop_bbox': bbox_s, | |
} | |
def crop_image(self, input_image_path: str, output_image_path: str, crop_bbox): | |
face_image = cv2.imread(input_image_path) | |
crop_image = face_image[crop_bbox[1]:crop_bbox[3], crop_bbox[0]:crop_bbox[2]] | |
cv2.imwrite(output | |