HemanM commited on
Commit
2b652a8
·
verified ·
1 Parent(s): 0eff587

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +2 -3
inference.py CHANGED
@@ -22,10 +22,9 @@ def predict(model, tokenizer, prompt, option1, option2, device):
22
  encoded = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt").to(device)
23
 
24
  with torch.no_grad():
25
- outputs = model(encoded["input_ids"]) # shape: [2, 512]
26
 
27
- # Classifier expects input shape [2, 512] → [2, 1], so we stack scores
28
- logits = model.classifier(outputs).squeeze(-1) # shape: [2]
29
  probs = torch.softmax(logits, dim=0)
30
  best = torch.argmax(probs).item()
31
 
 
22
  encoded = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt").to(device)
23
 
24
  with torch.no_grad():
25
+ outputs = model(encoded["input_ids"]) # already includes classifier
26
 
27
+ logits = outputs.squeeze(-1) # shape: [2]
 
28
  probs = torch.softmax(logits, dim=0)
29
  best = torch.argmax(probs).item()
30