Update app.py
Browse files
app.py
CHANGED
@@ -109,7 +109,7 @@ def CTXGen(X1, X2, τ, g_num, length_range):
|
|
109 |
input_ids = vocab_mlm.__getitem__(generated_seq)
|
110 |
logits = model(torch.tensor([input_ids]).to(device), idx_msa)
|
111 |
cls_mask_logits = logits[0, 1, :]
|
112 |
-
cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=
|
113 |
|
114 |
generated_seq[2] = "[MASK]"
|
115 |
input_ids = vocab_mlm.__getitem__(generated_seq)
|
|
|
109 |
input_ids = vocab_mlm.__getitem__(generated_seq)
|
110 |
logits = model(torch.tensor([input_ids]).to(device), idx_msa)
|
111 |
cls_mask_logits = logits[0, 1, :]
|
112 |
+
cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=10)
|
113 |
|
114 |
generated_seq[2] = "[MASK]"
|
115 |
input_ids = vocab_mlm.__getitem__(generated_seq)
|