File size: 3,243 Bytes
65bd8af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import pdb
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
import time
from .bindevaluator_modules import *
class BindEvaluator(pl.LightningModule):
def __init__(self, n_layers, d_model, d_hidden, n_head,
d_k, d_v, d_inner, dropout=0.2,
learning_rate=0.00001, max_epochs=15, kl_weight=1):
super(BindEvaluator, self).__init__()
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
# freeze all the esm_model parameters
for param in self.esm_model.parameters():
param.requires_grad = False
self.repeated_module = RepeatedModule3(n_layers, d_model, d_hidden,
n_head, d_k, d_v, d_inner, dropout=dropout)
self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
d_k, d_v, dropout=dropout)
self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
self.output_projection_prot = nn.Linear(d_model, 1)
self.learning_rate = learning_rate
self.max_epochs = max_epochs
self.kl_weight = kl_weight
self.classification_threshold = nn.Parameter(torch.tensor(0.5)) # Initial threshold
self.historical_memory = 0.9
self.class_weights = torch.tensor([3.000471363174231, 0.5999811490272925]) # binding_site weights, non-bidning site weights
def forward(self, binder_tokens, target_tokens):
peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state
protein_sequence = self.esm_model(**target_tokens).last_hidden_state
prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
protein_sequence)
prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
prot_enc = self.final_ffn(prot_enc)
prot_enc = self.output_projection_prot(prot_enc)
return prot_enc
def get_probs(self, xt, target_sequence):
'''
Inputs:
- xt: Shape (bsz*seq_len*vocab_size, seq_len)
- target_sequence: Shape (bsz*seq_len*vocab_size, tgt_len)
'''
binder_attention_mask = torch.ones_like(xt)
target_attention_mask = torch.ones_like(target_sequence)
binder_attention_mask[:, 0] = binder_attention_mask[:, -1] = 0
target_attention_mask[:, 0] = target_attention_mask[:, -1] = 0
binder_tokens = {'input_ids': xt, 'attention_mask': binder_attention_mask.to(xt.device)}
target_tokens = {'input_ids': target_sequence, 'attention_mask': target_attention_mask.to(target_sequence.device)}
start = time.time()
logits = self.forward(binder_tokens, target_tokens).squeeze(-1)
# print(f"Time: {time.time() - start} seconds")
logits[:, 0] = logits[:, -1] = -100 # float('-inf')
probs = F.softmax(logits, dim=-1)
return probs # shape (bsz*seq_len*vocab_size, tgt_len)
|