Spaces:
Build error
Build error
import math | |
import time | |
import torch | |
import random | |
from loguru import logger | |
from einops import rearrange | |
from hymm_sp.diffusion import load_diffusion_pipeline | |
from hymm_sp.helpers import get_nd_rotary_pos_embed_new | |
from hymm_sp.inference import Inference | |
from hymm_sp.diffusion.schedulers import FlowMatchDiscreteScheduler | |
from hymm_sp.data_kits.audio_preprocessor import encode_audio, get_facemask | |
def align_to(value, alignment): | |
return int(math.ceil(value / alignment) * alignment) | |
class HunyuanVideoSampler(Inference): | |
def __init__(self, args, vae, vae_kwargs, text_encoder, model, text_encoder_2=None, pipeline=None, | |
device=0, logger=None): | |
super().__init__(args, vae, vae_kwargs, text_encoder, model, text_encoder_2=text_encoder_2, | |
pipeline=pipeline, device=device, logger=logger) | |
self.args = args | |
self.pipeline = load_diffusion_pipeline( | |
args, 0, self.vae, self.text_encoder, self.text_encoder_2, self.model, | |
device=self.device) | |
print('load hunyuan model successful... ') | |
def get_rotary_pos_embed(self, video_length, height, width, concat_dict={}): | |
target_ndim = 3 | |
ndim = 5 - 2 | |
if '884' in self.args.vae: | |
latents_size = [(video_length-1)//4+1 , height//8, width//8] | |
else: | |
latents_size = [video_length , height//8, width//8] | |
if isinstance(self.model.patch_size, int): | |
assert all(s % self.model.patch_size == 0 for s in latents_size), \ | |
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \ | |
f"but got {latents_size}." | |
rope_sizes = [s // self.model.patch_size for s in latents_size] | |
elif isinstance(self.model.patch_size, list): | |
assert all(s % self.model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \ | |
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \ | |
f"but got {latents_size}." | |
rope_sizes = [s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)] | |
if len(rope_sizes) != target_ndim: | |
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis | |
head_dim = self.model.hidden_size // self.model.num_heads | |
rope_dim_list = self.model.rope_dim_list | |
if rope_dim_list is None: | |
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] | |
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" | |
freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list, | |
rope_sizes, | |
theta=self.args.rope_theta, | |
use_real=True, | |
theta_rescale_factor=1, | |
concat_dict=concat_dict) | |
return freqs_cos, freqs_sin | |
def predict(self, | |
args, batch, wav2vec, feature_extractor, align_instance, | |
**kwargs): | |
""" | |
Predict the image from the given text. | |
Args: | |
prompt (str or List[str]): The input text. | |
kwargs: | |
size (int): The (height, width) of the output image/video. Default is (256, 256). | |
video_length (int): The frame number of the output video. Default is 1. | |
seed (int or List[str]): The random seed for the generation. Default is a random integer. | |
negative_prompt (str or List[str]): The negative text prompt. Default is an empty string. | |
infer_steps (int): The number of inference steps. Default is 100. | |
guidance_scale (float): The guidance scale for the generation. Default is 6.0. | |
num_videos_per_prompt (int): The number of videos per prompt. Default is 1. | |
verbose (int): 0 for no log, 1 for all log, 2 for fewer log. Default is 1. | |
output_type (str): The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. | |
Default is 'pil'. | |
""" | |
out_dict = dict() | |
prompt = batch['text_prompt'][0] | |
image_path = str(batch["image_path"][0]) | |
audio_path = str(batch["audio_path"][0]) | |
neg_prompt = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, Lens changes" | |
# videoid = batch['videoid'][0] | |
fps = batch["fps"].to(self.device) | |
audio_prompts = batch["audio_prompts"].to(self.device) | |
weight_dtype = audio_prompts.dtype | |
audio_prompts = [encode_audio(wav2vec, audio_feat.to(dtype=wav2vec.dtype), fps.item(), num_frames=batch["audio_len"][0]) for audio_feat in audio_prompts] | |
audio_prompts = torch.cat(audio_prompts, dim=0).to(device=self.device, dtype=weight_dtype) | |
if audio_prompts.shape[1] <= 129: | |
audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1,129-audio_prompts.shape[1], 1, 1, 1)], dim=1) | |
else: | |
audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1, 5, 1, 1, 1)], dim=1) | |
wav2vec.to("cpu") | |
torch.cuda.empty_cache() | |
uncond_audio_prompts = torch.zeros_like(audio_prompts[:,:129]) | |
motion_exp = batch["motion_bucket_id_exps"].to(self.device) | |
motion_pose = batch["motion_bucket_id_heads"].to(self.device) | |
pixel_value_ref = batch['pixel_value_ref'].to(self.device) # (b f c h w) 取值范围[0,255] | |
face_masks = get_facemask(pixel_value_ref.clone(), align_instance, area=3.0) | |
pixel_value_ref = pixel_value_ref.clone().repeat(1,129,1,1,1) | |
uncond_pixel_value_ref = torch.zeros_like(pixel_value_ref) | |
pixel_value_ref = pixel_value_ref / 127.5 - 1. | |
uncond_pixel_value_ref = uncond_pixel_value_ref * 2 - 1 | |
pixel_value_ref_for_vae = rearrange(pixel_value_ref, "b f c h w -> b c f h w") | |
uncond_uncond_pixel_value_ref = rearrange(uncond_pixel_value_ref, "b f c h w -> b c f h w") | |
pixel_value_llava = batch["pixel_value_ref_llava"].to(self.device) | |
pixel_value_llava = rearrange(pixel_value_llava, "b f c h w -> (b f) c h w") | |
uncond_pixel_value_llava = pixel_value_llava.clone() | |
# ========== Encode reference latents ========== | |
vae_dtype = self.vae.dtype | |
with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_dtype != torch.float32): | |
if args.cpu_offload: | |
self.vae.to('cuda') | |
self.vae.enable_tiling() | |
ref_latents = self.vae.encode(pixel_value_ref_for_vae.clone()).latent_dist.sample() | |
uncond_ref_latents = self.vae.encode(uncond_uncond_pixel_value_ref).latent_dist.sample() | |
self.vae.disable_tiling() | |
if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor: | |
ref_latents.sub_(self.vae.config.shift_factor).mul_(self.vae.config.scaling_factor) | |
uncond_ref_latents.sub_(self.vae.config.shift_factor).mul_(self.vae.config.scaling_factor) | |
else: | |
ref_latents.mul_(self.vae.config.scaling_factor) | |
uncond_ref_latents.mul_(self.vae.config.scaling_factor) | |
if args.cpu_offload: | |
self.vae.to('cpu') | |
torch.cuda.empty_cache() | |
face_masks = torch.nn.functional.interpolate(face_masks.float().squeeze(2), | |
(ref_latents.shape[-2], | |
ref_latents.shape[-1]), | |
mode="bilinear").unsqueeze(2).to(dtype=ref_latents.dtype) | |
size = (batch['pixel_value_ref'].shape[-2], batch['pixel_value_ref'].shape[-1]) | |
target_length = 129 | |
target_height = align_to(size[0], 16) | |
target_width = align_to(size[1], 16) | |
concat_dict = {'mode': 'timecat', 'bias': -1} | |
# concat_dict = {} | |
freqs_cos, freqs_sin = self.get_rotary_pos_embed( | |
target_length, | |
target_height, | |
target_width, | |
concat_dict) | |
n_tokens = freqs_cos.shape[0] | |
generator = torch.Generator(device=self.device).manual_seed(args.seed) | |
debug_str = f""" | |
prompt: {prompt} | |
image_path: {image_path} | |
audio_path: {audio_path} | |
negative_prompt: {neg_prompt} | |
seed: {args.seed} | |
fps: {fps.item()} | |
infer_steps: {args.infer_steps} | |
target_height: {target_height} | |
target_width: {target_width} | |
target_length: {target_length} | |
guidance_scale: {args.cfg_scale} | |
""" | |
self.logger.info(debug_str) | |
pipeline_kwargs = { | |
"cpu_offload": args.cpu_offload | |
} | |
start_time = time.time() | |
samples = self.pipeline(prompt=prompt, | |
height=target_height, | |
width=target_width, | |
frame=target_length, | |
num_inference_steps=args.infer_steps, | |
guidance_scale=args.cfg_scale, # cfg scale | |
negative_prompt=neg_prompt, | |
num_images_per_prompt=args.num_images, | |
generator=generator, | |
prompt_embeds=None, | |
ref_latents=ref_latents, # [1, 16, 1, h//8, w//8] | |
uncond_ref_latents=uncond_ref_latents, | |
pixel_value_llava=pixel_value_llava, # [1, 3, 336, 336] | |
uncond_pixel_value_llava=uncond_pixel_value_llava, | |
face_masks=face_masks, # [b f h w] | |
audio_prompts=audio_prompts, | |
uncond_audio_prompts=uncond_audio_prompts, | |
motion_exp=motion_exp, | |
motion_pose=motion_pose, | |
fps=fps, | |
num_videos_per_prompt=1, | |
attention_mask=None, | |
negative_prompt_embeds=None, | |
negative_attention_mask=None, | |
output_type="pil", | |
freqs_cis=(freqs_cos, freqs_sin), | |
n_tokens=n_tokens, | |
data_type='video', | |
is_progress_bar=True, | |
vae_ver=self.args.vae, | |
enable_tiling=self.args.vae_tiling, | |
**pipeline_kwargs | |
)[0] | |
if samples is None: | |
return None | |
out_dict['samples'] = samples | |
gen_time = time.time() - start_time | |
logger.info(f"Success, time: {gen_time}") | |
wav2vec.to(self.device) | |
return out_dict | |