bndl's picture
Upload 115 files
4f5540c
import torch
class Vector_IV(torch.nn.Module):
def __init__(self, input_feat, hidden_channels):
super(Vector_IV, self).__init__()
self.fc = torch.nn.Sequential(
torch.nn.Linear(input_feat, hidden_channels),
torch.nn.PReLU(),
torch.nn.Linear(hidden_channels, 1)
)
def forward(self, x):
return self.fc(x)
class Vector_Tg(torch.nn.Module):
def __init__(self, input_feat, hidden_channels):
super(Vector_Tg, self).__init__()
self.fc = torch.nn.Sequential(
torch.nn.Linear(input_feat, hidden_channels),
torch.nn.PReLU(),
)
self.TG = torch.nn.Linear(hidden_channels, 1)
self.mult_factor = torch.nn.Sequential(
torch.nn.Linear(hidden_channels, 1),
torch.nn.Tanh(),
)
def forward(self, x):
x = self.fc(x)
return self.TG(x) * self.mult_factor(x)
class Vector_Joint(torch.nn.Module):
def __init__(self, input_feat, hidden_channels):
super(Vector_Joint, self).__init__()
self.fc1 = torch.nn.Linear(input_feat, hidden_channels)
self.leaky1 = torch.nn.PReLU()
self.fc2 = torch.nn.Sequential( # IV model
torch.nn.Linear(hidden_channels, hidden_channels),
torch.nn.PReLU(),
torch.nn.Linear(hidden_channels, 1)
)
self.fc3 = torch.nn.Linear(hidden_channels, 1) # Tg
self.mult_factor = torch.nn.Sequential(
torch.nn.Linear(hidden_channels, 1),
torch.nn.Tanh(),
)
def forward(self, x):
x = self.leaky1(self.fc1(x))
# Run IV:
x_IV = self.fc2(x)
# Run Tg (+ multiplying factor):
x_Tg = torch.exp(self.fc3(x))
m = self.mult_factor(x)
x_Tg = x_Tg * m
return {'IV': x_IV, 'Tg': x_Tg}