a96123155 commited on
Commit
7a5615b
·
1 Parent(s): 82d4030
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -410,9 +410,10 @@ def mutated_seq(wt_seq, wt_label):
410
  print(f'Wild Type: Label = {wt_label}, Y_pred = {wt_pred.item()}, Y_prob = {wt_prob.item():.2%}')
411
 
412
  # print(n_mut, mlm_tok_num, n_designs_ep, n_sampling_designs_ep, n_mlm_recovery_sampling, mutate2stronger)
413
- pbar = tqdm(total=n_mut)
414
  mutated_seqs = []
415
  i = 1
 
416
  while i <= n_mut:
417
  if i == 1: seeds_ep = [wt_seq[1:]]
418
  seeds_next_ep, seeds_probs_next_ep, seeds_logits_next_ep = [], [], []
@@ -475,9 +476,10 @@ def mutated_seq(wt_seq, wt_label):
475
 
476
  seeds_ep = seeds_next_ep
477
  i += 1
478
- pbar.update(1)
479
- pbar.close()
480
-
 
481
  mutated_seqs.extend([(wt_seq[1:], wt_logit.item(), wt_prob.item(), 0)])
482
  mutated_seqs = sorted(mutated_seqs, key=lambda x: x[2], reverse=True)
483
  mutated_seqs = pd.DataFrame(mutated_seqs, columns = ['mutated_seq', 'predicted_logit', 'predicted_probability', 'mutated_num']).drop_duplicates('mutated_seq')
@@ -503,7 +505,7 @@ def read_raw(raw_input):
503
  return ids, sequences
504
 
505
  def predict_raw(raw_input):
506
- state_dict = torch.load('v2.7_LeidenContrastive_best_model_fold0.pt', map_location=torch.device(device))
507
  new_state_dict = OrderedDict()
508
 
509
  for k, v in state_dict.items():
 
410
  print(f'Wild Type: Label = {wt_label}, Y_pred = {wt_pred.item()}, Y_prob = {wt_prob.item():.2%}')
411
 
412
  # print(n_mut, mlm_tok_num, n_designs_ep, n_sampling_designs_ep, n_mlm_recovery_sampling, mutate2stronger)
413
+ # pbar = tqdm(total=n_mut)
414
  mutated_seqs = []
415
  i = 1
416
+ pbar = st.progress(i, text="mutated number of sequence")
417
  while i <= n_mut:
418
  if i == 1: seeds_ep = [wt_seq[1:]]
419
  seeds_next_ep, seeds_probs_next_ep, seeds_logits_next_ep = [], [], []
 
476
 
477
  seeds_ep = seeds_next_ep
478
  i += 1
479
+ # pbar.update(1)
480
+ pbar.progress(i/n_mut, text="Mutating")
481
+ # pbar.close()
482
+ st.success('Done', icon="✅")
483
  mutated_seqs.extend([(wt_seq[1:], wt_logit.item(), wt_prob.item(), 0)])
484
  mutated_seqs = sorted(mutated_seqs, key=lambda x: x[2], reverse=True)
485
  mutated_seqs = pd.DataFrame(mutated_seqs, columns = ['mutated_seq', 'predicted_logit', 'predicted_probability', 'mutated_num']).drop_duplicates('mutated_seq')
 
505
  return ids, sequences
506
 
507
  def predict_raw(raw_input):
508
+ state_dict = torch.load('model.pt', map_location=torch.device(device))
509
  new_state_dict = OrderedDict()
510
 
511
  for k, v in state_dict.items():