from typing import Any, Dict, Optional, Tuple, Union import torch from torch import nn from einops import rearrange import torch.nn.functional as F from diffusers.configuration_utils import FrozenDict from diffusers import CogVideoXTransformer3DModel from diffusers.models.transformers.cogvideox_transformer_3d import Transformer2DModelOutput, CogVideoXBlock from diffusers.utils import is_torch_version from diffusers.loaders import PeftAdapterMixin from diffusers.utils.torch_utils import maybe_allow_in_graph from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed from diffusers.models.modeling_utils import ModelMixin from diffusers.models.attention import Attention, FeedForward from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor2_0 from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero, AdaLayerNormZeroSingle from diffusers.configuration_utils import ConfigMixin, register_to_config from .cogvideox_custom_modules import CustomCogVideoXPatchEmbed, CustomCogVideoXBlock import pdb class CogVideoXControlnet(ModelMixin, ConfigMixin, PeftAdapterMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, num_attention_heads: int = 30, # 48 for 5B, 30 for 2B. attention_head_dim: int = 64, # in_channels: int = 3, in_channels: int = 16, out_channels: Optional[int] = 16, # Not used flip_sin_to_cos: bool = True, freq_shift: int = 0, time_embed_dim: int = 512, ofs_embed_dim: Optional[int] = None, text_embed_dim: int = 4096, num_layers: int = 30, dropout: float = 0.0, attention_bias: bool = True, sample_width: int = 90, sample_height: int = 60, sample_frames: int = 49, patch_size: int = 2, patch_size_t: Optional[int] = None, temporal_compression_ratio: int = 4, max_text_seq_length: int = 226, activation_fn: str = "gelu-approximate", timestep_activation_fn: str = "silu", norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, spatial_interpolation_scale: float = 1.875, temporal_interpolation_scale: float = 1.0, use_rotary_positional_embeddings: bool = False, use_learned_positional_embeddings: bool = False, patch_bias: bool = True, out_proj_dim_factor: int = 8, out_proj_dim_zero_init: bool = True, notextinflow: bool = False, ): super().__init__() inner_dim = num_attention_heads * attention_head_dim self.notextinflow = notextinflow if not use_rotary_positional_embeddings and use_learned_positional_embeddings: raise ValueError( "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional " "embeddings. If you're using a custom model and/or believe this should be supported, please open an " "issue at https://github.com/huggingface/diffusers/issues." ) """ Delete below. In our case, FloVD, controlnet_hidden_states is already flow_latents encoded by 3D-Causal-VAE """ # start_channels = in_channels * (downscale_coef ** 2) # input_channels = [start_channels, start_channels // 2, start_channels // 4] # self.unshuffle = nn.PixelUnshuffle(downscale_coef) # self.controlnet_encode_first = nn.Sequential( # nn.Conv2d(input_channels[0], input_channels[1], kernel_size=1, stride=1, padding=0), # nn.GroupNorm(2, input_channels[1]), # nn.ReLU(), # ) # self.controlnet_encode_second = nn.Sequential( # nn.Conv2d(input_channels[1], input_channels[2], kernel_size=1, stride=1, padding=0), # nn.GroupNorm(2, input_channels[2]), # nn.ReLU(), # ) # """ # Modify below. # In our case, patch_embed takes encoder_hidden_states, hidden_states, controlnet_hidden_states (flow) # """ # 1. Patch embedding self.patch_embed = CogVideoXPatchEmbed( patch_size=patch_size, in_channels=in_channels, embed_dim=inner_dim, bias=True, sample_width=sample_width, sample_height=sample_height, sample_frames=sample_frames, temporal_compression_ratio=temporal_compression_ratio, spatial_interpolation_scale=spatial_interpolation_scale, temporal_interpolation_scale=temporal_interpolation_scale, use_positional_embeddings=not use_rotary_positional_embeddings, use_learned_positional_embeddings=use_learned_positional_embeddings, ) # self.patch_embed = CustomCogVideoXPatchEmbed( # patch_size=patch_size, # patch_size_t=patch_size_t, # in_channels=in_channels, # embed_dim=inner_dim, # text_embed_dim=text_embed_dim, # bias=patch_bias, # sample_width=sample_width, # sample_height=sample_height, # sample_frames=sample_frames, # temporal_compression_ratio=temporal_compression_ratio, # max_text_seq_length=max_text_seq_length, # spatial_interpolation_scale=spatial_interpolation_scale, # temporal_interpolation_scale=temporal_interpolation_scale, # use_positional_embeddings=not use_rotary_positional_embeddings, # use_learned_positional_embeddings=use_learned_positional_embeddings, # ) self.embedding_dropout = nn.Dropout(dropout) # 2. Time embeddings self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) # 3. Define spatio-temporal transformers blocks # self.transformer_blocks = nn.ModuleList( # [ # CogVideoXBlock( # dim=inner_dim, # num_attention_heads=num_attention_heads, # attention_head_dim=attention_head_dim, # time_embed_dim=time_embed_dim, # dropout=dropout, # activation_fn=activation_fn, # attention_bias=attention_bias, # norm_elementwise_affine=norm_elementwise_affine, # norm_eps=norm_eps, # ) # for _ in range(num_layers) # ] # ) self.transformer_blocks = nn.ModuleList( [ CustomCogVideoXBlock( dim=inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, time_embed_dim=time_embed_dim, dropout=dropout, activation_fn=activation_fn, attention_bias=attention_bias, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, ) for _ in range(num_layers) ] ) self.out_projectors = None if out_proj_dim_factor is not None: out_proj_dim = num_attention_heads * out_proj_dim_factor self.out_projectors = nn.ModuleList( [nn.Linear(inner_dim, out_proj_dim) for _ in range(num_layers)] ) if out_proj_dim_zero_init: for out_projector in self.out_projectors: self.zeros_init_linear(out_projector) self.gradient_checkpointing = False def zeros_init_linear(self, linear: nn.Module): if isinstance(linear, (nn.Linear, nn.Conv1d)): if hasattr(linear, "weight"): nn.init.zeros_(linear.weight) if hasattr(linear, "bias"): nn.init.zeros_(linear.bias) def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value def compress_time(self, x, num_frames): x = rearrange(x, '(b f) c h w -> b f c h w', f=num_frames) batch_size, frames, channels, height, width = x.shape x = rearrange(x, 'b f c h w -> (b h w) c f') if x.shape[-1] % 2 == 1: x_first, x_rest = x[..., 0], x[..., 1:] if x_rest.shape[-1] > 0: x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) x = torch.cat([x_first[..., None], x_rest], dim=-1) else: x = F.avg_pool1d(x, kernel_size=2, stride=2) x = rearrange(x, '(b h w) c f -> (b f) c h w', b=batch_size, h=height, w=width) return x # """ # Add below. # Load pre-trained weight from Diffusers # For patch_embed, copy a projection layer for controlnet_states # """ @classmethod def from_pretrained(cls, model_path, subfolder, **additional_kwargs): base = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder=subfolder) controlnet_config = FrozenDict({**base.config, **additional_kwargs}) model = cls(**controlnet_config) missing, unexpected = model.load_state_dict(base.state_dict(), strict=False) print(f"Load CogVideoXTransformer3DModel.") # if len(missing) != 0 or len(unexpected) != 0: # print(f"Missing keys: {missing}") # print(f"Unexpected keys: {unexpected}") del base torch.cuda.empty_cache() return model def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, controlnet_hidden_states: torch.Tensor, timestep: Union[int, float, torch.LongTensor], controlnet_valid_mask: torch.Tensor = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, timestep_cond: Optional[torch.Tensor] = None, return_dict: bool = True, ): """ Delete below. In our case, FloVD, controlnet_hidden_states is already flow_latents encoded by 3D-Causal-VAE """ # batch_size, num_frames, channels, height, width = controlnet_states.shape # # 0. Controlnet encoder # controlnet_states = rearrange(controlnet_states, 'b f c h w -> (b f) c h w') # controlnet_states = self.unshuffle(controlnet_states) # controlnet_states = self.controlnet_encode_first(controlnet_states) # controlnet_states = self.compress_time(controlnet_states, num_frames=num_frames) # num_frames = controlnet_states.shape[0] // batch_size # controlnet_states = self.controlnet_encode_second(controlnet_states) # controlnet_states = self.compress_time(controlnet_states, num_frames=num_frames) # controlnet_states = rearrange(controlnet_states, '(b f) c h w -> b f c h w', b=batch_size) batch_size, num_frames, channels, height, width = hidden_states.shape # """ # Modify below. # Distinguish hidden_states and controlnet_states (i.e., flow_hidden_states) # """ hidden_states = torch.cat([hidden_states, controlnet_hidden_states], dim=2) # instead of image_latents, we use flow_latents for condition. # controlnet_states = self.controlnext_encoder(controlnet_states, timestep=timestep) # 1. Time embedding timesteps = timestep t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=hidden_states.dtype) emb = self.time_embedding(t_emb, timestep_cond) # """ # Modify below. # patch_embed takes encoder, hidden_states, controlnet_hidden_states # """ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) # hidden_states = self.patch_embed(encoder_hidden_states, hidden_states, controlnet_hidden_states) # output: [text_embeds, image_embeds, flow_embeds] [B, 35326, 3072] hidden_states = self.embedding_dropout(hidden_states) """ Not modified below. hidden_states include both hidden_states and controlnet_hidden_states """ text_seq_length = encoder_hidden_states.shape[1] encoder_hidden_states = hidden_states[:, :text_seq_length] # [text_embeds] [B, 226, 3072] hidden_states = hidden_states[:, text_seq_length:] # [image_embeds, flow_embeds] [B, 35100, 3072] # attention mask if controlnet_valid_mask is not None: mask_shape = controlnet_valid_mask.shape attention_mask = torch.nn.functional.interpolate(controlnet_valid_mask, size=(mask_shape[2], mask_shape[3]//2, mask_shape[4]//2), mode='trilinear', align_corners=False) # CFHW attention_mask[attention_mask>=0.5] = 1 attention_mask[attention_mask<0.5] = 0 attention_mask = attention_mask.to(torch.bool) attention_mask = rearrange(attention_mask.squeeze(1), 'b f h w -> b (f h w)') # (B, N=(fxhxw)) # Consider encoder_hidden_states.. or do not use?? not sure.. if not self.notextinflow: attention_mask = F.pad(attention_mask, (text_seq_length, 0), value=0.0) attention_kwargs = { 'attention_mask': attention_mask if controlnet_valid_mask is not None else None, 'notextinflow': self.notextinflow, } controlnet_hidden_states = () # 3. Transformer blocks for i, block in enumerate(self.transformer_blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states, emb, image_rotary_emb, attention_kwargs, **ckpt_kwargs, ) else: hidden_states, encoder_hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=emb, image_rotary_emb=image_rotary_emb, attention_kwargs=attention_kwargs, ) if self.out_projectors is not None: controlnet_hidden_states += (self.out_projectors[i](hidden_states),) else: controlnet_hidden_states += (hidden_states,) if not return_dict: return (controlnet_hidden_states,) return Transformer2DModelOutput(sample=controlnet_hidden_states)