Delete dataset_mlm.py
Browse files- dataset_mlm.py +0 -151
dataset_mlm.py
DELETED
@@ -1,151 +0,0 @@
|
|
1 |
-
import pandas as pd
|
2 |
-
from copy import deepcopy
|
3 |
-
|
4 |
-
import torch
|
5 |
-
from torch.utils.data import TensorDataset, DataLoader
|
6 |
-
from sklearn.model_selection import train_test_split
|
7 |
-
|
8 |
-
from vocab import PepVocab
|
9 |
-
from utils import mask, create_vocab
|
10 |
-
|
11 |
-
addtition_tokens = ['<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
|
12 |
-
'<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>',
|
13 |
-
'<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>',
|
14 |
-
'<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
|
15 |
-
'<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
|
16 |
-
'<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>','<high>','<low>']
|
17 |
-
|
18 |
-
def add_tokens_to_vocab(vocab_mlm: PepVocab):
|
19 |
-
vocab_mlm.add_special_token(addtition_tokens)
|
20 |
-
return vocab_mlm
|
21 |
-
|
22 |
-
def split_seq(seq, vocab, get_seq=False):
|
23 |
-
'''
|
24 |
-
note: the function is suitable for the sequences with the format of "label|label|sequence|msa1|msa2|msa3"
|
25 |
-
'''
|
26 |
-
start = '[CLS]'
|
27 |
-
end = '[SEP]'
|
28 |
-
pad = '[PAD]'
|
29 |
-
cls_label = seq.split('|')[0]
|
30 |
-
act_label = seq.split('|')[1]
|
31 |
-
|
32 |
-
if get_seq == True:
|
33 |
-
add = lambda x: [start] + [cls_label] + [act_label] + x + [end]
|
34 |
-
pep_seq = seq.split('|')[2]
|
35 |
-
# return [start] + [cls_label] + [act_label] + vocab.split_seq(pep_seq) + [end]
|
36 |
-
return add(vocab.split_seq(pep_seq))
|
37 |
-
|
38 |
-
else:
|
39 |
-
add = lambda x: [start] + [pad] + [pad] + x + [end]
|
40 |
-
msa1_seq = seq.split('|')[3]
|
41 |
-
msa2_seq = seq.split('|')[4]
|
42 |
-
msa3_seq = seq.split('|')[5]
|
43 |
-
|
44 |
-
# return [vocab.split_seq(msa1_seq)] + [vocab.split_seq(msa2_seq)] + [vocab.split_seq(msa3_seq)]
|
45 |
-
return [add(vocab.split_seq(msa1_seq))] + [add(vocab.split_seq(msa2_seq))] + [add(vocab.split_seq(msa3_seq))]
|
46 |
-
|
47 |
-
def get_paded_token_idx(vocab_mlm):
|
48 |
-
cono_path = '/home/ubuntu/work/gecheng/conoGen_final/FinalCono/new_cycle/conoData_C5.csv'
|
49 |
-
seq = pd.read_csv(cono_path)['Sequences']
|
50 |
-
|
51 |
-
splited_seq = list(seq.apply(split_seq, args=(vocab_mlm,True, )))
|
52 |
-
splited_msa = list(seq.apply(split_seq, args=(vocab_mlm, False, )))
|
53 |
-
|
54 |
-
vocab_mlm.set_get_attn(is_get=True)
|
55 |
-
padded_seq = vocab_mlm.truncate_pad(splited_seq, num_steps=54, padding_token='[PAD]')
|
56 |
-
attn_idx = vocab_mlm.get_attention_mask_mat()
|
57 |
-
|
58 |
-
vocab_mlm.set_get_attn(is_get=False)
|
59 |
-
padded_msa = vocab_mlm.truncate_pad(splited_msa, num_steps=54, padding_token='[PAD]')
|
60 |
-
|
61 |
-
idx_seq = vocab_mlm.__getitem__(padded_seq) # [b, 54] start, cls_label, act_label, sequence, end
|
62 |
-
|
63 |
-
idx_msa = vocab_mlm.__getitem__(padded_msa) # [b, 3, 50]
|
64 |
-
|
65 |
-
return padded_seq, idx_seq, idx_msa, attn_idx
|
66 |
-
|
67 |
-
def get_paded_token_idx_gen(vocab_mlm, seq):
|
68 |
-
|
69 |
-
splited_seq = split_seq(seq[0], vocab_mlm, True)
|
70 |
-
splited_msa = split_seq(seq[0], vocab_mlm, False)
|
71 |
-
|
72 |
-
vocab_mlm.set_get_attn(is_get=True)
|
73 |
-
padded_seq = vocab_mlm.truncate_pad(splited_seq, num_steps=54, padding_token='[PAD]')
|
74 |
-
attn_idx = vocab_mlm.get_attention_mask_mat()
|
75 |
-
|
76 |
-
vocab_mlm.set_get_attn(is_get=False)
|
77 |
-
padded_msa = vocab_mlm.truncate_pad(splited_msa, num_steps=54, padding_token='[PAD]')
|
78 |
-
|
79 |
-
idx_seq = vocab_mlm.__getitem__(padded_seq) # [b, 54] start, cls_label, act_label, sequence, end
|
80 |
-
|
81 |
-
idx_msa = vocab_mlm.__getitem__(padded_msa) # [b, 3, 50]
|
82 |
-
|
83 |
-
return padded_seq, idx_seq, idx_msa, attn_idx
|
84 |
-
|
85 |
-
|
86 |
-
def get_paded_token_idx_gen(vocab_mlm, seq, new_seq):
|
87 |
-
if new_seq == None:
|
88 |
-
splited_seq = split_seq(seq[0], vocab_mlm, True)
|
89 |
-
splited_msa = split_seq(seq[0], vocab_mlm, False)
|
90 |
-
|
91 |
-
vocab_mlm.set_get_attn(is_get=True)
|
92 |
-
padded_seq = vocab_mlm.truncate_pad(splited_seq, num_steps=54, padding_token='[PAD]')
|
93 |
-
attn_idx = vocab_mlm.get_attention_mask_mat()
|
94 |
-
vocab_mlm.set_get_attn(is_get=False)
|
95 |
-
|
96 |
-
padded_msa = vocab_mlm.truncate_pad(splited_msa, num_steps=54, padding_token='[PAD]')
|
97 |
-
|
98 |
-
idx_seq = vocab_mlm.__getitem__(padded_seq) # [b, 54] start, cls_label, act_label, sequence, end
|
99 |
-
idx_msa = vocab_mlm.__getitem__(padded_msa) # [b, 3, 50]
|
100 |
-
else:
|
101 |
-
splited_seq = split_seq(seq[0], vocab_mlm, True)
|
102 |
-
splited_msa = split_seq(seq[0], vocab_mlm, False)
|
103 |
-
vocab_mlm.set_get_attn(is_get=True)
|
104 |
-
padded_seq = vocab_mlm.truncate_pad(splited_seq, num_steps=54, padding_token='[PAD]')
|
105 |
-
attn_idx = vocab_mlm.get_attention_mask_mat()
|
106 |
-
vocab_mlm.set_get_attn(is_get=False)
|
107 |
-
padded_msa = vocab_mlm.truncate_pad(splited_msa, num_steps=54, padding_token='[PAD]')
|
108 |
-
idx_msa = vocab_mlm.__getitem__(padded_msa) # [b, 3, 50]
|
109 |
-
|
110 |
-
idx_seq = vocab_mlm.__getitem__(new_seq)
|
111 |
-
return padded_seq, idx_seq, idx_msa, attn_idx
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
def make_mask(seq_ser, start, end, time, vocab_mlm, labels, idx_msa, attn_idx):
|
116 |
-
seq_ser = pd.Series(seq_ser)
|
117 |
-
masked_seq = seq_ser.apply(mask, args=(start, end, time))
|
118 |
-
masked_idx = vocab_mlm.__getitem__(list(masked_seq))
|
119 |
-
masked_idx = torch.tensor(masked_idx)
|
120 |
-
device = torch.device('cuda:1')
|
121 |
-
data_arrays = (masked_idx.to(device), labels.to(device), idx_msa.to(device), attn_idx.to(device))
|
122 |
-
dataset = TensorDataset(*data_arrays)
|
123 |
-
train_dataset, test_dataset = train_test_split(dataset, test_size=0.1, random_state=42, shuffle=True)
|
124 |
-
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
|
125 |
-
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=True)
|
126 |
-
|
127 |
-
return train_loader, test_loader
|
128 |
-
|
129 |
-
if __name__ == '__main__':
|
130 |
-
# from add_args import parse_args
|
131 |
-
import numpy as np
|
132 |
-
# args = parse_args()
|
133 |
-
|
134 |
-
vocab_mlm = create_vocab()
|
135 |
-
vocab_mlm = add_tokens_to_vocab(vocab_mlm)
|
136 |
-
padded_seq, idx_seq, idx_msa, attn_idx = get_paded_token_idx(vocab_mlm)
|
137 |
-
labels = torch.tensor(idx_seq)
|
138 |
-
idx_msa = torch.tensor(idx_msa)
|
139 |
-
attn_idx = torch.tensor(attn_idx)
|
140 |
-
|
141 |
-
# time_step = args.mask_time_step
|
142 |
-
for t in np.arange(1, 50):
|
143 |
-
padded_seq_copy = deepcopy(padded_seq)
|
144 |
-
train_loader, test_loader = make_mask(padded_seq_copy, start=0, end=49, time=t,
|
145 |
-
vocab_mlm=vocab_mlm, labels=labels, idx_msa=idx_msa, attn_idx=attn_idx)
|
146 |
-
for i, (masked_idx, label, msa, attn) in enumerate(train_loader):
|
147 |
-
print(f"the {i}th batch is that masked_idx is {masked_idx.shape}, labels is {label.shape}, idx_msa is {msa.shape}")
|
148 |
-
print(f"the {t}th time step is done")
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|