Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel | |
from NLP_model import chatbot | |
import uvicorn | |
import asyncio | |
import time | |
import logging | |
from contextlib import asynccontextmanager | |
import os | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
) | |
logger = logging.getLogger(__name__) | |
# Chuẩn bị RAG model tại lúc khởi động | |
async def lifespan(app: FastAPI): | |
# Khởi tạo retriever sẵn khi server bắt đầu | |
logger.info("Initializing RAG model retriever...") | |
# Sử dụng asyncio.to_thread để không block event loop | |
await asyncio.to_thread(chatbot.get_chain) | |
logger.info("RAG model retriever initialized successfully") | |
yield | |
# Dọn dẹp khi shutdown | |
logger.info("Shutting down RAG model...") | |
app = FastAPI( | |
title="Solana SuperTeam RAG API", | |
description="API cho mô hình RAG của Solana SuperTeam", | |
version="1.0.0", | |
lifespan=lifespan | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Request counter để theo dõi số lượng request đang xử lý | |
active_requests = 0 | |
max_concurrent_requests = 5 # Giới hạn số request xử lý đồng thời | |
request_lock = asyncio.Lock() | |
class ChatRequest(BaseModel): | |
query: str | |
user_id: str = "default_user" | |
class ChatResponse(BaseModel): | |
response: str | |
processing_time: float = None | |
async def add_process_time_header(request: Request, call_next): | |
"""Middleware để đo thời gian xử lý và kiểm soát số lượng request""" | |
global active_requests | |
# Kiểm tra và tăng số request đang xử lý | |
async with request_lock: | |
# Nếu đã đạt giới hạn, từ chối request mới | |
if active_requests >= max_concurrent_requests and request.url.path == "/chat": | |
return JSONResponse( | |
status_code=429, | |
content={"detail": "Too many requests. Please try again later."} | |
) | |
active_requests += 1 | |
try: | |
start_time = time.time() | |
response = await call_next(request) | |
process_time = time.time() - start_time | |
# Thêm thời gian xử lý vào header | |
response.headers["X-Process-Time"] = str(process_time) | |
logger.info(f"Request processed in {process_time:.2f} seconds: {request.url.path}") | |
return response | |
finally: | |
# Giảm counter khi xử lý xong | |
async with request_lock: | |
active_requests -= 1 | |
async def chat_endpoint(request: ChatRequest): | |
""" | |
Xử lý yêu cầu chat từ người dùng | |
""" | |
start_time = time.time() | |
try: | |
# Gọi hàm chat với thông tin được cung cấp | |
response = await asyncio.to_thread(chatbot.chat, request.query, int(request.user_id)) | |
process_time = time.time() - start_time | |
return ChatResponse( | |
response=response, | |
processing_time=process_time | |
) | |
except Exception as e: | |
logger.error(f"Error processing chat request: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
""" | |
Kiểm tra trạng thái của API | |
""" | |
# Kiểm tra xem retriever đã được khởi tạo chưa | |
retriever = chatbot.get_chain() | |
if retriever: | |
status = "healthy" | |
else: | |
status = "degraded" | |
return { | |
"status": status, | |
"active_requests": active_requests, | |
"cache_size": len(chatbot.response_cache) | |
} | |
async def clear_user_memory(user_id: str): | |
""" | |
Xóa lịch sử trò chuyện của một người dùng | |
""" | |
try: | |
result = await asyncio.to_thread(chatbot.clear_memory, user_id) | |
return {"status": "success", "message": result} | |
except Exception as e: | |
logger.error(f"Error clearing memory for user {user_id}: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
import nest_asyncio | |
nest_asyncio.apply() | |
uvicorn.run(app, host="0.0.0.0", port=7860) |