|
import pdb |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from .layers import * |
|
from .modules import * |
|
import pdb |
|
from transformers import EsmModel, EsmTokenizer |
|
|
|
def to_var(x): |
|
if torch.cuda.is_available(): |
|
x = x.cuda() |
|
return x |
|
|
|
|
|
class RepeatedModule3(nn.Module): |
|
def __init__(self, n_layers, d_model, d_hidden, |
|
n_head, d_k, d_v, d_inner, dropout=0.1): |
|
super().__init__() |
|
|
|
self.linear1 = nn.Linear(1280, d_model) |
|
self.linear2 = nn.Linear(1280, d_model) |
|
self.sequence_embedding = nn.Embedding(20, d_model) |
|
self.d_model = d_model |
|
|
|
self.reciprocal_layer_stack = nn.ModuleList([ |
|
ReciprocalLayerwithCNN(d_model, d_inner, d_hidden, n_head, d_k, d_v) |
|
for _ in range(n_layers)]) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
self.dropout_2 = nn.Dropout(dropout) |
|
|
|
def forward(self, peptide_sequence, protein_sequence): |
|
sequence_attention_list = [] |
|
|
|
prot_attention_list = [] |
|
|
|
prot_seq_attention_list = [] |
|
|
|
seq_prot_attention_list = [] |
|
|
|
sequence_enc = self.dropout(self.linear1(peptide_sequence)) |
|
|
|
prot_enc = self.dropout_2(self.linear2(protein_sequence)) |
|
|
|
for reciprocal_layer in self.reciprocal_layer_stack: |
|
prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention = \ |
|
reciprocal_layer(sequence_enc, prot_enc) |
|
|
|
sequence_attention_list.append(sequence_attention) |
|
|
|
prot_attention_list.append(prot_attention) |
|
|
|
prot_seq_attention_list.append(prot_seq_attention) |
|
|
|
seq_prot_attention_list.append(seq_prot_attention) |
|
|
|
return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ |
|
seq_prot_attention_list, seq_prot_attention_list |
|
|
|
|
|
class RepeatedModule2(nn.Module): |
|
def __init__(self, n_layers, d_model, |
|
n_head, d_k, d_v, d_inner, dropout=0.1): |
|
super().__init__() |
|
|
|
self.linear1 = nn.Linear(1280, d_model) |
|
self.linear2 = nn.Linear(1280, d_model) |
|
self.sequence_embedding = nn.Embedding(20, d_model) |
|
self.d_model = d_model |
|
|
|
self.reciprocal_layer_stack = nn.ModuleList([ |
|
ReciprocalLayer(d_model, d_inner, n_head, d_k, d_v) |
|
for _ in range(n_layers)]) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
self.dropout_2 = nn.Dropout(dropout) |
|
|
|
def forward(self, peptide_sequence, protein_sequence): |
|
sequence_attention_list = [] |
|
|
|
prot_attention_list = [] |
|
|
|
prot_seq_attention_list = [] |
|
|
|
seq_prot_attention_list = [] |
|
|
|
sequence_enc = self.dropout(self.linear1(peptide_sequence)) |
|
|
|
prot_enc = self.dropout_2(self.linear2(protein_sequence)) |
|
|
|
for reciprocal_layer in self.reciprocal_layer_stack: |
|
prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention = \ |
|
reciprocal_layer(sequence_enc, prot_enc) |
|
|
|
sequence_attention_list.append(sequence_attention) |
|
|
|
prot_attention_list.append(prot_attention) |
|
|
|
prot_seq_attention_list.append(prot_seq_attention) |
|
|
|
seq_prot_attention_list.append(seq_prot_attention) |
|
|
|
return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ |
|
seq_prot_attention_list, seq_prot_attention_list |
|
|
|
|
|
class RepeatedModule(nn.Module): |
|
|
|
def __init__(self, n_layers, d_model, |
|
n_head, d_k, d_v, d_inner, dropout=0.1): |
|
|
|
super().__init__() |
|
|
|
self.linear = nn.Linear(1024, d_model) |
|
self.sequence_embedding = nn.Embedding(20, d_model) |
|
self.d_model = d_model |
|
|
|
self.reciprocal_layer_stack = nn.ModuleList([ |
|
ReciprocalLayer(d_model, d_inner, n_head, d_k, d_v) |
|
for _ in range(n_layers)]) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
self.dropout_2 = nn.Dropout(dropout) |
|
|
|
|
|
|
|
def _positional_embedding(self, batches, number): |
|
|
|
result = torch.exp(torch.arange(0, self.d_model,2,dtype=torch.float32)*-1*(np.log(10000)/self.d_model)) |
|
|
|
numbers = torch.arange(0, number, dtype=torch.float32) |
|
|
|
numbers = numbers.unsqueeze(0) |
|
|
|
numbers = numbers.unsqueeze(2) |
|
|
|
result = numbers*result |
|
|
|
result = torch.cat((torch.sin(result), torch.cos(result)),2) |
|
|
|
return result |
|
|
|
def forward(self, peptide_sequence, protein_sequence): |
|
|
|
|
|
sequence_attention_list = [] |
|
|
|
prot_attention_list = [] |
|
|
|
prot_seq_attention_list = [] |
|
|
|
seq_prot_attention_list = [] |
|
|
|
sequence_enc = self.sequence_embedding(peptide_sequence) |
|
|
|
sequence_enc += to_var(self._positional_embedding(peptide_sequence.shape[0], |
|
peptide_sequence.shape[1])) |
|
sequence_enc = self.dropout(sequence_enc) |
|
|
|
|
|
|
|
|
|
|
|
prot_enc = self.dropout_2(self.linear(protein_sequence)) |
|
|
|
|
|
|
|
|
|
for reciprocal_layer in self.reciprocal_layer_stack: |
|
|
|
prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention =\ |
|
reciprocal_layer(sequence_enc, prot_enc) |
|
|
|
sequence_attention_list.append(sequence_attention) |
|
|
|
prot_attention_list.append(prot_attention) |
|
|
|
prot_seq_attention_list.append(prot_seq_attention) |
|
|
|
seq_prot_attention_list.append(seq_prot_attention) |
|
|
|
|
|
|
|
return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list,\ |
|
seq_prot_attention_list, seq_prot_attention_list |
|
|
|
|
|
class FullModel(nn.Module): |
|
|
|
def __init__(self, n_layers, d_model, n_head, |
|
d_k, d_v, d_inner, return_attention=False, dropout=0.2): |
|
super().__init__() |
|
|
|
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
|
|
|
|
for param in self.esm_model.parameters(): |
|
param.requires_grad = False |
|
|
|
self.repeated_module = RepeatedModule2(n_layers, d_model, |
|
n_head, d_k, d_v, d_inner, dropout=dropout) |
|
|
|
self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model, |
|
d_k, d_v, dropout=dropout) |
|
|
|
self.final_ffn = FFN(d_model, d_inner, dropout=dropout) |
|
|
|
self.output_projection_prot = nn.Linear(d_model, 1) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
self.return_attention = return_attention |
|
|
|
def forward(self, binder_tokens, target_tokens): |
|
|
|
with torch.no_grad(): |
|
peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state |
|
protein_sequence = self.esm_model(**target_tokens).last_hidden_state |
|
|
|
|
|
|
|
prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ |
|
seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence, |
|
protein_sequence) |
|
|
|
prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc) |
|
|
|
|
|
|
|
prot_enc = self.final_ffn(prot_enc) |
|
|
|
prot_enc = self.sigmoid(self.output_projection_prot(prot_enc)) |
|
|
|
return prot_enc |
|
|
|
|
|
|
|
class Original_FullModel(nn.Module): |
|
|
|
def __init__(self, n_layers, d_model, n_head, |
|
d_k, d_v, d_inner, return_attention=False, dropout=0.2): |
|
|
|
super().__init__() |
|
self.repeated_module = RepeatedModule(n_layers, d_model, |
|
n_head, d_k, d_v, d_inner, dropout=dropout) |
|
|
|
self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model, |
|
d_k, d_v, dropout=dropout) |
|
|
|
self.final_ffn = FFN(d_model, d_inner, dropout=dropout) |
|
self.output_projection_prot = nn.Linear(d_model, 2) |
|
|
|
|
|
|
|
self.softmax_prot =nn.LogSoftmax(dim=-1) |
|
|
|
|
|
self.return_attention = return_attention |
|
|
|
def forward(self, peptide_sequence, protein_sequence): |
|
|
|
prot_enc, sequence_enc, sequence_attention_list, prot_attention_list,\ |
|
seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence, |
|
protein_sequence) |
|
|
|
|
|
|
|
prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc) |
|
|
|
prot_enc = self.final_ffn(prot_enc) |
|
|
|
prot_enc = self.softmax_prot(self.output_projection_prot(prot_enc)) |
|
|
|
|
|
|
|
|
|
|
|
if not self.return_attention: |
|
return prot_enc |
|
else: |
|
return prot_enc, sequence_attention_list, prot_attention_list,\ |
|
seq_prot_attention_list, seq_prot_attention_list |
|
|
|
|