import math import time import torch import random from loguru import logger from einops import rearrange from hymm_sp.diffusion import load_diffusion_pipeline from hymm_sp.helpers import get_nd_rotary_pos_embed_new from hymm_sp.inference import Inference from hymm_sp.diffusion.schedulers import FlowMatchDiscreteScheduler from hymm_sp.data_kits.audio_preprocessor import encode_audio, get_facemask def align_to(value, alignment): return int(math.ceil(value / alignment) * alignment) class HunyuanVideoSampler(Inference): def __init__(self, args, vae, vae_kwargs, text_encoder, model, text_encoder_2=None, pipeline=None, device=0, logger=None): super().__init__(args, vae, vae_kwargs, text_encoder, model, text_encoder_2=text_encoder_2, pipeline=pipeline, device=device, logger=logger) self.args = args self.pipeline = load_diffusion_pipeline( args, 0, self.vae, self.text_encoder, self.text_encoder_2, self.model, device=self.device) print('load hunyuan model successful... ') def get_rotary_pos_embed(self, video_length, height, width, concat_dict={}): target_ndim = 3 ndim = 5 - 2 if '884' in self.args.vae: latents_size = [(video_length-1)//4+1 , height//8, width//8] else: latents_size = [video_length , height//8, width//8] if isinstance(self.model.patch_size, int): assert all(s % self.model.patch_size == 0 for s in latents_size), \ f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \ f"but got {latents_size}." rope_sizes = [s // self.model.patch_size for s in latents_size] elif isinstance(self.model.patch_size, list): assert all(s % self.model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \ f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \ f"but got {latents_size}." rope_sizes = [s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)] if len(rope_sizes) != target_ndim: rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis head_dim = self.model.hidden_size // self.model.num_heads rope_dim_list = self.model.rope_dim_list if rope_dim_list is None: rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list, rope_sizes, theta=self.args.rope_theta, use_real=True, theta_rescale_factor=1, concat_dict=concat_dict) return freqs_cos, freqs_sin @torch.no_grad() def predict(self, args, batch, wav2vec, feature_extractor, align_instance, **kwargs): """ Predict the image from the given text. Args: prompt (str or List[str]): The input text. kwargs: size (int): The (height, width) of the output image/video. Default is (256, 256). video_length (int): The frame number of the output video. Default is 1. seed (int or List[str]): The random seed for the generation. Default is a random integer. negative_prompt (str or List[str]): The negative text prompt. Default is an empty string. infer_steps (int): The number of inference steps. Default is 100. guidance_scale (float): The guidance scale for the generation. Default is 6.0. num_videos_per_prompt (int): The number of videos per prompt. Default is 1. verbose (int): 0 for no log, 1 for all log, 2 for fewer log. Default is 1. output_type (str): The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. Default is 'pil'. """ out_dict = dict() prompt = batch['text_prompt'][0] image_path = str(batch["image_path"][0]) audio_path = str(batch["audio_path"][0]) neg_prompt = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, Lens changes" # videoid = batch['videoid'][0] fps = batch["fps"].to(self.device) audio_prompts = batch["audio_prompts"].to(self.device) weight_dtype = audio_prompts.dtype audio_prompts = [encode_audio(wav2vec, audio_feat.to(dtype=wav2vec.dtype), fps.item(), num_frames=batch["audio_len"][0]) for audio_feat in audio_prompts] audio_prompts = torch.cat(audio_prompts, dim=0).to(device=self.device, dtype=weight_dtype) if audio_prompts.shape[1] <= 129: audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1,129-audio_prompts.shape[1], 1, 1, 1)], dim=1) else: audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1, 5, 1, 1, 1)], dim=1) wav2vec.to("cpu") torch.cuda.empty_cache() uncond_audio_prompts = torch.zeros_like(audio_prompts[:,:129]) motion_exp = batch["motion_bucket_id_exps"].to(self.device) motion_pose = batch["motion_bucket_id_heads"].to(self.device) pixel_value_ref = batch['pixel_value_ref'].to(self.device) # (b f c h w) 取值范围[0,255] face_masks = get_facemask(pixel_value_ref.clone(), align_instance, area=3.0) pixel_value_ref = pixel_value_ref.clone().repeat(1,129,1,1,1) uncond_pixel_value_ref = torch.zeros_like(pixel_value_ref) pixel_value_ref = pixel_value_ref / 127.5 - 1. uncond_pixel_value_ref = uncond_pixel_value_ref * 2 - 1 pixel_value_ref_for_vae = rearrange(pixel_value_ref, "b f c h w -> b c f h w") uncond_uncond_pixel_value_ref = rearrange(uncond_pixel_value_ref, "b f c h w -> b c f h w") pixel_value_llava = batch["pixel_value_ref_llava"].to(self.device) pixel_value_llava = rearrange(pixel_value_llava, "b f c h w -> (b f) c h w") uncond_pixel_value_llava = pixel_value_llava.clone() # ========== Encode reference latents ========== vae_dtype = self.vae.dtype with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_dtype != torch.float32): if args.cpu_offload: self.vae.to('cuda') self.vae.enable_tiling() ref_latents = self.vae.encode(pixel_value_ref_for_vae.clone()).latent_dist.sample() uncond_ref_latents = self.vae.encode(uncond_uncond_pixel_value_ref).latent_dist.sample() self.vae.disable_tiling() if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor: ref_latents.sub_(self.vae.config.shift_factor).mul_(self.vae.config.scaling_factor) uncond_ref_latents.sub_(self.vae.config.shift_factor).mul_(self.vae.config.scaling_factor) else: ref_latents.mul_(self.vae.config.scaling_factor) uncond_ref_latents.mul_(self.vae.config.scaling_factor) if args.cpu_offload: self.vae.to('cpu') torch.cuda.empty_cache() face_masks = torch.nn.functional.interpolate(face_masks.float().squeeze(2), (ref_latents.shape[-2], ref_latents.shape[-1]), mode="bilinear").unsqueeze(2).to(dtype=ref_latents.dtype) size = (batch['pixel_value_ref'].shape[-2], batch['pixel_value_ref'].shape[-1]) target_length = 129 target_height = align_to(size[0], 16) target_width = align_to(size[1], 16) concat_dict = {'mode': 'timecat', 'bias': -1} # concat_dict = {} freqs_cos, freqs_sin = self.get_rotary_pos_embed( target_length, target_height, target_width, concat_dict) n_tokens = freqs_cos.shape[0] generator = torch.Generator(device=self.device).manual_seed(args.seed) debug_str = f""" prompt: {prompt} image_path: {image_path} audio_path: {audio_path} negative_prompt: {neg_prompt} seed: {args.seed} fps: {fps.item()} infer_steps: {args.infer_steps} target_height: {target_height} target_width: {target_width} target_length: {target_length} guidance_scale: {args.cfg_scale} """ self.logger.info(debug_str) pipeline_kwargs = { "cpu_offload": args.cpu_offload } start_time = time.time() samples = self.pipeline(prompt=prompt, height=target_height, width=target_width, frame=target_length, num_inference_steps=args.infer_steps, guidance_scale=args.cfg_scale, # cfg scale negative_prompt=neg_prompt, num_images_per_prompt=args.num_images, generator=generator, prompt_embeds=None, ref_latents=ref_latents, # [1, 16, 1, h//8, w//8] uncond_ref_latents=uncond_ref_latents, pixel_value_llava=pixel_value_llava, # [1, 3, 336, 336] uncond_pixel_value_llava=uncond_pixel_value_llava, face_masks=face_masks, # [b f h w] audio_prompts=audio_prompts, uncond_audio_prompts=uncond_audio_prompts, motion_exp=motion_exp, motion_pose=motion_pose, fps=fps, num_videos_per_prompt=1, attention_mask=None, negative_prompt_embeds=None, negative_attention_mask=None, output_type="pil", freqs_cis=(freqs_cos, freqs_sin), n_tokens=n_tokens, data_type='video', is_progress_bar=True, vae_ver=self.args.vae, enable_tiling=self.args.vae_tiling, **pipeline_kwargs )[0] if samples is None: return None out_dict['samples'] = samples gen_time = time.time() - start_time logger.info(f"Success, time: {gen_time}") wav2vec.to(self.device) return out_dict