Spaces:
Sleeping
Sleeping
import math | |
from collections import defaultdict | |
from typing import Literal | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from rdkit import Chem | |
from scipy.sparse import coo_matrix | |
from torch_geometric.data import Data | |
from torch_geometric.nn.pool.topk_pool import TopKPooling | |
from torch_geometric.nn.glob import global_mean_pool as gap, global_max_pool as gmp | |
from torch_geometric.utils import add_self_loops, remove_self_loops | |
from torch_geometric.nn.conv.message_passing import MessagePassing | |
class CoaDTIPro(nn.Module): | |
def __init__(self, | |
esm_model_and_alphabet, n_fingerprint, dim, n_word, layer_output, layer_coa, nhead=8, dropout=0.1, | |
co_attention: Literal['stack', 'encoder', 'inter'] = 'inter', gcn_pooling=False, ): | |
super().__init__() | |
self.co_attention = co_attention | |
self.layer_output = layer_output | |
self.layer_coa = layer_coa | |
self.embed_word = nn.Embedding(n_word, dim) | |
self.gnn = GNN(n_fingerprint, gcn_pooling) | |
self.esm_model, self.alphabet = esm_model_and_alphabet | |
self.batch_converter = self.alphabet.get_batch_converter() | |
self.W_attention = nn.Linear(dim, dim) | |
self.W_out = nn.Sequential( | |
nn.Linear(2 * dim, dim), | |
nn.Linear(dim, 128), | |
nn.Linear(128, 64) | |
) | |
self.coa_layers = CoAttention(dim, nhead, dropout, layer_coa, co_attention) | |
self.lin = nn.Linear(768, 512) # bert1024 esm768 | |
self.W_interaction = nn.Linear(64, 2) | |
def attention_cnn(self, x, xs, layer): | |
"""The attention mechanism is applied to the last layer of CNN.""" | |
xs = torch.unsqueeze(torch.unsqueeze(xs, 0), 0) | |
for i in range(layer): | |
xs = torch.relu(self.W_cnn[i](xs)) | |
xs = torch.squeeze(torch.squeeze(xs, 0), 0) | |
h = torch.relu(self.W_attention(x)) | |
hs = torch.relu(self.W_attention(xs)) | |
weights = torch.tanh(F.linear(h, hs)) | |
ys = torch.t(weights) * hs | |
return torch.unsqueeze(torch.mean(ys, 0), 0) | |
def forward(self, inputs, proteins): | |
"""Compound vector with GNN.""" | |
compound_vector = self.gnn(inputs) | |
compound_vector = torch.unsqueeze(compound_vector, 0) # sequence-like GNN ouput | |
_, _, proteins = self.batch_converter([(None, protein) for protein in proteins]) | |
with torch.no_grad(): | |
results = self.esm_model(proteins.to(compound_vector.device), repr_layers=[6]) | |
token_representations = results["representations"][6] | |
protein_vector = token_representations[:, 1:, :] | |
protein_vector = self.lin(torch.squeeze(protein_vector, 1)) | |
protein_vector, compound_vector = self.coa_layers(protein_vector, compound_vector) | |
protein_vector = protein_vector.mean(dim=1) | |
compound_vector = compound_vector.mean(dim=1) | |
"""Concatenate the above two vectors and output the interaction.""" | |
cat_vector = torch.cat((compound_vector, protein_vector), 1) | |
cat_vector = torch.tanh(self.W_out(cat_vector)) | |
interaction = self.W_interaction(cat_vector) | |
return interaction | |
class CoAttention(nn.Module): | |
def __init__(self, dim, nhead, dropout, layer_coa, co_attention): | |
super().__init__() | |
self.co_attention = co_attention | |
if self.co_attention == 'encoder': | |
self.coa_layers = EncoderCrossAtt(dim, nhead, dropout, layer_coa) | |
elif self.co_attention == 'stack': | |
self.coa_layers = nn.ModuleList([StackCrossAtt(dim, nhead, dropout) for _ in range(layer_coa)]) | |
elif self.co_attention == 'inter': | |
self.coa_layers = nn.ModuleList([InterCrossAtt(dim, nhead, dropout) for _ in range(layer_coa)]) | |
def forward(self, protein_vector, compound_vector): | |
# x and y are the input tensors for the two modalities | |
# edge_index_x and edge_index_y are the edge indices for the graph data | |
if self.co_attention == 'encoder': | |
return self.coa_layers(protein_vector, compound_vector) | |
else: | |
# loop over the sequential layers and pass the arguments | |
for layer in self.coa_layers: | |
protein_vector, compound_vector = layer(protein_vector, compound_vector) | |
return protein_vector, compound_vector | |
class EncoderCrossAtt(nn.Module): | |
def __init__(self, dim, nhead, dropout, layers): | |
super().__init__() | |
# self.encoder_layers = nn.ModuleList([SEA(dim, dropout) for _ in range(layers)]) | |
self.encoder_layers = nn.ModuleList([SA(dim, nhead, dropout) for _ in range(layers)]) | |
self.decoder_sa = nn.ModuleList([SA(dim, nhead, dropout) for _ in range(layers)]) | |
self.decoder_coa = nn.ModuleList([DPA(dim, nhead, dropout) for _ in range(layers)]) | |
self.layer_coa = layers | |
def forward(self, protein_vector, compound_vector): | |
for i in range(self.layer_coa): | |
compound_vector = self.encoder_layers[i](compound_vector, None) # self-attention | |
for i in range(self.layer_coa): | |
protein_vector = self.decoder_sa[i](protein_vector, None) | |
protein_vector = self.decoder_coa[i](protein_vector, compound_vector, None)# co-attention | |
return protein_vector, compound_vector | |
class InterCrossAtt(nn.Module): | |
def __init__(self, dim, nhead, dropout): | |
super().__init__() | |
self.sca = SA(dim, nhead, dropout) | |
self.spa = SA(dim, nhead, dropout) | |
self.coa_pc = DPA(dim, nhead, dropout) | |
self.coa_cp = DPA(dim, nhead, dropout) | |
def forward(self, protein_vector, compound_vector): | |
compound_vector = self.sca(compound_vector, None) # self-attention | |
protein_vector = self.spa(protein_vector, None) # self-attention | |
compound_covector = self.coa_pc(compound_vector, protein_vector, None) # co-attention | |
protein_covector = self.coa_cp(protein_vector, compound_vector, None) # co-attention | |
return protein_covector, compound_covector | |
class StackCrossAtt(nn.Module): | |
def __init__(self, dim, nhead, dropout): | |
super().__init__() | |
self.sca = SA(dim, nhead, dropout) | |
self.spa = SA(dim, nhead, dropout) | |
self.coa_cp = DPA(dim, nhead, dropout) | |
def forward(self, protein_vector, compound_vector): | |
compound_vector = self.sca(compound_vector, None) # self-attention | |
protein_vector = self.spa(protein_vector, None) # self-attention | |
protein_covector = self.coa_cp(protein_vector, compound_vector, None) # co-attention | |
return protein_covector, compound_vector | |
class MHAtt(nn.Module): | |
def __init__(self, hid_dim, n_heads, dropout): | |
super().__init__() | |
self.linear_v = nn.Linear(hid_dim, hid_dim) | |
self.linear_k = nn.Linear(hid_dim, hid_dim) | |
self.linear_q = nn.Linear(hid_dim, hid_dim) | |
self.linear_merge = nn.Linear(hid_dim, hid_dim) | |
self.hid_dim = hid_dim | |
self.dropout = dropout | |
self.nhead = n_heads | |
self.dropout = nn.Dropout(dropout) | |
self.hidden_size_head = int(self.hid_dim / self.nhead) | |
def forward(self, v, k, q, mask): | |
n_batches = q.size(0) | |
v = self.linear_v(v).view(n_batches, -1, self.nhead, self.hidden_size_head).transpose(1, 2) | |
k = self.linear_k(k).view(n_batches, -1, self.nhead, self.hidden_size_head).transpose(1, 2) | |
q = self.linear_q(q).view(n_batches, -1, self.nhead, self.hidden_size_head).transpose(1, 2) | |
atted = self.att(v, k, q, mask) | |
atted = atted.transpose(1, 2).contiguous().view(n_batches, -1, self.hid_dim) | |
atted = self.linear_merge(atted) | |
return atted | |
def att(self, value, key, query, mask): | |
d_k = query.size(-1) | |
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) | |
if mask is not None: | |
scores = scores.masked_fill(mask, -1e9) | |
att_map = F.softmax(scores, dim=-1) | |
att_map = self.dropout(att_map) | |
return torch.matmul(att_map, value) | |
class DPA(nn.Module): | |
def __init__(self, hid_dim, n_heads, dropout): | |
super().__init__() | |
self.mhatt1 = MHAtt(hid_dim, n_heads, dropout) | |
self.dropout1 = nn.Dropout(dropout) | |
self.norm1 = nn.LayerNorm(hid_dim) | |
def forward(self, x, y, y_mask=None): | |
x = self.norm1(x + self.dropout1(self.mhatt1(y, y, x, y_mask))) | |
return x | |
class SA(nn.Module): | |
def __init__(self, hid_dim, n_heads, dropout): | |
super().__init__() | |
self.mhatt1 = MHAtt(hid_dim, n_heads, dropout) | |
self.dropout1 = nn.Dropout(dropout) | |
self.norm1 = nn.LayerNorm(hid_dim) | |
def forward(self, x, mask=None): | |
x = self.norm1(x + self.dropout1(self.mhatt1(x, x, x, mask))) | |
return x | |
class SAGEConv(MessagePassing): | |
def __init__(self, in_channels, out_channels): | |
super().__init__(aggr='max') # "Max" aggregation. | |
self.lin = torch.nn.Linear(in_channels, out_channels) | |
self.act = torch.nn.ReLU() | |
self.update_lin = torch.nn.Linear(in_channels + out_channels, in_channels, bias=False) | |
self.update_act = torch.nn.ReLU() | |
def forward(self, x, edge_index): | |
# x has shape [N, in_channels] | |
# edge_index has shape [2, E] | |
edge_index, _ = remove_self_loops(edge_index) | |
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) | |
return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x) | |
def message(self, x_j): | |
# x_j has shape [E, in_channels] | |
x_j = self.lin(x_j) | |
x_j = self.act(x_j) | |
return x_j | |
def update(self, aggr_out, x): | |
# aggr_out has shape [N, out_channels] | |
new_embedding = torch.cat([aggr_out, x], dim=1) | |
new_embedding = self.update_lin(new_embedding) | |
new_embedding = self.update_act(new_embedding) | |
return new_embedding | |
class GNN(nn.Module): | |
def __init__(self, n_fingerprint, pooling, embed_dim=128): | |
super().__init__() | |
self.pooling = pooling | |
self.embed_fingerprint = nn.Embedding(num_embeddings=n_fingerprint, embedding_dim=embed_dim) | |
self.conv1 = SAGEConv(embed_dim, 128) | |
self.pool1 = TopKPooling(128, ratio=0.8) | |
self.conv2 = SAGEConv(128, 128) | |
self.pool2 = TopKPooling(128, ratio=0.8) | |
self.conv3 = SAGEConv(128, 128) | |
self.pool3 = TopKPooling(128, ratio=0.8) | |
self.linp1 = torch.nn.Linear(256, 128) | |
self.linp2 = torch.nn.Linear(128, 512) | |
self.lin = torch.nn.Linear(128, 512) | |
self.bn1 = torch.nn.BatchNorm1d(128) | |
self.bn2 = torch.nn.BatchNorm1d(64) | |
self.act1 = torch.nn.ReLU() | |
self.act2 = torch.nn.ReLU() | |
def forward(self, data): | |
# x, edge_index, batch = data.x, data.edge_index, data.batch | |
x, edge_index, batch = data.x, data.edge_index, data.batch | |
x = self.embed_fingerprint(x) | |
x = x.squeeze(1) | |
x = F.relu(self.conv1(x, edge_index)) | |
if self.pooling: | |
x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch) | |
x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) | |
x = F.relu(self.conv2(x, edge_index)) | |
x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch) | |
x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) | |
x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch) | |
x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) | |
x = x1 + x2 + x3 | |
x = self.linp1(x) | |
x = self.act1(x) | |
x = self.linp2(x) | |
else: | |
x = F.relu(self.conv2(x, edge_index)) | |
x = self.lin(x) | |
return x | |
atom_dict = defaultdict(lambda: len(atom_dict)) # 51 bindingdb: 26 | |
bond_dict = defaultdict(lambda: len(bond_dict)) # 4 bindingdb: 4 | |
fingerprint_dict = defaultdict(lambda: len(fingerprint_dict)) # 6341 bindingdb: 20366 | |
edge_dict = defaultdict(lambda: len(edge_dict)) # 17536 bindingdb: 77916 | |
word_dict = defaultdict(lambda: len(word_dict)) # 22 bindingdb: 21 | |
def drug_featurizer(smiles, radius=2): | |
mol = Chem.AddHs(Chem.MolFromSmiles(smiles)) | |
atoms = create_atoms(mol) | |
i_jbond_dict = create_ijbonddict(mol) | |
fingerprints = extract_fingerprints(atoms, i_jbond_dict, radius) | |
adjacency = coo_matrix(Chem.GetAdjacencyMatrix(mol)) | |
adjacency = coo_matrix(adjacency) | |
edge_index = np.array([adjacency.row, adjacency.col]) | |
return Data(x=torch.LongTensor(fingerprints).unsqueeze(1), edge_index=torch.LongTensor(edge_index)) | |
def create_atoms(mol): | |
"""Create a list of atom (e.g., hydrogen and oxygen) IDs | |
considering the aromaticity.""" | |
# GetSymbol: obtain the symbol of the atom | |
atoms = [a.GetSymbol() for a in mol.GetAtoms()] | |
for a in mol.GetAromaticAtoms(): | |
i = a.GetIdx() | |
atoms[i] = (atoms[i], 'aromatic') | |
# turn it into index | |
atoms = [atom_dict[a] for a in atoms] | |
return np.array(atoms) | |
def create_ijbonddict(mol): | |
"""Create a dictionary, which each key is a node ID | |
and each value is the tuples of its neighboring node | |
and bond (e.g., single and double) IDs.""" | |
i_jbond_dict = defaultdict(lambda: []) | |
for b in mol.GetBonds(): | |
i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx() | |
bond = bond_dict[str(b.GetBondType())] | |
i_jbond_dict[i].append((j, bond)) | |
i_jbond_dict[j].append((i, bond)) | |
return i_jbond_dict | |
def extract_fingerprints(atoms, i_jbond_dict, radius=2): | |
"""Extract the r-radius subgraphs (i.e., fingerprints) | |
from a molecular graph using Weisfeiler-Lehman algorithm.""" | |
fingerprints = None | |
if (len(atoms) == 1) or (radius == 0): | |
fingerprints = [fingerprint_dict[a] for a in atoms] | |
else: | |
nodes = atoms | |
i_jedge_dict = i_jbond_dict | |
for _ in range(radius): | |
"""Update each node ID considering its neighboring nodes and edges | |
(i.e., r-radius subgraphs or fingerprints).""" | |
fingerprints = [] | |
for i, j_edge in i_jedge_dict.items(): | |
neighbors = [(nodes[j], edge) for j, edge in j_edge] | |
fingerprint = (nodes[i], tuple(sorted(neighbors))) | |
fingerprints.append(fingerprint_dict[fingerprint]) | |
nodes = fingerprints | |
"""Also update each edge ID considering two nodes | |
on its both sides.""" | |
_i_jedge_dict = defaultdict(lambda: []) | |
for i, j_edge in i_jedge_dict.items(): | |
for j, edge in j_edge: | |
both_side = tuple(sorted((nodes[i], nodes[j]))) | |
edge = edge_dict[(both_side, edge)] | |
_i_jedge_dict[i].append((j, edge)) | |
i_jedge_dict = _i_jedge_dict | |
return np.array(fingerprints) | |