AlienChen commited on
Commit
d732b63
·
verified ·
1 Parent(s): f084507

Create peptide_static_batching.py

Browse files
Files changed (1) hide show
  1. dataset/peptide_static_batching.py +83 -0
dataset/peptide_static_batching.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import ast
3
+ from sklearn.model_selection import train_test_split
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from datasets import Dataset as HFDataset, DatasetDict
9
+ from transformers import AutoTokenizer
10
+ import pdb
11
+
12
+
13
+ class TripletDataset(Dataset):
14
+ def __init__(self, anchors, positives, binding_sites, tokenizer, max_sequence_length=40000):
15
+ self.anchors = anchors
16
+ self.positives = positives
17
+ self.binding_sites = binding_sites
18
+ self.tokenizer = tokenizer
19
+ self.max_sequence_length = max_sequence_length
20
+ self.triplets = []
21
+ self.precompute_triplets()
22
+
23
+ def __len__(self):
24
+ return len(self.triplets)
25
+
26
+ def __getitem__(self, index):
27
+ return self.triplets[index]
28
+
29
+ def precompute_triplets(self):
30
+ self.triplets = []
31
+ for anchor, positive, binding_site in zip(self.anchors, self.positives, self.binding_sites):
32
+ anchor_tokens = self.tokenizer(anchor, return_tensors='pt', padding=True, truncation=True,
33
+ max_length=self.max_sequence_length)
34
+ positive_tokens = self.tokenizer(positive, return_tensors='pt', padding=True, truncation=True,
35
+ max_length=self.max_sequence_length)
36
+
37
+ # mask out the first and last tokens due to being <bos> and <eos>
38
+ anchor_tokens['attention_mask'][0][0] = 0
39
+ anchor_tokens['attention_mask'][0][-1] = 0
40
+ positive_tokens['attention_mask'][0][0] = 0
41
+ positive_tokens['attention_mask'][0][-1] = 0
42
+
43
+ self.triplets.append((anchor_tokens, positive_tokens, binding_site))
44
+ # pdb.set_trace()
45
+ return self.triplets
46
+
47
+
48
+ def main():
49
+
50
+ data = pd.read_csv('/home/tc415/muPPIt/dataset/pep_prot/pep_prot_test.csv')
51
+
52
+ print(len(data))
53
+
54
+ positives = data['Binder'].tolist()
55
+ anchors = data['Target'].tolist()
56
+ binding_sites = data['Motif'].tolist()
57
+
58
+ # We should plus 1 because there will be a start token after embedded by ESM-2
59
+ binding_sites = [binding_site.split(',') for binding_site in binding_sites]
60
+ binding_sites = [[int(site) + 1 for site in binding_site] for binding_site in binding_sites]
61
+
62
+ train_anchor_dataset = np.array(anchors)
63
+ train_positive_dataset = np.array(positives)
64
+ train_binding_dataset = binding_sites
65
+
66
+ # Create an instance of the tokenizer
67
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
68
+
69
+ # Initialize the TripletDataset
70
+ train_dataset = TripletDataset(train_anchor_dataset, train_positive_dataset, train_binding_dataset, tokenizer=tokenizer, max_sequence_length=50000)
71
+ train_prebatched_data_dict = {
72
+ 'anchors': [batch[0] for batch in train_dataset.triplets],
73
+ 'positives': [batch[1] for batch in train_dataset.triplets],
74
+ 'binding_site': [batch[2] for batch in train_dataset.triplets]
75
+ }
76
+
77
+ # Convert the dictionary to a HuggingFace Dataset
78
+ train_hf_dataset = HFDataset.from_dict(train_prebatched_data_dict)
79
+ train_hf_dataset.save_to_disk('/home/tc415/muPPIt/dataset/pep_prot_test')
80
+
81
+
82
+ if __name__ == "__main__":
83
+ main()