oucgc1996 commited on
Commit
9ca4aa5
·
verified ·
1 Parent(s): a464b00

Delete dataset_mlm.py

Browse files
Files changed (1) hide show
  1. dataset_mlm.py +0 -151
dataset_mlm.py DELETED
@@ -1,151 +0,0 @@
1
- import pandas as pd
2
- from copy import deepcopy
3
-
4
- import torch
5
- from torch.utils.data import TensorDataset, DataLoader
6
- from sklearn.model_selection import train_test_split
7
-
8
- from vocab import PepVocab
9
- from utils import mask, create_vocab
10
-
11
- addtition_tokens = ['<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
12
- '<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>',
13
- '<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>',
14
- '<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
15
- '<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
16
- '<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>','<high>','<low>']
17
-
18
- def add_tokens_to_vocab(vocab_mlm: PepVocab):
19
- vocab_mlm.add_special_token(addtition_tokens)
20
- return vocab_mlm
21
-
22
- def split_seq(seq, vocab, get_seq=False):
23
- '''
24
- note: the function is suitable for the sequences with the format of "label|label|sequence|msa1|msa2|msa3"
25
- '''
26
- start = '[CLS]'
27
- end = '[SEP]'
28
- pad = '[PAD]'
29
- cls_label = seq.split('|')[0]
30
- act_label = seq.split('|')[1]
31
-
32
- if get_seq == True:
33
- add = lambda x: [start] + [cls_label] + [act_label] + x + [end]
34
- pep_seq = seq.split('|')[2]
35
- # return [start] + [cls_label] + [act_label] + vocab.split_seq(pep_seq) + [end]
36
- return add(vocab.split_seq(pep_seq))
37
-
38
- else:
39
- add = lambda x: [start] + [pad] + [pad] + x + [end]
40
- msa1_seq = seq.split('|')[3]
41
- msa2_seq = seq.split('|')[4]
42
- msa3_seq = seq.split('|')[5]
43
-
44
- # return [vocab.split_seq(msa1_seq)] + [vocab.split_seq(msa2_seq)] + [vocab.split_seq(msa3_seq)]
45
- return [add(vocab.split_seq(msa1_seq))] + [add(vocab.split_seq(msa2_seq))] + [add(vocab.split_seq(msa3_seq))]
46
-
47
- def get_paded_token_idx(vocab_mlm):
48
- cono_path = '/home/ubuntu/work/gecheng/conoGen_final/FinalCono/new_cycle/conoData_C5.csv'
49
- seq = pd.read_csv(cono_path)['Sequences']
50
-
51
- splited_seq = list(seq.apply(split_seq, args=(vocab_mlm,True, )))
52
- splited_msa = list(seq.apply(split_seq, args=(vocab_mlm, False, )))
53
-
54
- vocab_mlm.set_get_attn(is_get=True)
55
- padded_seq = vocab_mlm.truncate_pad(splited_seq, num_steps=54, padding_token='[PAD]')
56
- attn_idx = vocab_mlm.get_attention_mask_mat()
57
-
58
- vocab_mlm.set_get_attn(is_get=False)
59
- padded_msa = vocab_mlm.truncate_pad(splited_msa, num_steps=54, padding_token='[PAD]')
60
-
61
- idx_seq = vocab_mlm.__getitem__(padded_seq) # [b, 54] start, cls_label, act_label, sequence, end
62
-
63
- idx_msa = vocab_mlm.__getitem__(padded_msa) # [b, 3, 50]
64
-
65
- return padded_seq, idx_seq, idx_msa, attn_idx
66
-
67
- def get_paded_token_idx_gen(vocab_mlm, seq):
68
-
69
- splited_seq = split_seq(seq[0], vocab_mlm, True)
70
- splited_msa = split_seq(seq[0], vocab_mlm, False)
71
-
72
- vocab_mlm.set_get_attn(is_get=True)
73
- padded_seq = vocab_mlm.truncate_pad(splited_seq, num_steps=54, padding_token='[PAD]')
74
- attn_idx = vocab_mlm.get_attention_mask_mat()
75
-
76
- vocab_mlm.set_get_attn(is_get=False)
77
- padded_msa = vocab_mlm.truncate_pad(splited_msa, num_steps=54, padding_token='[PAD]')
78
-
79
- idx_seq = vocab_mlm.__getitem__(padded_seq) # [b, 54] start, cls_label, act_label, sequence, end
80
-
81
- idx_msa = vocab_mlm.__getitem__(padded_msa) # [b, 3, 50]
82
-
83
- return padded_seq, idx_seq, idx_msa, attn_idx
84
-
85
-
86
- def get_paded_token_idx_gen(vocab_mlm, seq, new_seq):
87
- if new_seq == None:
88
- splited_seq = split_seq(seq[0], vocab_mlm, True)
89
- splited_msa = split_seq(seq[0], vocab_mlm, False)
90
-
91
- vocab_mlm.set_get_attn(is_get=True)
92
- padded_seq = vocab_mlm.truncate_pad(splited_seq, num_steps=54, padding_token='[PAD]')
93
- attn_idx = vocab_mlm.get_attention_mask_mat()
94
- vocab_mlm.set_get_attn(is_get=False)
95
-
96
- padded_msa = vocab_mlm.truncate_pad(splited_msa, num_steps=54, padding_token='[PAD]')
97
-
98
- idx_seq = vocab_mlm.__getitem__(padded_seq) # [b, 54] start, cls_label, act_label, sequence, end
99
- idx_msa = vocab_mlm.__getitem__(padded_msa) # [b, 3, 50]
100
- else:
101
- splited_seq = split_seq(seq[0], vocab_mlm, True)
102
- splited_msa = split_seq(seq[0], vocab_mlm, False)
103
- vocab_mlm.set_get_attn(is_get=True)
104
- padded_seq = vocab_mlm.truncate_pad(splited_seq, num_steps=54, padding_token='[PAD]')
105
- attn_idx = vocab_mlm.get_attention_mask_mat()
106
- vocab_mlm.set_get_attn(is_get=False)
107
- padded_msa = vocab_mlm.truncate_pad(splited_msa, num_steps=54, padding_token='[PAD]')
108
- idx_msa = vocab_mlm.__getitem__(padded_msa) # [b, 3, 50]
109
-
110
- idx_seq = vocab_mlm.__getitem__(new_seq)
111
- return padded_seq, idx_seq, idx_msa, attn_idx
112
-
113
-
114
-
115
- def make_mask(seq_ser, start, end, time, vocab_mlm, labels, idx_msa, attn_idx):
116
- seq_ser = pd.Series(seq_ser)
117
- masked_seq = seq_ser.apply(mask, args=(start, end, time))
118
- masked_idx = vocab_mlm.__getitem__(list(masked_seq))
119
- masked_idx = torch.tensor(masked_idx)
120
- device = torch.device('cuda:1')
121
- data_arrays = (masked_idx.to(device), labels.to(device), idx_msa.to(device), attn_idx.to(device))
122
- dataset = TensorDataset(*data_arrays)
123
- train_dataset, test_dataset = train_test_split(dataset, test_size=0.1, random_state=42, shuffle=True)
124
- train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
125
- test_loader = DataLoader(test_dataset, batch_size=128, shuffle=True)
126
-
127
- return train_loader, test_loader
128
-
129
- if __name__ == '__main__':
130
- # from add_args import parse_args
131
- import numpy as np
132
- # args = parse_args()
133
-
134
- vocab_mlm = create_vocab()
135
- vocab_mlm = add_tokens_to_vocab(vocab_mlm)
136
- padded_seq, idx_seq, idx_msa, attn_idx = get_paded_token_idx(vocab_mlm)
137
- labels = torch.tensor(idx_seq)
138
- idx_msa = torch.tensor(idx_msa)
139
- attn_idx = torch.tensor(attn_idx)
140
-
141
- # time_step = args.mask_time_step
142
- for t in np.arange(1, 50):
143
- padded_seq_copy = deepcopy(padded_seq)
144
- train_loader, test_loader = make_mask(padded_seq_copy, start=0, end=49, time=t,
145
- vocab_mlm=vocab_mlm, labels=labels, idx_msa=idx_msa, attn_idx=attn_idx)
146
- for i, (masked_idx, label, msa, attn) in enumerate(train_loader):
147
- print(f"the {i}th batch is that masked_idx is {masked_idx.shape}, labels is {label.shape}, idx_msa is {msa.shape}")
148
- print(f"the {t}th time step is done")
149
-
150
-
151
-