|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
from tqdm import tqdm |
|
import seaborn as sns |
|
import pickle |
|
from collections import defaultdict |
|
from predict import * |
|
|
|
def compute_mean(tuple_list): |
|
sum_count_dict = defaultdict(lambda: [0, 0]) |
|
|
|
|
|
for key, value in tuple_list: |
|
sum_count_dict[key][0] += value |
|
sum_count_dict[key][1] += 1 |
|
|
|
|
|
mean_dict = {key: round(sum_value[0] / sum_value[1],2) for key, sum_value in sum_count_dict.items()} |
|
|
|
print(dict(sorted(mean_dict.items()))) |
|
|
|
|
|
def main(): |
|
df = pd.read_csv('/home/tc415/muPPIt_embedding/dataset/correct_skempi.csv') |
|
results = [] |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
|
|
model = muPPIt(1288, 8, 0.1, 10, 1e-4) |
|
model.load_weights('/home/tc415/muPPIt_embedding/checkpoints/new_train_1/model-epoch=15-val_acc=0.57.ckpt') |
|
|
|
device = 'cuda:1' |
|
model.to(device) |
|
model.eval() |
|
|
|
for index, row in tqdm(df.iterrows(), total=len(df)): |
|
binder = row['binder'] |
|
wildtype = row['wt'] |
|
mutant = row['mut'] |
|
mut_aff = np.log10(row['mut_affinity']) |
|
wt_aff = np.log10(row['wt_affinity']) |
|
|
|
binder_tokens = torch.tensor(tokenizer(binder)['input_ids'][1:-1]).unsqueeze(0).to(device) |
|
mut_tokens = torch.tensor(tokenizer(mutant)['input_ids'][1:-1]).unsqueeze(0).to(device) |
|
wt_tokens = torch.tensor(tokenizer(wildtype)['input_ids'][1:-1]).unsqueeze(0).to(device) |
|
|
|
with torch.no_grad(): |
|
distance = model(binder_tokens, wt_tokens, mut_tokens) |
|
|
|
|
|
|
|
|
|
results.append((int(abs(wt_aff - mut_aff)), distance.item())) |
|
|
|
compute_mean(results) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
main() |