evo-gov-copilot-mu / evo_inference.py
HemanM's picture
Update evo_inference.py
c487bf4 verified
"""
evo_inference.py — FLAN-optimized + topic router + anti-echo/off-topic
- Routes queries to the right topic (passport / driving / civil status / business)
- Prefers chunks whose filename/text match the topic; filters placeholders
- FLAN-friendly prompt; cleans prompt-echo; falls back if echo/too short/off-topic
- Labels outputs: [Generative] / [Extractive]
"""
from typing import List, Dict
import re
from utils_lang import L, normalize_lang
# Try to load your real Evo plugin first; else use the example; else None.
_GENERATOR = None
try:
from evo_plugin import load_model as _load_real
_GENERATOR = _load_real()
except Exception:
try:
from evo_plugin_example import load_model as _load_example
_GENERATOR = _load_example()
except Exception:
_GENERATOR = None
# Keep snippets short so FLAN-T5 stays within encoder limit (512)
MAX_SNIPPET_CHARS = 220
def _snippet(text: str) -> str:
text = " ".join(text.split())
return text[:MAX_SNIPPET_CHARS] + ("..." if len(text) > MAX_SNIPPET_CHARS else "")
def _extractive_answer(user_query: str, lang: str, hits: List[Dict]) -> str:
if not hits:
return "**[Extractive]**\n\n" + L(lang, "intro_err")
bullets = [f"- {_snippet(h['text'])}" for h in hits[:4]]
steps = {
"en": [
"• Step 1: Check eligibility & gather required documents.",
"• Step 2: Confirm fees & payment options.",
"• Step 3: Apply online or at the indicated office.",
"• Step 4: Keep reference/receipt; track processing time.",
],
"fr": [
"• Étape 1 : Vérifiez l’éligibilité et rassemblez les documents requis.",
"• Étape 2 : Confirmez les frais et les moyens de paiement.",
"• Étape 3 : Déposez la demande en ligne ou au bureau indiqué.",
"• Étape 4 : Conservez le reçu/la référence et suivez le délai de traitement.",
],
"mfe": [
"• Step 1: Get dokiman neseser ek verifie si to elegib.",
"• Step 2: Konfirm fre ek manyer peyman.",
"• Step 3: Fer demand online ouswa dan biro ki indike.",
"• Step 4: Gard referans/reso; swiv letan tretman.",
],
}[normalize_lang(lang)]
return (
"**[Extractive]**\n\n"
f"**{L(lang, 'intro_ok')}**\n\n"
f"**Q:** {user_query}\n\n"
f"**Key information:**\n" + "\n".join(bullets) + "\n\n"
f"**Suggested steps:**\n" + "\n".join(steps)
)
def _lang_name(code: str) -> str:
return {"en": "English", "fr": "French", "mfe": "Kreol Morisien"}.get(code, "English")
# --- Topic routing -------------------------------------------------------------
_TOPIC_MAP = {
"passport": {
"file_hints": ["passport_renewal", "passport"],
"word_hints": ["passport", "passeport", "paspor", "renew", "renouvel"],
"forbid_words": ["business", "cbrd", "brn", "driving", "licence", "license", "civil status"],
},
"driving": {
"file_hints": ["driving_licence", "driving_license"],
"word_hints": ["driving", "licence", "license", "permit", "idp", "pf-77"],
"forbid_words": ["passport", "cbrd", "brn", "civil status"],
},
"civil": {
"file_hints": ["birth_marriage_certificate", "civil_status"],
"word_hints": ["birth", "naissance", "nesans", "marriage", "mariage", "maryaz", "certificate", "extract"],
"forbid_words": ["passport", "driving", "cbrd", "brn"],
},
"business": {
"file_hints": ["business_registration_cbrd", "cbrd"],
"word_hints": ["business", "brn", "cbrd", "register", "trade fee"],
"forbid_words": ["passport", "driving", "civil status"],
},
}
def _guess_topic(query: str) -> str:
q = (query or "").lower()
if any(w in q for w in _TOPIC_MAP["passport"]["word_hints"]):
return "passport"
if any(w in q for w in _TOPIC_MAP["driving"]["word_hints"]):
return "driving"
if any(w in q for w in _TOPIC_MAP["civil"]["word_hints"]):
return "civil"
if any(w in q for w in _TOPIC_MAP["business"]["word_hints"]):
return "business"
return "" # unknown → no routing
def _hit_file(h: Dict) -> str:
# Try several common fields for filepath
return (
h.get("file")
or h.get("source")
or (h.get("meta") or {}).get("file")
or ""
).lower()
def _filter_hits(hits: List[Dict], query: str, keep: int = 4) -> List[Dict]:
"""
Prefer non-placeholder + topic-consistent chunks.
- 1) Drop placeholders
- 2) If topic known: score by filename hits + keyword overlap
- 3) Return top 'keep' items
"""
if not hits:
return []
# 1) remove placeholders
pool = [
h for h in hits
if "placeholder" not in h["text"].lower() and "disclaimer" not in h["text"].lower()
] or hits
topic = _guess_topic(query)
if not topic:
return pool[:keep]
hints = _TOPIC_MAP[topic]
file_hints = hints["file_hints"]
word_hints = set(hints["word_hints"])
forbid = set(hints["forbid_words"])
def score(h: Dict) -> float:
s = 0.0
f = _hit_file(h)
t = h["text"].lower()
# filename boosts
if any(k in f for k in file_hints):
s += 2.0
# keyword overlap boosts
s += sum(1.0 for w in word_hints if w in t)
# forbid words penalty
s -= sum(1.5 for w in forbid if w in t or w in f)
return s
scored = sorted(pool, key=score, reverse=True)
return scored[:keep]
# --- Prompt build & cleaning ---------------------------------------------------
_ECHO_PATTERNS = [
r"^\s*Instruction.*$", r"^\s*Context:.*$", r"^\s*Question:.*$", r"^\s*Answer.*$",
r"^\s*\[Instructions?\].*$", r"^\s*Be concise.*$", r"^\s*Do not invent.*$",
r"^\s*(en|fr|mfe)\s*$",
]
def _clean_generated(text: str) -> str:
lines = [ln.strip() for ln in text.strip().splitlines()]
out = []
for ln in lines:
if any(re.match(pat, ln, flags=re.IGNORECASE) for pat in _ECHO_PATTERNS):
continue
out.append(ln)
cleaned = "\n".join(out).strip()
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
return cleaned
def _is_echo_or_too_short_or_offtopic(ans: str, question: str, topic: str) -> bool:
a = re.sub(r"\W+", " ", (ans or "").lower()).strip()
q = re.sub(r"\W+", " ", (question or "").lower()).strip()
if len(a) < 60:
return True
if q and (a.startswith(q) or q in a[: max(80, len(q) + 10)]):
return True
# crude off-topic guard
if topic == "passport" and ("business" in a or "cbrd" in a or "brn" in a):
return True
if topic == "driving" and ("passport" in a or "cbrd" in a or "brn" in a or "civil status" in a):
return True
if topic == "civil" and ("passport" in a or "driving" in a or "cbrd" in a or "brn" in a):
return True
if topic == "business" and ("passport" in a or "driving" in a or "civil status" in a):
return True
return False
def _build_grounded_prompt(question: str, lang: str, hits: List[Dict]) -> str:
lang = normalize_lang(lang)
lang_readable = _lang_name(lang)
topic = _guess_topic(question)
# Strong guardrails in the instruction: stay on topic, bullets only
if lang == "fr":
instruction = (
"Tu es le Copilote Gouvernemental de Maurice. Réponds UNIQUEMENT à partir du contexte. "
"Reste sur le SUJET demandé et ignore les autres documents. Ne répète pas la question. "
"Écris 6–10 puces courtes couvrant: Documents requis, Frais, Où postuler, Délai, Étapes. "
"Si une info manque, dis-le. Pas d'autres sections."
)
elif lang == "mfe":
instruction = (
"To enn Copilot Gouv Moris. Servi ZIS konteks. Reste lor SUZET ki finn demande, "
"ignorar lezot dokiman. Pa repete kestyon. Ekri 6–10 pwin kout: Dokiman, Fre, Kot pou al, "
"Letan tretman, Steps. Si info manke, dir li. Pa azout lezot seksion."
)
else:
instruction = (
"You are the Mauritius Government Copilot. Use ONLY the context. Stay strictly on the "
"REQUESTED TOPIC and ignore other documents. Do NOT repeat the question. Write 6–10 short "
"bullets covering: Required documents, Fees, Where to apply, Processing time, Steps. "
"If something is missing, say so. No extra sections."
)
# Add an explicit topic hint to the instruction (helps FLAN stay on track)
if topic:
instruction += f" Topic: {topic}."
ctx_lines = [f"{i+1}) {_snippet(h['text'])}" for i, h in enumerate(hits)]
ctx_block = "\n".join(ctx_lines) if ctx_lines else "(none)"
# Prime with leading dash to bias bullet style
prompt = (
f"Instruction ({lang_readable}): {instruction}\n\n"
f"Context:\n{ctx_block}\n\n"
f"Question: {question}\n\n"
f"Answer ({lang_readable}):\n- "
)
return prompt
# --- Main entry ----------------------------------------------------------------
def synthesize_with_evo(
user_query: str,
lang: str,
hits: List[Dict],
mode: str = "extractive",
max_new_tokens: int = 192,
temperature: float = 0.0,
) -> str:
lang = normalize_lang(lang)
if not hits:
return _extractive_answer(user_query, lang, hits)
# Route/filter hits to keep only on-topic, high-signal chunks
chosen = _filter_hits(hits, user_query, keep=4)
if mode != "generative" or _GENERATOR is None:
return _extractive_answer(user_query, lang, chosen)
prompt = _build_grounded_prompt(user_query, lang, chosen)
try:
text = _GENERATOR.generate(
prompt,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
)
text = _clean_generated(text)
topic = _guess_topic(user_query)
if _is_echo_or_too_short_or_offtopic(text, user_query, topic):
return _extractive_answer(user_query, lang, chosen)
return "**[Generative]**\n\n" + text
except Exception:
return _extractive_answer(user_query, lang, chosen)