Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2023 HuggingFace Team | |
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates. | |
# SPDX-License-Identifier: Apache License, Version 2.0 (the "License") | |
# | |
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 1st June 2025 | |
# | |
# Original file was released under Apache License, Version 2.0 (the "License"), with the full license text | |
# available at http://www.apache.org/licenses/LICENSE-2.0. | |
# | |
# This modified file is released under the same license. | |
from contextlib import nullcontext | |
from typing import Optional, Tuple, Literal, Callable, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution | |
from einops import rearrange | |
from common.distributed.advanced import get_sequence_parallel_world_size | |
from common.logger import get_logger | |
from models.video_vae_v3.modules.causal_inflation_lib import ( | |
InflatedCausalConv3d, | |
causal_norm_wrapper, | |
init_causal_conv3d, | |
remove_head, | |
) | |
from models.video_vae_v3.modules.context_parallel_lib import ( | |
causal_conv_gather_outputs, | |
causal_conv_slice_inputs, | |
) | |
from models.video_vae_v3.modules.global_config import set_norm_limit | |
from models.video_vae_v3.modules.types import ( | |
CausalAutoencoderOutput, | |
CausalDecoderOutput, | |
CausalEncoderOutput, | |
MemoryState, | |
_inflation_mode_t, | |
_memory_device_t, | |
_receptive_field_t, | |
_selective_checkpointing_t, | |
) | |
logger = get_logger(__name__) # pylint: disable=invalid-name | |
# Fake func, no checkpointing is required for inference | |
def gradient_checkpointing(module: Union[Callable, nn.Module], *args, enabled: bool, **kwargs): | |
return module(*args, **kwargs) | |
class ResnetBlock2D(nn.Module): | |
r""" | |
A Resnet block. | |
Parameters: | |
in_channels (`int`): The number of channels in the input. | |
out_channels (`int`, *optional*, default to be `None`): | |
The number of output channels for the first conv2d layer. | |
If None, same as `in_channels`. | |
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. | |
""" | |
def __init__( | |
self, *, in_channels: int, out_channels: Optional[int] = None, dropout: float = 0.0 | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
out_channels = in_channels if out_channels is None else out_channels | |
self.out_channels = out_channels | |
self.nonlinearity = nn.SiLU() | |
self.norm1 = torch.nn.GroupNorm( | |
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True | |
) | |
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
self.norm2 = torch.nn.GroupNorm( | |
num_groups=32, num_channels=out_channels, eps=1e-6, affine=True | |
) | |
self.dropout = torch.nn.Dropout(dropout) | |
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
self.use_in_shortcut = self.in_channels != out_channels | |
self.conv_shortcut = None | |
if self.use_in_shortcut: | |
self.conv_shortcut = nn.Conv2d( | |
in_channels, out_channels, kernel_size=1, stride=1, padding=0 | |
) | |
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: | |
hidden = input_tensor | |
hidden = self.norm1(hidden) | |
hidden = self.nonlinearity(hidden) | |
hidden = self.conv1(hidden) | |
hidden = self.norm2(hidden) | |
hidden = self.nonlinearity(hidden) | |
hidden = self.dropout(hidden) | |
hidden = self.conv2(hidden) | |
if self.conv_shortcut is not None: | |
input_tensor = self.conv_shortcut(input_tensor) | |
output_tensor = input_tensor + hidden | |
return output_tensor | |
class Upsample3D(nn.Module): | |
"""A 3D upsampling layer.""" | |
def __init__( | |
self, | |
channels: int, | |
inflation_mode: _inflation_mode_t = "tail", | |
temporal_up: bool = False, | |
spatial_up: bool = True, | |
slicing: bool = False, | |
): | |
super().__init__() | |
self.channels = channels | |
self.conv = init_causal_conv3d( | |
self.channels, self.channels, kernel_size=3, padding=1, inflation_mode=inflation_mode | |
) | |
self.temporal_up = temporal_up | |
self.spatial_up = spatial_up | |
self.temporal_ratio = 2 if temporal_up else 1 | |
self.spatial_ratio = 2 if spatial_up else 1 | |
self.slicing = slicing | |
upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio | |
self.upscale_conv = nn.Conv3d( | |
self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 | |
) | |
identity = ( | |
torch.eye(self.channels).repeat(upscale_ratio, 1).reshape_as(self.upscale_conv.weight) | |
) | |
self.upscale_conv.weight.data.copy_(identity) | |
nn.init.zeros_(self.upscale_conv.bias) | |
self.gradient_checkpointing = False | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
memory_state: MemoryState, | |
) -> torch.FloatTensor: | |
return gradient_checkpointing( | |
self.custom_forward, | |
hidden_states, | |
memory_state, | |
enabled=self.training and self.gradient_checkpointing, | |
) | |
def custom_forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
memory_state: MemoryState, | |
) -> torch.FloatTensor: | |
assert hidden_states.shape[1] == self.channels | |
if self.slicing: | |
split_size = hidden_states.size(2) // 2 | |
hidden_states = list( | |
hidden_states.split([split_size, hidden_states.size(2) - split_size], dim=2) | |
) | |
else: | |
hidden_states = [hidden_states] | |
for i in range(len(hidden_states)): | |
hidden_states[i] = self.upscale_conv(hidden_states[i]) | |
hidden_states[i] = rearrange( | |
hidden_states[i], | |
"b (x y z c) f h w -> b c (f z) (h x) (w y)", | |
x=self.spatial_ratio, | |
y=self.spatial_ratio, | |
z=self.temporal_ratio, | |
) | |
# [Overridden] For causal temporal conv | |
if self.temporal_up and memory_state != MemoryState.ACTIVE: | |
hidden_states[0] = remove_head(hidden_states[0]) | |
if self.slicing: | |
hidden_states = self.conv(hidden_states, memory_state=memory_state) | |
return torch.cat(hidden_states, dim=2) | |
else: | |
return self.conv(hidden_states[0], memory_state=memory_state) | |
class Downsample3D(nn.Module): | |
"""A 3D downsampling layer.""" | |
def __init__( | |
self, | |
channels: int, | |
inflation_mode: _inflation_mode_t = "tail", | |
temporal_down: bool = False, | |
spatial_down: bool = True, | |
): | |
super().__init__() | |
self.channels = channels | |
self.temporal_down = temporal_down | |
self.spatial_down = spatial_down | |
self.temporal_ratio = 2 if temporal_down else 1 | |
self.spatial_ratio = 2 if spatial_down else 1 | |
self.temporal_kernel = 3 if temporal_down else 1 | |
self.spatial_kernel = 3 if spatial_down else 1 | |
self.conv = init_causal_conv3d( | |
self.channels, | |
self.channels, | |
kernel_size=(self.temporal_kernel, self.spatial_kernel, self.spatial_kernel), | |
stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), | |
padding=((1 if self.temporal_down else 0), 0, 0), | |
inflation_mode=inflation_mode, | |
) | |
self.gradient_checkpointing = False | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
memory_state: MemoryState, | |
) -> torch.FloatTensor: | |
return gradient_checkpointing( | |
self.custom_forward, | |
hidden_states, | |
memory_state, | |
enabled=self.training and self.gradient_checkpointing, | |
) | |
def custom_forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
memory_state: MemoryState, | |
) -> torch.FloatTensor: | |
assert hidden_states.shape[1] == self.channels | |
if self.spatial_down: | |
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0) | |
hidden_states = self.conv(hidden_states, memory_state=memory_state) | |
return hidden_states | |
class ResnetBlock3D(ResnetBlock2D): | |
def __init__( | |
self, | |
*args, | |
inflation_mode: _inflation_mode_t = "tail", | |
time_receptive_field: _receptive_field_t = "half", | |
**kwargs, | |
): | |
super().__init__(*args, **kwargs) | |
self.conv1 = init_causal_conv3d( | |
self.in_channels, | |
self.out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
inflation_mode=inflation_mode, | |
) | |
self.conv2 = init_causal_conv3d( | |
self.out_channels, | |
self.out_channels, | |
kernel_size=(1, 3, 3) if time_receptive_field == "half" else (3, 3, 3), | |
stride=1, | |
padding=(0, 1, 1) if time_receptive_field == "half" else (1, 1, 1), | |
inflation_mode=inflation_mode, | |
) | |
if self.use_in_shortcut: | |
self.conv_shortcut = init_causal_conv3d( | |
self.in_channels, | |
self.out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=(self.conv_shortcut.bias is not None), | |
inflation_mode=inflation_mode, | |
) | |
self.gradient_checkpointing = False | |
def forward(self, input_tensor: torch.Tensor, memory_state: MemoryState = MemoryState.UNSET): | |
return gradient_checkpointing( | |
self.custom_forward, | |
input_tensor, | |
memory_state, | |
enabled=self.training and self.gradient_checkpointing, | |
) | |
def custom_forward( | |
self, input_tensor: torch.Tensor, memory_state: MemoryState = MemoryState.UNSET | |
): | |
assert memory_state != MemoryState.UNSET | |
hidden_states = input_tensor | |
hidden_states = causal_norm_wrapper(self.norm1, hidden_states) | |
hidden_states = self.nonlinearity(hidden_states) | |
hidden_states = self.conv1(hidden_states, memory_state=memory_state) | |
hidden_states = causal_norm_wrapper(self.norm2, hidden_states) | |
hidden_states = self.nonlinearity(hidden_states) | |
hidden_states = self.dropout(hidden_states) | |
hidden_states = self.conv2(hidden_states, memory_state=memory_state) | |
if self.conv_shortcut is not None: | |
input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state) | |
output_tensor = input_tensor + hidden_states | |
return output_tensor | |
class DownEncoderBlock3D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
add_downsample: bool = True, | |
inflation_mode: _inflation_mode_t = "tail", | |
time_receptive_field: _receptive_field_t = "half", | |
temporal_down: bool = True, | |
spatial_down: bool = True, | |
): | |
super().__init__() | |
resnets = [] | |
for i in range(num_layers): | |
in_channels = in_channels if i == 0 else out_channels | |
resnets.append( | |
ResnetBlock3D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
dropout=dropout, | |
inflation_mode=inflation_mode, | |
time_receptive_field=time_receptive_field, | |
) | |
) | |
self.resnets = nn.ModuleList(resnets) | |
self.downsamplers = None | |
if add_downsample: | |
# Todo: Refactor this line before V5 Image VAE Training. | |
self.downsamplers = nn.ModuleList( | |
[ | |
Downsample3D( | |
channels=out_channels, | |
inflation_mode=inflation_mode, | |
temporal_down=temporal_down, | |
spatial_down=spatial_down, | |
) | |
] | |
) | |
def forward( | |
self, hidden_states: torch.FloatTensor, memory_state: MemoryState | |
) -> torch.FloatTensor: | |
for resnet in self.resnets: | |
hidden_states = resnet(hidden_states, memory_state=memory_state) | |
if self.downsamplers is not None: | |
for downsampler in self.downsamplers: | |
hidden_states = downsampler(hidden_states, memory_state=memory_state) | |
return hidden_states | |
class UpDecoderBlock3D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
add_upsample: bool = True, | |
inflation_mode: _inflation_mode_t = "tail", | |
time_receptive_field: _receptive_field_t = "half", | |
temporal_up: bool = True, | |
spatial_up: bool = True, | |
slicing: bool = False, | |
): | |
super().__init__() | |
resnets = [] | |
for i in range(num_layers): | |
input_channels = in_channels if i == 0 else out_channels | |
resnets.append( | |
ResnetBlock3D( | |
in_channels=input_channels, | |
out_channels=out_channels, | |
dropout=dropout, | |
inflation_mode=inflation_mode, | |
time_receptive_field=time_receptive_field, | |
) | |
) | |
self.resnets = nn.ModuleList(resnets) | |
self.upsamplers = None | |
# Todo: Refactor this line before V5 Image VAE Training. | |
if add_upsample: | |
self.upsamplers = nn.ModuleList( | |
[ | |
Upsample3D( | |
channels=out_channels, | |
inflation_mode=inflation_mode, | |
temporal_up=temporal_up, | |
spatial_up=spatial_up, | |
slicing=slicing, | |
) | |
] | |
) | |
def forward( | |
self, hidden_states: torch.FloatTensor, memory_state: MemoryState | |
) -> torch.FloatTensor: | |
for resnet in self.resnets: | |
hidden_states = resnet(hidden_states, memory_state=memory_state) | |
if self.upsamplers is not None: | |
for upsampler in self.upsamplers: | |
hidden_states = upsampler(hidden_states, memory_state=memory_state) | |
return hidden_states | |
class UNetMidBlock3D(nn.Module): | |
def __init__( | |
self, | |
channels: int, | |
dropout: float = 0.0, | |
inflation_mode: _inflation_mode_t = "tail", | |
time_receptive_field: _receptive_field_t = "half", | |
): | |
super().__init__() | |
self.resnets = nn.ModuleList( | |
[ | |
ResnetBlock3D( | |
in_channels=channels, | |
out_channels=channels, | |
dropout=dropout, | |
inflation_mode=inflation_mode, | |
time_receptive_field=time_receptive_field, | |
), | |
ResnetBlock3D( | |
in_channels=channels, | |
out_channels=channels, | |
dropout=dropout, | |
inflation_mode=inflation_mode, | |
time_receptive_field=time_receptive_field, | |
), | |
] | |
) | |
def forward(self, hidden_states: torch.Tensor, memory_state: MemoryState): | |
for resnet in self.resnets: | |
hidden_states = resnet(hidden_states, memory_state) | |
return hidden_states | |
class Encoder3D(nn.Module): | |
r""" | |
The `Encoder` layer of a variational autoencoder that encodes | |
its input into a latent representation. | |
""" | |
def __init__( | |
self, | |
in_channels: int = 3, | |
out_channels: int = 3, | |
block_out_channels: Tuple[int, ...] = (64,), | |
layers_per_block: int = 2, | |
double_z: bool = True, | |
temporal_down_num: int = 2, | |
inflation_mode: _inflation_mode_t = "tail", | |
time_receptive_field: _receptive_field_t = "half", | |
selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), | |
): | |
super().__init__() | |
self.layers_per_block = layers_per_block | |
self.temporal_down_num = temporal_down_num | |
self.conv_in = init_causal_conv3d( | |
in_channels, | |
block_out_channels[0], | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
inflation_mode=inflation_mode, | |
) | |
self.down_blocks = nn.ModuleList([]) | |
# down | |
output_channel = block_out_channels[0] | |
for i in range(len(block_out_channels)): | |
input_channel = output_channel | |
output_channel = block_out_channels[i] | |
is_final_block = i == len(block_out_channels) - 1 | |
is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 | |
# Note: take the last one | |
down_block = DownEncoderBlock3D( | |
num_layers=self.layers_per_block, | |
in_channels=input_channel, | |
out_channels=output_channel, | |
add_downsample=not is_final_block, | |
temporal_down=is_temporal_down_block, | |
spatial_down=True, | |
inflation_mode=inflation_mode, | |
time_receptive_field=time_receptive_field, | |
) | |
self.down_blocks.append(down_block) | |
# mid | |
self.mid_block = UNetMidBlock3D( | |
channels=block_out_channels[-1], | |
inflation_mode=inflation_mode, | |
time_receptive_field=time_receptive_field, | |
) | |
# out | |
self.conv_norm_out = nn.GroupNorm( | |
num_channels=block_out_channels[-1], num_groups=32, eps=1e-6 | |
) | |
self.conv_act = nn.SiLU() | |
conv_out_channels = 2 * out_channels if double_z else out_channels | |
self.conv_out = init_causal_conv3d( | |
block_out_channels[-1], conv_out_channels, 3, padding=1, inflation_mode=inflation_mode | |
) | |
assert len(selective_checkpointing) == len(self.down_blocks) | |
self.set_gradient_checkpointing(selective_checkpointing) | |
def set_gradient_checkpointing(self, checkpointing_types): | |
gradient_checkpointing = [] | |
for down_block, sac_type in zip(self.down_blocks, checkpointing_types): | |
if sac_type == "coarse": | |
gradient_checkpointing.append(True) | |
elif sac_type == "fine": | |
for n, m in down_block.named_modules(): | |
if hasattr(m, "gradient_checkpointing"): | |
m.gradient_checkpointing = True | |
logger.debug(f"set gradient_checkpointing: {n}") | |
gradient_checkpointing.append(False) | |
else: | |
gradient_checkpointing.append(False) | |
self.gradient_checkpointing = gradient_checkpointing | |
logger.info(f"[Encoder3D] gradient_checkpointing: {checkpointing_types}") | |
def forward(self, sample: torch.FloatTensor, memory_state: MemoryState) -> torch.FloatTensor: | |
r"""The forward method of the `Encoder` class.""" | |
sample = self.conv_in(sample, memory_state=memory_state) | |
# down | |
for down_block, sac in zip(self.down_blocks, self.gradient_checkpointing): | |
sample = gradient_checkpointing( | |
down_block, | |
sample, | |
memory_state=memory_state, | |
enabled=self.training and sac, | |
) | |
# middle | |
sample = self.mid_block(sample, memory_state=memory_state) | |
# post-process | |
sample = causal_norm_wrapper(self.conv_norm_out, sample) | |
sample = self.conv_act(sample) | |
sample = self.conv_out(sample, memory_state=memory_state) | |
return sample | |
class Decoder3D(nn.Module): | |
r""" | |
The `Decoder` layer of a variational autoencoder that | |
decodes its latent representation into an output sample. | |
""" | |
def __init__( | |
self, | |
in_channels: int = 3, | |
out_channels: int = 3, | |
block_out_channels: Tuple[int, ...] = (64,), | |
layers_per_block: int = 2, | |
inflation_mode: _inflation_mode_t = "tail", | |
time_receptive_field: _receptive_field_t = "half", | |
temporal_up_num: int = 2, | |
slicing_up_num: int = 0, | |
selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), | |
): | |
super().__init__() | |
self.layers_per_block = layers_per_block | |
self.temporal_up_num = temporal_up_num | |
self.conv_in = init_causal_conv3d( | |
in_channels, | |
block_out_channels[-1], | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
inflation_mode=inflation_mode, | |
) | |
self.up_blocks = nn.ModuleList([]) | |
# mid | |
self.mid_block = UNetMidBlock3D( | |
channels=block_out_channels[-1], | |
inflation_mode=inflation_mode, | |
time_receptive_field=time_receptive_field, | |
) | |
# up | |
reversed_block_out_channels = list(reversed(block_out_channels)) | |
output_channel = reversed_block_out_channels[0] | |
for i in range(len(reversed_block_out_channels)): | |
prev_output_channel = output_channel | |
output_channel = reversed_block_out_channels[i] | |
is_final_block = i == len(block_out_channels) - 1 | |
is_temporal_up_block = i < self.temporal_up_num | |
is_slicing_up_block = i >= len(block_out_channels) - slicing_up_num | |
# Note: Keep symmetric | |
up_block = UpDecoderBlock3D( | |
num_layers=self.layers_per_block + 1, | |
in_channels=prev_output_channel, | |
out_channels=output_channel, | |
add_upsample=not is_final_block, | |
temporal_up=is_temporal_up_block, | |
slicing=is_slicing_up_block, | |
inflation_mode=inflation_mode, | |
time_receptive_field=time_receptive_field, | |
) | |
self.up_blocks.append(up_block) | |
# out | |
self.conv_norm_out = nn.GroupNorm( | |
num_channels=block_out_channels[0], num_groups=32, eps=1e-6 | |
) | |
self.conv_act = nn.SiLU() | |
self.conv_out = init_causal_conv3d( | |
block_out_channels[0], out_channels, 3, padding=1, inflation_mode=inflation_mode | |
) | |
assert len(selective_checkpointing) == len(self.up_blocks) | |
self.set_gradient_checkpointing(selective_checkpointing) | |
def set_gradient_checkpointing(self, checkpointing_types): | |
gradient_checkpointing = [] | |
for up_block, sac_type in zip(self.up_blocks, checkpointing_types): | |
if sac_type == "coarse": | |
gradient_checkpointing.append(True) | |
elif sac_type == "fine": | |
for n, m in up_block.named_modules(): | |
if hasattr(m, "gradient_checkpointing"): | |
m.gradient_checkpointing = True | |
logger.debug(f"set gradient_checkpointing: {n}") | |
gradient_checkpointing.append(False) | |
else: | |
gradient_checkpointing.append(False) | |
self.gradient_checkpointing = gradient_checkpointing | |
logger.info(f"[Decoder3D] gradient_checkpointing: {checkpointing_types}") | |
def forward(self, sample: torch.FloatTensor, memory_state: MemoryState) -> torch.FloatTensor: | |
r"""The forward method of the `Decoder` class.""" | |
sample = self.conv_in(sample, memory_state=memory_state) | |
# middle | |
sample = self.mid_block(sample, memory_state=memory_state) | |
# up | |
for up_block, sac in zip(self.up_blocks, self.gradient_checkpointing): | |
sample = gradient_checkpointing( | |
up_block, | |
sample, | |
memory_state=memory_state, | |
enabled=self.training and sac, | |
) | |
# post-process | |
sample = causal_norm_wrapper(self.conv_norm_out, sample) | |
sample = self.conv_act(sample) | |
sample = self.conv_out(sample, memory_state=memory_state) | |
return sample | |
class VideoAutoencoderKL(nn.Module): | |
def __init__( | |
self, | |
in_channels: int = 3, | |
out_channels: int = 3, | |
block_out_channels: Tuple[int] = (64,), | |
layers_per_block: int = 1, | |
latent_channels: int = 4, | |
use_quant_conv: bool = True, | |
use_post_quant_conv: bool = True, | |
enc_selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), | |
dec_selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), | |
temporal_scale_num: int = 3, | |
slicing_up_num: int = 0, | |
inflation_mode: _inflation_mode_t = "tail", | |
time_receptive_field: _receptive_field_t = "half", | |
slicing_sample_min_size: int = None, | |
spatial_downsample_factor: int = 16, | |
temporal_downsample_factor: int = 8, | |
freeze_encoder: bool = False, | |
): | |
super().__init__() | |
self.spatial_downsample_factor = spatial_downsample_factor | |
self.temporal_downsample_factor = temporal_downsample_factor | |
self.freeze_encoder = freeze_encoder | |
if slicing_sample_min_size is None: | |
slicing_sample_min_size = temporal_downsample_factor | |
self.slicing_sample_min_size = slicing_sample_min_size | |
self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) | |
# pass init params to Encoder | |
self.encoder = Encoder3D( | |
in_channels=in_channels, | |
out_channels=latent_channels, | |
block_out_channels=block_out_channels, | |
layers_per_block=layers_per_block, | |
double_z=True, | |
temporal_down_num=temporal_scale_num, | |
selective_checkpointing=enc_selective_checkpointing, | |
inflation_mode=inflation_mode, | |
time_receptive_field=time_receptive_field, | |
) | |
# pass init params to Decoder | |
self.decoder = Decoder3D( | |
in_channels=latent_channels, | |
out_channels=out_channels, | |
block_out_channels=block_out_channels, | |
layers_per_block=layers_per_block, | |
# [Override] add temporal_up_num parameter | |
temporal_up_num=temporal_scale_num, | |
slicing_up_num=slicing_up_num, | |
selective_checkpointing=dec_selective_checkpointing, | |
inflation_mode=inflation_mode, | |
time_receptive_field=time_receptive_field, | |
) | |
self.quant_conv = ( | |
init_causal_conv3d( | |
in_channels=2 * latent_channels, | |
out_channels=2 * latent_channels, | |
kernel_size=1, | |
inflation_mode=inflation_mode, | |
) | |
if use_quant_conv | |
else None | |
) | |
self.post_quant_conv = ( | |
init_causal_conv3d( | |
in_channels=latent_channels, | |
out_channels=latent_channels, | |
kernel_size=1, | |
inflation_mode=inflation_mode, | |
) | |
if use_post_quant_conv | |
else None | |
) | |
self.use_slicing = False | |
def enable_slicing(self): | |
self.use_slicing = True | |
def disable_slicing(self): | |
self.use_slicing = False | |
def encode(self, x: torch.FloatTensor) -> CausalEncoderOutput: | |
if x.ndim == 4: | |
x = x.unsqueeze(2) | |
h = self.slicing_encode(x) | |
p = DiagonalGaussianDistribution(h) | |
z = p.sample() | |
return CausalEncoderOutput(z, p) | |
def decode(self, z: torch.FloatTensor) -> CausalDecoderOutput: | |
if z.ndim == 4: | |
z = z.unsqueeze(2) | |
x = self.slicing_decode(z) | |
return CausalDecoderOutput(x) | |
def _encode(self, x: torch.Tensor, memory_state: MemoryState) -> torch.Tensor: | |
x = causal_conv_slice_inputs(x, self.slicing_sample_min_size, memory_state=memory_state) | |
h = self.encoder(x, memory_state=memory_state) | |
h = self.quant_conv(h, memory_state=memory_state) if self.quant_conv is not None else h | |
h = causal_conv_gather_outputs(h) | |
return h | |
def _decode(self, z: torch.Tensor, memory_state: MemoryState) -> torch.Tensor: | |
z = causal_conv_slice_inputs(z, self.slicing_latent_min_size, memory_state=memory_state) | |
z = ( | |
self.post_quant_conv(z, memory_state=memory_state) | |
if self.post_quant_conv is not None | |
else z | |
) | |
x = self.decoder(z, memory_state=memory_state) | |
x = causal_conv_gather_outputs(x) | |
return x | |
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: | |
sp_size = get_sequence_parallel_world_size() | |
if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: | |
x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2) | |
encoded_slices = [ | |
self._encode( | |
torch.cat((x[:, :, :1], x_slices[0]), dim=2), | |
memory_state=MemoryState.INITIALIZING, | |
) | |
] | |
for x_idx in range(1, len(x_slices)): | |
encoded_slices.append( | |
self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE) | |
) | |
return torch.cat(encoded_slices, dim=2) | |
else: | |
return self._encode(x, memory_state=MemoryState.DISABLED) | |
def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: | |
sp_size = get_sequence_parallel_world_size() | |
if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: | |
z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) | |
decoded_slices = [ | |
self._decode( | |
torch.cat((z[:, :, :1], z_slices[0]), dim=2), | |
memory_state=MemoryState.INITIALIZING, | |
) | |
] | |
for z_idx in range(1, len(z_slices)): | |
decoded_slices.append( | |
self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE) | |
) | |
return torch.cat(decoded_slices, dim=2) | |
else: | |
return self._decode(z, memory_state=MemoryState.DISABLED) | |
def forward(self, x: torch.FloatTensor) -> CausalAutoencoderOutput: | |
with torch.no_grad() if self.freeze_encoder else nullcontext(): | |
z, p = self.encode(x) | |
x = self.decode(z).sample | |
return CausalAutoencoderOutput(x, z, p) | |
def preprocess(self, x: torch.Tensor): | |
# x should in [B, C, T, H, W], [B, C, H, W] | |
assert x.ndim == 4 or x.size(2) % self.temporal_downsample_factor == 1 | |
return x | |
def postprocess(self, x: torch.Tensor): | |
# x should in [B, C, T, H, W], [B, C, H, W] | |
return x | |
def set_causal_slicing( | |
self, | |
*, | |
split_size: Optional[int], | |
memory_device: _memory_device_t, | |
): | |
assert ( | |
split_size is None or memory_device is not None | |
), "if split_size is set, memory_device must not be None." | |
if split_size is not None: | |
self.enable_slicing() | |
self.slicing_sample_min_size = split_size | |
self.slicing_latent_min_size = split_size // self.temporal_downsample_factor | |
else: | |
self.disable_slicing() | |
for module in self.modules(): | |
if isinstance(module, InflatedCausalConv3d): | |
module.set_memory_device(memory_device) | |
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): | |
set_norm_limit(norm_max_mem) | |
for m in self.modules(): | |
if isinstance(m, InflatedCausalConv3d): | |
m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) | |
class VideoAutoencoderKLWrapper(VideoAutoencoderKL): | |
def __init__( | |
self, *args, spatial_downsample_factor: int, temporal_downsample_factor: int, **kwargs | |
): | |
self.spatial_downsample_factor = spatial_downsample_factor | |
self.temporal_downsample_factor = temporal_downsample_factor | |
super().__init__(*args, **kwargs) | |
def forward(self, x) -> CausalAutoencoderOutput: | |
z, _, p = self.encode(x) | |
x, _ = self.decode(z) | |
return CausalAutoencoderOutput(x, z, None, p) | |
def encode(self, x) -> CausalEncoderOutput: | |
if x.ndim == 4: | |
x = x.unsqueeze(2) | |
p = super().encode(x).latent_dist | |
z = p.sample().squeeze(2) | |
return CausalEncoderOutput(z, None, p) | |
def decode(self, z) -> CausalDecoderOutput: | |
if z.ndim == 4: | |
z = z.unsqueeze(2) | |
x = super().decode(z).sample.squeeze(2) | |
return CausalDecoderOutput(x, None) | |
def preprocess(self, x): | |
# x should in [B, C, T, H, W], [B, C, H, W] | |
assert x.ndim == 4 or x.size(2) % 4 == 1 | |
return x | |
def postprocess(self, x): | |
# x should in [B, C, T, H, W], [B, C, H, W] | |
return x | |
def set_causal_slicing( | |
self, | |
*, | |
split_size: Optional[int], | |
memory_device: Optional[Literal["cpu", "same"]], | |
): | |
assert ( | |
split_size is None or memory_device is not None | |
), "if split_size is set, memory_device must not be None." | |
if split_size is not None: | |
self.enable_slicing() | |
else: | |
self.disable_slicing() | |
self.slicing_sample_min_size = split_size | |
if split_size is not None: | |
self.slicing_latent_min_size = split_size // self.temporal_downsample_factor | |
for module in self.modules(): | |
if isinstance(module, InflatedCausalConv3d): | |
module.set_memory_device(memory_device) |