oucgc1996 commited on
Commit
39d3f04
Β·
verified Β·
1 Parent(s): 0d11814

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -9,9 +9,6 @@ import time
9
 
10
  is_stopped = False
11
 
12
- seed = random.randint(0, 100000)
13
- setup_seed(seed)
14
-
15
  def temperature_sampling(logits, temperature):
16
  logits = logits / temperature
17
  probabilities = torch.softmax(logits, dim=-1)
@@ -23,7 +20,13 @@ def stop_generation():
23
  is_stopped = True
24
  return "Generation stopped."
25
 
26
- def CTXGen(Ο„, g_num, length_range, model_name):
 
 
 
 
 
 
27
  global is_stopped
28
  is_stopped = False
29
 
@@ -157,11 +160,13 @@ with gr.Blocks() as demo:
157
  gr.Markdown("βœ…**Number of generations**: if it is not completed within 1200 seconds, it will automatically stop.")
158
  gr.Markdown("βœ…**Length range**: expected length range of conotoxins generated")
159
  gr.Markdown("βœ…**Model**: model parameters trained at different stages of data augmentation. Please refer to the paper for details.")
 
160
  with gr.Row():
161
  Ο„ = gr.Slider(minimum=1, maximum=2, step=0.1, label="Ο„")
162
  g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
163
  length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range")
164
  model_name = gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
 
165
  with gr.Row():
166
  start_button = gr.Button("Start Generation")
167
  stop_button = gr.Button("Stop Generation")
@@ -170,7 +175,7 @@ with gr.Blocks() as demo:
170
  with gr.Row():
171
  output_df = gr.DataFrame(label="Generated Conotoxins")
172
 
173
- start_button.click(CTXGen, inputs=[Ο„, g_num, length_range, model_name], outputs=[output_file, output_df])
174
  stop_button.click(stop_generation, outputs=None)
175
 
176
  demo.launch()
 
9
 
10
  is_stopped = False
11
 
 
 
 
12
  def temperature_sampling(logits, temperature):
13
  logits = logits / temperature
14
  probabilities = torch.softmax(logits, dim=-1)
 
20
  is_stopped = True
21
  return "Generation stopped."
22
 
23
+ def CTXGen(Ο„, g_num, length_range, model_name, seed):
24
+ if seed =='random':
25
+ seed = random.randint(0,100000)
26
+ setup_seed(seed)
27
+ else:
28
+ setup_seed(int(seed))
29
+
30
  global is_stopped
31
  is_stopped = False
32
 
 
160
  gr.Markdown("βœ…**Number of generations**: if it is not completed within 1200 seconds, it will automatically stop.")
161
  gr.Markdown("βœ…**Length range**: expected length range of conotoxins generated")
162
  gr.Markdown("βœ…**Model**: model parameters trained at different stages of data augmentation. Please refer to the paper for details.")
163
+ gr.Markdown("βœ…**Seed**: Enter an integer as the random seed to ensure reproducible results. The default is random")
164
  with gr.Row():
165
  Ο„ = gr.Slider(minimum=1, maximum=2, step=0.1, label="Ο„")
166
  g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
167
  length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range")
168
  model_name = gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
169
+ seed = gr.Textbox(label="Seed", value="random")
170
  with gr.Row():
171
  start_button = gr.Button("Start Generation")
172
  stop_button = gr.Button("Stop Generation")
 
175
  with gr.Row():
176
  output_df = gr.DataFrame(label="Generated Conotoxins")
177
 
178
+ start_button.click(CTXGen, inputs=[Ο„, g_num, length_range, model_name, seed], outputs=[output_file, output_df])
179
  stop_button.click(stop_generation, outputs=None)
180
 
181
  demo.launch()