Spaces:
Sleeping
Sleeping
Update inference.py
Browse files- inference.py +34 -27
inference.py
CHANGED
@@ -10,34 +10,40 @@ import time
|
|
10 |
# π Load OpenAI API Key securely
|
11 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
12 |
|
13 |
-
#
|
14 |
MODEL_PATH = "evo_hellaswag.pt"
|
15 |
-
last_mod_time = 0
|
16 |
-
model = None
|
17 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
|
|
|
|
18 |
|
19 |
-
#
|
20 |
def load_model():
|
21 |
global model, last_mod_time
|
22 |
-
|
23 |
-
|
24 |
-
model
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
29 |
return model
|
30 |
|
31 |
-
# π§ Evo
|
32 |
def get_evo_response(query, options, user_context=""):
|
33 |
model = load_model()
|
|
|
|
|
34 |
|
35 |
-
# Retrieve
|
36 |
context_texts = web_search(query) + ([user_context] if user_context else [])
|
37 |
context_str = "\n".join(context_texts)
|
38 |
input_pairs = [f"{query} [SEP] {opt} [CTX] {context_str}" for opt in options]
|
39 |
|
40 |
-
# Encode
|
41 |
scores = []
|
42 |
for pair in input_pairs:
|
43 |
encoded = tokenizer(pair, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
|
@@ -48,13 +54,13 @@ def get_evo_response(query, options, user_context=""):
|
|
48 |
|
49 |
best_idx = int(scores[1] > scores[0])
|
50 |
return (
|
51 |
-
options[best_idx],
|
52 |
-
max(scores),
|
53 |
-
f"{options[0]}: {scores[0]:.3f} vs {options[1]}: {scores[1]:.3f}", #
|
54 |
-
context_str
|
55 |
)
|
56 |
|
57 |
-
#
|
58 |
def get_gpt_response(query, user_context=""):
|
59 |
try:
|
60 |
context_block = f"\n\nContext:\n{user_context}" if user_context else ""
|
@@ -67,17 +73,10 @@ def get_gpt_response(query, user_context=""):
|
|
67 |
except Exception as e:
|
68 |
return f"β οΈ GPT error:\n\n{str(e)}"
|
69 |
|
70 |
-
#
|
71 |
-
def infer(query, options, user_context=""):
|
72 |
-
return get_evo_response(query, options, user_context)
|
73 |
-
|
74 |
-
# π§ Unified chat-style interface for EvoRAG
|
75 |
def evo_chat_predict(history, query, options):
|
76 |
-
# Use the last few exchanges as context (up to 3 pairs)
|
77 |
context = "\n".join(history[-6:]) if history else ""
|
78 |
-
|
79 |
evo_ans, evo_score, evo_reason, evo_ctx = get_evo_response(query, options, context)
|
80 |
-
|
81 |
return {
|
82 |
"answer": evo_ans,
|
83 |
"confidence": round(evo_score, 3),
|
@@ -85,3 +84,11 @@ def evo_chat_predict(history, query, options):
|
|
85 |
"context_used": evo_ctx
|
86 |
}
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
# π Load OpenAI API Key securely
|
11 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
12 |
|
13 |
+
# π¦ Constants
|
14 |
MODEL_PATH = "evo_hellaswag.pt"
|
|
|
|
|
15 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
16 |
+
model = None
|
17 |
+
last_mod_time = 0
|
18 |
|
19 |
+
# π Reload model if changed on disk
|
20 |
def load_model():
|
21 |
global model, last_mod_time
|
22 |
+
try:
|
23 |
+
current_mod_time = os.path.getmtime(MODEL_PATH)
|
24 |
+
if model is None or current_mod_time > last_mod_time:
|
25 |
+
model = EvoTransformerV22()
|
26 |
+
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
|
27 |
+
model.eval()
|
28 |
+
last_mod_time = current_mod_time
|
29 |
+
print("π Evo model reloaded.")
|
30 |
+
except Exception as e:
|
31 |
+
print(f"β Error loading Evo model: {e}")
|
32 |
+
model = None
|
33 |
return model
|
34 |
|
35 |
+
# π§ Evo logic
|
36 |
def get_evo_response(query, options, user_context=""):
|
37 |
model = load_model()
|
38 |
+
if model is None:
|
39 |
+
return "Error", 0.0, "Model failed to load", ""
|
40 |
|
41 |
+
# Retrieve web search + optional user context
|
42 |
context_texts = web_search(query) + ([user_context] if user_context else [])
|
43 |
context_str = "\n".join(context_texts)
|
44 |
input_pairs = [f"{query} [SEP] {opt} [CTX] {context_str}" for opt in options]
|
45 |
|
46 |
+
# Encode and score each option
|
47 |
scores = []
|
48 |
for pair in input_pairs:
|
49 |
encoded = tokenizer(pair, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
|
|
|
54 |
|
55 |
best_idx = int(scores[1] > scores[0])
|
56 |
return (
|
57 |
+
options[best_idx], # β
Evo's answer
|
58 |
+
max(scores), # β
Confidence
|
59 |
+
f"{options[0]}: {scores[0]:.3f} vs {options[1]}: {scores[1]:.3f}", # β
Reasoning trace
|
60 |
+
context_str # β
Context used
|
61 |
)
|
62 |
|
63 |
+
# π GPT backup response
|
64 |
def get_gpt_response(query, user_context=""):
|
65 |
try:
|
66 |
context_block = f"\n\nContext:\n{user_context}" if user_context else ""
|
|
|
73 |
except Exception as e:
|
74 |
return f"β οΈ GPT error:\n\n{str(e)}"
|
75 |
|
76 |
+
# π― Used by app.py to display Evo live output
|
|
|
|
|
|
|
|
|
77 |
def evo_chat_predict(history, query, options):
|
|
|
78 |
context = "\n".join(history[-6:]) if history else ""
|
|
|
79 |
evo_ans, evo_score, evo_reason, evo_ctx = get_evo_response(query, options, context)
|
|
|
80 |
return {
|
81 |
"answer": evo_ans,
|
82 |
"confidence": round(evo_score, 3),
|
|
|
84 |
"context_used": evo_ctx
|
85 |
}
|
86 |
|
87 |
+
# π Returns current Evo architecture stats (for UI display)
|
88 |
+
def get_model_config():
|
89 |
+
return {
|
90 |
+
"num_layers": 6,
|
91 |
+
"num_heads": 8,
|
92 |
+
"ffn_dim": 1024,
|
93 |
+
"memory_enabled": True
|
94 |
+
}
|