HemanM commited on
Commit
a457e2e
·
verified ·
1 Parent(s): 06c8ded

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +35 -15
inference.py CHANGED
@@ -1,25 +1,45 @@
1
- from transformers import AutoTokenizer
 
2
  from evo_model import EvoTransformerForClassification
3
- from init_save import initialize_and_save_model # Ensure this line is added
4
  import torch
5
 
6
- # Ensure model is initialized and saved BEFORE loading
7
- initialize_and_save_model("trained_model")
8
-
9
- # 🔁 Load tokenizer and model
10
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
11
  model = EvoTransformerForClassification.from_pretrained("trained_model")
12
  model.eval()
13
 
14
- def generate_response(goal, sol1, sol2):
15
- prompt = f"Goal: {goal}\nOption 1: {sol1}\nOption 2: {sol2}\nWhich is better?"
16
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
 
 
 
17
 
18
- if 'token_type_ids' in inputs:
19
- del inputs['token_type_ids'] # Evo doesn't use this
 
 
 
 
 
 
 
 
 
20
 
 
 
 
 
 
21
  with torch.no_grad():
22
  logits = model(**inputs)
23
-
24
- predicted = torch.argmax(logits, dim=1).item()
25
- return f"Option {predicted + 1} seems more reasonable based on EvoTransformer."
 
 
 
 
 
 
 
 
1
+ import openai
2
+ from transformers import BertTokenizer
3
  from evo_model import EvoTransformerForClassification
 
4
  import torch
5
 
6
+ # Load Evo model
 
 
 
 
7
  model = EvoTransformerForClassification.from_pretrained("trained_model")
8
  model.eval()
9
 
10
+ # Tokenizer (BERT-compatible)
11
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
12
+
13
+ # Set OpenAI key (assumes you have it set as ENV VAR or replace directly)
14
+ import os
15
+ openai.api_key = os.getenv("OPENAI_API_KEY")
16
 
17
+ def query_gpt35(prompt):
18
+ try:
19
+ response = openai.ChatCompletion.create(
20
+ model="gpt-3.5-turbo",
21
+ messages=[{"role": "user", "content": prompt}],
22
+ max_tokens=50,
23
+ temperature=0.3,
24
+ )
25
+ return response['choices'][0]['message']['content'].strip()
26
+ except Exception as e:
27
+ return f"[GPT-3.5 Error] {e}"
28
 
29
+ def generate_response(goal, option1, option2):
30
+ # Evo prediction
31
+ prompt = f"Goal: {goal}\nOption 1: {option1}\nOption 2: {option2}\nWhich is better?"
32
+ inputs = tokenizer([goal + " " + option1, goal + " " + option2],
33
+ return_tensors="pt", padding=True, truncation=True)
34
  with torch.no_grad():
35
  logits = model(**inputs)
36
+ pred = torch.argmax(logits, dim=1).item()
37
+ evo_result = option1 if pred == 0 else option2
38
+
39
+ # GPT-3.5 prediction
40
+ gpt_result = query_gpt35(prompt)
41
+
42
+ return {
43
+ "evo_suggestion": evo_result,
44
+ "gpt_suggestion": gpt_result
45
+ }