# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # // # // Licensed under the Apache License, Version 2.0 (the "License"); # // you may not use this file except in compliance with the License. # // You may obtain a copy of the License at # // # // http://www.apache.org/licenses/LICENSE-2.0 # // # // Unless required by applicable law or agreed to in writing, software # // distributed under the License is distributed on an "AS IS" BASIS, # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # // See the License for the specific language governing permissions and # // limitations under the License. from functools import partial from typing import Literal, Optional from torch import Tensor from torch.nn import Conv3d from models.video_vae_v3.modules.inflated_lib import ( MemoryState, extend_head, inflate_bias, inflate_weight, modify_state_dict, ) _inflation_mode_t = Literal["none", "tail", "replicate"] _memory_device_t = Optional[Literal["cpu", "same"]] class InflatedCausalConv3d(Conv3d): def __init__( self, *args, inflation_mode: _inflation_mode_t, memory_device: _memory_device_t = "same", **kwargs, ): self.inflation_mode = inflation_mode self.memory = None super().__init__(*args, **kwargs) self.temporal_padding = self.padding[0] self.memory_device = memory_device self.padding = (0, *self.padding[1:]) # Remove temporal pad to keep causal. def set_memory_device(self, memory_device: _memory_device_t): self.memory_device = memory_device def forward(self, input: Tensor, memory_state: MemoryState = MemoryState.DISABLED) -> Tensor: mem_size = self.stride[0] - self.kernel_size[0] if (self.memory is not None) and (memory_state == MemoryState.ACTIVE): input = extend_head(input, memory=self.memory) else: input = extend_head(input, times=self.temporal_padding * 2) memory = ( input[:, :, mem_size:].detach() if (mem_size != 0 and memory_state != MemoryState.DISABLED) else None ) if ( memory_state != MemoryState.DISABLED and not self.training and (self.memory_device is not None) ): self.memory = memory if self.memory_device == "cpu" and self.memory is not None: self.memory = self.memory.to("cpu") return super().forward(input) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): if self.inflation_mode != "none": state_dict = modify_state_dict( self, state_dict, prefix, inflate_weight_fn=partial(inflate_weight, position="tail"), inflate_bias_fn=partial(inflate_bias, position="tail"), ) super()._load_from_state_dict( state_dict, prefix, local_metadata, (strict and self.inflation_mode == "none"), missing_keys, unexpected_keys, error_msgs, ) def init_causal_conv3d( *args, inflation_mode: _inflation_mode_t, **kwargs, ): """ Initialize a Causal-3D convolution layer. Parameters: inflation_mode: Listed as below. It's compatible with all the 3D-VAE checkpoints we have. - none: No inflation will be conducted. The loading logic of state dict will fall back to default. - tail / replicate: Refer to the definition of `InflatedCausalConv3d`. """ return InflatedCausalConv3d(*args, inflation_mode=inflation_mode, **kwargs)