|
import torch |
|
|
|
|
|
class Vector_IV(torch.nn.Module): |
|
|
|
def __init__(self, input_feat, hidden_channels): |
|
super(Vector_IV, self).__init__() |
|
|
|
self.fc = torch.nn.Sequential( |
|
torch.nn.Linear(input_feat, hidden_channels), |
|
torch.nn.PReLU(), |
|
torch.nn.Linear(hidden_channels, 1) |
|
) |
|
|
|
def forward(self, x): |
|
return self.fc(x) |
|
|
|
class Vector_Tg(torch.nn.Module): |
|
|
|
def __init__(self, input_feat, hidden_channels): |
|
super(Vector_Tg, self).__init__() |
|
|
|
self.fc = torch.nn.Sequential( |
|
torch.nn.Linear(input_feat, hidden_channels), |
|
torch.nn.PReLU(), |
|
) |
|
|
|
self.TG = torch.nn.Linear(hidden_channels, 1) |
|
self.mult_factor = torch.nn.Sequential( |
|
torch.nn.Linear(hidden_channels, 1), |
|
torch.nn.Tanh(), |
|
) |
|
|
|
def forward(self, x): |
|
x = self.fc(x) |
|
return self.TG(x) * self.mult_factor(x) |
|
|
|
|
|
class Vector_Joint(torch.nn.Module): |
|
def __init__(self, input_feat, hidden_channels): |
|
super(Vector_Joint, self).__init__() |
|
|
|
self.fc1 = torch.nn.Linear(input_feat, hidden_channels) |
|
self.leaky1 = torch.nn.PReLU() |
|
self.fc2 = torch.nn.Sequential( |
|
torch.nn.Linear(hidden_channels, hidden_channels), |
|
torch.nn.PReLU(), |
|
torch.nn.Linear(hidden_channels, 1) |
|
) |
|
self.fc3 = torch.nn.Linear(hidden_channels, 1) |
|
|
|
self.mult_factor = torch.nn.Sequential( |
|
torch.nn.Linear(hidden_channels, 1), |
|
torch.nn.Tanh(), |
|
) |
|
|
|
def forward(self, x): |
|
x = self.leaky1(self.fc1(x)) |
|
|
|
|
|
x_IV = self.fc2(x) |
|
|
|
|
|
x_Tg = torch.exp(self.fc3(x)) |
|
m = self.mult_factor(x) |
|
x_Tg = x_Tg * m |
|
|
|
return {'IV': x_IV, 'Tg': x_Tg} |