Spaces:
Sleeping
Sleeping
Update evo_inference.py
Browse files- evo_inference.py +166 -61
evo_inference.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1 |
"""
|
2 |
-
evo_inference.py — FLAN-optimized + anti-echo
|
3 |
-
-
|
4 |
-
-
|
5 |
-
-
|
6 |
-
-
|
7 |
-
- Labeled outputs: [Generative] / [Extractive]
|
8 |
"""
|
9 |
|
10 |
from typing import List, Dict
|
@@ -23,12 +22,15 @@ except Exception:
|
|
23 |
except Exception:
|
24 |
_GENERATOR = None
|
25 |
|
26 |
-
|
|
|
|
|
27 |
|
28 |
def _snippet(text: str) -> str:
|
29 |
text = " ".join(text.split())
|
30 |
return text[:MAX_SNIPPET_CHARS] + ("..." if len(text) > MAX_SNIPPET_CHARS else "")
|
31 |
|
|
|
32 |
def _extractive_answer(user_query: str, lang: str, hits: List[Dict]) -> str:
|
33 |
if not hits:
|
34 |
return "**[Extractive]**\n\n" + L(lang, "intro_err")
|
@@ -61,57 +63,102 @@ def _extractive_answer(user_query: str, lang: str, hits: List[Dict]) -> str:
|
|
61 |
f"**Suggested steps:**\n" + "\n".join(steps)
|
62 |
)
|
63 |
|
|
|
64 |
def _lang_name(code: str) -> str:
|
65 |
return {"en": "English", "fr": "French", "mfe": "Kreol Morisien"}.get(code, "English")
|
66 |
|
67 |
-
def _filter_hits(hits: List[Dict], keep: int = 4) -> List[Dict]:
|
68 |
-
# Prefer non-placeholder chunks; if all are placeholders, return originals.
|
69 |
-
filtered = [h for h in hits if "placeholder" not in h["text"].lower() and "disclaimer" not in h["text"].lower()]
|
70 |
-
if not filtered:
|
71 |
-
filtered = hits
|
72 |
-
return filtered[:keep]
|
73 |
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
"""
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
Answer: - bullet - bullet ...
|
81 |
"""
|
82 |
-
|
83 |
-
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
)
|
91 |
-
elif lang == "mfe":
|
92 |
-
instruction = (
|
93 |
-
"To enn Copilot Gouv Moris. Reponn zis lor konteks. Pa repete kestyon. Donn 6–10 pwin kout "
|
94 |
-
"lor: Dokiman, Fre, Kot pou al, Letan tretman, Steps. Si info manke, dir li. Pa azout seksion anplis."
|
95 |
-
)
|
96 |
-
else:
|
97 |
-
instruction = (
|
98 |
-
"You are the Mauritius Government Copilot. Use ONLY the context. Do not repeat the question. "
|
99 |
-
"Write 6–10 short bullet points covering: Required documents, Fees, Where to apply, "
|
100 |
-
"Processing time, and Steps. If something is missing, say so. No extra sections."
|
101 |
-
)
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
_ECHO_PATTERNS = [
|
117 |
r"^\s*Instruction.*$", r"^\s*Context:.*$", r"^\s*Question:.*$", r"^\s*Answer.*$",
|
@@ -120,7 +167,6 @@ _ECHO_PATTERNS = [
|
|
120 |
]
|
121 |
|
122 |
def _clean_generated(text: str) -> str:
|
123 |
-
# Remove common echoed lines from the model output.
|
124 |
lines = [ln.strip() for ln in text.strip().splitlines()]
|
125 |
out = []
|
126 |
for ln in lines:
|
@@ -131,34 +177,92 @@ def _clean_generated(text: str) -> str:
|
|
131 |
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
|
132 |
return cleaned
|
133 |
|
134 |
-
|
135 |
-
|
136 |
a = re.sub(r"\W+", " ", (ans or "").lower()).strip()
|
137 |
q = re.sub(r"\W+", " ", (question or "").lower()).strip()
|
138 |
-
if len(a) <
|
139 |
return True
|
140 |
if q and (a.startswith(q) or q in a[: max(80, len(q) + 10)]):
|
141 |
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
return False
|
143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
def synthesize_with_evo(
|
145 |
user_query: str,
|
146 |
lang: str,
|
147 |
hits: List[Dict],
|
148 |
mode: str = "extractive",
|
149 |
max_new_tokens: int = 192,
|
150 |
-
temperature: float = 0.
|
151 |
) -> str:
|
152 |
-
# No context → safe fallback
|
153 |
lang = normalize_lang(lang)
|
|
|
154 |
if not hits:
|
155 |
return _extractive_answer(user_query, lang, hits)
|
156 |
|
157 |
-
#
|
|
|
|
|
158 |
if mode != "generative" or _GENERATOR is None:
|
159 |
-
return _extractive_answer(user_query, lang,
|
160 |
|
161 |
-
prompt = _build_grounded_prompt(user_query, lang,
|
162 |
try:
|
163 |
text = _GENERATOR.generate(
|
164 |
prompt,
|
@@ -166,8 +270,9 @@ def synthesize_with_evo(
|
|
166 |
temperature=float(temperature),
|
167 |
)
|
168 |
text = _clean_generated(text)
|
169 |
-
|
170 |
-
|
|
|
171 |
return "**[Generative]**\n\n" + text
|
172 |
except Exception:
|
173 |
-
return _extractive_answer(user_query, lang,
|
|
|
1 |
"""
|
2 |
+
evo_inference.py — FLAN-optimized + topic router + anti-echo/off-topic
|
3 |
+
- Routes queries to the right topic (passport / driving / civil status / business)
|
4 |
+
- Prefers chunks whose filename/text match the topic; filters placeholders
|
5 |
+
- FLAN-friendly prompt; cleans prompt-echo; falls back if echo/too short/off-topic
|
6 |
+
- Labels outputs: [Generative] / [Extractive]
|
|
|
7 |
"""
|
8 |
|
9 |
from typing import List, Dict
|
|
|
22 |
except Exception:
|
23 |
_GENERATOR = None
|
24 |
|
25 |
+
# Keep snippets short so FLAN-T5 stays within encoder limit (512)
|
26 |
+
MAX_SNIPPET_CHARS = 220
|
27 |
+
|
28 |
|
29 |
def _snippet(text: str) -> str:
|
30 |
text = " ".join(text.split())
|
31 |
return text[:MAX_SNIPPET_CHARS] + ("..." if len(text) > MAX_SNIPPET_CHARS else "")
|
32 |
|
33 |
+
|
34 |
def _extractive_answer(user_query: str, lang: str, hits: List[Dict]) -> str:
|
35 |
if not hits:
|
36 |
return "**[Extractive]**\n\n" + L(lang, "intro_err")
|
|
|
63 |
f"**Suggested steps:**\n" + "\n".join(steps)
|
64 |
)
|
65 |
|
66 |
+
|
67 |
def _lang_name(code: str) -> str:
|
68 |
return {"en": "English", "fr": "French", "mfe": "Kreol Morisien"}.get(code, "English")
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
+
# --- Topic routing -------------------------------------------------------------
|
72 |
+
|
73 |
+
_TOPIC_MAP = {
|
74 |
+
"passport": {
|
75 |
+
"file_hints": ["passport_renewal", "passport"],
|
76 |
+
"word_hints": ["passport", "passeport", "paspor", "renew", "renouvel"],
|
77 |
+
"forbid_words": ["business", "cbrd", "brn", "driving", "licence", "license", "civil status"],
|
78 |
+
},
|
79 |
+
"driving": {
|
80 |
+
"file_hints": ["driving_licence", "driving_license"],
|
81 |
+
"word_hints": ["driving", "licence", "license", "permit", "idp", "pf-77"],
|
82 |
+
"forbid_words": ["passport", "cbrd", "brn", "civil status"],
|
83 |
+
},
|
84 |
+
"civil": {
|
85 |
+
"file_hints": ["birth_marriage_certificate", "civil_status"],
|
86 |
+
"word_hints": ["birth", "naissance", "nesans", "marriage", "mariage", "maryaz", "certificate", "extract"],
|
87 |
+
"forbid_words": ["passport", "driving", "cbrd", "brn"],
|
88 |
+
},
|
89 |
+
"business": {
|
90 |
+
"file_hints": ["business_registration_cbrd", "cbrd"],
|
91 |
+
"word_hints": ["business", "brn", "cbrd", "register", "trade fee"],
|
92 |
+
"forbid_words": ["passport", "driving", "civil status"],
|
93 |
+
},
|
94 |
+
}
|
95 |
+
|
96 |
+
def _guess_topic(query: str) -> str:
|
97 |
+
q = (query or "").lower()
|
98 |
+
if any(w in q for w in _TOPIC_MAP["passport"]["word_hints"]):
|
99 |
+
return "passport"
|
100 |
+
if any(w in q for w in _TOPIC_MAP["driving"]["word_hints"]):
|
101 |
+
return "driving"
|
102 |
+
if any(w in q for w in _TOPIC_MAP["civil"]["word_hints"]):
|
103 |
+
return "civil"
|
104 |
+
if any(w in q for w in _TOPIC_MAP["business"]["word_hints"]):
|
105 |
+
return "business"
|
106 |
+
return "" # unknown → no routing
|
107 |
+
|
108 |
+
|
109 |
+
def _hit_file(h: Dict) -> str:
|
110 |
+
# Try several common fields for filepath
|
111 |
+
return (
|
112 |
+
h.get("file")
|
113 |
+
or h.get("source")
|
114 |
+
or (h.get("meta") or {}).get("file")
|
115 |
+
or ""
|
116 |
+
).lower()
|
117 |
+
|
118 |
+
|
119 |
+
def _filter_hits(hits: List[Dict], query: str, keep: int = 4) -> List[Dict]:
|
120 |
"""
|
121 |
+
Prefer non-placeholder + topic-consistent chunks.
|
122 |
+
- 1) Drop placeholders
|
123 |
+
- 2) If topic known: score by filename hits + keyword overlap
|
124 |
+
- 3) Return top 'keep' items
|
|
|
125 |
"""
|
126 |
+
if not hits:
|
127 |
+
return []
|
128 |
|
129 |
+
# 1) remove placeholders
|
130 |
+
pool = [
|
131 |
+
h for h in hits
|
132 |
+
if "placeholder" not in h["text"].lower() and "disclaimer" not in h["text"].lower()
|
133 |
+
] or hits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
+
topic = _guess_topic(query)
|
136 |
+
if not topic:
|
137 |
+
return pool[:keep]
|
138 |
|
139 |
+
hints = _TOPIC_MAP[topic]
|
140 |
+
file_hints = hints["file_hints"]
|
141 |
+
word_hints = set(hints["word_hints"])
|
142 |
+
forbid = set(hints["forbid_words"])
|
143 |
+
|
144 |
+
def score(h: Dict) -> float:
|
145 |
+
s = 0.0
|
146 |
+
f = _hit_file(h)
|
147 |
+
t = h["text"].lower()
|
148 |
+
# filename boosts
|
149 |
+
if any(k in f for k in file_hints):
|
150 |
+
s += 2.0
|
151 |
+
# keyword overlap boosts
|
152 |
+
s += sum(1.0 for w in word_hints if w in t)
|
153 |
+
# forbid words penalty
|
154 |
+
s -= sum(1.5 for w in forbid if w in t or w in f)
|
155 |
+
return s
|
156 |
+
|
157 |
+
scored = sorted(pool, key=score, reverse=True)
|
158 |
+
return scored[:keep]
|
159 |
+
|
160 |
+
|
161 |
+
# --- Prompt build & cleaning ---------------------------------------------------
|
162 |
|
163 |
_ECHO_PATTERNS = [
|
164 |
r"^\s*Instruction.*$", r"^\s*Context:.*$", r"^\s*Question:.*$", r"^\s*Answer.*$",
|
|
|
167 |
]
|
168 |
|
169 |
def _clean_generated(text: str) -> str:
|
|
|
170 |
lines = [ln.strip() for ln in text.strip().splitlines()]
|
171 |
out = []
|
172 |
for ln in lines:
|
|
|
177 |
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
|
178 |
return cleaned
|
179 |
|
180 |
+
|
181 |
+
def _is_echo_or_too_short_or_offtopic(ans: str, question: str, topic: str) -> bool:
|
182 |
a = re.sub(r"\W+", " ", (ans or "").lower()).strip()
|
183 |
q = re.sub(r"\W+", " ", (question or "").lower()).strip()
|
184 |
+
if len(a) < 60:
|
185 |
return True
|
186 |
if q and (a.startswith(q) or q in a[: max(80, len(q) + 10)]):
|
187 |
return True
|
188 |
+
# crude off-topic guard
|
189 |
+
if topic == "passport" and ("business" in a or "cbrd" in a or "brn" in a):
|
190 |
+
return True
|
191 |
+
if topic == "driving" and ("passport" in a or "cbrd" in a or "brn" in a or "civil status" in a):
|
192 |
+
return True
|
193 |
+
if topic == "civil" and ("passport" in a or "driving" in a or "cbrd" in a or "brn" in a):
|
194 |
+
return True
|
195 |
+
if topic == "business" and ("passport" in a or "driving" in a or "civil status" in a):
|
196 |
+
return True
|
197 |
return False
|
198 |
|
199 |
+
|
200 |
+
def _build_grounded_prompt(question: str, lang: str, hits: List[Dict]) -> str:
|
201 |
+
lang = normalize_lang(lang)
|
202 |
+
lang_readable = _lang_name(lang)
|
203 |
+
topic = _guess_topic(question)
|
204 |
+
|
205 |
+
# Strong guardrails in the instruction: stay on topic, bullets only
|
206 |
+
if lang == "fr":
|
207 |
+
instruction = (
|
208 |
+
"Tu es le Copilote Gouvernemental de Maurice. Réponds UNIQUEMENT à partir du contexte. "
|
209 |
+
"Reste sur le SUJET demandé et ignore les autres documents. Ne répète pas la question. "
|
210 |
+
"Écris 6–10 puces courtes couvrant: Documents requis, Frais, Où postuler, Délai, Étapes. "
|
211 |
+
"Si une info manque, dis-le. Pas d'autres sections."
|
212 |
+
)
|
213 |
+
elif lang == "mfe":
|
214 |
+
instruction = (
|
215 |
+
"To enn Copilot Gouv Moris. Servi ZIS konteks. Reste lor SUZET ki finn demande, "
|
216 |
+
"ignorar lezot dokiman. Pa repete kestyon. Ekri 6–10 pwin kout: Dokiman, Fre, Kot pou al, "
|
217 |
+
"Letan tretman, Steps. Si info manke, dir li. Pa azout lezot seksion."
|
218 |
+
)
|
219 |
+
else:
|
220 |
+
instruction = (
|
221 |
+
"You are the Mauritius Government Copilot. Use ONLY the context. Stay strictly on the "
|
222 |
+
"REQUESTED TOPIC and ignore other documents. Do NOT repeat the question. Write 6–10 short "
|
223 |
+
"bullets covering: Required documents, Fees, Where to apply, Processing time, Steps. "
|
224 |
+
"If something is missing, say so. No extra sections."
|
225 |
+
)
|
226 |
+
|
227 |
+
# Add an explicit topic hint to the instruction (helps FLAN stay on track)
|
228 |
+
if topic:
|
229 |
+
instruction += f" Topic: {topic}."
|
230 |
+
|
231 |
+
ctx_lines = [f"{i+1}) {_snippet(h['text'])}" for i, h in enumerate(hits)]
|
232 |
+
ctx_block = "\n".join(ctx_lines) if ctx_lines else "(none)"
|
233 |
+
|
234 |
+
# Prime with leading dash to bias bullet style
|
235 |
+
prompt = (
|
236 |
+
f"Instruction ({lang_readable}): {instruction}\n\n"
|
237 |
+
f"Context:\n{ctx_block}\n\n"
|
238 |
+
f"Question: {question}\n\n"
|
239 |
+
f"Answer ({lang_readable}):\n- "
|
240 |
+
)
|
241 |
+
return prompt
|
242 |
+
|
243 |
+
|
244 |
+
# --- Main entry ----------------------------------------------------------------
|
245 |
+
|
246 |
def synthesize_with_evo(
|
247 |
user_query: str,
|
248 |
lang: str,
|
249 |
hits: List[Dict],
|
250 |
mode: str = "extractive",
|
251 |
max_new_tokens: int = 192,
|
252 |
+
temperature: float = 0.0,
|
253 |
) -> str:
|
|
|
254 |
lang = normalize_lang(lang)
|
255 |
+
|
256 |
if not hits:
|
257 |
return _extractive_answer(user_query, lang, hits)
|
258 |
|
259 |
+
# Route/filter hits to keep only on-topic, high-signal chunks
|
260 |
+
chosen = _filter_hits(hits, user_query, keep=4)
|
261 |
+
|
262 |
if mode != "generative" or _GENERATOR is None:
|
263 |
+
return _extractive_answer(user_query, lang, chosen)
|
264 |
|
265 |
+
prompt = _build_grounded_prompt(user_query, lang, chosen)
|
266 |
try:
|
267 |
text = _GENERATOR.generate(
|
268 |
prompt,
|
|
|
270 |
temperature=float(temperature),
|
271 |
)
|
272 |
text = _clean_generated(text)
|
273 |
+
topic = _guess_topic(user_query)
|
274 |
+
if _is_echo_or_too_short_or_offtopic(text, user_query, topic):
|
275 |
+
return _extractive_answer(user_query, lang, chosen)
|
276 |
return "**[Generative]**\n\n" + text
|
277 |
except Exception:
|
278 |
+
return _extractive_answer(user_query, lang, chosen)
|