HemanM commited on
Commit
b530936
·
verified ·
1 Parent(s): d2a18e8

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +24 -9
inference.py CHANGED
@@ -1,18 +1,33 @@
1
- from evo_model import EvoTransformerForClassification
2
  from transformers import AutoTokenizer
 
3
  import torch
4
 
5
- # Load model and tokenizer once
6
- model = EvoTransformerForClassification.from_pretrained("trained_model")
7
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
 
8
  model.eval()
9
 
10
  def generate_response(goal, sol1, sol2):
11
- prompt = f"Goal: {goal} Option 1: {sol1} Option 2: {sol2}"
12
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  with torch.no_grad():
14
- outputs = model(**inputs)
15
- logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
16
- prediction = torch.argmax(logits, dim=1).item()
17
- return "Solution 1" if prediction == 0 else "Solution 2"
18
 
 
 
 
1
  from transformers import AutoTokenizer
2
+ from evo_model import EvoTransformerForClassification
3
  import torch
4
 
5
+ # Load tokenizer and model
 
6
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
7
+ model = EvoTransformerForClassification.from_pretrained("trained_model")
8
  model.eval()
9
 
10
  def generate_response(goal, sol1, sol2):
11
+ input_text = f"Goal: {goal}\nOption A: {sol1}\nOption B: {sol2}"
12
+
13
+ # Tokenize input
14
+ inputs = tokenizer(
15
+ input_text,
16
+ return_tensors="pt",
17
+ padding=True,
18
+ truncation=True,
19
+ )
20
+
21
+ # ✅ Filter out unwanted keys (e.g., token_type_ids)
22
+ filtered_inputs = {
23
+ k: v for k, v in inputs.items()
24
+ if k in ["input_ids", "attention_mask"]
25
+ }
26
+
27
+ # Predict
28
  with torch.no_grad():
29
+ outputs = model(**filtered_inputs)
30
+ logits = outputs.logits
31
+ predicted_class = torch.argmax(logits, dim=1).item()
 
32
 
33
+ return "A" if predicted_class == 0 else "B"