|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch_scatter import scatter |
|
from torch_geometric.nn import InstanceNorm |
|
|
|
class EGNNLayer(nn.Module): |
|
""" |
|
EGNN layer with optional feed forward network and batch normalization. |
|
|
|
Args: |
|
input_nf: Number of input node features |
|
output_nf: Number of output node features |
|
hidden_nf: Number of hidden features |
|
edges_in_d: Number of input edge features |
|
act_fn: Activation function |
|
residual: Whether to use residual connections |
|
attention: Whether to use attention mechanism for edge features |
|
normalize: Whether to normalize coordinates |
|
coords_agg: Aggregation method for coordinates (mean, sum, max, min) |
|
tanh: Whether to use tanh activation for coordinate updates |
|
dropout: Dropout rate |
|
ffn: Whether to use feed forward network |
|
batch_norm: Whether to use batch normalization |
|
""" |
|
def __init__(self, input_nf, output_nf, hidden_nf, |
|
edges_in_d=0, act_fn=nn.SiLU(), |
|
residual=True, attention=False, normalize=False, |
|
coords_agg='mean', tanh=False, dropout=0.0, |
|
ffn=False, batch_norm=True): |
|
super().__init__() |
|
self.input_nf = input_nf |
|
self.output_nf = output_nf |
|
self.hidden_nf = hidden_nf |
|
self.residual = residual |
|
self.attention = attention |
|
self.normalize = normalize |
|
self.coords_agg = coords_agg |
|
self.tanh = tanh |
|
self.epsilon = 1e-8 |
|
self.dropout = dropout |
|
self.ffn = ffn |
|
self.batch_norm = batch_norm |
|
|
|
|
|
in_edge = input_nf*2 + 1 + edges_in_d |
|
self.edge_mlp = nn.Sequential( |
|
nn.Linear(in_edge, hidden_nf), |
|
act_fn, nn.Dropout(dropout), |
|
nn.Linear(hidden_nf, hidden_nf), |
|
act_fn, nn.Dropout(dropout), |
|
) |
|
if attention: |
|
self.att_mlp = nn.Sequential(nn.Linear(hidden_nf,1), nn.Sigmoid()) |
|
|
|
|
|
layer = nn.Linear(hidden_nf,1, bias=False) |
|
nn.init.xavier_uniform_(layer.weight, gain=0.001) |
|
coord_blocks = [nn.Linear(hidden_nf, hidden_nf), act_fn, |
|
nn.Dropout(dropout), layer] |
|
if tanh: coord_blocks.append(nn.Tanh()) |
|
self.coord_mlp = nn.Sequential(*coord_blocks) |
|
|
|
|
|
self.node_mlp = nn.Sequential( |
|
nn.Linear(hidden_nf + input_nf, hidden_nf), |
|
act_fn, nn.Dropout(dropout), |
|
nn.Linear(hidden_nf, output_nf), |
|
) |
|
|
|
|
|
if batch_norm: |
|
self.norm_node = InstanceNorm(output_nf, affine=True) |
|
self.norm_coord = InstanceNorm(3, affine=True) |
|
|
|
|
|
if ffn: |
|
self.ff1 = nn.Linear(output_nf, output_nf*2) |
|
self.ff2 = nn.Linear(output_nf*2, output_nf) |
|
self.act_ff = act_fn |
|
self.drop_ff = nn.Dropout(dropout) |
|
if batch_norm: |
|
self.norm_ff1 = InstanceNorm(output_nf, affine=True) |
|
self.norm_ff2 = InstanceNorm(output_nf, affine=True) |
|
|
|
def coord2radial(self, edge_index, coord): |
|
row, col = edge_index |
|
diff = coord[row] - coord[col] |
|
dist2 = (diff**2).sum(dim=-1, keepdim=True) |
|
|
|
|
|
dist2 = torch.clamp(dist2, min=self.epsilon, max=100.0) |
|
|
|
if self.normalize: |
|
norm = (dist2.sqrt().detach() + self.epsilon) |
|
diff = diff / norm |
|
|
|
diff = torch.where(torch.isfinite(diff), diff, torch.zeros_like(diff)) |
|
return dist2, diff |
|
|
|
def _ff_block(self, x): |
|
"""Feed Forward block. |
|
""" |
|
x = self.drop_ff(self.act_ff(self.ff1(x))) |
|
return self.ff2(x) |
|
|
|
def forward(self, h, coord, edge_index, batch, edge_attr=None, node_attr=None): |
|
row, col = edge_index |
|
radial, coord_diff = self.coord2radial(edge_index, coord) |
|
|
|
|
|
e_in = [h[row], h[col], radial] |
|
if edge_attr is not None: e_in.append(edge_attr) |
|
e = torch.cat(e_in, dim=-1) |
|
e = self.edge_mlp(e) |
|
if self.attention: |
|
att = self.att_mlp(e) |
|
e = e * att |
|
|
|
|
|
coord_update = self.coord_mlp(e) |
|
|
|
coord_update = torch.clamp(coord_update, -1.0, 1.0) |
|
trans = coord_diff * coord_update |
|
|
|
|
|
trans = torch.where(torch.isfinite(trans), trans, torch.zeros_like(trans)) |
|
|
|
agg_coord = scatter(trans, row, dim=0, |
|
dim_size=coord.size(0), |
|
reduce=self.coords_agg) |
|
coord = coord + agg_coord |
|
|
|
|
|
coord = torch.where(torch.isfinite(coord), coord, torch.zeros_like(coord)) |
|
|
|
if self.batch_norm: |
|
coord = self.norm_coord(coord, batch) |
|
|
|
|
|
agg_node = scatter(e, row, dim=0, |
|
dim_size=h.size(0), reduce='sum') |
|
x_in = torch.cat([h, agg_node], dim=-1) |
|
if node_attr is not None: |
|
x_in = torch.cat([x_in, node_attr], dim=-1) |
|
h_new = self.node_mlp(x_in) |
|
if self.batch_norm: |
|
h_new = self.norm_node(h_new, batch) |
|
if self.residual and h_new.shape[-1] == h.shape[-1]: |
|
h_new = h + h_new |
|
|
|
|
|
if self.ffn: |
|
if self.batch_norm: |
|
h_new = self.norm_ff1(h_new, batch) |
|
h_new = h_new + self._ff_block(h_new) |
|
if self.batch_norm: |
|
h_new = self.norm_ff2(h_new, batch) |
|
|
|
return h_new, coord, e |
|
|
|
class EGNNLayer2(nn.Module): |
|
""" |
|
EGNN layer with optional feed forward network and batch normalization. |
|
|
|
Args: |
|
input_nf: Number of input node features |
|
output_nf: Number of output node features |
|
hidden_nf: Number of hidden features |
|
edges_in_d: Number of input edge features |
|
act_fn: Activation function |
|
residual: Whether to use residual connections |
|
attention: Whether to use attention mechanism for edge features |
|
normalize: Whether to normalize coordinates |
|
coords_agg: Aggregation method for coordinates (mean, sum, max, min) |
|
tanh: Whether to use tanh activation for coordinate updates |
|
dropout: Dropout rate |
|
ffn: Whether to use feed forward network |
|
batch_norm: Whether to use batch normalization |
|
""" |
|
def __init__(self, input_nf, output_nf, hidden_nf, |
|
edges_in_d=0, act_fn=nn.SiLU(), |
|
residual=True, attention=False, normalize=False, |
|
coords_agg='mean', tanh=False, dropout=0.0, |
|
ffn=False, batch_norm=True): |
|
super().__init__() |
|
self.input_nf = input_nf |
|
self.output_nf = output_nf |
|
self.hidden_nf = hidden_nf |
|
self.residual = residual |
|
self.attention = attention |
|
self.normalize = normalize |
|
self.coords_agg = coords_agg |
|
self.tanh = tanh |
|
self.epsilon = 1e-8 |
|
self.dropout = dropout |
|
self.ffn = ffn |
|
self.batch_norm = batch_norm |
|
|
|
|
|
in_edge = input_nf*2 + 1 + edges_in_d |
|
self.edge_mlp = nn.Sequential( |
|
nn.Linear(in_edge, hidden_nf), |
|
act_fn, nn.Dropout(dropout), |
|
nn.Linear(hidden_nf, hidden_nf), |
|
act_fn, nn.Dropout(dropout), |
|
) |
|
if attention: |
|
self.att_mlp = nn.Sequential(nn.Linear(hidden_nf,1), nn.Sigmoid()) |
|
|
|
|
|
layer = nn.Linear(hidden_nf,1, bias=False) |
|
nn.init.xavier_uniform_(layer.weight, gain=0.001) |
|
coord_blocks = [nn.Linear(hidden_nf, hidden_nf), act_fn, |
|
nn.Dropout(dropout), layer] |
|
if tanh: coord_blocks.append(nn.Tanh()) |
|
self.coord_mlp = nn.Sequential(*coord_blocks) |
|
|
|
|
|
self.node_mlp = nn.Sequential( |
|
nn.Linear(hidden_nf + input_nf, hidden_nf), |
|
act_fn, nn.Dropout(dropout), |
|
nn.Linear(hidden_nf, output_nf), |
|
) |
|
|
|
|
|
if batch_norm: |
|
self.norm_node = InstanceNorm(output_nf, affine=True) |
|
self.norm_coord = InstanceNorm(3, affine=True) |
|
|
|
|
|
if ffn: |
|
self.ff1 = nn.Linear(output_nf, output_nf*2) |
|
self.ff2 = nn.Linear(output_nf*2, output_nf) |
|
self.act_ff = act_fn |
|
self.drop_ff = nn.Dropout(dropout) |
|
if batch_norm: |
|
self.norm_ff1 = InstanceNorm(output_nf, affine=True) |
|
self.norm_ff2 = InstanceNorm(output_nf, affine=True) |
|
|
|
def coord2radial(self, edge_index, coord): |
|
row, col = edge_index |
|
diff = coord[row] - coord[col] |
|
dist2 = (diff**2).sum(dim=-1, keepdim=True) |
|
|
|
|
|
dist2 = torch.clamp(dist2, min=self.epsilon, max=100.0) |
|
|
|
if self.normalize: |
|
norm = (dist2.sqrt().detach() + self.epsilon) |
|
diff = diff / norm |
|
|
|
diff = torch.where(torch.isfinite(diff), diff, torch.zeros_like(diff)) |
|
return dist2, diff |
|
|
|
def _ff_block(self, x): |
|
"""Feed Forward block. |
|
""" |
|
x = self.drop_ff(self.act_ff(self.ff1(x))) |
|
return self.ff2(x) |
|
|
|
def forward(self, h, coord, edge_index, batch, edge_attr=None, node_attr=None): |
|
row, col = edge_index |
|
radial, coord_diff = self.coord2radial(edge_index, coord) |
|
|
|
|
|
e_in = [h[row], h[col], radial] |
|
if edge_attr is not None: e_in.append(edge_attr) |
|
e = torch.cat(e_in, dim=-1) |
|
e = self.edge_mlp(e) |
|
if self.attention: |
|
att = self.att_mlp(e) |
|
e = e * att |
|
|
|
|
|
coord_update = self.coord_mlp(e) |
|
|
|
coord_update = torch.clamp(coord_update, -1.0, 1.0) |
|
trans = coord_diff * coord_update |
|
|
|
|
|
trans = torch.where(torch.isfinite(trans), trans, torch.zeros_like(trans)) |
|
|
|
agg_coord = scatter(trans, row, dim=0, |
|
dim_size=coord.size(0), |
|
reduce=self.coords_agg) |
|
coord = coord + agg_coord |
|
|
|
|
|
coord = torch.where(torch.isfinite(coord), coord, torch.zeros_like(coord)) |
|
|
|
if self.batch_norm: |
|
coord = self.norm_coord(coord, batch) |
|
|
|
|
|
agg_node = scatter(e, row, dim=0, |
|
dim_size=h.size(0), reduce='sum') |
|
x_in = torch.cat([h, agg_node], dim=-1) |
|
if node_attr is not None: |
|
x_in = torch.cat([x_in, node_attr], dim=-1) |
|
h_new = self.node_mlp(x_in) |
|
if self.batch_norm: |
|
h_new = self.norm_node(h_new, batch) |
|
if self.residual and h_new.shape[-1] == h.shape[-1]: |
|
h_new = h + h_new |
|
|
|
|
|
if self.ffn: |
|
if self.batch_norm: |
|
h_new = self.norm_ff1(h_new, batch) |
|
h_new = h_new + self._ff_block(h_new) |
|
if self.batch_norm: |
|
h_new = self.norm_ff2(h_new, batch) |
|
|
|
return h_new, coord, e |