libokj's picture
Upload 299 files
953417b
import math
from collections import defaultdict
from typing import Literal
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from rdkit import Chem
from scipy.sparse import coo_matrix
from torch_geometric.data import Data
from torch_geometric.nn.pool.topk_pool import TopKPooling
from torch_geometric.nn.glob import global_mean_pool as gap, global_max_pool as gmp
from torch_geometric.utils import add_self_loops, remove_self_loops
from torch_geometric.nn.conv.message_passing import MessagePassing
class CoaDTIPro(nn.Module):
def __init__(self,
esm_model_and_alphabet, n_fingerprint, dim, n_word, layer_output, layer_coa, nhead=8, dropout=0.1,
co_attention: Literal['stack', 'encoder', 'inter'] = 'inter', gcn_pooling=False, ):
super().__init__()
self.co_attention = co_attention
self.layer_output = layer_output
self.layer_coa = layer_coa
self.embed_word = nn.Embedding(n_word, dim)
self.gnn = GNN(n_fingerprint, gcn_pooling)
self.esm_model, self.alphabet = esm_model_and_alphabet
self.batch_converter = self.alphabet.get_batch_converter()
self.W_attention = nn.Linear(dim, dim)
self.W_out = nn.Sequential(
nn.Linear(2 * dim, dim),
nn.Linear(dim, 128),
nn.Linear(128, 64)
)
self.coa_layers = CoAttention(dim, nhead, dropout, layer_coa, co_attention)
self.lin = nn.Linear(768, 512) # bert1024 esm768
self.W_interaction = nn.Linear(64, 2)
def attention_cnn(self, x, xs, layer):
"""The attention mechanism is applied to the last layer of CNN."""
xs = torch.unsqueeze(torch.unsqueeze(xs, 0), 0)
for i in range(layer):
xs = torch.relu(self.W_cnn[i](xs))
xs = torch.squeeze(torch.squeeze(xs, 0), 0)
h = torch.relu(self.W_attention(x))
hs = torch.relu(self.W_attention(xs))
weights = torch.tanh(F.linear(h, hs))
ys = torch.t(weights) * hs
return torch.unsqueeze(torch.mean(ys, 0), 0)
def forward(self, inputs, proteins):
"""Compound vector with GNN."""
compound_vector = self.gnn(inputs)
compound_vector = torch.unsqueeze(compound_vector, 0) # sequence-like GNN ouput
_, _, proteins = self.batch_converter([(None, protein) for protein in proteins])
with torch.no_grad():
results = self.esm_model(proteins.to(compound_vector.device), repr_layers=[6])
token_representations = results["representations"][6]
protein_vector = token_representations[:, 1:, :]
protein_vector = self.lin(torch.squeeze(protein_vector, 1))
protein_vector, compound_vector = self.coa_layers(protein_vector, compound_vector)
protein_vector = protein_vector.mean(dim=1)
compound_vector = compound_vector.mean(dim=1)
"""Concatenate the above two vectors and output the interaction."""
cat_vector = torch.cat((compound_vector, protein_vector), 1)
cat_vector = torch.tanh(self.W_out(cat_vector))
interaction = self.W_interaction(cat_vector)
return interaction
class CoAttention(nn.Module):
def __init__(self, dim, nhead, dropout, layer_coa, co_attention):
super().__init__()
self.co_attention = co_attention
if self.co_attention == 'encoder':
self.coa_layers = EncoderCrossAtt(dim, nhead, dropout, layer_coa)
elif self.co_attention == 'stack':
self.coa_layers = nn.ModuleList([StackCrossAtt(dim, nhead, dropout) for _ in range(layer_coa)])
elif self.co_attention == 'inter':
self.coa_layers = nn.ModuleList([InterCrossAtt(dim, nhead, dropout) for _ in range(layer_coa)])
def forward(self, protein_vector, compound_vector):
# x and y are the input tensors for the two modalities
# edge_index_x and edge_index_y are the edge indices for the graph data
if self.co_attention == 'encoder':
return self.coa_layers(protein_vector, compound_vector)
else:
# loop over the sequential layers and pass the arguments
for layer in self.coa_layers:
protein_vector, compound_vector = layer(protein_vector, compound_vector)
return protein_vector, compound_vector
class EncoderCrossAtt(nn.Module):
def __init__(self, dim, nhead, dropout, layers):
super().__init__()
# self.encoder_layers = nn.ModuleList([SEA(dim, dropout) for _ in range(layers)])
self.encoder_layers = nn.ModuleList([SA(dim, nhead, dropout) for _ in range(layers)])
self.decoder_sa = nn.ModuleList([SA(dim, nhead, dropout) for _ in range(layers)])
self.decoder_coa = nn.ModuleList([DPA(dim, nhead, dropout) for _ in range(layers)])
self.layer_coa = layers
def forward(self, protein_vector, compound_vector):
for i in range(self.layer_coa):
compound_vector = self.encoder_layers[i](compound_vector, None) # self-attention
for i in range(self.layer_coa):
protein_vector = self.decoder_sa[i](protein_vector, None)
protein_vector = self.decoder_coa[i](protein_vector, compound_vector, None)# co-attention
return protein_vector, compound_vector
class InterCrossAtt(nn.Module):
def __init__(self, dim, nhead, dropout):
super().__init__()
self.sca = SA(dim, nhead, dropout)
self.spa = SA(dim, nhead, dropout)
self.coa_pc = DPA(dim, nhead, dropout)
self.coa_cp = DPA(dim, nhead, dropout)
def forward(self, protein_vector, compound_vector):
compound_vector = self.sca(compound_vector, None) # self-attention
protein_vector = self.spa(protein_vector, None) # self-attention
compound_covector = self.coa_pc(compound_vector, protein_vector, None) # co-attention
protein_covector = self.coa_cp(protein_vector, compound_vector, None) # co-attention
return protein_covector, compound_covector
class StackCrossAtt(nn.Module):
def __init__(self, dim, nhead, dropout):
super().__init__()
self.sca = SA(dim, nhead, dropout)
self.spa = SA(dim, nhead, dropout)
self.coa_cp = DPA(dim, nhead, dropout)
def forward(self, protein_vector, compound_vector):
compound_vector = self.sca(compound_vector, None) # self-attention
protein_vector = self.spa(protein_vector, None) # self-attention
protein_covector = self.coa_cp(protein_vector, compound_vector, None) # co-attention
return protein_covector, compound_vector
class MHAtt(nn.Module):
def __init__(self, hid_dim, n_heads, dropout):
super().__init__()
self.linear_v = nn.Linear(hid_dim, hid_dim)
self.linear_k = nn.Linear(hid_dim, hid_dim)
self.linear_q = nn.Linear(hid_dim, hid_dim)
self.linear_merge = nn.Linear(hid_dim, hid_dim)
self.hid_dim = hid_dim
self.dropout = dropout
self.nhead = n_heads
self.dropout = nn.Dropout(dropout)
self.hidden_size_head = int(self.hid_dim / self.nhead)
def forward(self, v, k, q, mask):
n_batches = q.size(0)
v = self.linear_v(v).view(n_batches, -1, self.nhead, self.hidden_size_head).transpose(1, 2)
k = self.linear_k(k).view(n_batches, -1, self.nhead, self.hidden_size_head).transpose(1, 2)
q = self.linear_q(q).view(n_batches, -1, self.nhead, self.hidden_size_head).transpose(1, 2)
atted = self.att(v, k, q, mask)
atted = atted.transpose(1, 2).contiguous().view(n_batches, -1, self.hid_dim)
atted = self.linear_merge(atted)
return atted
def att(self, value, key, query, mask):
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask, -1e9)
att_map = F.softmax(scores, dim=-1)
att_map = self.dropout(att_map)
return torch.matmul(att_map, value)
class DPA(nn.Module):
def __init__(self, hid_dim, n_heads, dropout):
super().__init__()
self.mhatt1 = MHAtt(hid_dim, n_heads, dropout)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(hid_dim)
def forward(self, x, y, y_mask=None):
x = self.norm1(x + self.dropout1(self.mhatt1(y, y, x, y_mask)))
return x
class SA(nn.Module):
def __init__(self, hid_dim, n_heads, dropout):
super().__init__()
self.mhatt1 = MHAtt(hid_dim, n_heads, dropout)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(hid_dim)
def forward(self, x, mask=None):
x = self.norm1(x + self.dropout1(self.mhatt1(x, x, x, mask)))
return x
class SAGEConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='max') # "Max" aggregation.
self.lin = torch.nn.Linear(in_channels, out_channels)
self.act = torch.nn.ReLU()
self.update_lin = torch.nn.Linear(in_channels + out_channels, in_channels, bias=False)
self.update_act = torch.nn.ReLU()
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
def message(self, x_j):
# x_j has shape [E, in_channels]
x_j = self.lin(x_j)
x_j = self.act(x_j)
return x_j
def update(self, aggr_out, x):
# aggr_out has shape [N, out_channels]
new_embedding = torch.cat([aggr_out, x], dim=1)
new_embedding = self.update_lin(new_embedding)
new_embedding = self.update_act(new_embedding)
return new_embedding
class GNN(nn.Module):
def __init__(self, n_fingerprint, pooling, embed_dim=128):
super().__init__()
self.pooling = pooling
self.embed_fingerprint = nn.Embedding(num_embeddings=n_fingerprint, embedding_dim=embed_dim)
self.conv1 = SAGEConv(embed_dim, 128)
self.pool1 = TopKPooling(128, ratio=0.8)
self.conv2 = SAGEConv(128, 128)
self.pool2 = TopKPooling(128, ratio=0.8)
self.conv3 = SAGEConv(128, 128)
self.pool3 = TopKPooling(128, ratio=0.8)
self.linp1 = torch.nn.Linear(256, 128)
self.linp2 = torch.nn.Linear(128, 512)
self.lin = torch.nn.Linear(128, 512)
self.bn1 = torch.nn.BatchNorm1d(128)
self.bn2 = torch.nn.BatchNorm1d(64)
self.act1 = torch.nn.ReLU()
self.act2 = torch.nn.ReLU()
def forward(self, data):
# x, edge_index, batch = data.x, data.edge_index, data.batch
x, edge_index, batch = data.x, data.edge_index, data.batch
x = self.embed_fingerprint(x)
x = x.squeeze(1)
x = F.relu(self.conv1(x, edge_index))
if self.pooling:
x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
x = F.relu(self.conv2(x, edge_index))
x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
x = x1 + x2 + x3
x = self.linp1(x)
x = self.act1(x)
x = self.linp2(x)
else:
x = F.relu(self.conv2(x, edge_index))
x = self.lin(x)
return x
atom_dict = defaultdict(lambda: len(atom_dict)) # 51 bindingdb: 26
bond_dict = defaultdict(lambda: len(bond_dict)) # 4 bindingdb: 4
fingerprint_dict = defaultdict(lambda: len(fingerprint_dict)) # 6341 bindingdb: 20366
edge_dict = defaultdict(lambda: len(edge_dict)) # 17536 bindingdb: 77916
word_dict = defaultdict(lambda: len(word_dict)) # 22 bindingdb: 21
def drug_featurizer(smiles, radius=2):
mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
atoms = create_atoms(mol)
i_jbond_dict = create_ijbonddict(mol)
fingerprints = extract_fingerprints(atoms, i_jbond_dict, radius)
adjacency = coo_matrix(Chem.GetAdjacencyMatrix(mol))
adjacency = coo_matrix(adjacency)
edge_index = np.array([adjacency.row, adjacency.col])
return Data(x=torch.LongTensor(fingerprints).unsqueeze(1), edge_index=torch.LongTensor(edge_index))
def create_atoms(mol):
"""Create a list of atom (e.g., hydrogen and oxygen) IDs
considering the aromaticity."""
# GetSymbol: obtain the symbol of the atom
atoms = [a.GetSymbol() for a in mol.GetAtoms()]
for a in mol.GetAromaticAtoms():
i = a.GetIdx()
atoms[i] = (atoms[i], 'aromatic')
# turn it into index
atoms = [atom_dict[a] for a in atoms]
return np.array(atoms)
def create_ijbonddict(mol):
"""Create a dictionary, which each key is a node ID
and each value is the tuples of its neighboring node
and bond (e.g., single and double) IDs."""
i_jbond_dict = defaultdict(lambda: [])
for b in mol.GetBonds():
i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
bond = bond_dict[str(b.GetBondType())]
i_jbond_dict[i].append((j, bond))
i_jbond_dict[j].append((i, bond))
return i_jbond_dict
def extract_fingerprints(atoms, i_jbond_dict, radius=2):
"""Extract the r-radius subgraphs (i.e., fingerprints)
from a molecular graph using Weisfeiler-Lehman algorithm."""
fingerprints = None
if (len(atoms) == 1) or (radius == 0):
fingerprints = [fingerprint_dict[a] for a in atoms]
else:
nodes = atoms
i_jedge_dict = i_jbond_dict
for _ in range(radius):
"""Update each node ID considering its neighboring nodes and edges
(i.e., r-radius subgraphs or fingerprints)."""
fingerprints = []
for i, j_edge in i_jedge_dict.items():
neighbors = [(nodes[j], edge) for j, edge in j_edge]
fingerprint = (nodes[i], tuple(sorted(neighbors)))
fingerprints.append(fingerprint_dict[fingerprint])
nodes = fingerprints
"""Also update each edge ID considering two nodes
on its both sides."""
_i_jedge_dict = defaultdict(lambda: [])
for i, j_edge in i_jedge_dict.items():
for j, edge in j_edge:
both_side = tuple(sorted((nodes[i], nodes[j])))
edge = edge_dict[(both_side, edge)]
_i_jedge_dict[i].append((j, edge))
i_jedge_dict = _i_jedge_dict
return np.array(fingerprints)