IceClear
upload files
42f2c22
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# // http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
from enum import Enum
from typing import Optional
import numpy as np
import torch
from diffusers.models.normalization import RMSNorm
from einops import rearrange
from torch import Tensor, nn
from common.logger import get_logger
logger = get_logger(__name__)
class MemoryState(Enum):
"""
State[Disabled]: No memory bank will be enabled.
State[Initializing]: The model is handling the first clip,
need to reset / initialize the memory bank.
State[Active]: There has been some data in the memory bank.
"""
DISABLED = 0
INITIALIZING = 1
ACTIVE = 2
def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
if isinstance(norm_layer, (nn.LayerNorm, RMSNorm)):
if x.ndim == 4:
x = rearrange(x, "b c h w -> b h w c")
x = norm_layer(x)
x = rearrange(x, "b h w c -> b c h w")
return x
if x.ndim == 5:
x = rearrange(x, "b c t h w -> b t h w c")
x = norm_layer(x)
x = rearrange(x, "b t h w c -> b c t h w")
return x
if isinstance(norm_layer, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)):
if x.ndim <= 4:
return norm_layer(x)
if x.ndim == 5:
t = x.size(2)
x = rearrange(x, "b c t h w -> (b t) c h w")
x = norm_layer(x)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
return x
raise NotImplementedError
def remove_head(tensor: Tensor, times: int = 1) -> Tensor:
"""
Remove duplicated first frame features in the up-sampling process.
"""
if times == 0:
return tensor
return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2)
def extend_head(
tensor: Tensor, times: Optional[int] = 2, memory: Optional[Tensor] = None
) -> Tensor:
"""
When memory is None:
- Duplicate first frame features in the down-sampling process.
When memory is not None:
- Concatenate memory features with the input features to keep temporal consistency.
"""
if times == 0:
return tensor
if memory is not None:
return torch.cat((memory.to(tensor), tensor), dim=2)
else:
tile_repeat = np.ones(tensor.ndim).astype(int)
tile_repeat[2] = times
return torch.cat(tensors=(torch.tile(tensor[:, :, :1], list(tile_repeat)), tensor), dim=2)
def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor, inflation_mode: str):
"""
Inflate a 2D convolution weight matrix to a 3D one.
Parameters:
weight_2d: The weight matrix of 2D conv to be inflated.
weight_3d: The weight matrix of 3D conv to be initialized.
inflation_mode: the mode of inflation
"""
assert inflation_mode in ["constant", "replicate"]
assert weight_3d.shape[:2] == weight_2d.shape[:2]
with torch.no_grad():
if inflation_mode == "replicate":
depth = weight_3d.size(2)
weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth)
else:
weight_3d.fill_(0.0)
weight_3d[:, :, -1].copy_(weight_2d)
return weight_3d
def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor, inflation_mode: str):
"""
Inflate a 2D convolution bias tensor to a 3D one
Parameters:
bias_2d: The bias tensor of 2D conv to be inflated.
bias_3d: The bias tensor of 3D conv to be initialized.
inflation_mode: Placeholder to align `inflate_weight`.
"""
assert bias_3d.shape == bias_2d.shape
with torch.no_grad():
bias_3d.copy_(bias_2d)
return bias_3d
def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn):
"""
the main function to inflated 2D parameters to 3D.
"""
weight_name = prefix + "weight"
bias_name = prefix + "bias"
if weight_name in state_dict:
weight_2d = state_dict[weight_name]
if weight_2d.dim() == 4:
# Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w)
weight_3d = inflate_weight_fn(
weight_2d=weight_2d,
weight_3d=layer.weight,
inflation_mode=layer.inflation_mode,
)
state_dict[weight_name] = weight_3d
else:
return state_dict
# It's a 3d state dict, should not do inflation on both bias and weight.
if bias_name in state_dict:
bias_2d = state_dict[bias_name]
if bias_2d.dim() == 1:
# Assuming the 2D biases are 1D tensors (out_channels,)
bias_3d = inflate_bias_fn(
bias_2d=bias_2d,
bias_3d=layer.bias,
inflation_mode=layer.inflation_mode,
)
state_dict[bias_name] = bias_3d
return state_dict