HemanM commited on
Commit
c487bf4
·
verified ·
1 Parent(s): 2c2c567

Update evo_inference.py

Browse files
Files changed (1) hide show
  1. evo_inference.py +166 -61
evo_inference.py CHANGED
@@ -1,10 +1,9 @@
1
  """
2
- evo_inference.py — FLAN-optimized + anti-echo
3
- - FLAN-friendly prompt with explicit bullet structure
4
- - Filters placeholder chunks
5
- - Cleans prompt-echo lines
6
- - Anti-echo guard: if the model repeats the question or outputs too little, we fall back to Extractive
7
- - Labeled outputs: [Generative] / [Extractive]
8
  """
9
 
10
  from typing import List, Dict
@@ -23,12 +22,15 @@ except Exception:
23
  except Exception:
24
  _GENERATOR = None
25
 
26
- MAX_SNIPPET_CHARS = 200
 
 
27
 
28
  def _snippet(text: str) -> str:
29
  text = " ".join(text.split())
30
  return text[:MAX_SNIPPET_CHARS] + ("..." if len(text) > MAX_SNIPPET_CHARS else "")
31
 
 
32
  def _extractive_answer(user_query: str, lang: str, hits: List[Dict]) -> str:
33
  if not hits:
34
  return "**[Extractive]**\n\n" + L(lang, "intro_err")
@@ -61,57 +63,102 @@ def _extractive_answer(user_query: str, lang: str, hits: List[Dict]) -> str:
61
  f"**Suggested steps:**\n" + "\n".join(steps)
62
  )
63
 
 
64
  def _lang_name(code: str) -> str:
65
  return {"en": "English", "fr": "French", "mfe": "Kreol Morisien"}.get(code, "English")
66
 
67
- def _filter_hits(hits: List[Dict], keep: int = 4) -> List[Dict]:
68
- # Prefer non-placeholder chunks; if all are placeholders, return originals.
69
- filtered = [h for h in hits if "placeholder" not in h["text"].lower() and "disclaimer" not in h["text"].lower()]
70
- if not filtered:
71
- filtered = hits
72
- return filtered[:keep]
73
 
74
- def _build_grounded_prompt(question: str, lang: str, hits: List[Dict]) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  """
76
- FLAN-style prompt:
77
- Instruction: (clear constraints)
78
- Context: 1) ... 2) ...
79
- Question: ...
80
- Answer: - bullet - bullet ...
81
  """
82
- lang = normalize_lang(lang)
83
- lang_readable = _lang_name(lang)
84
 
85
- if lang == "fr":
86
- instruction = (
87
- "Tu es le Copilote Gouvernemental de Maurice. Réponds UNIQUEMENT à partir du contexte. "
88
- "Ne répète pas la question. Donne 6–10 puces courtes couvrant: Documents requis, Frais, "
89
- "Où postuler, Délai de traitement, Étapes. Si une info manque, dis-le. Pas d'autres sections."
90
- )
91
- elif lang == "mfe":
92
- instruction = (
93
- "To enn Copilot Gouv Moris. Reponn zis lor konteks. Pa repete kestyon. Donn 6–10 pwin kout "
94
- "lor: Dokiman, Fre, Kot pou al, Letan tretman, Steps. Si info manke, dir li. Pa azout seksion anplis."
95
- )
96
- else:
97
- instruction = (
98
- "You are the Mauritius Government Copilot. Use ONLY the context. Do not repeat the question. "
99
- "Write 6–10 short bullet points covering: Required documents, Fees, Where to apply, "
100
- "Processing time, and Steps. If something is missing, say so. No extra sections."
101
- )
102
 
103
- chosen = _filter_hits(hits, keep=6)
104
- ctx_lines = [f"{i+1}) {_snippet(h['text'])}" for i, h in enumerate(chosen)]
105
- ctx_block = "\n".join(ctx_lines) if ctx_lines else "(none)"
106
 
107
- # Prime with a leading dash to encourage bullets.
108
- prompt = (
109
- f"Instruction ({lang_readable}): {instruction}\n\n"
110
- f"Context:\n{ctx_block}\n\n"
111
- f"Question: {question}\n\n"
112
- f"Answer ({lang_readable}):\n- "
113
- )
114
- return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  _ECHO_PATTERNS = [
117
  r"^\s*Instruction.*$", r"^\s*Context:.*$", r"^\s*Question:.*$", r"^\s*Answer.*$",
@@ -120,7 +167,6 @@ _ECHO_PATTERNS = [
120
  ]
121
 
122
  def _clean_generated(text: str) -> str:
123
- # Remove common echoed lines from the model output.
124
  lines = [ln.strip() for ln in text.strip().splitlines()]
125
  out = []
126
  for ln in lines:
@@ -131,34 +177,92 @@ def _clean_generated(text: str) -> str:
131
  cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
132
  return cleaned
133
 
134
- def _is_echo_or_too_short(ans: str, question: str) -> bool:
135
- # Normalize and check if answer is basically the question or too short.
136
  a = re.sub(r"\W+", " ", (ans or "").lower()).strip()
137
  q = re.sub(r"\W+", " ", (question or "").lower()).strip()
138
- if len(a) < 40:
139
  return True
140
  if q and (a.startswith(q) or q in a[: max(80, len(q) + 10)]):
141
  return True
 
 
 
 
 
 
 
 
 
142
  return False
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  def synthesize_with_evo(
145
  user_query: str,
146
  lang: str,
147
  hits: List[Dict],
148
  mode: str = "extractive",
149
  max_new_tokens: int = 192,
150
- temperature: float = 0.4,
151
  ) -> str:
152
- # No context → safe fallback
153
  lang = normalize_lang(lang)
 
154
  if not hits:
155
  return _extractive_answer(user_query, lang, hits)
156
 
157
- # Extractive path or no generator available
 
 
158
  if mode != "generative" or _GENERATOR is None:
159
- return _extractive_answer(user_query, lang, hits)
160
 
161
- prompt = _build_grounded_prompt(user_query, lang, hits)
162
  try:
163
  text = _GENERATOR.generate(
164
  prompt,
@@ -166,8 +270,9 @@ def synthesize_with_evo(
166
  temperature=float(temperature),
167
  )
168
  text = _clean_generated(text)
169
- if _is_echo_or_too_short(text, user_query):
170
- return _extractive_answer(user_query, lang, hits)
 
171
  return "**[Generative]**\n\n" + text
172
  except Exception:
173
- return _extractive_answer(user_query, lang, hits)
 
1
  """
2
+ evo_inference.py — FLAN-optimized + topic router + anti-echo/off-topic
3
+ - Routes queries to the right topic (passport / driving / civil status / business)
4
+ - Prefers chunks whose filename/text match the topic; filters placeholders
5
+ - FLAN-friendly prompt; cleans prompt-echo; falls back if echo/too short/off-topic
6
+ - Labels outputs: [Generative] / [Extractive]
 
7
  """
8
 
9
  from typing import List, Dict
 
22
  except Exception:
23
  _GENERATOR = None
24
 
25
+ # Keep snippets short so FLAN-T5 stays within encoder limit (512)
26
+ MAX_SNIPPET_CHARS = 220
27
+
28
 
29
  def _snippet(text: str) -> str:
30
  text = " ".join(text.split())
31
  return text[:MAX_SNIPPET_CHARS] + ("..." if len(text) > MAX_SNIPPET_CHARS else "")
32
 
33
+
34
  def _extractive_answer(user_query: str, lang: str, hits: List[Dict]) -> str:
35
  if not hits:
36
  return "**[Extractive]**\n\n" + L(lang, "intro_err")
 
63
  f"**Suggested steps:**\n" + "\n".join(steps)
64
  )
65
 
66
+
67
  def _lang_name(code: str) -> str:
68
  return {"en": "English", "fr": "French", "mfe": "Kreol Morisien"}.get(code, "English")
69
 
 
 
 
 
 
 
70
 
71
+ # --- Topic routing -------------------------------------------------------------
72
+
73
+ _TOPIC_MAP = {
74
+ "passport": {
75
+ "file_hints": ["passport_renewal", "passport"],
76
+ "word_hints": ["passport", "passeport", "paspor", "renew", "renouvel"],
77
+ "forbid_words": ["business", "cbrd", "brn", "driving", "licence", "license", "civil status"],
78
+ },
79
+ "driving": {
80
+ "file_hints": ["driving_licence", "driving_license"],
81
+ "word_hints": ["driving", "licence", "license", "permit", "idp", "pf-77"],
82
+ "forbid_words": ["passport", "cbrd", "brn", "civil status"],
83
+ },
84
+ "civil": {
85
+ "file_hints": ["birth_marriage_certificate", "civil_status"],
86
+ "word_hints": ["birth", "naissance", "nesans", "marriage", "mariage", "maryaz", "certificate", "extract"],
87
+ "forbid_words": ["passport", "driving", "cbrd", "brn"],
88
+ },
89
+ "business": {
90
+ "file_hints": ["business_registration_cbrd", "cbrd"],
91
+ "word_hints": ["business", "brn", "cbrd", "register", "trade fee"],
92
+ "forbid_words": ["passport", "driving", "civil status"],
93
+ },
94
+ }
95
+
96
+ def _guess_topic(query: str) -> str:
97
+ q = (query or "").lower()
98
+ if any(w in q for w in _TOPIC_MAP["passport"]["word_hints"]):
99
+ return "passport"
100
+ if any(w in q for w in _TOPIC_MAP["driving"]["word_hints"]):
101
+ return "driving"
102
+ if any(w in q for w in _TOPIC_MAP["civil"]["word_hints"]):
103
+ return "civil"
104
+ if any(w in q for w in _TOPIC_MAP["business"]["word_hints"]):
105
+ return "business"
106
+ return "" # unknown → no routing
107
+
108
+
109
+ def _hit_file(h: Dict) -> str:
110
+ # Try several common fields for filepath
111
+ return (
112
+ h.get("file")
113
+ or h.get("source")
114
+ or (h.get("meta") or {}).get("file")
115
+ or ""
116
+ ).lower()
117
+
118
+
119
+ def _filter_hits(hits: List[Dict], query: str, keep: int = 4) -> List[Dict]:
120
  """
121
+ Prefer non-placeholder + topic-consistent chunks.
122
+ - 1) Drop placeholders
123
+ - 2) If topic known: score by filename hits + keyword overlap
124
+ - 3) Return top 'keep' items
 
125
  """
126
+ if not hits:
127
+ return []
128
 
129
+ # 1) remove placeholders
130
+ pool = [
131
+ h for h in hits
132
+ if "placeholder" not in h["text"].lower() and "disclaimer" not in h["text"].lower()
133
+ ] or hits
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ topic = _guess_topic(query)
136
+ if not topic:
137
+ return pool[:keep]
138
 
139
+ hints = _TOPIC_MAP[topic]
140
+ file_hints = hints["file_hints"]
141
+ word_hints = set(hints["word_hints"])
142
+ forbid = set(hints["forbid_words"])
143
+
144
+ def score(h: Dict) -> float:
145
+ s = 0.0
146
+ f = _hit_file(h)
147
+ t = h["text"].lower()
148
+ # filename boosts
149
+ if any(k in f for k in file_hints):
150
+ s += 2.0
151
+ # keyword overlap boosts
152
+ s += sum(1.0 for w in word_hints if w in t)
153
+ # forbid words penalty
154
+ s -= sum(1.5 for w in forbid if w in t or w in f)
155
+ return s
156
+
157
+ scored = sorted(pool, key=score, reverse=True)
158
+ return scored[:keep]
159
+
160
+
161
+ # --- Prompt build & cleaning ---------------------------------------------------
162
 
163
  _ECHO_PATTERNS = [
164
  r"^\s*Instruction.*$", r"^\s*Context:.*$", r"^\s*Question:.*$", r"^\s*Answer.*$",
 
167
  ]
168
 
169
  def _clean_generated(text: str) -> str:
 
170
  lines = [ln.strip() for ln in text.strip().splitlines()]
171
  out = []
172
  for ln in lines:
 
177
  cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
178
  return cleaned
179
 
180
+
181
+ def _is_echo_or_too_short_or_offtopic(ans: str, question: str, topic: str) -> bool:
182
  a = re.sub(r"\W+", " ", (ans or "").lower()).strip()
183
  q = re.sub(r"\W+", " ", (question or "").lower()).strip()
184
+ if len(a) < 60:
185
  return True
186
  if q and (a.startswith(q) or q in a[: max(80, len(q) + 10)]):
187
  return True
188
+ # crude off-topic guard
189
+ if topic == "passport" and ("business" in a or "cbrd" in a or "brn" in a):
190
+ return True
191
+ if topic == "driving" and ("passport" in a or "cbrd" in a or "brn" in a or "civil status" in a):
192
+ return True
193
+ if topic == "civil" and ("passport" in a or "driving" in a or "cbrd" in a or "brn" in a):
194
+ return True
195
+ if topic == "business" and ("passport" in a or "driving" in a or "civil status" in a):
196
+ return True
197
  return False
198
 
199
+
200
+ def _build_grounded_prompt(question: str, lang: str, hits: List[Dict]) -> str:
201
+ lang = normalize_lang(lang)
202
+ lang_readable = _lang_name(lang)
203
+ topic = _guess_topic(question)
204
+
205
+ # Strong guardrails in the instruction: stay on topic, bullets only
206
+ if lang == "fr":
207
+ instruction = (
208
+ "Tu es le Copilote Gouvernemental de Maurice. Réponds UNIQUEMENT à partir du contexte. "
209
+ "Reste sur le SUJET demandé et ignore les autres documents. Ne répète pas la question. "
210
+ "Écris 6–10 puces courtes couvrant: Documents requis, Frais, Où postuler, Délai, Étapes. "
211
+ "Si une info manque, dis-le. Pas d'autres sections."
212
+ )
213
+ elif lang == "mfe":
214
+ instruction = (
215
+ "To enn Copilot Gouv Moris. Servi ZIS konteks. Reste lor SUZET ki finn demande, "
216
+ "ignorar lezot dokiman. Pa repete kestyon. Ekri 6–10 pwin kout: Dokiman, Fre, Kot pou al, "
217
+ "Letan tretman, Steps. Si info manke, dir li. Pa azout lezot seksion."
218
+ )
219
+ else:
220
+ instruction = (
221
+ "You are the Mauritius Government Copilot. Use ONLY the context. Stay strictly on the "
222
+ "REQUESTED TOPIC and ignore other documents. Do NOT repeat the question. Write 6–10 short "
223
+ "bullets covering: Required documents, Fees, Where to apply, Processing time, Steps. "
224
+ "If something is missing, say so. No extra sections."
225
+ )
226
+
227
+ # Add an explicit topic hint to the instruction (helps FLAN stay on track)
228
+ if topic:
229
+ instruction += f" Topic: {topic}."
230
+
231
+ ctx_lines = [f"{i+1}) {_snippet(h['text'])}" for i, h in enumerate(hits)]
232
+ ctx_block = "\n".join(ctx_lines) if ctx_lines else "(none)"
233
+
234
+ # Prime with leading dash to bias bullet style
235
+ prompt = (
236
+ f"Instruction ({lang_readable}): {instruction}\n\n"
237
+ f"Context:\n{ctx_block}\n\n"
238
+ f"Question: {question}\n\n"
239
+ f"Answer ({lang_readable}):\n- "
240
+ )
241
+ return prompt
242
+
243
+
244
+ # --- Main entry ----------------------------------------------------------------
245
+
246
  def synthesize_with_evo(
247
  user_query: str,
248
  lang: str,
249
  hits: List[Dict],
250
  mode: str = "extractive",
251
  max_new_tokens: int = 192,
252
+ temperature: float = 0.0,
253
  ) -> str:
 
254
  lang = normalize_lang(lang)
255
+
256
  if not hits:
257
  return _extractive_answer(user_query, lang, hits)
258
 
259
+ # Route/filter hits to keep only on-topic, high-signal chunks
260
+ chosen = _filter_hits(hits, user_query, keep=4)
261
+
262
  if mode != "generative" or _GENERATOR is None:
263
+ return _extractive_answer(user_query, lang, chosen)
264
 
265
+ prompt = _build_grounded_prompt(user_query, lang, chosen)
266
  try:
267
  text = _GENERATOR.generate(
268
  prompt,
 
270
  temperature=float(temperature),
271
  )
272
  text = _clean_generated(text)
273
+ topic = _guess_topic(user_query)
274
+ if _is_echo_or_too_short_or_offtopic(text, user_query, topic):
275
+ return _extractive_answer(user_query, lang, chosen)
276
  return "**[Generative]**\n\n" + text
277
  except Exception:
278
+ return _extractive_answer(user_query, lang, chosen)