AlienChen commited on
Commit
0977aa0
·
verified ·
1 Parent(s): dbeb56d

Delete dataset

Browse files
dataset/PPI_README.md DELETED
@@ -1,8 +0,0 @@
1
- # 1. Run comtamination.py to get initial results and error sequences.
2
- `python -u contamination.py -i 1,2,3,4`
3
-
4
- # 2. Run extract_full_sequence.py to get full sequences for error sequences. extract_full_sequence.py can only run for one id at a time.
5
- `python -u extract_full_sequence.py -id 1`
6
-
7
- # 3. After getting full sequences for error sequences, run final_contamination.py to get the final results.
8
- `python -u final_contamination.py -i 1,2,3,4`
 
 
 
 
 
 
 
 
 
dataset/PPI_contamination.py DELETED
@@ -1,169 +0,0 @@
1
- """BLOSUM guided motif contamination"""
2
- import pandas as pd
3
- import blosum as bl
4
- import ast
5
- import pickle
6
- import pandas as pd
7
- from Bio import SeqIO
8
- from math import ceil
9
- from sklearn.model_selection import train_test_split
10
- import random
11
- from Bio.Seq import Seq
12
- from Bio.SeqRecord import SeqRecord
13
- import argparse
14
-
15
- def main(i):
16
- random.seed(42)
17
-
18
- blosum = bl.BLOSUM(62)
19
- def get_least_likely_substitution(residue):
20
- if residue not in blosum:
21
- return residue # If residue is not in Blosum matrix, return it as is
22
- matrix_keys = list(blosum.keys())
23
- min_score = min(blosum[residue][r] for r in matrix_keys if r != '*' and r != 'J')
24
- least_likely_residues = [r for r in matrix_keys if r != '*' and r != 'J' and blosum[residue][r] == min_score]
25
- least_likely_residue = random.choice(least_likely_residues)
26
- return least_likely_residue
27
-
28
-
29
- csv_file = f"raw_data/processed_6A_results_batch_{i}.json"
30
- df = pd.read_json(csv_file)
31
- df.to_csv(f"raw_data/processed_6A_results_batch_{i}.csv", index=False)
32
- df = pd.read_csv(f"raw_data/processed_6A_results_batch_{i}.csv")
33
-
34
- output_csv = f"contaminated_data/processed_6A_results_batch_{i}.csv"
35
- error_csv = f"contaminated_data/error_6A_results_batch_{i}.csv"
36
-
37
- new_rows = []
38
- error_rows = []
39
-
40
- for idx, row in df.iterrows():
41
- flag1 = False # check whether there are errors when mutating Sequence1
42
- flag2 = False # check whether there are errors when mutation Sequence2
43
-
44
- chain1 = row['Chain1'].upper()
45
- chain2 = row['Chain2'].upper()
46
- sequence1 = row['Sequence1']
47
- sequence2 = row['Sequence2']
48
- chain_1_motifs = ast.literal_eval(row['Chain_1_motifs'])
49
- chain_2_motifs = ast.literal_eval(row['Chain_2_motifs'])
50
- chain_1_offset = row['Chain_1_offset']
51
- chain_2_offset = row['Chain_2_offset']
52
-
53
- # Create a new entry by mutating sequence1
54
- sequence1_list = list(sequence1)
55
- modified_chain_1_motifs = []
56
- if len(chain_1_motifs) > 0:
57
-
58
- # Ignore entries where motif length equals sequence length cuz it'll be too hard for models to learn
59
- if len(chain_1_motifs) == len(sequence1):
60
- flag1 = True
61
-
62
- for motif in chain_1_motifs:
63
- res, pos = motif.split('_')
64
- # Errors for motifs or there are unalignments between sequence and motif
65
- if int(pos) >= len(sequence1) or int(pos) < 0 or res != sequence1[int(pos)]:
66
- error_rows.append({
67
- 'PDB_ID': row['PDB_ID'] + '_' + chain1 + '_' + chain2,
68
- 'Chain': chain1,
69
- 'Sequence': sequence1,
70
- 'Error_motif': motif,
71
- 'Chain_offset': row['Chain_1_offset']
72
- })
73
- flag1 = True
74
- break
75
-
76
- least_likely_residue = get_least_likely_substitution(res)
77
- sequence1_list[int(pos)] = least_likely_residue
78
- modified_chain_1_motifs.append(res + '_' + pos + '_' + least_likely_residue)
79
-
80
- # only save the entries without errors or do not need to be ignored
81
- if flag1 is False:
82
- modified_sequence1 = ''.join(sequence1_list)
83
- new_rows.append({
84
- 'PDB_ID': row['PDB_ID'] + '_' + chain1 + '_' + chain2,
85
- 'Chain1': chain1,
86
- 'Sequence1': modified_sequence1,
87
- 'Chain2': chain2,
88
- 'Sequence2': sequence2,
89
- 'Chain_1_motifs': str(modified_chain_1_motifs),
90
- 'Chain_2_motifs': row['Chain_2_motifs'],
91
- 'Chain_1_offset': row['Chain_1_offset'],
92
- 'Chain_2_offset': row['Chain_2_offset'],
93
- 'Modified_chain': chain1,
94
- 'Original_sequence': sequence1,
95
- })
96
-
97
- # If sequence2 is the same as sequence1 and so as the motifs, do not need to mutate sequence2
98
- if sequence1 == sequence2 and chain_1_motifs == chain_2_motifs:
99
- continue
100
-
101
- # Create a new entry by mutating sequence2, using the same logic as sequenc1
102
- if len(chain_2_motifs) > 0:
103
- if len(chain_2_motifs) == len(sequence2):
104
- flag2 == True
105
- sequence2_list = list(sequence2)
106
- modified_chain_2_motifs = []
107
- for motif in chain_2_motifs:
108
- res, pos = motif.split('_')
109
- if int(pos) >= len(sequence2) or int(pos) < 0 or res != sequence2[int(pos)]:
110
- error_rows.append({
111
- 'PDB_ID': row['PDB_ID'] + '_' + chain1 + '_' + chain2,
112
- 'Chain': chain2,
113
- 'Sequence': sequence2,
114
- 'Error_motif': motif,
115
- 'Chain_offset': row['Chain_2_offset']
116
- })
117
- flag2 = True
118
- break
119
-
120
- least_likely_residue = get_least_likely_substitution(res)
121
- sequence2_list[int(pos)] = least_likely_residue
122
- modified_chain_2_motifs.append(res + '_' + pos + '_' + least_likely_residue)
123
-
124
- if flag2 is False:
125
- modified_sequence2 = ''.join(sequence2_list)
126
- new_rows.append({
127
- 'PDB_ID': row['PDB_ID'] + '_' + chain2 + '_' + chain1,
128
- 'Chain1': chain1,
129
- 'Sequence1': sequence1,
130
- 'Chain2': chain2,
131
- 'Sequence2': modified_sequence2,
132
- 'Chain_1_motifs': row['Chain_1_motifs'],
133
- 'Chain_2_motifs': str(modified_chain_2_motifs),
134
- 'Chain_1_offset': row['Chain_1_offset'],
135
- 'Chain_2_offset': row['Chain_2_offset'],
136
- 'Modified_chain': chain2,
137
- 'Original_sequence': sequence2,
138
- })
139
-
140
-
141
-
142
- # Finished mutation
143
- new_df = pd.DataFrame(new_rows)
144
-
145
- # Deduplicate
146
- columns_to_check = ['Sequence1', 'Sequence2', 'Chain_1_motifs', 'Chain_2_motifs', 'Chain_1_offset', 'Chain_2_offset']
147
- deduplicated_new_df = new_df.drop_duplicates(subset=columns_to_check)
148
- print(f"Number of rows before deduplication: {len(new_df)}")
149
- print(f"Number of rows after deduplication: {len(deduplicated_new_df)}")
150
-
151
- deduplicated_new_df.to_csv(output_csv, index=False)
152
-
153
- # Save error sequences to another file
154
- error_df = pd.DataFrame(error_rows)
155
- error_df.to_csv(error_csv, index=False)
156
-
157
- if __name__ == '__main__':
158
-
159
- parser = argparse.ArgumentParser()
160
-
161
- parser.add_argument('-i')
162
-
163
- args = parser.parse_args()
164
-
165
- i_s = args.i # 2,3,4,5,6,7,8,9,10
166
-
167
- for i in i_s.split(','):
168
- print(int(i))
169
- main(int(i))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset/PPI_extract_full_sequence.py DELETED
@@ -1,176 +0,0 @@
1
- """Pull full sequences from PDB files for error sequences"""
2
- import json
3
- import os
4
- import logging
5
- from Bio import PDB
6
- import warnings
7
- import requests
8
- import pickle
9
- import pandas as pd
10
- import argparse
11
-
12
- warnings.filterwarnings("ignore", category=PDB.PDBExceptions.PDBConstructionWarning)
13
- logging.basicConfig(filename='pdb.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
-
15
- AA_CODE_MAP = {
16
- 'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D',
17
- 'CYS': 'C', 'GLN': 'Q', 'GLU': 'E', 'GLY': 'G',
18
- 'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K',
19
- 'MET': 'M', 'PHE': 'F', 'PRO': 'P', 'SER': 'S',
20
- 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'
21
- }
22
-
23
-
24
- # Missing residues are recoreded in REMARK 465 fields in pdb files
25
- def extract_remark_465(pdb_file_path):
26
- remark_465_lines = []
27
- with open(pdb_file_path, 'r') as file:
28
- for line in file:
29
- if line.startswith("REMARK 465 "):
30
- remark_465_lines.append(line.strip())
31
- return remark_465_lines[2:]
32
-
33
-
34
- def parse_remark_465(remark_465_lines):
35
- missing_residues = {}
36
- for line in remark_465_lines:
37
- parts = line.split()
38
- if len(parts) < 5:
39
- continue
40
- chain_id = parts[3]
41
- resseq = int(parts[4])
42
- resname = parts[2]
43
- # print(resname, chain_id, resseq)
44
- if chain_id not in missing_residues:
45
- missing_residues[chain_id] = []
46
- missing_residues[chain_id].append((resseq, resname))
47
- return missing_residues
48
-
49
-
50
- def extract_sequences(structure, target_chain_id, missing_residues):
51
- for chain in structure.get_chains():
52
- chain_id = chain.get_id()
53
- residues = list(chain.get_residues())
54
- if chain_id == target_chain_id:
55
- seq_list = []
56
- resseq_set = set(res.get_id()[1] for res in residues)
57
- min_resseq_struct = min(resseq_set, default=1)
58
- max_resseq_struct = max(resseq_set, default=0)
59
- max_resseq_missing = max((x[0] for x in missing_residues.get(chain_id, [])), default=0)
60
- resseq_max = max(max_resseq_struct, max_resseq_missing)
61
-
62
- for i in range(min_resseq_struct, resseq_max + 1):
63
- if i in resseq_set:
64
- resname = next(res.get_resname() for res in residues if res.get_id()[1] == i)
65
- seq_list.append(AA_CODE_MAP.get(resname, 'X'))
66
- elif chain_id in missing_residues and i in [x[0] for x in missing_residues[chain_id]]:
67
- resname = next(x[1] for x in missing_residues[chain_id] if x[0] == i)
68
- seq_list.append(AA_CODE_MAP.get(resname, 'X'))
69
-
70
- chain_seq = ''.join(seq_list).strip('X')
71
-
72
- return chain_seq
73
-
74
-
75
- def download_pdb(pdb_id, id):
76
- # proxies = {
77
- # "http": "http://127.0.0.1:1080",
78
- # "https": "http://127.0.0.1:1080",
79
- # }
80
-
81
- url = f"https://files.rcsb.org/download/{pdb_id}.pdb"
82
- file_path = f"pdb{id}/{pdb_id}.pdb"
83
-
84
- while (True):
85
- try:
86
- # Download the PDB file
87
- # response = requests.get(url, proxies=proxies)
88
- response = requests.get(url)
89
- response.raise_for_status()
90
- with open(file_path, "wb") as file:
91
- file.write(response.content)
92
- # print(f"Downloaded {pdb_id}.pdb")
93
- return file_path
94
- except requests.exceptions.RequestException as e:
95
- print(f"Failed to download {pdb_id}.pdb: {e}")
96
- continue
97
-
98
-
99
- def delete_pdb(pdb_id, chain_id, id):
100
- file_path = f"pdb{id}/{pdb_id}.pdb"
101
- if os.path.exists(file_path):
102
- os.remove(file_path)
103
- print(f"Deleted {pdb_id}.pdb for {chain_id}")
104
- else:
105
- print(f"File {pdb_id}.pdb does not exist")
106
-
107
-
108
- def process_entry(entry, chain_id, results, id):
109
- pdb_id = entry[0:4]
110
-
111
- pdb_file_path = download_pdb(pdb_id, id)
112
-
113
- if os.path.exists(pdb_file_path):
114
- try:
115
- parser = PDB.PDBParser()
116
- structure = parser.get_structure(pdb_id, pdb_file_path)
117
-
118
- # Extract and parse REMARK 465
119
- remark_465 = extract_remark_465(pdb_file_path)
120
- missing_residues = parse_remark_465(remark_465)
121
-
122
- # Get the full sequences for target chain
123
- chain_seq = extract_sequences(structure, chain_id, missing_residues)
124
-
125
- for index, row in results.iterrows():
126
- if row['PDB_ID'] == pdb_id:
127
- if row['Chain1'] == chain_id:
128
- results.at[index, 'Sequence1'] = chain_seq
129
- elif row['Chain2'] == chain_id:
130
- results.at[index, 'Sequence2'] = chain_seq
131
- else:
132
- NotImplementedError
133
-
134
- delete_pdb(pdb_id, chain_id, id)
135
-
136
- except Exception as e:
137
- logging.error(f'Failed to process {pdb_id}: {str(e)}')
138
- else:
139
- logging.error(f'PDB file {pdb_id}.pdb not found')
140
-
141
-
142
- def main(id):
143
- # Load the PDB_ID list and corresponding chain ID
144
- df = pd.read_csv(f'contaminated_data/error_6A_results_batch_{id}.csv')
145
- # print(df)
146
- pdb_id_list = df['PDB_ID'].tolist()
147
- chain_id_list = df['Chain'].tolist()
148
- processed = []
149
-
150
- rs = pd.read_csv(f'raw_data/processed_6A_results_batch_{id}.csv')
151
-
152
- for i in range(len(pdb_id_list)):
153
- entry, chain_id = pdb_id_list[i], chain_id_list[i].upper() # 6x85_D_F, D
154
- if {entry: chain_id} not in processed:
155
- processed.append({entry: chain_id})
156
- process_entry(entry, chain_id, rs, id)
157
-
158
- if i % 100 == 0:
159
- rs.to_csv(f'raw_data/corrected_processed_6A_results_batch_{id}.csv', index=False)
160
- print(f"Saving for i={i}")
161
-
162
- rs.to_csv(f'raw_data/corrected_processed_6A_results_batch_{id}.csv', index=False)
163
-
164
-
165
- if __name__ == "__main__":
166
- parser = argparse.ArgumentParser()
167
-
168
- parser.add_argument('-id')
169
-
170
- args = parser.parse_args()
171
-
172
- print(int(args.id))
173
-
174
- if not os.path.exists(f"pdb{int(args.id)}"):
175
- os.makedirs(f"pdb{int(args.id)}")
176
- main(int(args.id))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset/PPI_final_contamination.py DELETED
@@ -1,165 +0,0 @@
1
- """Final motif contamination after pulling full sequences"""
2
- import pandas as pd
3
- import blosum as bl
4
- import ast
5
- import pickle
6
- import pandas as pd
7
- from Bio import SeqIO
8
- from math import ceil
9
- from sklearn.model_selection import train_test_split
10
- import random
11
- from Bio.Seq import Seq
12
- from Bio.SeqRecord import SeqRecord
13
- import argparse
14
-
15
- def main(i):
16
- random.seed(42)
17
-
18
- blosum = bl.BLOSUM(62)
19
- def get_least_likely_substitution(residue):
20
- if residue not in blosum:
21
- return residue # If residue is not in Blosum matrix, return it as is
22
- matrix_keys = list(blosum.keys())
23
- min_score = min(blosum[residue][r] for r in matrix_keys if r != '*' and r != 'J')
24
- least_likely_residues = [r for r in matrix_keys if r != '*' and r != 'J' and blosum[residue][r] == min_score]
25
- least_likely_residue = random.choice(least_likely_residues)
26
- return least_likely_residue
27
-
28
-
29
- df = pd.read_csv(f"raw_data/corrected_processed_6A_results_batch_{i}.csv")
30
-
31
- output_csv = f"contaminated_data/processed_6A_results_batch_{i}.csv"
32
- error_csv = f"contaminated_data/error_6A_results_batch_{i}.csv"
33
-
34
- new_rows = []
35
- error_rows = []
36
-
37
- for idx, row in df.iterrows():
38
- flag1 = False # check whether there are errors when mutating Sequence1
39
- flag2 = False # check whether there are errors when mutation Sequence2
40
-
41
- chain1 = row['Chain1'].upper()
42
- chain2 = row['Chain2'].upper()
43
- sequence1 = row['Sequence1']
44
- sequence2 = row['Sequence2']
45
- chain_1_motifs = ast.literal_eval(row['Chain_1_motifs'])
46
- chain_2_motifs = ast.literal_eval(row['Chain_2_motifs'])
47
- chain_1_offset = row['Chain_1_offset']
48
- chain_2_offset = row['Chain_2_offset']
49
-
50
- # Create a new entry by mutating sequence1
51
- sequence1_list = list(sequence1)
52
- modified_chain_1_motifs = []
53
- if len(chain_1_motifs) > 0:
54
-
55
- # Ignore entries where motif length equals sequence length cuz it'll be too hard for models to learn
56
- if len(chain_1_motifs) == len(sequence1):
57
- flag1 = True
58
-
59
- for motif in chain_1_motifs:
60
- res, pos = motif.split('_')
61
- # Errors for motifs or there are unalignments between sequence and motif
62
- if int(pos) >= len(sequence1) or int(pos) < 0 or res != sequence1[int(pos)]:
63
- error_rows.append({
64
- 'PDB_ID': row['PDB_ID'] + '_' + chain1 + '_' + chain2,
65
- 'Chain': chain1,
66
- 'Sequence': sequence1,
67
- 'Error_motif': motif,
68
- 'Chain_offset': row['Chain_1_offset']
69
- })
70
- flag1 = True
71
- break
72
-
73
- least_likely_residue = get_least_likely_substitution(res)
74
- sequence1_list[int(pos)] = least_likely_residue
75
- modified_chain_1_motifs.append(res + '_' + pos + '_' + least_likely_residue)
76
-
77
- # only save the entries without errors or do not need to be ignored
78
- if flag1 is False:
79
- modified_sequence1 = ''.join(sequence1_list)
80
- new_rows.append({
81
- 'PDB_ID': row['PDB_ID'] + '_' + chain1 + '_' + chain2,
82
- 'Chain1': chain1,
83
- 'Sequence1': modified_sequence1,
84
- 'Chain2': chain2,
85
- 'Sequence2': sequence2,
86
- 'Chain_1_motifs': str(modified_chain_1_motifs),
87
- 'Chain_2_motifs': row['Chain_2_motifs'],
88
- 'Chain_1_offset': row['Chain_1_offset'],
89
- 'Chain_2_offset': row['Chain_2_offset'],
90
- 'Modified_chain': chain1,
91
- 'Original_sequence': sequence1,
92
- })
93
-
94
- # if sequence2 is the same as sequence1 and so as the motifs, do not need to mutate sequence2
95
- if sequence1 == sequence2 and chain_1_motifs == chain_2_motifs:
96
- continue
97
-
98
- # Create a new entry by mutating sequence2, using the same logic as sequenc1
99
- if len(chain_2_motifs) > 0:
100
- if len(chain_2_motifs) == len(sequence2):
101
- flag2 == True
102
- sequence2_list = list(sequence2)
103
- modified_chain_2_motifs = []
104
- for motif in chain_2_motifs:
105
- res, pos = motif.split('_')
106
- if int(pos) >= len(sequence2) or int(pos) < 0 or res != sequence2[int(pos)]:
107
- error_rows.append({
108
- 'PDB_ID': row['PDB_ID'] + '_' + chain1 + '_' + chain2,
109
- 'Chain': chain2,
110
- 'Sequence': sequence2,
111
- 'Error_motif': motif,
112
- 'Chain_offset': row['Chain_2_offset']
113
- })
114
- flag2 = True
115
- break
116
-
117
- least_likely_residue = get_least_likely_substitution(res)
118
- sequence2_list[int(pos)] = least_likely_residue
119
- modified_chain_2_motifs.append(res + '_' + pos + '_' + least_likely_residue)
120
-
121
- if flag2 is False:
122
- modified_sequence2 = ''.join(sequence2_list)
123
- new_rows.append({
124
- 'PDB_ID': row['PDB_ID'] + '_' + chain2 + '_' + chain1,
125
- 'Chain1': chain1,
126
- 'Sequence1': sequence1,
127
- 'Chain2': chain2,
128
- 'Sequence2': modified_sequence2,
129
- 'Chain_1_motifs': row['Chain_1_motifs'],
130
- 'Chain_2_motifs': str(modified_chain_2_motifs),
131
- 'Chain_1_offset': row['Chain_1_offset'],
132
- 'Chain_2_offset': row['Chain_2_offset'],
133
- 'Modified_chain': chain2,
134
- 'Original_sequence': sequence2,
135
- })
136
-
137
-
138
-
139
- # Finished mutation
140
- new_df = pd.DataFrame(new_rows)
141
-
142
- # Deduplicate
143
- columns_to_check = ['Sequence1', 'Sequence2', 'Chain_1_motifs', 'Chain_2_motifs', 'Chain_1_offset', 'Chain_2_offset']
144
- deduplicated_new_df = new_df.drop_duplicates(subset=columns_to_check)
145
- print(f"Number of rows before deduplication: {len(new_df)}")
146
- print(f"Number of rows after deduplication: {len(deduplicated_new_df)}")
147
-
148
- deduplicated_new_df.to_csv(output_csv, index=False)
149
-
150
- error_df = pd.DataFrame(error_rows)
151
- error_df.to_csv(error_csv, index=False)
152
-
153
- if __name__ == '__main__':
154
-
155
- parser = argparse.ArgumentParser()
156
-
157
- parser.add_argument('-i')
158
-
159
- args = parser.parse_args()
160
-
161
- i_s = args.i # 2,3,4,5,6,7,8,9,10
162
-
163
- for i in i_s.split(','):
164
- print(int(i))
165
- main(int(i))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset/compute_class_weights.py DELETED
@@ -1,47 +0,0 @@
1
- import pandas as pd
2
- import ast
3
- from sklearn.model_selection import train_test_split
4
- import numpy as np
5
- import torch.nn.functional as F
6
- import torch
7
- from torch.utils.data import Dataset
8
- from datasets import Dataset as HFDataset, DatasetDict
9
- from transformers import AutoTokenizer
10
- import pdb
11
-
12
- def main():
13
-
14
- data = pd.read_csv('dataset_drop_500.csv')
15
-
16
- print(len(data))
17
-
18
- binding_sites = data['mutTarget_motifs'].tolist()
19
- targets = data['Target'].tolist()
20
-
21
- # No need for padding the first position of binding sites for class weight calculations
22
- binding_sites = [ast.literal_eval(binding_site) for binding_site in binding_sites]
23
- binding_sites = [len(binding_site) for binding_site in binding_sites]
24
- targets = [len(seq) for seq in targets]
25
- pdb.set_trace()
26
-
27
- train_val_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
28
- train_data, val_data = train_test_split(train_val_data, test_size=0.25, random_state=42)
29
-
30
- train_index = train_data.index.to_numpy()
31
- print(len(train_index))
32
- return
33
-
34
- train_binding_dataset = [binding_sites[i] for i in train_index]
35
- train_targets = [targets[i] for i in train_index]
36
-
37
- num_binding_sites = sum(train_binding_dataset)
38
- num_total = sum(train_targets)
39
- num_non_binding_sites = num_total - num_binding_sites
40
- weight_for_binding = num_total / (2 * num_binding_sites)
41
- weight_for_non_binding = num_total / (2 * num_non_binding_sites)
42
-
43
- print(weight_for_binding, weight_for_non_binding)
44
-
45
-
46
- if __name__ == "__main__":
47
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset/peptide_static_batching.py DELETED
@@ -1,83 +0,0 @@
1
- import pandas as pd
2
- import ast
3
- from sklearn.model_selection import train_test_split
4
- import numpy as np
5
- import torch.nn.functional as F
6
- import torch
7
- from torch.utils.data import Dataset
8
- from datasets import Dataset as HFDataset, DatasetDict
9
- from transformers import AutoTokenizer
10
- import pdb
11
-
12
-
13
- class TripletDataset(Dataset):
14
- def __init__(self, anchors, positives, binding_sites, tokenizer, max_sequence_length=40000):
15
- self.anchors = anchors
16
- self.positives = positives
17
- self.binding_sites = binding_sites
18
- self.tokenizer = tokenizer
19
- self.max_sequence_length = max_sequence_length
20
- self.triplets = []
21
- self.precompute_triplets()
22
-
23
- def __len__(self):
24
- return len(self.triplets)
25
-
26
- def __getitem__(self, index):
27
- return self.triplets[index]
28
-
29
- def precompute_triplets(self):
30
- self.triplets = []
31
- for anchor, positive, binding_site in zip(self.anchors, self.positives, self.binding_sites):
32
- anchor_tokens = self.tokenizer(anchor, return_tensors='pt', padding=True, truncation=True,
33
- max_length=self.max_sequence_length)
34
- positive_tokens = self.tokenizer(positive, return_tensors='pt', padding=True, truncation=True,
35
- max_length=self.max_sequence_length)
36
-
37
- # mask out the first and last tokens due to being <bos> and <eos>
38
- anchor_tokens['attention_mask'][0][0] = 0
39
- anchor_tokens['attention_mask'][0][-1] = 0
40
- positive_tokens['attention_mask'][0][0] = 0
41
- positive_tokens['attention_mask'][0][-1] = 0
42
-
43
- self.triplets.append((anchor_tokens, positive_tokens, binding_site))
44
- # pdb.set_trace()
45
- return self.triplets
46
-
47
-
48
- def main():
49
-
50
- data = pd.read_csv('/home/tc415/muPPIt/dataset/pep_prot/pep_prot_test.csv')
51
-
52
- print(len(data))
53
-
54
- positives = data['Binder'].tolist()
55
- anchors = data['Target'].tolist()
56
- binding_sites = data['Motif'].tolist()
57
-
58
- # We should plus 1 because there will be a start token after embedded by ESM-2
59
- binding_sites = [binding_site.split(',') for binding_site in binding_sites]
60
- binding_sites = [[int(site) + 1 for site in binding_site] for binding_site in binding_sites]
61
-
62
- train_anchor_dataset = np.array(anchors)
63
- train_positive_dataset = np.array(positives)
64
- train_binding_dataset = binding_sites
65
-
66
- # Create an instance of the tokenizer
67
- tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
68
-
69
- # Initialize the TripletDataset
70
- train_dataset = TripletDataset(train_anchor_dataset, train_positive_dataset, train_binding_dataset, tokenizer=tokenizer, max_sequence_length=50000)
71
- train_prebatched_data_dict = {
72
- 'anchors': [batch[0] for batch in train_dataset.triplets],
73
- 'positives': [batch[1] for batch in train_dataset.triplets],
74
- 'binding_site': [batch[2] for batch in train_dataset.triplets]
75
- }
76
-
77
- # Convert the dictionary to a HuggingFace Dataset
78
- train_hf_dataset = HFDataset.from_dict(train_prebatched_data_dict)
79
- train_hf_dataset.save_to_disk('/home/tc415/muPPIt/dataset/pep_prot_test')
80
-
81
-
82
- if __name__ == "__main__":
83
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset/prebatching.py DELETED
@@ -1,197 +0,0 @@
1
- import pandas as pd
2
- import ast
3
- from sklearn.model_selection import train_test_split
4
- import numpy as np
5
- import torch.nn.functional as F
6
- import torch
7
- from torch.utils.data import Dataset
8
- from datasets import Dataset as HFDataset, DatasetDict
9
- from transformers import AutoTokenizer
10
- from lightning.pytorch import seed_everything
11
- import pdb
12
-
13
-
14
- class TripletDataset(Dataset):
15
- def __init__(self, anchors, positives, negatives, binding_sites, tokenizer, max_sequence_length=40000):
16
- self.anchors = anchors
17
- self.positives = positives
18
- self.negatives = negatives
19
- self.binding_sites = binding_sites
20
- self.tokenizer = tokenizer
21
- self.max_sequence_length = max_sequence_length
22
- self.triplets = self.precompute_triplets()
23
- self.batch_indices = self.get_batch_indices()
24
- self.prebatched_data = self.create_prebatched_data()
25
-
26
- def __len__(self):
27
- return len(self.batch_indices)
28
-
29
- def __getitem__(self, index):
30
- batch = self.prebatched_data[index]
31
- return batch
32
-
33
- def precompute_triplets(self):
34
- triplets = []
35
- for anchor, positive, negative, binding_site in zip(self.anchors, self.positives, self.negatives,
36
- self.binding_sites):
37
- triplets.append((anchor, positive, negative, binding_site))
38
- return triplets
39
-
40
- def get_batch_indices(self):
41
- sizes = [(len(anchor) + len(positive) + len(negative), i) for i, (anchor, positive, negative, _) in
42
- enumerate(self.triplets)]
43
- sizes.sort()
44
- batches = []
45
- buf = []
46
- current_buf_len = 0
47
-
48
- def _flush_current_buf():
49
- nonlocal current_buf_len, buf
50
- if len(buf) == 0:
51
- return
52
- batches.append(buf)
53
- buf = []
54
- current_buf_len = 0
55
-
56
- for sz, i in sizes:
57
- if current_buf_len + sz > self.max_sequence_length:
58
- _flush_current_buf()
59
- buf.append(i)
60
- current_buf_len += sz
61
-
62
- _flush_current_buf()
63
- return batches
64
-
65
- def create_prebatched_data(self):
66
- prebatched_data = []
67
- for batch_indices in self.batch_indices:
68
- anchor_batch = []
69
- positive_batch = []
70
- negative_batch = []
71
- binding_site_batch = []
72
-
73
- for index in batch_indices:
74
- anchor, positive, negative, binding_site = self.triplets[index]
75
- anchor_batch.append(anchor)
76
- positive_batch.append(positive)
77
- negative_batch.append(negative)
78
- binding_site_batch.append(binding_site)
79
-
80
- anchor_tokens = self.tokenizer(anchor_batch, return_tensors='pt', padding=True, truncation=True,
81
- max_length=self.max_sequence_length)
82
- positive_tokens = self.tokenizer(positive_batch, return_tensors='pt', padding=True, truncation=True,
83
- max_length=self.max_sequence_length)
84
- negative_tokens = self.tokenizer(negative_batch, return_tensors='pt', padding=True, truncation=True,
85
- max_length=self.max_sequence_length)
86
-
87
- n, max_length = negative_tokens['input_ids'].shape[0], negative_tokens['input_ids'].shape[1]
88
- target = torch.zeros(n, max_length)
89
- for i in range(len(binding_site_batch)):
90
- binding_site = binding_site_batch[i]
91
- target[i,binding_site] = 1
92
-
93
- # mask out the first column because it corresponds to the start token
94
- anchor_tokens['attention_mask'][:, 0] = 0
95
- positive_tokens['attention_mask'][:, 0] = 0
96
- negative_tokens['attention_mask'][:, 0] = 0
97
-
98
- prebatched_data.append({
99
- 'anchor_input_ids': anchor_tokens['input_ids'],
100
- 'anchor_attention_mask': anchor_tokens['attention_mask'],
101
- 'positive_input_ids': positive_tokens['input_ids'],
102
- 'positive_attention_mask': positive_tokens['attention_mask'],
103
- 'negative_input_ids': negative_tokens['input_ids'],
104
- 'negative_attention_mask': negative_tokens['attention_mask'],
105
- 'binding_site': target
106
- })
107
-
108
- return prebatched_data
109
-
110
- def main():
111
- seed_everything(42)
112
-
113
- data = pd.read_csv('dataset/dataset.csv')
114
-
115
- negatives = data['mutTarget'].tolist()
116
- positives = data['Binder'].tolist()
117
- anchors = data['Target'].tolist()
118
- binding_sites = data['mutTarget_motifs'].tolist()
119
-
120
- # We should plus 1 because there will be a start token after embedded by ESM-2
121
- binding_sites = [ast.literal_eval(binding_site) for binding_site in binding_sites]
122
- binding_sites = [[int(site.split('_')[1]) + 1 for site in binding_site] for binding_site in binding_sites]
123
-
124
- train_val_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
125
- train_data, val_data = train_test_split(train_val_data, test_size=0.25, random_state=42)
126
-
127
- train_index = train_data.index.to_numpy()
128
- val_index = val_data.index.to_numpy()
129
- test_index = test_data.index.to_numpy()
130
-
131
- train_anchor_dataset = np.array(anchors)[train_index]
132
- train_negative_dataset = np.array(negatives)[train_index]
133
- train_positive_dataset = np.array(positives)[train_index]
134
- train_binding_dataset = [binding_sites[i] for i in train_index]
135
-
136
- val_anchor_dataset = np.array(anchors)[val_index]
137
- val_negative_dataset = np.array(negatives)[val_index]
138
- val_positive_dataset = np.array(positives)[val_index]
139
- val_binding_dataset = [binding_sites[i] for i in val_index]
140
-
141
- test_anchor_dataset = np.array(anchors)[test_index]
142
- test_negative_dataset = np.array(negatives)[test_index]
143
- test_positive_dataset = np.array(positives)[test_index]
144
- test_binding_dataset = [binding_sites[i] for i in test_index]
145
-
146
- # Create an instance of the tokenizer
147
- tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
148
-
149
- # Initialize the TripletDataset
150
- train_dataset = TripletDataset(train_anchor_dataset, train_positive_dataset, train_negative_dataset, train_binding_dataset, tokenizer=tokenizer, max_sequence_length=40000)
151
- val_dataset = TripletDataset(val_anchor_dataset, val_positive_dataset, val_negative_dataset, val_binding_dataset, tokenizer=tokenizer, max_sequence_length=40000)
152
- test_dataset = TripletDataset(test_anchor_dataset, test_positive_dataset, test_negative_dataset, test_binding_dataset, tokenizer=tokenizer, max_sequence_length=40000)
153
-
154
- # Convert the prebatched data to a dictionary with each batch as an entry
155
- train_prebatched_data_dict = {
156
- 'anchor_input_ids': [batch['anchor_input_ids'].numpy() for batch in train_dataset.prebatched_data],
157
- 'anchor_attention_mask': [batch['anchor_attention_mask'].numpy() for batch in train_dataset.prebatched_data],
158
- 'positive_input_ids': [batch['positive_input_ids'].numpy() for batch in train_dataset.prebatched_data],
159
- 'positive_attention_mask': [batch['positive_attention_mask'].numpy() for batch in train_dataset.prebatched_data],
160
- 'negative_input_ids': [batch['negative_input_ids'].numpy() for batch in train_dataset.prebatched_data],
161
- 'negative_attention_mask': [batch['negative_attention_mask'].numpy() for batch in train_dataset.prebatched_data],
162
- 'binding_site': [batch['binding_site'].numpy() for batch in train_dataset.prebatched_data]
163
- }
164
-
165
- val_prebatched_data_dict = {
166
- 'anchor_input_ids': [batch['anchor_input_ids'].numpy() for batch in val_dataset.prebatched_data],
167
- 'anchor_attention_mask': [batch['anchor_attention_mask'].numpy() for batch in val_dataset.prebatched_data],
168
- 'positive_input_ids': [batch['positive_input_ids'].numpy() for batch in val_dataset.prebatched_data],
169
- 'positive_attention_mask': [batch['positive_attention_mask'].numpy() for batch in val_dataset.prebatched_data],
170
- 'negative_input_ids': [batch['negative_input_ids'].numpy() for batch in val_dataset.prebatched_data],
171
- 'negative_attention_mask': [batch['negative_attention_mask'].numpy() for batch in val_dataset.prebatched_data],
172
- 'binding_site': [batch['binding_site'].numpy() for batch in val_dataset.prebatched_data]
173
- }
174
- test_prebatched_data_dict = {
175
- 'anchor_input_ids': [batch['anchor_input_ids'].numpy() for batch in test_dataset.prebatched_data],
176
- 'anchor_attention_mask': [batch['anchor_attention_mask'].numpy() for batch in test_dataset.prebatched_data],
177
- 'positive_input_ids': [batch['positive_input_ids'].numpy() for batch in test_dataset.prebatched_data],
178
- 'positive_attention_mask': [batch['positive_attention_mask'].numpy() for batch in test_dataset.prebatched_data],
179
- 'negative_input_ids': [batch['negative_input_ids'].numpy() for batch in test_dataset.prebatched_data],
180
- 'negative_attention_mask': [batch['negative_attention_mask'].numpy() for batch in test_dataset.prebatched_data],
181
- 'binding_site': [batch['binding_site'].numpy() for batch in test_dataset.prebatched_data]
182
- }
183
-
184
- # Convert the dictionary to a HuggingFace Dataset
185
- train_hf_dataset = HFDataset.from_dict(train_prebatched_data_dict)
186
- train_hf_dataset.save_to_disk('train_mut')
187
-
188
- val_hf_dataset = HFDataset.from_dict(val_prebatched_data_dict)
189
- val_hf_dataset.save_to_disk('val_mut')
190
-
191
- test_hf_dataset = HFDataset.from_dict(test_prebatched_data_dict)
192
- test_hf_dataset.save_to_disk('test_mut')
193
-
194
-
195
-
196
- if __name__ == "__main__":
197
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset/static_prebatching.py DELETED
@@ -1,162 +0,0 @@
1
- import pandas as pd
2
- import ast
3
- from sklearn.model_selection import train_test_split
4
- import numpy as np
5
- import torch.nn.functional as F
6
- import torch
7
- from torch.utils.data import Dataset
8
- from datasets import Dataset as HFDataset, DatasetDict
9
- from transformers import AutoTokenizer
10
- import pdb
11
-
12
-
13
- class TripletDataset(Dataset):
14
- def __init__(self, anchors, positives, negatives, binding_sites, tokenizer, max_sequence_length=40000):
15
- self.anchors = anchors
16
- self.positives = positives
17
- self.negatives = negatives
18
- self.binding_sites = binding_sites
19
- self.tokenizer = tokenizer
20
- self.max_sequence_length = max_sequence_length
21
- self.triplets = []
22
- self.precompute_triplets()
23
-
24
- def __len__(self):
25
- return len(self.triplets)
26
-
27
- def __getitem__(self, index):
28
- return self.triplets[index]
29
-
30
- def precompute_triplets(self):
31
- self.triplets = []
32
- for anchor, positive, negative, binding_site in zip(self.anchors, self.positives, self.negatives,
33
- self.binding_sites):
34
- anchor_tokens = self.tokenizer(anchor, return_tensors='pt', padding=True, truncation=True,
35
- max_length=self.max_sequence_length)
36
- positive_tokens = self.tokenizer(positive, return_tensors='pt', padding=True, truncation=True,
37
- max_length=self.max_sequence_length)
38
- negative_tokens = self.tokenizer(negative, return_tensors='pt', padding=True, truncation=True,
39
- max_length=self.max_sequence_length)
40
-
41
- # mask out the first and last tokens due to being <bos> and <eos>
42
- anchor_tokens['attention_mask'][0][0] = 0
43
- anchor_tokens['attention_mask'][0][-1] = 0
44
- positive_tokens['attention_mask'][0][0] = 0
45
- positive_tokens['attention_mask'][0][-1] = 0
46
- negative_tokens['attention_mask'][0][0] = 0
47
- negative_tokens['attention_mask'][0][-1] = 0
48
-
49
- self.triplets.append((anchor_tokens, positive_tokens, negative_tokens, binding_site))
50
- # pdb.set_trace()
51
- return self.triplets
52
-
53
-
54
- def main():
55
-
56
- data = pd.read_csv('dataset_drop_500.csv')
57
-
58
- print(len(data))
59
-
60
- negatives = data['mutTarget'].tolist()
61
- positives = data['Binder'].tolist()
62
- anchors = data['Target'].tolist()
63
- binding_sites = data['mutTarget_motifs'].tolist()
64
-
65
- # We should plus 1 because there will be a start token after embedded by ESM-2
66
- binding_sites = [ast.literal_eval(binding_site) for binding_site in binding_sites]
67
- binding_sites = [[int(site.split('_')[1]) + 1 for site in binding_site] for binding_site in binding_sites]
68
-
69
- train_val_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
70
- train_data, val_data = train_test_split(train_val_data, test_size=0.25, random_state=42)
71
-
72
- train_index = train_data.index.to_numpy()
73
- val_index = val_data.index.to_numpy()
74
- test_index = test_data.index.to_numpy()
75
-
76
- train_anchor_dataset = np.array(anchors)[train_index]
77
- train_negative_dataset = np.array(negatives)[train_index]
78
- train_positive_dataset = np.array(positives)[train_index]
79
- train_binding_dataset = [binding_sites[i] for i in train_index]
80
-
81
- val_anchor_dataset = np.array(anchors)[val_index]
82
- val_negative_dataset = np.array(negatives)[val_index]
83
- val_positive_dataset = np.array(positives)[val_index]
84
- val_binding_dataset = [binding_sites[i] for i in val_index]
85
-
86
- test_anchor_dataset = np.array(anchors)[test_index]
87
- test_negative_dataset = np.array(negatives)[test_index]
88
- test_positive_dataset = np.array(positives)[test_index]
89
- test_binding_dataset = [binding_sites[i] for i in test_index]
90
-
91
- # Create an instance of the tokenizer
92
- tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
93
-
94
- # Initialize the TripletDataset
95
- train_dataset = TripletDataset(train_anchor_dataset, train_positive_dataset, train_negative_dataset, train_binding_dataset, tokenizer=tokenizer, max_sequence_length=50000)
96
- val_dataset = TripletDataset(val_anchor_dataset, val_positive_dataset, val_negative_dataset, val_binding_dataset, tokenizer=tokenizer, max_sequence_length=50000)
97
- test_dataset = TripletDataset(test_anchor_dataset, test_positive_dataset, test_negative_dataset, test_binding_dataset, tokenizer=tokenizer, max_sequence_length=50000)
98
-
99
- train_prebatched_data_dict = {
100
- 'anchors': [batch[0] for batch in train_dataset.triplets],
101
- 'positives': [batch[1] for batch in train_dataset.triplets],
102
- # 'negatives': [batch[2] for batch in train_dataset.triplets],
103
- 'binding_site': [batch[3] for batch in train_dataset.triplets]
104
- }
105
-
106
- val_prebatched_data_dict = {
107
- 'anchors': [batch[0] for batch in val_dataset.triplets],
108
- 'positives': [batch[1] for batch in val_dataset.triplets],
109
- # 'negatives': [batch[2] for batch in val_dataset.triplets],
110
- 'binding_site': [batch[3] for batch in val_dataset.triplets]
111
- }
112
-
113
- test_prebatched_data_dict = {
114
- 'anchors': [batch[0] for batch in test_dataset.triplets],
115
- 'positives': [batch[1] for batch in test_dataset.triplets],
116
- # 'negatives': [batch[2] for batch in test_dataset.triplets],
117
- 'binding_site': [batch[3] for batch in test_dataset.triplets]
118
- }
119
-
120
- # Convert the prebatched data to a dictionary with each batch as an entry
121
- # train_prebatched_data_dict = {
122
- # 'anchor_input_ids': [batch[0]['input_ids'].numpy() for batch in train_dataset.triplets],
123
- # 'anchor_attention_mask': [batch[0]['attention_mask'].numpy() for batch in train_dataset.triplets],
124
- # 'positive_input_ids': [batch[1]['input_ids'].numpy() for batch in train_dataset.triplets],
125
- # 'positive_attention_mask': [batch[1]['attention_mask'].numpy() for batch in train_dataset.triplets],
126
- # 'negative_input_ids': [batch[2]['input_ids'].numpy() for batch in train_dataset.triplets],
127
- # 'negative_attention_mask': [batch[2]['attention_mask'].numpy() for batch in train_dataset.triplets],
128
- # 'binding_site': [batch[3] for batch in train_dataset.triplets]
129
- # }
130
- #
131
- # val_prebatched_data_dict = {
132
- # 'anchor_input_ids': [batch[0]['input_ids'].numpy() for batch in val_dataset.triplets],
133
- # 'anchor_attention_mask': [batch[0]['attention_mask'].numpy() for batch in val_dataset.triplets],
134
- # 'positive_input_ids': [batch[1]['input_ids'].numpy() for batch in val_dataset.triplets],
135
- # 'positive_attention_mask': [batch[1]['attention_mask'].numpy() for batch in val_dataset.triplets],
136
- # 'negative_input_ids': [batch[2]['input_ids'].numpy() for batch in val_dataset.triplets],
137
- # 'negative_attention_mask': [batch[2]['attention_mask'].numpy() for batch in val_dataset.triplets],
138
- # 'binding_site': [batch[3] for batch in val_dataset.triplets]
139
- # }
140
- # test_prebatched_data_dict = {
141
- # 'anchor_input_ids': [batch[0]['input_ids'].numpy() for batch in test_dataset.triplets],
142
- # 'anchor_attention_mask': [batch[0]['attention_mask'].numpy() for batch in test_dataset.triplets],
143
- # 'positive_input_ids': [batch[1]['input_ids'].numpy() for batch in test_dataset.triplets],
144
- # 'positive_attention_mask': [batch[1]['attention_mask'].numpy() for batch in test_dataset.triplets],
145
- # 'negative_input_ids': [batch[2]['input_ids'].numpy() for batch in test_dataset.triplets],
146
- # 'negative_attention_mask': [batch[2]['attention_mask'].numpy() for batch in test_dataset.triplets],
147
- # 'binding_site': [batch[3] for batch in test_dataset.triplets]
148
- # }
149
-
150
- # Convert the dictionary to a HuggingFace Dataset
151
- train_hf_dataset = HFDataset.from_dict(train_prebatched_data_dict)
152
- train_hf_dataset.save_to_disk('/home/tc415/muPPIt/dataset/train_dataset_drop_500')
153
-
154
- val_hf_dataset = HFDataset.from_dict(val_prebatched_data_dict)
155
- val_hf_dataset.save_to_disk('/home/tc415/muPPIt/dataset/val_dataset_drop_500')
156
-
157
- test_hf_dataset = HFDataset.from_dict(test_prebatched_data_dict)
158
- test_hf_dataset.save_to_disk('/home/tc415/muPPIt/dataset/test_dataset_drop_500')
159
-
160
-
161
- if __name__ == "__main__":
162
- main()