IceClear
upload files
42f2c22
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# // http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
from typing import Tuple
import torch
import torch.nn as nn
# from ..cache import Cache
from common.cache import Cache
from .attention.mmattn import NaSwinAttention
from ..mm import MMArg
from ..modulation import ada_layer_type
from ..normalization import norm_layer_type
from ..mm import MMArg, MMModule
from ..mlp import get_mlp
class NaMMSRTransformerBlock(nn.Module):
def __init__(
self,
*,
vid_dim: int,
txt_dim: int,
emb_dim: int,
heads: int,
head_dim: int,
expand_ratio: int,
norm: norm_layer_type,
norm_eps: float,
ada: ada_layer_type,
qk_bias: bool,
qk_norm: norm_layer_type,
mlp_type: str,
shared_weights: bool,
rope_type: str,
rope_dim: int,
is_last_layer: bool,
**kwargs,
):
super().__init__()
dim = MMArg(vid_dim, txt_dim)
self.attn_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,)
self.attn = NaSwinAttention(
vid_dim=vid_dim,
txt_dim=txt_dim,
heads=heads,
head_dim=head_dim,
qk_bias=qk_bias,
qk_norm=qk_norm,
qk_norm_eps=norm_eps,
rope_type=rope_type,
rope_dim=rope_dim,
shared_weights=shared_weights,
window=kwargs.pop("window", None),
window_method=kwargs.pop("window_method", None),
)
self.mlp_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer)
self.mlp = MMModule(
get_mlp(mlp_type),
dim=dim,
expand_ratio=expand_ratio,
shared_weights=shared_weights,
vid_only=is_last_layer
)
self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer)
self.is_last_layer = is_last_layer
def forward(
self,
vid: torch.FloatTensor, # l c
txt: torch.FloatTensor, # l c
vid_shape: torch.LongTensor, # b 3
txt_shape: torch.LongTensor, # b 1
emb: torch.FloatTensor,
cache: Cache,
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
torch.LongTensor,
torch.LongTensor,
]:
hid_len = MMArg(
cache("vid_len", lambda: vid_shape.prod(-1)),
cache("txt_len", lambda: txt_shape.prod(-1)),
)
ada_kwargs = {
"emb": emb,
"hid_len": hid_len,
"cache": cache,
"branch_tag": MMArg("vid", "txt"),
}
vid_attn, txt_attn = self.attn_norm(vid, txt)
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs)
vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache)
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs)
vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt)
vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn)
vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs)
vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp)
vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs)
vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn)
return vid_mlp, txt_mlp, vid_shape, txt_shape