from typing import Any, List, Tuple, Optional, Union, Dict from einops import rearrange import torch import torch.nn as nn import torch.nn.functional as F from diffusers.models import ModelMixin from diffusers.configuration_utils import ConfigMixin, register_to_config from .activation_layers import get_activation_layer from .norm_layers import get_norm_layer from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection from .attenion import attention, parallel_attention, get_cu_seqlens from .posemb_layers import apply_rotary_emb from .mlp_layers import MLP, MLPEmbedder, FinalLayer from .modulate_layers import ModulateDiT, modulate, modulate_ , apply_gate, apply_gate_and_accumulate_ from .token_refiner import SingleTokenRefiner import numpy as np from mmgp import offload from wan.modules.attention import pay_attention from .audio_adapters import AudioProjNet2, PerceiverAttentionCA def get_linear_split_map(): hidden_size = 3072 split_linear_modules_map = { "img_attn_qkv" : {"mapped_modules" : ["img_attn_q", "img_attn_k", "img_attn_v"] , "split_sizes": [hidden_size, hidden_size, hidden_size]}, "linear1" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v", "linear1_mlp"] , "split_sizes": [hidden_size, hidden_size, hidden_size, 7*hidden_size- 3*hidden_size]} } return split_linear_modules_map try: from xformers.ops.fmha.attn_bias import BlockDiagonalPaddedKeysMask except ImportError: BlockDiagonalPaddedKeysMask = None class MMDoubleStreamBlock(nn.Module): """ A multimodal dit block with seperate modulation for text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206 (Flux.1): https://github.com/black-forest-labs/flux """ def __init__( self, hidden_size: int, heads_num: int, mlp_width_ratio: float, mlp_act_type: str = "gelu_tanh", qk_norm: bool = True, qk_norm_type: str = "rms", qkv_bias: bool = False, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, attention_mode: str = "sdpa", ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.attention_mode = attention_mode self.deterministic = False self.heads_num = heads_num head_dim = hidden_size // heads_num mlp_hidden_dim = int(hidden_size * mlp_width_ratio) self.img_mod = ModulateDiT( hidden_size, factor=6, act_layer=get_activation_layer("silu"), **factory_kwargs, ) self.img_norm1 = nn.LayerNorm( hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs ) self.img_attn_qkv = nn.Linear( hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs ) qk_norm_layer = get_norm_layer(qk_norm_type) self.img_attn_q_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.img_attn_k_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.img_attn_proj = nn.Linear( hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs ) self.img_norm2 = nn.LayerNorm( hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs ) self.img_mlp = MLP( hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs, ) self.txt_mod = ModulateDiT( hidden_size, factor=6, act_layer=get_activation_layer("silu"), **factory_kwargs, ) self.txt_norm1 = nn.LayerNorm( hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs ) self.txt_attn_qkv = nn.Linear( hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs ) self.txt_attn_q_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.txt_attn_k_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.txt_attn_proj = nn.Linear( hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs ) self.txt_norm2 = nn.LayerNorm( hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs ) self.txt_mlp = MLP( hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs, ) self.hybrid_seq_parallel_attn = None def enable_deterministic(self): self.deterministic = True def disable_deterministic(self): self.deterministic = False def forward( self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, attn_mask = None, seqlens_q: Optional[torch.Tensor] = None, seqlens_kv: Optional[torch.Tensor] = None, freqs_cis: tuple = None, condition_type: str = None, token_replace_vec: torch.Tensor = None, frist_frame_token_num: int = None, ) -> Tuple[torch.Tensor, torch.Tensor]: if condition_type == "token_replace": img_mod1, token_replace_img_mod1 = self.img_mod(vec, condition_type=condition_type, \ token_replace_vec=token_replace_vec) (img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = img_mod1.chunk(6, dim=-1) (tr_img_mod1_shift, tr_img_mod1_scale, tr_img_mod1_gate, tr_img_mod2_shift, tr_img_mod2_scale, tr_img_mod2_gate) = token_replace_img_mod1.chunk(6, dim=-1) else: ( img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate, ) = self.img_mod(vec).chunk(6, dim=-1) ( txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate, ) = self.txt_mod(vec).chunk(6, dim=-1) ##### Enjoy this spagheti VRAM optimizations done by DeepBeepMeep ! # I am sure you are a nice person and as you copy this code, you will give me officially proper credits: # Please link to https://github.com/deepbeepmeep/HunyuanVideoGP and @deepbeepmeep on twitter # Prepare image for attention. img_modulated = self.img_norm1(img) img_modulated = img_modulated.to(torch.bfloat16) if condition_type == "token_replace": modulate_(img_modulated[:, :frist_frame_token_num], shift=tr_img_mod1_shift, scale=tr_img_mod1_scale) modulate_(img_modulated[:, frist_frame_token_num:], shift=img_mod1_shift, scale=img_mod1_scale) else: modulate_( img_modulated, shift=img_mod1_shift, scale=img_mod1_scale ) shape = (*img_modulated.shape[:2], self.heads_num, int(img_modulated.shape[-1] / self.heads_num) ) img_q = self.img_attn_q(img_modulated).view(*shape) img_k = self.img_attn_k(img_modulated).view(*shape) img_v = self.img_attn_v(img_modulated).view(*shape) del img_modulated # Apply QK-Norm if needed self.img_attn_q_norm.apply_(img_q).to(img_v) img_q_len = img_q.shape[1] self.img_attn_k_norm.apply_(img_k).to(img_v) img_kv_len= img_k.shape[1] batch_size = img_k.shape[0] # Apply RoPE if needed. qklist = [img_q, img_k] del img_q, img_k img_q, img_k = apply_rotary_emb(qklist, freqs_cis, head_first=False) # Prepare txt for attention. txt_modulated = self.txt_norm1(txt) modulate_(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale ) txt_qkv = self.txt_attn_qkv(txt_modulated) del txt_modulated txt_q, txt_k, txt_v = rearrange( txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num ) del txt_qkv # Apply QK-Norm if needed. self.txt_attn_q_norm.apply_(txt_q).to(txt_v) self.txt_attn_k_norm.apply_(txt_k).to(txt_v) # Run actual attention. q = torch.cat((img_q, txt_q), dim=1) del img_q, txt_q k = torch.cat((img_k, txt_k), dim=1) del img_k, txt_k v = torch.cat((img_v, txt_v), dim=1) del img_v, txt_v # attention computation start qkv_list = [q,k,v] del q, k, v attn = pay_attention( qkv_list, attention_mask=attn_mask, q_lens=seqlens_q, k_lens=seqlens_kv, ) b, s, a, d = attn.shape attn = attn.reshape(b, s, -1) del qkv_list # attention computation end img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] del attn # Calculate the img bloks. if condition_type == "token_replace": img_attn = self.img_attn_proj(img_attn) apply_gate_and_accumulate_(img[:, :frist_frame_token_num], img_attn[:, :frist_frame_token_num], gate=tr_img_mod1_gate) apply_gate_and_accumulate_(img[:, frist_frame_token_num:], img_attn[:, frist_frame_token_num:], gate=img_mod1_gate) del img_attn img_modulated = self.img_norm2(img) img_modulated = img_modulated.to(torch.bfloat16) modulate_( img_modulated[:, :frist_frame_token_num], shift=tr_img_mod2_shift, scale=tr_img_mod2_scale) modulate_( img_modulated[:, frist_frame_token_num:], shift=img_mod2_shift, scale=img_mod2_scale) self.img_mlp.apply_(img_modulated) apply_gate_and_accumulate_(img[:, :frist_frame_token_num], img_modulated[:, :frist_frame_token_num], gate=tr_img_mod2_gate) apply_gate_and_accumulate_(img[:, frist_frame_token_num:], img_modulated[:, frist_frame_token_num:], gate=img_mod2_gate) del img_modulated else: img_attn = self.img_attn_proj(img_attn) apply_gate_and_accumulate_(img, img_attn, gate=img_mod1_gate) del img_attn img_modulated = self.img_norm2(img) img_modulated = img_modulated.to(torch.bfloat16) modulate_( img_modulated , shift=img_mod2_shift, scale=img_mod2_scale) self.img_mlp.apply_(img_modulated) apply_gate_and_accumulate_(img, img_modulated, gate=img_mod2_gate) del img_modulated # Calculate the txt bloks. txt_attn = self.txt_attn_proj(txt_attn) apply_gate_and_accumulate_(txt, txt_attn, gate=txt_mod1_gate) del txt_attn txt_modulated = self.txt_norm2(txt) txt_modulated = txt_modulated.to(torch.bfloat16) modulate_(txt_modulated, shift=txt_mod2_shift, scale=txt_mod2_scale) txt_mlp = self.txt_mlp(txt_modulated) del txt_modulated apply_gate_and_accumulate_(txt, txt_mlp, gate=txt_mod2_gate) return img, txt class MMSingleStreamBlock(nn.Module): """ A DiT block with parallel linear layers as described in https://arxiv.org/abs/2302.05442 and adapted modulation interface. Also refer to (SD3): https://arxiv.org/abs/2403.03206 (Flux.1): https://github.com/black-forest-labs/flux """ def __init__( self, hidden_size: int, heads_num: int, mlp_width_ratio: float = 4.0, mlp_act_type: str = "gelu_tanh", qk_norm: bool = True, qk_norm_type: str = "rms", qk_scale: float = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, attention_mode: str = "sdpa", ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.attention_mode = attention_mode self.deterministic = False self.hidden_size = hidden_size self.heads_num = heads_num head_dim = hidden_size // heads_num mlp_hidden_dim = int(hidden_size * mlp_width_ratio) self.mlp_hidden_dim = mlp_hidden_dim self.scale = qk_scale or head_dim ** -0.5 # qkv and mlp_in self.linear1 = nn.Linear( hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs ) # proj and mlp_out self.linear2 = nn.Linear( hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs ) qk_norm_layer = get_norm_layer(qk_norm_type) self.q_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.k_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.pre_norm = nn.LayerNorm( hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs ) self.mlp_act = get_activation_layer(mlp_act_type)() self.modulation = ModulateDiT( hidden_size, factor=3, act_layer=get_activation_layer("silu"), **factory_kwargs, ) self.hybrid_seq_parallel_attn = None def enable_deterministic(self): self.deterministic = True def disable_deterministic(self): self.deterministic = False def forward( self, # x: torch.Tensor, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, txt_len: int, attn_mask= None, seqlens_q: Optional[torch.Tensor] = None, seqlens_kv: Optional[torch.Tensor] = None, freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, condition_type: str = None, token_replace_vec: torch.Tensor = None, frist_frame_token_num: int = None, ) -> torch.Tensor: ##### More spagheti VRAM optimizations done by DeepBeepMeep ! # I am sure you are a nice person and as you copy this code, you will give me proper credits: # Please link to https://github.com/deepbeepmeep/HunyuanVideoGP and @deepbeepmeep on twitter if condition_type == "token_replace": mod, tr_mod = self.modulation(vec, condition_type=condition_type, token_replace_vec=token_replace_vec) (mod_shift, mod_scale, mod_gate) = mod.chunk(3, dim=-1) (tr_mod_shift, tr_mod_scale, tr_mod_gate) = tr_mod.chunk(3, dim=-1) else: mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1) img_mod = self.pre_norm(img) img_mod = img_mod.to(torch.bfloat16) if condition_type == "token_replace": modulate_(img_mod[:, :frist_frame_token_num], shift=tr_mod_shift, scale=tr_mod_scale) modulate_(img_mod[:, frist_frame_token_num:], shift=mod_shift, scale=mod_scale) else: modulate_(img_mod, shift=mod_shift, scale=mod_scale) txt_mod = self.pre_norm(txt) txt_mod = txt_mod.to(torch.bfloat16) modulate_(txt_mod, shift=mod_shift, scale=mod_scale) shape = (*img_mod.shape[:2], self.heads_num, int(img_mod.shape[-1] / self.heads_num) ) img_q = self.linear1_attn_q(img_mod).view(*shape) img_k = self.linear1_attn_k(img_mod).view(*shape) img_v = self.linear1_attn_v(img_mod).view(*shape) shape = (*txt_mod.shape[:2], self.heads_num, int(txt_mod.shape[-1] / self.heads_num) ) txt_q = self.linear1_attn_q(txt_mod).view(*shape) txt_k = self.linear1_attn_k(txt_mod).view(*shape) txt_v = self.linear1_attn_v(txt_mod).view(*shape) batch_size = img_mod.shape[0] # Apply QK-Norm if needed. # q = self.q_norm(q).to(v) self.q_norm.apply_(img_q) self.k_norm.apply_(img_k) self.q_norm.apply_(txt_q) self.k_norm.apply_(txt_k) qklist = [img_q, img_k] del img_q, img_k img_q, img_k = apply_rotary_emb(qklist, freqs_cis, head_first=False) img_q_len=img_q.shape[1] q = torch.cat((img_q, txt_q), dim=1) del img_q, txt_q k = torch.cat((img_k, txt_k), dim=1) img_kv_len=img_k.shape[1] del img_k, txt_k v = torch.cat((img_v, txt_v), dim=1) del img_v, txt_v # attention computation start qkv_list = [q,k,v] del q, k, v attn = pay_attention( qkv_list, attention_mask=attn_mask, q_lens = seqlens_q, k_lens = seqlens_kv, ) b, s, a, d = attn.shape attn = attn.reshape(b, s, -1) del qkv_list # attention computation end x_mod = torch.cat((img_mod, txt_mod), 1) del img_mod, txt_mod x_mod_shape = x_mod.shape x_mod = x_mod.view(-1, x_mod.shape[-1]) chunk_size = int(x_mod_shape[1]/6) x_chunks = torch.split(x_mod, chunk_size) attn = attn.view(-1, attn.shape[-1]) attn_chunks =torch.split(attn, chunk_size) for x_chunk, attn_chunk in zip(x_chunks, attn_chunks): mlp_chunk = self.linear1_mlp(x_chunk) mlp_chunk = self.mlp_act(mlp_chunk) attn_mlp_chunk = torch.cat((attn_chunk, mlp_chunk), -1) del attn_chunk, mlp_chunk x_chunk[...] = self.linear2(attn_mlp_chunk) del attn_mlp_chunk x_mod = x_mod.view(x_mod_shape) if condition_type == "token_replace": apply_gate_and_accumulate_(img[:, :frist_frame_token_num, :], x_mod[:, :frist_frame_token_num, :], gate=tr_mod_gate) apply_gate_and_accumulate_(img[:, frist_frame_token_num:, :], x_mod[:, frist_frame_token_num:-txt_len, :], gate=mod_gate) else: apply_gate_and_accumulate_(img, x_mod[:, :-txt_len, :], gate=mod_gate) apply_gate_and_accumulate_(txt, x_mod[:, -txt_len:, :], gate=mod_gate) return img, txt class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): def preprocess_loras(self, model_type, sd): if model_type != "i2v" : return sd new_sd = {} for k,v in sd.items(): repl_list = ["double_blocks", "single_blocks", "final_layer", "img_mlp", "img_attn_qkv", "img_attn_proj","img_mod", "txt_mlp", "txt_attn_qkv","txt_attn_proj", "txt_mod", "linear1", "linear2", "modulation", "mlp_fc1"] src_list = [k +"_" for k in repl_list] + ["_" + k for k in repl_list] tgt_list = [k +"." for k in repl_list] + ["." + k for k in repl_list] if k.startswith("Hunyuan_video_I2V_lora_"): # crappy conversion script for non reversible lora naming k = k.replace("Hunyuan_video_I2V_lora_","diffusion_model.") k = k.replace("lora_up","lora_B") k = k.replace("lora_down","lora_A") if "txt_in_individual" in k: pass for s,t in zip(src_list, tgt_list): k = k.replace(s,t) if "individual_token_refiner" in k: k = k.replace("txt_in_individual_token_refiner_blocks_", "txt_in.individual_token_refiner.blocks.") k = k.replace("_mlp_fc", ".mlp.fc",) k = k.replace(".mlp_fc", ".mlp.fc",) new_sd[k] = v return new_sd """ HunyuanVideo Transformer backbone Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline. Reference: [1] Flux.1: https://github.com/black-forest-labs/flux [2] MMDiT: http://arxiv.org/abs/2403.03206 Parameters ---------- args: argparse.Namespace The arguments parsed by argparse. patch_size: list The size of the patch. in_channels: int The number of input channels. out_channels: int The number of output channels. hidden_size: int The hidden size of the transformer backbone. heads_num: int The number of attention heads. mlp_width_ratio: float The ratio of the hidden size of the MLP in the transformer block. mlp_act_type: str The activation function of the MLP in the transformer block. depth_double_blocks: int The number of transformer blocks in the double blocks. depth_single_blocks: int The number of transformer blocks in the single blocks. rope_dim_list: list The dimension of the rotary embedding for t, h, w. qkv_bias: bool Whether to use bias in the qkv linear layer. qk_norm: bool Whether to use qk norm. qk_norm_type: str The type of qk norm. guidance_embed: bool Whether to use guidance embedding for distillation. text_projection: str The type of the text projection, default is single_refiner. use_attention_mask: bool Whether to use attention mask for text encoder. dtype: torch.dtype The dtype of the model. device: torch.device The device of the model. """ @register_to_config def __init__( self, i2v_condition_type, patch_size: list = [1, 2, 2], in_channels: int = 4, # Should be VAE.config.latent_channels. out_channels: int = None, hidden_size: int = 3072, heads_num: int = 24, mlp_width_ratio: float = 4.0, mlp_act_type: str = "gelu_tanh", mm_double_blocks_depth: int = 20, mm_single_blocks_depth: int = 40, rope_dim_list: List[int] = [16, 56, 56], qkv_bias: bool = True, qk_norm: bool = True, qk_norm_type: str = "rms", guidance_embed: bool = False, # For modulation. text_projection: str = "single_refiner", use_attention_mask: bool = True, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, attention_mode: Optional[str] = "sdpa", video_condition: bool = False, audio_condition: bool = False, avatar = False, custom = False, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() # mm_double_blocks_depth , mm_single_blocks_depth = 5, 5 self.patch_size = patch_size self.in_channels = in_channels self.out_channels = in_channels if out_channels is None else out_channels self.unpatchify_channels = self.out_channels self.guidance_embed = guidance_embed self.rope_dim_list = rope_dim_list self.i2v_condition_type = i2v_condition_type self.attention_mode = attention_mode self.video_condition = video_condition self.audio_condition = audio_condition self.avatar = avatar self.custom = custom # Text projection. Default to linear projection. # Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831 self.use_attention_mask = use_attention_mask self.text_projection = text_projection self.text_states_dim = 4096 self.text_states_dim_2 = 768 if hidden_size % heads_num != 0: raise ValueError( f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}" ) pe_dim = hidden_size // heads_num if sum(rope_dim_list) != pe_dim: raise ValueError( f"Got {rope_dim_list} but expected positional dim {pe_dim}" ) self.hidden_size = hidden_size self.heads_num = heads_num # image projection self.img_in = PatchEmbed( self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs ) # text projection if self.text_projection == "linear": self.txt_in = TextProjection( self.text_states_dim, self.hidden_size, get_activation_layer("silu"), **factory_kwargs, ) elif self.text_projection == "single_refiner": self.txt_in = SingleTokenRefiner( self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs ) else: raise NotImplementedError( f"Unsupported text_projection: {self.text_projection}" ) # time modulation self.time_in = TimestepEmbedder( self.hidden_size, get_activation_layer("silu"), **factory_kwargs ) # text modulation self.vector_in = MLPEmbedder( self.text_states_dim_2, self.hidden_size, **factory_kwargs ) # guidance modulation self.guidance_in = ( TimestepEmbedder( self.hidden_size, get_activation_layer("silu"), **factory_kwargs ) if guidance_embed else None ) # double blocks self.double_blocks = nn.ModuleList( [ MMDoubleStreamBlock( self.hidden_size, self.heads_num, mlp_width_ratio=mlp_width_ratio, mlp_act_type=mlp_act_type, qk_norm=qk_norm, qk_norm_type=qk_norm_type, qkv_bias=qkv_bias, attention_mode = attention_mode, **factory_kwargs, ) for _ in range(mm_double_blocks_depth) ] ) # single blocks self.single_blocks = nn.ModuleList( [ MMSingleStreamBlock( self.hidden_size, self.heads_num, mlp_width_ratio=mlp_width_ratio, mlp_act_type=mlp_act_type, qk_norm=qk_norm, qk_norm_type=qk_norm_type, attention_mode = attention_mode, **factory_kwargs, ) for _ in range(mm_single_blocks_depth) ] ) self.final_layer = FinalLayer( self.hidden_size, self.patch_size, self.out_channels, get_activation_layer("silu"), **factory_kwargs, ) if self.video_condition: self.bg_in = PatchEmbed( self.patch_size, self.in_channels * 2, self.hidden_size, **factory_kwargs ) self.bg_proj = nn.Linear(self.hidden_size, self.hidden_size) if audio_condition: if avatar: self.ref_in = PatchEmbed( self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs ) # -------------------- audio_proj_model -------------------- self.audio_proj = AudioProjNet2(seq_len=10, blocks=5, channels=384, intermediate_dim=1024, output_dim=3072, context_tokens=4) # -------------------- motion-embeder -------------------- self.motion_exp = TimestepEmbedder( self.hidden_size // 4, get_activation_layer("silu"), **factory_kwargs ) self.motion_pose = TimestepEmbedder( self.hidden_size // 4, get_activation_layer("silu"), **factory_kwargs ) self.fps_proj = TimestepEmbedder( self.hidden_size, get_activation_layer("silu"), **factory_kwargs ) self.before_proj = nn.Linear(self.hidden_size, self.hidden_size) # -------------------- audio_insert_model -------------------- self.double_stream_list = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19] audio_block_name = "audio_adapter_blocks" elif custom: self.audio_proj = AudioProjNet2(seq_len=10, blocks=5, channels=384, intermediate_dim=1024, output_dim=3072, context_tokens=4) self.double_stream_list = [1, 3, 5, 7, 9, 11] audio_block_name = "audio_models" self.double_stream_map = {str(i): j for j, i in enumerate(self.double_stream_list)} self.single_stream_list = [] self.single_stream_map = {str(i): j+len(self.double_stream_list) for j, i in enumerate(self.single_stream_list)} setattr(self, audio_block_name, nn.ModuleList([ PerceiverAttentionCA(dim=3072, dim_head=1024, heads=33) for _ in range(len(self.double_stream_list) + len(self.single_stream_list)) ])) def lock_layers_dtypes(self, dtype = torch.float32): layer_list = [self.final_layer, self.final_layer.linear, self.final_layer.adaLN_modulation[1]] target_dype= dtype for current_layer_list, current_dtype in zip([layer_list], [target_dype]): for layer in current_layer_list: layer._lock_dtype = dtype if hasattr(layer, "weight") and layer.weight.dtype != current_dtype : layer.weight.data = layer.weight.data.to(current_dtype) if hasattr(layer, "bias"): layer.bias.data = layer.bias.data.to(current_dtype) self._lock_dtype = dtype def enable_deterministic(self): for block in self.double_blocks: block.enable_deterministic() for block in self.single_blocks: block.enable_deterministic() def disable_deterministic(self): for block in self.double_blocks: block.disable_deterministic() for block in self.single_blocks: block.disable_deterministic() def forward( self, x: torch.Tensor, t: torch.Tensor, # Should be in range(0, 1000). ref_latents: torch.Tensor=None, text_states: torch.Tensor = None, text_mask: torch.Tensor = None, # Now we don't use it. text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation. freqs_cos: Optional[torch.Tensor] = None, freqs_sin: Optional[torch.Tensor] = None, guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000. pipeline=None, x_id = 0, step_no = 0, callback = None, audio_prompts = None, motion_exp = None, motion_pose = None, fps = None, face_mask = None, audio_strength = None, bg_latents = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: img = x bsz, _, ot, oh, ow = x.shape del x txt = text_states tt, th, tw = ( ot // self.patch_size[0], oh // self.patch_size[1], ow // self.patch_size[2], ) # Prepare modulation vectors. vec = self.time_in(t) if motion_exp != None: vec += self.motion_exp(motion_exp.view(-1)).view(bsz, -1) # (b, 3072) if motion_pose != None: vec += self.motion_pose(motion_pose.view(-1)).view(bsz, -1) # (b, 3072) if fps != None: vec += self.fps_proj(fps) # (b, 3072) if audio_prompts != None: audio_feature_all = self.audio_proj(audio_prompts) audio_feature_pad = audio_feature_all[:,:1].repeat(1,3,1,1) audio_feature_all_insert = torch.cat([audio_feature_pad, audio_feature_all], dim=1).view(bsz, ot, 16, 3072) audio_feature_all = None if self.i2v_condition_type == "token_replace": token_replace_t = torch.zeros_like(t) token_replace_vec = self.time_in(token_replace_t) frist_frame_token_num = th * tw else: token_replace_vec = None frist_frame_token_num = None # token_replace_mask_img = None # token_replace_mask_txt = None # text modulation vec_2 = self.vector_in(text_states_2) del text_states_2 vec += vec_2 if self.i2v_condition_type == "token_replace": token_replace_vec += vec_2 del vec_2 # guidance modulation if self.guidance_embed: if guidance is None: raise ValueError( "Didn't get guidance strength for guidance distilled model." ) # our timestep_embedding is merged into guidance_in(TimestepEmbedder) vec += self.guidance_in(guidance) # Embed image and text. img, shape_mask = self.img_in(img) if self.avatar: ref_latents_first = ref_latents[:, :, :1].clone() ref_latents,_ = self.ref_in(ref_latents) ref_latents_first,_ = self.img_in(ref_latents_first) elif self.custom: if ref_latents != None: ref_latents, _ = self.img_in(ref_latents) if bg_latents is not None and self.video_condition: bg_latents, _ = self.bg_in(bg_latents) img += self.bg_proj(bg_latents) if self.text_projection == "linear": txt = self.txt_in(txt) elif self.text_projection == "single_refiner": txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None) else: raise NotImplementedError( f"Unsupported text_projection: {self.text_projection}" ) if self.avatar: img += self.before_proj(ref_latents) ref_length = ref_latents_first.shape[-2] # [b s c] img = torch.cat([ref_latents_first, img], dim=-2) # t c img_len = img.shape[1] mask_len = img_len - ref_length if face_mask.shape[2] == 1: face_mask = face_mask.repeat(1,1,ot,1,1) # repeat if number of mask frame is 1 face_mask = torch.nn.functional.interpolate(face_mask, size=[ot, shape_mask[-2], shape_mask[-1]], mode="nearest") # face_mask = face_mask.view(-1,mask_len,1).repeat(1,1,img.shape[-1]).type_as(img) face_mask = face_mask.view(-1,mask_len,1).type_as(img) elif ref_latents == None: ref_length = None else: ref_length = ref_latents.shape[-2] img = torch.cat([ref_latents, img], dim=-2) # t c txt_seq_len = txt.shape[1] img_seq_len = img.shape[1] text_len = text_mask.sum(1) total_len = text_len + img_seq_len seqlens_q = seqlens_kv = total_len attn_mask = None freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None if self.enable_cache: if x_id == 0: self.should_calc = True inp = img[0:1] vec_ = vec[0:1] ( img_mod1_shift, img_mod1_scale, _ , _ , _ , _ , ) = self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1) normed_inp = self.double_blocks[0].img_norm1(inp) normed_inp = normed_inp.to(torch.bfloat16) modulated_inp = modulate( normed_inp, shift=img_mod1_shift, scale=img_mod1_scale ) del normed_inp, img_mod1_shift, img_mod1_scale if step_no <= self.cache_start_step or step_no == self.num_steps-1: self.accumulated_rel_l1_distance = 0 else: coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02] rescale_func = np.poly1d(coefficients) self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) if self.accumulated_rel_l1_distance < self.rel_l1_thresh: self.should_calc = False self.teacache_skipped_steps += 1 else: self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = modulated_inp else: self.should_calc = True if not self.should_calc: img += self.previous_residual[x_id] else: if self.enable_cache: self.previous_residual[x_id] = None ori_img = img[0:1].clone() # --------------------- Pass through DiT blocks ------------------------ for layer_num, block in enumerate(self.double_blocks): for i in range(len(img)): if callback != None: callback(-1, None, False, True) if pipeline._interrupt: return None double_block_args = [ img[i:i+1], txt[i:i+1], vec[i:i+1], attn_mask, seqlens_q[i:i+1], seqlens_kv[i:i+1], freqs_cis, self.i2v_condition_type, token_replace_vec, frist_frame_token_num, ] img[i], txt[i] = block(*double_block_args) double_block_args = None # insert audio feature to img if audio_prompts != None: audio_adapter = getattr(self.double_blocks[layer_num], "audio_adapter", None) if audio_adapter != None: real_img = img[i:i+1,ref_length:].view(1, ot, -1, 3072) real_img = audio_adapter(audio_feature_all_insert[i:i+1], real_img).view(1, -1, 3072) if face_mask != None: real_img *= face_mask[i:i+1] if audio_strength != None and audio_strength != 1: real_img *= audio_strength img[i:i+1, ref_length:] += real_img real_img = None for _, block in enumerate(self.single_blocks): for i in range(len(img)): if callback != None: callback(-1, None, False, True) if pipeline._interrupt: return None single_block_args = [ # x, img[i:i+1], txt[i:i+1], vec[i:i+1], txt_seq_len, attn_mask, seqlens_q[i:i+1], seqlens_kv[i:i+1], (freqs_cos, freqs_sin), self.i2v_condition_type, token_replace_vec, frist_frame_token_num, ] img[i], txt[i] = block(*single_block_args) single_block_args = None # img = x[:, :img_seq_len, ...] if self.enable_cache: if len(img) > 1: self.previous_residual[0] = torch.empty_like(img) for i, (x, residual) in enumerate(zip(img, self.previous_residual[0])): if i < len(img) - 1: residual[...] = torch.sub(x, ori_img) else: residual[...] = ori_img torch.sub(x, ori_img, out=residual) x = None else: self.previous_residual[x_id] = ori_img torch.sub(img, ori_img, out=self.previous_residual[x_id]) if ref_length != None: img = img[:, ref_length:] # ---------------------------- Final layer ------------------------------ out_dtype = self.final_layer.linear.weight.dtype vec = vec.to(out_dtype) img_list = [] for img_chunk, vec_chunk in zip(img,vec): img_list.append( self.final_layer(img_chunk.to(out_dtype).unsqueeze(0), vec_chunk.unsqueeze(0))) # (N, T, patch_size ** 2 * out_channels) img = torch.cat(img_list) img_list = None # img = self.unpatchify(img, tt, th, tw) img = self.unpatchify(img, tt, th, tw) return img def unpatchify(self, x, t, h, w): """ x: (N, T, patch_size**2 * C) imgs: (N, H, W, C) """ c = self.unpatchify_channels pt, ph, pw = self.patch_size assert t * h * w == x.shape[1] x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw)) x = torch.einsum("nthwcopq->nctohpwq", x) imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) return imgs def params_count(self): counts = { "double": sum( [ sum(p.numel() for p in block.img_attn_qkv.parameters()) + sum(p.numel() for p in block.img_attn_proj.parameters()) + sum(p.numel() for p in block.img_mlp.parameters()) + sum(p.numel() for p in block.txt_attn_qkv.parameters()) + sum(p.numel() for p in block.txt_attn_proj.parameters()) + sum(p.numel() for p in block.txt_mlp.parameters()) for block in self.double_blocks ] ), "single": sum( [ sum(p.numel() for p in block.linear1.parameters()) + sum(p.numel() for p in block.linear2.parameters()) for block in self.single_blocks ] ), "total": sum(p.numel() for p in self.parameters()), } counts["attn+mlp"] = counts["double"] + counts["single"] return counts ################################################################################# # HunyuanVideo Configs # ################################################################################# HUNYUAN_VIDEO_CONFIG = { "HYVideo-T/2": { "mm_double_blocks_depth": 20, "mm_single_blocks_depth": 40, "rope_dim_list": [16, 56, 56], "hidden_size": 3072, "heads_num": 24, "mlp_width_ratio": 4, }, "HYVideo-T/2-cfgdistill": { "mm_double_blocks_depth": 20, "mm_single_blocks_depth": 40, "rope_dim_list": [16, 56, 56], "hidden_size": 3072, "heads_num": 24, "mlp_width_ratio": 4, "guidance_embed": True, }, "HYVideo-S/2": { "mm_double_blocks_depth": 6, "mm_single_blocks_depth": 12, "rope_dim_list": [12, 42, 42], "hidden_size": 480, "heads_num": 5, "mlp_width_ratio": 4, }, 'HYVideo-T/2-custom': { # 9.0B / 12.5B "mm_double_blocks_depth": 20, "mm_single_blocks_depth": 40, "rope_dim_list": [16, 56, 56], "hidden_size": 3072, "heads_num": 24, "mlp_width_ratio": 4, 'custom' : True }, 'HYVideo-T/2-custom-audio': { # 9.0B / 12.5B "mm_double_blocks_depth": 20, "mm_single_blocks_depth": 40, "rope_dim_list": [16, 56, 56], "hidden_size": 3072, "heads_num": 24, "mlp_width_ratio": 4, 'custom' : True, 'audio_condition' : True, }, 'HYVideo-T/2-custom-edit': { # 9.0B / 12.5B "mm_double_blocks_depth": 20, "mm_single_blocks_depth": 40, "rope_dim_list": [16, 56, 56], "hidden_size": 3072, "heads_num": 24, "mlp_width_ratio": 4, 'custom' : True, 'video_condition' : True, }, 'HYVideo-T/2-avatar': { # 9.0B / 12.5B 'mm_double_blocks_depth': 20, 'mm_single_blocks_depth': 40, 'rope_dim_list': [16, 56, 56], 'hidden_size': 3072, 'heads_num': 24, 'mlp_width_ratio': 4, 'avatar': True, 'audio_condition' : True, }, }