# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # // # // Licensed under the Apache License, Version 2.0 (the "License"); # // you may not use this file except in compliance with the License. # // You may obtain a copy of the License at # // # // http://www.apache.org/licenses/LICENSE-2.0 # // # // Unless required by applicable law or agreed to in writing, software # // distributed under the License is distributed on an "AS IS" BASIS, # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # // See the License for the specific language governing permissions and # // limitations under the License. from dataclasses import dataclass from typing import Any, Callable, Dict, List, Tuple import torch from torch import nn @dataclass class MMArg: vid: Any txt: Any def get_args(key: str, args: List[Any]) -> List[Any]: return [getattr(v, key) if isinstance(v, MMArg) else v for v in args] def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()} class MMModule(nn.Module): def __init__( self, module: Callable[..., nn.Module], *args, shared_weights: bool = False, vid_only: bool = False, **kwargs, ): super().__init__() self.shared_weights = shared_weights self.vid_only = vid_only if self.shared_weights: assert get_args("vid", args) == get_args("txt", args) assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs) self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) else: self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) self.txt = ( module(*get_args("txt", args), **get_kwargs("txt", kwargs)) if not vid_only else None ) def forward( self, vid: torch.FloatTensor, txt: torch.FloatTensor, *args, **kwargs, ) -> Tuple[ torch.FloatTensor, torch.FloatTensor, ]: vid_module = self.vid if not self.shared_weights else self.all vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) if not self.vid_only: txt_module = self.txt if not self.shared_weights else self.all txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs)) return vid, txt