File size: 2,332 Bytes
65bd8af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
import pickle
from collections import defaultdict
from predict import *
def compute_mean(tuple_list):
sum_count_dict = defaultdict(lambda: [0, 0]) # [sum, count]
# Iterate through the list and update the sum and count
for key, value in tuple_list:
sum_count_dict[key][0] += value # Sum of tuple[1] for the same tuple[0]
sum_count_dict[key][1] += 1 # Count the occurrences
# Calculate the mean for each unique tuple[0]
mean_dict = {key: round(sum_value[0] / sum_value[1],2) for key, sum_value in sum_count_dict.items()}
print(dict(sorted(mean_dict.items())))
def main():
df = pd.read_csv('/home/tc415/muPPIt_embedding/dataset/correct_skempi.csv')
results = []
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
model = muPPIt(1288, 8, 0.1, 10, 1e-4)
model.load_weights('/home/tc415/muPPIt_embedding/checkpoints/new_train_1/model-epoch=15-val_acc=0.57.ckpt')
device = 'cuda:1'
model.to(device)
model.eval()
for index, row in tqdm(df.iterrows(), total=len(df)):
binder = row['binder']
wildtype = row['wt']
mutant = row['mut']
mut_aff = np.log10(row['mut_affinity'])
wt_aff = np.log10(row['wt_affinity'])
binder_tokens = torch.tensor(tokenizer(binder)['input_ids'][1:-1]).unsqueeze(0).to(device)
mut_tokens = torch.tensor(tokenizer(mutant)['input_ids'][1:-1]).unsqueeze(0).to(device)
wt_tokens = torch.tensor(tokenizer(wildtype)['input_ids'][1:-1]).unsqueeze(0).to(device)
with torch.no_grad():
distance = model(binder_tokens, wt_tokens, mut_tokens)
# if distance > 20:
# continue
results.append((int(abs(wt_aff - mut_aff)), distance.item()))
compute_mean(results)
# with open('skempi_distance.pkl', 'wb') as f:
# pickle.dump(results, f)
# x_values = [t[0] for t in results]
# y_values = [t[1] for t in results]
# sns.kdeplot(x=x_values, y=y_values, fill=True, cmap='viridis')
# plt.xlim(0, None)
# plt.ylim(0, None)
# plt.xlabel('Affinity difference')
# plt.ylabel('Distance')
# plt.savefig('skempi_distance.png')
if __name__ == '__main__':
main() |