Spaces:
Runtime error
Runtime error
Commit
·
539470a
1
Parent(s):
e4296b4
Display probabilities
Browse files
app.py
CHANGED
@@ -44,13 +44,16 @@ def predict(model_abstract, model_claims, tokenizer_abstract, tokenizer_claims,
|
|
44 |
attention_mask_claims = encoding_claims['attention_mask'].to(device)
|
45 |
|
46 |
with torch.no_grad():
|
47 |
-
outputs_abstract = model_abstract(input_ids=input_abstract
|
48 |
-
outputs_claims = model_claims(input_ids=input_claims
|
|
|
|
|
|
|
49 |
|
50 |
combined_prob = (outputs_abstract.logits.softmax(dim=1) + outputs_claims.logits.softmax(dim=1)) / 2
|
51 |
-
label = torch.argmax(combined_prob,
|
52 |
|
53 |
-
return label, combined_prob
|
54 |
|
55 |
|
56 |
if __name__ == '__main__':
|
|
|
44 |
attention_mask_claims = encoding_claims['attention_mask'].to(device)
|
45 |
|
46 |
with torch.no_grad():
|
47 |
+
outputs_abstract = model_abstract(input_ids=input_abstract)
|
48 |
+
outputs_claims = model_claims(input_ids=input_claims)
|
49 |
+
|
50 |
+
print(outputs_abstract.logits)
|
51 |
+
print(outputs_claims.logits)
|
52 |
|
53 |
combined_prob = (outputs_abstract.logits.softmax(dim=1) + outputs_claims.logits.softmax(dim=1)) / 2
|
54 |
+
label = torch.argmax(combined_prob, dim=1)
|
55 |
|
56 |
+
return label, combined_prob.tolist()[0]
|
57 |
|
58 |
|
59 |
if __name__ == '__main__':
|