bndl's picture
Upload 115 files
4f5540c
import torch
from torch_geometric.nn import SAGEConv, GATConv, Sequential, BatchNorm
from torch_geometric.nn import SAGPooling
class PolymerGNN_IV_EXPLAIN(torch.nn.Module):
def __init__(self, input_feat, hidden_channels, num_additional = 0):
super(PolymerGNN_IV_EXPLAIN, self).__init__()
self.hidden_channels = hidden_channels
self.Asage = Sequential('x, edge_index, batch', [
(GATConv(input_feat, hidden_channels, aggr = 'max'), 'x, edge_index -> x'),
BatchNorm(hidden_channels, track_running_stats=False),
torch.nn.PReLU(),
(SAGEConv(hidden_channels, hidden_channels, aggr = 'max'), 'x, edge_index -> x'),
BatchNorm(hidden_channels, track_running_stats=False),
torch.nn.PReLU(),
(SAGPooling(hidden_channels), 'x, edge_index, batch=batch -> x'),
])
self.Gsage = Sequential('x, edge_index, batch', [
(GATConv(input_feat, hidden_channels, aggr = 'max'), 'x, edge_index -> x'),
BatchNorm(hidden_channels, track_running_stats=False),
torch.nn.PReLU(),
(SAGEConv(hidden_channels, hidden_channels, aggr = 'max'), 'x, edge_index -> x'),
BatchNorm(hidden_channels, track_running_stats=False),
torch.nn.PReLU(),
(SAGPooling(hidden_channels), 'x, edge_index, batch=batch -> x'),
])
self.fc1 = torch.nn.Linear(hidden_channels * 2 + num_additional, hidden_channels)
self.leaky1 = torch.nn.PReLU()
self.fc2 = torch.nn.Linear(hidden_channels, 1)
def forward(self,
Abatch_X: torch.Tensor,
Abatch_edge_index: torch.Tensor,
Abatch_batch: torch.Tensor,
Gbatch_X: torch.Tensor,
Gbatch_edge_index: torch.Tensor,
Gbatch_batch: torch.Tensor,
add_features: torch.Tensor):
'''
Only thing that's different is the forward method
'''
# Decompose X into acid and glycol
Aembeddings = self.Asage(Abatch_X, Abatch_edge_index, Abatch_batch)[0]
Gembeddings = self.Gsage(Gbatch_X, Gbatch_edge_index, Gbatch_batch)[0]
# self.saveAembed = Aembeddings.clone()
# self.saveGembed = Gembeddings.clone()
Aembed, _ = torch.max(Aembeddings, dim=0)
Gembed, _ = torch.max(Gembeddings, dim=0)
# Aggregate pooled vectors
if add_features is not None:
poolAgg = torch.cat([Aembed, Gembed, add_features])
else:
poolAgg = torch.cat([Aembed, Gembed])
x = self.leaky1(self.fc1(poolAgg))
x = self.fc2(x)
# Because we're predicting log:
return torch.exp(x)
class PolymerGNN_IVMono_EXPLAIN(torch.nn.Module):
def __init__(self, input_feat, hidden_channels, num_additional = 0):
super(PolymerGNN_IVMono_EXPLAIN, self).__init__()
self.hidden_channels = hidden_channels
self.sage = Sequential('x, edge_index, batch', [
(GATConv(input_feat, hidden_channels, aggr = 'max'), 'x, edge_index -> x'),
BatchNorm(hidden_channels, track_running_stats=False),
torch.nn.PReLU(),
(SAGEConv(hidden_channels, hidden_channels, aggr = 'max'), 'x, edge_index -> x'),
BatchNorm(hidden_channels, track_running_stats=False),
torch.nn.PReLU(),
(SAGPooling(hidden_channels), 'x, edge_index, batch=batch -> x'),
])
self.fc1 = torch.nn.Linear(hidden_channels * 2 + num_additional, hidden_channels)
self.leaky1 = torch.nn.PReLU()
self.fc2 = torch.nn.Linear(hidden_channels, 1)
def forward(self,
Abatch_X: torch.Tensor,
Abatch_edge_index: torch.Tensor,
Abatch_batch: torch.Tensor,
Gbatch_X: torch.Tensor,
Gbatch_edge_index: torch.Tensor,
Gbatch_batch: torch.Tensor,
add_features: torch.Tensor):
'''
Only thing that's different is the forward method
'''
# Decompose X into acid and glycol
Aembeddings = self.sage(Abatch_X, Abatch_edge_index, Abatch_batch)[0]
Gembeddings = self.sage(Gbatch_X, Gbatch_edge_index, Gbatch_batch)[0]
# self.saveAembed = Aembeddings.clone()
# self.saveGembed = Gembeddings.clone()
Aembed, _ = torch.max(Aembeddings, dim=0)
Gembed, _ = torch.max(Gembeddings, dim=0)
# Aggregate pooled vectors
if add_features is not None:
poolAgg = torch.cat([Aembed, Gembed, add_features])
else:
poolAgg = torch.cat([Aembed, Gembed])
x = self.leaky1(self.fc1(poolAgg))
x = self.fc2(x)
# Because we're predicting log:
return x
class PolymerGNN_Tg_EXPLAIN(torch.nn.Module):
def __init__(self, input_feat, hidden_channels, num_additional = 0):
super(PolymerGNN_Tg_EXPLAIN, self).__init__()
self.hidden_channels = hidden_channels
self.Asage = Sequential('x, edge_index, batch', [
(GATConv(input_feat, hidden_channels, aggr = 'max'), 'x, edge_index -> x'),
BatchNorm(hidden_channels, track_running_stats=False),
torch.nn.PReLU(),
(SAGEConv(hidden_channels, hidden_channels, aggr = 'max'), 'x, edge_index -> x'),
BatchNorm(hidden_channels, track_running_stats=False),
torch.nn.PReLU(),
(SAGPooling(hidden_channels), 'x, edge_index, batch=batch -> x'),
])
self.Gsage = Sequential('x, edge_index, batch', [
(GATConv(input_feat, hidden_channels, aggr = 'max'), 'x, edge_index -> x'),
BatchNorm(hidden_channels, track_running_stats=False),
torch.nn.PReLU(),
(SAGEConv(hidden_channels, hidden_channels, aggr = 'max'), 'x, edge_index -> x'),
BatchNorm(hidden_channels, track_running_stats=False),
torch.nn.PReLU(),
(SAGPooling(hidden_channels), 'x, edge_index, batch=batch -> x'),
])
self.fc1 = torch.nn.Linear(hidden_channels * 2 + num_additional, hidden_channels)
self.leaky1 = torch.nn.PReLU()
self.fc2 = torch.nn.Linear(hidden_channels, 1)
self.mult_factor = torch.nn.Linear(hidden_channels, 1)
def forward(self,
Abatch_X: torch.Tensor,
Abatch_edge_index: torch.Tensor,
Abatch_batch: torch.Tensor,
Gbatch_X: torch.Tensor,
Gbatch_edge_index: torch.Tensor,
Gbatch_batch: torch.Tensor,
add_features: torch.Tensor):
'''
'''
# Decompose X into acid and glycol
Aembeddings = self.Asage(Abatch_X, Abatch_edge_index, Abatch_batch)[0]
Gembeddings = self.Gsage(Gbatch_X, Gbatch_edge_index, Gbatch_batch)[0]
Aembed, _ = torch.max(Aembeddings, dim=0)
Gembed, _ = torch.max(Gembeddings, dim=0)
# Aggregate pooled vectors
if add_features is not None:
poolAgg = torch.cat([Aembed, Gembed, add_features])
else:
poolAgg = torch.cat([Aembed, Gembed])
x = self.leaky1(self.fc1(poolAgg))
pred = self.fc2(x)
factor = self.mult_factor(x).tanh()
# Because we're predicting log:
return torch.exp(pred) * factor