Spaces:
Paused
Paused
from typing import Any, Dict, Optional, Tuple, Union | |
import torch | |
import numpy as np | |
from diffusers.utils import is_torch_version | |
from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXTransformer3DModel, Transformer2DModelOutput | |
import pdb | |
class CustomCogVideoXTransformer3DModel(CogVideoXTransformer3DModel): | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: torch.Tensor, | |
timestep: Union[int, float, torch.LongTensor], | |
start_frame = None, | |
timestep_cond: Optional[torch.Tensor] = None, | |
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
controlnet_states: torch.Tensor = None, | |
controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0, | |
return_dict: bool = True, | |
): | |
batch_size, num_frames, channels, height, width = hidden_states.shape | |
if start_frame is not None: | |
hidden_states = torch.cat([start_frame, hidden_states], dim=2) | |
# 1. Time embedding | |
timesteps = timestep | |
t_emb = self.time_proj(timesteps) | |
# timesteps does not contain any weights and will always return f32 tensors | |
# but time_embedding might actually be running in fp16. so we need to cast here. | |
# there might be better ways to encapsulate this. | |
t_emb = t_emb.to(dtype=hidden_states.dtype) | |
emb = self.time_embedding(t_emb, timestep_cond) | |
# 2. Patch embedding | |
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) | |
hidden_states = self.embedding_dropout(hidden_states) | |
text_seq_length = encoder_hidden_states.shape[1] | |
encoder_hidden_states = hidden_states[:, :text_seq_length] | |
hidden_states = hidden_states[:, text_seq_length:] | |
# 3. Transformer blocks | |
for i, block in enumerate(self.transformer_blocks): | |
if self.training and self.gradient_checkpointing: | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module(*inputs) | |
return custom_forward | |
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(block), | |
hidden_states, | |
encoder_hidden_states, | |
emb, | |
image_rotary_emb, | |
**ckpt_kwargs, | |
) | |
else: | |
hidden_states, encoder_hidden_states = block( | |
hidden_states=hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
temb=emb, | |
image_rotary_emb=image_rotary_emb, | |
) | |
if (controlnet_states is not None) and (i < len(controlnet_states)): | |
controlnet_states_block = controlnet_states[i] | |
controlnet_block_weight = 1.0 | |
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights): | |
controlnet_block_weight = controlnet_weights[i] | |
elif isinstance(controlnet_weights, (float, int)): | |
controlnet_block_weight = controlnet_weights | |
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight | |
if not self.config.use_rotary_positional_embeddings: | |
# CogVideoX-2B | |
hidden_states = self.norm_final(hidden_states) | |
else: | |
# CogVideoX-5B | |
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) | |
hidden_states = self.norm_final(hidden_states) | |
hidden_states = hidden_states[:, text_seq_length:] | |
# 4. Final block | |
hidden_states = self.norm_out(hidden_states, temb=emb) | |
hidden_states = self.proj_out(hidden_states) | |
# 5. Unpatchify | |
p = self.config.patch_size | |
p_t = self.config.patch_size_t | |
if p_t is None: | |
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) | |
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) | |
else: | |
output = hidden_states.reshape( | |
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p | |
) | |
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) | |
if not return_dict: | |
return (output,) | |
return Transformer2DModelOutput(sample=output) |