HemanM commited on
Commit
f7f9a8a
Β·
verified Β·
1 Parent(s): 1622676

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +37 -40
inference.py CHANGED
@@ -1,75 +1,72 @@
 
1
  import torch
 
2
  from transformers import AutoTokenizer
3
- from evo_model import EvoTransformerV22
4
  from search_utils import web_search
5
  import openai
6
- import os
7
-
8
- # GPT Setup
9
- openai.api_key = os.getenv("OPENAI_API_KEY") # πŸ”’ Load securely from environment
10
 
11
- def load_model_and_tokenizer(model_path=None):
12
- model = EvoTransformerV22()
13
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
14
 
15
- # Smart load logic
16
- if model_path and os.path.exists(model_path):
17
- model.load_state_dict(torch.load(model_path, map_location="cpu"))
18
- print(f"πŸ” Loaded Evo model from {model_path}.")
19
- elif os.path.exists("trained_model/evo_retrained.pt"):
20
- model.load_state_dict(torch.load("trained_model/evo_retrained.pt", map_location="cpu"))
21
- print("πŸ” Loaded retrained Evo model.")
22
- elif os.path.exists("trained_model/evo_pretrained.pt"):
23
- model.load_state_dict(torch.load("trained_model/evo_pretrained.pt", map_location="cpu"))
24
- print("πŸ“¦ Loaded pretrained Evo model.")
25
- elif os.path.exists("evo_hellaswag.pt"):
26
- model.load_state_dict(torch.load("evo_hellaswag.pt", map_location="cpu"))
27
- print("πŸ“₯ Loaded default Evo model.")
28
- else:
29
- raise FileNotFoundError("❌ No Evo model file found.")
30
 
31
- model.eval()
32
- return model, tokenizer
33
-
34
- # Default model + tokenizer loaded once at startup
35
- model, tokenizer = load_model_and_tokenizer()
 
 
 
 
 
 
36
 
 
37
  def get_evo_response(query, options, user_context=""):
 
 
 
38
  context_texts = web_search(query) + ([user_context] if user_context else [])
39
  context_str = "\n".join(context_texts)
40
  input_pairs = [f"{query} [SEP] {opt} [CTX] {context_str}" for opt in options]
41
 
 
42
  scores = []
43
  for pair in input_pairs:
44
- encoded = tokenizer(pair, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
45
  with torch.no_grad():
46
- output = model(encoded["input_ids"])
47
- score = torch.sigmoid(output).item()
48
  scores.append(score)
49
 
50
  best_idx = int(scores[1] > scores[0])
51
  return (
52
- options[best_idx], # evo_ans
53
- max(scores), # evo_score (float)
54
- f"{options[0]}: {scores[0]:.3f} vs {options[1]}: {scores[1]:.3f}", # evo_reason
55
- context_str # evo_ctx
56
  )
57
 
58
-
59
  def get_gpt_response(query, user_context=""):
60
  try:
61
  context_block = f"\n\nContext:\n{user_context}" if user_context else ""
62
  response = openai.chat.completions.create(
63
  model="gpt-3.5-turbo",
64
- messages=[
65
- {"role": "user", "content": query + context_block}
66
- ],
67
  temperature=0.7,
68
  )
69
  return response.choices[0].message.content.strip()
70
  except Exception as e:
71
  return f"⚠️ GPT error:\n\n{str(e)}"
72
 
73
- # βœ… Compatibility exports
74
  def infer(query, options, user_context=""):
75
  return get_evo_response(query, options, user_context)
 
1
+ import os
2
  import torch
3
+ import torch.nn.functional as F
4
  from transformers import AutoTokenizer
5
+ from evo_model import EvoTransformerV22
6
  from search_utils import web_search
7
  import openai
8
+ import time
 
 
 
9
 
10
+ # πŸ” Load OpenAI API Key securely
11
+ openai.api_key = os.getenv("OPENAI_API_KEY")
 
12
 
13
+ # πŸ” Track model changes
14
+ MODEL_PATH = "evo_hellaswag.pt"
15
+ last_mod_time = 0
16
+ model = None
17
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # πŸ“¦ Load model with auto-reload if file is updated
20
+ def load_model():
21
+ global model, last_mod_time
22
+ current_mod_time = os.path.getmtime(MODEL_PATH)
23
+ if model is None or current_mod_time > last_mod_time:
24
+ model = EvoTransformerV22()
25
+ model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
26
+ model.eval()
27
+ last_mod_time = current_mod_time
28
+ print("πŸ” Evo model reloaded.")
29
+ return model
30
 
31
+ # 🧠 Evo decision logic with confidence scores
32
  def get_evo_response(query, options, user_context=""):
33
+ model = load_model()
34
+
35
+ # Retrieve RAG context + optional user input
36
  context_texts = web_search(query) + ([user_context] if user_context else [])
37
  context_str = "\n".join(context_texts)
38
  input_pairs = [f"{query} [SEP] {opt} [CTX] {context_str}" for opt in options]
39
 
40
+ # Encode both options and compute scores
41
  scores = []
42
  for pair in input_pairs:
43
+ encoded = tokenizer(pair, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
44
  with torch.no_grad():
45
+ logits = model(encoded["input_ids"])
46
+ score = torch.sigmoid(logits).item()
47
  scores.append(score)
48
 
49
  best_idx = int(scores[1] > scores[0])
50
  return (
51
+ options[best_idx], # πŸ”Ή Selected answer
52
+ max(scores), # πŸ”Ή Confidence score
53
+ f"{options[0]}: {scores[0]:.3f} vs {options[1]}: {scores[1]:.3f}", # πŸ”Ή Reasoning trace
54
+ context_str # πŸ”Ή Context used
55
  )
56
 
57
+ # πŸ€– GPT-3.5 backup or comparison
58
  def get_gpt_response(query, user_context=""):
59
  try:
60
  context_block = f"\n\nContext:\n{user_context}" if user_context else ""
61
  response = openai.chat.completions.create(
62
  model="gpt-3.5-turbo",
63
+ messages=[{"role": "user", "content": query + context_block}],
 
 
64
  temperature=0.7,
65
  )
66
  return response.choices[0].message.content.strip()
67
  except Exception as e:
68
  return f"⚠️ GPT error:\n\n{str(e)}"
69
 
70
+ # βœ… Final callable interface
71
  def infer(query, options, user_context=""):
72
  return get_evo_response(query, options, user_context)