Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,807 Bytes
42f2c22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# // http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
from functools import partial
from typing import Literal, Optional
from torch import Tensor
from torch.nn import Conv3d
from models.video_vae_v3.modules.inflated_lib import (
MemoryState,
extend_head,
inflate_bias,
inflate_weight,
modify_state_dict,
)
_inflation_mode_t = Literal["none", "tail", "replicate"]
_memory_device_t = Optional[Literal["cpu", "same"]]
class InflatedCausalConv3d(Conv3d):
def __init__(
self,
*args,
inflation_mode: _inflation_mode_t,
memory_device: _memory_device_t = "same",
**kwargs,
):
self.inflation_mode = inflation_mode
self.memory = None
super().__init__(*args, **kwargs)
self.temporal_padding = self.padding[0]
self.memory_device = memory_device
self.padding = (0, *self.padding[1:]) # Remove temporal pad to keep causal.
def set_memory_device(self, memory_device: _memory_device_t):
self.memory_device = memory_device
def forward(self, input: Tensor, memory_state: MemoryState = MemoryState.DISABLED) -> Tensor:
mem_size = self.stride[0] - self.kernel_size[0]
if (self.memory is not None) and (memory_state == MemoryState.ACTIVE):
input = extend_head(input, memory=self.memory)
else:
input = extend_head(input, times=self.temporal_padding * 2)
memory = (
input[:, :, mem_size:].detach()
if (mem_size != 0 and memory_state != MemoryState.DISABLED)
else None
)
if (
memory_state != MemoryState.DISABLED
and not self.training
and (self.memory_device is not None)
):
self.memory = memory
if self.memory_device == "cpu" and self.memory is not None:
self.memory = self.memory.to("cpu")
return super().forward(input)
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
if self.inflation_mode != "none":
state_dict = modify_state_dict(
self,
state_dict,
prefix,
inflate_weight_fn=partial(inflate_weight, position="tail"),
inflate_bias_fn=partial(inflate_bias, position="tail"),
)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
(strict and self.inflation_mode == "none"),
missing_keys,
unexpected_keys,
error_msgs,
)
def init_causal_conv3d(
*args,
inflation_mode: _inflation_mode_t,
**kwargs,
):
"""
Initialize a Causal-3D convolution layer.
Parameters:
inflation_mode: Listed as below. It's compatible with all the 3D-VAE checkpoints we have.
- none: No inflation will be conducted.
The loading logic of state dict will fall back to default.
- tail / replicate: Refer to the definition of `InflatedCausalConv3d`.
"""
return InflatedCausalConv3d(*args, inflation_mode=inflation_mode, **kwargs)
|