ReCEP / src /bce /model /EGNN.py
NielTT's picture
Upload 108 files
e611d1f verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter
from torch_geometric.nn import InstanceNorm
class EGNNLayer(nn.Module):
"""
EGNN layer with optional feed forward network and batch normalization.
Args:
input_nf: Number of input node features
output_nf: Number of output node features
hidden_nf: Number of hidden features
edges_in_d: Number of input edge features
act_fn: Activation function
residual: Whether to use residual connections
attention: Whether to use attention mechanism for edge features
normalize: Whether to normalize coordinates
coords_agg: Aggregation method for coordinates (mean, sum, max, min)
tanh: Whether to use tanh activation for coordinate updates
dropout: Dropout rate
ffn: Whether to use feed forward network
batch_norm: Whether to use batch normalization
"""
def __init__(self, input_nf, output_nf, hidden_nf,
edges_in_d=0, act_fn=nn.SiLU(),
residual=True, attention=False, normalize=False,
coords_agg='mean', tanh=False, dropout=0.0,
ffn=False, batch_norm=True):
super().__init__()
self.input_nf = input_nf
self.output_nf = output_nf
self.hidden_nf = hidden_nf
self.residual = residual
self.attention = attention
self.normalize = normalize
self.coords_agg = coords_agg
self.tanh = tanh
self.epsilon = 1e-8
self.dropout = dropout
self.ffn = ffn
self.batch_norm = batch_norm
# Edge MLP
in_edge = input_nf*2 + 1 + edges_in_d
self.edge_mlp = nn.Sequential(
nn.Linear(in_edge, hidden_nf),
act_fn, nn.Dropout(dropout),
nn.Linear(hidden_nf, hidden_nf),
act_fn, nn.Dropout(dropout),
)
if attention:
self.att_mlp = nn.Sequential(nn.Linear(hidden_nf,1), nn.Sigmoid())
# Coord MLP
layer = nn.Linear(hidden_nf,1, bias=False)
nn.init.xavier_uniform_(layer.weight, gain=0.001)
coord_blocks = [nn.Linear(hidden_nf, hidden_nf), act_fn,
nn.Dropout(dropout), layer]
if tanh: coord_blocks.append(nn.Tanh())
self.coord_mlp = nn.Sequential(*coord_blocks)
# Node MLP
self.node_mlp = nn.Sequential(
nn.Linear(hidden_nf + input_nf, hidden_nf),
act_fn, nn.Dropout(dropout),
nn.Linear(hidden_nf, output_nf),
)
# per-graph normalization
if batch_norm:
self.norm_node = InstanceNorm(output_nf, affine=True)
self.norm_coord = InstanceNorm(3, affine=True)
# FFN
if ffn:
self.ff1 = nn.Linear(output_nf, output_nf*2)
self.ff2 = nn.Linear(output_nf*2, output_nf)
self.act_ff = act_fn
self.drop_ff = nn.Dropout(dropout)
if batch_norm:
self.norm_ff1 = InstanceNorm(output_nf, affine=True)
self.norm_ff2 = InstanceNorm(output_nf, affine=True)
def coord2radial(self, edge_index, coord):
row, col = edge_index
diff = coord[row] - coord[col]
dist2 = (diff**2).sum(dim=-1, keepdim=True)
# Clamp distance to prevent extreme values
dist2 = torch.clamp(dist2, min=self.epsilon, max=100.0)
if self.normalize:
norm = (dist2.sqrt().detach() + self.epsilon)
diff = diff / norm
# Check for NaN/Inf in normalized diff
diff = torch.where(torch.isfinite(diff), diff, torch.zeros_like(diff))
return dist2, diff
def _ff_block(self, x):
"""Feed Forward block.
"""
x = self.drop_ff(self.act_ff(self.ff1(x)))
return self.ff2(x)
def forward(self, h, coord, edge_index, batch, edge_attr=None, node_attr=None):
row, col = edge_index
radial, coord_diff = self.coord2radial(edge_index, coord)
# -- edge features --
e_in = [h[row], h[col], radial]
if edge_attr is not None: e_in.append(edge_attr)
e = torch.cat(e_in, dim=-1)
e = self.edge_mlp(e)
if self.attention:
att = self.att_mlp(e)
e = e * att
# -- coordinate update --
coord_update = self.coord_mlp(e) # [E,1]
# Clamp coordinate updates to prevent explosion
coord_update = torch.clamp(coord_update, -1.0, 1.0)
trans = coord_diff * coord_update # [E,3]
# Check for NaN/Inf in coordinate updates
trans = torch.where(torch.isfinite(trans), trans, torch.zeros_like(trans))
agg_coord = scatter(trans, row, dim=0,
dim_size=coord.size(0),
reduce=self.coords_agg)
coord = coord + agg_coord
# Check for NaN/Inf in final coordinates
coord = torch.where(torch.isfinite(coord), coord, torch.zeros_like(coord))
if self.batch_norm:
coord = self.norm_coord(coord, batch)
# -- node update --
agg_node = scatter(e, row, dim=0,
dim_size=h.size(0), reduce='sum')
x_in = torch.cat([h, agg_node], dim=-1)
if node_attr is not None:
x_in = torch.cat([x_in, node_attr], dim=-1)
h_new = self.node_mlp(x_in)
if self.batch_norm:
h_new = self.norm_node(h_new, batch)
if self.residual and h_new.shape[-1] == h.shape[-1]:
h_new = h + h_new
# -- optional FFN --
if self.ffn:
if self.batch_norm:
h_new = self.norm_ff1(h_new, batch)
h_new = h_new + self._ff_block(h_new)
if self.batch_norm:
h_new = self.norm_ff2(h_new, batch)
return h_new, coord, e
class EGNNLayer2(nn.Module):
"""
EGNN layer with optional feed forward network and batch normalization.
Args:
input_nf: Number of input node features
output_nf: Number of output node features
hidden_nf: Number of hidden features
edges_in_d: Number of input edge features
act_fn: Activation function
residual: Whether to use residual connections
attention: Whether to use attention mechanism for edge features
normalize: Whether to normalize coordinates
coords_agg: Aggregation method for coordinates (mean, sum, max, min)
tanh: Whether to use tanh activation for coordinate updates
dropout: Dropout rate
ffn: Whether to use feed forward network
batch_norm: Whether to use batch normalization
"""
def __init__(self, input_nf, output_nf, hidden_nf,
edges_in_d=0, act_fn=nn.SiLU(),
residual=True, attention=False, normalize=False,
coords_agg='mean', tanh=False, dropout=0.0,
ffn=False, batch_norm=True):
super().__init__()
self.input_nf = input_nf
self.output_nf = output_nf
self.hidden_nf = hidden_nf
self.residual = residual
self.attention = attention
self.normalize = normalize
self.coords_agg = coords_agg
self.tanh = tanh
self.epsilon = 1e-8
self.dropout = dropout
self.ffn = ffn
self.batch_norm = batch_norm
# Edge MLP
in_edge = input_nf*2 + 1 + edges_in_d
self.edge_mlp = nn.Sequential(
nn.Linear(in_edge, hidden_nf),
act_fn, nn.Dropout(dropout),
nn.Linear(hidden_nf, hidden_nf),
act_fn, nn.Dropout(dropout),
)
if attention:
self.att_mlp = nn.Sequential(nn.Linear(hidden_nf,1), nn.Sigmoid())
# Coord MLP
layer = nn.Linear(hidden_nf,1, bias=False)
nn.init.xavier_uniform_(layer.weight, gain=0.001)
coord_blocks = [nn.Linear(hidden_nf, hidden_nf), act_fn,
nn.Dropout(dropout), layer]
if tanh: coord_blocks.append(nn.Tanh())
self.coord_mlp = nn.Sequential(*coord_blocks)
# Node MLP
self.node_mlp = nn.Sequential(
nn.Linear(hidden_nf + input_nf, hidden_nf),
act_fn, nn.Dropout(dropout),
nn.Linear(hidden_nf, output_nf),
)
# per-graph normalization
if batch_norm:
self.norm_node = InstanceNorm(output_nf, affine=True)
self.norm_coord = InstanceNorm(3, affine=True)
# FFN
if ffn:
self.ff1 = nn.Linear(output_nf, output_nf*2)
self.ff2 = nn.Linear(output_nf*2, output_nf)
self.act_ff = act_fn
self.drop_ff = nn.Dropout(dropout)
if batch_norm:
self.norm_ff1 = InstanceNorm(output_nf, affine=True)
self.norm_ff2 = InstanceNorm(output_nf, affine=True)
def coord2radial(self, edge_index, coord):
row, col = edge_index
diff = coord[row] - coord[col]
dist2 = (diff**2).sum(dim=-1, keepdim=True)
# Clamp distance to prevent extreme values
dist2 = torch.clamp(dist2, min=self.epsilon, max=100.0)
if self.normalize:
norm = (dist2.sqrt().detach() + self.epsilon)
diff = diff / norm
# Check for NaN/Inf in normalized diff
diff = torch.where(torch.isfinite(diff), diff, torch.zeros_like(diff))
return dist2, diff
def _ff_block(self, x):
"""Feed Forward block.
"""
x = self.drop_ff(self.act_ff(self.ff1(x)))
return self.ff2(x)
def forward(self, h, coord, edge_index, batch, edge_attr=None, node_attr=None):
row, col = edge_index
radial, coord_diff = self.coord2radial(edge_index, coord)
# -- edge features --
e_in = [h[row], h[col], radial]
if edge_attr is not None: e_in.append(edge_attr)
e = torch.cat(e_in, dim=-1)
e = self.edge_mlp(e)
if self.attention:
att = self.att_mlp(e)
e = e * att
# -- coordinate update --
coord_update = self.coord_mlp(e) # [E,1]
# Clamp coordinate updates to prevent explosion
coord_update = torch.clamp(coord_update, -1.0, 1.0)
trans = coord_diff * coord_update # [E,3]
# Check for NaN/Inf in coordinate updates
trans = torch.where(torch.isfinite(trans), trans, torch.zeros_like(trans))
agg_coord = scatter(trans, row, dim=0,
dim_size=coord.size(0),
reduce=self.coords_agg)
coord = coord + agg_coord
# Check for NaN/Inf in final coordinates
coord = torch.where(torch.isfinite(coord), coord, torch.zeros_like(coord))
if self.batch_norm:
coord = self.norm_coord(coord, batch)
# -- node update --
agg_node = scatter(e, row, dim=0,
dim_size=h.size(0), reduce='sum')
x_in = torch.cat([h, agg_node], dim=-1)
if node_attr is not None:
x_in = torch.cat([x_in, node_attr], dim=-1)
h_new = self.node_mlp(x_in)
if self.batch_norm:
h_new = self.norm_node(h_new, batch)
if self.residual and h_new.shape[-1] == h.shape[-1]:
h_new = h + h_new
# -- optional FFN --
if self.ffn:
if self.batch_norm:
h_new = self.norm_ff1(h_new, batch)
h_new = h_new + self._ff_block(h_new)
if self.batch_norm:
h_new = self.norm_ff2(h_new, batch)
return h_new, coord, e