import torch import torch.nn as nn import torch.nn.functional as F class DihedralFeatures(nn.Module): def __init__(self, node_embed_dim): """ Embed dihedral angle features. """ super(DihedralFeatures, self).__init__() # 3 dihedral angles; sin and cos of each angle node_in = 6 # Normalization and embedding self.node_embedding = nn.Linear(node_in, node_embed_dim, bias=True) self.norm_nodes = Normalize(node_embed_dim) def forward(self, X): """ Featurize coordinates as an attributed graph """ with torch.no_grad(): V = self._dihedrals(X) V = V.squeeze(1) V = self.node_embedding(V) V = self.norm_nodes(V) return V @staticmethod def _dihedrals(X, eps=1e-7,): # First 3 coordinates are [N, CA, C] / [C4', C1', N1/N9] if len(X.shape) == 4: X = X[..., :3, :].reshape(X.shape[0], 3*X.shape[1], 3) else: X = X[:, :3, :] # Shifted slices of unit vectors dX = X[:,1:,:] - X[:,:-1,:] U = F.normalize(dX, dim=-1) u_2 = U[:,:-2,:] u_1 = U[:,1:-1,:] u_0 = U[:,2:,:] # Backbone normals n_2 = F.normalize(torch.cross(u_2, u_1, dim=-1), dim=-1) n_1 = F.normalize(torch.cross(u_1, u_0, dim=-1), dim=-1) # Angle between normals cosD = (n_2 * n_1).sum(-1) cosD = torch.clamp(cosD, -1+eps, 1-eps) D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD) # This scheme will remove phi[0], psi[-1], omega[-1] D = F.pad(D, (1,2), 'constant', 0) D = D.view((D.size(0), int(D.size(1)/3), 3)) # Lift angle representations to the circle D_features = torch.cat((torch.cos(D), torch.sin(D)), 2) return D_features class Normalize(nn.Module): def __init__(self, features, epsilon=1e-6): super(Normalize, self).__init__() self.gain = nn.Parameter(torch.ones(features)) self.bias = nn.Parameter(torch.zeros(features)) self.epsilon = epsilon def forward(self, x, dim=-1): mu = x.mean(dim, keepdim=True) sigma = torch.sqrt(x.var(dim, keepdim=True) + self.epsilon) gain = self.gain bias = self.bias # Reshape if dim != -1: shape = [1] * len(mu.size()) shape[dim] = self.gain.size()[0] gain = gain.view(shape) bias = bias.view(shape) return gain * (x - mu) / (sigma + self.epsilon) + bias