File size: 4,617 Bytes
de833bd
 
 
 
 
 
 
 
 
 
c626b14
de833bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cadb68d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from NLP_model import chatbot
import uvicorn
import asyncio
import time
import logging
from contextlib import asynccontextmanager
import os

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
)
logger = logging.getLogger(__name__)

# Chuẩn bị RAG model tại lúc khởi động
@asynccontextmanager
async def lifespan(app: FastAPI):
    # Khởi tạo retriever sẵn khi server bắt đầu
    logger.info("Initializing RAG model retriever...")
    # Sử dụng asyncio.to_thread để không block event loop
    await asyncio.to_thread(chatbot.get_chain)
    logger.info("RAG model retriever initialized successfully")
    yield
    # Dọn dẹp khi shutdown
    logger.info("Shutting down RAG model...")

app = FastAPI(
    title="Solana SuperTeam RAG API", 
    description="API cho mô hình RAG của Solana SuperTeam",
    version="1.0.0",
    lifespan=lifespan
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Request counter để theo dõi số lượng request đang xử lý
active_requests = 0
max_concurrent_requests = 5  # Giới hạn số request xử lý đồng thời
request_lock = asyncio.Lock()

class ChatRequest(BaseModel):
    query: str
    user_id: str = "default_user"

class ChatResponse(BaseModel):
    response: str
    processing_time: float = None

@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
    """Middleware để đo thời gian xử lý và kiểm soát số lượng request"""
    global active_requests
    
    # Kiểm tra và tăng số request đang xử lý
    async with request_lock:
        # Nếu đã đạt giới hạn, từ chối request mới
        if active_requests >= max_concurrent_requests and request.url.path == "/chat":
            return JSONResponse(
                status_code=429,
                content={"detail": "Too many requests. Please try again later."}
            )
        active_requests += 1
    
    try:
        start_time = time.time()
        response = await call_next(request)
        process_time = time.time() - start_time
        
        # Thêm thời gian xử lý vào header
        response.headers["X-Process-Time"] = str(process_time)
        logger.info(f"Request processed in {process_time:.2f} seconds: {request.url.path}")
        return response
    finally:
        # Giảm counter khi xử lý xong
        async with request_lock:
            active_requests -= 1

@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
    """

    Xử lý yêu cầu chat từ người dùng

    """
    start_time = time.time()
    try:
        # Gọi hàm chat với thông tin được cung cấp
        response = await asyncio.to_thread(chatbot.chat, request.query, request.user_id)
        process_time = time.time() - start_time
        return ChatResponse(
            response=response,
            processing_time=process_time
        )
    except Exception as e:
        logger.error(f"Error processing chat request: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    """

    Kiểm tra trạng thái của API

    """
    # Kiểm tra xem retriever đã được khởi tạo chưa
    retriever = chatbot.get_chain()
    if retriever:
        status = "healthy"
    else:
        status = "degraded"
    
    return {
        "status": status,
        "active_requests": active_requests,
        "cache_size": len(chatbot.response_cache)
    }

@app.post("/clear-memory/{user_id}")
async def clear_user_memory(user_id: str):
    """

    Xóa lịch sử trò chuyện của một người dùng

    """
    try:
        result = await asyncio.to_thread(chatbot.clear_memory, user_id)
        return {"status": "success", "message": result}
    except Exception as e:
        logger.error(f"Error clearing memory for user {user_id}: {e}")
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import nest_asyncio
    nest_asyncio.apply()
    uvicorn.run(app, host="0.0.0.0", port=7860)