Spaces:
Runtime error
Runtime error
Commit
·
539470a
1
Parent(s):
e4296b4
Display probabilities
Browse files
app.py
CHANGED
|
@@ -44,13 +44,16 @@ def predict(model_abstract, model_claims, tokenizer_abstract, tokenizer_claims,
|
|
| 44 |
attention_mask_claims = encoding_claims['attention_mask'].to(device)
|
| 45 |
|
| 46 |
with torch.no_grad():
|
| 47 |
-
outputs_abstract = model_abstract(input_ids=input_abstract
|
| 48 |
-
outputs_claims = model_claims(input_ids=input_claims
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
combined_prob = (outputs_abstract.logits.softmax(dim=1) + outputs_claims.logits.softmax(dim=1)) / 2
|
| 51 |
-
label = torch.argmax(combined_prob,
|
| 52 |
|
| 53 |
-
return label, combined_prob
|
| 54 |
|
| 55 |
|
| 56 |
if __name__ == '__main__':
|
|
|
|
| 44 |
attention_mask_claims = encoding_claims['attention_mask'].to(device)
|
| 45 |
|
| 46 |
with torch.no_grad():
|
| 47 |
+
outputs_abstract = model_abstract(input_ids=input_abstract)
|
| 48 |
+
outputs_claims = model_claims(input_ids=input_claims)
|
| 49 |
+
|
| 50 |
+
print(outputs_abstract.logits)
|
| 51 |
+
print(outputs_claims.logits)
|
| 52 |
|
| 53 |
combined_prob = (outputs_abstract.logits.softmax(dim=1) + outputs_claims.logits.softmax(dim=1)) / 2
|
| 54 |
+
label = torch.argmax(combined_prob, dim=1)
|
| 55 |
|
| 56 |
+
return label, combined_prob.tolist()[0]
|
| 57 |
|
| 58 |
|
| 59 |
if __name__ == '__main__':
|