HemanM commited on
Commit
cdcb82a
·
verified ·
1 Parent(s): 7987693

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +5 -7
inference.py CHANGED
@@ -22,15 +22,13 @@ def predict(model, tokenizer, prompt, option1, option2, device):
22
  encoded = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt").to(device)
23
 
24
  with torch.no_grad():
25
- outputs = model(encoded["input_ids"])
26
-
27
- # Simple linear classifier logic
28
- logits = torch.nn.functional.linear(outputs, model.classifier.weight, model.classifier.bias)
29
- probs = torch.softmax(logits, dim=1)
30
  best = torch.argmax(probs).item()
31
 
32
  return {
33
  "choice": best,
34
- "confidence": probs[0][best].item(),
35
- "scores": probs[0].tolist(),
36
  }
 
22
  encoded = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt").to(device)
23
 
24
  with torch.no_grad():
25
+ logits = model(encoded["input_ids"]) # shape: [2, 1]
26
+
27
+ probs = torch.softmax(logits.squeeze(), dim=0) # shape: [2]
 
 
28
  best = torch.argmax(probs).item()
29
 
30
  return {
31
  "choice": best,
32
+ "confidence": probs[best].item(),
33
+ "scores": probs.tolist(),
34
  }