Spaces:
Sleeping
Sleeping
from itertools import zip_longest | |
from typing import Sequence, Dict, Union | |
import torch | |
from lightning_utilities.core.rank_zero import rank_zero_warn | |
from torch import nn | |
class MultiEntityInteraction(nn.Module): | |
def __init__( | |
self, | |
encoders: Union[nn.Module, Sequence[nn.Module], Dict[str, nn.Module]], | |
decoders: Union[nn.Module, Sequence[nn.Module], Dict[str, nn.Module]], | |
): | |
super().__init__() | |
# Add new encoders to MultiEntityInteraction. | |
if isinstance(encoders, nn.Module): | |
# set compatible with original type expectations | |
encoders = [encoders] | |
elif isinstance(encoders, Sequence): | |
# Check all values are encoders | |
for i, encoder in enumerate(encoders): | |
if not isinstance(encoder, nn.Module): | |
raise ValueError( | |
f"Value {encoder} at index {i} is not an instance of `nn.Module`." | |
) | |
elif isinstance(encoders, dict): | |
# Check all values are encoders | |
for k, encoder in encoders.items(): | |
if not isinstance(encoder, nn.Module): | |
raise ValueError( | |
f"Value {encoder} at key {k} is not an instance of `nn.Module`." | |
) | |
else: | |
raise ValueError( | |
"Unknown input to MultiEntityInteraction. Expected, `nn.Module`, or `dict`/`sequence` of the" | |
f" previous, but got {encoders}" | |
) | |
self.encoders = encoders | |
self.decoders = decoders | |
def forward(self, inputs): | |
preds = [encoder(x) for encoder, x in zip_longest(self.encoders, inputs)] | |
return preds | |