Update app.py
Browse files
app.py
CHANGED
@@ -55,7 +55,7 @@ def inference(title, abstract, threshold=0.95):
|
|
55 |
attention_mask = encoding["attention_mask"].to(device)
|
56 |
|
57 |
with torch.no_grad():
|
58 |
-
res_probs =
|
59 |
|
60 |
probs = res_probs.squeeze(0) # (8,)
|
61 |
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
|
|
55 |
attention_mask = encoding["attention_mask"].to(device)
|
56 |
|
57 |
with torch.no_grad():
|
58 |
+
res_probs = class_model(input_ids, attention_mask) # shape: (1, 8)
|
59 |
|
60 |
probs = res_probs.squeeze(0) # (8,)
|
61 |
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|