MostoHF commited on
Commit
2472ad0
·
verified ·
1 Parent(s): 8b209d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -55,8 +55,8 @@ def inference(title, abstract, threshold=0.95):
55
  attention_mask = encoding["attention_mask"].to(device)
56
 
57
  with torch.no_grad():
58
- res_probs = class_model(input_ids, attention_mask) # shape: (1, 8)
59
-
60
  probs = res_probs.squeeze(0) # (8,)
61
  sorted_probs, sorted_indices = torch.sort(probs, descending=True)
62
 
 
55
  attention_mask = encoding["attention_mask"].to(device)
56
 
57
  with torch.no_grad():
58
+ res_probs = torch.exp(class_model(input_ids, attention_mask))
59
+
60
  probs = res_probs.squeeze(0) # (8,)
61
  sorted_probs, sorted_indices = torch.sort(probs, descending=True)
62