Delete dataset
Browse files- dataset/PPI_README.md +0 -8
- dataset/PPI_contamination.py +0 -169
- dataset/PPI_extract_full_sequence.py +0 -176
- dataset/PPI_final_contamination.py +0 -165
- dataset/compute_class_weights.py +0 -47
- dataset/peptide_static_batching.py +0 -83
- dataset/prebatching.py +0 -197
- dataset/static_prebatching.py +0 -162
dataset/PPI_README.md
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
# 1. Run comtamination.py to get initial results and error sequences.
|
2 |
-
`python -u contamination.py -i 1,2,3,4`
|
3 |
-
|
4 |
-
# 2. Run extract_full_sequence.py to get full sequences for error sequences. extract_full_sequence.py can only run for one id at a time.
|
5 |
-
`python -u extract_full_sequence.py -id 1`
|
6 |
-
|
7 |
-
# 3. After getting full sequences for error sequences, run final_contamination.py to get the final results.
|
8 |
-
`python -u final_contamination.py -i 1,2,3,4`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset/PPI_contamination.py
DELETED
@@ -1,169 +0,0 @@
|
|
1 |
-
"""BLOSUM guided motif contamination"""
|
2 |
-
import pandas as pd
|
3 |
-
import blosum as bl
|
4 |
-
import ast
|
5 |
-
import pickle
|
6 |
-
import pandas as pd
|
7 |
-
from Bio import SeqIO
|
8 |
-
from math import ceil
|
9 |
-
from sklearn.model_selection import train_test_split
|
10 |
-
import random
|
11 |
-
from Bio.Seq import Seq
|
12 |
-
from Bio.SeqRecord import SeqRecord
|
13 |
-
import argparse
|
14 |
-
|
15 |
-
def main(i):
|
16 |
-
random.seed(42)
|
17 |
-
|
18 |
-
blosum = bl.BLOSUM(62)
|
19 |
-
def get_least_likely_substitution(residue):
|
20 |
-
if residue not in blosum:
|
21 |
-
return residue # If residue is not in Blosum matrix, return it as is
|
22 |
-
matrix_keys = list(blosum.keys())
|
23 |
-
min_score = min(blosum[residue][r] for r in matrix_keys if r != '*' and r != 'J')
|
24 |
-
least_likely_residues = [r for r in matrix_keys if r != '*' and r != 'J' and blosum[residue][r] == min_score]
|
25 |
-
least_likely_residue = random.choice(least_likely_residues)
|
26 |
-
return least_likely_residue
|
27 |
-
|
28 |
-
|
29 |
-
csv_file = f"raw_data/processed_6A_results_batch_{i}.json"
|
30 |
-
df = pd.read_json(csv_file)
|
31 |
-
df.to_csv(f"raw_data/processed_6A_results_batch_{i}.csv", index=False)
|
32 |
-
df = pd.read_csv(f"raw_data/processed_6A_results_batch_{i}.csv")
|
33 |
-
|
34 |
-
output_csv = f"contaminated_data/processed_6A_results_batch_{i}.csv"
|
35 |
-
error_csv = f"contaminated_data/error_6A_results_batch_{i}.csv"
|
36 |
-
|
37 |
-
new_rows = []
|
38 |
-
error_rows = []
|
39 |
-
|
40 |
-
for idx, row in df.iterrows():
|
41 |
-
flag1 = False # check whether there are errors when mutating Sequence1
|
42 |
-
flag2 = False # check whether there are errors when mutation Sequence2
|
43 |
-
|
44 |
-
chain1 = row['Chain1'].upper()
|
45 |
-
chain2 = row['Chain2'].upper()
|
46 |
-
sequence1 = row['Sequence1']
|
47 |
-
sequence2 = row['Sequence2']
|
48 |
-
chain_1_motifs = ast.literal_eval(row['Chain_1_motifs'])
|
49 |
-
chain_2_motifs = ast.literal_eval(row['Chain_2_motifs'])
|
50 |
-
chain_1_offset = row['Chain_1_offset']
|
51 |
-
chain_2_offset = row['Chain_2_offset']
|
52 |
-
|
53 |
-
# Create a new entry by mutating sequence1
|
54 |
-
sequence1_list = list(sequence1)
|
55 |
-
modified_chain_1_motifs = []
|
56 |
-
if len(chain_1_motifs) > 0:
|
57 |
-
|
58 |
-
# Ignore entries where motif length equals sequence length cuz it'll be too hard for models to learn
|
59 |
-
if len(chain_1_motifs) == len(sequence1):
|
60 |
-
flag1 = True
|
61 |
-
|
62 |
-
for motif in chain_1_motifs:
|
63 |
-
res, pos = motif.split('_')
|
64 |
-
# Errors for motifs or there are unalignments between sequence and motif
|
65 |
-
if int(pos) >= len(sequence1) or int(pos) < 0 or res != sequence1[int(pos)]:
|
66 |
-
error_rows.append({
|
67 |
-
'PDB_ID': row['PDB_ID'] + '_' + chain1 + '_' + chain2,
|
68 |
-
'Chain': chain1,
|
69 |
-
'Sequence': sequence1,
|
70 |
-
'Error_motif': motif,
|
71 |
-
'Chain_offset': row['Chain_1_offset']
|
72 |
-
})
|
73 |
-
flag1 = True
|
74 |
-
break
|
75 |
-
|
76 |
-
least_likely_residue = get_least_likely_substitution(res)
|
77 |
-
sequence1_list[int(pos)] = least_likely_residue
|
78 |
-
modified_chain_1_motifs.append(res + '_' + pos + '_' + least_likely_residue)
|
79 |
-
|
80 |
-
# only save the entries without errors or do not need to be ignored
|
81 |
-
if flag1 is False:
|
82 |
-
modified_sequence1 = ''.join(sequence1_list)
|
83 |
-
new_rows.append({
|
84 |
-
'PDB_ID': row['PDB_ID'] + '_' + chain1 + '_' + chain2,
|
85 |
-
'Chain1': chain1,
|
86 |
-
'Sequence1': modified_sequence1,
|
87 |
-
'Chain2': chain2,
|
88 |
-
'Sequence2': sequence2,
|
89 |
-
'Chain_1_motifs': str(modified_chain_1_motifs),
|
90 |
-
'Chain_2_motifs': row['Chain_2_motifs'],
|
91 |
-
'Chain_1_offset': row['Chain_1_offset'],
|
92 |
-
'Chain_2_offset': row['Chain_2_offset'],
|
93 |
-
'Modified_chain': chain1,
|
94 |
-
'Original_sequence': sequence1,
|
95 |
-
})
|
96 |
-
|
97 |
-
# If sequence2 is the same as sequence1 and so as the motifs, do not need to mutate sequence2
|
98 |
-
if sequence1 == sequence2 and chain_1_motifs == chain_2_motifs:
|
99 |
-
continue
|
100 |
-
|
101 |
-
# Create a new entry by mutating sequence2, using the same logic as sequenc1
|
102 |
-
if len(chain_2_motifs) > 0:
|
103 |
-
if len(chain_2_motifs) == len(sequence2):
|
104 |
-
flag2 == True
|
105 |
-
sequence2_list = list(sequence2)
|
106 |
-
modified_chain_2_motifs = []
|
107 |
-
for motif in chain_2_motifs:
|
108 |
-
res, pos = motif.split('_')
|
109 |
-
if int(pos) >= len(sequence2) or int(pos) < 0 or res != sequence2[int(pos)]:
|
110 |
-
error_rows.append({
|
111 |
-
'PDB_ID': row['PDB_ID'] + '_' + chain1 + '_' + chain2,
|
112 |
-
'Chain': chain2,
|
113 |
-
'Sequence': sequence2,
|
114 |
-
'Error_motif': motif,
|
115 |
-
'Chain_offset': row['Chain_2_offset']
|
116 |
-
})
|
117 |
-
flag2 = True
|
118 |
-
break
|
119 |
-
|
120 |
-
least_likely_residue = get_least_likely_substitution(res)
|
121 |
-
sequence2_list[int(pos)] = least_likely_residue
|
122 |
-
modified_chain_2_motifs.append(res + '_' + pos + '_' + least_likely_residue)
|
123 |
-
|
124 |
-
if flag2 is False:
|
125 |
-
modified_sequence2 = ''.join(sequence2_list)
|
126 |
-
new_rows.append({
|
127 |
-
'PDB_ID': row['PDB_ID'] + '_' + chain2 + '_' + chain1,
|
128 |
-
'Chain1': chain1,
|
129 |
-
'Sequence1': sequence1,
|
130 |
-
'Chain2': chain2,
|
131 |
-
'Sequence2': modified_sequence2,
|
132 |
-
'Chain_1_motifs': row['Chain_1_motifs'],
|
133 |
-
'Chain_2_motifs': str(modified_chain_2_motifs),
|
134 |
-
'Chain_1_offset': row['Chain_1_offset'],
|
135 |
-
'Chain_2_offset': row['Chain_2_offset'],
|
136 |
-
'Modified_chain': chain2,
|
137 |
-
'Original_sequence': sequence2,
|
138 |
-
})
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
# Finished mutation
|
143 |
-
new_df = pd.DataFrame(new_rows)
|
144 |
-
|
145 |
-
# Deduplicate
|
146 |
-
columns_to_check = ['Sequence1', 'Sequence2', 'Chain_1_motifs', 'Chain_2_motifs', 'Chain_1_offset', 'Chain_2_offset']
|
147 |
-
deduplicated_new_df = new_df.drop_duplicates(subset=columns_to_check)
|
148 |
-
print(f"Number of rows before deduplication: {len(new_df)}")
|
149 |
-
print(f"Number of rows after deduplication: {len(deduplicated_new_df)}")
|
150 |
-
|
151 |
-
deduplicated_new_df.to_csv(output_csv, index=False)
|
152 |
-
|
153 |
-
# Save error sequences to another file
|
154 |
-
error_df = pd.DataFrame(error_rows)
|
155 |
-
error_df.to_csv(error_csv, index=False)
|
156 |
-
|
157 |
-
if __name__ == '__main__':
|
158 |
-
|
159 |
-
parser = argparse.ArgumentParser()
|
160 |
-
|
161 |
-
parser.add_argument('-i')
|
162 |
-
|
163 |
-
args = parser.parse_args()
|
164 |
-
|
165 |
-
i_s = args.i # 2,3,4,5,6,7,8,9,10
|
166 |
-
|
167 |
-
for i in i_s.split(','):
|
168 |
-
print(int(i))
|
169 |
-
main(int(i))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset/PPI_extract_full_sequence.py
DELETED
@@ -1,176 +0,0 @@
|
|
1 |
-
"""Pull full sequences from PDB files for error sequences"""
|
2 |
-
import json
|
3 |
-
import os
|
4 |
-
import logging
|
5 |
-
from Bio import PDB
|
6 |
-
import warnings
|
7 |
-
import requests
|
8 |
-
import pickle
|
9 |
-
import pandas as pd
|
10 |
-
import argparse
|
11 |
-
|
12 |
-
warnings.filterwarnings("ignore", category=PDB.PDBExceptions.PDBConstructionWarning)
|
13 |
-
logging.basicConfig(filename='pdb.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
14 |
-
|
15 |
-
AA_CODE_MAP = {
|
16 |
-
'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D',
|
17 |
-
'CYS': 'C', 'GLN': 'Q', 'GLU': 'E', 'GLY': 'G',
|
18 |
-
'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K',
|
19 |
-
'MET': 'M', 'PHE': 'F', 'PRO': 'P', 'SER': 'S',
|
20 |
-
'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'
|
21 |
-
}
|
22 |
-
|
23 |
-
|
24 |
-
# Missing residues are recoreded in REMARK 465 fields in pdb files
|
25 |
-
def extract_remark_465(pdb_file_path):
|
26 |
-
remark_465_lines = []
|
27 |
-
with open(pdb_file_path, 'r') as file:
|
28 |
-
for line in file:
|
29 |
-
if line.startswith("REMARK 465 "):
|
30 |
-
remark_465_lines.append(line.strip())
|
31 |
-
return remark_465_lines[2:]
|
32 |
-
|
33 |
-
|
34 |
-
def parse_remark_465(remark_465_lines):
|
35 |
-
missing_residues = {}
|
36 |
-
for line in remark_465_lines:
|
37 |
-
parts = line.split()
|
38 |
-
if len(parts) < 5:
|
39 |
-
continue
|
40 |
-
chain_id = parts[3]
|
41 |
-
resseq = int(parts[4])
|
42 |
-
resname = parts[2]
|
43 |
-
# print(resname, chain_id, resseq)
|
44 |
-
if chain_id not in missing_residues:
|
45 |
-
missing_residues[chain_id] = []
|
46 |
-
missing_residues[chain_id].append((resseq, resname))
|
47 |
-
return missing_residues
|
48 |
-
|
49 |
-
|
50 |
-
def extract_sequences(structure, target_chain_id, missing_residues):
|
51 |
-
for chain in structure.get_chains():
|
52 |
-
chain_id = chain.get_id()
|
53 |
-
residues = list(chain.get_residues())
|
54 |
-
if chain_id == target_chain_id:
|
55 |
-
seq_list = []
|
56 |
-
resseq_set = set(res.get_id()[1] for res in residues)
|
57 |
-
min_resseq_struct = min(resseq_set, default=1)
|
58 |
-
max_resseq_struct = max(resseq_set, default=0)
|
59 |
-
max_resseq_missing = max((x[0] for x in missing_residues.get(chain_id, [])), default=0)
|
60 |
-
resseq_max = max(max_resseq_struct, max_resseq_missing)
|
61 |
-
|
62 |
-
for i in range(min_resseq_struct, resseq_max + 1):
|
63 |
-
if i in resseq_set:
|
64 |
-
resname = next(res.get_resname() for res in residues if res.get_id()[1] == i)
|
65 |
-
seq_list.append(AA_CODE_MAP.get(resname, 'X'))
|
66 |
-
elif chain_id in missing_residues and i in [x[0] for x in missing_residues[chain_id]]:
|
67 |
-
resname = next(x[1] for x in missing_residues[chain_id] if x[0] == i)
|
68 |
-
seq_list.append(AA_CODE_MAP.get(resname, 'X'))
|
69 |
-
|
70 |
-
chain_seq = ''.join(seq_list).strip('X')
|
71 |
-
|
72 |
-
return chain_seq
|
73 |
-
|
74 |
-
|
75 |
-
def download_pdb(pdb_id, id):
|
76 |
-
# proxies = {
|
77 |
-
# "http": "http://127.0.0.1:1080",
|
78 |
-
# "https": "http://127.0.0.1:1080",
|
79 |
-
# }
|
80 |
-
|
81 |
-
url = f"https://files.rcsb.org/download/{pdb_id}.pdb"
|
82 |
-
file_path = f"pdb{id}/{pdb_id}.pdb"
|
83 |
-
|
84 |
-
while (True):
|
85 |
-
try:
|
86 |
-
# Download the PDB file
|
87 |
-
# response = requests.get(url, proxies=proxies)
|
88 |
-
response = requests.get(url)
|
89 |
-
response.raise_for_status()
|
90 |
-
with open(file_path, "wb") as file:
|
91 |
-
file.write(response.content)
|
92 |
-
# print(f"Downloaded {pdb_id}.pdb")
|
93 |
-
return file_path
|
94 |
-
except requests.exceptions.RequestException as e:
|
95 |
-
print(f"Failed to download {pdb_id}.pdb: {e}")
|
96 |
-
continue
|
97 |
-
|
98 |
-
|
99 |
-
def delete_pdb(pdb_id, chain_id, id):
|
100 |
-
file_path = f"pdb{id}/{pdb_id}.pdb"
|
101 |
-
if os.path.exists(file_path):
|
102 |
-
os.remove(file_path)
|
103 |
-
print(f"Deleted {pdb_id}.pdb for {chain_id}")
|
104 |
-
else:
|
105 |
-
print(f"File {pdb_id}.pdb does not exist")
|
106 |
-
|
107 |
-
|
108 |
-
def process_entry(entry, chain_id, results, id):
|
109 |
-
pdb_id = entry[0:4]
|
110 |
-
|
111 |
-
pdb_file_path = download_pdb(pdb_id, id)
|
112 |
-
|
113 |
-
if os.path.exists(pdb_file_path):
|
114 |
-
try:
|
115 |
-
parser = PDB.PDBParser()
|
116 |
-
structure = parser.get_structure(pdb_id, pdb_file_path)
|
117 |
-
|
118 |
-
# Extract and parse REMARK 465
|
119 |
-
remark_465 = extract_remark_465(pdb_file_path)
|
120 |
-
missing_residues = parse_remark_465(remark_465)
|
121 |
-
|
122 |
-
# Get the full sequences for target chain
|
123 |
-
chain_seq = extract_sequences(structure, chain_id, missing_residues)
|
124 |
-
|
125 |
-
for index, row in results.iterrows():
|
126 |
-
if row['PDB_ID'] == pdb_id:
|
127 |
-
if row['Chain1'] == chain_id:
|
128 |
-
results.at[index, 'Sequence1'] = chain_seq
|
129 |
-
elif row['Chain2'] == chain_id:
|
130 |
-
results.at[index, 'Sequence2'] = chain_seq
|
131 |
-
else:
|
132 |
-
NotImplementedError
|
133 |
-
|
134 |
-
delete_pdb(pdb_id, chain_id, id)
|
135 |
-
|
136 |
-
except Exception as e:
|
137 |
-
logging.error(f'Failed to process {pdb_id}: {str(e)}')
|
138 |
-
else:
|
139 |
-
logging.error(f'PDB file {pdb_id}.pdb not found')
|
140 |
-
|
141 |
-
|
142 |
-
def main(id):
|
143 |
-
# Load the PDB_ID list and corresponding chain ID
|
144 |
-
df = pd.read_csv(f'contaminated_data/error_6A_results_batch_{id}.csv')
|
145 |
-
# print(df)
|
146 |
-
pdb_id_list = df['PDB_ID'].tolist()
|
147 |
-
chain_id_list = df['Chain'].tolist()
|
148 |
-
processed = []
|
149 |
-
|
150 |
-
rs = pd.read_csv(f'raw_data/processed_6A_results_batch_{id}.csv')
|
151 |
-
|
152 |
-
for i in range(len(pdb_id_list)):
|
153 |
-
entry, chain_id = pdb_id_list[i], chain_id_list[i].upper() # 6x85_D_F, D
|
154 |
-
if {entry: chain_id} not in processed:
|
155 |
-
processed.append({entry: chain_id})
|
156 |
-
process_entry(entry, chain_id, rs, id)
|
157 |
-
|
158 |
-
if i % 100 == 0:
|
159 |
-
rs.to_csv(f'raw_data/corrected_processed_6A_results_batch_{id}.csv', index=False)
|
160 |
-
print(f"Saving for i={i}")
|
161 |
-
|
162 |
-
rs.to_csv(f'raw_data/corrected_processed_6A_results_batch_{id}.csv', index=False)
|
163 |
-
|
164 |
-
|
165 |
-
if __name__ == "__main__":
|
166 |
-
parser = argparse.ArgumentParser()
|
167 |
-
|
168 |
-
parser.add_argument('-id')
|
169 |
-
|
170 |
-
args = parser.parse_args()
|
171 |
-
|
172 |
-
print(int(args.id))
|
173 |
-
|
174 |
-
if not os.path.exists(f"pdb{int(args.id)}"):
|
175 |
-
os.makedirs(f"pdb{int(args.id)}")
|
176 |
-
main(int(args.id))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset/PPI_final_contamination.py
DELETED
@@ -1,165 +0,0 @@
|
|
1 |
-
"""Final motif contamination after pulling full sequences"""
|
2 |
-
import pandas as pd
|
3 |
-
import blosum as bl
|
4 |
-
import ast
|
5 |
-
import pickle
|
6 |
-
import pandas as pd
|
7 |
-
from Bio import SeqIO
|
8 |
-
from math import ceil
|
9 |
-
from sklearn.model_selection import train_test_split
|
10 |
-
import random
|
11 |
-
from Bio.Seq import Seq
|
12 |
-
from Bio.SeqRecord import SeqRecord
|
13 |
-
import argparse
|
14 |
-
|
15 |
-
def main(i):
|
16 |
-
random.seed(42)
|
17 |
-
|
18 |
-
blosum = bl.BLOSUM(62)
|
19 |
-
def get_least_likely_substitution(residue):
|
20 |
-
if residue not in blosum:
|
21 |
-
return residue # If residue is not in Blosum matrix, return it as is
|
22 |
-
matrix_keys = list(blosum.keys())
|
23 |
-
min_score = min(blosum[residue][r] for r in matrix_keys if r != '*' and r != 'J')
|
24 |
-
least_likely_residues = [r for r in matrix_keys if r != '*' and r != 'J' and blosum[residue][r] == min_score]
|
25 |
-
least_likely_residue = random.choice(least_likely_residues)
|
26 |
-
return least_likely_residue
|
27 |
-
|
28 |
-
|
29 |
-
df = pd.read_csv(f"raw_data/corrected_processed_6A_results_batch_{i}.csv")
|
30 |
-
|
31 |
-
output_csv = f"contaminated_data/processed_6A_results_batch_{i}.csv"
|
32 |
-
error_csv = f"contaminated_data/error_6A_results_batch_{i}.csv"
|
33 |
-
|
34 |
-
new_rows = []
|
35 |
-
error_rows = []
|
36 |
-
|
37 |
-
for idx, row in df.iterrows():
|
38 |
-
flag1 = False # check whether there are errors when mutating Sequence1
|
39 |
-
flag2 = False # check whether there are errors when mutation Sequence2
|
40 |
-
|
41 |
-
chain1 = row['Chain1'].upper()
|
42 |
-
chain2 = row['Chain2'].upper()
|
43 |
-
sequence1 = row['Sequence1']
|
44 |
-
sequence2 = row['Sequence2']
|
45 |
-
chain_1_motifs = ast.literal_eval(row['Chain_1_motifs'])
|
46 |
-
chain_2_motifs = ast.literal_eval(row['Chain_2_motifs'])
|
47 |
-
chain_1_offset = row['Chain_1_offset']
|
48 |
-
chain_2_offset = row['Chain_2_offset']
|
49 |
-
|
50 |
-
# Create a new entry by mutating sequence1
|
51 |
-
sequence1_list = list(sequence1)
|
52 |
-
modified_chain_1_motifs = []
|
53 |
-
if len(chain_1_motifs) > 0:
|
54 |
-
|
55 |
-
# Ignore entries where motif length equals sequence length cuz it'll be too hard for models to learn
|
56 |
-
if len(chain_1_motifs) == len(sequence1):
|
57 |
-
flag1 = True
|
58 |
-
|
59 |
-
for motif in chain_1_motifs:
|
60 |
-
res, pos = motif.split('_')
|
61 |
-
# Errors for motifs or there are unalignments between sequence and motif
|
62 |
-
if int(pos) >= len(sequence1) or int(pos) < 0 or res != sequence1[int(pos)]:
|
63 |
-
error_rows.append({
|
64 |
-
'PDB_ID': row['PDB_ID'] + '_' + chain1 + '_' + chain2,
|
65 |
-
'Chain': chain1,
|
66 |
-
'Sequence': sequence1,
|
67 |
-
'Error_motif': motif,
|
68 |
-
'Chain_offset': row['Chain_1_offset']
|
69 |
-
})
|
70 |
-
flag1 = True
|
71 |
-
break
|
72 |
-
|
73 |
-
least_likely_residue = get_least_likely_substitution(res)
|
74 |
-
sequence1_list[int(pos)] = least_likely_residue
|
75 |
-
modified_chain_1_motifs.append(res + '_' + pos + '_' + least_likely_residue)
|
76 |
-
|
77 |
-
# only save the entries without errors or do not need to be ignored
|
78 |
-
if flag1 is False:
|
79 |
-
modified_sequence1 = ''.join(sequence1_list)
|
80 |
-
new_rows.append({
|
81 |
-
'PDB_ID': row['PDB_ID'] + '_' + chain1 + '_' + chain2,
|
82 |
-
'Chain1': chain1,
|
83 |
-
'Sequence1': modified_sequence1,
|
84 |
-
'Chain2': chain2,
|
85 |
-
'Sequence2': sequence2,
|
86 |
-
'Chain_1_motifs': str(modified_chain_1_motifs),
|
87 |
-
'Chain_2_motifs': row['Chain_2_motifs'],
|
88 |
-
'Chain_1_offset': row['Chain_1_offset'],
|
89 |
-
'Chain_2_offset': row['Chain_2_offset'],
|
90 |
-
'Modified_chain': chain1,
|
91 |
-
'Original_sequence': sequence1,
|
92 |
-
})
|
93 |
-
|
94 |
-
# if sequence2 is the same as sequence1 and so as the motifs, do not need to mutate sequence2
|
95 |
-
if sequence1 == sequence2 and chain_1_motifs == chain_2_motifs:
|
96 |
-
continue
|
97 |
-
|
98 |
-
# Create a new entry by mutating sequence2, using the same logic as sequenc1
|
99 |
-
if len(chain_2_motifs) > 0:
|
100 |
-
if len(chain_2_motifs) == len(sequence2):
|
101 |
-
flag2 == True
|
102 |
-
sequence2_list = list(sequence2)
|
103 |
-
modified_chain_2_motifs = []
|
104 |
-
for motif in chain_2_motifs:
|
105 |
-
res, pos = motif.split('_')
|
106 |
-
if int(pos) >= len(sequence2) or int(pos) < 0 or res != sequence2[int(pos)]:
|
107 |
-
error_rows.append({
|
108 |
-
'PDB_ID': row['PDB_ID'] + '_' + chain1 + '_' + chain2,
|
109 |
-
'Chain': chain2,
|
110 |
-
'Sequence': sequence2,
|
111 |
-
'Error_motif': motif,
|
112 |
-
'Chain_offset': row['Chain_2_offset']
|
113 |
-
})
|
114 |
-
flag2 = True
|
115 |
-
break
|
116 |
-
|
117 |
-
least_likely_residue = get_least_likely_substitution(res)
|
118 |
-
sequence2_list[int(pos)] = least_likely_residue
|
119 |
-
modified_chain_2_motifs.append(res + '_' + pos + '_' + least_likely_residue)
|
120 |
-
|
121 |
-
if flag2 is False:
|
122 |
-
modified_sequence2 = ''.join(sequence2_list)
|
123 |
-
new_rows.append({
|
124 |
-
'PDB_ID': row['PDB_ID'] + '_' + chain2 + '_' + chain1,
|
125 |
-
'Chain1': chain1,
|
126 |
-
'Sequence1': sequence1,
|
127 |
-
'Chain2': chain2,
|
128 |
-
'Sequence2': modified_sequence2,
|
129 |
-
'Chain_1_motifs': row['Chain_1_motifs'],
|
130 |
-
'Chain_2_motifs': str(modified_chain_2_motifs),
|
131 |
-
'Chain_1_offset': row['Chain_1_offset'],
|
132 |
-
'Chain_2_offset': row['Chain_2_offset'],
|
133 |
-
'Modified_chain': chain2,
|
134 |
-
'Original_sequence': sequence2,
|
135 |
-
})
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
# Finished mutation
|
140 |
-
new_df = pd.DataFrame(new_rows)
|
141 |
-
|
142 |
-
# Deduplicate
|
143 |
-
columns_to_check = ['Sequence1', 'Sequence2', 'Chain_1_motifs', 'Chain_2_motifs', 'Chain_1_offset', 'Chain_2_offset']
|
144 |
-
deduplicated_new_df = new_df.drop_duplicates(subset=columns_to_check)
|
145 |
-
print(f"Number of rows before deduplication: {len(new_df)}")
|
146 |
-
print(f"Number of rows after deduplication: {len(deduplicated_new_df)}")
|
147 |
-
|
148 |
-
deduplicated_new_df.to_csv(output_csv, index=False)
|
149 |
-
|
150 |
-
error_df = pd.DataFrame(error_rows)
|
151 |
-
error_df.to_csv(error_csv, index=False)
|
152 |
-
|
153 |
-
if __name__ == '__main__':
|
154 |
-
|
155 |
-
parser = argparse.ArgumentParser()
|
156 |
-
|
157 |
-
parser.add_argument('-i')
|
158 |
-
|
159 |
-
args = parser.parse_args()
|
160 |
-
|
161 |
-
i_s = args.i # 2,3,4,5,6,7,8,9,10
|
162 |
-
|
163 |
-
for i in i_s.split(','):
|
164 |
-
print(int(i))
|
165 |
-
main(int(i))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset/compute_class_weights.py
DELETED
@@ -1,47 +0,0 @@
|
|
1 |
-
import pandas as pd
|
2 |
-
import ast
|
3 |
-
from sklearn.model_selection import train_test_split
|
4 |
-
import numpy as np
|
5 |
-
import torch.nn.functional as F
|
6 |
-
import torch
|
7 |
-
from torch.utils.data import Dataset
|
8 |
-
from datasets import Dataset as HFDataset, DatasetDict
|
9 |
-
from transformers import AutoTokenizer
|
10 |
-
import pdb
|
11 |
-
|
12 |
-
def main():
|
13 |
-
|
14 |
-
data = pd.read_csv('dataset_drop_500.csv')
|
15 |
-
|
16 |
-
print(len(data))
|
17 |
-
|
18 |
-
binding_sites = data['mutTarget_motifs'].tolist()
|
19 |
-
targets = data['Target'].tolist()
|
20 |
-
|
21 |
-
# No need for padding the first position of binding sites for class weight calculations
|
22 |
-
binding_sites = [ast.literal_eval(binding_site) for binding_site in binding_sites]
|
23 |
-
binding_sites = [len(binding_site) for binding_site in binding_sites]
|
24 |
-
targets = [len(seq) for seq in targets]
|
25 |
-
pdb.set_trace()
|
26 |
-
|
27 |
-
train_val_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
|
28 |
-
train_data, val_data = train_test_split(train_val_data, test_size=0.25, random_state=42)
|
29 |
-
|
30 |
-
train_index = train_data.index.to_numpy()
|
31 |
-
print(len(train_index))
|
32 |
-
return
|
33 |
-
|
34 |
-
train_binding_dataset = [binding_sites[i] for i in train_index]
|
35 |
-
train_targets = [targets[i] for i in train_index]
|
36 |
-
|
37 |
-
num_binding_sites = sum(train_binding_dataset)
|
38 |
-
num_total = sum(train_targets)
|
39 |
-
num_non_binding_sites = num_total - num_binding_sites
|
40 |
-
weight_for_binding = num_total / (2 * num_binding_sites)
|
41 |
-
weight_for_non_binding = num_total / (2 * num_non_binding_sites)
|
42 |
-
|
43 |
-
print(weight_for_binding, weight_for_non_binding)
|
44 |
-
|
45 |
-
|
46 |
-
if __name__ == "__main__":
|
47 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset/peptide_static_batching.py
DELETED
@@ -1,83 +0,0 @@
|
|
1 |
-
import pandas as pd
|
2 |
-
import ast
|
3 |
-
from sklearn.model_selection import train_test_split
|
4 |
-
import numpy as np
|
5 |
-
import torch.nn.functional as F
|
6 |
-
import torch
|
7 |
-
from torch.utils.data import Dataset
|
8 |
-
from datasets import Dataset as HFDataset, DatasetDict
|
9 |
-
from transformers import AutoTokenizer
|
10 |
-
import pdb
|
11 |
-
|
12 |
-
|
13 |
-
class TripletDataset(Dataset):
|
14 |
-
def __init__(self, anchors, positives, binding_sites, tokenizer, max_sequence_length=40000):
|
15 |
-
self.anchors = anchors
|
16 |
-
self.positives = positives
|
17 |
-
self.binding_sites = binding_sites
|
18 |
-
self.tokenizer = tokenizer
|
19 |
-
self.max_sequence_length = max_sequence_length
|
20 |
-
self.triplets = []
|
21 |
-
self.precompute_triplets()
|
22 |
-
|
23 |
-
def __len__(self):
|
24 |
-
return len(self.triplets)
|
25 |
-
|
26 |
-
def __getitem__(self, index):
|
27 |
-
return self.triplets[index]
|
28 |
-
|
29 |
-
def precompute_triplets(self):
|
30 |
-
self.triplets = []
|
31 |
-
for anchor, positive, binding_site in zip(self.anchors, self.positives, self.binding_sites):
|
32 |
-
anchor_tokens = self.tokenizer(anchor, return_tensors='pt', padding=True, truncation=True,
|
33 |
-
max_length=self.max_sequence_length)
|
34 |
-
positive_tokens = self.tokenizer(positive, return_tensors='pt', padding=True, truncation=True,
|
35 |
-
max_length=self.max_sequence_length)
|
36 |
-
|
37 |
-
# mask out the first and last tokens due to being <bos> and <eos>
|
38 |
-
anchor_tokens['attention_mask'][0][0] = 0
|
39 |
-
anchor_tokens['attention_mask'][0][-1] = 0
|
40 |
-
positive_tokens['attention_mask'][0][0] = 0
|
41 |
-
positive_tokens['attention_mask'][0][-1] = 0
|
42 |
-
|
43 |
-
self.triplets.append((anchor_tokens, positive_tokens, binding_site))
|
44 |
-
# pdb.set_trace()
|
45 |
-
return self.triplets
|
46 |
-
|
47 |
-
|
48 |
-
def main():
|
49 |
-
|
50 |
-
data = pd.read_csv('/home/tc415/muPPIt/dataset/pep_prot/pep_prot_test.csv')
|
51 |
-
|
52 |
-
print(len(data))
|
53 |
-
|
54 |
-
positives = data['Binder'].tolist()
|
55 |
-
anchors = data['Target'].tolist()
|
56 |
-
binding_sites = data['Motif'].tolist()
|
57 |
-
|
58 |
-
# We should plus 1 because there will be a start token after embedded by ESM-2
|
59 |
-
binding_sites = [binding_site.split(',') for binding_site in binding_sites]
|
60 |
-
binding_sites = [[int(site) + 1 for site in binding_site] for binding_site in binding_sites]
|
61 |
-
|
62 |
-
train_anchor_dataset = np.array(anchors)
|
63 |
-
train_positive_dataset = np.array(positives)
|
64 |
-
train_binding_dataset = binding_sites
|
65 |
-
|
66 |
-
# Create an instance of the tokenizer
|
67 |
-
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
68 |
-
|
69 |
-
# Initialize the TripletDataset
|
70 |
-
train_dataset = TripletDataset(train_anchor_dataset, train_positive_dataset, train_binding_dataset, tokenizer=tokenizer, max_sequence_length=50000)
|
71 |
-
train_prebatched_data_dict = {
|
72 |
-
'anchors': [batch[0] for batch in train_dataset.triplets],
|
73 |
-
'positives': [batch[1] for batch in train_dataset.triplets],
|
74 |
-
'binding_site': [batch[2] for batch in train_dataset.triplets]
|
75 |
-
}
|
76 |
-
|
77 |
-
# Convert the dictionary to a HuggingFace Dataset
|
78 |
-
train_hf_dataset = HFDataset.from_dict(train_prebatched_data_dict)
|
79 |
-
train_hf_dataset.save_to_disk('/home/tc415/muPPIt/dataset/pep_prot_test')
|
80 |
-
|
81 |
-
|
82 |
-
if __name__ == "__main__":
|
83 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset/prebatching.py
DELETED
@@ -1,197 +0,0 @@
|
|
1 |
-
import pandas as pd
|
2 |
-
import ast
|
3 |
-
from sklearn.model_selection import train_test_split
|
4 |
-
import numpy as np
|
5 |
-
import torch.nn.functional as F
|
6 |
-
import torch
|
7 |
-
from torch.utils.data import Dataset
|
8 |
-
from datasets import Dataset as HFDataset, DatasetDict
|
9 |
-
from transformers import AutoTokenizer
|
10 |
-
from lightning.pytorch import seed_everything
|
11 |
-
import pdb
|
12 |
-
|
13 |
-
|
14 |
-
class TripletDataset(Dataset):
|
15 |
-
def __init__(self, anchors, positives, negatives, binding_sites, tokenizer, max_sequence_length=40000):
|
16 |
-
self.anchors = anchors
|
17 |
-
self.positives = positives
|
18 |
-
self.negatives = negatives
|
19 |
-
self.binding_sites = binding_sites
|
20 |
-
self.tokenizer = tokenizer
|
21 |
-
self.max_sequence_length = max_sequence_length
|
22 |
-
self.triplets = self.precompute_triplets()
|
23 |
-
self.batch_indices = self.get_batch_indices()
|
24 |
-
self.prebatched_data = self.create_prebatched_data()
|
25 |
-
|
26 |
-
def __len__(self):
|
27 |
-
return len(self.batch_indices)
|
28 |
-
|
29 |
-
def __getitem__(self, index):
|
30 |
-
batch = self.prebatched_data[index]
|
31 |
-
return batch
|
32 |
-
|
33 |
-
def precompute_triplets(self):
|
34 |
-
triplets = []
|
35 |
-
for anchor, positive, negative, binding_site in zip(self.anchors, self.positives, self.negatives,
|
36 |
-
self.binding_sites):
|
37 |
-
triplets.append((anchor, positive, negative, binding_site))
|
38 |
-
return triplets
|
39 |
-
|
40 |
-
def get_batch_indices(self):
|
41 |
-
sizes = [(len(anchor) + len(positive) + len(negative), i) for i, (anchor, positive, negative, _) in
|
42 |
-
enumerate(self.triplets)]
|
43 |
-
sizes.sort()
|
44 |
-
batches = []
|
45 |
-
buf = []
|
46 |
-
current_buf_len = 0
|
47 |
-
|
48 |
-
def _flush_current_buf():
|
49 |
-
nonlocal current_buf_len, buf
|
50 |
-
if len(buf) == 0:
|
51 |
-
return
|
52 |
-
batches.append(buf)
|
53 |
-
buf = []
|
54 |
-
current_buf_len = 0
|
55 |
-
|
56 |
-
for sz, i in sizes:
|
57 |
-
if current_buf_len + sz > self.max_sequence_length:
|
58 |
-
_flush_current_buf()
|
59 |
-
buf.append(i)
|
60 |
-
current_buf_len += sz
|
61 |
-
|
62 |
-
_flush_current_buf()
|
63 |
-
return batches
|
64 |
-
|
65 |
-
def create_prebatched_data(self):
|
66 |
-
prebatched_data = []
|
67 |
-
for batch_indices in self.batch_indices:
|
68 |
-
anchor_batch = []
|
69 |
-
positive_batch = []
|
70 |
-
negative_batch = []
|
71 |
-
binding_site_batch = []
|
72 |
-
|
73 |
-
for index in batch_indices:
|
74 |
-
anchor, positive, negative, binding_site = self.triplets[index]
|
75 |
-
anchor_batch.append(anchor)
|
76 |
-
positive_batch.append(positive)
|
77 |
-
negative_batch.append(negative)
|
78 |
-
binding_site_batch.append(binding_site)
|
79 |
-
|
80 |
-
anchor_tokens = self.tokenizer(anchor_batch, return_tensors='pt', padding=True, truncation=True,
|
81 |
-
max_length=self.max_sequence_length)
|
82 |
-
positive_tokens = self.tokenizer(positive_batch, return_tensors='pt', padding=True, truncation=True,
|
83 |
-
max_length=self.max_sequence_length)
|
84 |
-
negative_tokens = self.tokenizer(negative_batch, return_tensors='pt', padding=True, truncation=True,
|
85 |
-
max_length=self.max_sequence_length)
|
86 |
-
|
87 |
-
n, max_length = negative_tokens['input_ids'].shape[0], negative_tokens['input_ids'].shape[1]
|
88 |
-
target = torch.zeros(n, max_length)
|
89 |
-
for i in range(len(binding_site_batch)):
|
90 |
-
binding_site = binding_site_batch[i]
|
91 |
-
target[i,binding_site] = 1
|
92 |
-
|
93 |
-
# mask out the first column because it corresponds to the start token
|
94 |
-
anchor_tokens['attention_mask'][:, 0] = 0
|
95 |
-
positive_tokens['attention_mask'][:, 0] = 0
|
96 |
-
negative_tokens['attention_mask'][:, 0] = 0
|
97 |
-
|
98 |
-
prebatched_data.append({
|
99 |
-
'anchor_input_ids': anchor_tokens['input_ids'],
|
100 |
-
'anchor_attention_mask': anchor_tokens['attention_mask'],
|
101 |
-
'positive_input_ids': positive_tokens['input_ids'],
|
102 |
-
'positive_attention_mask': positive_tokens['attention_mask'],
|
103 |
-
'negative_input_ids': negative_tokens['input_ids'],
|
104 |
-
'negative_attention_mask': negative_tokens['attention_mask'],
|
105 |
-
'binding_site': target
|
106 |
-
})
|
107 |
-
|
108 |
-
return prebatched_data
|
109 |
-
|
110 |
-
def main():
|
111 |
-
seed_everything(42)
|
112 |
-
|
113 |
-
data = pd.read_csv('dataset/dataset.csv')
|
114 |
-
|
115 |
-
negatives = data['mutTarget'].tolist()
|
116 |
-
positives = data['Binder'].tolist()
|
117 |
-
anchors = data['Target'].tolist()
|
118 |
-
binding_sites = data['mutTarget_motifs'].tolist()
|
119 |
-
|
120 |
-
# We should plus 1 because there will be a start token after embedded by ESM-2
|
121 |
-
binding_sites = [ast.literal_eval(binding_site) for binding_site in binding_sites]
|
122 |
-
binding_sites = [[int(site.split('_')[1]) + 1 for site in binding_site] for binding_site in binding_sites]
|
123 |
-
|
124 |
-
train_val_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
|
125 |
-
train_data, val_data = train_test_split(train_val_data, test_size=0.25, random_state=42)
|
126 |
-
|
127 |
-
train_index = train_data.index.to_numpy()
|
128 |
-
val_index = val_data.index.to_numpy()
|
129 |
-
test_index = test_data.index.to_numpy()
|
130 |
-
|
131 |
-
train_anchor_dataset = np.array(anchors)[train_index]
|
132 |
-
train_negative_dataset = np.array(negatives)[train_index]
|
133 |
-
train_positive_dataset = np.array(positives)[train_index]
|
134 |
-
train_binding_dataset = [binding_sites[i] for i in train_index]
|
135 |
-
|
136 |
-
val_anchor_dataset = np.array(anchors)[val_index]
|
137 |
-
val_negative_dataset = np.array(negatives)[val_index]
|
138 |
-
val_positive_dataset = np.array(positives)[val_index]
|
139 |
-
val_binding_dataset = [binding_sites[i] for i in val_index]
|
140 |
-
|
141 |
-
test_anchor_dataset = np.array(anchors)[test_index]
|
142 |
-
test_negative_dataset = np.array(negatives)[test_index]
|
143 |
-
test_positive_dataset = np.array(positives)[test_index]
|
144 |
-
test_binding_dataset = [binding_sites[i] for i in test_index]
|
145 |
-
|
146 |
-
# Create an instance of the tokenizer
|
147 |
-
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
148 |
-
|
149 |
-
# Initialize the TripletDataset
|
150 |
-
train_dataset = TripletDataset(train_anchor_dataset, train_positive_dataset, train_negative_dataset, train_binding_dataset, tokenizer=tokenizer, max_sequence_length=40000)
|
151 |
-
val_dataset = TripletDataset(val_anchor_dataset, val_positive_dataset, val_negative_dataset, val_binding_dataset, tokenizer=tokenizer, max_sequence_length=40000)
|
152 |
-
test_dataset = TripletDataset(test_anchor_dataset, test_positive_dataset, test_negative_dataset, test_binding_dataset, tokenizer=tokenizer, max_sequence_length=40000)
|
153 |
-
|
154 |
-
# Convert the prebatched data to a dictionary with each batch as an entry
|
155 |
-
train_prebatched_data_dict = {
|
156 |
-
'anchor_input_ids': [batch['anchor_input_ids'].numpy() for batch in train_dataset.prebatched_data],
|
157 |
-
'anchor_attention_mask': [batch['anchor_attention_mask'].numpy() for batch in train_dataset.prebatched_data],
|
158 |
-
'positive_input_ids': [batch['positive_input_ids'].numpy() for batch in train_dataset.prebatched_data],
|
159 |
-
'positive_attention_mask': [batch['positive_attention_mask'].numpy() for batch in train_dataset.prebatched_data],
|
160 |
-
'negative_input_ids': [batch['negative_input_ids'].numpy() for batch in train_dataset.prebatched_data],
|
161 |
-
'negative_attention_mask': [batch['negative_attention_mask'].numpy() for batch in train_dataset.prebatched_data],
|
162 |
-
'binding_site': [batch['binding_site'].numpy() for batch in train_dataset.prebatched_data]
|
163 |
-
}
|
164 |
-
|
165 |
-
val_prebatched_data_dict = {
|
166 |
-
'anchor_input_ids': [batch['anchor_input_ids'].numpy() for batch in val_dataset.prebatched_data],
|
167 |
-
'anchor_attention_mask': [batch['anchor_attention_mask'].numpy() for batch in val_dataset.prebatched_data],
|
168 |
-
'positive_input_ids': [batch['positive_input_ids'].numpy() for batch in val_dataset.prebatched_data],
|
169 |
-
'positive_attention_mask': [batch['positive_attention_mask'].numpy() for batch in val_dataset.prebatched_data],
|
170 |
-
'negative_input_ids': [batch['negative_input_ids'].numpy() for batch in val_dataset.prebatched_data],
|
171 |
-
'negative_attention_mask': [batch['negative_attention_mask'].numpy() for batch in val_dataset.prebatched_data],
|
172 |
-
'binding_site': [batch['binding_site'].numpy() for batch in val_dataset.prebatched_data]
|
173 |
-
}
|
174 |
-
test_prebatched_data_dict = {
|
175 |
-
'anchor_input_ids': [batch['anchor_input_ids'].numpy() for batch in test_dataset.prebatched_data],
|
176 |
-
'anchor_attention_mask': [batch['anchor_attention_mask'].numpy() for batch in test_dataset.prebatched_data],
|
177 |
-
'positive_input_ids': [batch['positive_input_ids'].numpy() for batch in test_dataset.prebatched_data],
|
178 |
-
'positive_attention_mask': [batch['positive_attention_mask'].numpy() for batch in test_dataset.prebatched_data],
|
179 |
-
'negative_input_ids': [batch['negative_input_ids'].numpy() for batch in test_dataset.prebatched_data],
|
180 |
-
'negative_attention_mask': [batch['negative_attention_mask'].numpy() for batch in test_dataset.prebatched_data],
|
181 |
-
'binding_site': [batch['binding_site'].numpy() for batch in test_dataset.prebatched_data]
|
182 |
-
}
|
183 |
-
|
184 |
-
# Convert the dictionary to a HuggingFace Dataset
|
185 |
-
train_hf_dataset = HFDataset.from_dict(train_prebatched_data_dict)
|
186 |
-
train_hf_dataset.save_to_disk('train_mut')
|
187 |
-
|
188 |
-
val_hf_dataset = HFDataset.from_dict(val_prebatched_data_dict)
|
189 |
-
val_hf_dataset.save_to_disk('val_mut')
|
190 |
-
|
191 |
-
test_hf_dataset = HFDataset.from_dict(test_prebatched_data_dict)
|
192 |
-
test_hf_dataset.save_to_disk('test_mut')
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
if __name__ == "__main__":
|
197 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset/static_prebatching.py
DELETED
@@ -1,162 +0,0 @@
|
|
1 |
-
import pandas as pd
|
2 |
-
import ast
|
3 |
-
from sklearn.model_selection import train_test_split
|
4 |
-
import numpy as np
|
5 |
-
import torch.nn.functional as F
|
6 |
-
import torch
|
7 |
-
from torch.utils.data import Dataset
|
8 |
-
from datasets import Dataset as HFDataset, DatasetDict
|
9 |
-
from transformers import AutoTokenizer
|
10 |
-
import pdb
|
11 |
-
|
12 |
-
|
13 |
-
class TripletDataset(Dataset):
|
14 |
-
def __init__(self, anchors, positives, negatives, binding_sites, tokenizer, max_sequence_length=40000):
|
15 |
-
self.anchors = anchors
|
16 |
-
self.positives = positives
|
17 |
-
self.negatives = negatives
|
18 |
-
self.binding_sites = binding_sites
|
19 |
-
self.tokenizer = tokenizer
|
20 |
-
self.max_sequence_length = max_sequence_length
|
21 |
-
self.triplets = []
|
22 |
-
self.precompute_triplets()
|
23 |
-
|
24 |
-
def __len__(self):
|
25 |
-
return len(self.triplets)
|
26 |
-
|
27 |
-
def __getitem__(self, index):
|
28 |
-
return self.triplets[index]
|
29 |
-
|
30 |
-
def precompute_triplets(self):
|
31 |
-
self.triplets = []
|
32 |
-
for anchor, positive, negative, binding_site in zip(self.anchors, self.positives, self.negatives,
|
33 |
-
self.binding_sites):
|
34 |
-
anchor_tokens = self.tokenizer(anchor, return_tensors='pt', padding=True, truncation=True,
|
35 |
-
max_length=self.max_sequence_length)
|
36 |
-
positive_tokens = self.tokenizer(positive, return_tensors='pt', padding=True, truncation=True,
|
37 |
-
max_length=self.max_sequence_length)
|
38 |
-
negative_tokens = self.tokenizer(negative, return_tensors='pt', padding=True, truncation=True,
|
39 |
-
max_length=self.max_sequence_length)
|
40 |
-
|
41 |
-
# mask out the first and last tokens due to being <bos> and <eos>
|
42 |
-
anchor_tokens['attention_mask'][0][0] = 0
|
43 |
-
anchor_tokens['attention_mask'][0][-1] = 0
|
44 |
-
positive_tokens['attention_mask'][0][0] = 0
|
45 |
-
positive_tokens['attention_mask'][0][-1] = 0
|
46 |
-
negative_tokens['attention_mask'][0][0] = 0
|
47 |
-
negative_tokens['attention_mask'][0][-1] = 0
|
48 |
-
|
49 |
-
self.triplets.append((anchor_tokens, positive_tokens, negative_tokens, binding_site))
|
50 |
-
# pdb.set_trace()
|
51 |
-
return self.triplets
|
52 |
-
|
53 |
-
|
54 |
-
def main():
|
55 |
-
|
56 |
-
data = pd.read_csv('dataset_drop_500.csv')
|
57 |
-
|
58 |
-
print(len(data))
|
59 |
-
|
60 |
-
negatives = data['mutTarget'].tolist()
|
61 |
-
positives = data['Binder'].tolist()
|
62 |
-
anchors = data['Target'].tolist()
|
63 |
-
binding_sites = data['mutTarget_motifs'].tolist()
|
64 |
-
|
65 |
-
# We should plus 1 because there will be a start token after embedded by ESM-2
|
66 |
-
binding_sites = [ast.literal_eval(binding_site) for binding_site in binding_sites]
|
67 |
-
binding_sites = [[int(site.split('_')[1]) + 1 for site in binding_site] for binding_site in binding_sites]
|
68 |
-
|
69 |
-
train_val_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
|
70 |
-
train_data, val_data = train_test_split(train_val_data, test_size=0.25, random_state=42)
|
71 |
-
|
72 |
-
train_index = train_data.index.to_numpy()
|
73 |
-
val_index = val_data.index.to_numpy()
|
74 |
-
test_index = test_data.index.to_numpy()
|
75 |
-
|
76 |
-
train_anchor_dataset = np.array(anchors)[train_index]
|
77 |
-
train_negative_dataset = np.array(negatives)[train_index]
|
78 |
-
train_positive_dataset = np.array(positives)[train_index]
|
79 |
-
train_binding_dataset = [binding_sites[i] for i in train_index]
|
80 |
-
|
81 |
-
val_anchor_dataset = np.array(anchors)[val_index]
|
82 |
-
val_negative_dataset = np.array(negatives)[val_index]
|
83 |
-
val_positive_dataset = np.array(positives)[val_index]
|
84 |
-
val_binding_dataset = [binding_sites[i] for i in val_index]
|
85 |
-
|
86 |
-
test_anchor_dataset = np.array(anchors)[test_index]
|
87 |
-
test_negative_dataset = np.array(negatives)[test_index]
|
88 |
-
test_positive_dataset = np.array(positives)[test_index]
|
89 |
-
test_binding_dataset = [binding_sites[i] for i in test_index]
|
90 |
-
|
91 |
-
# Create an instance of the tokenizer
|
92 |
-
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
93 |
-
|
94 |
-
# Initialize the TripletDataset
|
95 |
-
train_dataset = TripletDataset(train_anchor_dataset, train_positive_dataset, train_negative_dataset, train_binding_dataset, tokenizer=tokenizer, max_sequence_length=50000)
|
96 |
-
val_dataset = TripletDataset(val_anchor_dataset, val_positive_dataset, val_negative_dataset, val_binding_dataset, tokenizer=tokenizer, max_sequence_length=50000)
|
97 |
-
test_dataset = TripletDataset(test_anchor_dataset, test_positive_dataset, test_negative_dataset, test_binding_dataset, tokenizer=tokenizer, max_sequence_length=50000)
|
98 |
-
|
99 |
-
train_prebatched_data_dict = {
|
100 |
-
'anchors': [batch[0] for batch in train_dataset.triplets],
|
101 |
-
'positives': [batch[1] for batch in train_dataset.triplets],
|
102 |
-
# 'negatives': [batch[2] for batch in train_dataset.triplets],
|
103 |
-
'binding_site': [batch[3] for batch in train_dataset.triplets]
|
104 |
-
}
|
105 |
-
|
106 |
-
val_prebatched_data_dict = {
|
107 |
-
'anchors': [batch[0] for batch in val_dataset.triplets],
|
108 |
-
'positives': [batch[1] for batch in val_dataset.triplets],
|
109 |
-
# 'negatives': [batch[2] for batch in val_dataset.triplets],
|
110 |
-
'binding_site': [batch[3] for batch in val_dataset.triplets]
|
111 |
-
}
|
112 |
-
|
113 |
-
test_prebatched_data_dict = {
|
114 |
-
'anchors': [batch[0] for batch in test_dataset.triplets],
|
115 |
-
'positives': [batch[1] for batch in test_dataset.triplets],
|
116 |
-
# 'negatives': [batch[2] for batch in test_dataset.triplets],
|
117 |
-
'binding_site': [batch[3] for batch in test_dataset.triplets]
|
118 |
-
}
|
119 |
-
|
120 |
-
# Convert the prebatched data to a dictionary with each batch as an entry
|
121 |
-
# train_prebatched_data_dict = {
|
122 |
-
# 'anchor_input_ids': [batch[0]['input_ids'].numpy() for batch in train_dataset.triplets],
|
123 |
-
# 'anchor_attention_mask': [batch[0]['attention_mask'].numpy() for batch in train_dataset.triplets],
|
124 |
-
# 'positive_input_ids': [batch[1]['input_ids'].numpy() for batch in train_dataset.triplets],
|
125 |
-
# 'positive_attention_mask': [batch[1]['attention_mask'].numpy() for batch in train_dataset.triplets],
|
126 |
-
# 'negative_input_ids': [batch[2]['input_ids'].numpy() for batch in train_dataset.triplets],
|
127 |
-
# 'negative_attention_mask': [batch[2]['attention_mask'].numpy() for batch in train_dataset.triplets],
|
128 |
-
# 'binding_site': [batch[3] for batch in train_dataset.triplets]
|
129 |
-
# }
|
130 |
-
#
|
131 |
-
# val_prebatched_data_dict = {
|
132 |
-
# 'anchor_input_ids': [batch[0]['input_ids'].numpy() for batch in val_dataset.triplets],
|
133 |
-
# 'anchor_attention_mask': [batch[0]['attention_mask'].numpy() for batch in val_dataset.triplets],
|
134 |
-
# 'positive_input_ids': [batch[1]['input_ids'].numpy() for batch in val_dataset.triplets],
|
135 |
-
# 'positive_attention_mask': [batch[1]['attention_mask'].numpy() for batch in val_dataset.triplets],
|
136 |
-
# 'negative_input_ids': [batch[2]['input_ids'].numpy() for batch in val_dataset.triplets],
|
137 |
-
# 'negative_attention_mask': [batch[2]['attention_mask'].numpy() for batch in val_dataset.triplets],
|
138 |
-
# 'binding_site': [batch[3] for batch in val_dataset.triplets]
|
139 |
-
# }
|
140 |
-
# test_prebatched_data_dict = {
|
141 |
-
# 'anchor_input_ids': [batch[0]['input_ids'].numpy() for batch in test_dataset.triplets],
|
142 |
-
# 'anchor_attention_mask': [batch[0]['attention_mask'].numpy() for batch in test_dataset.triplets],
|
143 |
-
# 'positive_input_ids': [batch[1]['input_ids'].numpy() for batch in test_dataset.triplets],
|
144 |
-
# 'positive_attention_mask': [batch[1]['attention_mask'].numpy() for batch in test_dataset.triplets],
|
145 |
-
# 'negative_input_ids': [batch[2]['input_ids'].numpy() for batch in test_dataset.triplets],
|
146 |
-
# 'negative_attention_mask': [batch[2]['attention_mask'].numpy() for batch in test_dataset.triplets],
|
147 |
-
# 'binding_site': [batch[3] for batch in test_dataset.triplets]
|
148 |
-
# }
|
149 |
-
|
150 |
-
# Convert the dictionary to a HuggingFace Dataset
|
151 |
-
train_hf_dataset = HFDataset.from_dict(train_prebatched_data_dict)
|
152 |
-
train_hf_dataset.save_to_disk('/home/tc415/muPPIt/dataset/train_dataset_drop_500')
|
153 |
-
|
154 |
-
val_hf_dataset = HFDataset.from_dict(val_prebatched_data_dict)
|
155 |
-
val_hf_dataset.save_to_disk('/home/tc415/muPPIt/dataset/val_dataset_drop_500')
|
156 |
-
|
157 |
-
test_hf_dataset = HFDataset.from_dict(test_prebatched_data_dict)
|
158 |
-
test_hf_dataset.save_to_disk('/home/tc415/muPPIt/dataset/test_dataset_drop_500')
|
159 |
-
|
160 |
-
|
161 |
-
if __name__ == "__main__":
|
162 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|