Spaces:
Running
Running
a96123155
commited on
Commit
·
7a5615b
1
Parent(s):
82d4030
app
Browse files
app.py
CHANGED
@@ -410,9 +410,10 @@ def mutated_seq(wt_seq, wt_label):
|
|
410 |
print(f'Wild Type: Label = {wt_label}, Y_pred = {wt_pred.item()}, Y_prob = {wt_prob.item():.2%}')
|
411 |
|
412 |
# print(n_mut, mlm_tok_num, n_designs_ep, n_sampling_designs_ep, n_mlm_recovery_sampling, mutate2stronger)
|
413 |
-
pbar = tqdm(total=n_mut)
|
414 |
mutated_seqs = []
|
415 |
i = 1
|
|
|
416 |
while i <= n_mut:
|
417 |
if i == 1: seeds_ep = [wt_seq[1:]]
|
418 |
seeds_next_ep, seeds_probs_next_ep, seeds_logits_next_ep = [], [], []
|
@@ -475,9 +476,10 @@ def mutated_seq(wt_seq, wt_label):
|
|
475 |
|
476 |
seeds_ep = seeds_next_ep
|
477 |
i += 1
|
478 |
-
pbar.update(1)
|
479 |
-
|
480 |
-
|
|
|
481 |
mutated_seqs.extend([(wt_seq[1:], wt_logit.item(), wt_prob.item(), 0)])
|
482 |
mutated_seqs = sorted(mutated_seqs, key=lambda x: x[2], reverse=True)
|
483 |
mutated_seqs = pd.DataFrame(mutated_seqs, columns = ['mutated_seq', 'predicted_logit', 'predicted_probability', 'mutated_num']).drop_duplicates('mutated_seq')
|
@@ -503,7 +505,7 @@ def read_raw(raw_input):
|
|
503 |
return ids, sequences
|
504 |
|
505 |
def predict_raw(raw_input):
|
506 |
-
state_dict = torch.load('
|
507 |
new_state_dict = OrderedDict()
|
508 |
|
509 |
for k, v in state_dict.items():
|
|
|
410 |
print(f'Wild Type: Label = {wt_label}, Y_pred = {wt_pred.item()}, Y_prob = {wt_prob.item():.2%}')
|
411 |
|
412 |
# print(n_mut, mlm_tok_num, n_designs_ep, n_sampling_designs_ep, n_mlm_recovery_sampling, mutate2stronger)
|
413 |
+
# pbar = tqdm(total=n_mut)
|
414 |
mutated_seqs = []
|
415 |
i = 1
|
416 |
+
pbar = st.progress(i, text="mutated number of sequence")
|
417 |
while i <= n_mut:
|
418 |
if i == 1: seeds_ep = [wt_seq[1:]]
|
419 |
seeds_next_ep, seeds_probs_next_ep, seeds_logits_next_ep = [], [], []
|
|
|
476 |
|
477 |
seeds_ep = seeds_next_ep
|
478 |
i += 1
|
479 |
+
# pbar.update(1)
|
480 |
+
pbar.progress(i/n_mut, text="Mutating")
|
481 |
+
# pbar.close()
|
482 |
+
st.success('Done', icon="✅")
|
483 |
mutated_seqs.extend([(wt_seq[1:], wt_logit.item(), wt_prob.item(), 0)])
|
484 |
mutated_seqs = sorted(mutated_seqs, key=lambda x: x[2], reverse=True)
|
485 |
mutated_seqs = pd.DataFrame(mutated_seqs, columns = ['mutated_seq', 'predicted_logit', 'predicted_probability', 'mutated_num']).drop_duplicates('mutated_seq')
|
|
|
505 |
return ids, sequences
|
506 |
|
507 |
def predict_raw(raw_input):
|
508 |
+
state_dict = torch.load('model.pt', map_location=torch.device(device))
|
509 |
new_state_dict = OrderedDict()
|
510 |
|
511 |
for k, v in state_dict.items():
|