Spaces:
Sleeping
Sleeping
File size: 4,617 Bytes
de833bd c626b14 de833bd cadb68d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from NLP_model import chatbot
import uvicorn
import asyncio
import time
import logging
from contextlib import asynccontextmanager
import os
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
)
logger = logging.getLogger(__name__)
# Chuẩn bị RAG model tại lúc khởi động
@asynccontextmanager
async def lifespan(app: FastAPI):
# Khởi tạo retriever sẵn khi server bắt đầu
logger.info("Initializing RAG model retriever...")
# Sử dụng asyncio.to_thread để không block event loop
await asyncio.to_thread(chatbot.get_chain)
logger.info("RAG model retriever initialized successfully")
yield
# Dọn dẹp khi shutdown
logger.info("Shutting down RAG model...")
app = FastAPI(
title="Solana SuperTeam RAG API",
description="API cho mô hình RAG của Solana SuperTeam",
version="1.0.0",
lifespan=lifespan
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Request counter để theo dõi số lượng request đang xử lý
active_requests = 0
max_concurrent_requests = 5 # Giới hạn số request xử lý đồng thời
request_lock = asyncio.Lock()
class ChatRequest(BaseModel):
query: str
user_id: str = "default_user"
class ChatResponse(BaseModel):
response: str
processing_time: float = None
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
"""Middleware để đo thời gian xử lý và kiểm soát số lượng request"""
global active_requests
# Kiểm tra và tăng số request đang xử lý
async with request_lock:
# Nếu đã đạt giới hạn, từ chối request mới
if active_requests >= max_concurrent_requests and request.url.path == "/chat":
return JSONResponse(
status_code=429,
content={"detail": "Too many requests. Please try again later."}
)
active_requests += 1
try:
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
# Thêm thời gian xử lý vào header
response.headers["X-Process-Time"] = str(process_time)
logger.info(f"Request processed in {process_time:.2f} seconds: {request.url.path}")
return response
finally:
# Giảm counter khi xử lý xong
async with request_lock:
active_requests -= 1
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
"""
Xử lý yêu cầu chat từ người dùng
"""
start_time = time.time()
try:
# Gọi hàm chat với thông tin được cung cấp
response = await asyncio.to_thread(chatbot.chat, request.query, request.user_id)
process_time = time.time() - start_time
return ChatResponse(
response=response,
processing_time=process_time
)
except Exception as e:
logger.error(f"Error processing chat request: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""
Kiểm tra trạng thái của API
"""
# Kiểm tra xem retriever đã được khởi tạo chưa
retriever = chatbot.get_chain()
if retriever:
status = "healthy"
else:
status = "degraded"
return {
"status": status,
"active_requests": active_requests,
"cache_size": len(chatbot.response_cache)
}
@app.post("/clear-memory/{user_id}")
async def clear_user_memory(user_id: str):
"""
Xóa lịch sử trò chuyện của một người dùng
"""
try:
result = await asyncio.to_thread(chatbot.clear_memory, user_id)
return {"status": "success", "message": result}
except Exception as e:
logger.error(f"Error clearing memory for user {user_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import nest_asyncio
nest_asyncio.apply()
uvicorn.run(app, host="0.0.0.0", port=7860) |