Create models/graph.py
Browse files- models/graph.py +140 -0
models/graph.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
import pdb
|
6 |
+
|
7 |
+
class ProteinGraph(nn.Module):
|
8 |
+
def __init__(self, d_node, d_edge, d_position):
|
9 |
+
super(ProteinGraph, self).__init__()
|
10 |
+
self.d_node = d_node
|
11 |
+
self.d_edge = d_edge
|
12 |
+
self.d_position = d_position
|
13 |
+
|
14 |
+
d_node_original = 1280 + 8 + d_position
|
15 |
+
self.node_mapping = nn.Linear(d_node_original, self.d_node)
|
16 |
+
self.linear_edge = nn.Linear(1, d_edge)
|
17 |
+
|
18 |
+
vhse8_values = {
|
19 |
+
'A': [0.15, -1.11, -1.35, -0.92, 0.02, -0.91, 0.36, -0.48],
|
20 |
+
'R': [-1.47, 1.45, 1.24, 1.27, 1.55, 1.47, 1.30, 0.83],
|
21 |
+
'N': [-0.99, 0.00, 0.69, -0.37, -0.55, 0.85, 0.73, -0.80],
|
22 |
+
'D': [-1.15, 0.67, -0.41, -0.01, -2.68, 1.31, 0.03, 0.56],
|
23 |
+
'C': [0.18, -1.67, -0.21, 0.00, 1.20, -1.61, -0.19, -0.41],
|
24 |
+
'Q': [-0.96, 0.12, 0.18, 0.16, 0.09, 0.42, -0.20, -0.41],
|
25 |
+
'E': [-1.18, 0.40, 0.10, 0.36, -2.16, -0.17, 0.91, 0.36],
|
26 |
+
'G': [-0.20, -1.53, -2.63, 2.28, -0.53, -1.18, -1.34, 1.10],
|
27 |
+
'H': [-0.43, -0.25, 0.37, 0.19, 0.51, 1.28, 0.93, 0.65],
|
28 |
+
'I': [1.27, 0.14, 0.30, -1.80, 0.30, -1.61, -0.16, -0.13],
|
29 |
+
'L': [1.36, 0.07, 0.26, -0.80, 0.22, -1.37, 0.08, -0.62],
|
30 |
+
'K': [-1.17, 0.70, 0.80, 1.64, 0.67, 1.63, 0.13, -0.01],
|
31 |
+
'M': [1.01, -0.53, 0.43, 0.00, 0.23, 0.10, -0.86, -0.68],
|
32 |
+
'F': [1.52, 0.61, 0.95, -0.16, 0.25, 0.28, -1.33, -0.65],
|
33 |
+
'P': [0.22, -0.17, -0.50, -0.05, 0.01, -1.34, 0.19, 3.56],
|
34 |
+
'S': [-0.67, -0.86, -1.07, -0.41, -0.32, 0.27, -0.64, 0.11],
|
35 |
+
'T': [-0.34, -0.51, -0.55, -1.06, 0.01, -0.01, -0.79, 0.39],
|
36 |
+
'W': [1.50, 2.06, 1.79, 0.75, 0.75, 0.13, -1.06, -0.85],
|
37 |
+
'Y': [0.61, 1.60, 1.17, 0.73, 0.53, 0.25, -0.96, -0.52],
|
38 |
+
'V': [0.76, -0.92, 0.17, -1.91, 0.22, -1.40, -0.24, -0.03],
|
39 |
+
'X': [0.15, -1.11, -1.35, -0.92, 0.02, -0.91, 0.36, -0.48],
|
40 |
+
'B': [-1.15, 0.67, -0.41, -0.01, -2.68, 1.31, 0.03, 0.56],
|
41 |
+
}
|
42 |
+
|
43 |
+
aa_to_idx = {'A': 5, 'R': 10, 'N': 17, 'D': 13, 'C': 23, 'Q': 16, 'E': 9, 'G': 6, 'H': 21, 'I': 12, 'L': 4, 'K': 15, 'M': 20, 'F': 18, 'P': 14, 'S': 8, 'T': 11, 'W': 22, 'Y': 19, 'V': 7, 'X': 24, 'B': 25}
|
44 |
+
|
45 |
+
self.vhse8_tensor = torch.zeros(26, 8)
|
46 |
+
for aa, values in vhse8_values.items():
|
47 |
+
aa_index = aa_to_idx[aa]
|
48 |
+
self.vhse8_tensor[aa_index] = torch.tensor(values)
|
49 |
+
self.vhse8_tensor.requires_grad = False
|
50 |
+
# self.position_embedding = nn.Embedding(seq_len, self.d_position)
|
51 |
+
|
52 |
+
# def one_hot_encoding(self, seq_len):
|
53 |
+
# positions = torch.arange(seq_len).unsqueeze(1)
|
54 |
+
# one_hot = torch.nn.functional.one_hot(positions, num_classes=seq_len).squeeze(1)
|
55 |
+
# return one_hot
|
56 |
+
|
57 |
+
def create_sinusoidal_embeddings(self, seq_len, d_position):
|
58 |
+
position = torch.arange(seq_len).unsqueeze(1)
|
59 |
+
div_term = torch.exp(torch.arange(0, d_position, 2) * -(math.log(10000.0) / d_position))
|
60 |
+
pe = torch.zeros(seq_len, d_position)
|
61 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
62 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
63 |
+
pe = pe.unsqueeze(0) # shape: (1, seq_len, d_position)
|
64 |
+
return pe
|
65 |
+
|
66 |
+
|
67 |
+
def add_cls_eos(self, tensor):
|
68 |
+
modified_tensor = []
|
69 |
+
|
70 |
+
for row in tensor:
|
71 |
+
new_row = [0] # Start with 0 at the beginning
|
72 |
+
ones_indices = (row == 1).nonzero(as_tuple=True)[0]
|
73 |
+
|
74 |
+
if len(ones_indices) > 0:
|
75 |
+
# Add 2 before the first occurrence of 1
|
76 |
+
first_one_idx = ones_indices[0].item()
|
77 |
+
new_row.extend(row[:first_one_idx].tolist()) # Add elements before the first 1
|
78 |
+
new_row.append(2) # Add 2 before the first 1
|
79 |
+
new_row.extend(row[first_one_idx:].tolist()) # Add the rest of the row
|
80 |
+
else:
|
81 |
+
# No 1 in the row, add 2 at the end
|
82 |
+
new_row.extend(row.tolist())
|
83 |
+
new_row.append(2) # Add 2 at the end
|
84 |
+
|
85 |
+
modified_tensor.append(torch.tensor(new_row))
|
86 |
+
|
87 |
+
return torch.stack(modified_tensor)
|
88 |
+
|
89 |
+
def forward(self, tokens, esm, alphabet):
|
90 |
+
# pdb.set_trace()
|
91 |
+
batch_size, seq_len = tokens.size()
|
92 |
+
pad_mask = (tokens != alphabet.padding_idx).int() # B*L
|
93 |
+
device = tokens.device
|
94 |
+
|
95 |
+
# ESM-2 embedding
|
96 |
+
with torch.no_grad():
|
97 |
+
esm_results = esm(tokens, repr_layers=[33], return_contacts=True)
|
98 |
+
esm_embedding = esm_results["representations"][33] # shape: B*L*1280
|
99 |
+
esm_embedding = esm_embedding * pad_mask.unsqueeze(-1)
|
100 |
+
|
101 |
+
# VSHE embedding
|
102 |
+
vhse8_tensor = self.vhse8_tensor.to(device)
|
103 |
+
vshe8_embedding = vhse8_tensor[tokens]
|
104 |
+
|
105 |
+
# Sinual positional embedding
|
106 |
+
# pdb.set_trace()
|
107 |
+
sin_embedding = self.create_sinusoidal_embeddings(seq_len, self.d_position).repeat(batch_size, 1, 1).to(device) # shape: B*L*d_position
|
108 |
+
sin_embedding = sin_embedding * pad_mask.unsqueeze(-1)
|
109 |
+
|
110 |
+
# # One-hot position encoding
|
111 |
+
# one_hot = torch.stack((self.one_hot_encoding(seq_len),)*batch_size) # shape: B*L*L
|
112 |
+
# one_hot_embedding = self.position_embedding(one_hot.view(-1, seq_len)).view(batch_size, seq_len, -1) # shape: B*L*d_position
|
113 |
+
# one_hot_embedding = one_hot_embedding * pad_mask.unsqueeze(-1)
|
114 |
+
|
115 |
+
node_representation = torch.cat((esm_embedding, vshe8_embedding, sin_embedding), dim=-1) # B*L*(1280+8+d_position)
|
116 |
+
node_representation = self.node_mapping(node_representation) # B*L*d_node
|
117 |
+
|
118 |
+
# Edge represntation
|
119 |
+
with torch.no_grad():
|
120 |
+
esm_results = esm(self.add_cls_eos(tokens.cpu()).to(device), repr_layers=[33], return_contacts=True) # add <cls> and <eos> back to the tokens for predicting contact maps
|
121 |
+
|
122 |
+
# pdb.set_trace()
|
123 |
+
contact_map = esm_results["contacts"] # shape: B*L*L
|
124 |
+
edge_representation = self.linear_edge(contact_map.unsqueeze(-1)) # shape: B*L*L*d_edge
|
125 |
+
expanded_pad_mask = pad_mask.unsqueeze(1).expand(-1, seq_len, -1)
|
126 |
+
edge_representation = edge_representation * expanded_pad_mask.unsqueeze(-1)
|
127 |
+
# edge_representation = edge_representation * expanded_pad_mask.transpose(1,2).unsqueeze(-1)
|
128 |
+
|
129 |
+
# pdb.set_trace()
|
130 |
+
return node_representation, edge_representation, pad_mask, expanded_pad_mask
|
131 |
+
|
132 |
+
if __name__ == '__main__':
|
133 |
+
import esm
|
134 |
+
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
|
135 |
+
|
136 |
+
tokens = torch.tensor([[5,5,5,1], [5,6,7,8]])
|
137 |
+
seq_len = tokens.shape[1]
|
138 |
+
graph = ProteinGraph(1024, 512, 64)
|
139 |
+
node, edge, pad = graph(tokens, model, alphabet)
|
140 |
+
print(node.shape, edge.shape, pad.shape)
|