HemanM commited on
Commit
541d702
Β·
verified Β·
1 Parent(s): b0ba5ba

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +34 -27
inference.py CHANGED
@@ -10,34 +10,40 @@ import time
10
  # πŸ” Load OpenAI API Key securely
11
  openai.api_key = os.getenv("OPENAI_API_KEY")
12
 
13
- # πŸ” Track model changes
14
  MODEL_PATH = "evo_hellaswag.pt"
15
- last_mod_time = 0
16
- model = None
17
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
 
 
18
 
19
- # πŸ“¦ Load model with auto-reload if file is updated
20
  def load_model():
21
  global model, last_mod_time
22
- current_mod_time = os.path.getmtime(MODEL_PATH)
23
- if model is None or current_mod_time > last_mod_time:
24
- model = EvoTransformerV22()
25
- model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
26
- model.eval()
27
- last_mod_time = current_mod_time
28
- print("πŸ” Evo model reloaded.")
 
 
 
 
29
  return model
30
 
31
- # 🧠 Evo decision logic with confidence scores
32
  def get_evo_response(query, options, user_context=""):
33
  model = load_model()
 
 
34
 
35
- # Retrieve RAG context + optional user input
36
  context_texts = web_search(query) + ([user_context] if user_context else [])
37
  context_str = "\n".join(context_texts)
38
  input_pairs = [f"{query} [SEP] {opt} [CTX] {context_str}" for opt in options]
39
 
40
- # Encode both options and compute scores
41
  scores = []
42
  for pair in input_pairs:
43
  encoded = tokenizer(pair, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
@@ -48,13 +54,13 @@ def get_evo_response(query, options, user_context=""):
48
 
49
  best_idx = int(scores[1] > scores[0])
50
  return (
51
- options[best_idx], # πŸ”Ή Selected answer
52
- max(scores), # πŸ”Ή Confidence score
53
- f"{options[0]}: {scores[0]:.3f} vs {options[1]}: {scores[1]:.3f}", # πŸ”Ή Reasoning trace
54
- context_str # πŸ”Ή Context used
55
  )
56
 
57
- # πŸ€– GPT-3.5 backup or comparison
58
  def get_gpt_response(query, user_context=""):
59
  try:
60
  context_block = f"\n\nContext:\n{user_context}" if user_context else ""
@@ -67,17 +73,10 @@ def get_gpt_response(query, user_context=""):
67
  except Exception as e:
68
  return f"⚠️ GPT error:\n\n{str(e)}"
69
 
70
- # βœ… Final callable interface
71
- def infer(query, options, user_context=""):
72
- return get_evo_response(query, options, user_context)
73
-
74
- # 🧠 Unified chat-style interface for EvoRAG
75
  def evo_chat_predict(history, query, options):
76
- # Use the last few exchanges as context (up to 3 pairs)
77
  context = "\n".join(history[-6:]) if history else ""
78
-
79
  evo_ans, evo_score, evo_reason, evo_ctx = get_evo_response(query, options, context)
80
-
81
  return {
82
  "answer": evo_ans,
83
  "confidence": round(evo_score, 3),
@@ -85,3 +84,11 @@ def evo_chat_predict(history, query, options):
85
  "context_used": evo_ctx
86
  }
87
 
 
 
 
 
 
 
 
 
 
10
  # πŸ” Load OpenAI API Key securely
11
  openai.api_key = os.getenv("OPENAI_API_KEY")
12
 
13
+ # πŸ“¦ Constants
14
  MODEL_PATH = "evo_hellaswag.pt"
 
 
15
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
16
+ model = None
17
+ last_mod_time = 0
18
 
19
+ # πŸ” Reload model if changed on disk
20
  def load_model():
21
  global model, last_mod_time
22
+ try:
23
+ current_mod_time = os.path.getmtime(MODEL_PATH)
24
+ if model is None or current_mod_time > last_mod_time:
25
+ model = EvoTransformerV22()
26
+ model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
27
+ model.eval()
28
+ last_mod_time = current_mod_time
29
+ print("πŸ” Evo model reloaded.")
30
+ except Exception as e:
31
+ print(f"❌ Error loading Evo model: {e}")
32
+ model = None
33
  return model
34
 
35
+ # 🧠 Evo logic
36
  def get_evo_response(query, options, user_context=""):
37
  model = load_model()
38
+ if model is None:
39
+ return "Error", 0.0, "Model failed to load", ""
40
 
41
+ # Retrieve web search + optional user context
42
  context_texts = web_search(query) + ([user_context] if user_context else [])
43
  context_str = "\n".join(context_texts)
44
  input_pairs = [f"{query} [SEP] {opt} [CTX] {context_str}" for opt in options]
45
 
46
+ # Encode and score each option
47
  scores = []
48
  for pair in input_pairs:
49
  encoded = tokenizer(pair, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
 
54
 
55
  best_idx = int(scores[1] > scores[0])
56
  return (
57
+ options[best_idx], # βœ… Evo's answer
58
+ max(scores), # βœ… Confidence
59
+ f"{options[0]}: {scores[0]:.3f} vs {options[1]}: {scores[1]:.3f}", # βœ… Reasoning trace
60
+ context_str # βœ… Context used
61
  )
62
 
63
+ # πŸ”„ GPT backup response
64
  def get_gpt_response(query, user_context=""):
65
  try:
66
  context_block = f"\n\nContext:\n{user_context}" if user_context else ""
 
73
  except Exception as e:
74
  return f"⚠️ GPT error:\n\n{str(e)}"
75
 
76
+ # 🎯 Used by app.py to display Evo live output
 
 
 
 
77
  def evo_chat_predict(history, query, options):
 
78
  context = "\n".join(history[-6:]) if history else ""
 
79
  evo_ans, evo_score, evo_reason, evo_ctx = get_evo_response(query, options, context)
 
80
  return {
81
  "answer": evo_ans,
82
  "confidence": round(evo_score, 3),
 
84
  "context_used": evo_ctx
85
  }
86
 
87
+ # πŸ“Š Returns current Evo architecture stats (for UI display)
88
+ def get_model_config():
89
+ return {
90
+ "num_layers": 6,
91
+ "num_heads": 8,
92
+ "ffn_dim": 1024,
93
+ "memory_enabled": True
94
+ }