Spaces:
Build error
Build error
File size: 11,805 Bytes
357c94c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
import math
import time
import torch
import random
from loguru import logger
from einops import rearrange
from hymm_sp.diffusion import load_diffusion_pipeline
from hymm_sp.helpers import get_nd_rotary_pos_embed_new
from hymm_sp.inference import Inference
from hymm_sp.diffusion.schedulers import FlowMatchDiscreteScheduler
from hymm_sp.data_kits.audio_preprocessor import encode_audio, get_facemask
def align_to(value, alignment):
return int(math.ceil(value / alignment) * alignment)
class HunyuanVideoSampler(Inference):
def __init__(self, args, vae, vae_kwargs, text_encoder, model, text_encoder_2=None, pipeline=None,
device=0, logger=None):
super().__init__(args, vae, vae_kwargs, text_encoder, model, text_encoder_2=text_encoder_2,
pipeline=pipeline, device=device, logger=logger)
self.args = args
self.pipeline = load_diffusion_pipeline(
args, 0, self.vae, self.text_encoder, self.text_encoder_2, self.model,
device=self.device)
print('load hunyuan model successful... ')
def get_rotary_pos_embed(self, video_length, height, width, concat_dict={}):
target_ndim = 3
ndim = 5 - 2
if '884' in self.args.vae:
latents_size = [(video_length-1)//4+1 , height//8, width//8]
else:
latents_size = [video_length , height//8, width//8]
if isinstance(self.model.patch_size, int):
assert all(s % self.model.patch_size == 0 for s in latents_size), \
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \
f"but got {latents_size}."
rope_sizes = [s // self.model.patch_size for s in latents_size]
elif isinstance(self.model.patch_size, list):
assert all(s % self.model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \
f"but got {latents_size}."
rope_sizes = [s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)]
if len(rope_sizes) != target_ndim:
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
head_dim = self.model.hidden_size // self.model.num_heads
rope_dim_list = self.model.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list,
rope_sizes,
theta=self.args.rope_theta,
use_real=True,
theta_rescale_factor=1,
concat_dict=concat_dict)
return freqs_cos, freqs_sin
@torch.no_grad()
def predict(self,
args, batch, wav2vec, feature_extractor, align_instance,
**kwargs):
"""
Predict the image from the given text.
Args:
prompt (str or List[str]): The input text.
kwargs:
size (int): The (height, width) of the output image/video. Default is (256, 256).
video_length (int): The frame number of the output video. Default is 1.
seed (int or List[str]): The random seed for the generation. Default is a random integer.
negative_prompt (str or List[str]): The negative text prompt. Default is an empty string.
infer_steps (int): The number of inference steps. Default is 100.
guidance_scale (float): The guidance scale for the generation. Default is 6.0.
num_videos_per_prompt (int): The number of videos per prompt. Default is 1.
verbose (int): 0 for no log, 1 for all log, 2 for fewer log. Default is 1.
output_type (str): The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
Default is 'pil'.
"""
out_dict = dict()
prompt = batch['text_prompt'][0]
image_path = str(batch["image_path"][0])
audio_path = str(batch["audio_path"][0])
neg_prompt = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, Lens changes"
# videoid = batch['videoid'][0]
fps = batch["fps"].to(self.device)
audio_prompts = batch["audio_prompts"].to(self.device)
weight_dtype = audio_prompts.dtype
audio_prompts = [encode_audio(wav2vec, audio_feat.to(dtype=wav2vec.dtype), fps.item(), num_frames=batch["audio_len"][0]) for audio_feat in audio_prompts]
audio_prompts = torch.cat(audio_prompts, dim=0).to(device=self.device, dtype=weight_dtype)
if audio_prompts.shape[1] <= 129:
audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1,129-audio_prompts.shape[1], 1, 1, 1)], dim=1)
else:
audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1, 5, 1, 1, 1)], dim=1)
wav2vec.to("cpu")
torch.cuda.empty_cache()
uncond_audio_prompts = torch.zeros_like(audio_prompts[:,:129])
motion_exp = batch["motion_bucket_id_exps"].to(self.device)
motion_pose = batch["motion_bucket_id_heads"].to(self.device)
pixel_value_ref = batch['pixel_value_ref'].to(self.device) # (b f c h w) 取值范围[0,255]
face_masks = get_facemask(pixel_value_ref.clone(), align_instance, area=3.0)
pixel_value_ref = pixel_value_ref.clone().repeat(1,129,1,1,1)
uncond_pixel_value_ref = torch.zeros_like(pixel_value_ref)
pixel_value_ref = pixel_value_ref / 127.5 - 1.
uncond_pixel_value_ref = uncond_pixel_value_ref * 2 - 1
pixel_value_ref_for_vae = rearrange(pixel_value_ref, "b f c h w -> b c f h w")
uncond_uncond_pixel_value_ref = rearrange(uncond_pixel_value_ref, "b f c h w -> b c f h w")
pixel_value_llava = batch["pixel_value_ref_llava"].to(self.device)
pixel_value_llava = rearrange(pixel_value_llava, "b f c h w -> (b f) c h w")
uncond_pixel_value_llava = pixel_value_llava.clone()
# ========== Encode reference latents ==========
vae_dtype = self.vae.dtype
with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_dtype != torch.float32):
if args.cpu_offload:
self.vae.to('cuda')
self.vae.enable_tiling()
ref_latents = self.vae.encode(pixel_value_ref_for_vae.clone()).latent_dist.sample()
uncond_ref_latents = self.vae.encode(uncond_uncond_pixel_value_ref).latent_dist.sample()
self.vae.disable_tiling()
if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor:
ref_latents.sub_(self.vae.config.shift_factor).mul_(self.vae.config.scaling_factor)
uncond_ref_latents.sub_(self.vae.config.shift_factor).mul_(self.vae.config.scaling_factor)
else:
ref_latents.mul_(self.vae.config.scaling_factor)
uncond_ref_latents.mul_(self.vae.config.scaling_factor)
if args.cpu_offload:
self.vae.to('cpu')
torch.cuda.empty_cache()
face_masks = torch.nn.functional.interpolate(face_masks.float().squeeze(2),
(ref_latents.shape[-2],
ref_latents.shape[-1]),
mode="bilinear").unsqueeze(2).to(dtype=ref_latents.dtype)
size = (batch['pixel_value_ref'].shape[-2], batch['pixel_value_ref'].shape[-1])
target_length = 129
target_height = align_to(size[0], 16)
target_width = align_to(size[1], 16)
concat_dict = {'mode': 'timecat', 'bias': -1}
# concat_dict = {}
freqs_cos, freqs_sin = self.get_rotary_pos_embed(
target_length,
target_height,
target_width,
concat_dict)
n_tokens = freqs_cos.shape[0]
generator = torch.Generator(device=self.device).manual_seed(args.seed)
debug_str = f"""
prompt: {prompt}
image_path: {image_path}
audio_path: {audio_path}
negative_prompt: {neg_prompt}
seed: {args.seed}
fps: {fps.item()}
infer_steps: {args.infer_steps}
target_height: {target_height}
target_width: {target_width}
target_length: {target_length}
guidance_scale: {args.cfg_scale}
"""
self.logger.info(debug_str)
pipeline_kwargs = {
"cpu_offload": args.cpu_offload
}
start_time = time.time()
samples = self.pipeline(prompt=prompt,
height=target_height,
width=target_width,
frame=target_length,
num_inference_steps=args.infer_steps,
guidance_scale=args.cfg_scale, # cfg scale
negative_prompt=neg_prompt,
num_images_per_prompt=args.num_images,
generator=generator,
prompt_embeds=None,
ref_latents=ref_latents, # [1, 16, 1, h//8, w//8]
uncond_ref_latents=uncond_ref_latents,
pixel_value_llava=pixel_value_llava, # [1, 3, 336, 336]
uncond_pixel_value_llava=uncond_pixel_value_llava,
face_masks=face_masks, # [b f h w]
audio_prompts=audio_prompts,
uncond_audio_prompts=uncond_audio_prompts,
motion_exp=motion_exp,
motion_pose=motion_pose,
fps=fps,
num_videos_per_prompt=1,
attention_mask=None,
negative_prompt_embeds=None,
negative_attention_mask=None,
output_type="pil",
freqs_cis=(freqs_cos, freqs_sin),
n_tokens=n_tokens,
data_type='video',
is_progress_bar=True,
vae_ver=self.args.vae,
enable_tiling=self.args.vae_tiling,
**pipeline_kwargs
)[0]
if samples is None:
return None
out_dict['samples'] = samples
gen_time = time.time() - start_time
logger.info(f"Success, time: {gen_time}")
wav2vec.to(self.device)
return out_dict
|