HemanM commited on
Commit
819cc50
Β·
verified Β·
1 Parent(s): 4f2bf95

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +10 -17
inference.py CHANGED
@@ -9,10 +9,8 @@ import time
9
  import psutil
10
  import platform
11
 
12
- # πŸ” Load OpenAI API Key
13
  openai.api_key = os.getenv("OPENAI_API_KEY")
14
 
15
- # πŸ“¦ Constants
16
  MODEL_PATH = "evo_hellaswag.pt"
17
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
18
  model = None
@@ -40,7 +38,6 @@ def evo_infer(query, options, user_context=""):
40
  if model is None:
41
  return "Model Error", 0.0, "Model not available", ""
42
 
43
- # βœ… Smart context logic: avoid search for numeric/logical queries
44
  def is_fact_or_math(q):
45
  q_lower = q.lower()
46
  return any(char.isdigit() for char in q_lower) or any(op in q_lower for op in ["+", "-", "*", "/", "=", "what is", "solve", "calculate"])
@@ -69,8 +66,7 @@ def evo_infer(query, options, user_context=""):
69
  context_str
70
  )
71
 
72
-
73
- # πŸ’¬ GPT fallback (used for comparison only)
74
  def get_gpt_response(query, user_context=""):
75
  try:
76
  context_block = f"\n\nContext:\n{user_context}" if user_context else ""
@@ -83,10 +79,9 @@ def get_gpt_response(query, user_context=""):
83
  except Exception as e:
84
  return f"⚠️ GPT error:\n{str(e)}"
85
 
86
- # πŸ€– Evo live chat prediction
87
  def evo_chat_predict(history, query, options):
88
  try:
89
- # Support list or DataFrame
90
  if isinstance(history, list):
91
  context = "\n".join(history[-6:])
92
  elif hasattr(history, "empty") and not history.empty:
@@ -104,7 +99,7 @@ def evo_chat_predict(history, query, options):
104
  "context_used": evo_ctx
105
  }
106
 
107
- # πŸ“Š Evo architecture stats
108
  def get_model_config():
109
  return {
110
  "num_layers": 6,
@@ -115,7 +110,7 @@ def get_model_config():
115
  "accuracy": "~64.5%"
116
  }
117
 
118
- # πŸ–₯️ System runtime stats
119
  def get_system_stats():
120
  gpu_info = torch.cuda.get_device_properties(0) if torch.cuda.is_available() else None
121
  memory = psutil.virtual_memory()
@@ -130,9 +125,9 @@ def get_system_stats():
130
  "platform": platform.platform()
131
  }
132
 
133
- # πŸ§ͺ Fine-tune Evo from feedback data (CSV or in-memory list)
134
- def retrain_from_feedback(feedback_data):
135
- if not feedback_data:
136
  return "⚠️ No feedback data to retrain from."
137
 
138
  model = load_model()
@@ -142,11 +137,10 @@ def retrain_from_feedback(feedback_data):
142
  model.train()
143
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
144
 
145
- for row in feedback_data:
146
  question, opt1, opt2, answer, *_ = row
147
- label = torch.tensor([1.0 if answer.strip() == opt2.strip() else 0.0]) # opt2 is class 1
148
 
149
- # Build input pair
150
  input_text = f"{question} [SEP] {opt2 if label.item() == 1 else opt1}"
151
  encoded = tokenizer(input_text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
152
 
@@ -157,5 +151,4 @@ def retrain_from_feedback(feedback_data):
157
  optimizer.zero_grad()
158
 
159
  torch.save(model.state_dict(), MODEL_PATH)
160
- return "βœ… Evo retrained from feedback."
161
-
 
9
  import psutil
10
  import platform
11
 
 
12
  openai.api_key = os.getenv("OPENAI_API_KEY")
13
 
 
14
  MODEL_PATH = "evo_hellaswag.pt"
15
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
16
  model = None
 
38
  if model is None:
39
  return "Model Error", 0.0, "Model not available", ""
40
 
 
41
  def is_fact_or_math(q):
42
  q_lower = q.lower()
43
  return any(char.isdigit() for char in q_lower) or any(op in q_lower for op in ["+", "-", "*", "/", "=", "what is", "solve", "calculate"])
 
66
  context_str
67
  )
68
 
69
+ # πŸ€– GPT fallback (for comparison)
 
70
  def get_gpt_response(query, user_context=""):
71
  try:
72
  context_block = f"\n\nContext:\n{user_context}" if user_context else ""
 
79
  except Exception as e:
80
  return f"⚠️ GPT error:\n{str(e)}"
81
 
82
+ # 🧠 Live Evo prediction logic
83
  def evo_chat_predict(history, query, options):
84
  try:
 
85
  if isinstance(history, list):
86
  context = "\n".join(history[-6:])
87
  elif hasattr(history, "empty") and not history.empty:
 
99
  "context_used": evo_ctx
100
  }
101
 
102
+ # πŸ“Š Evo model config metadata
103
  def get_model_config():
104
  return {
105
  "num_layers": 6,
 
110
  "accuracy": "~64.5%"
111
  }
112
 
113
+ # πŸ–₯️ Runtime stats
114
  def get_system_stats():
115
  gpu_info = torch.cuda.get_device_properties(0) if torch.cuda.is_available() else None
116
  memory = psutil.virtual_memory()
 
125
  "platform": platform.platform()
126
  }
127
 
128
+ # πŸ” Retrain from in-memory feedback_log
129
+ def retrain_from_feedback(feedback_log):
130
+ if not feedback_log:
131
  return "⚠️ No feedback data to retrain from."
132
 
133
  model = load_model()
 
137
  model.train()
138
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
139
 
140
+ for row in feedback_log:
141
  question, opt1, opt2, answer, *_ = row
142
+ label = torch.tensor([1.0 if answer.strip() == opt2.strip() else 0.0]) # opt2 = class 1
143
 
 
144
  input_text = f"{question} [SEP] {opt2 if label.item() == 1 else opt1}"
145
  encoded = tokenizer(input_text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
146
 
 
151
  optimizer.zero_grad()
152
 
153
  torch.save(model.state_dict(), MODEL_PATH)
154
+ return "βœ… Evo retrained and reloaded from memory."