|
|
|
|
|
|
|
from typing import Dict, Literal, Tuple |
|
import json |
|
import re |
|
|
|
from components.LLMs.Mistral import MistralTogetherClient, build_messages |
|
|
|
|
|
Intent = Literal[ |
|
"headlines_request", |
|
"preferences_update", |
|
"greeting", |
|
"help", |
|
"small_talk", |
|
"chat_question", |
|
"unsubscribe", |
|
"other" |
|
] |
|
|
|
INTENT_SET = [ |
|
"headlines_request", |
|
"preferences_update", |
|
"greeting", |
|
"help", |
|
"small_talk", |
|
"chat_question", |
|
"unsubscribe", |
|
"other", |
|
] |
|
|
|
|
|
TRIGGER_SYNONYMS = { |
|
"headlines_request": ["headlines", "digest", "daily", "news", "today's headlines", "view headlines"], |
|
"unsubscribe": ["stop", "opt out", "unsubscribe", "cancel"], |
|
"help": ["help", "how it works", "what can you do"], |
|
"greeting": ["hi", "hello", "hey"], |
|
} |
|
|
|
SYSTEM_PROMPT = ( |
|
"You are an intent classifier for a WhatsApp news assistant called NuseAI. " |
|
"Choose exactly one label from the allowed list, and produce STRICT JSON: " |
|
'{"label": "<one_of_allowed>", "confidence": <0..1>, "reason": "<short>"}.\n\n' |
|
f"ALLOWED_LABELS = {INTENT_SET}.\n" |
|
"- headlines_request: user asks for today's headlines, digest, news, or to view/open the feed.\n" |
|
"- preferences_update: user wants to set or change interests, topics, regions, or delivery time.\n" |
|
"- greeting: greetings like hi/hello/hey.\n" |
|
"- help: asks how to use the bot, commands, or capabilities.\n" |
|
"- small_talk: casual chat or jokes, not seeking info.\n" |
|
"- chat_question: general info or Q&A unrelated to getting today's digest.\n" |
|
"- unsubscribe: stop/opt-out messages.\n" |
|
"- other: anything that doesn't fit.\n\n" |
|
"Rules:\n" |
|
"1) Output valid JSON ONLY (no backticks, no extra text).\n" |
|
"2) confidence in [0,1].\n" |
|
"3) Prefer 'headlines_request' for any phrasing that implies they want today's headlines/digest.\n" |
|
) |
|
|
|
USER_INSTRUCTION_TEMPLATE = ( |
|
"Classify this WhatsApp message into one label.\n" |
|
"Message: \"{text}\"\n" |
|
"Return JSON only." |
|
) |
|
|
|
class ZeroShotClassifier: |
|
def __init__(self): |
|
self.llm = MistralTogetherClient() |
|
|
|
def classify(self, text: str) -> Tuple[Intent, Dict]: |
|
|
|
user_prompt = USER_INSTRUCTION_TEMPLATE.format(text=text.strip()) |
|
msgs = build_messages(user_prompt, SYSTEM_PROMPT) |
|
|
|
|
|
raw, _usage = self.llm.chat(msgs, temperature=0.0, max_tokens=200) |
|
|
|
|
|
payload = self._safe_parse_json(raw) |
|
label = payload.get("label", "other") |
|
confidence = float(payload.get("confidence", 0.5)) |
|
reason = payload.get("reason", "") |
|
|
|
|
|
if label not in INTENT_SET: |
|
label = "other" |
|
|
|
|
|
norm = text.lower().strip() |
|
if label == "other": |
|
if any(s in norm for s in ["headline", "digest", "news", "today"]): |
|
label = "headlines_request" |
|
elif any(s in norm for s in ["unsubscribe", "opt out", "stop"]): |
|
label = "unsubscribe" |
|
elif re.fullmatch(r"(hi|hello|hey)[!.]?", norm): |
|
label = "greeting" |
|
|
|
result = {"label": label, "confidence": confidence, "reason": reason, "raw": raw} |
|
return label, result |
|
|
|
@staticmethod |
|
def _safe_parse_json(s: str) -> Dict: |
|
|
|
try: |
|
return json.loads(s) |
|
except Exception: |
|
pass |
|
|
|
try: |
|
start = s.find("{") |
|
end = s.rfind("}") |
|
if start != -1 and end != -1: |
|
return json.loads(s[start : end + 1]) |
|
except Exception: |
|
pass |
|
return {} |
|
|