|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class DihedralFeatures(nn.Module): |
|
def __init__(self, node_embed_dim): |
|
""" Embed dihedral angle features. """ |
|
super(DihedralFeatures, self).__init__() |
|
|
|
node_in = 6 |
|
|
|
self.node_embedding = nn.Linear(node_in, node_embed_dim, bias=True) |
|
self.norm_nodes = Normalize(node_embed_dim) |
|
|
|
def forward(self, X): |
|
""" Featurize coordinates as an attributed graph """ |
|
with torch.no_grad(): |
|
V = self._dihedrals(X) |
|
V = V.squeeze(1) |
|
V = self.node_embedding(V) |
|
V = self.norm_nodes(V) |
|
return V |
|
|
|
@staticmethod |
|
def _dihedrals(X, eps=1e-7,): |
|
|
|
if len(X.shape) == 4: |
|
X = X[..., :3, :].reshape(X.shape[0], 3*X.shape[1], 3) |
|
else: |
|
X = X[:, :3, :] |
|
|
|
|
|
dX = X[:,1:,:] - X[:,:-1,:] |
|
U = F.normalize(dX, dim=-1) |
|
u_2 = U[:,:-2,:] |
|
u_1 = U[:,1:-1,:] |
|
u_0 = U[:,2:,:] |
|
|
|
n_2 = F.normalize(torch.cross(u_2, u_1, dim=-1), dim=-1) |
|
n_1 = F.normalize(torch.cross(u_1, u_0, dim=-1), dim=-1) |
|
|
|
|
|
cosD = (n_2 * n_1).sum(-1) |
|
cosD = torch.clamp(cosD, -1+eps, 1-eps) |
|
D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD) |
|
|
|
|
|
D = F.pad(D, (1,2), 'constant', 0) |
|
D = D.view((D.size(0), int(D.size(1)/3), 3)) |
|
|
|
|
|
D_features = torch.cat((torch.cos(D), torch.sin(D)), 2) |
|
return D_features |
|
|
|
|
|
class Normalize(nn.Module): |
|
def __init__(self, features, epsilon=1e-6): |
|
super(Normalize, self).__init__() |
|
self.gain = nn.Parameter(torch.ones(features)) |
|
self.bias = nn.Parameter(torch.zeros(features)) |
|
self.epsilon = epsilon |
|
|
|
def forward(self, x, dim=-1): |
|
mu = x.mean(dim, keepdim=True) |
|
sigma = torch.sqrt(x.var(dim, keepdim=True) + self.epsilon) |
|
gain = self.gain |
|
bias = self.bias |
|
|
|
if dim != -1: |
|
shape = [1] * len(mu.size()) |
|
shape[dim] = self.gain.size()[0] |
|
gain = gain.view(shape) |
|
bias = bias.view(shape) |
|
return gain * (x - mu) / (sigma + self.epsilon) + bias |