File size: 3,090 Bytes
4f5540c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import torch
from torch_geometric.nn import SAGEConv, GATConv, Sequential, BatchNorm
from torch_geometric.nn import SAGPooling
class PolymerGNN_IV_evidential(torch.nn.Module):
def __init__(self, input_feat, hidden_channels, num_additional = 0):
super(PolymerGNN_IV_evidential, self).__init__()
self.hidden_channels = hidden_channels
self.Asage = Sequential('x, edge_index, batch', [
(GATConv(input_feat, hidden_channels, aggr = 'max'), 'x, edge_index -> x'),
BatchNorm(hidden_channels, track_running_stats=False),
torch.nn.PReLU(),
(SAGEConv(hidden_channels, hidden_channels, aggr = 'max'), 'x, edge_index -> x'),
BatchNorm(hidden_channels, track_running_stats=False),
torch.nn.PReLU(),
(SAGPooling(hidden_channels), 'x, edge_index, batch=batch -> x'),
])
# self.Gsage = Sequential('x, edge_index, batch', [
# (GATConv(input_feat, hidden_channels, aggr = 'max'), 'x, edge_index -> x'),
# BatchNorm(hidden_channels, track_running_stats=False),
# torch.nn.PReLU(),
# (SAGEConv(hidden_channels, hidden_channels, aggr = 'max'), 'x, edge_index -> x'),
# BatchNorm(hidden_channels, track_running_stats=False),
# torch.nn.PReLU(),
# (SAGPooling(hidden_channels), 'x, edge_index, batch=batch -> x'),
# ])
self.fc1 = torch.nn.Linear(hidden_channels * 2 + num_additional, hidden_channels)
self.leaky1 = torch.nn.PReLU()
self.fc2 = torch.nn.Linear(hidden_channels, 4) # Output 4 parameters of NIG distribution
self.evidence = torch.nn.Softplus()
def forward(self, Abatch: torch.Tensor, Gbatch: torch.Tensor, add_features: torch.Tensor):
'''
'''
# Decompose X into acid and glycol
Aembeddings = self.Asage(Abatch.x, Abatch.edge_index, Abatch.batch)[0]
Gembeddings = self.Asage(Gbatch.x, Gbatch.edge_index, Gbatch.batch)[0]
# self.saveA_embed = Aembed.clone()
# self.saveG_embed = Gembed.clone()
Aembed, _ = torch.max(Aembeddings, dim=0)
Gembed, _ = torch.max(Gembeddings, dim=0)
# Aggregate pooled vectors
if add_features is not None:
poolAgg = torch.cat([Aembed, Gembed, add_features])
else:
poolAgg = torch.cat([Aembed, Gembed])
x = self.leaky1(self.fc1(poolAgg))
# Unpack last layer:
gamma, logv, logalpha, logbeta = self.fc2(x).squeeze()
# Because we predict log, exp transform gamma:
#gamma = torch.exp(loggamma)
# KEEP IN MIND: may need to alter distribution b/c exp of gamma
# Activate all other parameters:
v = self.evidence(logv)
alpha = self.evidence(logalpha) + 1
beta = self.evidence(logbeta)
#yhat = torch.exp(self.fc2(x))
multi_dict = {
'gamma': gamma,
'v': v,
'alpha': alpha,
'beta': beta
}
return multi_dict |