Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| # original code from: https://github.com/nvidia-cosmos/cosmos-predict2 | |
| import torch | |
| from torch import nn | |
| from einops import rearrange | |
| from einops.layers.torch import Rearrange | |
| import logging | |
| from typing import Callable, Optional, Tuple | |
| import math | |
| from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis | |
| from torchvision import transforms | |
| import comfy.patcher_extension | |
| from comfy.ldm.modules.attention import optimized_attention | |
| def apply_rotary_pos_emb( | |
| t: torch.Tensor, | |
| freqs: torch.Tensor, | |
| ) -> torch.Tensor: | |
| t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float() | |
| t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1] | |
| t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t) | |
| return t_out | |
| # ---------------------- Feed Forward Network ----------------------- | |
| class GPT2FeedForward(nn.Module): | |
| def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None: | |
| super().__init__() | |
| self.activation = nn.GELU() | |
| self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype) | |
| self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype) | |
| self._layer_id = None | |
| self._dim = d_model | |
| self._hidden_dim = d_ff | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.layer1(x) | |
| x = self.activation(x) | |
| x = self.layer2(x) | |
| return x | |
| def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor: | |
| """Computes multi-head attention using PyTorch's native implementation. | |
| This function provides a PyTorch backend alternative to Transformer Engine's attention operation. | |
| It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product | |
| attention, and rearranges the output back to the original format. | |
| The input tensor names use the following dimension conventions: | |
| - B: batch size | |
| - S: sequence length | |
| - H: number of attention heads | |
| - D: head dimension | |
| Args: | |
| q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim) | |
| k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim) | |
| v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim) | |
| Returns: | |
| Attention output tensor with shape (batch, seq_len, n_heads * head_dim) | |
| """ | |
| in_q_shape = q_B_S_H_D.shape | |
| in_k_shape = k_B_S_H_D.shape | |
| q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1]) | |
| k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) | |
| v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) | |
| return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, transformer_options=transformer_options) | |
| class Attention(nn.Module): | |
| """ | |
| A flexible attention module supporting both self-attention and cross-attention mechanisms. | |
| This module implements a multi-head attention layer that can operate in either self-attention | |
| or cross-attention mode. The mode is determined by whether a context dimension is provided. | |
| The implementation uses scaled dot-product attention and supports optional bias terms and | |
| dropout regularization. | |
| Args: | |
| query_dim (int): The dimensionality of the query vectors. | |
| context_dim (int, optional): The dimensionality of the context (key/value) vectors. | |
| If None, the module operates in self-attention mode using query_dim. Default: None | |
| n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8 | |
| head_dim (int, optional): The dimension of each attention head. Default: 64 | |
| dropout (float, optional): Dropout probability applied to the output. Default: 0.0 | |
| qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd" | |
| backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine" | |
| Examples: | |
| >>> # Self-attention with 512 dimensions and 8 heads | |
| >>> self_attn = Attention(query_dim=512) | |
| >>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim) | |
| >>> out = self_attn(x) # (32, 16, 512) | |
| >>> # Cross-attention | |
| >>> cross_attn = Attention(query_dim=512, context_dim=256) | |
| >>> query = torch.randn(32, 16, 512) | |
| >>> context = torch.randn(32, 8, 256) | |
| >>> out = cross_attn(query, context) # (32, 16, 512) | |
| """ | |
| def __init__( | |
| self, | |
| query_dim: int, | |
| context_dim: Optional[int] = None, | |
| n_heads: int = 8, | |
| head_dim: int = 64, | |
| dropout: float = 0.0, | |
| device=None, | |
| dtype=None, | |
| operations=None, | |
| ) -> None: | |
| super().__init__() | |
| logging.debug( | |
| f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " | |
| f"{n_heads} heads with a dimension of {head_dim}." | |
| ) | |
| self.is_selfattn = context_dim is None # self attention | |
| context_dim = query_dim if context_dim is None else context_dim | |
| inner_dim = head_dim * n_heads | |
| self.n_heads = n_heads | |
| self.head_dim = head_dim | |
| self.query_dim = query_dim | |
| self.context_dim = context_dim | |
| self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype) | |
| self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) | |
| self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) | |
| self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) | |
| self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) | |
| self.v_norm = nn.Identity() | |
| self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype) | |
| self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity() | |
| self.attn_op = torch_attention_op | |
| self._query_dim = query_dim | |
| self._context_dim = context_dim | |
| self._inner_dim = inner_dim | |
| def compute_qkv( | |
| self, | |
| x: torch.Tensor, | |
| context: Optional[torch.Tensor] = None, | |
| rope_emb: Optional[torch.Tensor] = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| q = self.q_proj(x) | |
| context = x if context is None else context | |
| k = self.k_proj(context) | |
| v = self.v_proj(context) | |
| q, k, v = map( | |
| lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim), | |
| (q, k, v), | |
| ) | |
| def apply_norm_and_rotary_pos_emb( | |
| q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor] | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| q = self.q_norm(q) | |
| k = self.k_norm(k) | |
| v = self.v_norm(v) | |
| if self.is_selfattn and rope_emb is not None: # only apply to self-attention! | |
| q = apply_rotary_pos_emb(q, rope_emb) | |
| k = apply_rotary_pos_emb(k, rope_emb) | |
| return q, k, v | |
| q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb) | |
| return q, k, v | |
| def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor: | |
| result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D] | |
| return self.output_dropout(self.output_proj(result)) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| context: Optional[torch.Tensor] = None, | |
| rope_emb: Optional[torch.Tensor] = None, | |
| transformer_options: Optional[dict] = {}, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| x (Tensor): The query tensor of shape [B, Mq, K] | |
| context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None | |
| """ | |
| q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb) | |
| return self.compute_attention(q, k, v, transformer_options=transformer_options) | |
| class Timesteps(nn.Module): | |
| def __init__(self, num_channels: int): | |
| super().__init__() | |
| self.num_channels = num_channels | |
| def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor: | |
| assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}" | |
| timesteps = timesteps_B_T.flatten().float() | |
| half_dim = self.num_channels // 2 | |
| exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) | |
| exponent = exponent / (half_dim - 0.0) | |
| emb = torch.exp(exponent) | |
| emb = timesteps[:, None].float() * emb[None, :] | |
| sin_emb = torch.sin(emb) | |
| cos_emb = torch.cos(emb) | |
| emb = torch.cat([cos_emb, sin_emb], dim=-1) | |
| return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1]) | |
| class TimestepEmbedding(nn.Module): | |
| def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None): | |
| super().__init__() | |
| logging.debug( | |
| f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility." | |
| ) | |
| self.in_dim = in_features | |
| self.out_dim = out_features | |
| self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype) | |
| self.activation = nn.SiLU() | |
| self.use_adaln_lora = use_adaln_lora | |
| if use_adaln_lora: | |
| self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype) | |
| else: | |
| self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype) | |
| def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| emb = self.linear_1(sample) | |
| emb = self.activation(emb) | |
| emb = self.linear_2(emb) | |
| if self.use_adaln_lora: | |
| adaln_lora_B_T_3D = emb | |
| emb_B_T_D = sample | |
| else: | |
| adaln_lora_B_T_3D = None | |
| emb_B_T_D = emb | |
| return emb_B_T_D, adaln_lora_B_T_3D | |
| class PatchEmbed(nn.Module): | |
| """ | |
| PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers, | |
| depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions, | |
| making it suitable for video and image processing tasks. It supports dividing the input into patches | |
| and embedding each patch into a vector of size `out_channels`. | |
| Parameters: | |
| - spatial_patch_size (int): The size of each spatial patch. | |
| - temporal_patch_size (int): The size of each temporal patch. | |
| - in_channels (int): Number of input channels. Default: 3. | |
| - out_channels (int): The dimension of the embedding vector for each patch. Default: 768. | |
| - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True. | |
| """ | |
| def __init__( | |
| self, | |
| spatial_patch_size: int, | |
| temporal_patch_size: int, | |
| in_channels: int = 3, | |
| out_channels: int = 768, | |
| device=None, dtype=None, operations=None | |
| ): | |
| super().__init__() | |
| self.spatial_patch_size = spatial_patch_size | |
| self.temporal_patch_size = temporal_patch_size | |
| self.proj = nn.Sequential( | |
| Rearrange( | |
| "b c (t r) (h m) (w n) -> b t h w (c r m n)", | |
| r=temporal_patch_size, | |
| m=spatial_patch_size, | |
| n=spatial_patch_size, | |
| ), | |
| operations.Linear( | |
| in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype | |
| ), | |
| ) | |
| self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass of the PatchEmbed module. | |
| Parameters: | |
| - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where | |
| B is the batch size, | |
| C is the number of channels, | |
| T is the temporal dimension, | |
| H is the height, and | |
| W is the width of the input. | |
| Returns: | |
| - torch.Tensor: The embedded patches as a tensor, with shape b t h w c. | |
| """ | |
| assert x.dim() == 5 | |
| _, _, T, H, W = x.shape | |
| assert ( | |
| H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0 | |
| ), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}" | |
| assert T % self.temporal_patch_size == 0 | |
| x = self.proj(x) | |
| return x | |
| class FinalLayer(nn.Module): | |
| """ | |
| The final layer of video DiT. | |
| """ | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| spatial_patch_size: int, | |
| temporal_patch_size: int, | |
| out_channels: int, | |
| use_adaln_lora: bool = False, | |
| adaln_lora_dim: int = 256, | |
| device=None, dtype=None, operations=None | |
| ): | |
| super().__init__() | |
| self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| self.linear = operations.Linear( | |
| hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype | |
| ) | |
| self.hidden_size = hidden_size | |
| self.n_adaln_chunks = 2 | |
| self.use_adaln_lora = use_adaln_lora | |
| self.adaln_lora_dim = adaln_lora_dim | |
| if use_adaln_lora: | |
| self.adaln_modulation = nn.Sequential( | |
| nn.SiLU(), | |
| operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype), | |
| operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype), | |
| ) | |
| else: | |
| self.adaln_modulation = nn.Sequential( | |
| nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype) | |
| ) | |
| def forward( | |
| self, | |
| x_B_T_H_W_D: torch.Tensor, | |
| emb_B_T_D: torch.Tensor, | |
| adaln_lora_B_T_3D: Optional[torch.Tensor] = None, | |
| ): | |
| if self.use_adaln_lora: | |
| assert adaln_lora_B_T_3D is not None | |
| shift_B_T_D, scale_B_T_D = ( | |
| self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size] | |
| ).chunk(2, dim=-1) | |
| else: | |
| shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1) | |
| shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange( | |
| scale_B_T_D, "b t d -> b t 1 1 d" | |
| ) | |
| def _fn( | |
| _x_B_T_H_W_D: torch.Tensor, | |
| _norm_layer: nn.Module, | |
| _scale_B_T_1_1_D: torch.Tensor, | |
| _shift_B_T_1_1_D: torch.Tensor, | |
| ) -> torch.Tensor: | |
| return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D | |
| x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D) | |
| x_B_T_H_W_O = self.linear(x_B_T_H_W_D) | |
| return x_B_T_H_W_O | |
| class Block(nn.Module): | |
| """ | |
| A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation. | |
| Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation. | |
| Parameters: | |
| x_dim (int): Dimension of input features | |
| context_dim (int): Dimension of context features for cross-attention | |
| num_heads (int): Number of attention heads | |
| mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0 | |
| use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False | |
| adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256 | |
| The block applies the following sequence: | |
| 1. Self-attention with AdaLN modulation | |
| 2. Cross-attention with AdaLN modulation | |
| 3. MLP with AdaLN modulation | |
| Each component uses skip connections and layer normalization. | |
| """ | |
| def __init__( | |
| self, | |
| x_dim: int, | |
| context_dim: int, | |
| num_heads: int, | |
| mlp_ratio: float = 4.0, | |
| use_adaln_lora: bool = False, | |
| adaln_lora_dim: int = 256, | |
| device=None, | |
| dtype=None, | |
| operations=None, | |
| ): | |
| super().__init__() | |
| self.x_dim = x_dim | |
| self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) | |
| self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations) | |
| self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) | |
| self.cross_attn = Attention( | |
| x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations | |
| ) | |
| self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) | |
| self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations) | |
| self.use_adaln_lora = use_adaln_lora | |
| if self.use_adaln_lora: | |
| self.adaln_modulation_self_attn = nn.Sequential( | |
| nn.SiLU(), | |
| operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), | |
| operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), | |
| ) | |
| self.adaln_modulation_cross_attn = nn.Sequential( | |
| nn.SiLU(), | |
| operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), | |
| operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), | |
| ) | |
| self.adaln_modulation_mlp = nn.Sequential( | |
| nn.SiLU(), | |
| operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), | |
| operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), | |
| ) | |
| else: | |
| self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) | |
| self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) | |
| self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) | |
| def forward( | |
| self, | |
| x_B_T_H_W_D: torch.Tensor, | |
| emb_B_T_D: torch.Tensor, | |
| crossattn_emb: torch.Tensor, | |
| rope_emb_L_1_1_D: Optional[torch.Tensor] = None, | |
| adaln_lora_B_T_3D: Optional[torch.Tensor] = None, | |
| extra_per_block_pos_emb: Optional[torch.Tensor] = None, | |
| transformer_options: Optional[dict] = {}, | |
| ) -> torch.Tensor: | |
| if extra_per_block_pos_emb is not None: | |
| x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb | |
| if self.use_adaln_lora: | |
| shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = ( | |
| self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D | |
| ).chunk(3, dim=-1) | |
| shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = ( | |
| self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D | |
| ).chunk(3, dim=-1) | |
| shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = ( | |
| self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D | |
| ).chunk(3, dim=-1) | |
| else: | |
| shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn( | |
| emb_B_T_D | |
| ).chunk(3, dim=-1) | |
| shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn( | |
| emb_B_T_D | |
| ).chunk(3, dim=-1) | |
| shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1) | |
| # Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting | |
| shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d") | |
| scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d") | |
| gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d") | |
| shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d") | |
| scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d") | |
| gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d") | |
| shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d") | |
| scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d") | |
| gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d") | |
| B, T, H, W, D = x_B_T_H_W_D.shape | |
| def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D): | |
| return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D | |
| normalized_x_B_T_H_W_D = _fn( | |
| x_B_T_H_W_D, | |
| self.layer_norm_self_attn, | |
| scale_self_attn_B_T_1_1_D, | |
| shift_self_attn_B_T_1_1_D, | |
| ) | |
| result_B_T_H_W_D = rearrange( | |
| self.self_attn( | |
| # normalized_x_B_T_HW_D, | |
| rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), | |
| None, | |
| rope_emb=rope_emb_L_1_1_D, | |
| transformer_options=transformer_options, | |
| ), | |
| "b (t h w) d -> b t h w d", | |
| t=T, | |
| h=H, | |
| w=W, | |
| ) | |
| x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D | |
| def _x_fn( | |
| _x_B_T_H_W_D: torch.Tensor, | |
| layer_norm_cross_attn: Callable, | |
| _scale_cross_attn_B_T_1_1_D: torch.Tensor, | |
| _shift_cross_attn_B_T_1_1_D: torch.Tensor, | |
| transformer_options: Optional[dict] = {}, | |
| ) -> torch.Tensor: | |
| _normalized_x_B_T_H_W_D = _fn( | |
| _x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D | |
| ) | |
| _result_B_T_H_W_D = rearrange( | |
| self.cross_attn( | |
| rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), | |
| crossattn_emb, | |
| rope_emb=rope_emb_L_1_1_D, | |
| transformer_options=transformer_options, | |
| ), | |
| "b (t h w) d -> b t h w d", | |
| t=T, | |
| h=H, | |
| w=W, | |
| ) | |
| return _result_B_T_H_W_D | |
| result_B_T_H_W_D = _x_fn( | |
| x_B_T_H_W_D, | |
| self.layer_norm_cross_attn, | |
| scale_cross_attn_B_T_1_1_D, | |
| shift_cross_attn_B_T_1_1_D, | |
| transformer_options=transformer_options, | |
| ) | |
| x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D | |
| normalized_x_B_T_H_W_D = _fn( | |
| x_B_T_H_W_D, | |
| self.layer_norm_mlp, | |
| scale_mlp_B_T_1_1_D, | |
| shift_mlp_B_T_1_1_D, | |
| ) | |
| result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D) | |
| x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D | |
| return x_B_T_H_W_D | |
| class MiniTrainDIT(nn.Module): | |
| """ | |
| A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1) | |
| A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. | |
| Args: | |
| max_img_h (int): Maximum height of the input images. | |
| max_img_w (int): Maximum width of the input images. | |
| max_frames (int): Maximum number of frames in the video sequence. | |
| in_channels (int): Number of input channels (e.g., RGB channels for color images). | |
| out_channels (int): Number of output channels. | |
| patch_spatial (tuple): Spatial resolution of patches for input processing. | |
| patch_temporal (int): Temporal resolution of patches for input processing. | |
| concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding. | |
| model_channels (int): Base number of channels used throughout the model. | |
| num_blocks (int): Number of transformer blocks. | |
| num_heads (int): Number of heads in the multi-head attention layers. | |
| mlp_ratio (float): Expansion ratio for MLP blocks. | |
| crossattn_emb_channels (int): Number of embedding channels for cross-attention. | |
| pos_emb_cls (str): Type of positional embeddings. | |
| pos_emb_learnable (bool): Whether positional embeddings are learnable. | |
| pos_emb_interpolation (str): Method for interpolating positional embeddings. | |
| min_fps (int): Minimum frames per second. | |
| max_fps (int): Maximum frames per second. | |
| use_adaln_lora (bool): Whether to use AdaLN-LoRA. | |
| adaln_lora_dim (int): Dimension for AdaLN-LoRA. | |
| rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE. | |
| rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE. | |
| rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE. | |
| extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings. | |
| extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings. | |
| extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings. | |
| extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings. | |
| """ | |
| def __init__( | |
| self, | |
| max_img_h: int, | |
| max_img_w: int, | |
| max_frames: int, | |
| in_channels: int, | |
| out_channels: int, | |
| patch_spatial: int, # tuple, | |
| patch_temporal: int, | |
| concat_padding_mask: bool = True, | |
| # attention settings | |
| model_channels: int = 768, | |
| num_blocks: int = 10, | |
| num_heads: int = 16, | |
| mlp_ratio: float = 4.0, | |
| # cross attention settings | |
| crossattn_emb_channels: int = 1024, | |
| # positional embedding settings | |
| pos_emb_cls: str = "sincos", | |
| pos_emb_learnable: bool = False, | |
| pos_emb_interpolation: str = "crop", | |
| min_fps: int = 1, | |
| max_fps: int = 30, | |
| use_adaln_lora: bool = False, | |
| adaln_lora_dim: int = 256, | |
| rope_h_extrapolation_ratio: float = 1.0, | |
| rope_w_extrapolation_ratio: float = 1.0, | |
| rope_t_extrapolation_ratio: float = 1.0, | |
| extra_per_block_abs_pos_emb: bool = False, | |
| extra_h_extrapolation_ratio: float = 1.0, | |
| extra_w_extrapolation_ratio: float = 1.0, | |
| extra_t_extrapolation_ratio: float = 1.0, | |
| rope_enable_fps_modulation: bool = True, | |
| image_model=None, | |
| device=None, | |
| dtype=None, | |
| operations=None, | |
| ) -> None: | |
| super().__init__() | |
| self.dtype = dtype | |
| self.max_img_h = max_img_h | |
| self.max_img_w = max_img_w | |
| self.max_frames = max_frames | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.patch_spatial = patch_spatial | |
| self.patch_temporal = patch_temporal | |
| self.num_heads = num_heads | |
| self.num_blocks = num_blocks | |
| self.model_channels = model_channels | |
| self.concat_padding_mask = concat_padding_mask | |
| # positional embedding settings | |
| self.pos_emb_cls = pos_emb_cls | |
| self.pos_emb_learnable = pos_emb_learnable | |
| self.pos_emb_interpolation = pos_emb_interpolation | |
| self.min_fps = min_fps | |
| self.max_fps = max_fps | |
| self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio | |
| self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio | |
| self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio | |
| self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb | |
| self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio | |
| self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio | |
| self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio | |
| self.rope_enable_fps_modulation = rope_enable_fps_modulation | |
| self.build_pos_embed(device=device, dtype=dtype) | |
| self.use_adaln_lora = use_adaln_lora | |
| self.adaln_lora_dim = adaln_lora_dim | |
| self.t_embedder = nn.Sequential( | |
| Timesteps(model_channels), | |
| TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,), | |
| ) | |
| in_channels = in_channels + 1 if concat_padding_mask else in_channels | |
| self.x_embedder = PatchEmbed( | |
| spatial_patch_size=patch_spatial, | |
| temporal_patch_size=patch_temporal, | |
| in_channels=in_channels, | |
| out_channels=model_channels, | |
| device=device, dtype=dtype, operations=operations, | |
| ) | |
| self.blocks = nn.ModuleList( | |
| [ | |
| Block( | |
| x_dim=model_channels, | |
| context_dim=crossattn_emb_channels, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| use_adaln_lora=use_adaln_lora, | |
| adaln_lora_dim=adaln_lora_dim, | |
| device=device, dtype=dtype, operations=operations, | |
| ) | |
| for _ in range(num_blocks) | |
| ] | |
| ) | |
| self.final_layer = FinalLayer( | |
| hidden_size=self.model_channels, | |
| spatial_patch_size=self.patch_spatial, | |
| temporal_patch_size=self.patch_temporal, | |
| out_channels=self.out_channels, | |
| use_adaln_lora=self.use_adaln_lora, | |
| adaln_lora_dim=self.adaln_lora_dim, | |
| device=device, dtype=dtype, operations=operations, | |
| ) | |
| self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype) | |
| def build_pos_embed(self, device=None, dtype=None) -> None: | |
| if self.pos_emb_cls == "rope3d": | |
| cls_type = VideoRopePosition3DEmb | |
| else: | |
| raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") | |
| logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") | |
| kwargs = dict( | |
| model_channels=self.model_channels, | |
| len_h=self.max_img_h // self.patch_spatial, | |
| len_w=self.max_img_w // self.patch_spatial, | |
| len_t=self.max_frames // self.patch_temporal, | |
| max_fps=self.max_fps, | |
| min_fps=self.min_fps, | |
| is_learnable=self.pos_emb_learnable, | |
| interpolation=self.pos_emb_interpolation, | |
| head_dim=self.model_channels // self.num_heads, | |
| h_extrapolation_ratio=self.rope_h_extrapolation_ratio, | |
| w_extrapolation_ratio=self.rope_w_extrapolation_ratio, | |
| t_extrapolation_ratio=self.rope_t_extrapolation_ratio, | |
| enable_fps_modulation=self.rope_enable_fps_modulation, | |
| device=device, | |
| ) | |
| self.pos_embedder = cls_type( | |
| **kwargs, # type: ignore | |
| ) | |
| if self.extra_per_block_abs_pos_emb: | |
| kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio | |
| kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio | |
| kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio | |
| kwargs["device"] = device | |
| kwargs["dtype"] = dtype | |
| self.extra_pos_embedder = LearnablePosEmbAxis( | |
| **kwargs, # type: ignore | |
| ) | |
| def prepare_embedded_sequence( | |
| self, | |
| x_B_C_T_H_W: torch.Tensor, | |
| fps: Optional[torch.Tensor] = None, | |
| padding_mask: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: | |
| """ | |
| Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. | |
| Args: | |
| x_B_C_T_H_W (torch.Tensor): video | |
| fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. | |
| If None, a default value (`self.base_fps`) will be used. | |
| padding_mask (Optional[torch.Tensor]): current it is not used | |
| Returns: | |
| Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| - A tensor of shape (B, T, H, W, D) with the embedded sequence. | |
| - An optional positional embedding tensor, returned only if the positional embedding class | |
| (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. | |
| Notes: | |
| - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. | |
| - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. | |
| - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using | |
| the `self.pos_embedder` with the shape [T, H, W]. | |
| - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the | |
| `self.pos_embedder` with the fps tensor. | |
| - Otherwise, the positional embeddings are generated without considering fps. | |
| """ | |
| if self.concat_padding_mask: | |
| if padding_mask is None: | |
| padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device) | |
| else: | |
| padding_mask = transforms.functional.resize( | |
| padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST | |
| ) | |
| x_B_C_T_H_W = torch.cat( | |
| [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 | |
| ) | |
| x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) | |
| if self.extra_per_block_abs_pos_emb: | |
| extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype) | |
| else: | |
| extra_pos_emb = None | |
| if "rope" in self.pos_emb_cls.lower(): | |
| return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb | |
| x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D] | |
| return x_B_T_H_W_D, None, extra_pos_emb | |
| def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor: | |
| x_B_C_Tt_Hp_Wp = rearrange( | |
| x_B_T_H_W_M, | |
| "B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)", | |
| p1=self.patch_spatial, | |
| p2=self.patch_spatial, | |
| t=self.patch_temporal, | |
| ) | |
| return x_B_C_Tt_Hp_Wp | |
| def forward(self, | |
| x: torch.Tensor, | |
| timesteps: torch.Tensor, | |
| context: torch.Tensor, | |
| fps: Optional[torch.Tensor] = None, | |
| padding_mask: Optional[torch.Tensor] = None, | |
| **kwargs, | |
| ): | |
| return comfy.patcher_extension.WrapperExecutor.new_class_executor( | |
| self._forward, | |
| self, | |
| comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) | |
| ).execute(x, timesteps, context, fps, padding_mask, **kwargs) | |
| def _forward( | |
| self, | |
| x: torch.Tensor, | |
| timesteps: torch.Tensor, | |
| context: torch.Tensor, | |
| fps: Optional[torch.Tensor] = None, | |
| padding_mask: Optional[torch.Tensor] = None, | |
| **kwargs, | |
| ): | |
| x_B_C_T_H_W = x | |
| timesteps_B_T = timesteps | |
| crossattn_emb = context | |
| """ | |
| Args: | |
| x: (B, C, T, H, W) tensor of spatial-temp inputs | |
| timesteps: (B, ) tensor of timesteps | |
| crossattn_emb: (B, N, D) tensor of cross-attention embeddings | |
| """ | |
| x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( | |
| x_B_C_T_H_W, | |
| fps=fps, | |
| padding_mask=padding_mask, | |
| ) | |
| if timesteps_B_T.ndim == 1: | |
| timesteps_B_T = timesteps_B_T.unsqueeze(1) | |
| t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype)) | |
| t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D) | |
| # for logging purpose | |
| affline_scale_log_info = {} | |
| affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach() | |
| self.affline_scale_log_info = affline_scale_log_info | |
| self.affline_emb = t_embedding_B_T_D | |
| self.crossattn_emb = crossattn_emb | |
| if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: | |
| assert ( | |
| x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape | |
| ), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}" | |
| block_kwargs = { | |
| "rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0), | |
| "adaln_lora_B_T_3D": adaln_lora_B_T_3D, | |
| "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, | |
| "transformer_options": kwargs.get("transformer_options", {}), | |
| } | |
| for block in self.blocks: | |
| x_B_T_H_W_D = block( | |
| x_B_T_H_W_D, | |
| t_embedding_B_T_D, | |
| crossattn_emb, | |
| **block_kwargs, | |
| ) | |
| x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D) | |
| x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O) | |
| return x_B_C_Tt_Hp_Wp | |