oucgc1996 commited on
Commit
f3acc3b
·
verified ·
1 Parent(s): bbb5726

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -12
app.py CHANGED
@@ -12,15 +12,6 @@ is_stopped = False
12
  seed = random.randint(0, 100000)
13
  setup_seed(seed)
14
 
15
- device = torch.device("cpu")
16
- vocab_mlm = create_vocab()
17
- vocab_mlm = add_tokens_to_vocab(vocab_mlm)
18
- save_path = 'mlm-model-27.pt'
19
- train_seqs = pd.read_csv('C0_seq.csv')
20
- train_seq = train_seqs['Seq'].tolist()
21
- model = torch.load(save_path, map_location=torch.device('cpu'))
22
- model = model.to(device)
23
-
24
  def temperature_sampling(logits, temperature):
25
  logits = logits / temperature
26
  probabilities = torch.softmax(logits, dim=-1)
@@ -32,9 +23,19 @@ def stop_generation():
32
  is_stopped = True
33
  return "Generation stopped."
34
 
35
- def CTXGen(τ, g_num, length_range):
36
  global is_stopped
37
  is_stopped = False
 
 
 
 
 
 
 
 
 
 
38
  start, end = length_range
39
  X1 = "X"
40
  X2 = "X"
@@ -43,7 +44,6 @@ def CTXGen(τ, g_num, length_range):
43
  X6 = ""
44
  model.eval()
45
  with torch.no_grad():
46
- new_seq = None
47
  IDs = []
48
  generated_seqs = []
49
  generated_seqs_FINAL = []
@@ -63,6 +63,7 @@ def CTXGen(τ, g_num, length_range):
63
 
64
  start_time = time.time()
65
  while count < gen_num:
 
66
  if is_stopped:
67
  return pd.DataFrame(), "output.csv"
68
 
@@ -149,6 +150,7 @@ with gr.Blocks() as demo:
149
  τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
150
  g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
151
  length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range")
 
152
  with gr.Row():
153
  start_button = gr.Button("Start Generation")
154
  stop_button = gr.Button("Stop Generation")
@@ -157,7 +159,7 @@ with gr.Blocks() as demo:
157
  with gr.Row():
158
  output_file = gr.File(label="Download generated conotoxins")
159
 
160
- start_button.click(CTXGen, inputs=[τ, g_num, length_range], outputs=[output_df, output_file])
161
  stop_button.click(stop_generation, outputs=None)
162
 
163
  demo.launch()
 
12
  seed = random.randint(0, 100000)
13
  setup_seed(seed)
14
 
 
 
 
 
 
 
 
 
 
15
  def temperature_sampling(logits, temperature):
16
  logits = logits / temperature
17
  probabilities = torch.softmax(logits, dim=-1)
 
23
  is_stopped = True
24
  return "Generation stopped."
25
 
26
+ def CTXGen(τ, g_num, length_range, model_name):
27
  global is_stopped
28
  is_stopped = False
29
+
30
+ device = torch.device("cpu")
31
+ vocab_mlm = create_vocab()
32
+ vocab_mlm = add_tokens_to_vocab(vocab_mlm)
33
+ save_path = model_name
34
+ train_seqs = pd.read_csv('C0_seq.csv')
35
+ train_seq = train_seqs['Seq'].tolist()
36
+ model = torch.load(save_path, map_location=torch.device('cpu'))
37
+ model = model.to(device)
38
+
39
  start, end = length_range
40
  X1 = "X"
41
  X2 = "X"
 
44
  X6 = ""
45
  model.eval()
46
  with torch.no_grad():
 
47
  IDs = []
48
  generated_seqs = []
49
  generated_seqs_FINAL = []
 
63
 
64
  start_time = time.time()
65
  while count < gen_num:
66
+ new_seq = None
67
  if is_stopped:
68
  return pd.DataFrame(), "output.csv"
69
 
 
150
  τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
151
  g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
152
  length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range")
153
+ model_name = gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
154
  with gr.Row():
155
  start_button = gr.Button("Start Generation")
156
  stop_button = gr.Button("Stop Generation")
 
159
  with gr.Row():
160
  output_file = gr.File(label="Download generated conotoxins")
161
 
162
+ start_button.click(CTXGen, inputs=[τ, g_num, length_range, model_name], outputs=[output_df, output_file])
163
  stop_button.click(stop_generation, outputs=None)
164
 
165
  demo.launch()