MostoHF commited on
Commit
2794a6d
·
verified ·
1 Parent(s): 2f5d2f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -55,7 +55,7 @@ def inference(title, abstract, threshold=0.95):
55
  attention_mask = encoding["attention_mask"].to(device)
56
 
57
  with torch.no_grad():
58
- res_probs = torch.exp(class_model(input_ids, attention_mask)) # shape: (1, 8)
59
 
60
  probs = res_probs.squeeze(0) # (8,)
61
  sorted_probs, sorted_indices = torch.sort(probs, descending=True)
 
55
  attention_mask = encoding["attention_mask"].to(device)
56
 
57
  with torch.no_grad():
58
+ res_probs = class_model(input_ids, attention_mask) # shape: (1, 8)
59
 
60
  probs = res_probs.squeeze(0) # (8,)
61
  sorted_probs, sorted_indices = torch.sort(probs, descending=True)