libokj's picture
Upload 299 files
953417b
from itertools import zip_longest
from typing import Sequence, Dict, Union
import torch
from lightning_utilities.core.rank_zero import rank_zero_warn
from torch import nn
class MultiEntityInteraction(nn.Module):
def __init__(
self,
encoders: Union[nn.Module, Sequence[nn.Module], Dict[str, nn.Module]],
decoders: Union[nn.Module, Sequence[nn.Module], Dict[str, nn.Module]],
):
super().__init__()
# Add new encoders to MultiEntityInteraction.
if isinstance(encoders, nn.Module):
# set compatible with original type expectations
encoders = [encoders]
elif isinstance(encoders, Sequence):
# Check all values are encoders
for i, encoder in enumerate(encoders):
if not isinstance(encoder, nn.Module):
raise ValueError(
f"Value {encoder} at index {i} is not an instance of `nn.Module`."
)
elif isinstance(encoders, dict):
# Check all values are encoders
for k, encoder in encoders.items():
if not isinstance(encoder, nn.Module):
raise ValueError(
f"Value {encoder} at key {k} is not an instance of `nn.Module`."
)
else:
raise ValueError(
"Unknown input to MultiEntityInteraction. Expected, `nn.Module`, or `dict`/`sequence` of the"
f" previous, but got {encoders}"
)
self.encoders = encoders
self.decoders = decoders
def forward(self, inputs):
preds = [encoder(x) for encoder, x in zip_longest(self.encoders, inputs)]
return preds