HemanM commited on
Commit
d838202
·
verified ·
1 Parent(s): 3489232

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +15 -24
inference.py CHANGED
@@ -1,34 +1,25 @@
1
- import torch
2
  from transformers import AutoTokenizer
3
  from evo_model import EvoTransformerForClassification
 
 
4
 
5
- # Load tokenizer and model
 
 
 
6
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
7
  model = EvoTransformerForClassification.from_pretrained("trained_model")
8
  model.eval()
9
 
10
- def generate_response(goal, option1, option2):
11
- prompt1 = f"Goal: {goal}\nOption: {option1}"
12
- prompt2 = f"Goal: {goal}\nOption: {option2}"
13
 
14
- inputs1 = tokenizer(prompt1, return_tensors="pt", padding=True, truncation=True)
15
- inputs2 = tokenizer(prompt2, return_tensors="pt", padding=True, truncation=True)
16
 
17
  with torch.no_grad():
18
- output1 = model(**inputs1)
19
- output2 = model(**inputs2)
20
-
21
- logits1 = output1["logits"]
22
- logits2 = output2["logits"]
23
-
24
- prob1 = torch.softmax(logits1, dim=1)[0][1].item()
25
- prob2 = torch.softmax(logits2, dim=1)[0][1].item()
26
-
27
- if prob1 > prob2:
28
- suggestion = "✅ Option 1 is more aligned with the goal."
29
- elif prob2 > prob1:
30
- suggestion = "✅ Option 2 is more aligned with the goal."
31
- else:
32
- suggestion = "⚖️ Both options are equally likely."
33
-
34
- return suggestion
 
 
1
  from transformers import AutoTokenizer
2
  from evo_model import EvoTransformerForClassification
3
+ from init_save import initialize_and_save_model # Ensure this line is added
4
+ import torch
5
 
6
+ # Ensure model is initialized and saved BEFORE loading
7
+ initialize_and_save_model("trained_model")
8
+
9
+ # 🔁 Load tokenizer and model
10
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
11
  model = EvoTransformerForClassification.from_pretrained("trained_model")
12
  model.eval()
13
 
14
+ def generate_response(goal, sol1, sol2):
15
+ prompt = f"Goal: {goal}\nOption 1: {sol1}\nOption 2: {sol2}\nWhich is better?"
16
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
17
 
18
+ if 'token_type_ids' in inputs:
19
+ del inputs['token_type_ids'] # Evo doesn't use this
20
 
21
  with torch.no_grad():
22
+ logits = model(**inputs)
23
+
24
+ predicted = torch.argmax(logits, dim=1).item()
25
+ return f"Option {predicted + 1} seems more reasonable based on EvoTransformer."