Update app.py
Browse files
app.py
CHANGED
@@ -12,15 +12,6 @@ is_stopped = False
|
|
12 |
seed = random.randint(0, 100000)
|
13 |
setup_seed(seed)
|
14 |
|
15 |
-
device = torch.device("cpu")
|
16 |
-
vocab_mlm = create_vocab()
|
17 |
-
vocab_mlm = add_tokens_to_vocab(vocab_mlm)
|
18 |
-
save_path = 'mlm-model-27.pt'
|
19 |
-
train_seqs = pd.read_csv('C0_seq.csv')
|
20 |
-
train_seq = train_seqs['Seq'].tolist()
|
21 |
-
model = torch.load(save_path, map_location=torch.device('cpu'))
|
22 |
-
model = model.to(device)
|
23 |
-
|
24 |
def temperature_sampling(logits, temperature):
|
25 |
logits = logits / temperature
|
26 |
probabilities = torch.softmax(logits, dim=-1)
|
@@ -32,9 +23,19 @@ def stop_generation():
|
|
32 |
is_stopped = True
|
33 |
return "Generation stopped."
|
34 |
|
35 |
-
def CTXGen(τ, g_num, length_range):
|
36 |
global is_stopped
|
37 |
is_stopped = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
start, end = length_range
|
39 |
X1 = "X"
|
40 |
X2 = "X"
|
@@ -43,7 +44,6 @@ def CTXGen(τ, g_num, length_range):
|
|
43 |
X6 = ""
|
44 |
model.eval()
|
45 |
with torch.no_grad():
|
46 |
-
new_seq = None
|
47 |
IDs = []
|
48 |
generated_seqs = []
|
49 |
generated_seqs_FINAL = []
|
@@ -63,6 +63,7 @@ def CTXGen(τ, g_num, length_range):
|
|
63 |
|
64 |
start_time = time.time()
|
65 |
while count < gen_num:
|
|
|
66 |
if is_stopped:
|
67 |
return pd.DataFrame(), "output.csv"
|
68 |
|
@@ -149,6 +150,7 @@ with gr.Blocks() as demo:
|
|
149 |
τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
|
150 |
g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
|
151 |
length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range")
|
|
|
152 |
with gr.Row():
|
153 |
start_button = gr.Button("Start Generation")
|
154 |
stop_button = gr.Button("Stop Generation")
|
@@ -157,7 +159,7 @@ with gr.Blocks() as demo:
|
|
157 |
with gr.Row():
|
158 |
output_file = gr.File(label="Download generated conotoxins")
|
159 |
|
160 |
-
start_button.click(CTXGen, inputs=[τ, g_num, length_range], outputs=[output_df, output_file])
|
161 |
stop_button.click(stop_generation, outputs=None)
|
162 |
|
163 |
demo.launch()
|
|
|
12 |
seed = random.randint(0, 100000)
|
13 |
setup_seed(seed)
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def temperature_sampling(logits, temperature):
|
16 |
logits = logits / temperature
|
17 |
probabilities = torch.softmax(logits, dim=-1)
|
|
|
23 |
is_stopped = True
|
24 |
return "Generation stopped."
|
25 |
|
26 |
+
def CTXGen(τ, g_num, length_range, model_name):
|
27 |
global is_stopped
|
28 |
is_stopped = False
|
29 |
+
|
30 |
+
device = torch.device("cpu")
|
31 |
+
vocab_mlm = create_vocab()
|
32 |
+
vocab_mlm = add_tokens_to_vocab(vocab_mlm)
|
33 |
+
save_path = model_name
|
34 |
+
train_seqs = pd.read_csv('C0_seq.csv')
|
35 |
+
train_seq = train_seqs['Seq'].tolist()
|
36 |
+
model = torch.load(save_path, map_location=torch.device('cpu'))
|
37 |
+
model = model.to(device)
|
38 |
+
|
39 |
start, end = length_range
|
40 |
X1 = "X"
|
41 |
X2 = "X"
|
|
|
44 |
X6 = ""
|
45 |
model.eval()
|
46 |
with torch.no_grad():
|
|
|
47 |
IDs = []
|
48 |
generated_seqs = []
|
49 |
generated_seqs_FINAL = []
|
|
|
63 |
|
64 |
start_time = time.time()
|
65 |
while count < gen_num:
|
66 |
+
new_seq = None
|
67 |
if is_stopped:
|
68 |
return pd.DataFrame(), "output.csv"
|
69 |
|
|
|
150 |
τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
|
151 |
g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
|
152 |
length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range")
|
153 |
+
model_name = gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
|
154 |
with gr.Row():
|
155 |
start_button = gr.Button("Start Generation")
|
156 |
stop_button = gr.Button("Stop Generation")
|
|
|
159 |
with gr.Row():
|
160 |
output_file = gr.File(label="Download generated conotoxins")
|
161 |
|
162 |
+
start_button.click(CTXGen, inputs=[τ, g_num, length_range, model_name], outputs=[output_df, output_file])
|
163 |
stop_button.click(stop_generation, outputs=None)
|
164 |
|
165 |
demo.launch()
|