Update app.py
Browse files
app.py
CHANGED
@@ -4,16 +4,22 @@ import gradio as gr
|
|
4 |
from gradio_rangeslider import RangeSlider
|
5 |
import pandas as pd
|
6 |
from utils import create_vocab, setup_seed
|
7 |
-
from dataset_mlm import
|
8 |
-
import time
|
9 |
-
seed = random.randint(0,100000)
|
10 |
|
|
|
|
|
|
|
|
|
|
|
11 |
setup_seed(seed)
|
|
|
|
|
12 |
device = torch.device("cpu")
|
13 |
vocab_mlm = create_vocab()
|
14 |
vocab_mlm = add_tokens_to_vocab(vocab_mlm)
|
15 |
-
save_path = 'mlm-model-27.pt'
|
16 |
-
train_seqs = pd.read_csv('C0_seq.csv')
|
17 |
train_seq = train_seqs['Seq'].tolist()
|
18 |
model = torch.load(save_path, map_location=torch.device('cpu'))
|
19 |
model = model.to(device)
|
@@ -24,7 +30,14 @@ def temperature_sampling(logits, temperature):
|
|
24 |
sampled_token = torch.multinomial(probabilities, 1)
|
25 |
return sampled_token
|
26 |
|
|
|
|
|
|
|
|
|
|
|
27 |
def CTXGen(τ, g_num, length_range, progress=gr.Progress()):
|
|
|
|
|
28 |
start, end = length_range
|
29 |
X1 = "X"
|
30 |
X2 = "X"
|
@@ -43,29 +56,36 @@ def CTXGen(τ, g_num, length_range, progress=gr.Progress()):
|
|
43 |
|
44 |
count = 0
|
45 |
gen_num = int(g_num)
|
46 |
-
NON_AA = ["B","O","U","Z","X",'<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
|
53 |
start_time = time.time()
|
54 |
while count < gen_num:
|
|
|
|
|
|
|
55 |
if time.time() - start_time > 1200:
|
56 |
break
|
|
|
57 |
gen_len = random.randint(int(start), int(end))
|
58 |
X3 = "X" * gen_len
|
59 |
seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"]
|
60 |
vocab_mlm.token_to_idx["X"] = 4
|
61 |
|
62 |
padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
63 |
-
input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
|
64 |
|
65 |
gen_length = len(input_text)
|
66 |
length = gen_length - sum(1 for x in input_text if x != '[MASK]')
|
67 |
|
68 |
for i in range(length):
|
|
|
|
|
|
|
69 |
_, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
70 |
idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
|
71 |
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
|
@@ -73,27 +93,26 @@ def CTXGen(τ, g_num, length_range, progress=gr.Progress()):
|
|
73 |
|
74 |
mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
|
75 |
mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
|
76 |
-
|
77 |
-
logits = model(idx_seq,idx_msa, attn_idx)
|
78 |
mask_logits = logits[0, mask_position.item(), :]
|
79 |
|
80 |
predicted_token_id = temperature_sampling(mask_logits, τ)
|
81 |
-
|
82 |
predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
|
83 |
input_text[mask_position.item()] = predicted_token
|
84 |
padded_seq[mask_position.item()] = predicted_token.strip()
|
85 |
new_seq = padded_seq
|
86 |
|
87 |
generated_seq = input_text
|
88 |
-
|
89 |
generated_seq[1] = "[MASK]"
|
90 |
generated_seq[2] = "[MASK]"
|
91 |
input_ids = vocab_mlm.__getitem__(generated_seq)
|
92 |
logits = model(torch.tensor([input_ids]).to(device), idx_msa)
|
93 |
-
|
94 |
cls_mask_logits = logits[0, 1, :]
|
95 |
act_mask_logits = logits[0, 2, :]
|
96 |
-
|
97 |
cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=1)
|
98 |
act_probability, act_mask_probs = torch.topk((torch.softmax(act_mask_logits, dim=-1)), k=1)
|
99 |
|
@@ -117,16 +136,24 @@ def CTXGen(τ, g_num, length_range, progress=gr.Progress()):
|
|
117 |
progress(count / gen_num, desc="Generating conotoxins...")
|
118 |
return 'output.csv', f"Generated {count} conotoxins."
|
119 |
|
|
|
120 |
iface = gr.Interface(
|
121 |
fn=CTXGen,
|
122 |
inputs=[
|
123 |
gr.Slider(minimum=1, maximum=2, step=0.1, label="τ"),
|
124 |
-
gr.Dropdown(choices=[1,10,100], label="Number of generations"),
|
125 |
RangeSlider(minimum=8, maximum=50, step=1, value=(12, 17), label="Length range")
|
126 |
],
|
127 |
outputs=[
|
128 |
gr.File(label="Download generated conotoxins"),
|
129 |
gr.Textbox(label="Progress")
|
130 |
-
]
|
|
|
131 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
iface.launch()
|
|
|
4 |
from gradio_rangeslider import RangeSlider
|
5 |
import pandas as pd
|
6 |
from utils import create_vocab, setup_seed
|
7 |
+
from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
|
8 |
+
import time
|
|
|
9 |
|
10 |
+
# 全局标志,用于控制停止
|
11 |
+
is_stopped = False
|
12 |
+
|
13 |
+
# 设置随机种子
|
14 |
+
seed = random.randint(0, 100000)
|
15 |
setup_seed(seed)
|
16 |
+
|
17 |
+
# 加载模型和数据
|
18 |
device = torch.device("cpu")
|
19 |
vocab_mlm = create_vocab()
|
20 |
vocab_mlm = add_tokens_to_vocab(vocab_mlm)
|
21 |
+
save_path = 'mlm-model-27.pt'
|
22 |
+
train_seqs = pd.read_csv('C0_seq.csv')
|
23 |
train_seq = train_seqs['Seq'].tolist()
|
24 |
model = torch.load(save_path, map_location=torch.device('cpu'))
|
25 |
model = model.to(device)
|
|
|
30 |
sampled_token = torch.multinomial(probabilities, 1)
|
31 |
return sampled_token
|
32 |
|
33 |
+
def stop_generation():
|
34 |
+
global is_stopped
|
35 |
+
is_stopped = True
|
36 |
+
return "Generation stopped."
|
37 |
+
|
38 |
def CTXGen(τ, g_num, length_range, progress=gr.Progress()):
|
39 |
+
global is_stopped
|
40 |
+
is_stopped = False # 重置停止标志
|
41 |
start, end = length_range
|
42 |
X1 = "X"
|
43 |
X2 = "X"
|
|
|
56 |
|
57 |
count = 0
|
58 |
gen_num = int(g_num)
|
59 |
+
NON_AA = ["B", "O", "U", "Z", "X", '<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
|
60 |
+
'<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>',
|
61 |
+
'<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>',
|
62 |
+
'<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>', '<α6β4>', '<α2β4>',
|
63 |
+
'<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
|
64 |
+
'<α9α10>', '<α6α3β4>', '<NaTTXS>', '<Na17>', '<high>', '<low>', '[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]']
|
65 |
|
66 |
start_time = time.time()
|
67 |
while count < gen_num:
|
68 |
+
if is_stopped: # 检查是否停止
|
69 |
+
return 'output.csv', f"Generation stopped. Generated {count} conotoxins."
|
70 |
+
|
71 |
if time.time() - start_time > 1200:
|
72 |
break
|
73 |
+
|
74 |
gen_len = random.randint(int(start), int(end))
|
75 |
X3 = "X" * gen_len
|
76 |
seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"]
|
77 |
vocab_mlm.token_to_idx["X"] = 4
|
78 |
|
79 |
padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
80 |
+
input_text = ["[MASK]" if i == "X" else i for i in padded_seq]
|
81 |
|
82 |
gen_length = len(input_text)
|
83 |
length = gen_length - sum(1 for x in input_text if x != '[MASK]')
|
84 |
|
85 |
for i in range(length):
|
86 |
+
if is_stopped: # 检查是否停止
|
87 |
+
return 'output.csv', f"Generation stopped. Generated {count} conotoxins."
|
88 |
+
|
89 |
_, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
90 |
idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
|
91 |
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
|
|
|
93 |
|
94 |
mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
|
95 |
mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
|
96 |
+
|
97 |
+
logits = model(idx_seq, idx_msa, attn_idx)
|
98 |
mask_logits = logits[0, mask_position.item(), :]
|
99 |
|
100 |
predicted_token_id = temperature_sampling(mask_logits, τ)
|
|
|
101 |
predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
|
102 |
input_text[mask_position.item()] = predicted_token
|
103 |
padded_seq[mask_position.item()] = predicted_token.strip()
|
104 |
new_seq = padded_seq
|
105 |
|
106 |
generated_seq = input_text
|
107 |
+
|
108 |
generated_seq[1] = "[MASK]"
|
109 |
generated_seq[2] = "[MASK]"
|
110 |
input_ids = vocab_mlm.__getitem__(generated_seq)
|
111 |
logits = model(torch.tensor([input_ids]).to(device), idx_msa)
|
112 |
+
|
113 |
cls_mask_logits = logits[0, 1, :]
|
114 |
act_mask_logits = logits[0, 2, :]
|
115 |
+
|
116 |
cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=1)
|
117 |
act_probability, act_mask_probs = torch.topk((torch.softmax(act_mask_logits, dim=-1)), k=1)
|
118 |
|
|
|
136 |
progress(count / gen_num, desc="Generating conotoxins...")
|
137 |
return 'output.csv', f"Generated {count} conotoxins."
|
138 |
|
139 |
+
# 定义 Gradio 界面
|
140 |
iface = gr.Interface(
|
141 |
fn=CTXGen,
|
142 |
inputs=[
|
143 |
gr.Slider(minimum=1, maximum=2, step=0.1, label="τ"),
|
144 |
+
gr.Dropdown(choices=[1, 10, 100], label="Number of generations"),
|
145 |
RangeSlider(minimum=8, maximum=50, step=1, value=(12, 17), label="Length range")
|
146 |
],
|
147 |
outputs=[
|
148 |
gr.File(label="Download generated conotoxins"),
|
149 |
gr.Textbox(label="Progress")
|
150 |
+
],
|
151 |
+
live=True
|
152 |
)
|
153 |
+
|
154 |
+
# 添加停止按钮
|
155 |
+
stop_button = gr.Button("Stop Generation")
|
156 |
+
stop_button.click(stop_generation, outputs=gr.Textbox(label="Status"))
|
157 |
+
|
158 |
+
# 启动 Gradio 应用
|
159 |
iface.launch()
|