HemanM commited on
Commit
7a7ebad
·
verified ·
1 Parent(s): a568566

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +43 -46
inference.py CHANGED
@@ -1,54 +1,51 @@
 
 
1
  import torch
2
- import openai
3
- import os
4
  from transformers import AutoTokenizer
5
- from evo_model import EvoTransformerV22
6
- from rag_utils import extract_text_from_file
7
- from search_utils import web_search
8
 
 
9
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
 
 
10
  model = EvoTransformerV22()
11
  model.load_state_dict(torch.load("evo_hellaswag.pt", map_location="cpu"))
12
  model.eval()
13
 
14
- def format_input(question, options, context, web_results):
15
- prompt = f"{question}\n"
16
- if context:
17
- prompt += f"\nContext:\n{context}\n"
18
- if web_results:
19
- prompt += f"\nWeb Search Results:\n" + "\n".join(web_results)
20
- prompt += "\nOptions:\n"
21
- for idx, opt in enumerate(options):
22
- prompt += f"{idx+1}. {opt}\n"
23
- return prompt.strip()
24
-
25
- def get_evo_response(question, context, options, enable_search=True):
26
- web_results = web_search(question) if enable_search else []
27
- input_text = format_input(question, options, context, web_results)
28
- encoded = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=256)
29
- with torch.no_grad():
30
- logits = model(encoded["input_ids"])
31
- probs = torch.softmax(logits, dim=1).squeeze()
32
- pred_index = torch.argmax(probs).item()
33
- confidence = probs[pred_index].item()
34
-
35
- suggestion = options[pred_index] if pred_index < len(options) else "N/A"
36
- evo_reasoning = f"Evo suggests: **{suggestion}** (Confidence: {confidence:.2f})\n\nContext used:\n" + "\n".join(web_results)
37
- return suggestion, evo_reasoning
38
-
39
- def get_gpt_response(question, context, options):
40
- openai.api_key = os.getenv("OPENAI_API_KEY", "")
41
- formatted_options = "\n".join([f"{i+1}. {opt}" for i, opt in enumerate(options)])
42
- prompt = f"Question: {question}\n\nContext:\n{context}\n\nOptions:\n{formatted_options}\n\nWhich option makes the most sense and why?"
43
-
44
- try:
45
- response = openai.ChatCompletion.create(
46
- model="gpt-3.5-turbo",
47
- messages=[
48
- {"role": "system", "content": "You are a helpful reasoning assistant."},
49
- {"role": "user", "content": prompt}
50
- ]
51
- )
52
- return response['choices'][0]['message']['content']
53
- except Exception as e:
54
- return f"⚠️ GPT error: {str(e)}"
 
1
+ # inference.py
2
+
3
  import torch
4
+ import torch.nn.functional as F
 
5
  from transformers import AutoTokenizer
6
+ from model import EvoTransformerV22
7
+ import openai
 
8
 
9
+ # Load tokenizer
10
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
11
+
12
+ # Load EvoTransformer model
13
  model = EvoTransformerV22()
14
  model.load_state_dict(torch.load("evo_hellaswag.pt", map_location="cpu"))
15
  model.eval()
16
 
17
+ # GPT-3.5 API
18
+ openai.api_key = "sk-..." # Replace with your key
19
+
20
+ def get_evo_response(question, option1, option2):
21
+ pair1 = f"{question} {option1}"
22
+ pair2 = f"{question} {option2}"
23
+
24
+ def score(pair):
25
+ encoded = tokenizer(pair, return_tensors="pt", padding=True, truncation=True, max_length=128)
26
+ with torch.no_grad():
27
+ logits = model(encoded["input_ids"])
28
+ prob = torch.sigmoid(logits).item()
29
+ return prob
30
+
31
+ score1 = score(pair1)
32
+ score2 = score(pair2)
33
+
34
+ better = option1 if score1 > score2 else option2
35
+ confidence = max(score1, score2)
36
+
37
+ return better, confidence, score1, score2
38
+
39
+ def get_gpt_response(question, option1, option2):
40
+ prompt = (
41
+ f"Question: {question}\n"
42
+ f"Option 1: {option1}\n"
43
+ f"Option 2: {option2}\n"
44
+ f"Which option makes more sense and why?"
45
+ )
46
+ response = openai.ChatCompletion.create(
47
+ model="gpt-3.5-turbo",
48
+ messages=[{"role": "user", "content": prompt}],
49
+ temperature=0.7
50
+ )
51
+ return response.choices[0].message.content.strip()