HemanM commited on
Commit
f87535f
·
verified ·
1 Parent(s): 9c277d8

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +4 -2
inference.py CHANGED
@@ -22,9 +22,11 @@ def predict(model, tokenizer, prompt, option1, option2, device):
22
  encoded = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt").to(device)
23
 
24
  with torch.no_grad():
25
- logits = model(encoded["input_ids"]) # shape: [2, 1]
26
 
27
- probs = torch.softmax(logits.squeeze(), dim=0) # shape: [2]
 
 
28
  best = torch.argmax(probs).item()
29
 
30
  return {
 
22
  encoded = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt").to(device)
23
 
24
  with torch.no_grad():
25
+ outputs = model(encoded["input_ids"]) # shape: [2, 512]
26
 
27
+ # Classifier expects input shape [2, 512] → [2, 1], so we stack scores
28
+ logits = model.classifier(outputs).squeeze(-1) # shape: [2]
29
+ probs = torch.softmax(logits, dim=0)
30
  best = torch.argmax(probs).item()
31
 
32
  return {