import math from typing import List, Optional, Tuple, Union, Dict, Any import copy import numpy as np import torch import torch.nn.functional as F from torch import nn from diffusers import CogVideoXTransformer3DModel from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXBlock from diffusers.models.normalization import CogVideoXLayerNormZero from diffusers.models.attention import FeedForward from diffusers.models.attention_processor import CogVideoXAttnProcessor2_0, Attention from diffusers.models.embeddings import CogVideoXPatchEmbed from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from diffusers.utils.torch_utils import maybe_allow_in_graph from contextlib import contextmanager from peft.tuners.lora.layer import LoraLayer # PEFT의 LoRA 레이어 기본 클래스 import pdb # Code heavily borrowed from https://github.com/huggingface/diffusers class enable_lora: def __init__(self, modules, enable=True): self.modules = modules self.enable = enable self.prev_states = {} def __enter__(self): for module in self.modules: self.prev_states[module] = getattr(module, "lora_enabled", True) setattr(module, "lora_enabled", self.enable) return self def __exit__(self, exc_type, exc_val, exc_tb): for module in self.modules: setattr(module, "lora_enabled", self.prev_states[module]) return False class CustomCogVideoXPatchEmbed(CogVideoXPatchEmbed): def __init__(self, **kwargs): super().__init__(**kwargs) patch_size = kwargs['patch_size'] patch_size_t = kwargs['patch_size_t'] bias = kwargs['bias'] in_channels = kwargs['in_channels'] embed_dim = kwargs['embed_dim'] # projection layer for flow latents if patch_size_t is None: # CogVideoX 1.0 checkpoints self.flow_proj = nn.Conv2d(in_channels//2, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias) else: # CogVideoX 1.5 checkpoints self.flow_proj = nn.Linear(in_channels//2 * patch_size * patch_size * patch_size_t, embed_dim) # Add positional embedding for flow_embeds if self.use_positional_embeddings or self.use_learned_positional_embeddings: flow_pos_embedding = self._get_positional_embeddings(self.sample_height, self.sample_width, self.sample_frames)[:,self.max_text_seq_length:] # shape: [1, 17550, 3072] self.flow_pos_embedding = nn.Parameter(flow_pos_embedding) def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor, flow_embeds: torch.Tensor): r""" Args: text_embeds (`torch.Tensor`): Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim). image_embeds (`torch.Tensor`): Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). flow_embeds (`torch.Tensor`): Input flow embeddings. Expected shape: (batch_size, num_frames, channels, height, width). """ text_embeds = self.text_proj(text_embeds) batch_size, num_frames, channels, height, width = image_embeds.shape if self.patch_size_t is None: # embed video latents image_embeds = image_embeds.reshape(-1, channels, height, width) image_embeds = self.proj(image_embeds) image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:]) image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] # embed flow latents flow_embeds = flow_embeds.reshape(-1, channels//2, height, width) flow_embeds = self.flow_proj(flow_embeds) flow_embeds = flow_embeds.view(batch_size, num_frames, *flow_embeds.shape[1:]) flow_embeds = flow_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] flow_embeds = flow_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] else: p = self.patch_size p_t = self.patch_size_t # embed video latents image_embeds = image_embeds.permute(0, 1, 3, 4, 2) image_embeds = image_embeds.reshape( batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels ) image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3) image_embeds = self.proj(image_embeds) # embed flow latents flow_embeds = flow_embeds.permute(0, 1, 3, 4, 2) flow_embeds = flow_embeds.reshape( batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels//2 ) flow_embeds = flow_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3) flow_embeds = self.flow_proj(flow_embeds) # Curriculum learning of flow token # flow_embeds = self.flow_scale * flow_embeds embeds = torch.cat( [text_embeds, image_embeds, flow_embeds], dim=1 ).contiguous() # [batch, num_frames x height x width + seq_length + num_frames x height x width, channels] if self.use_positional_embeddings or self.use_learned_positional_embeddings: if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height): raise ValueError( "It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'." "If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues." ) pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 if ( self.sample_height != height or self.sample_width != width or self.sample_frames != pre_time_compression_frames ): pos_embedding = self._get_positional_embeddings( height, width, pre_time_compression_frames, device=embeds.device ) else: pos_embedding = self.pos_embedding # Previous version.. # pos_embedding = pos_embedding.to(dtype=embeds.dtype) # embeds = embeds + pos_embedding # Add flow embedding.. # flow_pos_embedding = self.flow_pos_scale * self.flow_pos_embedding flow_pos_embedding = self.flow_pos_embedding pos_embedding_total = torch.cat([pos_embedding, flow_pos_embedding], dim=1).to(dtype=embeds.dtype) embeds = embeds + pos_embedding_total return embeds @maybe_allow_in_graph class CustomCogVideoXBlock(nn.Module): r""" Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. time_embed_dim (`int`): The number of channels in timestep embedding. dropout (`float`, defaults to `0.0`): The dropout probability to use. activation_fn (`str`, defaults to `"gelu-approximate"`): Activation function to be used in feed-forward. attention_bias (`bool`, defaults to `False`): Whether or not to use bias in attention projection layers. qk_norm (`bool`, defaults to `True`): Whether or not to use normalization after query and key projections in Attention. norm_elementwise_affine (`bool`, defaults to `True`): Whether to use learnable elementwise affine parameters for normalization. norm_eps (`float`, defaults to `1e-5`): Epsilon value for normalization layers. final_dropout (`bool` defaults to `False`): Whether to apply a final dropout after the last feed-forward layer. ff_inner_dim (`int`, *optional*, defaults to `None`): Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used. ff_bias (`bool`, defaults to `True`): Whether or not to use bias in Feed-forward layer. attention_out_bias (`bool`, defaults to `True`): Whether or not to use bias in Attention output projection layer. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, time_embed_dim: int, dropout: float = 0.0, activation_fn: str = "gelu-approximate", attention_bias: bool = False, qk_norm: bool = True, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, final_dropout: bool = True, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, ): super().__init__() # 1. Self Attention self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) self.attn1 = Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, processor=CustomCogVideoXAttnProcessor2_0(), ) # 2. Feed Forward self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) attention_kwargs = attention_kwargs or {} # norm & modulate norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( hidden_states, encoder_hidden_states, temb ) # attention attn_hidden_states, attn_encoder_hidden_states = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, **attention_kwargs, ) hidden_states = hidden_states + gate_msa * attn_hidden_states encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states # norm & modulate norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( hidden_states, encoder_hidden_states, temb ) # feed-forward norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] return hidden_states, encoder_hidden_states class CustomCogVideoXAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on query and key vectors, but does not include spatial normalization. """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, notextinflow: Optional[bool] = False, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) if not notextinflow: hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) batch_size, sequence_length, _ = hidden_states.shape if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # Apply RoPE if needed if image_rotary_emb is not None: from diffusers.models.embeddings import apply_rotary_emb if not notextinflow: query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) if not attn.is_cross_attention: key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) else: query[:, :, :] = apply_rotary_emb(query[:, :, :], image_rotary_emb) if not attn.is_cross_attention: key[:, :, :] = apply_rotary_emb(key[:, :, :], image_rotary_emb) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if not notextinflow: encoder_hidden_states, hidden_states = hidden_states.split( [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 ) return hidden_states, encoder_hidden_states