HemanM commited on
Commit
5b3d26d
·
verified ·
1 Parent(s): 5a9508e

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +35 -39
inference.py CHANGED
@@ -1,51 +1,47 @@
1
- # inference.py
2
-
3
  import torch
4
- import torch.nn.functional as F
5
  from transformers import AutoTokenizer
6
- from evo_model import EvoTransformerV22
 
7
  import openai
8
 
9
- # Load tokenizer
10
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
11
-
12
- # Load EvoTransformer model
13
  model = EvoTransformerV22()
14
  model.load_state_dict(torch.load("evo_hellaswag.pt", map_location="cpu"))
15
  model.eval()
16
 
17
- # GPT-3.5 API
18
- openai.api_key = "sk-..." # Replace with your key
19
-
20
- def get_evo_response(question, option1, option2):
21
- pair1 = f"{question} {option1}"
22
- pair2 = f"{question} {option2}"
23
-
24
- def score(pair):
25
- encoded = tokenizer(pair, return_tensors="pt", padding=True, truncation=True, max_length=128)
26
- with torch.no_grad():
27
- logits = model(encoded["input_ids"])
28
- prob = torch.sigmoid(logits).item()
29
- return prob
30
-
31
- score1 = score(pair1)
32
- score2 = score(pair2)
33
 
34
- better = option1 if score1 > score2 else option2
35
- confidence = max(score1, score2)
36
 
37
- return better, confidence, score1, score2
 
 
 
38
 
39
- def get_gpt_response(question, option1, option2):
40
- prompt = (
41
- f"Question: {question}\n"
42
- f"Option 1: {option1}\n"
43
- f"Option 2: {option2}\n"
44
- f"Which option makes more sense and why?"
45
- )
46
- response = openai.ChatCompletion.create(
47
- model="gpt-3.5-turbo",
48
- messages=[{"role": "user", "content": prompt}],
49
- temperature=0.7
 
 
 
50
  )
51
- return response.choices[0].message.content.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
 
2
  from transformers import AutoTokenizer
3
+ from model import EvoTransformerV22
4
+ from search_utils import web_search
5
  import openai
6
 
7
+ # Load Evo model and tokenizer
 
 
 
8
  model = EvoTransformerV22()
9
  model.load_state_dict(torch.load("evo_hellaswag.pt", map_location="cpu"))
10
  model.eval()
11
 
12
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # GPT Setup
15
+ openai.api_key = "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" # 🔑 Set your actual key securely
16
 
17
+ def get_evo_response(query, options, user_context=""):
18
+ context_texts = web_search(query) + ([user_context] if user_context else [])
19
+ context_str = "\n".join(context_texts)
20
+ input_pairs = [f"{query} [SEP] {opt} [CTX] {context_str}" for opt in options]
21
 
22
+ scores = []
23
+ for pair in input_pairs:
24
+ encoded = tokenizer(pair, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
25
+ with torch.no_grad():
26
+ output = model(encoded["input_ids"])
27
+ score = torch.sigmoid(output).item()
28
+ scores.append(score)
29
+
30
+ best_idx = int(scores[1] > scores[0])
31
+ return (
32
+ options[best_idx],
33
+ f"{options[0]}: {scores[0]:.3f} vs {options[1]}: {scores[1]:.3f}",
34
+ max(scores),
35
+ context_str
36
  )
37
+
38
+ def get_gpt_response(query, user_context=""):
39
+ try:
40
+ context_block = f"\n\nContext:\n{user_context}" if user_context else ""
41
+ completion = openai.ChatCompletion.create(
42
+ model="gpt-3.5-turbo",
43
+ messages=[{"role": "user", "content": query + context_block}]
44
+ )
45
+ return completion.choices[0].message.content.strip()
46
+ except Exception as e:
47
+ return f"⚠️ GPT error: {str(e)}"