flare / inference_test.py
ciyidogan's picture
Update inference_test.py
d3cb3ec verified
raw
history blame
6.9 kB
import os, threading, uvicorn, time, traceback, random, json, asyncio, uuid
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import intent_test_runner
from service_config import ServiceConfig
import intent, intent, llm_model
from log import log
s_config = ServiceConfig()
s_config.setup_environment()
# === FastAPI
app = FastAPI()
chat_history = []
@app.get("/")
def health():
return {"status": "ok"}
import uuid # yukarıda zaten eklendiğini varsayıyoruz
@app.post("/run_tests", status_code=202)
def run_tests():
log("🚦 /run_tests çağrıldı. Testler başlatılıyor...")
threading.Thread(target=intent_test_runner.run_all_tests, daemon=True).start()
return {"status": "running", "message": "Test süreci başlatıldı."}
@app.get("/start", response_class=HTMLResponse)
def root():
# Yeni session ID üret
session_id = str(uuid.uuid4())
session_info = {
"session_id": session_id,
"variables": {},
"auth_tokens": {},
"last_intent": None,
"awaiting_variable": None
}
# Session store başlatıldıysa ekle
if not hasattr(app.state, "session_store"):
app.state.session_store = {}
app.state.session_store[session_id] = session_info
log(f"🌐 /start ile yeni session başlatıldı: {session_id}")
# HTML + session_id gömülü
return f"""
<html><body>
<h2>Turkcell LLM Chat</h2>
<textarea id='input' rows='4' cols='60'></textarea><br>
<button onclick='send()'>Gönder</button><br><br>
<label>Model Cevabı:</label><br>
<textarea id='output' rows='10' cols='80' readonly style='white-space: pre-wrap;'></textarea>
<script>
const sessionId = "{session_id}";
localStorage.setItem("session_id", sessionId);
async function send() {{
const input = document.getElementById("input").value;
const res = await fetch('/chat', {{
method: 'POST',
headers: {{
'Content-Type': 'application/json',
'X-Session-ID': sessionId
}},
body: JSON.stringify({{ user_input: input }})
}});
const data = await res.json();
document.getElementById('output').value = data.reply || data.response || data.error || 'Hata oluştu.';
}}
</script>
</body></html>
"""
@app.post("/start_chat")
def start_chat():
if llm_model.model is None or llm_model.tokenizer is None:
return {"error": "Model yüklenmedi."}
if not hasattr(app.state, "session_store"):
app.state.session_store = {}
session_id = str(uuid.uuid4())
session_info = {
"session_id": session_id,
"variables": {},
"auth_tokens": {},
"last_intent": None,
"awaiting_variable": None
}
app.state.session_store[session_id] = session_info
log(f"🆕 Yeni session başlatıldı: {session_id}")
return {"session_id": session_id}
@app.post("/train_intents", status_code=202)
def train_intents(train_input: intent.TrainInput):
log("📥 POST /train_intents çağrıldı.")
intents = train_input.intents
s_config.INTENT_DEFINITIONS = {intent["name"]: intent for intent in intents}
threading.Thread(target=lambda: intent.background_training(intents, s_config), daemon=True).start()
return {"status": "accepted", "message": "Intent eğitimi arka planda başlatıldı."}
@app.post("/load_intent_model")
def load_intent_model():
try:
intent.INTENT_TOKENIZER = AutoTokenizer.from_pretrained(s_config.INTENT_MODEL_PATH)
intent.INTENT_MODEL = AutoModelForSequenceClassification.from_pretrained(s_config.INTENT_MODEL_PATH)
with open(os.path.join(s_config.INTENT_MODEL_PATH, "label2id.json")) as f:
intent.LABEL2ID = json.load(f)
return {"status": "ok", "message": "Intent modeli yüklendi."}
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
@app.post("/chat")
async def chat(msg: llm_model.Message, request: Request):
user_input = msg.user_input.strip()
session_id = request.headers.get("X-Session-ID", "demo-session")
if not hasattr(app.state, "session_store"):
app.state.session_store = {}
session_store = getattr(app.state, "session_store", {})
session_info = {
"session_id": session_id,
"variables": {},
"auth_tokens": {},
"last_intent": None
}
session = session_store.get(session_id, session_info)
try:
if llm_model.model is None or llm_model.tokenizer is None:
return {"error": "Model yüklenmedi."}
if s_config.INTENT_MODEL:
intent_task = asyncio.create_task(intent.detect_intent(user_input))
response_task = asyncio.create_task(llm_model.generate_response(user_input, s_config))
intent, intent_conf = await intent_task
log(f"🎯 Intent: {intent} (conf={intent_conf:.2f})")
if intent_conf > s_config.INTENT_CONFIDENCE_THRESHOLD and intent in s_config.INTENT_DEFINITIONS:
result = intent.execute_intent(intent, user_input, session)
if "reply" in result:
session_store[session_id] = result["session"]
app.state.session_store = session_store
return {"reply": result["reply"]}
elif "errors" in result:
session_store[session_id] = result["session"]
app.state.session_store = session_store
return {"response": list(result["errors"].values())[0]}
else:
return {"response": random.choice(s_config.FALLBACK_ANSWERS)}
else:
response, response_conf = await response_task
if response_conf is not None and response_conf < s_config.LLM_CONFIDENCE_THRESHOLD:
return {"response": random.choice(s_config.FALLBACK_ANSWERS)}
return {"response": response}
else:
response, response_conf = await llm_model.generate_response(user_input, s_config)
if response_conf is not None and response_conf < s_config.LLM_CONFIDENCE_THRESHOLD:
return {"response": random.choice(s_config.FALLBACK_ANSWERS)}
return {"response": response}
except Exception as e:
traceback.print_exc()
return JSONResponse(content={"error": str(e)}, status_code=500)
threading.Thread(target=llm_model.setup_model, kwargs={"s_config": s_config}, daemon=True).start()
threading.Thread(target=lambda: uvicorn.run(app, host="0.0.0.0", port=7860), daemon=True).start()
while True:
time.sleep(60)