# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # // # // Licensed under the Apache License, Version 2.0 (the "License"); # // you may not use this file except in compliance with the License. # // You may obtain a copy of the License at # // # // http://www.apache.org/licenses/LICENSE-2.0 # // # // Unless required by applicable law or agreed to in writing, software # // distributed under the License is distributed on an "AS IS" BASIS, # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # // See the License for the specific language governing permissions and # // limitations under the License. from typing import Tuple import torch import torch.nn as nn # from ..cache import Cache from common.cache import Cache from .attention.mmattn import NaSwinAttention from ..mm import MMArg from ..modulation import ada_layer_type from ..normalization import norm_layer_type from ..mm import MMArg, MMModule from ..mlp import get_mlp class NaMMSRTransformerBlock(nn.Module): def __init__( self, *, vid_dim: int, txt_dim: int, emb_dim: int, heads: int, head_dim: int, expand_ratio: int, norm: norm_layer_type, norm_eps: float, ada: ada_layer_type, qk_bias: bool, qk_norm: norm_layer_type, mlp_type: str, shared_weights: bool, rope_type: str, rope_dim: int, is_last_layer: bool, **kwargs, ): super().__init__() dim = MMArg(vid_dim, txt_dim) self.attn_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,) self.attn = NaSwinAttention( vid_dim=vid_dim, txt_dim=txt_dim, heads=heads, head_dim=head_dim, qk_bias=qk_bias, qk_norm=qk_norm, qk_norm_eps=norm_eps, rope_type=rope_type, rope_dim=rope_dim, shared_weights=shared_weights, window=kwargs.pop("window", None), window_method=kwargs.pop("window_method", None), ) self.mlp_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer) self.mlp = MMModule( get_mlp(mlp_type), dim=dim, expand_ratio=expand_ratio, shared_weights=shared_weights, vid_only=is_last_layer ) self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer) self.is_last_layer = is_last_layer def forward( self, vid: torch.FloatTensor, # l c txt: torch.FloatTensor, # l c vid_shape: torch.LongTensor, # b 3 txt_shape: torch.LongTensor, # b 1 emb: torch.FloatTensor, cache: Cache, ) -> Tuple[ torch.FloatTensor, torch.FloatTensor, torch.LongTensor, torch.LongTensor, ]: hid_len = MMArg( cache("vid_len", lambda: vid_shape.prod(-1)), cache("txt_len", lambda: txt_shape.prod(-1)), ) ada_kwargs = { "emb": emb, "hid_len": hid_len, "cache": cache, "branch_tag": MMArg("vid", "txt"), } vid_attn, txt_attn = self.attn_norm(vid, txt) vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs) vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp) vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs) vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn) return vid_mlp, txt_mlp, vid_shape, txt_shape