Update app.py
Browse files
app.py
CHANGED
@@ -7,14 +7,11 @@ from utils import create_vocab, setup_seed
|
|
7 |
from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
|
8 |
import time
|
9 |
|
10 |
-
# 全局标志,用于控制停止
|
11 |
is_stopped = False
|
12 |
|
13 |
-
# 设置随机种子
|
14 |
seed = random.randint(0, 100000)
|
15 |
setup_seed(seed)
|
16 |
|
17 |
-
# 加载模型和数据
|
18 |
device = torch.device("cpu")
|
19 |
vocab_mlm = create_vocab()
|
20 |
vocab_mlm = add_tokens_to_vocab(vocab_mlm)
|
@@ -37,7 +34,7 @@ def stop_generation():
|
|
37 |
|
38 |
def CTXGen(τ, g_num, length_range):
|
39 |
global is_stopped
|
40 |
-
is_stopped = False
|
41 |
start, end = length_range
|
42 |
X1 = "X"
|
43 |
X2 = "X"
|
@@ -65,7 +62,7 @@ def CTXGen(τ, g_num, length_range):
|
|
65 |
|
66 |
start_time = time.time()
|
67 |
while count < gen_num:
|
68 |
-
if is_stopped:
|
69 |
return pd.DataFrame(), "output.csv"
|
70 |
|
71 |
if time.time() - start_time > 1200:
|
@@ -83,7 +80,7 @@ def CTXGen(τ, g_num, length_range):
|
|
83 |
length = gen_length - sum(1 for x in input_text if x != '[MASK]')
|
84 |
|
85 |
for i in range(length):
|
86 |
-
if is_stopped:
|
87 |
return pd.DataFrame(), "output.csv"
|
88 |
|
89 |
_, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
@@ -143,13 +140,12 @@ def CTXGen(τ, g_num, length_range):
|
|
143 |
yield out, "output.csv"
|
144 |
return out, "output.csv"
|
145 |
|
146 |
-
# 使用 gr.Blocks 构建界面
|
147 |
with gr.Blocks() as demo:
|
148 |
gr.Markdown("# Conotoxin Generation")
|
149 |
with gr.Row():
|
150 |
τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
|
151 |
-
g_num = gr.Dropdown(choices=[1, 10,
|
152 |
-
length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12,
|
153 |
with gr.Row():
|
154 |
start_button = gr.Button("Start Generation")
|
155 |
stop_button = gr.Button("Stop Generation")
|
@@ -158,9 +154,7 @@ with gr.Blocks() as demo:
|
|
158 |
with gr.Row():
|
159 |
output_file = gr.File(label="Download generated conotoxins")
|
160 |
|
161 |
-
# 绑定事件
|
162 |
start_button.click(CTXGen, inputs=[τ, g_num, length_range], outputs=[output_df, output_file])
|
163 |
stop_button.click(stop_generation, outputs=None)
|
164 |
|
165 |
-
# 启动 Gradio 应用
|
166 |
demo.launch()
|
|
|
7 |
from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
|
8 |
import time
|
9 |
|
|
|
10 |
is_stopped = False
|
11 |
|
|
|
12 |
seed = random.randint(0, 100000)
|
13 |
setup_seed(seed)
|
14 |
|
|
|
15 |
device = torch.device("cpu")
|
16 |
vocab_mlm = create_vocab()
|
17 |
vocab_mlm = add_tokens_to_vocab(vocab_mlm)
|
|
|
34 |
|
35 |
def CTXGen(τ, g_num, length_range):
|
36 |
global is_stopped
|
37 |
+
is_stopped = False
|
38 |
start, end = length_range
|
39 |
X1 = "X"
|
40 |
X2 = "X"
|
|
|
62 |
|
63 |
start_time = time.time()
|
64 |
while count < gen_num:
|
65 |
+
if is_stopped:
|
66 |
return pd.DataFrame(), "output.csv"
|
67 |
|
68 |
if time.time() - start_time > 1200:
|
|
|
80 |
length = gen_length - sum(1 for x in input_text if x != '[MASK]')
|
81 |
|
82 |
for i in range(length):
|
83 |
+
if is_stopped:
|
84 |
return pd.DataFrame(), "output.csv"
|
85 |
|
86 |
_, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
|
|
140 |
yield out, "output.csv"
|
141 |
return out, "output.csv"
|
142 |
|
|
|
143 |
with gr.Blocks() as demo:
|
144 |
gr.Markdown("# Conotoxin Generation")
|
145 |
with gr.Row():
|
146 |
τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
|
147 |
+
g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
|
148 |
+
length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range")
|
149 |
with gr.Row():
|
150 |
start_button = gr.Button("Start Generation")
|
151 |
stop_button = gr.Button("Stop Generation")
|
|
|
154 |
with gr.Row():
|
155 |
output_file = gr.File(label="Download generated conotoxins")
|
156 |
|
|
|
157 |
start_button.click(CTXGen, inputs=[τ, g_num, length_range], outputs=[output_df, output_file])
|
158 |
stop_button.click(stop_generation, outputs=None)
|
159 |
|
|
|
160 |
demo.launch()
|