Spaces:
Running
on
Zero
Running
on
Zero
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
# // | |
# // Licensed under the Apache License, Version 2.0 (the "License"); | |
# // you may not use this file except in compliance with the License. | |
# // You may obtain a copy of the License at | |
# // | |
# // http://www.apache.org/licenses/LICENSE-2.0 | |
# // | |
# // Unless required by applicable law or agreed to in writing, software | |
# // distributed under the License is distributed on an "AS IS" BASIS, | |
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# // See the License for the specific language governing permissions and | |
# // limitations under the License. | |
from functools import partial | |
from typing import Literal, Optional | |
from torch import Tensor | |
from torch.nn import Conv3d | |
from models.video_vae_v3.modules.inflated_lib import ( | |
MemoryState, | |
extend_head, | |
inflate_bias, | |
inflate_weight, | |
modify_state_dict, | |
) | |
_inflation_mode_t = Literal["none", "tail", "replicate"] | |
_memory_device_t = Optional[Literal["cpu", "same"]] | |
class InflatedCausalConv3d(Conv3d): | |
def __init__( | |
self, | |
*args, | |
inflation_mode: _inflation_mode_t, | |
memory_device: _memory_device_t = "same", | |
**kwargs, | |
): | |
self.inflation_mode = inflation_mode | |
self.memory = None | |
super().__init__(*args, **kwargs) | |
self.temporal_padding = self.padding[0] | |
self.memory_device = memory_device | |
self.padding = (0, *self.padding[1:]) # Remove temporal pad to keep causal. | |
def set_memory_device(self, memory_device: _memory_device_t): | |
self.memory_device = memory_device | |
def forward(self, input: Tensor, memory_state: MemoryState = MemoryState.DISABLED) -> Tensor: | |
mem_size = self.stride[0] - self.kernel_size[0] | |
if (self.memory is not None) and (memory_state == MemoryState.ACTIVE): | |
input = extend_head(input, memory=self.memory) | |
else: | |
input = extend_head(input, times=self.temporal_padding * 2) | |
memory = ( | |
input[:, :, mem_size:].detach() | |
if (mem_size != 0 and memory_state != MemoryState.DISABLED) | |
else None | |
) | |
if ( | |
memory_state != MemoryState.DISABLED | |
and not self.training | |
and (self.memory_device is not None) | |
): | |
self.memory = memory | |
if self.memory_device == "cpu" and self.memory is not None: | |
self.memory = self.memory.to("cpu") | |
return super().forward(input) | |
def _load_from_state_dict( | |
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs | |
): | |
if self.inflation_mode != "none": | |
state_dict = modify_state_dict( | |
self, | |
state_dict, | |
prefix, | |
inflate_weight_fn=partial(inflate_weight, position="tail"), | |
inflate_bias_fn=partial(inflate_bias, position="tail"), | |
) | |
super()._load_from_state_dict( | |
state_dict, | |
prefix, | |
local_metadata, | |
(strict and self.inflation_mode == "none"), | |
missing_keys, | |
unexpected_keys, | |
error_msgs, | |
) | |
def init_causal_conv3d( | |
*args, | |
inflation_mode: _inflation_mode_t, | |
**kwargs, | |
): | |
""" | |
Initialize a Causal-3D convolution layer. | |
Parameters: | |
inflation_mode: Listed as below. It's compatible with all the 3D-VAE checkpoints we have. | |
- none: No inflation will be conducted. | |
The loading logic of state dict will fall back to default. | |
- tail / replicate: Refer to the definition of `InflatedCausalConv3d`. | |
""" | |
return InflatedCausalConv3d(*args, inflation_mode=inflation_mode, **kwargs) | |