flare / chat_handler.py
ciyidogan's picture
Update chat_handler.py
01da95c verified
raw
history blame
8.08 kB
"""
Flare – Chat Handler (Spark /generate format)
=============================================
• X-Session-ID header
• /generate payload:
{project_name, user_input, context, system_prompt}
"""
import re, json, uuid, httpx, commentjson
from datetime import datetime
from typing import Dict, List, Optional
from fastapi import APIRouter, HTTPException, Header
from pydantic import BaseModel
from prompt_builder import build_intent_prompt, build_parameter_prompt, log
# --------------------------------------------------------------------------- #
# CONFIG
# --------------------------------------------------------------------------- #
CFG = commentjson.load(open("service_config.jsonc", encoding="utf-8"))
PROJECTS = {p["name"]: p for p in CFG["projects"]}
APIS = {a["name"]: a for a in CFG["apis"]}
SPARK_URL = CFG["config"]["spark_endpoint"].rstrip("/") + "/generate"
# --------------------------------------------------------------------------- #
# SESSION
# --------------------------------------------------------------------------- #
class Session:
def __init__(self, project_name: str):
self.id = str(uuid.uuid4())
self.project = PROJECTS[project_name]
self.history: List[Dict[str, str]] = [] # {role, content}
self.variables: Dict[str, str] = {}
self.awaiting: Optional[Dict] = None
log(f"🆕 Session {self.id} for {project_name}")
SESSIONS: Dict[str, Session] = {}
# --------------------------------------------------------------------------- #
# Spark client
# --------------------------------------------------------------------------- #
async def spark_generate(session: Session,
system_prompt: str,
user_input: str) -> str:
"""Send request to Spark /generate endpoint"""
payload = {
"project_name": session.project["name"],
"user_input": user_input,
"context": session.history[-10:], # only last 10 turns
"system_prompt": system_prompt
}
async with httpx.AsyncClient(timeout=60) as c:
r = await c.post(SPARK_URL, json=payload)
r.raise_for_status()
# Spark örneğinde cevap key'i "model_answer" olabilir.
data = r.json()
return data.get("assistant") or data.get("model_answer") or data.get("text", "")
# --------------------------------------------------------------------------- #
# FASTAPI ROUTER
# --------------------------------------------------------------------------- #
router = APIRouter()
@router.get("/")
def health():
return {"status": "ok"}
class StartSessionRequest(BaseModel):
project_name: str
class ChatBody(BaseModel):
user_input: str
class ChatResponse(BaseModel):
session_id: str
answer: str
# --------------------------------------------------------------------------- #
# ENDPOINTS
# --------------------------------------------------------------------------- #
@router.post("/start_session", response_model=ChatResponse)
async def start_session(req: StartSessionRequest):
if req.project_name not in PROJECTS:
raise HTTPException(404, "Unknown project")
s = Session(req.project_name)
SESSIONS[s.id] = s
return ChatResponse(session_id=s.id, answer="Nasıl yardımcı olabilirim?")
@router.post("/chat", response_model=ChatResponse)
async def chat(body: ChatBody,
x_session_id: str = Header(...)):
if x_session_id not in SESSIONS:
raise HTTPException(404, "Invalid session")
s = SESSIONS[x_session_id]
user_msg = body.user_input.strip()
s.history.append({"role": "user", "content": user_msg})
# ---------------- follow-up modunda mı?
if s.awaiting:
answer = await _followup(s, user_msg)
s.history.append({"role": "assistant", "content": answer})
return ChatResponse(session_id=s.id, answer=answer)
# ---------------- intent detection ----------------
gen_prompt = s.project["versions"][0]["general_prompt"]
intent_out = await spark_generate(s, gen_prompt, user_msg)
if not intent_out.startswith("#DETECTED_INTENT:"):
s.history.append({"role": "assistant", "content": intent_out})
return ChatResponse(session_id=s.id, answer=intent_out)
intent_name = intent_out.split(":", 1)[1].strip()
intent_cfg = _find_intent(s.project, intent_name)
if not intent_cfg:
err = "Üzgünüm, anlayamadım."
s.history.append({"role": "assistant", "content": err})
return ChatResponse(session_id=s.id, answer=err)
answer = await _handle_intent(s, intent_cfg, user_msg)
s.history.append({"role": "assistant", "content": answer})
return ChatResponse(session_id=s.id, answer=answer)
# --------------------------------------------------------------------------- #
# Helper functions
# --------------------------------------------------------------------------- #
def _find_intent(project, name_):
return next((i for i in project["versions"][0]["intents"] if i["name"] == name_), None)
def _missing(s, intent_cfg):
return [p["name"] for p in intent_cfg["parameters"]
if p["variable_name"] not in s.variables]
async def _handle_intent(s, intent_cfg, user_msg):
missing = _missing(s, intent_cfg)
if missing:
p_prompt = build_parameter_prompt(intent_cfg, missing, user_msg, s.history)
p_out = await spark_generate(s, p_prompt, user_msg)
if p_out.startswith("#PARAMETERS:"):
if bad := _process_params(s, intent_cfg, p_out):
return bad
missing = _missing(s, intent_cfg)
if missing:
s.awaiting = {"intent": intent_cfg, "missing": missing}
cap = next(p for p in intent_cfg["parameters"] if p["name"] == missing[0])["caption"]
return f"{cap} nedir?"
s.awaiting = None
return await _call_api(s, intent_cfg)
async def _followup(s, user_msg):
intent_cfg = s.awaiting["intent"]
missing = s.awaiting["missing"]
p_prompt = build_parameter_prompt(intent_cfg, missing, user_msg, s.history)
p_out = await spark_generate(s, p_prompt, user_msg)
if not p_out.startswith("#PARAMETERS:"):
return "Üzgünüm, anlayamadım."
if bad := _process_params(s, intent_cfg, p_out):
return bad
missing = _missing(s, intent_cfg)
if missing:
s.awaiting["missing"] = missing
cap = next(p for p in intent_cfg["parameters"] if p["name"] == missing[0])["caption"]
return f"{cap} nedir?"
s.awaiting = None
return await _call_api(s, intent_cfg)
def _process_params(s, intent_cfg, p_out):
try:
data = json.loads(p_out[len("#PARAMETERS:"):])
except json.JSONDecodeError:
return "Parametreleri çözemedim."
for pair in data.get("extracted", []):
p_cfg = next(p for p in intent_cfg["parameters"] if p["name"] == pair["name"])
if not _valid(p_cfg, pair["value"]):
return p_cfg.get("invalid_prompt", "Geçersiz değer.")
s.variables[p_cfg["variable_name"]] = pair["value"]
return None
def _valid(p_cfg, val):
rx = p_cfg.get("validation_regex")
return re.match(rx, val) is not None if rx else True
async def _call_api(s, intent_cfg):
api = APIS[intent_cfg["action"]]
token = "testtoken"
headers = {k: v.replace("{{token}}", token) for k, v in api["headers"].items()}
body = json.loads(json.dumps(api["body_template"]))
for k, v in body.items():
if isinstance(v, str) and v.startswith("{{") and v.endswith("}}"):
body[k] = s.variables.get(v[2:-2], "")
try:
async with httpx.AsyncClient(timeout=api["timeout_seconds"]) as c:
r = await c.request(api["method"], api["url"], headers=headers, json=body)
r.raise_for_status()
api_json = r.json()
except Exception:
return intent_cfg["fallback_error_prompt"]
summary_prompt = api["response_prompt"].replace(
"{{api_response}}", json.dumps(api_json, ensure_ascii=False)
)
return await spark_generate(s, summary_prompt, "")