import pandas as pd import matplotlib.pyplot as plt from tqdm import tqdm import seaborn as sns import pickle from collections import defaultdict from predict import * def compute_mean(tuple_list): sum_count_dict = defaultdict(lambda: [0, 0]) # [sum, count] # Iterate through the list and update the sum and count for key, value in tuple_list: sum_count_dict[key][0] += value # Sum of tuple[1] for the same tuple[0] sum_count_dict[key][1] += 1 # Count the occurrences # Calculate the mean for each unique tuple[0] mean_dict = {key: round(sum_value[0] / sum_value[1],2) for key, sum_value in sum_count_dict.items()} print(dict(sorted(mean_dict.items()))) def main(): df = pd.read_csv('/home/tc415/muPPIt_embedding/dataset/correct_skempi.csv') results = [] tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") model = muPPIt(1288, 8, 0.1, 10, 1e-4) model.load_weights('/home/tc415/muPPIt_embedding/checkpoints/new_train_1/model-epoch=15-val_acc=0.57.ckpt') device = 'cuda:1' model.to(device) model.eval() for index, row in tqdm(df.iterrows(), total=len(df)): binder = row['binder'] wildtype = row['wt'] mutant = row['mut'] mut_aff = np.log10(row['mut_affinity']) wt_aff = np.log10(row['wt_affinity']) binder_tokens = torch.tensor(tokenizer(binder)['input_ids'][1:-1]).unsqueeze(0).to(device) mut_tokens = torch.tensor(tokenizer(mutant)['input_ids'][1:-1]).unsqueeze(0).to(device) wt_tokens = torch.tensor(tokenizer(wildtype)['input_ids'][1:-1]).unsqueeze(0).to(device) with torch.no_grad(): distance = model(binder_tokens, wt_tokens, mut_tokens) # if distance > 20: # continue results.append((int(abs(wt_aff - mut_aff)), distance.item())) compute_mean(results) # with open('skempi_distance.pkl', 'wb') as f: # pickle.dump(results, f) # x_values = [t[0] for t in results] # y_values = [t[1] for t in results] # sns.kdeplot(x=x_values, y=y_values, fill=True, cmap='viridis') # plt.xlim(0, None) # plt.ylim(0, None) # plt.xlabel('Affinity difference') # plt.ylabel('Distance') # plt.savefig('skempi_distance.png') if __name__ == '__main__': main()