oucgc1996 commited on
Commit
fdecd8f
·
verified ·
1 Parent(s): f138db5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -11
app.py CHANGED
@@ -7,14 +7,11 @@ from utils import create_vocab, setup_seed
7
  from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
8
  import time
9
 
10
- # 全局标志,用于控制停止
11
  is_stopped = False
12
 
13
- # 设置随机种子
14
  seed = random.randint(0, 100000)
15
  setup_seed(seed)
16
 
17
- # 加载模型和数据
18
  device = torch.device("cpu")
19
  vocab_mlm = create_vocab()
20
  vocab_mlm = add_tokens_to_vocab(vocab_mlm)
@@ -37,7 +34,7 @@ def stop_generation():
37
 
38
  def CTXGen(τ, g_num, length_range):
39
  global is_stopped
40
- is_stopped = False # 重置停止标志
41
  start, end = length_range
42
  X1 = "X"
43
  X2 = "X"
@@ -65,7 +62,7 @@ def CTXGen(τ, g_num, length_range):
65
 
66
  start_time = time.time()
67
  while count < gen_num:
68
- if is_stopped: # 检查是否停止
69
  return pd.DataFrame(), "output.csv"
70
 
71
  if time.time() - start_time > 1200:
@@ -83,7 +80,7 @@ def CTXGen(τ, g_num, length_range):
83
  length = gen_length - sum(1 for x in input_text if x != '[MASK]')
84
 
85
  for i in range(length):
86
- if is_stopped: # 检查是否停止
87
  return pd.DataFrame(), "output.csv"
88
 
89
  _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
@@ -143,13 +140,12 @@ def CTXGen(τ, g_num, length_range):
143
  yield out, "output.csv"
144
  return out, "output.csv"
145
 
146
- # 使用 gr.Blocks 构建界面
147
  with gr.Blocks() as demo:
148
  gr.Markdown("# Conotoxin Generation")
149
  with gr.Row():
150
  τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
151
- g_num = gr.Dropdown(choices=[1, 10, 100], label="Number of generations")
152
- length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 17), label="Length range")
153
  with gr.Row():
154
  start_button = gr.Button("Start Generation")
155
  stop_button = gr.Button("Stop Generation")
@@ -158,9 +154,7 @@ with gr.Blocks() as demo:
158
  with gr.Row():
159
  output_file = gr.File(label="Download generated conotoxins")
160
 
161
- # 绑定事件
162
  start_button.click(CTXGen, inputs=[τ, g_num, length_range], outputs=[output_df, output_file])
163
  stop_button.click(stop_generation, outputs=None)
164
 
165
- # 启动 Gradio 应用
166
  demo.launch()
 
7
  from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
8
  import time
9
 
 
10
  is_stopped = False
11
 
 
12
  seed = random.randint(0, 100000)
13
  setup_seed(seed)
14
 
 
15
  device = torch.device("cpu")
16
  vocab_mlm = create_vocab()
17
  vocab_mlm = add_tokens_to_vocab(vocab_mlm)
 
34
 
35
  def CTXGen(τ, g_num, length_range):
36
  global is_stopped
37
+ is_stopped = False
38
  start, end = length_range
39
  X1 = "X"
40
  X2 = "X"
 
62
 
63
  start_time = time.time()
64
  while count < gen_num:
65
+ if is_stopped:
66
  return pd.DataFrame(), "output.csv"
67
 
68
  if time.time() - start_time > 1200:
 
80
  length = gen_length - sum(1 for x in input_text if x != '[MASK]')
81
 
82
  for i in range(length):
83
+ if is_stopped:
84
  return pd.DataFrame(), "output.csv"
85
 
86
  _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
 
140
  yield out, "output.csv"
141
  return out, "output.csv"
142
 
 
143
  with gr.Blocks() as demo:
144
  gr.Markdown("# Conotoxin Generation")
145
  with gr.Row():
146
  τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
147
+ g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
148
+ length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range")
149
  with gr.Row():
150
  start_button = gr.Button("Start Generation")
151
  stop_button = gr.Button("Stop Generation")
 
154
  with gr.Row():
155
  output_file = gr.File(label="Download generated conotoxins")
156
 
 
157
  start_button.click(CTXGen, inputs=[τ, g_num, length_range], outputs=[output_df, output_file])
158
  stop_button.click(stop_generation, outputs=None)
159
 
 
160
  demo.launch()