oucgc1996 commited on
Commit
7e0b4b9
·
verified ·
1 Parent(s): 6c17270

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -109,7 +109,7 @@ def CTXGen(X1, X2, τ, g_num, length_range):
109
  input_ids = vocab_mlm.__getitem__(generated_seq)
110
  logits = model(torch.tensor([input_ids]).to(device), idx_msa)
111
  cls_mask_logits = logits[0, 1, :]
112
- cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=5)
113
 
114
  generated_seq[2] = "[MASK]"
115
  input_ids = vocab_mlm.__getitem__(generated_seq)
 
109
  input_ids = vocab_mlm.__getitem__(generated_seq)
110
  logits = model(torch.tensor([input_ids]).to(device), idx_msa)
111
  cls_mask_logits = logits[0, 1, :]
112
+ cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=10)
113
 
114
  generated_seq[2] = "[MASK]"
115
  input_ids = vocab_mlm.__getitem__(generated_seq)