Spaces:
Sleeping
Sleeping
File size: 1,747 Bytes
953417b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
from itertools import zip_longest
from typing import Sequence, Dict, Union
import torch
from lightning_utilities.core.rank_zero import rank_zero_warn
from torch import nn
class MultiEntityInteraction(nn.Module):
def __init__(
self,
encoders: Union[nn.Module, Sequence[nn.Module], Dict[str, nn.Module]],
decoders: Union[nn.Module, Sequence[nn.Module], Dict[str, nn.Module]],
):
super().__init__()
# Add new encoders to MultiEntityInteraction.
if isinstance(encoders, nn.Module):
# set compatible with original type expectations
encoders = [encoders]
elif isinstance(encoders, Sequence):
# Check all values are encoders
for i, encoder in enumerate(encoders):
if not isinstance(encoder, nn.Module):
raise ValueError(
f"Value {encoder} at index {i} is not an instance of `nn.Module`."
)
elif isinstance(encoders, dict):
# Check all values are encoders
for k, encoder in encoders.items():
if not isinstance(encoder, nn.Module):
raise ValueError(
f"Value {encoder} at key {k} is not an instance of `nn.Module`."
)
else:
raise ValueError(
"Unknown input to MultiEntityInteraction. Expected, `nn.Module`, or `dict`/`sequence` of the"
f" previous, but got {encoders}"
)
self.encoders = encoders
self.decoders = decoders
def forward(self, inputs):
preds = [encoder(x) for encoder, x in zip_longest(self.encoders, inputs)]
return preds
|