import torch from torch.utils.data import DataLoader, Subset from torch.optim import AdamW import torch.nn.functional as F import torch.nn as nn from datasets import load_from_disk import esm import numpy as np import math import os from transformers import AutoTokenizer from torch.optim.lr_scheduler import CosineAnnealingLR from transformers import get_linear_schedule_with_warmup from tqdm import tqdm from torch.cuda.amp import autocast, GradScaler import gc import pdb import pandas as pd from collections import defaultdict os.environ['CUDA_VISIBLE_DEVICES'] = '1' from ppl import * ##################### Hyper-parameters ############################################# binders = ['RPPGPAPARRLYA', 'KKKVAAKLLDLST', 'SQVKVVKYFLTKR', 'NGRPKHKVYLYLK', 'SLGTEIILDTMNK', 'CPGGIRCCAGSYG', 'MPPRSRPNSLPDG', 'ATEVELVTNILYR', 'WLYDRITSVDLSA', 'RPPPPMKAKKRPD', 'KSPPKSRPRPRLH', 'LQGVIGIRLILA', 'TGNVFVHIQHFPE', 'QDAEEAIYELAAI', 'TSYCKILSVETCA', 'TLRRPEYKKLRLD', 'KSKSEVSTQLQNQ', 'AASSSKKNDLQAS', 'ATGKAPRGPRKTG', 'AKWLIVRASRPCP', 'SPRKQRRSRTASA', 'SICQQCFWSSSED', 'KRNLAIRPSLVAP', 'AACAVEQPWSCCC', 'ADKYRSFTDKFLT', 'VGNFEVVDDKFNK', 'NEEVALKWTVHTS', 'AERQRVRRLLCGP', 'KPKKKNPMEKLHD', 'EEEDEETYEGLFE', 'KRKVTMTPLRQSS', 'RCGGGRYGFRRYQ', 'ARCRPLYGRFKCV', 'CQATFGCSWRFAD', 'LQGLLRLLSDSDD', 'GRSRAGPPAAAIN', 'NTLPNFPKSMLSS'] wildtype = 'MAEYLASIFGTEKDKVNCSFYFKIGACRHGDRCSRLHNKPTFSQTIALLNIYRNPQNSSQSADGLRCAVSDVEMQEHYDEFFEEVFTEMEEKYGEVEEMNVCDNLGDHLVGNVYVKFRREEDAEKAVIDLNNRWFNGQPIHAELSPVTDFREACCRQYEMGECTRGGFCNFMHLKPISRELRRELYGRRRKKHRSRSRSRERRSRSRDRGRGGGGGGGGGGGGRERDRRRSRDRERSGRF' mutant = 'MAEYLASIFGTEKDKVNCSFYFKIGACRHGDRCFRLHNKPTFSQTIALLNIYRNPQNSSQSADGLRCAVSDVEMQEHYDEFFEEVFTEMEEKYGEVEEMNVCDNLGDHLVGNVYVKFRREEDAEKAVIDLNNRWFNGQPIHAELSPVTDFREACCRQYEMGECTRGGFCNFMHLKPISRELRRELYGRRRKKHRSRSRSRERRSRSRDRGRGGGGGGGGGGGGRERDRRRSRDRERSGRF' n_heads = 4 d_k = 32 d_v = 32 checkpoint_path = '/home/tc415/muPPIt_embedding/checkpoints/mutBind_small' device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #################################################################################### vhse8_values = { 'A': [0.15, -1.11, -1.35, -0.92, 0.02, -0.91, 0.36, -0.48], 'R': [-1.47, 1.45, 1.24, 1.27, 1.55, 1.47, 1.30, 0.83], 'N': [-0.99, 0.00, 0.69, -0.37, -0.55, 0.85, 0.73, -0.80], 'D': [-1.15, 0.67, -0.41, -0.01, -2.68, 1.31, 0.03, 0.56], 'C': [0.18, -1.67, -0.21, 0.00, 1.20, -1.61, -0.19, -0.41], 'Q': [-0.96, 0.12, 0.18, 0.16, 0.09, 0.42, -0.20, -0.41], 'E': [-1.18, 0.40, 0.10, 0.36, -2.16, -0.17, 0.91, 0.36], 'G': [-0.20, -1.53, -2.63, 2.28, -0.53, -1.18, -1.34, 1.10], 'H': [-0.43, -0.25, 0.37, 0.19, 0.51, 1.28, 0.93, 0.65], 'I': [1.27, 0.14, 0.30, -1.80, 0.30, -1.61, -0.16, -0.13], 'L': [1.36, 0.07, 0.26, -0.80, 0.22, -1.37, 0.08, -0.62], 'K': [-1.17, 0.70, 0.80, 1.64, 0.67, 1.63, 0.13, -0.01], 'M': [1.01, -0.53, 0.43, 0.00, 0.23, 0.10, -0.86, -0.68], 'F': [1.52, 0.61, 0.95, -0.16, 0.25, 0.28, -1.33, -0.65], 'P': [0.22, -0.17, -0.50, -0.05, 0.01, -1.34, 0.19, 3.56], 'S': [-0.67, -0.86, -1.07, -0.41, -0.32, 0.27, -0.64, 0.11], 'T': [-0.34, -0.51, -0.55, -1.06, 0.01, -0.01, -0.79, 0.39], 'W': [1.50, 2.06, 1.79, 0.75, 0.75, 0.13, -1.06, -0.85], 'Y': [0.61, 1.60, 1.17, 0.73, 0.53, 0.25, -0.96, -0.52], 'V': [0.76, -0.92, 0.17, -1.91, 0.22, -1.40, -0.24, -0.03], } aa_to_idx = {'A': 5, 'R': 10, 'N': 17, 'D': 13, 'C': 23, 'Q': 16, 'E': 9, 'G': 6, 'H': 21, 'I': 12, 'L': 4, 'K': 15, 'M': 20, 'F': 18, 'P': 14, 'S': 8, 'T': 11, 'W': 22, 'Y': 19, 'V': 7} vhse8_tensor = torch.zeros(33, 8) for aa, values in vhse8_values.items(): aa_index = aa_to_idx[aa] vhse8_tensor[aa_index] = torch.tensor(values) vhse8_tensor = vhse8_tensor.to(device) vhse8_tensor.requires_grad = False class MutBind(torch.nn.Module): def __init__(self, d_node, d_k, d_v, n_heads, lr): super(MutBind, self).__init__() self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() for param in self.esm.parameters(): param.requires_grad = False self.q = nn.Linear(d_node, n_heads * d_k) self.k = nn.Linear(d_node, n_heads * d_k) self.v = nn.Linear(d_node, n_heads * d_v) self.layer_norm = torch.nn.LayerNorm(n_heads * d_v) self.map = torch.nn.Sequential( torch.nn.Linear(n_heads * d_v, (n_heads * d_v) // 2), torch.nn.SiLU(), torch.nn.Linear((n_heads * d_v) // 2, (n_heads * d_v) // 4), torch.nn.SiLU(), torch.nn.Linear((n_heads * d_v) // 4, 2) ) self.learning_rate = lr self.n_heads = n_heads self.d_k = d_k self.d_v = d_v self.d_node = d_node def forward(self, binder_tokens, wt_tokens, mut_tokens): global vhse8_tensor with torch.no_grad(): binder_pad_mask = (binder_tokens != self.alphabet.padding_idx).int() binder_embed = self.esm(binder_tokens, repr_layers=[33], return_contacts=False)["representations"][33] * binder_pad_mask.unsqueeze(-1) binder_vhse8 = vhse8_tensor[binder_tokens] binder_embed = torch.concat([binder_embed, binder_vhse8], dim=-1) mut_pad_mask = (mut_tokens != self.alphabet.padding_idx).int() mut_embed = self.esm(mut_tokens, repr_layers=[33], return_contacts=False)["representations"][33] * mut_pad_mask.unsqueeze(-1) mut_vhse8 = vhse8_tensor[mut_tokens] mut_embed = torch.concat([mut_embed, mut_vhse8], dim=-1) wt_pad_mask = (wt_tokens != self.alphabet.padding_idx).int() wt_embed = self.esm(wt_tokens, repr_layers=[33], return_contacts=False)["representations"][33] * wt_pad_mask.unsqueeze(-1) wt_vhse8 = vhse8_tensor[wt_tokens] wt_embed = torch.concat([wt_embed, wt_vhse8], dim=-1) # binder_embed = binder_embed.transpose(0,1) # mut_embed = mut_embed.transpose(0,1) # wt_embed = wt_embed.transpose(0,1) binder_wt_reciprocal = self.cross_attention(binder_embed, wt_embed) binder_mut_reciprocal = self.cross_attention(binder_embed, mut_embed) binder_wt_reciprocal = self.layer_norm(binder_wt_reciprocal).mean(dim=1) binder_mut_reciprocal = self.layer_norm(binder_mut_reciprocal).mean(dim=1) difference = binder_wt_reciprocal - binder_mut_reciprocal # (B, d_node) logits = self.map(difference) # (B, 2) return logits def cross_attention(self, embed_1, embed_2): B, L1, _ = embed_1.shape _, L2, _ = embed_2.shape Q = self.q(embed_1).view(B, L1, self.n_heads, self.d_k) # (B, L1, n_heads, d_k) K = self.k(embed_2).view(B, L2, self.n_heads, self.d_k) # (B, L2, n_heads, d_k) V = self.v(embed_2).view(B, L2, self.n_heads, self.d_v) # (B, L2, n_heads, d_v) Q = Q.transpose(1, 2) # (B, n_heads, L1, d_k) K = K.transpose(1, 2) # (B, n_heads, L2, d_k) V = V.transpose(1, 2) # (B, n_heads, L2, d_v) attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5) # (B, n_heads, L1, L2) attention_weights = F.softmax(attention_scores, dim=-1) # (B, n_heads, L1, L2) output = torch.matmul(attention_weights, V) # (B, n_heads, L1, d_v) output = output.transpose(1, 2).contiguous().view(B, L1, self.n_heads * self.d_v) # (B, L1, n_heads * d_v) return output tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") def predict(model, binder, wildtype, mutant): global tokenizer binder_tokens = torch.tensor(tokenizer(binder)['input_ids']).unsqueeze(0).to(device) wt_tokens = torch.tensor(tokenizer(wildtype)['input_ids']).unsqueeze(0).to(device) mut_tokens = torch.tensor(tokenizer(mutant)['input_ids']).unsqueeze(0).to(device) logits = model.forward(binder_tokens, wt_tokens, mut_tokens) return F.softmax(logits).squeeze() model = MutBind(d_node=1288, d_k=d_k, d_v=d_v, n_heads=n_heads, lr=None).to(device) model.load_state_dict(torch.load(checkpoint_path)) model.eval() good = [] for binder in binders: prob = predict(model, binder, wildtype, mutant) # pdb.set_trace() if prob[0].item() > 0.5: print(f"{binder} -> WILDTYPE: {prob[0].item():.4f}\n") else: print(f"{binder} -> MUTANT: {prob[1].item():.4f}\n") good.append(binder) print(f"Good:\n{good}") final = [] for binder in good: wt_ppl = compute_pseudo_perplexity(pepmlm, tokenizer, wildtype, binder) mut_ppl = compute_pseudo_perplexity(pepmlm, tokenizer, mutant, binder) print(f"{binder}:\n{wt_ppl}\n{mut_ppl}\n") if wt_ppl > mut_ppl: final.append(binder) print(f"Final:\n{final}") num = 0 for i in range(len(wildtype)): if wildtype[i] != mutant[i]: num += 1 print(f"Pos {i+1}: {wildtype[i]} -> {mutant[i]}") print(f"\n# Mutations = {num}")