HemanM commited on
Commit
854864a
·
verified ·
1 Parent(s): 5feee28

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +24 -23
inference.py CHANGED
@@ -1,33 +1,34 @@
 
1
  from transformers import AutoTokenizer
2
  from evo_model import EvoTransformerForClassification
3
- import torch
4
 
5
  # Load tokenizer and model
6
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
7
  model = EvoTransformerForClassification.from_pretrained("trained_model")
8
  model.eval()
9
 
10
- def generate_response(goal, sol1, sol2):
11
- input_text = f"Goal: {goal}\nOption A: {sol1}\nOption B: {sol2}"
12
-
13
- # Tokenize input
14
- inputs = tokenizer(
15
- input_text,
16
- return_tensors="pt",
17
- padding=True,
18
- truncation=True,
19
- )
20
-
21
- # ✅ Filter out unwanted keys (e.g., token_type_ids)
22
- filtered_inputs = {
23
- k: v for k, v in inputs.items()
24
- if k in ["input_ids", "attention_mask"]
25
- }
26
-
27
- # Predict
28
  with torch.no_grad():
29
- outputs = model(**filtered_inputs)
30
- logits = outputs.logits
31
- predicted_class = torch.argmax(logits, dim=1).item()
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- return "A" if predicted_class == 0 else "B"
 
1
+ import torch
2
  from transformers import AutoTokenizer
3
  from evo_model import EvoTransformerForClassification
 
4
 
5
  # Load tokenizer and model
6
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
7
  model = EvoTransformerForClassification.from_pretrained("trained_model")
8
  model.eval()
9
 
10
+ def generate_response(goal, option1, option2):
11
+ prompt1 = f"Goal: {goal}\nOption: {option1}"
12
+ prompt2 = f"Goal: {goal}\nOption: {option2}"
13
+
14
+ inputs1 = tokenizer(prompt1, return_tensors="pt", padding=True, truncation=True)
15
+ inputs2 = tokenizer(prompt2, return_tensors="pt", padding=True, truncation=True)
16
+
 
 
 
 
 
 
 
 
 
 
 
17
  with torch.no_grad():
18
+ output1 = model(**inputs1)
19
+ output2 = model(**inputs2)
20
+
21
+ logits1 = output1["logits"]
22
+ logits2 = output2["logits"]
23
+
24
+ prob1 = torch.softmax(logits1, dim=1)[0][1].item()
25
+ prob2 = torch.softmax(logits2, dim=1)[0][1].item()
26
+
27
+ if prob1 > prob2:
28
+ suggestion = "✅ Option 1 is more aligned with the goal."
29
+ elif prob2 > prob1:
30
+ suggestion = "✅ Option 2 is more aligned with the goal."
31
+ else:
32
+ suggestion = "⚖️ Both options are equally likely."
33
 
34
+ return suggestion