AlienChen's picture
Upload 139 files
65bd8af verified
import pdb
import numpy as np
import torch
import torch.nn as nn
from .layers import *
from .modules import *
import pdb
from transformers import EsmModel, EsmTokenizer
def to_var(x):
if torch.cuda.is_available():
x = x.cuda()
return x
class RepeatedModule3(nn.Module):
def __init__(self, n_layers, d_model, d_hidden,
n_head, d_k, d_v, d_inner, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(1280, d_model)
self.linear2 = nn.Linear(1280, d_model)
self.sequence_embedding = nn.Embedding(20, d_model)
self.d_model = d_model
self.reciprocal_layer_stack = nn.ModuleList([
ReciprocalLayerwithCNN(d_model, d_inner, d_hidden, n_head, d_k, d_v)
for _ in range(n_layers)])
self.dropout = nn.Dropout(dropout)
self.dropout_2 = nn.Dropout(dropout)
def forward(self, peptide_sequence, protein_sequence):
sequence_attention_list = []
prot_attention_list = []
prot_seq_attention_list = []
seq_prot_attention_list = []
sequence_enc = self.dropout(self.linear1(peptide_sequence))
prot_enc = self.dropout_2(self.linear2(protein_sequence))
for reciprocal_layer in self.reciprocal_layer_stack:
prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention = \
reciprocal_layer(sequence_enc, prot_enc)
sequence_attention_list.append(sequence_attention)
prot_attention_list.append(prot_attention)
prot_seq_attention_list.append(prot_seq_attention)
seq_prot_attention_list.append(seq_prot_attention)
return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
seq_prot_attention_list, seq_prot_attention_list
class RepeatedModule2(nn.Module):
def __init__(self, n_layers, d_model,
n_head, d_k, d_v, d_inner, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(1280, d_model)
self.linear2 = nn.Linear(1280, d_model)
self.sequence_embedding = nn.Embedding(20, d_model)
self.d_model = d_model
self.reciprocal_layer_stack = nn.ModuleList([
ReciprocalLayer(d_model, d_inner, n_head, d_k, d_v)
for _ in range(n_layers)])
self.dropout = nn.Dropout(dropout)
self.dropout_2 = nn.Dropout(dropout)
def forward(self, peptide_sequence, protein_sequence):
sequence_attention_list = []
prot_attention_list = []
prot_seq_attention_list = []
seq_prot_attention_list = []
sequence_enc = self.dropout(self.linear1(peptide_sequence))
prot_enc = self.dropout_2(self.linear2(protein_sequence))
for reciprocal_layer in self.reciprocal_layer_stack:
prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention = \
reciprocal_layer(sequence_enc, prot_enc)
sequence_attention_list.append(sequence_attention)
prot_attention_list.append(prot_attention)
prot_seq_attention_list.append(prot_seq_attention)
seq_prot_attention_list.append(seq_prot_attention)
return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
seq_prot_attention_list, seq_prot_attention_list
class RepeatedModule(nn.Module):
def __init__(self, n_layers, d_model,
n_head, d_k, d_v, d_inner, dropout=0.1):
super().__init__()
self.linear = nn.Linear(1024, d_model)
self.sequence_embedding = nn.Embedding(20, d_model)
self.d_model = d_model
self.reciprocal_layer_stack = nn.ModuleList([
ReciprocalLayer(d_model, d_inner, n_head, d_k, d_v)
for _ in range(n_layers)])
self.dropout = nn.Dropout(dropout)
self.dropout_2 = nn.Dropout(dropout)
def _positional_embedding(self, batches, number):
result = torch.exp(torch.arange(0, self.d_model,2,dtype=torch.float32)*-1*(np.log(10000)/self.d_model))
numbers = torch.arange(0, number, dtype=torch.float32)
numbers = numbers.unsqueeze(0)
numbers = numbers.unsqueeze(2)
result = numbers*result
result = torch.cat((torch.sin(result), torch.cos(result)),2)
return result
def forward(self, peptide_sequence, protein_sequence):
sequence_attention_list = []
prot_attention_list = []
prot_seq_attention_list = []
seq_prot_attention_list = []
sequence_enc = self.sequence_embedding(peptide_sequence)
sequence_enc += to_var(self._positional_embedding(peptide_sequence.shape[0],
peptide_sequence.shape[1]))
sequence_enc = self.dropout(sequence_enc)
prot_enc = self.dropout_2(self.linear(protein_sequence))
for reciprocal_layer in self.reciprocal_layer_stack:
prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention =\
reciprocal_layer(sequence_enc, prot_enc)
sequence_attention_list.append(sequence_attention)
prot_attention_list.append(prot_attention)
prot_seq_attention_list.append(prot_seq_attention)
seq_prot_attention_list.append(seq_prot_attention)
return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list,\
seq_prot_attention_list, seq_prot_attention_list
class FullModel(nn.Module):
def __init__(self, n_layers, d_model, n_head,
d_k, d_v, d_inner, return_attention=False, dropout=0.2):
super().__init__()
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
# freeze all the esm_model parameters
for param in self.esm_model.parameters():
param.requires_grad = False
self.repeated_module = RepeatedModule2(n_layers, d_model,
n_head, d_k, d_v, d_inner, dropout=dropout)
self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
d_k, d_v, dropout=dropout)
self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
self.output_projection_prot = nn.Linear(d_model, 1)
self.sigmoid = nn.Sigmoid()
self.return_attention = return_attention
def forward(self, binder_tokens, target_tokens):
with torch.no_grad():
peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state
protein_sequence = self.esm_model(**target_tokens).last_hidden_state
# pdb.set_trace()
prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
protein_sequence)
prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
# pdb.set_trace()
prot_enc = self.final_ffn(prot_enc)
prot_enc = self.sigmoid(self.output_projection_prot(prot_enc))
return prot_enc
class Original_FullModel(nn.Module):
def __init__(self, n_layers, d_model, n_head,
d_k, d_v, d_inner, return_attention=False, dropout=0.2):
super().__init__()
self.repeated_module = RepeatedModule(n_layers, d_model,
n_head, d_k, d_v, d_inner, dropout=dropout)
self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
d_k, d_v, dropout=dropout)
self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
self.output_projection_prot = nn.Linear(d_model, 2)
self.softmax_prot =nn.LogSoftmax(dim=-1)
self.return_attention = return_attention
def forward(self, peptide_sequence, protein_sequence):
prot_enc, sequence_enc, sequence_attention_list, prot_attention_list,\
seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
protein_sequence)
prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
prot_enc = self.final_ffn(prot_enc)
prot_enc = self.softmax_prot(self.output_projection_prot(prot_enc))
if not self.return_attention:
return prot_enc
else:
return prot_enc, sequence_attention_list, prot_attention_list,\
seq_prot_attention_list, seq_prot_attention_list