from itertools import zip_longest from typing import Sequence, Dict, Union import torch from lightning_utilities.core.rank_zero import rank_zero_warn from torch import nn class MultiEntityInteraction(nn.Module): def __init__( self, encoders: Union[nn.Module, Sequence[nn.Module], Dict[str, nn.Module]], decoders: Union[nn.Module, Sequence[nn.Module], Dict[str, nn.Module]], ): super().__init__() # Add new encoders to MultiEntityInteraction. if isinstance(encoders, nn.Module): # set compatible with original type expectations encoders = [encoders] elif isinstance(encoders, Sequence): # Check all values are encoders for i, encoder in enumerate(encoders): if not isinstance(encoder, nn.Module): raise ValueError( f"Value {encoder} at index {i} is not an instance of `nn.Module`." ) elif isinstance(encoders, dict): # Check all values are encoders for k, encoder in encoders.items(): if not isinstance(encoder, nn.Module): raise ValueError( f"Value {encoder} at key {k} is not an instance of `nn.Module`." ) else: raise ValueError( "Unknown input to MultiEntityInteraction. Expected, `nn.Module`, or `dict`/`sequence` of the" f" previous, but got {encoders}" ) self.encoders = encoders self.decoders = decoders def forward(self, inputs): preds = [encoder(x) for encoder, x in zip_longest(self.encoders, inputs)] return preds