from typing import Optional, Union from pathlib import Path import os import json import torch import torch.nn as nn from einops import rearrange from diffusers import ConfigMixin, ModelMixin from safetensors.torch import safe_open from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND class ResBlock(nn.Module): def __init__( self, channels: int, mid_channels: Optional[int] = None, dims: int = 3 ): super().__init__() if mid_channels is None: mid_channels = channels Conv = nn.Conv2d if dims == 2 else nn.Conv3d self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) self.norm1 = nn.GroupNorm(32, mid_channels) self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) self.norm2 = nn.GroupNorm(32, channels) self.activation = nn.SiLU() def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x x = self.conv1(x) x = self.norm1(x) x = self.activation(x) x = self.conv2(x) x = self.norm2(x) x = self.activation(x + residual) return x class LatentUpsampler(ModelMixin, ConfigMixin): """ Model to spatially upsample VAE latents. Args: in_channels (`int`): Number of channels in the input latent mid_channels (`int`): Number of channels in the middle layers num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) dims (`int`): Number of dimensions for convolutions (2 or 3) spatial_upsample (`bool`): Whether to spatially upsample the latent temporal_upsample (`bool`): Whether to temporally upsample the latent """ def __init__( self, in_channels: int = 128, mid_channels: int = 512, num_blocks_per_stage: int = 4, dims: int = 3, spatial_upsample: bool = True, temporal_upsample: bool = False, ): super().__init__() self.in_channels = in_channels self.mid_channels = mid_channels self.num_blocks_per_stage = num_blocks_per_stage self.dims = dims self.spatial_upsample = spatial_upsample self.temporal_upsample = temporal_upsample Conv = nn.Conv2d if dims == 2 else nn.Conv3d self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1) self.initial_norm = nn.GroupNorm(32, mid_channels) self.initial_activation = nn.SiLU() self.res_blocks = nn.ModuleList( [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] ) if spatial_upsample and temporal_upsample: self.upsampler = nn.Sequential( nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), PixelShuffleND(3), ) elif spatial_upsample: self.upsampler = nn.Sequential( nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), PixelShuffleND(2), ) elif temporal_upsample: self.upsampler = nn.Sequential( nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), PixelShuffleND(1), ) else: raise ValueError( "Either spatial_upsample or temporal_upsample must be True" ) self.post_upsample_res_blocks = nn.ModuleList( [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] ) self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1) def forward(self, latent: torch.Tensor) -> torch.Tensor: b, c, f, h, w = latent.shape if self.dims == 2: x = rearrange(latent, "b c f h w -> (b f) c h w") x = self.initial_conv(x) x = self.initial_norm(x) x = self.initial_activation(x) for block in self.res_blocks: x = block(x) x = self.upsampler(x) for block in self.post_upsample_res_blocks: x = block(x) x = self.final_conv(x) x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) else: x = self.initial_conv(latent) x = self.initial_norm(x) x = self.initial_activation(x) for block in self.res_blocks: x = block(x) if self.temporal_upsample: x = self.upsampler(x) x = x[:, :, 1:, :, :] else: x = rearrange(x, "b c f h w -> (b f) c h w") x = self.upsampler(x) x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) for block in self.post_upsample_res_blocks: x = block(x) x = self.final_conv(x) return x @classmethod def from_config(cls, config): return cls( in_channels=config.get("in_channels", 4), mid_channels=config.get("mid_channels", 128), num_blocks_per_stage=config.get("num_blocks_per_stage", 4), dims=config.get("dims", 2), spatial_upsample=config.get("spatial_upsample", True), temporal_upsample=config.get("temporal_upsample", False), ) def config(self): return { "_class_name": "LatentUpsampler", "in_channels": self.in_channels, "mid_channels": self.mid_channels, "num_blocks_per_stage": self.num_blocks_per_stage, "dims": self.dims, "spatial_upsample": self.spatial_upsample, "temporal_upsample": self.temporal_upsample, } @classmethod def from_pretrained( cls, pretrained_model_path: Optional[Union[str, os.PathLike]], *args, **kwargs, ): pretrained_model_path = Path(pretrained_model_path) if pretrained_model_path.is_file() and str(pretrained_model_path).endswith( ".safetensors" ): state_dict = {} with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: metadata = f.metadata() for k in f.keys(): state_dict[k] = f.get_tensor(k) config = json.loads(metadata["config"]) with torch.device("meta"): latent_upsampler = LatentUpsampler.from_config(config) latent_upsampler.load_state_dict(state_dict, assign=True) return latent_upsampler if __name__ == "__main__": latent_upsampler = LatentUpsampler(num_blocks_per_stage=4, dims=3) print(latent_upsampler) total_params = sum(p.numel() for p in latent_upsampler.parameters()) print(f"Total number of parameters: {total_params:,}") latent = torch.randn(1, 128, 9, 16, 16) upsampled_latent = latent_upsampler(latent) print(f"Upsampled latent shape: {upsampled_latent.shape}")