ReCEP / src /bce /model /dihedral.py
NielTT's picture
Upload 108 files
e611d1f verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class DihedralFeatures(nn.Module):
def __init__(self, node_embed_dim):
""" Embed dihedral angle features. """
super(DihedralFeatures, self).__init__()
# 3 dihedral angles; sin and cos of each angle
node_in = 6
# Normalization and embedding
self.node_embedding = nn.Linear(node_in, node_embed_dim, bias=True)
self.norm_nodes = Normalize(node_embed_dim)
def forward(self, X):
""" Featurize coordinates as an attributed graph """
with torch.no_grad():
V = self._dihedrals(X)
V = V.squeeze(1)
V = self.node_embedding(V)
V = self.norm_nodes(V)
return V
@staticmethod
def _dihedrals(X, eps=1e-7,):
# First 3 coordinates are [N, CA, C] / [C4', C1', N1/N9]
if len(X.shape) == 4:
X = X[..., :3, :].reshape(X.shape[0], 3*X.shape[1], 3)
else:
X = X[:, :3, :]
# Shifted slices of unit vectors
dX = X[:,1:,:] - X[:,:-1,:]
U = F.normalize(dX, dim=-1)
u_2 = U[:,:-2,:]
u_1 = U[:,1:-1,:]
u_0 = U[:,2:,:]
# Backbone normals
n_2 = F.normalize(torch.cross(u_2, u_1, dim=-1), dim=-1)
n_1 = F.normalize(torch.cross(u_1, u_0, dim=-1), dim=-1)
# Angle between normals
cosD = (n_2 * n_1).sum(-1)
cosD = torch.clamp(cosD, -1+eps, 1-eps)
D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD)
# This scheme will remove phi[0], psi[-1], omega[-1]
D = F.pad(D, (1,2), 'constant', 0)
D = D.view((D.size(0), int(D.size(1)/3), 3))
# Lift angle representations to the circle
D_features = torch.cat((torch.cos(D), torch.sin(D)), 2)
return D_features
class Normalize(nn.Module):
def __init__(self, features, epsilon=1e-6):
super(Normalize, self).__init__()
self.gain = nn.Parameter(torch.ones(features))
self.bias = nn.Parameter(torch.zeros(features))
self.epsilon = epsilon
def forward(self, x, dim=-1):
mu = x.mean(dim, keepdim=True)
sigma = torch.sqrt(x.var(dim, keepdim=True) + self.epsilon)
gain = self.gain
bias = self.bias
# Reshape
if dim != -1:
shape = [1] * len(mu.size())
shape[dim] = self.gain.size()[0]
gain = gain.view(shape)
bias = bias.view(shape)
return gain * (x - mu) / (sigma + self.epsilon) + bias