File size: 229 Bytes
6ae852e
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
from torch import cat

from deepscreen.models.components.mlp import MLP


class ConcatMLP(MLP):
    def forward(self, *inputs):
        x = cat([*inputs], 1)
        for module in self:
            x = module(x)
        return x