Flov-space / finetune /modules /cogvideox_custom_modules.py
roll-ai's picture
Upload 177 files
59d751c verified
import math
from typing import List, Optional, Tuple, Union, Dict, Any
import copy
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from diffusers import CogVideoXTransformer3DModel
from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXBlock
from diffusers.models.normalization import CogVideoXLayerNormZero
from diffusers.models.attention import FeedForward
from diffusers.models.attention_processor import CogVideoXAttnProcessor2_0, Attention
from diffusers.models.embeddings import CogVideoXPatchEmbed
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from diffusers.utils.torch_utils import maybe_allow_in_graph
from contextlib import contextmanager
from peft.tuners.lora.layer import LoraLayer # PEFT의 LoRA 레이어 기본 클래스
import pdb
# Code heavily borrowed from https://github.com/huggingface/diffusers
class enable_lora:
def __init__(self, modules, enable=True):
self.modules = modules
self.enable = enable
self.prev_states = {}
def __enter__(self):
for module in self.modules:
self.prev_states[module] = getattr(module, "lora_enabled", True)
setattr(module, "lora_enabled", self.enable)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
for module in self.modules:
setattr(module, "lora_enabled", self.prev_states[module])
return False
class CustomCogVideoXPatchEmbed(CogVideoXPatchEmbed):
def __init__(self, **kwargs):
super().__init__(**kwargs)
patch_size = kwargs['patch_size']
patch_size_t = kwargs['patch_size_t']
bias = kwargs['bias']
in_channels = kwargs['in_channels']
embed_dim = kwargs['embed_dim']
# projection layer for flow latents
if patch_size_t is None:
# CogVideoX 1.0 checkpoints
self.flow_proj = nn.Conv2d(in_channels//2, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias)
else:
# CogVideoX 1.5 checkpoints
self.flow_proj = nn.Linear(in_channels//2 * patch_size * patch_size * patch_size_t, embed_dim)
# Add positional embedding for flow_embeds
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
flow_pos_embedding = self._get_positional_embeddings(self.sample_height, self.sample_width, self.sample_frames)[:,self.max_text_seq_length:] # shape: [1, 17550, 3072]
self.flow_pos_embedding = nn.Parameter(flow_pos_embedding)
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor, flow_embeds: torch.Tensor):
r"""
Args:
text_embeds (`torch.Tensor`):
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
image_embeds (`torch.Tensor`):
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
flow_embeds (`torch.Tensor`):
Input flow embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
"""
text_embeds = self.text_proj(text_embeds)
batch_size, num_frames, channels, height, width = image_embeds.shape
if self.patch_size_t is None:
# embed video latents
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
# embed flow latents
flow_embeds = flow_embeds.reshape(-1, channels//2, height, width)
flow_embeds = self.flow_proj(flow_embeds)
flow_embeds = flow_embeds.view(batch_size, num_frames, *flow_embeds.shape[1:])
flow_embeds = flow_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
flow_embeds = flow_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
else:
p = self.patch_size
p_t = self.patch_size_t
# embed video latents
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
image_embeds = image_embeds.reshape(
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
)
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
image_embeds = self.proj(image_embeds)
# embed flow latents
flow_embeds = flow_embeds.permute(0, 1, 3, 4, 2)
flow_embeds = flow_embeds.reshape(
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels//2
)
flow_embeds = flow_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
flow_embeds = self.flow_proj(flow_embeds)
# Curriculum learning of flow token
# flow_embeds = self.flow_scale * flow_embeds
embeds = torch.cat(
[text_embeds, image_embeds, flow_embeds], dim=1
).contiguous() # [batch, num_frames x height x width + seq_length + num_frames x height x width, channels]
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
raise ValueError(
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
"If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
)
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
if (
self.sample_height != height
or self.sample_width != width
or self.sample_frames != pre_time_compression_frames
):
pos_embedding = self._get_positional_embeddings(
height, width, pre_time_compression_frames, device=embeds.device
)
else:
pos_embedding = self.pos_embedding
# Previous version..
# pos_embedding = pos_embedding.to(dtype=embeds.dtype)
# embeds = embeds + pos_embedding
# Add flow embedding..
# flow_pos_embedding = self.flow_pos_scale * self.flow_pos_embedding
flow_pos_embedding = self.flow_pos_embedding
pos_embedding_total = torch.cat([pos_embedding, flow_pos_embedding], dim=1).to(dtype=embeds.dtype)
embeds = embeds + pos_embedding_total
return embeds
@maybe_allow_in_graph
class CustomCogVideoXBlock(nn.Module):
r"""
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
Parameters:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
time_embed_dim (`int`):
The number of channels in timestep embedding.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to be used in feed-forward.
attention_bias (`bool`, defaults to `False`):
Whether or not to use bias in attention projection layers.
qk_norm (`bool`, defaults to `True`):
Whether or not to use normalization after query and key projections in Attention.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_eps (`float`, defaults to `1e-5`):
Epsilon value for normalization layers.
final_dropout (`bool` defaults to `False`):
Whether to apply a final dropout after the last feed-forward layer.
ff_inner_dim (`int`, *optional*, defaults to `None`):
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
ff_bias (`bool`, defaults to `True`):
Whether or not to use bias in Feed-forward layer.
attention_out_bias (`bool`, defaults to `True`):
Whether or not to use bias in Attention output projection layer.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
time_embed_dim: int,
dropout: float = 0.0,
activation_fn: str = "gelu-approximate",
attention_bias: bool = False,
qk_norm: bool = True,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
final_dropout: bool = True,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
# 1. Self Attention
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.attn1 = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=attention_bias,
out_bias=attention_out_bias,
processor=CustomCogVideoXAttnProcessor2_0(),
)
# 2. Feed Forward
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
attention_kwargs = attention_kwargs or {}
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
# attention
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**attention_kwargs,
)
hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
hidden_states, encoder_hidden_states, temb
)
# feed-forward
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
return hidden_states, encoder_hidden_states
class CustomCogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
notextinflow: Optional[bool] = False,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
if not notextinflow:
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = hidden_states.shape
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
if not notextinflow:
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
else:
query[:, :, :] = apply_rotary_emb(query[:, :, :], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, :] = apply_rotary_emb(key[:, :, :], image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if not notextinflow:
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states