fastAPIv2 / components /LLMs /Classifier.py
ragV98's picture
initiated mistral and integrated classifier, cross fingers
8d4ff93
raw
history blame
4.31 kB
# components/LLMs/Classifier.py
# Zero-shot intent classifier using Together AI (Mistral-7B-Instruct)
from typing import Dict, Literal, Tuple
import json
import re
from components.LLMs.Mistral import MistralTogetherClient, build_messages
# ---- Define your canonical intents (keep this list stable) ----
Intent = Literal[
"headlines_request", # user wants today's headlines/digest/news
"preferences_update", # change topics, regions, time, etc.
"greeting", # hi/hello/hey
"help", # how to use / what can you do
"small_talk", # chitchat, jokes, idle talk
"chat_question", # general Q&A about news/economy/etc.
"unsubscribe", # stop/opt-out
"other" # anything else
]
INTENT_SET = [
"headlines_request",
"preferences_update",
"greeting",
"help",
"small_talk",
"chat_question",
"unsubscribe",
"other",
]
# Optional: synonyms that you may normalize before passing to downstream logic
TRIGGER_SYNONYMS = {
"headlines_request": ["headlines", "digest", "daily", "news", "today's headlines", "view headlines"],
"unsubscribe": ["stop", "opt out", "unsubscribe", "cancel"],
"help": ["help", "how it works", "what can you do"],
"greeting": ["hi", "hello", "hey"],
}
SYSTEM_PROMPT = (
"You are an intent classifier for a WhatsApp news assistant called NuseAI. "
"Choose exactly one label from the allowed list, and produce STRICT JSON: "
'{"label": "<one_of_allowed>", "confidence": <0..1>, "reason": "<short>"}.\n\n'
f"ALLOWED_LABELS = {INTENT_SET}.\n"
"- headlines_request: user asks for today's headlines, digest, news, or to view/open the feed.\n"
"- preferences_update: user wants to set or change interests, topics, regions, or delivery time.\n"
"- greeting: greetings like hi/hello/hey.\n"
"- help: asks how to use the bot, commands, or capabilities.\n"
"- small_talk: casual chat or jokes, not seeking info.\n"
"- chat_question: general info or Q&A unrelated to getting today's digest.\n"
"- unsubscribe: stop/opt-out messages.\n"
"- other: anything that doesn't fit.\n\n"
"Rules:\n"
"1) Output valid JSON ONLY (no backticks, no extra text).\n"
"2) confidence in [0,1].\n"
"3) Prefer 'headlines_request' for any phrasing that implies they want today's headlines/digest.\n"
)
USER_INSTRUCTION_TEMPLATE = (
"Classify this WhatsApp message into one label.\n"
"Message: \"{text}\"\n"
"Return JSON only."
)
class ZeroShotClassifier:
def __init__(self):
self.llm = MistralTogetherClient()
def classify(self, text: str) -> Tuple[Intent, Dict]:
# Build messages
user_prompt = USER_INSTRUCTION_TEMPLATE.format(text=text.strip())
msgs = build_messages(user_prompt, SYSTEM_PROMPT)
# Call LLM (deterministic)
raw, _usage = self.llm.chat(msgs, temperature=0.0, max_tokens=200)
# Parse JSON robustly
payload = self._safe_parse_json(raw)
label = payload.get("label", "other")
confidence = float(payload.get("confidence", 0.5))
reason = payload.get("reason", "")
# Guardrails: enforce canonical label
if label not in INTENT_SET:
label = "other"
# Light normalization for obvious cases
norm = text.lower().strip()
if label == "other":
if any(s in norm for s in ["headline", "digest", "news", "today"]):
label = "headlines_request"
elif any(s in norm for s in ["unsubscribe", "opt out", "stop"]):
label = "unsubscribe"
elif re.fullmatch(r"(hi|hello|hey)[!.]?", norm):
label = "greeting"
result = {"label": label, "confidence": confidence, "reason": reason, "raw": raw}
return label, result
@staticmethod
def _safe_parse_json(s: str) -> Dict:
# Try direct parse
try:
return json.loads(s)
except Exception:
pass
# Try to extract the first {...} block
try:
start = s.find("{")
end = s.rfind("}")
if start != -1 and end != -1:
return json.loads(s[start : end + 1])
except Exception:
pass
return {}