File size: 2,332 Bytes
65bd8af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
import pickle
from collections import defaultdict
from predict import *

def compute_mean(tuple_list):
    sum_count_dict = defaultdict(lambda: [0, 0])  # [sum, count]

    # Iterate through the list and update the sum and count
    for key, value in tuple_list:
        sum_count_dict[key][0] += value  # Sum of tuple[1] for the same tuple[0]
        sum_count_dict[key][1] += 1      # Count the occurrences

    # Calculate the mean for each unique tuple[0]
    mean_dict = {key: round(sum_value[0] / sum_value[1],2) for key, sum_value in sum_count_dict.items()}

    print(dict(sorted(mean_dict.items())))


def main():
    df = pd.read_csv('/home/tc415/muPPIt_embedding/dataset/correct_skempi.csv')
    results = []

    tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")

    model = muPPIt(1288, 8, 0.1, 10, 1e-4)
    model.load_weights('/home/tc415/muPPIt_embedding/checkpoints/new_train_1/model-epoch=15-val_acc=0.57.ckpt')

    device = 'cuda:1'
    model.to(device)
    model.eval()

    for index, row in tqdm(df.iterrows(), total=len(df)):
        binder = row['binder']
        wildtype = row['wt']
        mutant = row['mut']
        mut_aff = np.log10(row['mut_affinity'])
        wt_aff = np.log10(row['wt_affinity'])

        binder_tokens = torch.tensor(tokenizer(binder)['input_ids'][1:-1]).unsqueeze(0).to(device)
        mut_tokens = torch.tensor(tokenizer(mutant)['input_ids'][1:-1]).unsqueeze(0).to(device)
        wt_tokens = torch.tensor(tokenizer(wildtype)['input_ids'][1:-1]).unsqueeze(0).to(device)

        with torch.no_grad():
            distance = model(binder_tokens, wt_tokens, mut_tokens)

        # if distance > 20:
        #     continue

        results.append((int(abs(wt_aff - mut_aff)), distance.item()))

    compute_mean(results)

    # with open('skempi_distance.pkl', 'wb') as f:
    #     pickle.dump(results, f)

    # x_values = [t[0] for t in results]
    # y_values = [t[1] for t in results]

    # sns.kdeplot(x=x_values, y=y_values, fill=True, cmap='viridis')

    # plt.xlim(0, None)
    # plt.ylim(0, None)

    # plt.xlabel('Affinity difference')
    # plt.ylabel('Distance')

    # plt.savefig('skempi_distance.png')

if __name__ == '__main__':
    main()